In [None]:
!git clone https://github.com/bardiarms/gan-rl.git
%cd gan-rl

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install torch torchvision torchaudio tqdm matplotlib pandas pillow

In [None]:
DATA_DIR = "/content/drive/MyDrive/gan-rl-data"
RUN_DIR  = "/content/drive/MyDrive/gan-rl-runs"

In [None]:
!ls -la "$DATA_DIR"

In [None]:
REPO_DIR = "/content/gan-rl"

In [None]:
import os

In [None]:
os.makedirs(RUN_DIR, exist_ok=True)
print("DATA_DIR exists:", os.path.exists(DATA_DIR))
print("RUN_DIR:", RUN_DIR)

In [None]:
# DATA_ROOT = "/content/cartoonset100k"
DATA_ROOT = "/content/drive/MyDrive/gan-rl-data/cartoonset100k"

In [None]:
from pathlib import Path


In [None]:
data_root = Path(DATA_ROOT)

pairs, missing_meta, missing_img = [], [], []

for d in sorted(data_root.iterdir()):

    for png_path in d.glob("*.png"):
        csv_path = png_path.with_suffix(".csv")
        if csv_path.exists():
            pairs.append((str(png_path), str(csv_path), int(d.name)))   # If the pair exists, add them to pairs
        else:
            missing_meta.append(str(png_path))


print("Total pairs:", len(pairs))
print("Missing CSV for PNG:", len(missing_meta))

In [None]:
for i in range(5):
    print(pairs[i])

In [None]:
pairs.sort(key=lambda x: x[0])  # sort by image path

In [None]:
# Assert we can open the first image + read first metadata line
from PIL import Image
import pandas as pd

img_path, meta_path, folder_id = pairs[0]
img = Image.open(img_path).convert("RGB")
df = pd.read_csv(meta_path, header=None, names=["attr", "value", "max"])

print("Sample folder:", folder_id)
print("Image size:", img.size)
print("Metadata shape:", df.shape)
print(df.head())

In [None]:
# Read one metadata file
def read_meta_csv(meta_path: str) -> pd.DataFrame:
    df = pd.read_csv(meta_path, header=None, names=["attr", "value", "max"])
    # ensure types
    df["attr"] = df["attr"].astype(str)
    df["value"] = df["value"].astype(int)
    df["max"] = df["max"].astype(int)
    return df

In [None]:
# Buil a schema for one-hot encoding
def build_schema(pairs, max_files=20):
    # pairs: list of (img_path, meta_path, folder_id)
    attr_to_num_classes = {}  # attr -> (max+1)
    attr_order = []           # stable order of attrs as discovered

    for i, (_, meta_path, _) in enumerate(pairs[:max_files]):
        df = read_meta_csv(meta_path)
        for _, row in df.iterrows():
            attr = row["attr"]
            num_classes = row["max"] + 1

            if attr not in attr_to_num_classes:
                attr_to_num_classes[attr] = num_classes
                attr_order.append(attr)
            else:
                # keep the maximum seen (in case some files differ)
                attr_to_num_classes[attr] = max(attr_to_num_classes[attr], num_classes)

    # offsets for packing one big vector
    offsets = {}
    total_dim = 0
    for attr in attr_order:
        offsets[attr] = total_dim
        total_dim += attr_to_num_classes[attr]

    return attr_order, attr_to_num_classes, offsets, total_dim

In [None]:
attr_order, attr_to_num_classes, offsets, meta_dim = build_schema(pairs)
print("num attributes:", len(attr_order))
print("meta vector dim:", meta_dim)
print(attr_order[:])

In [None]:
# Freeze the schema (Run Once)

import json

schema = {
    "attr_order": attr_order,
    "attr_to_num_classes": attr_to_num_classes,
    "offsets": offsets,
    "total_dim": meta_dim,
}

SCHEMA_PATH = "/content/drive/MyDrive/gan-rl-runs/meta_schema.json"
os.makedirs(os.path.dirname(SCHEMA_PATH), exist_ok=True)

with open(SCHEMA_PATH, "w") as f:
    json.dump(schema, f, indent=2)

print("Saved:", SCHEMA_PATH)

In [None]:
from torchvision import transforms

In [None]:

IMG_SIZE = 64      # We convert 500*500 pixel images into 128*128.
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]

img_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
    ])

In [None]:
import torch
import numpy as np

In [None]:
# Read metadata and store them in pandas dataframe
def read_meta(meta_path: str)-> pd.DataFrame:

  df = pd.read_csv(meta_path, header=None, names=["attr", "value", "max"])
  df["attr"] = df["attr"].astype(str)
  df["value"] = df["value"].astype(int)
  df["max"] = df["max"].astype(int)

  return df

