In [None]:
from google.colab import drive
drive.mount('/content/drive')

import numpy as np
import torch, torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models, utils, datasets, transforms
import torchvision.utils as vutils
import os, time, glob, random, contextlib, shutil
import torchvision.transforms as T
from torch.amp import GradScaler, autocast
from torch.optim.lr_scheduler import LambdaLR
from PIL import Image, ImageFile

Mounted at /content/drive


In [None]:
ImageFile.LOAD_TRUNCATED_IMAGES = True
local_root = "/content/cartoon-data"
os.makedirs(local_root, exist_ok=True)

drive_root = "/content/drive/MyDrive/cartoon-classification"

# copy ALL real photos
src_real = os.path.join(drive_root, "real_photos_clean")
dst_real = os.path.join(local_root, "real_photos_clean")

# ‚úÖ copy ONLY Gumball cartoons
src_cart = os.path.join(drive_root, "cartoon_modern_clean", "TRAIN", "Gumball")
dst_cart = os.path.join(local_root, "cartoon_modern_clean", "TRAIN", "Gumball")

In [None]:
def copy_tree_with_progress(src, dst):
    if os.path.exists(dst):
        print(f"{dst} already exists, skipping copy ‚úÖ")
        return

    print(f"Copying {src} -> {dst} ...")
    os.makedirs(dst, exist_ok=True)

    # count files first (for tqdm total)
    total_files = sum(len(files) for _, _, files in os.walk(src))

    copied = 0
    for root, dirs, files in os.walk(src):
        # replicate directory structure
        rel_path = os.path.relpath(root, src)
        dst_path = os.path.join(dst, rel_path)
        os.makedirs(dst_path, exist_ok=True)

        for file in files:
            src_file = os.path.join(root, file)
            dst_file = os.path.join(dst_path, file)
            shutil.copy2(src_file, dst_file)
            copied += 1
            tqdm.write(f"Copied {copied}/{total_files}", end="\r")

    print("\n‚úÖ Done.")

copy_tree_with_progress(src_real, dst_real)
copy_tree_with_progress(src_cart, dst_cart)



Copying /content/drive/MyDrive/cartoon-classification/real_photos_clean -> /content/cartoon-data/real_photos_clean ...
Copied 9802/9802
‚úÖ Done.
Copying /content/drive/MyDrive/cartoon-classification/cartoon_modern_clean/TRAIN/Gumball -> /content/cartoon-data/cartoon_modern_clean/TRAIN/Gumball ...
Copied 9794/9794
‚úÖ Done.


In [None]:
X_ROOT = "/content/cartoon-data/real_photos_clean"
Y_ROOT_GUMBALL = "/content/cartoon-data/cartoon_modern_clean/TRAIN/Gumball"
OUT_ROOT = "/content/drive/MyDrive/cartoon-classification/cyclegan_gumball"
os.makedirs(OUT_ROOT, exist_ok=True)
IMG_EXTS = (".jpg", ".jpeg", ".png", ".webp", ".bmp")

In [None]:
class UnpairedDataset(Dataset):
    """
    Unpaired X ~ real photos, Y ~ Gumball cartoons.
    IMPORTANT: no random crop/flip here, just Resize -> keeps structure.
    """
    def __init__(self, path_x, path_y, subset_size=None, img_size=256):
        super().__init__()

        def list_imgs(root):
            pat = os.path.join(root, "**", "*")
            files = sorted(
                p for p in glob.glob(pat, recursive=True)
                if os.path.isfile(p) and p.lower().endswith(IMG_EXTS)
            )
            return files

        self.x_files = list_imgs(path_x)
        self.y_files = list_imgs(path_y)

        if subset_size is not None:
            random.seed(42)
            if len(self.x_files) > subset_size:
                self.x_files = random.sample(self.x_files, subset_size)
            if len(self.y_files) > subset_size:
                self.y_files = random.sample(self.y_files, subset_size)

        if not self.x_files:
            raise RuntimeError(f"No X images at {path_x}")
        if not self.y_files:
            raise RuntimeError(f"No Y images at {path_y}")

        # *** NO random crop / flip: just resize to square ***
        self.tx = T.Compose([
            T.Resize((img_size, img_size), antialias=True),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3),
        ])
        self.ty = T.Compose([
            T.Resize((img_size, img_size), antialias=True),
            T.ToTensor(),
            T.Normalize([0.5]*3, [0.5]*3),
        ])

    def __len__(self):
        return max(len(self.x_files), len(self.y_files))

    def __getitem__(self, idx):
        x_path = self.x_files[idx % len(self.x_files)]
        y_path = self.y_files[random.randint(0, len(self.y_files) - 1)]
        x = Image.open(x_path).convert("RGB")
        y = Image.open(y_path).convert("RGB")
        return self.tx(x), self.ty(y)


