> # Preparation environnement 

In [1]:
!pip install -q medmnist

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m87.2/87.2 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m30.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m12.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.5/207.5 MB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m21.1/21.1 MB[0m [31m76.3 MB/s

In [2]:
import time, copy, random
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.colors as colors
from sklearn.decomposition import PCA

import torch
import torch.nn as nn
import torch.nn.functional as F       
from torch.utils.data import DataLoader, Subset, Dataset

import torchvision.models as models
import torchvision.transforms as T
import torchvision.transforms.functional as TF   

from medmnist.dataset import PathMNIST
from medmnist import Evaluator

from tqdm.auto import tqdm
from sklearn.manifold import TSNE
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_auc_score



> # Définir un modèle supervisé de base (ResNet18)

> # BarlowTwin 1% 

In [3]:
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Calcul mean et std

In [4]:
tmp_transform = T.ToTensor()
train_raw = PathMNIST(split='train', transform=tmp_transform, download=True)

loader = DataLoader(train_raw, batch_size=512, num_workers=2)

n_pixels = 0
mean = torch.zeros(3)
M2   = torch.zeros(3)           

for imgs, _ in loader:
    bs, c, h, w = imgs.shape
    imgs = imgs.view(bs, c, -1)                     
    batch_sum = imgs.sum((0, 2))                
    batch_pix = bs * h * w
    n_pixels += batch_pix

    delta = batch_sum / batch_pix - mean        
    mean += delta * (batch_pix / n_pixels)      

    M2 += ((imgs - mean.view(1, -1, 1))**2).sum((0, 2))

std = torch.sqrt(M2 / (n_pixels - 1))
PATH_MEAN = tuple(mean.numpy())
PATH_STD  = tuple(std.numpy())

print("MEAN :", PATH_MEAN)
print("STD  :", PATH_STD)


100%|██████████| 206M/206M [00:07<00:00, 26.4MB/s]


MEAN : (0.74054503, 0.5329823, 0.705829)
STD  : (0.12368046, 0.17675981, 0.1244284)


In [5]:
class RandomRotate90:
    def __init__(self, p=0.75, interpolation=T.InterpolationMode.BILINEAR):
        self.p = p
        self.interpolation = interpolation
    def __call__(self, img):
        if random.random() < self.p:
            angle = random.choice((0, 90, 180, 270))
            if angle:                                    
                img = TF.rotate(img, angle,
                                interpolation=self.interpolation)
        return img

In [6]:
def get_backbone_output_dim(backbone):
    device = next(backbone.parameters()).device
    dummy   = torch.zeros(1, 3, 28, 28, device=device)
    with torch.no_grad():
        return backbone(dummy).shape[1]

class Classifier(nn.Module):
    def __init__(self, backbone, n_classes):
        super().__init__()
        self.backbone = backbone
        feat_dim = get_backbone_output_dim(backbone)
        self.head = nn.Linear(feat_dim, n_classes)

    def forward(self, x):
        return self.head(self.backbone(x))

In [7]:
def make_subset(ds, pct):
    if pct >= 1.0:                       
        return ds
    y = ds.labels.squeeze()                       
    sss = StratifiedShuffleSplit(n_splits=1, train_size=pct, random_state=SEED)
    idx, _ = next(sss.split(np.zeros(len(y)), y))
    return Subset(ds, idx)

### 1. Transforms & Two-View Dataset

Ces six augmentations viennent du papier original Barlow Twins (Zbontar et al., 2021), section 3.1 “Image Augmentations”. 

*“We apply the following data augmentations to each image in a pair:
random cropping and resizing (scale between 0.08 and 1.0), random horizontal flipping, color jittering (with probability 0.8), random grayscale conversion (with probability 0.2), Gaussian blur (always applied for the first view; 50% for the second), and normalization.”*



In [8]:
'''ssl_transform = T.Compose([
    T.RandomHorizontalFlip(),
    T.RandomAffine(degrees=10, translate=(0.05, 0.05)),  
    T.ColorJitter(0.1, 0.1, 0.05, 0.02),
    T.RandomApply([T.GaussianBlur(3, sigma=(0.1, 0.5))], p=0.5),
    T.ToTensor(),
    T.Normalize([.5], [.5])
])'''

ssl_transform = T.Compose([
    T.RandomAffine(
        degrees=10,
        translate=(0.1, 0.1),
        scale=(0.9, 1.1),
        interpolation=T.InterpolationMode.BILINEAR,  # préserve les détails
    ),
    T.RandomHorizontalFlip(),
    T.RandomVerticalFlip(),
    RandomRotate90(p=0.75),
    T.ColorJitter(brightness=0.12, contrast=0.12,
                  saturation=0.12, hue=0.03),
    # Bruit de microscopes
    T.RandomApply(
        [T.GaussianBlur(kernel_size=3, sigma=(.1, 1.0))], p=0.3),
    T.ToTensor(),
    T.Normalize(PATH_MEAN, PATH_STD),
])


'''eval_transform = T.Compose([
    T.ToTensor(),
    T.Normalize([.5], [.5])
])'''

eval_transform = T.Compose([
    T.ToTensor(),
    T.Normalize(PATH_MEAN, PATH_STD),
])

data_flag  = 'pathmnist'
base_train = PathMNIST(split='train', download=True)  

class TwoViewDataset(Dataset):
    def __init__(self, base_ds, transform):
        self.base_ds  = base_ds
        self.transform = transform

    def __len__(self):
        return len(self.base_ds)

    def __getitem__(self, idx):
        img, _ = self.base_ds[idx]
        v1 = self.transform(img)
        v2 = self.transform(img)
        return v1, v2

ssl_dataset = TwoViewDataset(base_train, ssl_transform)
ssl_loader  = DataLoader(
    ssl_dataset,
    batch_size=512,      
    shuffle=True,
    num_workers=2,
    persistent_workers=True,
    pin_memory=True,
    drop_last=True
)

train_labeled = PathMNIST(split='train', transform=eval_transform, download=True)
val_set       = PathMNIST(split='val'  , transform=eval_transform, download=True)
test_set      = PathMNIST(split='test' , transform=eval_transform, download=True)
n_classes     = len(np.unique(train_labeled.labels))


### 2. Barlow Twins : modèle & loss

Backbone (avec fc en moins) + projecteur + loss

In [9]:
class BarlowTwins(nn.Module):
    def __init__(self, proj_dim=2048, hidden_dim=2048):
        super().__init__()
        self.backbone = models.resnet18(weights=None)

        self.backbone.conv1 = nn.Conv2d(3, 64, 3, 1, 1, bias=False)
        self.backbone.bn1   = nn.BatchNorm2d(64, eps=1e-3, momentum=0.03)
        self.backbone.maxpool = nn.Identity()

        feat_dim = self.backbone.fc.in_features
        self.backbone.fc = nn.Identity()

        self.projector = nn.Sequential(
            nn.Linear(feat_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, hidden_dim, bias=False),
            nn.BatchNorm1d(hidden_dim), nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, proj_dim, bias=False)
        )

    def forward(self, x1, x2):
        z1 = self.projector(self.backbone(x1))
        z2 = self.projector(self.backbone(x2))
        return z1, z2

def barlow_twins_loss_verbose(z1, z2, lambd=5e-3, eps=1e-9):
    N, D = z1.size()
    z1 = (z1 - z1.mean(0)) / (z1.std(0, unbiased=False) + eps)
    z2 = (z2 - z2.mean(0)) / (z2.std(0, unbiased=False) + eps)
    c = (z1.T @ z2) / N

    on_diag  = torch.diagonal(c).add_(-1).pow_(2).sum()
    off_diag = (c.flatten()[1:].view(D-1, D+1)[:, :-1]).pow_(2).sum()
    loss = on_diag + lambd * off_diag
    return loss, on_diag.detach(), off_diag.detach()


### 3. Pré-entraînement SSL

In [10]:
'''def extract_feats(model, loader):
    model.eval()
    feats, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            feats.append(model.backbone(x).cpu())
            labels.append(y.squeeze())          
    return torch.cat(feats), torch.cat(labels)

def plot_tsne_points(emb, lab, title, fname, n_classes=9):
    cmap = cm.get_cmap('tab10', n_classes)
    norm = colors.Normalize(vmin=0, vmax=n_classes-1)

    plt.figure(figsize=(5,5))
    plt.scatter(emb[:,0], emb[:,1], c=lab, s=3, cmap=cmap, norm=norm)
    plt.title(title)

    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ticks=range(n_classes), label="classe")

    plt.tight_layout()
    plt.savefig(fname, dpi=300)
    plt.close()'''



'def extract_feats(model, loader):\n    model.eval()\n    feats, labels = [], []\n    with torch.no_grad():\n        for x, y in loader:\n            x = x.to(device)\n            feats.append(model.backbone(x).cpu())\n            labels.append(y.squeeze())          \n    return torch.cat(feats), torch.cat(labels)\n\ndef plot_tsne_points(emb, lab, title, fname, n_classes=9):\n    cmap = cm.get_cmap(\'tab10\', n_classes)\n    norm = colors.Normalize(vmin=0, vmax=n_classes-1)\n\n    plt.figure(figsize=(5,5))\n    plt.scatter(emb[:,0], emb[:,1], c=lab, s=3, cmap=cmap, norm=norm)\n    plt.title(title)\n\n    sm = cm.ScalarMappable(cmap=cmap, norm=norm)\n    sm.set_array([])\n    plt.colorbar(sm, ticks=range(n_classes), label="classe")\n\n    plt.tight_layout()\n    plt.savefig(fname, dpi=300)\n    plt.close()'

In [11]:
def extract_feats(model, loader):
    model.eval()
    feats, labels = [], []
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            feats.append(model.backbone(x).cpu())
            labels.append(y.squeeze())      
    return torch.cat(feats), torch.cat(labels)

def plot_tsne_points(emb2d, lab, title, fname, n_classes=9):
    cmap = cm.get_cmap('tab10', n_classes)
    norm = colors.Normalize(vmin=0, vmax=n_classes-1)

    plt.figure(figsize=(5, 5))
    plt.scatter(emb2d[:, 0], emb2d[:, 1],
                c=lab, s=3, cmap=cmap, norm=norm)
    plt.title(title)

    sm = cm.ScalarMappable(cmap=cmap, norm=norm)
    sm.set_array([])
    plt.colorbar(sm, ticks=range(n_classes), label="classe")

    plt.tight_layout()
    plt.savefig(fname, dpi=300)
    plt.close()

def make_tsne(model, base_loader, title, fname,
              n_classes=9, keep_per_class=300):
    emb, lab = extract_feats(model, base_loader)

    idx_list = []
    for c in range(n_classes):
        idx_c = torch.nonzero(lab == c).squeeze()
        idx_list.append(idx_c[:keep_per_class])
    idx = torch.cat(idx_list)
    emb, lab = emb[idx], lab[idx]

    emb50 = PCA(n_components=50).fit_transform(emb)

    z = TSNE(n_components=2,
             perplexity=40,
             early_exaggeration=12,
             n_iter=3000,
             init="pca").fit_transform(emb50)

    plot_tsne_points(z, lab, title, fname, n_classes)

In [12]:
model = BarlowTwins().to(device)
opt_ssl = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-5)
sched   = torch.optim.lr_scheduler.CosineAnnealingLR(opt_ssl, T_max=100)