In [None]:
# Apply one-hot encoding
def encode_onehot(meta_path: str,
                  attr_to_num_classes: dict,
                  offsets: dict,
                  total_dim: int
                  )-> torch.Tensor:

    df = read_meta(meta_path)
    vec = np.zeros((total_dim,), dtype=np.float32)

    for _, row in df.iterrows():
        attr = row["attr"]
        val = int(row["value"])

        if attr not in offsets:
          continue

        n = attr_to_num_classes[attr]
        if val < 0 or val >= n:
            val = max(0, min(val, n - 1))

        vec[offsets[attr] + val] = 1.0

    return torch.from_numpy(vec)


In [None]:
from torch.utils.data import Dataset

In [None]:
class CartoonSetDataset(Dataset):

    def __init__(self, pairs, img_transform, meta_cache = None):
        self.pairs = pairs
        self.img_transform = img_transform
        self.meta_cache = meta_cache

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, meta_path, folder_id = self.pairs[idx]

        img = Image.open(img_path).convert("RGB")
        img = self.img_transform(img)

        if self.meta_cache is not None:
            meta = self.meta_cache[meta_path]
            return img, meta, folder_id

        else:
            return img, folder_id



In [None]:
from torch.utils.data import DataLoader

In [None]:
ds = CartoonSetDataset(pairs=pairs, img_transform=img_transform)

dl = DataLoader(
    ds,
    batch_size=4,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=False
)
imgs, folder_ids = next(iter(dl))
print("imgs:", imgs.shape, imgs.min().item(), imgs.max().item())
#print("metas:", metas.shape, metas.min().item(), metas.max().item())
print("folder_ids:", folder_ids[:8])
#print("meta sums (first 8):", metas[:8].sum(dim=1))

In [None]:
import torch
import torch.nn as nn
from src import models
IMG_SIZE = 64
Z_DIM = 128
BATCH_SIZE = 4
DEVICE = "cuda"
criterion = nn.BCEWithLogitsLoss()

In [None]:
G = models.Generator(z_dim=Z_DIM).to(DEVICE)
D = models.Discriminator().to(DEVICE)

imgs, folder_ids = next(iter(dl))  # from your existing DataLoader
imgs = imgs.to(DEVICE)

z = torch.randn(imgs.size(0), Z_DIM, device=DEVICE)
fake = G(z)

print("Real:", imgs.shape, imgs.min().item(), imgs.max().item())
print("Fake:", fake.shape, fake.min().item(), fake.max().item())

d_real = D(imgs)
d_fake = D(fake.detach())

print("D(real) shape:", d_real.shape, "min/max:", d_real.min().item(), d_real.max().item())
print("D(fake) shape:", d_fake.shape, "min/max:", d_fake.min().item(), d_fake.max().item())

In [None]:
G = models.Generator(z_dim=Z_DIM).to(DEVICE)
D = models.Discriminator().to(DEVICE)

imgs, folder_ids = next(iter(dl))  # from your existing DataLoader
imgs = imgs.to(DEVICE)

z = torch.randn(imgs.size(0), Z_DIM, device=DEVICE)
fake = G(z)

print("Real:", imgs.shape, imgs.min().item(), imgs.max().item())
print("Fake:", fake.shape, fake.min().item(), fake.max().item())

d_real = D(imgs)
d_fake = D(fake.detach())

print("D(real) shape:", d_real.shape, "min/max:", d_real.min().item(), d_real.max().item())
print("D(fake) shape:", d_fake.shape, "min/max:", d_fake.min().item(), d_fake.max().item())

### Train GAN

In [None]:
from torchvision.utils import make_grid, save_image

In [None]:
# Denormalize pixels for viewing
def denorm(x):
    return (x * 0.5 + 0.5).clamp(0, 1)

@torch.no_grad()
def save_samples(G, step, fixed_z, out_dir, nrow=8):
    G.eval()
    fake = G(fixed_z)
    grid = make_grid(denorm(fake), nrow=nrow)
    path = os.path.join(out_dir, f"step_{step:06d}.png")
    save_image(grid, path)
    G.train()



In [None]:
# Add Gaussian Noise to Discriminator's inputs
def noise_sigma(step, sigma0=0.10, hold_steps=1500, decay_steps=4000):
    if step < hold_steps:
        return sigma0
    t = (step - hold_steps) / decay_steps
    return sigma0 * max(0.0, 1.0 - t)

# Noise Helper
def add_instance_noise(x, sigma):
    if sigma <= 0:
        return x
    return x + sigma * torch.randn_like(x)

