In [1]:
# stgan_kaggle_friendly.py
# Kaggle-friendly STGAN-lite: two-stage GAN with Freeze-D + Barlow Twins (SSL)
# - Lightweight conv generator/discriminator (fits Kaggle GPUs)
# - Stage-1: train unconditional GAN on all images (128x128)
# - Stage-2: for each class, fine-tune from Stage-1 with Freeze-D + Barlow Twins self-supervised loss
# - Generate synthetic images, combine with reals, train T-ResNet50 classifier (TTA optional)
# Usage: run as a script in Kaggle notebook or terminal

import os, shutil, random, math, time
from pathlib import Path
import argparse
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
from torchvision.datasets import ImageFolder
import torchvision.models as models
from torchvision.utils import save_image, make_grid

# ---------------------------
# Config (edit for your environment)
# ---------------------------
DATA_DIR = os.getenv('DATA_DIR', '/kaggle/input/skin-cancer-mnist-ham10000')  # original files
WORK_DIR = os.getenv('WORK_DIR', '/kaggle/working/stgan_lite')
REAL_IMGS_DIR = os.path.join(WORK_DIR, 'real_images')  # per-class folders will be created here
STAGE1_DIR = os.path.join(WORK_DIR, 'stage1')
STAGE2_DIR = os.path.join(WORK_DIR, 'stage2')
SYNTH_DIR = os.path.join(WORK_DIR, 'synth')
COMBINED_DIR = os.path.join(WORK_DIR, 'combined')
CLASS_NAMES = ["akiec","bcc","bkl","df","mel","nv","vasc"]
IMG_SIZE = 128   # small to be Kaggle-friendly
BATCH = 32
LATENT_DIM = 128
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
torch.manual_seed(SEED); random.seed(SEED); np.random.seed(SEED)
os.makedirs(WORK_DIR, exist_ok=True)

# ---------------------------
# Helpers: organize HAM10000 into per-class folders if needed
# ---------------------------
def organize_ham10000(data_dir, out_dir):
    # expects HAM10000 metadata CSV and image files inside data_dir
    import csv
    meta = Path(data_dir) / "HAM10000_metadata.csv"
    if not meta.exists():
        print("Metadata CSV not found in", data_dir)
        return
    outp = Path(out_dir)
    outp.mkdir(parents=True, exist_ok=True)
    for c in CLASS_NAMES:
        (outp/c).mkdir(exist_ok=True)
    # find images (jpg/png)
    for img_path in Path(data_dir).glob("**/*"):
        if img_path.suffix.lower() in ['.jpg','.jpeg','.png']:
            stem = img_path.stem
            # read metadata mapping
            # we'll parse CSV once
    # load mapping
    mapping = {}
    with open(meta, 'r') as f:
        reader = csv.DictReader(f)
        for r in reader:
            mapping[r['image_id']] = r['dx'].lower()
    # copy
    for img_path in Path(data_dir).glob("**/*"):
        if img_path.suffix.lower() in ['.jpg','.jpeg','.png']:
            id = img_path.stem
            if id in mapping and mapping[id] in CLASS_NAMES:
                dst = Path(out_dir)/mapping[id]/img_path.name
                if not dst.exists():
                    shutil.copy(img_path, dst)
    print("Organized images under", out_dir)

# ---------------------------
# Dataset
# ---------------------------
class ImgFolderDataset(Dataset):
    def __init__(self, folder, transform=None):
        self.paths = list(Path(folder).rglob("*.jpg")) + list(Path(folder).rglob("*.png"))
        self.transform = transform
    def __len__(self): return len(self.paths)
    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        if self.transform: img = self.transform(img)
        return img

# ---------------------------
# Small generator & discriminator (GAN-lite)
# ---------------------------
def conv_block(in_ch, out_ch, kernel=3, stride=1, padding=1, activation=True):
    layers = [nn.Conv2d(in_ch, out_ch, kernel, stride, padding)]
    layers.append(nn.BatchNorm2d(out_ch))
    if activation: layers.append(nn.LeakyReLU(0.2, inplace=True))
    return nn.Sequential(*layers)

