In [1]:
# Cell 0 — Installs + imports
!pip install timm --quiet

import os, random, math, sys
from pathlib import Path
from pprint import pprint
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid
import torchvision.models as models
import timm

# Device
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Device:", DEVICE)




Device: cuda


In [2]:
# Cell 1 — Find dataset root & metadata automatically (robust)
# The common Kaggle HAM10000 dataset folder names vary; try a few.
possible_roots = [
    "/kaggle/input/ham10000",
    "/kaggle/input/skin-cancer-mnist-ham10000",
    "/kaggle/input/hab10000",  # fallback
]
ROOT_CANDIDATES = [Path(p) for p in possible_roots if Path(p).exists()]
if len(ROOT_CANDIDATES) == 0:
    # try any folder under /kaggle/input
    root = Path("/kaggle/input")
    subs = list(root.iterdir()) if root.exists() else []
    if len(subs)>0:
        ROOT_CANDIDATES = [subs[0]]
        print("Auto selected:", subs[0])
    else:
        raise RuntimeError("Couldn't find /kaggle/input dataset. Place HAM10000 in Kaggle Input.")
DATA_ROOT = ROOT_CANDIDATES[0]
print("Using dataset root:", DATA_ROOT)

# try to locate image folder and metadata csv
# common layout: HAM10000_images/*.jpg and HAM10000_metadata.csv OR skin-cancer-mnist-ham10000/images/*.jpg and labels.csv
IMAGE_DIR = None
META_CSV = None
# search patterns
for p in DATA_ROOT.rglob("*"):
    if p.is_dir() and any(p.glob("*.jpg")):
        IMAGE_DIR = p
        break
# find metadata
for fname in ["HAM10000_metadata.csv","metadata.csv","labels.csv","HAM10000_images_metadata.csv"]:
    candidate = DATA_ROOT / fname
    if candidate.exists():
        META_CSV = candidate
        break
# fallback: find first csv
if META_CSV is None:
    csvs = list(DATA_ROOT.rglob("*.csv"))
    META_CSV = csvs[0] if csvs else None

print("Image dir:", IMAGE_DIR)
print("Metadata csv:", META_CSV)
if IMAGE_DIR is None:
    raise RuntimeError("No image folder (.jpg/.png) found under dataset root.")
if META_CSV is None:
    print("Warning: metadata CSV not found. We'll try to infer class labels from filenames if possible.")


Using dataset root: /kaggle/input/skin-cancer-mnist-ham10000
Image dir: /kaggle/input/skin-cancer-mnist-ham10000/HAM10000_images_part_1
Metadata csv: /kaggle/input/skin-cancer-mnist-ham10000/HAM10000_metadata.csv


In [3]:
# Cell 2 — Build metadata dataframe (robustly)
# Preferred: HAM10000 metadata with columns 'image_id' and 'dx' (paper uses 'dx' for diagnosis label)
if META_CSV is not None:
    df_meta = pd.read_csv(META_CSV)
    print("Loaded metadata columns:", df_meta.columns.tolist())
    # Try to standardize
    if 'image_id' in df_meta.columns and 'dx' in df_meta.columns:
        df_meta = df_meta[['image_id','dx']].copy()
    else:
        # try alternate names
        if 'lesion_id' in df_meta.columns and 'dx' in df_meta.columns:
            df_meta = df_meta[['lesion_id','dx']].rename(columns={'lesion_id':'image_id'})
        else:
            # attempt to map filename -> label if there is a column containing 'image' and one containing 'label'
            col_img = None
            col_lbl = None
            for c in df_meta.columns:
                if 'image' in c.lower():
                    col_img = c
                if any(x in c.lower() for x in ['dx','label','diagnosis','class']):
                    col_lbl = c
            if col_img and col_lbl:
                df_meta = df_meta[[col_img,col_lbl]].rename(columns={col_img:'image_id', col_lbl:'dx'})
            else:
                print("Could not find image_id/dx in metadata; showing head for inspection:")
                display(df_meta.head())
                # we'll not error here; fallback to filename scanning later
