
# ITFFC — Task 2: Conditional GAN (ACGAN) + Synthetic Augmentation

This notebook trains a **conditional GAN (ACGAN)** on your **medical** processed images only and generates synthetic data to **balance class imbalance**.

- **Input root (processed):** `C:\Users\bacht\Desktop\Master2_S1\ITFFC\Dataset\medical`
- It automatically discovers every `*_processed` leaf folder and treats each as a class.
- Trains on GPU (CUDA) if available (e.g., RTX 3060), with **AMP mixed precision** for speed.
- After training (**200 epochs** by default), it generates samples **per class** to reach the count of the largest class.
- Synthetic images are saved under:  
  `...\Dataset\medical_synthetic\<class>\*.png`

> Model: **ACGAN** (Discriminator outputs real/fake + class logits; Generator conditioned on class labels).  
> Image size default: **256×256**, channels: **3**, latent dim: **128**.


In [None]:

# =======================================
# Configuration
# =======================================
from pathlib import Path

DATASET_MEDICAL_ROOT = r"C:\Users\bacht\Desktop\Master2_S1\ITFFC\Dataset\medical"

image_size   = 256
channels     = 3
latent_dim   = 128
batch_size   = 32
num_epochs   = 200
lr_G         = 2e-4
lr_D         = 2e-4
beta1, beta2 = 0.5, 0.999
num_workers  = 4

samples_per_grid = 16
save_preview_every = 5
balance_to_largest = True

OUTPUT_ROOT = Path(DATASET_MEDICAL_ROOT).parent / "medical_synthetic"
CHECKPOINTS = Path(DATASET_MEDICAL_ROOT).parent / "gan_ckpts"
PREVIEWS    = Path(DATASET_MEDICAL_ROOT).parent / "gan_previews"

print("Synthetic out:", OUTPUT_ROOT)
print("Checkpoints  :", CHECKPOINTS)
print("Previews     :", PREVIEWS)



### Optional: Install libraries (run only if needed)
If your environment is missing packages, run the cell below.


In [None]:

# !pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# !pip install pillow tqdm matplotlib imagehash


In [None]:

# =======================================
# Imports
# =======================================
import os, sys, math, json, random, time
from pathlib import Path
from collections import Counter

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torchvision
from torchvision import transforms
from torchvision.utils import make_grid, save_image

from PIL import Image
import numpy as np
import matplotlib.pyplot as plt

from tqdm.auto import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())


In [None]:

# =======================================
# Data utilities
# =======================================
def find_leaf_processed_dirs(root: Path):
    root = Path(root)
    out = []
    for p in root.rglob("*"):
        if p.is_dir() and p.name.endswith("_processed"):
            entries = list(p.glob("*"))
            has_images = any(e.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff"} for e in entries if e.is_file())
            has_subdirs = any(e.is_dir() for e in entries)
            if has_images and not has_subdirs:
                out.append(p)
    return sorted(out)

def class_name_from_processed_dir(d: Path):
    name = d.name
    if name.endswith("_processed"):
        name = name[:-10]
    return name

processed_dirs = find_leaf_processed_dirs(Path(DATASET_MEDICAL_ROOT))
assert len(processed_dirs) > 0, "No *_processed leaf directories found under medical/"
class_names = [class_name_from_processed_dir(d) for d in processed_dirs]
num_classes = len(class_names)

print(f"Found {num_classes} classes:")
for d, cname in zip(processed_dirs, class_names):
    print(" -", cname, "from", d)

class_to_idx = {c:i for i,c in enumerate(class_names)}
idx_to_class = {i:c for c,i in class_to_idx.items()}

class ProcessedImagesDataset(Dataset):
    def __init__(self, processed_dirs, class_to_idx, image_size=256, channels=3, augment=True):
        self.samples = []
        self.class_to_idx = class_to_idx
        self.channels = channels

        for d in processed_dirs:
            cname = d.name[:-10] if d.name.endswith("_processed") else d.name
            y = class_to_idx[cname]
            for imgp in d.glob("*.png"):
                self.samples.append((imgp, y))

        aug_list = []
        if augment:
            aug_list += [
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=5),
            ]
        aug_list += [
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5]*channels, [0.5]*channels),
        ]
        self.transform = transforms.Compose(aug_list)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        p, y = self.samples[idx]
        img = Image.open(p).convert("RGB")
        x = self.transform(img)
        return x, y