class SimpleGenerator(nn.Module):
    def __init__(self, latent_dim=128, out_channels=3, fmap=64, img_size=128):
        super().__init__()
        self.init_size = img_size // 16  # we will upsample 4x (x2 x2 x2 x2)
        self.l1 = nn.Linear(latent_dim, fmap*8 * self.init_size * self.init_size)
        self.conv_blocks = nn.Sequential(
            conv_block(fmap*8, fmap*8),
            nn.Upsample(scale_factor=2),
            conv_block(fmap*8, fmap*4),
            nn.Upsample(scale_factor=2),
            conv_block(fmap*4, fmap*2),
            nn.Upsample(scale_factor=2),
            conv_block(fmap*2, fmap),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(fmap, out_channels, 3, 1, 1),
            nn.Tanh()
        )
    def forward(self, z):
        out = self.l1(z).view(z.size(0), -1, self.init_size, self.init_size)
        return self.conv_blocks(out)

class SimpleDiscriminator(nn.Module):
    def __init__(self, in_channels=3, fmap=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, fmap, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),
            conv_block(fmap, fmap*2),
            nn.AvgPool2d(2),
            conv_block(fmap*2, fmap*4),
            nn.AvgPool2d(2),
            conv_block(fmap*4, fmap*8),
            nn.AdaptiveAvgPool2d(1),
        )
        self.fc = nn.Linear(fmap*8, 1)
        # small projector for Barlow Twins (optional)
        self.projector = nn.Sequential(nn.Linear(fmap*8, 256), nn.ReLU(), nn.Linear(256,128))

    def forward(self, x, return_feat=False):
        f = self.model(x).view(x.size(0), -1)
        logits = self.fc(f)
        if return_feat:
            return logits.view(-1,1), f, self.projector(f)
        return logits.view(-1,1)

# ---------------------------
# Barlow Twins loss (simple)
# ---------------------------
def barlow_twins_loss(z_a, z_b, lam_offdiag=0.005):
    # z_a, z_b: [B, D] already projected; normalize feature dims
    B, D = z_a.size()
    # normalize batchwise per-dim
    za = (z_a - z_a.mean(0)) / (z_a.std(0) + 1e-9)
    zb = (z_b - z_b.mean(0)) / (z_b.std(0) + 1e-9)
    C = (za.T @ zb) / B  # D x D
    on_diag = torch.diagonal(C).add_(-1).pow(2).sum()
    off_diag = (C.pow(2).sum() - torch.diagonal(C).pow(2).sum())
    return on_diag + lam_offdiag * off_diag

# ---------------------------
# Augmentations
# ---------------------------
def get_gan_transforms(img_size):
    tf = T.Compose([
        T.Resize((img_size, img_size)),
        T.RandomHorizontalFlip(),
        T.RandomRotation(10),
        T.ColorJitter(0.1,0.1,0.1,0.05),
        T.ToTensor(),
        T.Normalize([0.5]*3, [0.5]*3)
    ])
    return tf

# ---------------------------
# Training utilities
# ---------------------------
bce_loss = nn.BCEWithLogitsLoss()

def train_gan_stage1(generator, discriminator, dataloader, epochs=6, lr=2e-4, save_dir=STAGE1_DIR):
    device = DEVICE
    G, D = generator.to(device), discriminator.to(device)
    optG = optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
    optD = optim.Adam(D.parameters(), lr=lr, betas=(0.5,0.999))
    os.makedirs(save_dir, exist_ok=True)
    for epoch in range(epochs):
        G.train(); D.train()
        for real in dataloader:
            real = real.to(device)
            bs = real.size(0)
            # Discriminator step
            z = torch.randn(bs, LATENT_DIM, device=device)
            fake = G(z)
            logits_real = D(real)
            logits_fake = D(fake.detach())
            lossD = bce_loss(logits_real, torch.ones_like(logits_real)) + bce_loss(logits_fake, torch.zeros_like(logits_fake))
            optD.zero_grad(); lossD.backward(); optD.step()
            # Generator step
            logits_fake2 = D(fake)
            lossG = bce_loss(logits_fake2, torch.ones_like(logits_fake2))
            optG.zero_grad(); lossG.backward(); optG.step()
        print(f"[Stage1] Epoch {epoch+1}/{epochs} | D_loss: {lossD.item():.4f} | G_loss: {lossG.item():.4f}")
        # save snapshots
        torch.save(G.state_dict(), os.path.join(save_dir, f"G_epoch{epoch+1}.pth"))
        torch.save(D.state_dict(), os.path.join(save_dir, f"D_epoch{epoch+1}.pth"))
    return G, D