else:
    df_meta = None


Loaded metadata columns: ['lesion_id', 'image_id', 'dx', 'dx_type', 'age', 'sex', 'localization']


In [4]:
# Cell 2 — Build metadata dataframe (robustly)
# Preferred: HAM10000 metadata with columns 'image_id' and 'dx' (paper uses 'dx' for diagnosis label)
if META_CSV is not None:
    df_meta = pd.read_csv(META_CSV)
    print("Loaded metadata columns:", df_meta.columns.tolist())
    # Try to standardize
    if 'image_id' in df_meta.columns and 'dx' in df_meta.columns:
        df_meta = df_meta[['image_id','dx']].copy()
    else:
        # try alternate names
        if 'lesion_id' in df_meta.columns and 'dx' in df_meta.columns:
            df_meta = df_meta[['lesion_id','dx']].rename(columns={'lesion_id':'image_id'})
        else:
            # attempt to map filename -> label if there is a column containing 'image' and one containing 'label'
            col_img = None
            col_lbl = None
            for c in df_meta.columns:
                if 'image' in c.lower():
                    col_img = c
                if any(x in c.lower() for x in ['dx','label','diagnosis','class']):
                    col_lbl = c
            if col_img and col_lbl:
                df_meta = df_meta[[col_img,col_lbl]].rename(columns={col_img:'image_id', col_lbl:'dx'})
            else:
                print("Could not find image_id/dx in metadata; showing head for inspection:")
                display(df_meta.head())
                # we'll not error here; fallback to filename scanning later
else:
    df_meta = None


Loaded metadata columns: ['lesion_id', 'image_id', 'dx', 'dx_type', 'age', 'sex', 'localization']


In [5]:
# Cell 3 — Dataset + transforms
IMG_SIZE_GAN = 64         # GAN image size for speed (paper used 256 but 64/128 speeds up)
IMG_SIZE_CLS = 224        # classifier input
BATCH_SIZE = 32
NUM_WORKERS = 2

# GAN transforms (normalize to [-1,1])
gan_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE_GAN, IMG_SIZE_GAN)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

# For Barlow Twins we need two augmented views: define a small augmentation pipeline
bt_aug = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE_GAN, scale=(0.8,1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(0.1,0.1,0.1,0.05),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3),
])

# classifier transforms
cls_train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE_CLS, IMG_SIZE_CLS)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])

cls_val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE_CLS, IMG_SIZE_CLS)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]),
])


In [6]:
# Cell 4 — Utility: build per-class path lists using metadata if available
from collections import defaultdict

def build_class_index(image_dir: Path, meta_df: pd.DataFrame = None):
    """
    Returns: dict label -> list(paths)
    """
    class_paths = defaultdict(list)
    if meta_df is not None and 'image_id' in meta_df.columns and 'dx' in meta_df.columns:
        # metadata image_id may be w/o extension. try common extensions
        for _, row in meta_df.iterrows():
            img_id = str(row['image_id'])
            lbl = str(row['dx'])
            found = None
            for ext in ['.jpg','.png','.jpeg','.JPG']:
                p = image_dir / (img_id + ext)
                if p.exists():
                    found = p
                    break
            if found:
                class_paths[lbl].append(str(found))
    # fallback: scan filenames and attempt to match label tokens (not reliable but fallback)
    if len(class_paths)==0:
        # map using folder names or filename tokens
        for p in image_dir.rglob("*"):
            if p.suffix.lower() in (".jpg",".png",".jpeg"):
                name = p.name.lower()
                # try to detect common HAM10000 labels inside filenames (mel, nv, bkl, etc.)
                tokens = ['nv','mel','bkl','akiec','bcc','vasc','df']
                matched = None
                for t in tokens:
                    if t in name:
                        matched = t
                        break
                if matched:
                    class_paths[matched].append(str(p))
                else:
                    # place in 'unknown'
                    class_paths['unknown'].append(str(p))
    return class_paths