on_hist, off_hist = [] , []         
EPOCHS_SSL = 100

for ep in range(1, EPOCHS_SSL + 1):
    model.train()
    loss_sum = 0

    for v1, v2 in tqdm(ssl_loader, leave=False):
        v1 = v1.to(device, non_blocking=True)
        v2 = v2.to(device, non_blocking=True)

        z1, z2 = model(v1, v2)
        loss, on_cur, off_cur = barlow_twins_loss_verbose(z1, z2)

        opt_ssl.zero_grad()
        loss.backward()
        opt_ssl.step()

        loss_sum += loss.item()
        on_hist.append(on_cur.item())
        off_hist.append(off_cur.item())

    sched.step()   

    if ep % 25 == 0 or ep == 1:
        print(f"[SSL] Epoch {ep:3d}/{EPOCHS_SSL} | Loss "
              f"{loss_sum/len(ssl_loader):.4f}")

        eval_loader = DataLoader(train_labeled, batch_size=256,
                         shuffle=False, num_workers=2, pin_memory=True)

        make_tsne(model, eval_loader,
                  title=f"t-SNE SSL epoch {ep}",
                  fname=f"ssl_tsne_epoch{ep}.png",
                  n_classes=n_classes)


        '''eval_loader = DataLoader(train_labeled, batch_size=256,
                                 shuffle=False, num_workers=2,
                                 pin_memory=True)
        emb, lab = extract_feats(model, eval_loader)

        if len(emb) > 2000:
            idx = torch.randperm(len(emb))[:2000]
            emb, lab = emb[idx], lab[idx]

        z = TSNE(n_components=2, perplexity=30,
                 init="pca", n_iter=1000).fit_transform(emb)

        plot_tsne_points(
            z, lab,
            title=f"t-SNE SSL epoch {ep}",
            fname=f"ssl_tsne_epoch{ep}.png",
            n_classes=n_classes            
        )

plt.figure(figsize=(6, 4))
plt.plot(on_hist,  label="on-diag")
plt.plot(off_hist, label="off-diag")
plt.yscale("log")
plt.xlabel("itérations")
plt.ylabel("Σ erreur² (log)")
plt.legend()
plt.tight_layout()
plt.savefig("bt_diag_curves.png", dpi=300)
plt.close()'''