def build_loader(img_size, subset=None):
    ds = UnpairedDataset(X_ROOT, Y_ROOT_GUMBALL, subset_size=subset, img_size=img_size)
    print(f"Dataset sizes: X={len(ds.x_files)}  Y={len(ds.y_files)}", flush=True)
    dl = DataLoader(
        ds,
        batch_size=BATCH,
        shuffle=True,
        num_workers=0,
        drop_last=True,
        pin_memory=(DEVICE == "cuda"),
    )
    print(f"üîß batches/epoch={len(dl)}  img={img_size}px  batch={BATCH}")
    return dl

In [None]:

class ResnetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, bias=False),
            nn.InstanceNorm2d(dim, affine=False, track_running_stats=False),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, 3, bias=False),
            nn.InstanceNorm2d(dim, affine=False, track_running_stats=False),
        )

    def forward(self, x):
        return x + self.block(x)


class ResnetGenerator(nn.Module):
    """
    9-block ResNet generator.
    """
    def __init__(self, in_c=3, out_c=3, n_blocks=9):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(in_c, 64, 7, bias=False),
            nn.InstanceNorm2d(64, affine=False, track_running_stats=False),
            nn.ReLU(True),
        ]
        c = 64
        # Downsample x2, x4
        for _ in range(2):
            model += [
                nn.Conv2d(c, c*2, 3, 2, 1, bias=False),
                nn.InstanceNorm2d(c*2, affine=False, track_running_stats=False),
                nn.ReLU(True),
            ]
            c *= 2
        # Residual blocks
        for _ in range(n_blocks):
            model += [ResnetBlock(c)]
        # Upsample back
        for _ in range(2):
            model += [
                nn.ConvTranspose2d(c, c//2, 3, 2, 1, 1, bias=False),
                nn.InstanceNorm2d(c//2, affine=False, track_running_stats=False),
                nn.ReLU(True),
            ]
            c //= 2
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(c, out_c, 7),
            nn.Tanh(),
        ]
        self.net = nn.Sequential(*model)

    def forward(self, x):
        return self.net(x)


class PatchDiscriminator(nn.Module):
    def __init__(self, in_c=3, n_layers=3):
        super().__init__()
        layers = [
            nn.Conv2d(in_c, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, True),
        ]
        c = 64
        for _ in range(1, n_layers):
            layers += [
                nn.Conv2d(c, c*2, 4, 2, 1, bias=False),
                nn.InstanceNorm2d(c*2, affine=False, track_running_stats=False),
                nn.LeakyReLU(0.2, True),
            ]
            c *= 2
        layers += [nn.Conv2d(c, 1, 4, 1, 1)]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


def gan_loss_lsgan(pred, is_real: bool):
    target = torch.ones_like(pred) if is_real else torch.zeros_like(pred)
    return F.mse_loss(pred, target)


In [None]:
# ------------- HELPERS -------------
def save_sample(x, y, fake_y, rec_x, step, out_dir, tag, max_n=4):
    n = min(max_n, x.size(0), y.size(0))
    grid = torch.cat([x[:n], y[:n], fake_y[:n], rec_x[:n]], dim=0)
    grid = (grid.clamp(-1, 1) + 1) * 0.5
    os.makedirs(out_dir, exist_ok=True)
    out = os.path.join(out_dir, f"{tag}_{step:07d}.png")
    vutils.save_image(grid, out, nrow=n)
    print(f"üñºÔ∏è  sample ‚Üí {out}", flush=True)


def make_lr_scheduler(optimizer, total_epochs):
    keep_epochs = int(total_epochs * LR_KEEP_FRAC)
    decay_epochs = max(1, total_epochs - keep_epochs)

    def lr_lambda(epoch):
        if epoch < keep_epochs:
            return 1.0
        return max(0.0, 1.0 - float(epoch - keep_epochs) / float(decay_epochs))

    return LambdaLR(optimizer, lr_lambda=lr_lambda)


def train_stage(
    stage_name,
    img_size,
    epochs,
    lr_g,
    lr_d,
    lambda_cyc,
    lambda_id_max,
    id_warmup_epochs,
    out_root,
    g_init_path=None,
    f_init_path=None,
    start_step=0,
):
    print(f"\n===== {stage_name} @ {img_size}px | epochs={epochs}, lr_g={lr_g}, lr_d={lr_d}, "
          f"Œª_cyc={lambda_cyc}, Œª_id_max={lambda_id_max} =====")
    samples_dir = os.path.join(out_root, "samples")
    ckpts_dir   = os.path.join(out_root, "checkpoints")
    os.makedirs(samples_dir, exist_ok=True)
    os.makedirs(ckpts_dir, exist_ok=True)

    dl = build_loader(img_size, SUBSET)

    G  = ResnetGenerator(n_blocks=9).to(DEVICE)
    Fm = ResnetGenerator(n_blocks=9).to(DEVICE)
    DX = PatchDiscriminator().to(DEVICE)
    DY = PatchDiscriminator().to(DEVICE)

    for m in (G, Fm, DX, DY):
        m.to(memory_format=torch.channels_last)

    if g_init_path and os.path.exists(g_init_path):
        print("Loading G from:", g_init_path)
        G.load_state_dict(torch.load(g_init_path, map_location=DEVICE), strict=True)
    if f_init_path and os.path.exists(f_init_path):
        print("Loading F from:", f_init_path)
        Fm.load_state_dict(torch.load(f_init_path, map_location=DEVICE), strict=True)

    g_opt = torch.optim.Adam(
        list(G.parameters()) + list(Fm.parameters()),
        lr=lr_g, betas=(0.5, 0.999)
    )
    d_opt = torch.optim.Adam(
        list(DX.parameters()) + list(DY.parameters()),
        lr=lr_d, betas=(0.5, 0.999)
    )

    g_sched = make_lr_scheduler(g_opt, epochs)
    d_sched = make_lr_scheduler(d_opt, epochs)

    use_cuda = (DEVICE == "cuda")
    amp_ctx  = autocast(device_type="cuda", dtype=torch.float16) if use_cuda else contextlib.nullcontext()
    g_scaler = GradScaler(device="cuda") if use_cuda else None
    d_scaler = GradScaler(device="cuda") if use_cuda else None

    step = start_step
    last_g_path = None
    last_f_path = None

    for ep in range(1, epochs+1):
        t0 = time.time()

        # identity weight ramp
        if id_warmup_epochs <= 0:
            lambda_id = lambda_id_max
        else:
            ramp = min(1.0, ep / float(id_warmup_epochs))
            lambda_id = lambda_id_max * ramp

        print(f"\nüåÄ {stage_name} Epoch {ep}/{epochs} | Œª_id={lambda_id:.3f}", flush=True)
        print(f"   lr_G={g_opt.param_groups[0]['lr']:.6g}, lr_D={d_opt.param_groups[0]['lr']:.6g}", flush=True)

        # epoch stats
        d_loss_sum = 0.0
        g_loss_sum = 0.0
        cyc_sum    = 0.0
        idt_sum    = 0.0
        advg_sum   = 0.0
        advf_sum   = 0.0
        num_steps  = 0

        for x, y in dl:
            x = x.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)
            y = y.to(DEVICE, non_blocking=True).to(memory_format=torch.channels_last)

            # --- D step ---
            with torch.no_grad(), amp_ctx:
                fake_y = G(x)
                fake_x = Fm(y)

            d_opt.zero_grad(set_to_none=True)
            with amp_ctx:
                dx_real = DX(x)
                dx_fake = DX(fake_x)
                dy_real = DY(y)
                dy_fake = DY(fake_y)

                d_loss = (
                    gan_loss_lsgan(dx_real, True) +
                    gan_loss_lsgan(dx_fake, False) +
                    gan_loss_lsgan(dy_real, True) +
                    gan_loss_lsgan(dy_fake, False)
                ) * 0.5

            if use_cuda:
                d_scaler.scale(d_loss).backward()
                d_scaler.step(d_opt)
                d_scaler.update()
            else:
                d_loss.backward()
                d_opt.step()

            # --- G/F step ---
            g_opt.zero_grad(set_to_none=True)
            with amp_ctx:
                fake_y = G(x);   pred_dy = DY(fake_y)
                fake_x = Fm(y);  pred_dx = DX(fake_x)

                adv_g = gan_loss_lsgan(pred_dy, True)
                adv_f = gan_loss_lsgan(pred_dx, True)

                rec_x = Fm(fake_y)
                rec_y = G(fake_x)
                cyc   = F.l1_loss(rec_x, x) + F.l1_loss(rec_y, y)

                if lambda_id > 0.0:
                    id_x = G(y)
                    id_y = Fm(x)
                    idt  = F.l1_loss(id_x, y) + F.l1_loss(id_y, x)
                else:
                    idt = torch.zeros((), device=DEVICE, dtype=cyc.dtype)

                g_loss = adv_g + adv_f + lambda_cyc*cyc + lambda_id*idt

            if use_cuda:
                g_scaler.scale(g_loss).backward()
                g_scaler.step(g_opt)
                g_scaler.update()
            else:
                g_loss.backward()
                g_opt.step()

            # accumulate stats
            d_loss_sum += d_loss.item()
            g_loss_sum += g_loss.item()
            cyc_sum    += cyc.item()
            idt_sum    += idt.item()
            advg_sum   += adv_g.item()
            advf_sum   += adv_f.item()
            num_steps  += 1

            if step % 400 == 0:
                print(
                    f"[{stage_name} ep {ep:02d}] step {step:07d} | "
                    f"D={d_loss.item():.3f} | G={g_loss.item():.3f} | "
                    f"advG={adv_g.item():.3f} | advF={adv_f.item():.3f} | "
                    f"cyc={cyc.item():.3f} | idt={idt.item():.3f}",
                    flush=True
                )

            if step % 2000 == 0:
                with torch.inference_mode(), amp_ctx:
                    save_sample(
                        x.detach().cpu(), y.detach().cpu(),
                        fake_y.detach().cpu(), rec_x.detach().cpu(),
                        step, samples_dir,
                        tag=stage_name.replace(" ", "_"),
                    )
            step += 1

        # end epoch: print epoch-avg stats
        if num_steps > 0:
            print(
                f"üìä {stage_name} Epoch {ep} stats | "
                f"D_avg={d_loss_sum/num_steps:.3f} | "
                f"G_avg={g_loss_sum/num_steps:.3f} | "
                f"advG_avg={advg_sum/num_steps:.3f} | "
                f"advF_avg={advf_sum/num_steps:.3f} | "
                f"cyc_avg={cyc_sum/num_steps:.3f} | "
                f"idt_avg={idt_sum/num_steps:.3f}",
                flush=True
            )

        tag = stage_name.replace(" ", "").lower()
        last_g_path = os.path.join(ckpts_dir, f"G_{img_size}_{tag}_ep{ep}.pt")
        last_f_path = os.path.join(ckpts_dir, f"F_{img_size}_{tag}_ep{ep}.pt")
        torch.save(G.state_dict(), last_g_path)
        torch.save(Fm.state_dict(), last_f_path)
        print(f"üíæ saved {stage_name} epoch {ep} ‚Üí {last_g_path}")
        print(f"‚è± epoch took {(time.time()-t0)/60:.1f} min")

        g_sched.step()
        d_sched.step()

    return last_g_path, last_f_path, step