class_index = build_class_index(IMAGE_DIR, df_meta)
for k in list(class_index.keys()):
    print(k, len(class_index[k]))


bkl 564
nv 3431
df 56
mel 435
vasc 65
bcc 266
akiec 183


In [7]:
# Cell 5 — Dataset classes used for Stage-2 fine-tuning
# Determine 7 classes if present, else use keys from class_index
CLS_NAMES = sorted(list(class_index.keys()))
print("Detected classes:", CLS_NAMES)
# If metadata has the official labels order, keep that. If not, we'll use detected keys.


Detected classes: ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']


In [8]:
# Cell 6 — PyTorch datasets for GAN and per-class use
class SimpleImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img

# full dataset for Stage-1
all_image_paths = []
for v in class_index.values():
    all_image_paths += v
all_image_paths = sorted(set(all_image_paths))
print("Total unique images found:", len(all_image_paths))
full_dataset = SimpleImageDataset(all_image_paths, transform=gan_transform)
full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)


Total unique images found: 5000


In [9]:
# Cell 7 — Simple DCGAN-like generator & discriminator (paper used StyleGAN2, but for Kaggle speed we use lightweight model)
LATENT_DIM = 100

class Gen(nn.Module):
    def __init__(self, latent_dim=LATENT_DIM, ngf=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, ngf*8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf*8), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*4), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf*2), nn.ReLU(True),
            nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf), nn.ReLU(True),
            nn.ConvTranspose2d(ngf, 3, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self, z):
        return self.net(z)

class Disc(nn.Module):
    def __init__(self, ndf=64, return_features=False):
        super().__init__()
        self.return_features = return_features
        self.conv1 = nn.Sequential(nn.Conv2d(3, ndf, 4, 2, 1), nn.LeakyReLU(0.2, inplace=True))
        self.conv2 = nn.Sequential(nn.Conv2d(ndf, ndf*2, 4, 2, 1), nn.BatchNorm2d(ndf*2), nn.LeakyReLU(0.2, inplace=True))
        self.conv3 = nn.Sequential(nn.Conv2d(ndf*2, ndf*4, 4, 2, 1), nn.BatchNorm2d(ndf*4), nn.LeakyReLU(0.2, inplace=True))
        self.conv4 = nn.Sequential(nn.Conv2d(ndf*4, ndf*8, 4, 2, 1), nn.BatchNorm2d(ndf*8), nn.LeakyReLU(0.2, inplace=True))
        self.head = nn.Conv2d(ndf*8, 1, 4, 1, 0)
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        feat = self.conv4(x)   # penultimate features (spatial)
        out = self.head(feat)
        out = out.view(-1,1)
        if self.return_features:
            # global-average pool features into vector for Barlow Twins
            feat_vec = torch.flatten(torch.mean(feat, dim=[2,3]), 1)
            return out, feat_vec
        return out


In [10]:
# Cell 8 — Instantiate models + optimizers
G = Gen().to(DEVICE)
D = Disc(return_features=True).to(DEVICE)

criterion_bce = nn.BCEWithLogitsLoss()   # use logits for stability (we used head w/o sigmoid ideally)
optim_G = optim.Adam(G.parameters(), lr=2.5e-4, betas=(0.5,0.999))
optim_D = optim.Adam(D.parameters(), lr=2.5e-4, betas=(0.5,0.999))


In [11]:
# Cell 9 — Helper: save sample grid
fixed_z = torch.randn(16, LATENT_DIM,1,1, device=DEVICE)
def save_sample_grid(epoch, folder="/kaggle/working/gan_samples"):
    os.makedirs(folder, exist_ok=True)
    G.eval()
    with torch.no_grad():
        imgs = G(fixed_z).cpu()
    save_image((imgs+1)/2, os.path.join(folder, f"epoch_{epoch:03d}.png"), nrow=4)
    G.train()


