In [1]:
# Cell 1: Imports và thiết lập chung
import os
import time
import math
import copy
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, Dataset, Subset, WeightedRandomSampler

from torchvision import transforms, models
from torchvision.models import resnet18
from sklearn.model_selection import KFold
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

from PIL import Image
from tqdm.notebook import tqdm

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [2]:
# Cell 2 (sửa lại): Dataset cho SimCLR Pretraining, xử lý trường hợp 'file' đã chứa đường dẫn đầy đủ

class SimCLRUnlabeledDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform):
        """
        csv_file: CSV chứa cột ['file'] (có thể chỉ tên ảnh hoặc đường dẫn tương đối/full đường dẫn tới ảnh)
        image_dir: thư mục gốc chứa ảnh (nếu 'file' chỉ là tên ảnh)
        transform: torchvision transform để tạo hai view cho SimCLR
        """
        self.df = pd.read_csv(csv_file)
        assert 'file' in self.df.columns
        self.image_dir = image_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        file_entry = row['file']
        # Nếu 'file' chứa đường dẫn tồn tại, dùng trực tiếp; nếu không, nối với image_dir
        if os.path.isfile(file_entry):
            img_path = file_entry
        else:
            img_path = os.path.join(self.image_dir, os.path.basename(file_entry))
        if not os.path.isfile(img_path):
            raise FileNotFoundError(f"Image not found: {img_path}")
        img = Image.open(img_path).convert('RGB')
        # Tạo hai view augmentation cho SimCLR
        xi = self.transform(img)
        xj = self.transform(img)
        return xi, xj


In [3]:
# Cell 3: Transforms cho SimCLR Pretraining

simclr_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2),
    transforms.RandomGrayscale(p=0.1),
    transforms.GaussianBlur(kernel_size=23, sigma=(0.1, 2.0)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])


In [4]:
# Cell 4: Đường dẫn và DataLoader cho Unlabeled 90%

UNLABELED_CSV = '../data/isic2018/labels/train_unlabeled.csv'  # chứa cột ['file'] đường dẫn ảnh unlabeled
TRAIN_IMG_DIR = '../data/isic2018/train'

unlabeled_ds = SimCLRUnlabeledDataset(UNLABELED_CSV, TRAIN_IMG_DIR, transform=simclr_transform)
unlabeled_loader = DataLoader(
    unlabeled_ds, batch_size=256, shuffle=True,
    num_workers=0, pin_memory=True, drop_last=True
)

print("Unlabeled dataset size:", len(unlabeled_ds))
print("Batches per epoch:", len(unlabeled_loader))


Unlabeled dataset size: 9014
Batches per epoch: 35


In [5]:
# Cell 5: Định nghĩa mô hình SimCLR (Encoder + Projection Head)

class SimCLRModel(nn.Module):
    def __init__(self, base_encoder=resnet18, projection_dim=128):
        super(SimCLRModel, self).__init__()
        # 1) Backbone ResNet18 (loại bỏ fully-connected)
        self.encoder = base_encoder(pretrained=False)
        self.encoder.fc = nn.Identity()
        # 2) Projection head: MLP (512 -> 512 -> projection_dim)
        encoder_dim = 512  # ResNet18 feature dim
        self.projector = nn.Sequential(
            nn.Linear(encoder_dim, encoder_dim),
            nn.ReLU(inplace=True),
            nn.Linear(encoder_dim, projection_dim)
        )

    def forward(self, x):
        h = self.encoder(x)             # shape (N, 512, 1, 1) hoặc (N,512)
        if h.dim() == 4:
            h = h.view(h.size(0), -1)
        z = self.projector(h)           # shape (N, projection_dim)
        return F.normalize(z, dim=1)    # lý tưởng cho contrastive loss


In [6]:
# Cell 6: NT-Xent Loss (SimCLR Contrastive Loss)