In [None]:
# train_gumball_cyclegan_fixed.py
# Two-stage CycleGAN for real -> Gumball-only cartoon
# StageA: 256px pretrain (moderate style, strong identity)
# StageB: 512px finetune (stronger style)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH = 4
SUBSET = None

LAMBDA_CYC_STAGEA = 5.0
LAMBDA_ID_STAGEA  = 15.0
ID_WARMUP_STAGEA  = 5

LAMBDA_CYC_STAGEB = 2.0
LAMBDA_ID_STAGEB  = 0.0
ID_WARMUP_STAGEB  = 0

STAGEA_IMG   = 256
STAGEA_EPOCH = 50
STAGEA_LR_G  = 2e-4
STAGEA_LR_D  = 2e-4

STAGEB_IMG   = 512
STAGEB_EPOCH = 25
STAGEB_LR_G  = 1e-4
STAGEB_LR_D  = 1e-4




In [None]:
# ================= MAIN: StageA then StageB =================
if __name__ == "__main__":
    # ----- Stage A: 256px from scratch -----
    OUT_A = os.path.join(OUT_ROOT, "StageA_256_new")
    os.makedirs(OUT_A, exist_ok=True)

    print(f"\nüöÄ Training StageA_256 from scratch for {STAGEA_EPOCH} epochs...")
    g_256_path, f_256_path, _ = train_stage(
        stage_name="StageA_256",
        img_size=STAGEA_IMG,
        epochs=STAGEA_EPOCH,
        lr_g=STAGEA_LR_G,
        lr_d=STAGEA_LR_D,
        lambda_cyc=LAMBDA_CYC_STAGEA,
        lambda_id_max=LAMBDA_ID_STAGEA,
        id_warmup_epochs=ID_WARMUP_STAGEA,
        out_root=OUT_A,
        g_init_path=None,
        f_init_path=None,
        start_step=0,
    )

    print("‚úÖ StageA_256 finished.")
    print("   Last G:", g_256_path)
    print("   Last F:", f_256_path)

    # ----- Stage B: 512px finetune from StageA -----
    OUT_B = os.path.join(OUT_ROOT, "StageB_512_new")
    os.makedirs(OUT_B, exist_ok=True)

    print(f"\nüîÅ Starting StageB_512 from StageA weights for {STAGEB_EPOCH} epochs...")

    _g_512_path, _f_512_path, _ = train_stage(
        stage_name="StageB_512",
        img_size=STAGEB_IMG,
        epochs=STAGEB_EPOCH,
        lr_g=STAGEB_LR_G,
        lr_d=STAGEB_LR_D,
        lambda_cyc=LAMBDA_CYC_STAGEB,
        lambda_id_max=LAMBDA_ID_STAGEB,
        id_warmup_epochs=ID_WARMUP_STAGEB,
        out_root=OUT_B,
        g_init_path=g_256_path,
        f_init_path=f_256_path,
        start_step=0,
    )

    print("‚úÖ StageB_512 training finished.")