In [12]:
# Cell 10 — Stage-1: Train unconditional GAN (fast-friendly settings)
EPOCHS_STAGE1 = 20   # paper used longer; reduce for Kaggle demo/time
print("Stage-1 training (unconditional GAN) ...")
for epoch in range(EPOCHS_STAGE1):
    running_d = 0.0
    running_g = 0.0
    for i, real in enumerate(full_loader):
        real = real.to(DEVICE)
        bs = real.size(0)
        # labels
        real_label = torch.ones(bs,1, device=DEVICE)
        fake_label = torch.zeros(bs,1, device=DEVICE)

        # Train D
        optim_D.zero_grad()
        out_real, feat_real = D(real)
        z = torch.randn(bs, LATENT_DIM,1,1, device=DEVICE)
        fake = G(z)
        out_fake, feat_fake = D(fake.detach())
        # BCE with logits: our D returns logits if we used head w/o sigmoid; if not, adjust
        loss_d_real = criterion_bce(out_real, real_label)
        loss_d_fake = criterion_bce(out_fake, fake_label)
        loss_D = (loss_d_real + loss_d_fake) * 0.5
        loss_D.backward()
        optim_D.step()

        # Train G
        optim_G.zero_grad()
        out_fake_for_g, _ = D(fake)
        loss_G = criterion_bce(out_fake_for_g, real_label)
        loss_G.backward()
        optim_G.step()

        running_d += loss_D.item()
        running_g += loss_G.item()

    avg_d = running_d / (i+1)
    avg_g = running_g / (i+1)
    print(f"Epoch {epoch+1}/{EPOCHS_STAGE1} | D {avg_d:.4f} | G {avg_g:.4f}")
    if (epoch+1) % 5 == 0:
        save_sample_grid(epoch+1)


Stage-1 training (unconditional GAN) ...
Epoch 1/20 | D 0.1244 | G 9.1847
Epoch 2/20 | D 0.2266 | G 6.0280
Epoch 3/20 | D 0.3328 | G 3.9286
Epoch 4/20 | D 0.2368 | G 4.1721
Epoch 5/20 | D 0.2766 | G 3.9763
Epoch 6/20 | D 0.2445 | G 3.8996
Epoch 7/20 | D 0.2751 | G 3.6109
Epoch 8/20 | D 0.3088 | G 3.5989
Epoch 9/20 | D 0.3236 | G 3.5350
Epoch 10/20 | D 0.2937 | G 3.7731
Epoch 11/20 | D 0.2801 | G 3.8494
Epoch 12/20 | D 0.2296 | G 4.2762
Epoch 13/20 | D 0.2569 | G 4.1268
Epoch 14/20 | D 0.2328 | G 4.4283
Epoch 15/20 | D 0.1525 | G 4.7991
Epoch 16/20 | D 0.2337 | G 4.4133
Epoch 17/20 | D 0.2025 | G 4.7696
Epoch 18/20 | D 0.1842 | G 4.7603
Epoch 19/20 | D 0.1914 | G 4.8185
Epoch 20/20 | D 0.1551 | G 5.1287


In [13]:
# Cell 11 — Stage-2: per-class fine-tuning with Freeze-D + lightweight Barlow Twins
# Implementation notes:
# - We will copy G and D weights from stage-1 (they already are G and D).
# - Freeze-D: freeze top n layers (we implement by freezing conv1 & conv2 to match "freeze highest-resolution")
# - Barlow Twins style loss: compute cross-correlation on batch features from two augmented views.
# - For stability, we will train only for a few epochs per class and use small subsets (paper used full class data, but smaller runs for speed).

def freeze_D_layers(D_model, freeze_until_layer_idx=2):
    # freeze first N conv blocks (conv1, conv2, ...)
    blocks = [D_model.conv1, D_model.conv2, D_model.conv3, D_model.conv4]
    for i, b in enumerate(blocks):
        requires = False if i < freeze_until_layer_idx else True
        for p in b.parameters():
            p.requires_grad = requires
    print(f"Freeze-D: frozen first {freeze_until_layer_idx} blocks")

