In [1]:
import torch

from melanoma_classification.model import get_dermmel_classifier_v1
from melanoma_classification.utils import (
    get_device,
    production_transform,
    train_transform,
)
from pathlib import Path
from training.trainer import train
from utils.dermmel import DermMel

In [None]:
device = get_device()

## Create Vision Transformer (ViT)


In [None]:
vit = get_dermmel_classifier_v1()
vit.load_pretrained_weights("deit_base_patch16_224")

## Prepare data


In [None]:
train_dataset = DermMel("../data", split="train_sep", transform=train_transform())
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=8, shuffle=True, num_workers=0
)

valid_dataset = DermMel("../data", split="valid", transform=production_transform())
valid_dataloader = torch.utils.data.DataLoader(
    valid_dataset, batch_size=8, shuffle=True, num_workers=0
)

train_dataset.visualize_image(-1)

## Training


In [6]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(
    [
        {"params": vit.cls_token, "lr": 1e-7},
        {"params": vit.pos_embed, "lr": 1e-7},
        {"params": vit.patch_embedding.parameters(), "lr": 1e-6},
        {"params": vit.transformer_layers.parameters(), "lr": 1e-5},
        {"params": vit.norm.parameters(), "lr": 1e-6},
        {"params": vit.classifier.parameters(), "lr": 1e-4},
    ]
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode="min", patience=3, factor=0.1
)

In [None]:
train(
    model=vit,
    train_loader=train_dataloader,
    val_loader=valid_dataloader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=20,
    freezed_epochs=5,
    device=device,
    # checkpoint_model_file="checkpoint_epoch_6.pth",
    checkpoint_path=Path("checkpoints/dermmel_tmp"),
    save_every_n_epochs=1,
)