In [None]:
def train_func(RUN_DIR: str,
               iters: int,
               SAMPLE_EVERY: int,
               CHKPT_EVERY: int
               )-> None:

    SAMPLES_DIR = os.path.join(RUN_DIR, f"samples_128_{iters}_iters")
    CHKPT_DIR = os.path.join(RUN_DIR, f"checkpoints_128_{iters}_iters")
    os.makedirs(SAMPLES_DIR, exist_ok=True)
    os.makedirs(CHKPT_DIR, exist_ok=True)

    fixed_z = torch.randn(64, Z_DIM, device=DEVICE)

    G.train(); D.train()

    step = 0
    data_iter = iter(dl)

    while step < iters:
        try:
            imgs, folder_ids = next(data_iter)
        except StopIteration:
            data_iter = iter(dl)
            imgs, folder_ids = next(data_iter)

        real = imgs.to(DEVICE, non_blocking=True)
        B = real.size(0)

        real_labels = torch.ones(B, device=DEVICE)
        fake_labels = torch.zeros(B, device=DEVICE)


        # ---Train Discriminator---

        opt_D.zero_grad(set_to_none=True)
        z = torch.randn(B, Z_DIM, device=DEVICE)

        sigma = 0 #noise_sigma(step)


        with torch.amp.autocast(device_type="cuda", enabled=use_amp):
            fake = G(z)

            real_in = add_instance_noise(real, sigma)
            fake_in = add_instance_noise(fake.detach(), sigma)

            logits_real = D(real_in)
            logits_fake = D(fake_in)
            loss_D_real = criterion(logits_real, real_labels)
            loss_D_fake = criterion(logits_fake, fake_labels)
            loss_D = loss_D_real + loss_D_fake

        scaler_D.scale(loss_D).backward()
        scaler_D.step(opt_D)
        scaler_D.update()


        # -------------------------
        # Train Generator
        # -------------------------
        # for _ in range(2):          # Generator updates twice as Discriminator
        opt_G.zero_grad(set_to_none=True)
        z = torch.randn(B, Z_DIM, device=DEVICE)

        with torch.amp.autocast(device_type="cuda", enabled=use_amp):
            fake = G(z)
            #fake_in = add_instance_noise(fake, sigma)
            logits_fake_for_G = D(fake)
            loss_G = criterion(logits_fake_for_G, real_labels)

        scaler_G.scale(loss_G).backward()
        scaler_G.step(opt_G)
        scaler_G.update()

        # Logging
        if step % 10 == 0:
            print(
                f"step {step:04d} | "
                f"loss_D {loss_D.item():.4f} (r {loss_D_real.item():.4f}, f {loss_D_fake.item():.4f}) | "
                f"loss_G {loss_G.item():.4f} | "
                f"D(real) {logits_real.mean().item():+.3f} | D(fake) {logits_fake.mean().item():+.3f}"
            )

        # Save samples
        if step % SAMPLE_EVERY == 0:
            save_samples(G, step, fixed_z, out_dir = SAMPLES_DIR)

        # Save checkpoints (optional)
        if step > 0 and step % CHKPT_EVERY == 0:
            ckpt_path = os.path.join(CHKPT_DIR, f"gan_step_{step:06d}.pt")
            torch.save(
                {
                    "step": step,
                    "G": G.state_dict(),
                    "D": D.state_dict(),
                    "opt_G": opt_G.state_dict(),
                    "opt_D": opt_D.state_dict(),
                    "scaler": scaler.state_dict(),
                },
                ckpt_path,
            )

        step += 1


In [None]:
train_func(RUN_DIR = "/content/drive/MyDrive/gan-rl-runs/12000_iters", iters = 12000, SAMPLE_EVERY = 500, CHKPT_EVERY = 500)

# Diagnosis

In [None]:
from torchvision.utils import make_grid, save_image

def denorm(x):
    # [-1, 1] -> [0, 1]
    return (x * 0.5 + 0.5).clamp(0, 1)

@torch.no_grad()
def save_random_grid(G, step, out_dir, z_dim, n=64, nrow=8, device="cuda"):
    G.eval()
    os.makedirs(out_dir, exist_ok=True)

    z = torch.randn(n, z_dim, device=device)      # fresh random z
    fake = G(z)                                   # [n, 3, H, W] in [-1,1]
    grid = make_grid(denorm(fake), nrow=nrow)

    path = os.path.join(out_dir, f"random_step_{step:06d}.png")
    save_image(grid, path)

    G.train()
    print("Saved:", path)


In [None]:
SAMPLES_DIR = "/content/drive/MyDrive/gan-rl-runs/12000_iters/samples_128_12000_iters"

In [None]:
save_random_grid(G, step=500, out_dir=SAMPLES_DIR, z_dim=Z_DIM, device=DEVICE)