# Barlow Twins loss helper
def barlow_twins_loss(z_a, z_b, lam=5e-3):
    # z_a, z_b shape: [B, D]
    B, D = z_a.size()
    # normalize (zero mean per dim, unit var)
    z_a = (z_a - z_a.mean(0)) / (z_a.std(0) + 1e-9)
    z_b = (z_b - z_b.mean(0)) / (z_b.std(0) + 1e-9)
    c = (z_a.T @ z_b) / B    # cross-correlation
    on_diag = torch.diagonal(c).add_(-1).pow(2).sum()
    off_diag = (c.pow(2).sum() - torch.diagonal(c).pow(2).sum())
    loss = on_diag + lam * off_diag
    return loss

# Stage-2 loop parameters
EPOCHS_STAGE2 = 10   # per class (reduce for Kaggle)
freeze_blocks = 2    # freeze first 2 conv blocks (experiment)

# Hyperparams
lambda_ss = 0.1  # weight for self-supervised loss relative to GAN loss (paper had lambda_ss)

# Iterate classes (you can pick a subset if you want faster)
classes_to_run = CLS_NAMES  # you can set e.g., ['mel','nv'] to run fewer
print("Stage-2 classes:", classes_to_run)

for class_name in classes_to_run:
    paths = class_index.get(class_name, [])
    if len(paths) < 8:
        print(f"Skipping class {class_name} (too few samples: {len(paths)})")
        continue
    print("Fine-tuning for class:", class_name, "num samples:", len(paths))

    # build loader — use all class images or a subset for speed
    subset_paths = paths  # optionally random.sample(paths, min(len(paths), 200))
    class_dataset = SimpleImageDataset(subset_paths, transform=gan_transform)
    class_loader = DataLoader(class_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

    # Freeze-D
    freeze_D_layers(D, freeze_until_layer_idx=freeze_blocks)

    # per-class fine-tuning: keep separate optimizers (but here we reuse optimizers; ensure only params requiring_grad present)
    optD = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), lr=2.5e-4, betas=(0.5,0.999))
    optG = optim.Adam(G.parameters(), lr=2.5e-4, betas=(0.5,0.999))

    for epoch in range(EPOCHS_STAGE2):
        d_loss_running = 0.0
        g_loss_running = 0.0
        for i, real in enumerate(class_loader):
            real = real.to(DEVICE)
            bs = real.size(0)
            real_labels = torch.ones(bs,1, device=DEVICE)
            fake_labels = torch.zeros(bs,1, device=DEVICE)

            # ---------------- D step ----------------
            optD.zero_grad()
            # real out + features (we use return_features=True)
            out_real, feat_real = D(real)
            z = torch.randn(bs, LATENT_DIM,1,1, device=DEVICE)
            fake = G(z)
            out_fake, feat_fake = D(fake.detach())

            loss_d = 0.5*(criterion_bce(out_real, real_labels) + criterion_bce(out_fake, fake_labels))
            # add Barlow Twins: compute augmented views features from same real batch
            # create 2 augmented views via bt_aug (PIL transforms)
            x_a = torch.stack([bt_aug(Image.open(p).convert('RGB')) for p in random.sample(subset_paths, bs)]) if len(subset_paths)>=bs else bt_aug(Image.open(subset_paths[0]).convert('RGB')).unsqueeze(0)
            # But creating bt views from random sample above complicates mapping; simpler: apply augment on 'real' tensor by re-loading
            # For simplicity & robustness, compute Barlow Twins on discriminator features of real and fake (two views: real augment + generated augment)
            # We'll approximate: compute features for two augmentations of the real batch:
            with torch.no_grad():
                # reconstruct images from tensor to PIL, apply augment -> re-tensor
                real_cpu = [transforms.ToPILImage()( (r.cpu()*0.5+0.5).clamp(0,1) ) for r in real]
            a_views = torch.stack([bt_aug(img) for img in real_cpu]).to(DEVICE)
            b_views = torch.stack([bt_aug(img) for img in real_cpu]).to(DEVICE)
            _, feat_a = D(a_views)
            _, feat_b = D(b_views)
            loss_bt = barlow_twins_loss(feat_a, feat_b, lam=5e-3)
            # Combine
            loss_D = loss_d + lambda_ss * loss_bt
            loss_D.backward()
            optD.step()

            # ---------------- G step ----------------
            optG.zero_grad()
            z = torch.randn(bs, LATENT_DIM,1,1, device=DEVICE)
            fake = G(z)
            out_fake_for_g, feat_fake_for_g = D(fake)
            loss_gan = criterion_bce(out_fake_for_g, real_labels)

            # also compute a lightweight self-supervised alignment: Barlow between generated features and real features
            # compute features for generated augment views
            with torch.no_grad():
                fake_cpu = [transforms.ToPILImage()( (f.cpu()*0.5+0.5).clamp(0,1) ) for f in fake]
            a_fake_views = torch.stack([bt_aug(img) for img in fake_cpu]).to(DEVICE)
            b_fake_views = torch.stack([bt_aug(img) for img in fake_cpu]).to(DEVICE)
            _, feat_fake_a = D(a_fake_views)
            _, feat_fake_b = D(b_fake_views)
            loss_bt_gen = barlow_twins_loss(feat_fake_a, feat_fake_b, lam=5e-3)

            loss_G = loss_gan + lambda_ss * loss_bt_gen * 0.5   # smaller weight for generator BT
            loss_G.backward()
            optG.step()

            d_loss_running += loss_D.item()
            g_loss_running += loss_G.item()

        avg_d = d_loss_running / (i+1)
        avg_g = g_loss_running / (i+1)
        print(f"[{class_name}][Epoch {epoch+1}/{EPOCHS_STAGE2}] D {avg_d:.4f} | G {avg_g:.4f}")

    # After class fine-tuning, save some examples
    os.makedirs(f"/kaggle/working/synth/{class_name}", exist_ok=True)
    G.eval()
    with torch.no_grad():
        z = torch.randn(25, LATENT_DIM,1,1, device=DEVICE)
        imgs = G(z)
        for idx,img in enumerate(imgs):
            save_image((img+1)/2, f"/kaggle/working/synth/{class_name}/{class_name}_{idx:03d}.png")
    G.train()
    print("Saved synthetic images for class", class_name)