torch.save(model.backbone.state_dict(), "bt_backbone.pth")


  0%|          | 0/175 [00:00<?, ?it/s]

[SSL] Epoch   1/100 | Loss 940.1986


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

[SSL] Epoch  25/100 | Loss 71.9134


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

[SSL] Epoch  50/100 | Loss 58.7123


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

[SSL] Epoch  75/100 | Loss 53.4438


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

[SSL] Epoch 100/100 | Loss 51.8596


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


k-NN top-1 sur les embeddings SSL

In [13]:
train_loader_knn = DataLoader(train_labeled, batch_size=256, shuffle=False,
                              num_workers=2, pin_memory=True)
val_loader_knn   = DataLoader(val_set      , batch_size=256, shuffle=False,
                              num_workers=2, pin_memory=True)

emb_tr, lab_tr = extract_feats(model, train_loader_knn)
emb_va, lab_va = extract_feats(model, val_loader_knn)

from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier(n_neighbors=200, metric="cosine", n_jobs=-1)
knn.fit(emb_tr, lab_tr)
acc_knn = knn.score(emb_va, lab_va)
print(f"[k-NN] accuracy top-1 avant fine-tune : {acc_knn:.3f}")


[k-NN] accuracy top-1 avant fine-tune : 0.797


### 4. Créer le split 1 % annoté