def freeze_discriminator_top_layers(D: nn.Module, num_top=1):
    # heuristic: freeze last N layers of the sequential blocks (we inspect named_parameters order)
    params = list(D.named_parameters())
    if num_top <= 0: return
    # freeze last 20% parameters per 'num_top' unit; simpler: freeze last num_top parameter groups
    for name, p in params[-(num_top*3):]:
        p.requires_grad = False

def train_gan_stage2_per_class(class_name, class_folder, G_init, D_init, epochs=4, lr=1e-4,
                               freeze_d_num_layers=2, barlow_lambda=1.0, batch=BATCH, save_root=STAGE2_DIR):
    """
    Fine-tune from global G_init/D_init on class_folder.
    Adds Freeze-D and Barlow Twins SSL (on real images in batch).
    """
    os.makedirs(save_root, exist_ok=True)
    # dataset/dataloader
    tf = get_gan_transforms(IMG_SIZE)
    ds = ImgFolderDataset(class_folder, transform=tf)
    if len(ds) < 4:
        print(f"Class {class_name} has {len(ds)} images; skipping stage2 fine-tune (too few).")
        return None, None
    dl = DataLoader(ds, batch_size=batch, shuffle=True, drop_last=True)
    # models
    G = SimpleGenerator(LATENT_DIM, img_size=IMG_SIZE).to(DEVICE)
    D = SimpleDiscriminator().to(DEVICE)
    G.load_state_dict(torch.load(G_init, map_location=DEVICE))
    D.load_state_dict(torch.load(D_init, map_location=DEVICE))
    optG = optim.Adam(G.parameters(), lr=lr, betas=(0.5,0.999))
    optD = optim.Adam([p for p in D.parameters() if p.requires_grad], lr=lr*0.5, betas=(0.5,0.999))
    # Freeze-D
    freeze_discriminator_top_layers(D, freeze_d_num_layers)
    print(f"[Stage2:{class_name}] freeze top layers applied; trainable params:", sum(p.requires_grad for p in D.parameters()))
    for epoch in range(epochs):
        G.train(); D.train()
        for real in dl:
            real = real.to(DEVICE); bs = real.size(0)
            # Discriminator standard GAN loss
            z = torch.randn(bs, LATENT_DIM, device=DEVICE)
            fake = G(z)
            logits_real, _, _ = D(real, return_feat=True)
            logits_fake, _, _ = D(fake.detach(), return_feat=True)
            lossD = bce_loss(logits_real, torch.ones_like(logits_real)) + bce_loss(logits_fake, torch.zeros_like(logits_fake))
            # Barlow Twins SSL on real images (two augmented views)
            # create two augmentations: using light jitter transforms
            aug = T.Compose([T.Resize((IMG_SIZE,IMG_SIZE)), T.RandomHorizontalFlip(), T.ColorJitter(0.1,0.1,0.1,0.05), T.ToTensor(), T.Normalize([0.5]*3,[0.5]*3)])
            # build two views by reloading original PIL images
            # NOTE: our real is already a transformed tensor; for simplicity we jitter tensor by small transforms:
            # create view_a/view_b by small pixel noise + flip
            view_a = real * (1 + 0.02*torch.randn_like(real))
            view_b = real.flip(-1) if random.random() < 0.5 else real * (1 + 0.02*torch.randn_like(real))
            _, feat_a, proj_a = D(view_a, return_feat=True)
            _, feat_b, proj_b = D(view_b, return_feat=True)
            # compute barlow twins on projectors
            L_bt = barlow_twins_loss(proj_a, proj_b)
            lossD_total = lossD + barlow_lambda * L_bt
            optD.zero_grad(); lossD_total.backward(); optD.step()
            # Generator step
            z2 = torch.randn(bs, LATENT_DIM, device=DEVICE)
            fake2 = G(z2)
            logits_fake2 = D(fake2)[0] if isinstance(D(fake2), tuple) else D(fake2)
            lossG = bce_loss(logits_fake2, torch.ones_like(logits_fake2))
            optG.zero_grad(); lossG.backward(); optG.step()
        print(f"[Stage2:{class_name}] Epoch {epoch+1}/{epochs} | D_loss: {lossD.item():.4f} | BT_loss: {L_bt.item():.4f} | G_loss: {lossG.item():.4f}")
        # save checkpoint per epoch
        torch.save(G.state_dict(), os.path.join(save_root, f"{class_name}_G_epoch{epoch+1}.pth"))
        torch.save(D.state_dict(), os.path.join(save_root, f"{class_name}_D_epoch{epoch+1}.pth"))
    return G, D