Stage-2 classes: ['akiec', 'bcc', 'bkl', 'df', 'mel', 'nv', 'vasc']
Fine-tuning for class: akiec num samples: 183
Freeze-D: frozen first 2 blocks
[akiec][Epoch 1/10] D 30.9439 | G 16.1852
[akiec][Epoch 2/10] D 23.8722 | G 13.7313
[akiec][Epoch 3/10] D 22.2344 | G 14.0105
[akiec][Epoch 4/10] D 20.9452 | G 13.6140
[akiec][Epoch 5/10] D 21.3790 | G 14.0787
[akiec][Epoch 6/10] D 20.5682 | G 13.2922
[akiec][Epoch 7/10] D 20.1214 | G 12.3527
[akiec][Epoch 8/10] D 19.4358 | G 13.5350
[akiec][Epoch 9/10] D 19.1912 | G 13.0068
[akiec][Epoch 10/10] D 18.8304 | G 13.4582
Saved synthetic images for class akiec
Fine-tuning for class: bcc num samples: 266
Freeze-D: frozen first 2 blocks
[bcc][Epoch 1/10] D 18.3527 | G 15.0512
[bcc][Epoch 2/10] D 18.9183 | G 14.0357
[bcc][Epoch 3/10] D 18.4101 | G 13.5439
[bcc][Epoch 4/10] D 16.9837 | G 14.1138
[bcc][Epoch 5/10] D 17.2333 | G 13.0958
[bcc][Epoch 6/10] D 17.1083 | G 12.7602
[bcc][Epoch 7/10] D 16.9925 | G 12.6168
[bcc][Epoch 8/10] D 16.7092 | G 12.749

  z_a = (z_a - z_a.mean(0)) / (z_a.std(0) + 1e-9)
  z_b = (z_b - z_b.mean(0)) / (z_b.std(0) + 1e-9)
  npimg = (npimg * 255).astype(np.uint8)