üöÄ Training StageA_256 from scratch for 50 epochs...

===== StageA_256 @ 256px | epochs=50, lr_g=0.0002, lr_d=0.0002, Œª_cyc=5.0, Œª_id_max=15.0 =====
Dataset sizes: X=9802  Y=9794
üîß batches/epoch=2450  img=256px  batch=4

üåÄ StageA_256 Epoch 1/50 | Œª_id=3.000
   lr_G=0.0002, lr_D=0.0002
[StageA_256 ep 01] step 0000000 | D=1.264 | G=9.838 | advG=0.681 | advF=0.847 | cyc=1.041 | idt=1.035
üñºÔ∏è  sample ‚Üí /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageA_256_new/samples/StageA_256_0000000.png
[StageA_256 ep 01] step 0000400 | D=0.482 | G=5.336 | advG=0.436 | advF=0.252 | cyc=0.658 | idt=0.453
[StageA_256 ep 01] step 0000800 | D=0.379 | G=3.280 | advG=0.540 | advF=0.298 | cyc=0.319 | idt=0.282
[StageA_256 ep 01] step 0001200 | D=0.330 | G=3.529 | advG=0.511 | advF=0.378 | cyc=0.348 | idt=0.299
[StageA_256 ep 01] step 0001600 | D=0.374 | G=3.049 | advG=0.590 | advF=0.300 | cyc=0.273 | idt=0.265
[StageA_256 ep 01] step 0002000 | D=0.425 | G=2.957 | advG=0.51