class NTXentLoss(nn.Module):
    def __init__(self, batch_size, temperature=0.5):
        super(NTXentLoss, self).__init__()
        self.batch_size = batch_size
        self.temperature = temperature
        self.similarity_f = nn.CosineSimilarity(dim=2)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")

        # tạo mask để loại bỏ positive pair và self-pair
        self.register_buffer("mask", self._get_correlated_mask(batch_size))

    def _get_correlated_mask(self, N):
        # N: batch_size (mỗi batch gồm 2N samples)
        diag = np.eye(2 * N)
        l1 = np.eye((2 * N), 2 * N, k=N)
        l2 = np.eye((2 * N), 2 * N, k=-N)
        mask = torch.from_numpy((diag + l1 + l2))
        mask = (1 - mask).bool()
        return mask

    def forward(self, zi, zj):
        """
        zi, zj: hai batch embeddings của hai views, shape (N, dim)
        """
        N = self.batch_size
        z = torch.cat([zi, zj], dim=0)                            # (2N, dim)
        # ma trận cosine similarity (2N x 2N)
        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0))   # (2N, 2N)
        sim = sim / self.temperature

        # mỗi sample có một positive example: index i<->i+N
        pos = torch.cat([torch.diag(sim, N), torch.diag(sim, -N)], dim=0).view(2 * N, 1)  # (2N,1)

        # tất cả similarity trừ self và positive pair
        mask = self.mask[: 2 * N, : 2 * N]
        sim_masked = sim.masked_select(mask).view(2 * N, -1)     # (2N, 2N-2)

        labels = torch.zeros(2 * N).long().to(zi.device)
        logits = torch.cat([pos, sim_masked], dim=1)              # (2N, 1 + 2N-2)
        loss = self.criterion(logits, labels)
        loss = loss / (2 * N)
        return loss


In [7]:
# Cell 7: Pretrain SimCLR

model_simclr = SimCLRModel(base_encoder=resnet18, projection_dim=128).to(device)
optimizer_simclr = optim.AdamW(model_simclr.parameters(), lr=3e-4, weight_decay=1e-4)
nt_xent_loss = NTXentLoss(batch_size=256, temperature=0.5).to(device)

import os
import torch

start_epoch = 1
num_epochs_pretrain = 100
ckpt_dir = 'checkpoints/simclr_pretrain'
os.makedirs(ckpt_dir, exist_ok=True)

# Nếu có checkpoint gần nhất, load để resume
latest_ckpt = None
for f in sorted(os.listdir(ckpt_dir)):
    if f.startswith('simclr_epoch65') and f.endswith('.pth'):
        latest_ckpt = f

if latest_ckpt is not None:
    ckpt_path = os.path.join(ckpt_dir, latest_ckpt)
    ckpt = torch.load(ckpt_path, map_location=device)
    model_simclr.load_state_dict(ckpt['model_state_dict'])
    optimizer_simclr.load_state_dict(ckpt['optimizer_state_dict'])
    if 'scheduler_state_dict' in ckpt:
        scheduler.load_state_dict(ckpt['scheduler_state_dict'])
    start_epoch = ckpt['epoch'] + 1
    print(f"Resumed from checkpoint {latest_ckpt}, starting at epoch {start_epoch}")

for epoch in range(start_epoch, num_epochs_pretrain + 1):
    model_simclr.train()
    running_loss = 0.0
    pbar = tqdm(unlabeled_loader, desc=f"SimCLR Epoch {epoch}/{num_epochs_pretrain}", leave=False)
    for xi, xj in pbar:
        xi = xi.to(device)
        xj = xj.to(device)
        zi = model_simclr(xi)
        zj = model_simclr(xj)
        loss_simclr = nt_xent_loss(zi, zj)

        optimizer_simclr.zero_grad()
        loss_simclr.backward()
        optimizer_simclr.step()
        if 'scheduler' in locals():
            scheduler.step()

        running_loss += loss_simclr.item()
        pbar.set_postfix({'loss': f"{running_loss/(pbar.n+1):.4f}"})
    avg_loss = running_loss / len(unlabeled_loader)
    print(f"Epoch {epoch:02d} | SimCLR Loss = {avg_loss:.4f}")

    # Lưu checkpoint mỗi 20 epoch (cũng ghi luôn epoch hiện tại, optimizer, scheduler)
    if epoch % 5 == 0 or epoch == num_epochs_pretrain:
        ckpt_path = os.path.join(ckpt_dir, f'simclr_epoch{epoch}.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model_simclr.state_dict(),
            'optimizer_state_dict': optimizer_simclr.state_dict(),
            # 'scheduler_state_dict': scheduler.state_dict(),  # nếu có dùng scheduler
        }, ckpt_path)
        print(f"↳ Saved checkpoint: {ckpt_path}")