# ---------------------------
# Synthesis helper
# ---------------------------
def generate_synthetic_images(G_checkpoint, out_folder, n_images=200, truncation=0.7):
    os.makedirs(out_folder, exist_ok=True)
    G = SimpleGenerator(LATENT_DIM, img_size=IMG_SIZE).to(DEVICE)
    G.load_state_dict(torch.load(G_checkpoint, map_location=DEVICE))
    G.eval()
    batch = 16
    idx = 0
    with torch.no_grad():
        for start in range(0, n_images, batch):
            bs = min(batch, n_images - start)
            z = torch.randn(bs, LATENT_DIM, device=DEVICE)
            fake = G(z)
            fake = (fake + 1) / 2.0  # to [0,1]
            for i in range(bs):
                save_image(fake[i], os.path.join(out_folder, f"{idx:05d}.png"))
                idx += 1
    print("Saved", idx, "images to", out_folder)

# ---------------------------
# Combine real + synth and train T-ResNet50 classifier
# ---------------------------
class TResNet50(nn.Module):
    def __init__(self, num_classes=len(CLASS_NAMES)):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        in_f = resnet.fc.in_features
        resnet.fc = nn.Sequential(nn.Linear(in_f,128), nn.ReLU(), nn.Dropout(0.3), nn.Linear(128,num_classes))
        self.model = resnet
    def forward(self,x): return self.model(x)

def assemble_combined(real_dir, synth_root, out_dir):
    os.makedirs(out_dir, exist_ok=True)
    for cls in CLASS_NAMES:
        src_real = Path(real_dir)/cls
        t = Path(out_dir)/cls
        t.mkdir(parents=True, exist_ok=True)
        if src_real.exists():
            for f in src_real.glob('*'):
                shutil.copy(f, t/f.name)
        synth = Path(synth_root)/cls
        if synth.exists():
            for f in synth.glob('*'):
                shutil.copy(f, t/f.name)
    print("Combined dataset at", out_dir)

def train_classifier(combined_dir, epochs=6, bs=32, lr=1e-4):
    train_tf = T.Compose([T.Resize((224,224)), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
    val_tf = T.Compose([T.Resize((224,224)), T.ToTensor(), T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])])
    ds = ImageFolder(combined_dir, transform=train_tf)
    total = len(ds)
    train_sz = int(0.8*total); val_sz = int(0.1*total); test_sz = total - train_sz - val_sz
    train_ds, val_ds, test_ds = random_split(ds, [train_sz,val_sz,test_sz])
    val_ds.dataset.transform = val_tf; test_ds.dataset.transform = val_tf
    tr = DataLoader(train_ds, batch_size=bs, shuffle=True, num_workers=2)
    vl = DataLoader(val_ds, batch_size=bs, shuffle=False, num_workers=2)
    model = TResNet50().to(DEVICE)
    opt = optim.Adam(model.parameters(), lr=lr)
    crit = nn.CrossEntropyLoss()
    best = 0.0
    for e in range(epochs):
        model.train(); running_loss=0; corr=0; tot=0
        for imgs, labels in tr:
            imgs,labels = imgs.to(DEVICE), labels.to(DEVICE)
            opt.zero_grad(); outs = model(imgs); loss = crit(outs, labels); loss.backward(); opt.step()
            running_loss += loss.item()*imgs.size(0)
            preds = outs.argmax(1); corr += (preds==labels).sum().item(); tot += labels.size(0)
        train_acc = corr/tot
        # val
        model.eval(); vcorr=0; vtot=0
        with torch.no_grad():
            for imgs, labels in vl:
                imgs,labels = imgs.to(DEVICE), labels.to(DEVICE)
                outs = model(imgs); preds = outs.argmax(1)
                vcorr += (preds==labels).sum().item(); vtot += labels.size(0)
        vacc = vcorr/vtot
        print(f"[Classifier] Epoch {e+1}/{epochs} train_acc={train_acc:.4f} val_acc={vacc:.4f}")
        if vacc > best:
            best = vacc; torch.save(model.state_dict(), os.path.join(WORK_DIR, 't_resnet_best.pth'))
    print("Classifier training finished. best val acc:", best)
    return os.path.join(WORK_DIR, 't_resnet_best.pth')

