In [1]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

In [2]:
train_image_dir = "train/images"
train_csv = "train.csv"

val_image_dir = "val/images"
val_csv = "val.csv"

df_train = pd.read_csv(train_csv)
df_val = pd.read_csv(val_csv)

df_train = df_train[df_train["image"].apply(
    lambda x: os.path.exists(os.path.join(train_image_dir, f"{x}.jpg")) or
              os.path.exists(os.path.join(train_image_dir, f"{x}.png"))
)].reset_index(drop=True)

df_val = df_val[df_val["image"].apply(
    lambda x: os.path.exists(os.path.join(val_image_dir, f"{x}.jpg")) or
              os.path.exists(os.path.join(val_image_dir, f"{x}.png"))
)].reset_index(drop=True)


In [3]:
from sklearn.preprocessing import LabelEncoder
le = LabelEncoder()
df_train["label_enc"] = le.fit_transform(df_train["label"])
df_val["label_enc"] = le.transform(df_val["label"])



In [4]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

class DermaDataset(Dataset):
    def __init__(self, df, img_dir, transform):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row["image"]
    
        img_path_jpg = os.path.join(self.img_dir, f"{img_name}.jpg")
        img_path_png = os.path.join(self.img_dir, f"{img_name}.png")
    
        path = img_path_jpg if os.path.exists(img_path_jpg) else img_path_png
    
        try:
            img = Image.open(path).convert("RGB")
        except Exception:
            return None  # mark as invalid
    
        img = self.transform(img)
        label = torch.tensor(row["label_enc"], dtype=torch.long)
        return img, label



In [5]:
from torchvision import transforms

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(518, scale=(0.75, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.2),
    transforms.RandomAffine(degrees=15, translate=(0.1, 0.1)),
    transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
    transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

transform_val = transforms.Compose([
    transforms.Resize(518, interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.CenterCrop(518),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])


train_set = DermaDataset(df_train, train_image_dir, transform_train)
val_set = DermaDataset(df_val, val_image_dir, transform_val)



In [6]:
def safe_collate(batch):
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    imgs, labels = zip(*batch)
    return torch.stack(imgs, 0), torch.tensor(labels)


train_loader = DataLoader(
    train_set,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    drop_last=True,
    collate_fn=safe_collate
)

val_loader = DataLoader(
    val_set,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    collate_fn=safe_collate
)




In [7]:
import timm
import torch.nn as nn

dino_model = timm.create_model('vit_base_patch14_dinov2', pretrained=True)
num_classes = len(le.classes_)

for param in dino_model.parameters():
    param.requires_grad = False

for param in dino_model.blocks[-2:].parameters():
    param.requires_grad = True

dino_model.head = nn.Sequential(
    nn.LayerNorm(dino_model.embed_dim),
    nn.Linear(dino_model.embed_dim, 1024),
    nn.GELU(),
    nn.Dropout(0.4),
    nn.Linear(1024, 512),
    nn.GELU(),
    nn.Dropout(0.3),
    nn.Linear(512, num_classes)
)

model = dino_model.cuda()

import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

criterion = nn.CrossEntropyLoss(label_smoothing=0.15)

optimizer = optim.AdamW([
    {'params': dino_model.blocks[-2:].parameters(), 'lr': 5e-5},
    {'params': model.head.parameters(), 'lr': 1e-3}
], weight_decay=0.02, betas=(0.9, 0.999))

scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6)



In [None]:
from tqdm import tqdm

def eval_model(loader):
    model.eval()
    correct = 0
    total = 0
    loss_total = 0

    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.cuda(), labels.cuda()
            preds = model(imgs)
            loss = criterion(preds, labels)
            loss_total += loss.item()
            correct += (preds.argmax(1) == labels).sum().item()
            total += labels.size(0)

    return correct / total, loss_total / len(loader)

best_val_acc = 0
patience = 7
patience_counter = 0

for epoch in range(25):
    model.train()
    train_loss = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/25")
    for imgs, labels in pbar:
        imgs, labels = imgs.cuda(), labels.cuda()

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()

        train_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})

    val_acc, val_loss = eval_model(val_loader)
    scheduler.step()

    print(f"Epoch {epoch+1} | Train Loss: {train_loss/len(train_loader):.4f} | Val Acc: {val_acc:.4f} | Val Loss: {val_loss:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), "dino_derma_classifier.pth")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Early stopping at epoch {epoch+1}")
            break

print(f"\nBest validation accuracy: {best_val_acc:.4f}")
model.load_state_dict(torch.load("dino_derma_classifier.pth"))

Epoch 1/25: 100%|██████████| 45/45 [02:58<00:00,  3.97s/it, loss=0.575]