[vasc][Epoch 2/10] D nan | G nan
[vasc][Epoch 3/10] D nan | G nan
[vasc][Epoch 4/10] D nan | G nan
[vasc][Epoch 5/10] D nan | G nan
[vasc][Epoch 6/10] D nan | G nan
[vasc][Epoch 7/10] D nan | G nan
[vasc][Epoch 8/10] D nan | G nan
[vasc][Epoch 9/10] D nan | G nan
[vasc][Epoch 10/10] D nan | G nan
Saved synthetic images for class vasc


In [14]:
# Cell 12 — Combine real + synthetic and prepare T-ResNet50 training set
# Collect generated images into a training folder per class
SYN_ROOT = Path("/kaggle/working/synth")
augmented_train_paths = []
augmented_train_labels = []
# real training split: create an 80/20 split using metadata if available
if df_meta is not None:
    # create train/test split consistent with paper 8:2
    from sklearn.model_selection import train_test_split
    ids = df_meta['image_id'].astype(str).tolist()
    train_ids, val_ids = train_test_split(ids, test_size=0.2, random_state=42, stratify=df_meta['dx'])
    # build train list paths
    for tid in train_ids:
        for ext in ['.jpg','.png','.jpeg']:
            p = IMAGE_DIR / (str(tid)+ext)
            if p.exists():
                lab = df_meta[df_meta['image_id']==tid]['dx'].values[0]
                augmented_train_paths.append(str(p))
                augmented_train_labels.append(lab)
                break
else:
    # fallback: use 80% of full list
    split = int(0.8 * len(all_image_paths))
    augmented_train_paths = all_image_paths[:split]
    augmented_train_labels = ['unknown'] * len(augmented_train_paths)

# add synthetic images (a few per class)
for cls in os.listdir(SYN_ROOT) if SYN_ROOT.exists() else []:
    p = SYN_ROOT / cls
    imgs = sorted([str(x) for x in p.glob("*.png")])
    for im in imgs:
        augmented_train_paths.append(im)
        augmented_train_labels.append(cls)

print("Total training samples after augmentation:", len(augmented_train_paths))


Total training samples after augmentation: 4185


In [15]:
# Cell 13 — Dataset for classifier training
class CldsDataset(Dataset):
    def __init__(self, paths, labels, transform=None, class_map=None):
        self.paths = paths
        self.labels = labels
        self.transform = transform
        # build class_to_idx
        if class_map is None:
            unique = sorted(list(set(labels)))
            self.class_map = {c:i for i,c in enumerate(unique)}
        else:
            self.class_map = class_map
    def __len__(self):
        return len(self.paths)
    def __getitem__(self, idx):
        p = self.paths[idx]
        img = Image.open(p).convert('RGB')
        if self.transform:
            img = self.transform(img)
        lbl = self.class_map[self.labels[idx]]
        return img, lbl

# Train/val split
from sklearn.model_selection import train_test_split
train_paths, val_paths, train_labels, val_labels = train_test_split(augmented_train_paths, augmented_train_labels, test_size=0.2, stratify=augmented_train_labels if len(set(augmented_train_labels))>1 else None, random_state=42)
class_map = {c:i for i,c in enumerate(sorted(set(train_labels)))}
train_ds = CldsDataset(train_paths, train_labels, transform=cls_train_transform, class_map=class_map)
val_ds = CldsDataset(val_paths, val_labels, transform=cls_val_transform, class_map=class_map)
train_loader = DataLoader(train_ds, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=32, shuffle=False, num_workers=2)
print("Classes map:", class_map)