# ---------------------------
# Orchestration / main
# ---------------------------
def main_pipeline():
    # 0) organize HAM10000 into per-class folders if not already
    if not Path(REAL_IMGS_DIR).exists() or not any(Path(REAL_IMGS_DIR).iterdir()):
        print("Organizing HAM10000 into per-class folders...")
        organize_ham10000(DATA_DIR, REAL_IMGS_DIR)

    # 1) Create Stage1 dataset (all images) and dataloader
    tf = get_gan_transforms(IMG_SIZE)
    ds_all = ImgFolderDataset(REAL_IMGS_DIR, transform=tf)
    dl_all = DataLoader(ds_all, batch_size=BATCH, shuffle=True, drop_last=True)
    print("Stage1 dataset size:", len(ds_all))

    # 2) instantiate model and train Stage1 (lightweight)
    G = SimpleGenerator(LATENT_DIM, img_size=IMG_SIZE)
    D = SimpleDiscriminator()
    G, D = train_gan_stage1(G, D, dl_all, epochs=4, lr=2e-4, save_dir=STAGE1_DIR)

    # save latest checkpoints
    # choose last saved epoch files
    lastG = sorted(Path(STAGE1_DIR).glob("G_epoch*.pth"))[-1]
    lastD = sorted(Path(STAGE1_DIR).glob("D_epoch*.pth"))[-1]

    # 3) Stage-2 per-class fine-tune (Freeze-D + Barlow)
    for cls in CLASS_NAMES:
        class_folder = os.path.join(REAL_IMGS_DIR, cls)
        if not Path(class_folder).exists() or len(list(Path(class_folder).glob('*'))) < 10:
            print(f"Skipping Stage-2 for {cls}: too few images.")
            continue
        outroot = os.path.join(STAGE2_DIR, cls)
        os.makedirs(outroot, exist_ok=True)
        # fine-tune
        Gc, Dc = train_gan_stage2_per_class(cls, class_folder, str(lastG), str(lastD),
                                            epochs=3, lr=1e-4, freeze_d_num_layers=2, barlow_lambda=0.5,
                                            batch=min(BATCH, max(4,len(list(Path(class_folder).glob('*'))))))
        # if fine-tuning succeeded, generate synthetic
        if Gc is not None:
            # use final saved file
            chk = sorted(Path(outroot).glob(f"{cls}_G_epoch*.pth"))
            if chk:
                generate_synthetic_images(str(chk[-1]), os.path.join(SYNTH_DIR, cls), n_images=200)

    # 4) For classes where stage2 didn't run, optionally generate from stage1 GAN (to keep balance)
    for cls in CLASS_NAMES:
        sdir = Path(SYNTH_DIR)/cls
        if not sdir.exists() or len(list(sdir.glob('*'))) < 100:
            print(f"Generating fallback synth for {cls} from stage1 generator")
            generate_synthetic_images(str(lastG), os.path.join(SYNTH_DIR, cls), n_images=200)

    # 5) Combine and train classifier
    assemble_combined(REAL_IMGS_DIR, SYNTH_DIR, COMBINED_DIR)
    ckpt = train_classifier(COMBINED_DIR, epochs=6, bs=32, lr=1e-4)
    print("Pipeline complete. classifier ckpt:", ckpt)

if __name__ == "__main__":
    main_pipeline()