Epoch 1 | Train Loss: 0.5534 | Val Acc: 0.8568 | Val Loss: 0.5039


Epoch 2/25: 100%|██████████| 45/45 [02:57<00:00,  3.95s/it, loss=0.508]


Epoch 2 | Train Loss: 0.4210 | Val Acc: 0.8761 | Val Loss: 0.4523


Epoch 3/25: 100%|██████████| 45/45 [02:59<00:00,  3.99s/it, loss=0.411]


Epoch 3 | Train Loss: 0.3641 | Val Acc: 0.8910 | Val Loss: 0.4198


Epoch 4/25:  67%|██████▋   | 30/45 [02:01<00:57,  3.83s/it, loss=0.356]

Epoch 4 | Train Loss: 0.3558 | Val Acc: 0.9444 | Val Loss: 0.3460


Epoch 5/25: 100%|██████████| 45/45 [02:57<00:00,  3.96s/it, loss=0.427]


Epoch 5 | Train Loss: 0.3353 | Val Acc: 0.9444 | Val Loss: 0.3508


Epoch 6/25: 100%|██████████| 45/45 [02:59<00:00,  3.98s/it, loss=0.273]


Epoch 6 | Train Loss: 0.3206 | Val Acc: 0.9551 | Val Loss: 0.3332


Epoch 7/25: 100%|██████████| 45/45 [03:00<00:00,  4.01s/it, loss=0.289]


Epoch 7 | Train Loss: 0.3068 | Val Acc: 0.9466 | Val Loss: 0.3564


Epoch 8/25: 100%|██████████| 45/45 [02:58<00:00,  3.97s/it, loss=0.286]


Epoch 8 | Train Loss: 0.2877 | Val Acc: 0.9594 | Val Loss: 0.3310


Epoch 9/25: 100%|██████████| 45/45 [02:59<00:00,  4.00s/it, loss=0.279]


Epoch 9 | Train Loss: 0.2833 | Val Acc: 0.9573 | Val Loss: 0.3305


Epoch 10/25: 100%|██████████| 45/45 [02:59<00:00,  3.98s/it, loss=0.286]


Epoch 10 | Train Loss: 0.2847 | Val Acc: 0.9594 | Val Loss: 0.3289


Epoch 11/25: 100%|██████████| 45/45 [03:01<00:00,  4.02s/it, loss=0.285]


Epoch 11 | Train Loss: 0.3094 | Val Acc: 0.9295 | Val Loss: 0.3928


Epoch 12/25: 100%|██████████| 45/45 [02:58<00:00,  3.98s/it, loss=0.293]


Epoch 12 | Train Loss: 0.3210 | Val Acc: 0.9466 | Val Loss: 0.3542


Epoch 13/25: 100%|██████████| 45/45 [02:58<00:00,  3.97s/it, loss=0.289]


Epoch 13 | Train Loss: 0.3258 | Val Acc: 0.9744 | Val Loss: 0.3248


Epoch 14/25: 100%|██████████| 45/45 [02:59<00:00,  3.99s/it, loss=0.3]  


Epoch 14 | Train Loss: 0.3250 | Val Acc: 0.9615 | Val Loss: 0.3324


Epoch 15/25: 100%|██████████| 45/45 [02:58<00:00,  3.97s/it, loss=0.271]


Epoch 15 | Train Loss: 0.2972 | Val Acc: 0.9594 | Val Loss: 0.3379


Epoch 16/25: 100%|██████████| 45/45 [02:58<00:00,  3.97s/it, loss=0.27] 


Epoch 16 | Train Loss: 0.2964 | Val Acc: 0.9530 | Val Loss: 0.3511


Epoch 17/25: 100%|██████████| 45/45 [02:59<00:00,  3.98s/it, loss=0.27] 


Epoch 17 | Train Loss: 0.2789 | Val Acc: 0.9573 | Val Loss: 0.3412


Epoch 18/25: 100%|██████████| 45/45 [02:58<00:00,  3.98s/it, loss=0.268]


Epoch 18 | Train Loss: 0.2788 | Val Acc: 0.9594 | Val Loss: 0.3349


Epoch 19/25: 100%|██████████| 45/45 [02:58<00:00,  3.96s/it, loss=0.268]


Epoch 19 | Train Loss: 0.2758 | Val Acc: 0.9615 | Val Loss: 0.3380


Epoch 20/25: 100%|██████████| 45/45 [02:58<00:00,  3.96s/it, loss=0.268]


Epoch 20 | Train Loss: 0.2751 | Val Acc: 0.9615 | Val Loss: 0.3400
Early stopping at epoch 20

Best validation accuracy: 0.9744


<All keys matched successfully>