dataset = ProcessedImagesDataset(processed_dirs, class_to_idx, image_size=image_size, channels=channels, augment=True)
print("Dataset size:", len(dataset))
counts = Counter([y for _, y in dataset.samples])
for k,v in counts.items():
    print(f"{idx_to_class[k]}: {v}")
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)


In [None]:

# =======================================
# ACGAN models
# =======================================
class GenBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.net = nn.Sequential(
            nn.ConvTranspose2d(in_ch, out_ch, 4, 2, 1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(True)
        )
    def forward(self, x): return self.net(x)

class DisBlock(nn.Module):
    def __init__(self, in_ch, out_ch, bn=True):
        super().__init__()
        layers = [nn.Conv2d(in_ch, out_ch, 4, 2, 1, bias=False)]
        if bn:
            layers.append(nn.BatchNorm2d(out_ch))
        layers.append(nn.LeakyReLU(0.2, inplace=True))
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

class Generator(nn.Module):
    def __init__(self, latent_dim, num_classes, img_channels=3, base=64):
        super().__init__()
        cond_dim = num_classes
        self.fc = nn.Sequential(
            nn.Linear(latent_dim + cond_dim, base*16*4*4),
            nn.ReLU(True)
        )
        self.net = nn.Sequential(
            GenBlock(base*16, base*8),  # 8x8
            GenBlock(base*8, base*4),   # 16x16
            GenBlock(base*4, base*2),   # 32x32
            GenBlock(base*2, base),     # 64x64
            GenBlock(base, base//2),    # 128x128
            nn.ConvTranspose2d(base//2, img_channels, 4, 2, 1, bias=False),  # 256x256
            nn.Tanh()
        )
        self.num_classes = num_classes

    def forward(self, z, y):
        y_onehot = F.one_hot(y, num_classes=self.num_classes).float()
        zc = torch.cat([z, y_onehot], dim=1)
        x = self.fc(zc).view(z.size(0), -1, 4, 4)
        img = self.net(x)
        return img

class Discriminator(nn.Module):
    def __init__(self, num_classes, img_channels=3, base=64):
        super().__init__()
        self.feature = nn.Sequential(
            DisBlock(img_channels, base, bn=False),
            DisBlock(base, base*2),
            DisBlock(base*2, base*4),
            DisBlock(base*4, base*8),
            DisBlock(base*8, base*16),
            nn.Conv2d(base*16, base*16, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
        )
        self.adv_head = nn.Conv2d(base*16, 1, 4, 1, 0, bias=False)
        self.cls_head = nn.Conv2d(base*16, num_classes, 4, 1, 0, bias=False)

    def forward(self, x):
        f = self.feature(x)
        adv = self.adv_head(f).view(x.size(0), 1)
        cls = self.cls_head(f).view(x.size(0), -1)
        return adv, cls


In [None]:

# =======================================
# Training loop with AMP
# =======================================
torch.backends.cudnn.benchmark = True

G = Generator(latent_dim, num_classes, img_channels=channels).to(device)
D = Discriminator(num_classes, img_channels=channels).to(device)

optG = torch.optim.Adam(G.parameters(), lr=lr_G, betas=(beta1, beta2))
optD = torch.optim.Adam(D.parameters(), lr=lr_D, betas=(beta1, beta2))

adv_loss = nn.BCEWithLogitsLoss().to(device)
cls_loss = nn.CrossEntropyLoss().to(device)
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

CHECKPOINTS.mkdir(parents=True, exist_ok=True)
PREVIEWS.mkdir(parents=True, exist_ok=True)

fixed_z = torch.randn(samples_per_grid, latent_dim, device=device)
fixed_labels = torch.tensor([i % num_classes for i in range(samples_per_grid)], device=device)

def preview(epoch):
    G.eval()
    with torch.no_grad():
        imgs = G(fixed_z, fixed_labels)
        imgs = (imgs + 1) * 0.5
        grid = make_grid(imgs.clamp(0,1), nrow=int(math.sqrt(samples_per_grid)))
        save_image(grid, PREVIEWS / f"epoch_{epoch:04d}.png")
    G.train()

for epoch in range(1, num_epochs+1):
    pbar = tqdm(loader, desc=f"Epoch {epoch}/{num_epochs}", leave=False)
    for real, y in pbar:
        real = real.to(device, non_blocking=True)
        y    = y.to(device, non_blocking=True)
        bsz  = real.size(0)

        # ---- Train D ----
        z = torch.randn(bsz, latent_dim, device=device)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            fake = G(z, y).detach()
            adv_real, cls_real = D(real)
            d_real = adv_loss(adv_real, torch.ones_like(adv_real))
            c_real = cls_loss(cls_real, y)

            adv_fake, _ = D(fake)
            d_fake = adv_loss(adv_fake, torch.zeros_like(adv_fake))

            d_loss = d_real + d_fake + c_real

        optD.zero_grad(set_to_none=True)
        scaler.scale(d_loss).backward()
        scaler.step(optD)

        # ---- Train G ----
        z = torch.randn(bsz, latent_dim, device=device)
        y_fake = torch.randint(0, num_classes, (bsz,), device=device)
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            gen = G(z, y_fake)
            adv_gen, cls_gen = D(gen)
            g_adv = adv_loss(adv_gen, torch.ones_like(adv_gen))
            g_cls = cls_loss(cls_gen, y_fake)
            g_loss = g_adv + g_cls

        optG.zero_grad(set_to_none=True)
        scaler.scale(g_loss).backward()
        scaler.step(optG)
        scaler.update()

        pbar.set_postfix(d=float(d_loss.item()), g=float(g_loss.item()))

    if epoch % save_preview_every == 0 or epoch == 1:
        preview(epoch)
        torch.save({
            "G": G.state_dict(),
            "D": D.state_dict(),
            "epoch": epoch,
            "class_to_idx": class_to_idx,
            "config": {
                "latent_dim": latent_dim, "image_size": image_size,
                "channels": channels, "num_classes": num_classes
            }
        }, CHECKPOINTS / f"acgan_epoch_{epoch:04d}.pt")

print("Training complete.")
preview(num_epochs)


In [None]:

# =======================================
# Synthetic generation to balance classes
# =======================================
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)

from collections import Counter
class_counts = Counter([y for _, y in dataset.samples])
max_count = max(class_counts.values())

def generate_for_class(cls_idx, n_images, batch=64):
    out_dir = OUTPUT_ROOT / idx_to_class[cls_idx]
    out_dir.mkdir(parents=True, exist_ok=True)
    G.eval()
    saved = 0
    with torch.no_grad():
        while saved < n_images:
            cur = min(batch, n_images - saved)
            z = torch.randn(cur, latent_dim, device=device)
            y = torch.full((cur,), cls_idx, device=device, dtype=torch.long)
            imgs = G(z, y)
            imgs = (imgs + 1) * 0.5
            for i in range(cur):
                save_image(imgs[i].clamp(0,1), out_dir / f"gen_{saved+i:06d}.png")
            saved += cur
    return saved

gen_plan = {}
for k,v in class_counts.items():
    target = max_count if balance_to_largest else v
    need = max(0, target - v)
    gen_plan[k] = need

print("Generation plan (#images to create per class):")
for k,need in gen_plan.items():
    print(f" - {idx_to_class[k]}: need {need}")

total_to_gen = sum(gen_plan.values())
print("Total to generate:", total_to_gen)

if total_to_gen > 0:
    for k,need in gen_plan.items():
        if need > 0:
            made = generate_for_class(k, need, batch=64)
            print(f"Generated {made} images for class {idx_to_class[k]}")
else:
    print("Dataset already balanced by max class count; no generation needed.")

print("Synthetic images saved under:", OUTPUT_ROOT)


In [None]:

# =======================================
# Quick visualization
# =======================================
from IPython.display import display
preview_imgs = sorted(PREVIEWS.glob("*.png"))
if preview_imgs:
    display(Image.open(preview_imgs[-1]).resize((512,512)))
else:
    print("No preview images found yet.")

for c in list(idx_to_class.values())[:3]:
    sample_dir = OUTPUT_ROOT / c
    if sample_dir.exists():
        samples = list(sample_dir.glob("*.png"))[:9]
        if samples:
            grid = []
            for p in samples:
                grid.append(transforms.ToTensor()(Image.open(p).convert("RGB")))
            grid = torch.stack(grid, dim=0)
            save_image(make_grid(grid, nrow=3), PREVIEWS / f"synthetic_{c}.png")
            display(Image.open(PREVIEWS / f"synthetic_{c}.png").resize((512,512)))