Organizing HAM10000 into per-class folders...
Organized images under /kaggle/working/stgan_lite/real_images
Stage1 dataset size: 10015
[Stage1] Epoch 1/4 | D_loss: 0.6897 | G_loss: 1.8895
[Stage1] Epoch 2/4 | D_loss: 0.8267 | G_loss: 1.8509
[Stage1] Epoch 3/4 | D_loss: 1.0526 | G_loss: 1.1082
[Stage1] Epoch 4/4 | D_loss: 0.2145 | G_loss: 1.9753
[Stage2:akiec] freeze top layers applied; trainable params: 14
[Stage2:akiec] Epoch 1/3 | D_loss: 0.5193 | BT_loss: 7.8907 | G_loss: 1.4636
[Stage2:akiec] Epoch 2/3 | D_loss: 0.6158 | BT_loss: 7.6159 | G_loss: 1.1701
[Stage2:akiec] Epoch 3/3 | D_loss: 0.5248 | BT_loss: 5.9850 | G_loss: 1.5373
[Stage2:bcc] freeze top layers applied; trainable params: 14
[Stage2:bcc] Epoch 1/3 | D_loss: 0.4821 | BT_loss: 6.7404 | G_loss: 1.6699
[Stage2:bcc] Epoch 2/3 | D_loss: 0.3657 | BT_loss: 6.3744 | G_loss: 1.8138
[Stage2:bcc] Epoch 3/3 | D_loss: 0.4785 | BT_loss: 5.7535 | G_loss: 1.2733
[Stage2:bkl] freeze top layers applied; trainable params: 14
[Stage2:bkl]

Downloading: "https://download.pytorch.org/models/resnet50-11ad3fa6.pth" to /root/.cache/torch/hub/checkpoints/resnet50-11ad3fa6.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 219MB/s]


[Classifier] Epoch 1/6 train_acc=0.6424 val_acc=0.7327
[Classifier] Epoch 2/6 train_acc=0.7616 val_acc=0.7520
[Classifier] Epoch 3/6 train_acc=0.8192 val_acc=0.7415
[Classifier] Epoch 4/6 train_acc=0.8611 val_acc=0.7572
[Classifier] Epoch 5/6 train_acc=0.8777 val_acc=0.7634
[Classifier] Epoch 6/6 train_acc=0.8872 val_acc=0.7677
Classifier training finished. best val acc: 0.7677475898334793
Pipeline complete. classifier ckpt: /kaggle/working/stgan_lite/t_resnet_best.pth


In [2]:
from sklearn.metrics import (
    confusion_matrix, classification_report, roc_auc_score,
    roc_curve, auc, precision_recall_curve
)
import itertools


In [4]:
# ==========================================
# VISUALIZATION & METRIC UTILITIES
# ==========================================

def show_random_real_images(folder_path, class_names, n=4):
    plt.figure(figsize=(10, 6))
    for i, cls in enumerate(class_names):
        img_paths = list((Path(folder_path) / cls).glob("*"))
        for j in range(n):
            plt.subplot(len(class_names), n, i*n + j + 1)
            img = Image.open(random.choice(img_paths))
            plt.imshow(img)
            plt.axis("off")
            if j == 0:
                plt.ylabel(cls, fontsize=12)
    plt.suptitle("Random Real Images Per Class")
    plt.tight_layout()
    plt.show()


def show_random_synthetic_images(folder_path, class_names, n=4):
    plt.figure(figsize=(10, 6))
    for i, cls in enumerate(class_names):
        img_paths = list((Path(folder_path) / cls).glob("*"))
        for j in range(n):
            plt.subplot(len(class_names), n, i*n + j + 1)
            img = Image.open(random.choice(img_paths))
            plt.imshow(img)
            plt.axis("off")
            if j == 0:
                plt.ylabel(cls, fontsize=12)
    plt.suptitle("Random Synthetic Images Per Class")
    plt.tight_layout()
    plt.show()


def plot_confusion_matrix(cm, classes):
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d',
                xticklabels=classes,
                yticklabels=classes, cmap="Blues")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()


def plot_roc_curves(y_true, y_pred_proba, classes):
    plt.figure(figsize=(8,6))
    for i, cls in enumerate(classes):
        fpr, tpr, _ = roc_curve((y_true == i).astype(int), y_pred_proba[:, i])
        auc_score = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f"{cls} (AUC={auc_score:.3f})")
    plt.plot([0,1],[0,1],'--',color='gray')
    plt.title("ROC Curves")
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.legend()
    plt.show()


def plot_precision_recall(y_true, y_pred_proba, classes):
    plt.figure(figsize=(8,6))
    for i, cls in enumerate(classes):
        pr, rc, _ = precision_recall_curve((y_true == i).astype(int), y_pred_proba[:, i])
        plt.plot(rc, pr, label=cls)
    plt.title("Precision–Recall Curves")
    plt.xlabel("Recall")
    plt.ylabel("Precision")
    plt.legend()
    plt.show()