KeyboardInterrupt: 

In [None]:
# ================= MAIN: RESUME SPECIFICALLY FROM STAGE B EPOCH 4 =================
if __name__ == "__main__":

    # ----- Skip Stage A (you already finished 50 epochs) -----
    OUT_A = os.path.join(OUT_ROOT, "StageA_256_new")
    g_256_path = os.path.join(OUT_A, "checkpoints/G_256_stagea_256_ep50.pt")
    f_256_path = os.path.join(OUT_A, "checkpoints/F_256_stagea_256_ep50.pt")

    print("‚úî Using StageA epoch 50 checkpoints:")
    print("G:", g_256_path)
    print("F:", f_256_path)

    # ----- Resume Stage B from epoch 4 -----
    OUT_B = os.path.join(OUT_ROOT, "StageB_512_new")
    os.makedirs(OUT_B, exist_ok=True)

    g_resume = os.path.join(OUT_B, "checkpoints/G_512_stageb_512_ep4.pt")
    f_resume = os.path.join(OUT_B, "checkpoints/F_512_stageb_512_ep4.pt")

    print("\nüîÅ Resuming StageB from EPOCH 4:")
    print("G:", g_resume)
    print("F:", f_resume)

    # compute correct start step (epoch * batches_per_epoch)
    start_step = 4 * 2450     # approx 9800
    print("start_step =", start_step)

    # resume training from epoch 4
    _g_last, _f_last, _ = train_stage(
        stage_name="StageB_512",
        img_size=STAGEB_IMG,
        epochs=STAGEB_EPOCH,          # stays 25; sched still works
        lr_g=STAGEB_LR_G,
        lr_d=STAGEB_LR_D,
        lambda_cyc=LAMBDA_CYC_STAGEB,
        lambda_id_max=LAMBDA_ID_STAGEB,
        id_warmup_epochs=ID_WARMUP_STAGEB,
        out_root=OUT_B,
        g_init_path=g_resume,
        f_init_path=f_resume,
        start_step=start_step
    )

    print("üéâ StageB resumed and running from epoch 4.")


