In [1]:
import os
import torch
from torchinfo import summary
from pathlib import Path

## Create Vision Transformer (ViT)

In [2]:
from melanoma_classification import VisionTransformer

In [None]:
vit = VisionTransformer()
vit.load_pretrained_weights('deit_base_patch16_224')
vit.set_classifier(num_classes=2)

## Prepare data

In [None]:
from utils.dermmel import DermMel
from melanoma_classification.utils.transformations import (
    train_transform,
    production_transform
)

train_dataset = DermMel(
    '../data',
    split='train_sep',
    transform=train_transform()
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

valid_dataset = DermMel(
    '../data',
    split='valid',
    transform=production_transform()
)
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset,
    batch_size=8,
    shuffle=True,
    num_workers=0
)

train_dataset.visualize_image(-1)

## Training

In [5]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(vit.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

In [6]:
# Init device
device = torch.device('cpu')
if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.mps.is_available():
    device = torch.device('mps')

In [None]:
from training.trainer import train

train(
    model=vit,
    train_loader=train_dataloader,
    val_loader=valid_dataloader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=20,
    freezed_epochs=10,
    device=device,
    checkpoint_model_file="checkpoint_epoch_6.pth",
    checkpoint_path=Path("checkpoints/dermmel"),
    save_every_n_epochs=1
)