In [14]:
train_sup_transform = eval_transform         
criterion = nn.CrossEntropyLoss()          

PCTS = [0.01, 0.05, 0.10, 0.20, 0.50, 0.80, 1.00]
PCTS_FOR_ROC = {0.01, 0.20, 1.00}

pretrained_backbone = BarlowTwins().backbone
pretrained_backbone.load_state_dict(
    torch.load("/kaggle/input/backboneweights/bt_backbone.pth",
               map_location=device))

results = {}

for pct in PCTS:
    full_train_plain = PathMNIST(split='train',
                                 transform=train_sup_transform,
                                 download=True)
    sub_train  = make_subset(full_train_plain, pct)

    train_loader = DataLoader(sub_train, batch_size=128, shuffle=True,
                              num_workers=2, pin_memory=True,
                              worker_init_fn=lambda s:
                                  np.random.seed(torch.initial_seed() % 2**32))
    val_loader   = DataLoader(val_set, batch_size=256, shuffle=False,
                              num_workers=2, pin_memory=True)

    model_ft = Classifier(copy.deepcopy(pretrained_backbone),
                          n_classes).to(device)

    for p in model_ft.backbone.parameters(): p.requires_grad = False
    opt = torch.optim.Adam(model_ft.head.parameters(),
                           lr=1e-3, weight_decay=1e-5)

    for _ in range(5):
        model_ft.train()
        for x, y in train_loader:
            x = x.to(device)
            y = y.squeeze().long().to(device)
            loss = criterion(model_ft(x), y)
            opt.zero_grad(); loss.backward(); opt.step()

    # -- fine-tune complet --
    for p in model_ft.backbone.parameters(): p.requires_grad = True
    opt   = torch.optim.Adam(model_ft.parameters(),
                             lr=1e-4, weight_decay=1e-5)
    sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=10)

    for _ in range(10):
        model_ft.train()
        for x, y in train_loader:
            x = x.to(device)
            y = y.squeeze().long().to(device)
            loss = criterion(model_ft(x), y)
            opt.zero_grad(); loss.backward(); opt.step()
        sched.step()

    model_ft.eval()
    preds, gts = [], []
    with torch.no_grad():
        for x, y in val_loader:
            x = x.to(device)
            preds.append(model_ft(x).cpu())
            gts.append(y.squeeze())
    preds = torch.cat(preds).softmax(1)
    gts   = torch.cat(gts)

    acc = (preds.argmax(1) == gts).float().mean().item()

    y_true_bin = label_binarize(gts.numpy(), classes=range(n_classes))
    y_score    = preds.numpy()
    auc_macro  = roc_auc_score(y_true_bin, y_score,
                               average="macro", multi_class="ovr")

    '''emb, lab = extract_feats(model_ft, val_loader)
    if len(emb) > 2000:
        idx = torch.randperm(len(emb))[:2000]
        emb, lab = emb[idx], lab[idx]
    z = TSNE(n_components=2, perplexity=30, init="pca",
             n_iter=1000).fit_transform(emb)
    plt.figure(figsize=(5,5))
    plt.scatter(z[:,0], z[:,1], c=lab, s=3, cmap='tab10')
    plt.title(f"t-SNE fine-tune {int(pct*100)} %")
    plt.savefig(f"tsne_ft_{int(pct*100)}.png", dpi=300); plt.close()'''

    make_tsne(model_ft, val_loader,
          title=f"t-SNE fine-tune {int(pct*100)} %",
          fname=f"tsne_ft_{int(pct*100)}.png",
          n_classes=n_classes)

    if pct in PCTS_FOR_ROC:
        from sklearn.metrics import roc_curve, auc
        y_true = y_true_bin
        y_score = y_score

        fpr, tpr, roc_auc = {}, {}, {}
        for i in range(n_classes):
            fpr[i], tpr[i], _ = roc_curve(y_true[:, i], y_score[:, i])
            roc_auc[i] = auc(fpr[i], tpr[i])

        all_fpr  = np.unique(np.concatenate([fpr[i] for i in range(n_classes)]))
        mean_tpr = np.zeros_like(all_fpr)
        for i in range(n_classes):
            mean_tpr += np.interp(all_fpr, fpr[i], tpr[i])
        mean_tpr /= n_classes
        roc_auc["macro"] = auc(all_fpr, mean_tpr)

        plt.figure(figsize=(5,5))
        plt.plot(all_fpr, mean_tpr, lw=2,
                 label=f"macro (AUC = {roc_auc['macro']:.2f})")
        for i in range(n_classes):
            plt.plot(fpr[i], tpr[i], lw=1,
                     label=f"class {i} (AUC = {roc_auc[i]:.2f})")
        plt.plot([0,1],[0,1],"--",lw=1,color="k")
        plt.xlim([0,1]); plt.ylim([0,1.05])
        plt.xlabel("FPR"); plt.ylabel("TPR")
        plt.title(f"ROC – {int(pct*100)} % annot.")
        plt.legend(fontsize="x-small"); plt.tight_layout()
        plt.savefig(f"roc_{int(pct*100)}.png", dpi=300); plt.close()

    results[pct] = {"acc": acc, "auc": auc_macro}
    print(f"[{int(pct*100):3d}%] accuracy={acc:.3f} | macro-AUC={auc_macro:.3f}")