‚úî Using StageA epoch 50 checkpoints:
G: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageA_256_new/checkpoints/G_256_stagea_256_ep50.pt
F: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageA_256_new/checkpoints/F_256_stagea_256_ep50.pt

üîÅ Resuming StageB from EPOCH 4:
G: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageB_512_new/checkpoints/G_512_stageb_512_ep4.pt
F: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageB_512_new/checkpoints/F_512_stageb_512_ep4.pt
start_step = 9800

===== StageB_512 @ 512px | epochs=25, lr_g=0.0001, lr_d=0.0001, Œª_cyc=2.0, Œª_id_max=0.0 =====
Dataset sizes: X=9802  Y=9794
üîß batches/epoch=2450  img=512px  batch=4
Loading G from: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageB_512_new/checkpoints/G_512_stageb_512_ep4.pt
Loading F from: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageB_512_new/checkpoints/F_512_stageb_512_ep4.pt

üåÄ 

KeyboardInterrupt: 

In [None]:
# ================= MAIN: RESUME SPECIFICALLY FROM STAGE B EPOCH 18 =================
if __name__ == "__main__":

    # We won't retrain Stage A
    OUT_A = os.path.join(OUT_ROOT, "StageA_256_new")
    g_256_path = os.path.join(OUT_A, "checkpoints/G_256_stagea_256_ep50.pt")
    f_256_path = os.path.join(OUT_A, "checkpoints/F_256_stagea_256_ep50.pt")
    print("‚úî Using StageA epoch 50 checkpoints:")
    print("G:", g_256_path)
    print("F:", f_256_path)

    OUT_B = os.path.join(OUT_ROOT, "StageB_512_new")
    os.makedirs(OUT_B, exist_ok=True)

    EPOCH_RESUME = 18  # <--- change this if you ever want a different epoch
    G_NAME = f"G_512_stageb_512_ep{EPOCH_RESUME}.pt"
    F_NAME = f"F_512_stageb_512_ep{EPOCH_RESUME}.pt"

    g_resume = os.path.join(OUT_B, "checkpoints", G_NAME)
    f_resume = os.path.join(OUT_B, "checkpoints", F_NAME)

    print(f"\nüîÅ Resuming StageB from EPOCH {EPOCH_RESUME}:")
    print("G:", g_resume)
    print("F:", f_resume)

    remaining_epochs = STAGEB_EPOCH - EPOCH_RESUME  # 25 - 18 = 7
    if remaining_epochs <= 0:
        print("Nothing left to train, StageB is already at or past target epoch.")
    else:
        batches_per_epoch = 2450  # from your logs
        start_step = EPOCH_RESUME * batches_per_epoch
        print("remaining_epochs =", remaining_epochs)
        print("start_step =", start_step)

        _g_last, _f_last, _ = train_stage(
            stage_name="StageB_512",
            img_size=STAGEB_IMG,
            epochs=remaining_epochs,      # just finish 19..25
            lr_g=STAGEB_LR_G,
            lr_d=STAGEB_LR_D,
            lambda_cyc=LAMBDA_CYC_STAGEB,
            lambda_id_max=LAMBDA_ID_STAGEB,
            id_warmup_epochs=ID_WARMUP_STAGEB,
            out_root=OUT_B,
            g_init_path=g_resume,
            f_init_path=f_resume,
            start_step=start_step,
        )

        print(f"üéâ StageB resumed from epoch {EPOCH_RESUME} and finished the remaining epochs.")


‚úî Using StageA epoch 50 checkpoints:
G: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageA_256_new/checkpoints/G_256_stagea_256_ep50.pt
F: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageA_256_new/checkpoints/F_256_stagea_256_ep50.pt

üîÅ Resuming StageB from EPOCH 18:
G: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageB_512_new/checkpoints/G_512_stageb_512_ep18.pt
F: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageB_512_new/checkpoints/F_512_stageb_512_ep18.pt
remaining_epochs = 7
start_step = 44100

===== StageB_512 @ 512px | epochs=7, lr_g=0.0001, lr_d=0.0001, Œª_cyc=2.0, Œª_id_max=0.0 =====
Dataset sizes: X=9802  Y=9794
üîß batches/epoch=2450  img=512px  batch=4
Loading G from: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageB_512_new/checkpoints/G_512_stageb_512_ep18.pt
Loading F from: /content/drive/MyDrive/cartoon-classification/cyclegan_gumball/StageB_512_new/checkpoints/F_512