Resumed from checkpoint simclr_epoch65.pth, starting at epoch 66


SimCLR Epoch 66/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 66 | SimCLR Loss = 4.8035


SimCLR Epoch 67/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 67 | SimCLR Loss = 4.7998


SimCLR Epoch 68/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 68 | SimCLR Loss = 4.8042


SimCLR Epoch 69/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 69 | SimCLR Loss = 4.8022


SimCLR Epoch 70/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 70 | SimCLR Loss = 4.7959
↳ Saved checkpoint: checkpoints/simclr_pretrain/simclr_epoch70.pth


SimCLR Epoch 71/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 71 | SimCLR Loss = 4.7927


SimCLR Epoch 72/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 72 | SimCLR Loss = 4.7940


SimCLR Epoch 73/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 73 | SimCLR Loss = 4.7953


SimCLR Epoch 74/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 74 | SimCLR Loss = 4.7907


SimCLR Epoch 75/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 75 | SimCLR Loss = 4.7865
↳ Saved checkpoint: checkpoints/simclr_pretrain/simclr_epoch75.pth


SimCLR Epoch 76/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 76 | SimCLR Loss = 4.7840


SimCLR Epoch 77/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 77 | SimCLR Loss = 4.7856


SimCLR Epoch 78/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 78 | SimCLR Loss = 4.7862


SimCLR Epoch 79/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 79 | SimCLR Loss = 4.7760


SimCLR Epoch 80/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 80 | SimCLR Loss = 4.7782
↳ Saved checkpoint: checkpoints/simclr_pretrain/simclr_epoch80.pth


SimCLR Epoch 81/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 81 | SimCLR Loss = 4.7688


SimCLR Epoch 82/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 82 | SimCLR Loss = 4.7718


SimCLR Epoch 83/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 83 | SimCLR Loss = 4.7716


SimCLR Epoch 84/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 84 | SimCLR Loss = 4.7736


SimCLR Epoch 85/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 85 | SimCLR Loss = 4.7792
↳ Saved checkpoint: checkpoints/simclr_pretrain/simclr_epoch85.pth


SimCLR Epoch 86/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 86 | SimCLR Loss = 4.7719


SimCLR Epoch 87/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 87 | SimCLR Loss = 4.7768


SimCLR Epoch 88/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 88 | SimCLR Loss = 4.7732


SimCLR Epoch 89/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 89 | SimCLR Loss = 4.7721


SimCLR Epoch 90/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 90 | SimCLR Loss = 4.7675
↳ Saved checkpoint: checkpoints/simclr_pretrain/simclr_epoch90.pth


SimCLR Epoch 91/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 91 | SimCLR Loss = 4.7653


SimCLR Epoch 92/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 92 | SimCLR Loss = 4.7600


SimCLR Epoch 93/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 93 | SimCLR Loss = 4.7614


SimCLR Epoch 94/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 94 | SimCLR Loss = 4.7607


SimCLR Epoch 95/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 95 | SimCLR Loss = 4.7711
↳ Saved checkpoint: checkpoints/simclr_pretrain/simclr_epoch95.pth


SimCLR Epoch 96/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 96 | SimCLR Loss = 4.7585


SimCLR Epoch 97/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 97 | SimCLR Loss = 4.7618


SimCLR Epoch 98/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 98 | SimCLR Loss = 4.7602


SimCLR Epoch 99/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 99 | SimCLR Loss = 4.7595


SimCLR Epoch 100/100:   0%|          | 0/35 [00:00<?, ?it/s]

Epoch 100 | SimCLR Loss = 4.7558
↳ Saved checkpoint: checkpoints/simclr_pretrain/simclr_epoch100.pth