pcts_sorted = [int(p*100) for p in sorted(results.keys())]
accs = [results[p/100]["acc"] for p in pcts_sorted]
aucs = [results[p/100]["auc"] for p in pcts_sorted]

plt.figure(figsize=(6,4))
plt.plot(pcts_sorted, accs, "o-", label="Accuracy")
plt.plot(pcts_sorted, aucs, "s-", label="Macro-AUC")
plt.xlabel("% d'images étiquetées"); plt.ylabel("Score")
plt.title("Label-efficiency – Barlow Twins")
plt.grid(alpha=.3); plt.legend(); plt.tight_layout()
plt.savefig("label_efficiency_curve.png", dpi=300); plt.close()


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


[  1%] accuracy=0.811 | macro-AUC=0.975


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


[  5%] accuracy=0.902 | macro-AUC=0.992


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


[ 10%] accuracy=0.925 | macro-AUC=0.995


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


[ 20%] accuracy=0.947 | macro-AUC=0.997


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


[ 50%] accuracy=0.968 | macro-AUC=0.999


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


[ 80%] accuracy=0.978 | macro-AUC=0.999


  cmap = cm.get_cmap('tab10', n_classes)
  plt.colorbar(sm, ticks=range(n_classes), label="classe")


[100%] accuracy=0.981 | macro-AUC=1.000