Classes map: {'akiec': 0, 'bcc': 1, 'bkl': 2, 'df': 3, 'mel': 4, 'nv': 5, 'vasc': 6}


In [16]:
# Cell 14 — T-ResNet50 (paper: ResNet50 + FC:2048->128->7)
num_classes = len(class_map)
resnet = models.resnet50(pretrained=True)
# Replace fc
resnet.fc = nn.Sequential(
    nn.Linear(2048,128),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(128, num_classes)
)
resnet = resnet.to(DEVICE)

# loss + optimizer + scheduler
cls_criterion = nn.CrossEntropyLoss()
cls_optimizer = optim.SGD(resnet.parameters(), lr=0.02, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(cls_optimizer, T_max=10)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 242MB/s]


In [17]:
# Cell 15 — Train classifier (fast)
EPOCHS_CLS = 6
best_acc = 0.0
for epoch in range(EPOCHS_CLS):
    # train
    resnet.train()
    running_loss = 0.0
    total = 0
    correct = 0
    for images, labels in train_loader:
        images = images.to(DEVICE); labels = labels.to(DEVICE)
        cls_optimizer.zero_grad()
        out = resnet(images)
        loss = cls_criterion(out, labels)
        loss.backward()
        cls_optimizer.step()
        running_loss += loss.item()
        preds = out.argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    train_acc = correct/total
    scheduler.step()
    # val
    resnet.eval()
    val_loss = 0.0
    total = 0
    correct = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(DEVICE); labels=labels.to(DEVICE)
            out = resnet(images)
            loss = cls_criterion(out, labels)
            val_loss += loss.item()
            preds = out.argmax(1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    val_acc = correct/total
    print(f"Epoch {epoch+1}/{EPOCHS_CLS} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f}")
    if val_acc > best_acc:
        best_acc = val_acc
        torch.save(resnet.state_dict(), "/kaggle/working/t_resnet50_best.pth")
print("Best val acc:", best_acc)


Epoch 1/6 | Train Acc: 0.6738 | Val Acc: 0.7384
Epoch 2/6 | Train Acc: 0.7222 | Val Acc: 0.7097
Epoch 3/6 | Train Acc: 0.7458 | Val Acc: 0.7431
Epoch 4/6 | Train Acc: 0.7754 | Val Acc: 0.7348
Epoch 5/6 | Train Acc: 0.7855 | Val Acc: 0.7599
Epoch 6/6 | Train Acc: 0.8100 | Val Acc: 0.7718
Best val acc: 0.7718040621266428


In [18]:
# Cell 16 — Optional: Test-time augmentation (TTA) inference for val set
# example TTA: horizontal flip and center crop (average probs)
tta_transforms = [
    cls_val_transform,
    transforms.Compose([transforms.Resize((IMG_SIZE_CLS, IMG_SIZE_CLS)), transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(),
                       transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
]

def tta_predict(model, pil_img, tta_transforms):
    model.eval()
    probs = []
    with torch.no_grad():
        for t in tta_transforms:
            x = t(pil_img).unsqueeze(0).to(DEVICE)
            out = model(x)
            probs.append(torch.softmax(out, dim=1))
    return torch.mean(torch.stack(probs), dim=0)

# quick TTA eval on val set (first 50 images)
model = resnet
correct = 0; total = 0
for i, (p,l) in enumerate(zip(val_paths[:50], val_labels[:50])):
    pil = Image.open(p).convert('RGB')
    prob = tta_predict(model, pil, tta_transforms)
    pred = prob.argmax(1).item()
    if pred == class_map[l]: correct += 1
    total += 1
print("TTA quick accuracy (first 50 val):", correct/total)


TTA quick accuracy (first 50 val): 0.8
