In [None]:
# full_cyclegan_512.py
import os
import random
import itertools
import time
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.utils as vutils
from tqdm import tqdm

# -----------------------
# CONFIG
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# user config - change these paths/values accordingly
img_size = 512
batch_size = 3            # increase if you have VRAM
epochs = 100
save_interval = 5         # save every 5 epochs
data_root = "/teamspace/studios/this_studio/.lightning_studio/dataset"    # must contain trainA/ and trainB/
save_dir = "/teamspace/studios/this_studio/.lightning_studio/checkpoints"
sample_dir = os.path.join(save_dir, "samples")
os.makedirs(save_dir, exist_ok=True)
os.makedirs(sample_dir, exist_ok=True)

# optim params
lr = 0.0002
beta1 = 0.5
beta2 = 0.999

# losses weights
lambda_cycle = 10.0
lambda_identity = 5.0

# random seed (optional)
random.seed(42)
torch.manual_seed(42)

# -----------------------
# DATASET
# -----------------------
transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),  # maps [0,1] -> [-1,1]
])

class SingleImageFolder(Dataset):
    def __init__(self, root_dir, transform=None):
        self.files = [os.path.join(root_dir, f) for f in sorted(os.listdir(root_dir)) 
                      if f.lower().endswith(("jpg", "jpeg", "png"))]
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        img = Image.open(self.files[idx]).convert("RGB")
        if self.transform:
            img = self.transform(img)
        return img

# paths for domain A (real photos) and domain B (Ghibli style)
dataset_A = SingleImageFolder(os.path.join(data_root, "trainA"), transform=transform)
dataset_B = SingleImageFolder(os.path.join(data_root, "trainB"), transform=transform)

loader_A = DataLoader(dataset_A, batch_size=batch_size, shuffle=True, num_workers=6, drop_last=True)
loader_B = DataLoader(dataset_B, batch_size=batch_size, shuffle=True, num_workers=6, drop_last=True)

steps_per_epoch = min(len(loader_A), len(loader_B))
print(f"Dataset sizes -> A: {len(dataset_A)}, B: {len(dataset_B)}. Steps/epoch: {steps_per_epoch}")

# -----------------------
# MODELS
# -----------------------
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0),
            nn.InstanceNorm2d(channels, affine=True),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0),
            nn.InstanceNorm2d(channels, affine=True),
        )
    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=12):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(inplace=True)
        ]
        in_channels = 64
        for _ in range(2):
            out_channels = in_channels * 2
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.ReLU(inplace=True)
            ]
            in_channels = out_channels
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_channels)]
        for _ in range(2):
            out_channels = in_channels // 2
            model += [
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.ReLU(inplace=True)
            ]
            in_channels = out_channels
        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, kernel_size=7),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*model)
    def forward(self, x):
        return self.model(x)

class Discriminator(nn.Module):
    def __init__(self, input_nc=3):
        super().__init__()
        layers = [
            nn.Conv2d(input_nc, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        ]
        in_channels = 64
        for _ in range(4):  # 4 additional => total 5 downsamples
            out_channels = min(in_channels * 2, 512)
            layers += [
                nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.LeakyReLU(0.2, inplace=True)
            ]
            in_channels = out_channels
        layers += [nn.Conv2d(in_channels, 1, kernel_size=4, padding=1)]
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

# instantiate
G_A2B = Generator(n_residual_blocks=12).to(device)
G_B2A = Generator(n_residual_blocks=12).to(device)
D_A = Discriminator().to(device)
D_B = Discriminator().to(device)

# -----------------------
# weight init
# -----------------------
def init_weights(net, init_gain=0.02):
    for m in net.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
            nn.init.normal_(m.weight.data, 0.0, init_gain)
            if getattr(m, "bias", None) is not None:
                nn.init.constant_(m.bias.data, 0.0)
        elif isinstance(m, (nn.InstanceNorm2d, nn.BatchNorm2d)):
            if getattr(m, "weight", None) is not None:
                nn.init.normal_(m.weight.data, 1.0, init_gain)
            if getattr(m, "bias", None) is not None:
                nn.init.constant_(m.bias.data, 0.0)

init_weights(G_A2B)
init_weights(G_B2A)
init_weights(D_A)
init_weights(D_B)

def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)

print("Params G:", count_params(G_A2B), " Params D:", count_params(D_A))

# -----------------------
# Replay buffer
# -----------------------
class ReplayBuffer:
    def __init__(self, max_size=50):
        assert max_size > 0
        self.max_size = max_size
        self.data = []  # list of tensors on CPU
    def push_and_pop(self, data):
        # data: tensor (B, C, H, W)
        returned = []
        for element in data.detach().cpu():
            element = element.unsqueeze(0)  # (1,C,H,W)
            if len(self.data) < self.max_size:
                self.data.append(element.clone())
                returned.append(element.to(device))
            else:
                if random.uniform(0,1) > 0.5:
                    idx = random.randint(0, self.max_size - 1)
                    tmp = self.data[idx].clone().to(device)
                    self.data[idx] = element.clone()
                    returned.append(tmp)
                else:
                    returned.append(element.to(device))
        return torch.cat(returned, dim=0)

fake_A_buffer = ReplayBuffer()
fake_B_buffer = ReplayBuffer()

# -----------------------
# losses, optimizers, schedulers
# -----------------------
criterion_GAN = nn.MSELoss().to(device)
criterion_cycle = nn.L1Loss().to(device)
criterion_identity = nn.L1Loss().to(device)

optimizer_G = optim.Adam(itertools.chain(G_A2B.parameters(), G_B2A.parameters()), lr=lr, betas=(beta1, beta2))
optimizer_D_A = optim.Adam(D_A.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D_B = optim.Adam(D_B.parameters(), lr=lr, betas=(beta1, beta2))

decay_start_epoch = max(1, epochs // 2)
def lambda_rule(epoch):
    return 1.0 - max(0, epoch - decay_start_epoch) / float(max(1, epochs - decay_start_epoch))

scheduler_G = optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda_rule)
scheduler_D_A = optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=lambda_rule)
scheduler_D_B = optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=lambda_rule)

# -----------------------
# utilities: sampling images for visualization
# -----------------------
def denorm(tensor):
    # tensor in [-1,1] -> [0,1]
    return (tensor + 1.0) / 2.0

def sample_images(epoch, G_A2B, G_B2A, loader_for_sampling, n_samples=4):
    G_A2B.eval(); G_B2A.eval()
    try:
        real_A = next(iter(loader_for_sampling[0]))
        real_B = next(iter(loader_for_sampling[1]))
    except Exception:
        # fallback: get first batch from each loader
        real_A = next(iter(loader_A))
        real_B = next(iter(loader_B))
    real_A = real_A[:n_samples].to(device)
    real_B = real_B[:n_samples].to(device)
    with torch.no_grad():
        fake_B = G_A2B(real_A)
        fake_A = G_B2A(real_B)
        rec_A = G_B2A(fake_B)
        rec_B = G_A2B(fake_A)
    # build grid: top rows A_real, A_fake, A_rec ; bottom rows B_real, B_fake, B_rec
    grid = torch.cat([real_A, fake_B, rec_A, real_B, fake_A, rec_B], dim=0)
    grid = denorm(grid)
    grid = vutils.make_grid(grid, nrow=n_samples, padding=2)
    path = os.path.join(sample_dir, f"epoch_{epoch}.png")
    vutils.save_image(grid, path)
    print(f"Saved sample image to: {path}")
    G_A2B.train(); G_B2A.train()

# -----------------------
# Mixed precision scaler
# -----------------------
use_amp = True if torch.cuda.is_available() else False
scaler_G = torch.cuda.amp.GradScaler(enabled=use_amp)
scaler_D_A = torch.cuda.amp.GradScaler(enabled=use_amp)
scaler_D_B = torch.cuda.amp.GradScaler(enabled=use_amp)

# -----------------------
# Training loop
# -----------------------
print("Starting training loop...")
for epoch in range(1, epochs + 1):
    epoch_start = time.time()
    # iterate zipped loaders (stops at smaller loader length)
    loop = zip(loader_A, loader_B)
    for i, (real_A, real_B) in enumerate(loop, start=1):
        real_A = real_A.to(device)
        real_B = real_B.to(device)

        # -----------------------------------------
        # Train Generators (G_A2B & G_B2A)
        # -----------------------------------------
        optimizer_G.zero_grad()
        with torch.cuda.amp.autocast(enabled=use_amp):
            # identity
            idt_B = G_A2B(real_B)
            loss_id_B = criterion_identity(idt_B, real_B) * lambda_identity
            idt_A = G_B2A(real_A)
            loss_id_A = criterion_identity(idt_A, real_A) * lambda_identity

            # GAN loss
            fake_B = G_A2B(real_A)
            pred_fake_B = D_B(fake_B)
            loss_GAN_A2B = criterion_GAN(pred_fake_B, torch.ones_like(pred_fake_B))

            fake_A = G_B2A(real_B)
            pred_fake_A = D_A(fake_A)
            loss_GAN_B2A = criterion_GAN(pred_fake_A, torch.ones_like(pred_fake_A))

            # cycle loss
            recov_A = G_B2A(fake_B)
            loss_cycle_A = criterion_cycle(recov_A, real_A) * lambda_cycle

            recov_B = G_A2B(fake_A)
            loss_cycle_B = criterion_cycle(recov_B, real_B) * lambda_cycle

            loss_G = loss_GAN_A2B + loss_GAN_B2A + loss_cycle_A + loss_cycle_B + loss_id_A + loss_id_B

        scaler_G.scale(loss_G).backward()
        scaler_G.step(optimizer_G)
        scaler_G.update()

        # -----------------------------------------
        # Train Discriminator A
        # -----------------------------------------
        optimizer_D_A.zero_grad()
        with torch.cuda.amp.autocast(enabled=use_amp):
            pred_real = D_A(real_A)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

            fake_A_for_disc = fake_A_buffer.push_and_pop(fake_A.detach())  # (B,C,H,W)
            pred_fake = D_A(fake_A_for_disc)
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

            loss_D_A = (loss_D_real + loss_D_fake) * 0.5

        scaler_D_A.scale(loss_D_A).backward()
        scaler_D_A.step(optimizer_D_A)
        scaler_D_A.update()

        # -----------------------------------------
        # Train Discriminator B
        # -----------------------------------------
        optimizer_D_B.zero_grad()
        with torch.cuda.amp.autocast(enabled=use_amp):
            pred_real = D_B(real_B)
            loss_D_real = criterion_GAN(pred_real, torch.ones_like(pred_real))

            fake_B_for_disc = fake_B_buffer.push_and_pop(fake_B.detach())
            pred_fake = D_B(fake_B_for_disc)
            loss_D_fake = criterion_GAN(pred_fake, torch.zeros_like(pred_fake))

            loss_D_B = (loss_D_real + loss_D_fake) * 0.5

        scaler_D_B.scale(loss_D_B).backward()
        scaler_D_B.step(optimizer_D_B)
        scaler_D_B.update()

        # Logging (every batch)
        if i % 10 == 0 or i == steps_per_epoch:
            print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{steps_per_epoch}] "
                  f"[G {loss_G.item():.4f}] [D_A {loss_D_A.item():.4f}] [D_B {loss_D_B.item():.4f}]")

        # limit to steps_per_epoch
        if i >= steps_per_epoch:
            break

    # update schedulers
    scheduler_G.step()
    scheduler_D_A.step()
    scheduler_D_B.step()

    # save checkpoints every save_interval epochs
    if epoch % save_interval == 0 or epoch == 1:
        ckpt = {
            "epoch": epoch,
            "G_A2B": G_A2B.state_dict(),
            "G_B2A": G_B2A.state_dict(),
            "D_A": D_A.state_dict(),
            "D_B": D_B.state_dict(),
            "optim_G": optimizer_G.state_dict(),
            "optim_D_A": optimizer_D_A.state_dict(),
            "optim_D_B": optimizer_D_B.state_dict()
        }
        torch.save(ckpt, os.path.join(save_dir, f"cyclegan_epoch_{epoch}.pth"))
        print(f"Saved checkpoint epoch {epoch} -> {os.path.join(save_dir, f'cyclegan_epoch_{epoch}.pth')}")

    # sample images for qualitative check
    sample_images(epoch, G_A2B, G_B2A, loader_for_sampling=(loader_A, loader_B), n_samples=min(4, batch_size))

    print(f"Epoch {epoch} completed in {time.time() - epoch_start:.2f}s")

print("Training finished.")


Using device: cuda
Dataset sizes -> A: 2500, B: 2500. Steps/epoch: 833


Params G: 14932227  Params D: 6962369
Starting training loop...


  scaler_G = torch.cuda.amp.GradScaler(enabled=use_amp)
  scaler_D_A = torch.cuda.amp.GradScaler(enabled=use_amp)
  scaler_D_B = torch.cuda.amp.GradScaler(enabled=use_amp)
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):
  with torch.cuda.amp.autocast(enabled=use_amp):


KeyboardInterrupt: 

In [11]:
# standalone_test_ghibli.py
import os
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

# -------- CONFIG --------
checkpoint_path = "/teamspace/studios/this_studio/.lightning_studio/checkpoints/cyclegan_epoch_85.pth"
test_image_path = "/teamspace/studios/this_studio/.lightning_studio/test_5.jpg"
out_dir = "/teamspace/studios/this_studio/.lightning_studio/Outputs"
os.makedirs(out_dir, exist_ok=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# -------- Generator class (must match training) --------
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0),
            nn.InstanceNorm2d(channels, affine=True),
            nn.ReLU(inplace=True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(channels, channels, kernel_size=3, padding=0),
            nn.InstanceNorm2d(channels, affine=True),
        )
    def forward(self, x):
        return x + self.block(x)

class Generator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, n_residual_blocks=12):
        super().__init__()
        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, 64, kernel_size=7, padding=0),
            nn.InstanceNorm2d(64, affine=True),
            nn.ReLU(inplace=True)
        ]
        in_channels = 64
        for _ in range(2):
            out_channels = in_channels * 2
            model += [
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.ReLU(inplace=True)
            ]
            in_channels = out_channels

        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_channels)]

        for _ in range(2):
            out_channels = in_channels // 2
            model += [
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
                nn.InstanceNorm2d(out_channels, affine=True),
                nn.ReLU(inplace=True)
            ]
            in_channels = out_channels

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(64, output_nc, kernel_size=7),
            nn.Tanh()
        ]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

# -------- Instantiate generator and load checkpoint --------
G = Generator(n_residual_blocks=12).to(device)

ckpt = torch.load(checkpoint_path, map_location=device)
# print keys for debugging if needed
if isinstance(ckpt, dict):
    print("Checkpoint keys:", list(ckpt.keys()))
else:
    print("Checkpoint is not a dict (it's probably a state_dict).")

# Load weights: preferred path if you saved dict with "G_A2B"
if isinstance(ckpt, dict) and "G_A2B" in ckpt:
    G.load_state_dict(ckpt["G_A2B"])
    print("Loaded weights from ckpt['G_A2B'].")
else:
    # fallback: try loading the whole checkpoint as a state_dict
    try:
        G.load_state_dict(ckpt)
        print("Loaded checkpoint as state_dict directly.")
    except Exception as e:
        raise RuntimeError(f"Couldn't load generator weights from checkpoint. Error: {e}")

G.eval()

# -------- Preprocess / postprocess --------
img_size = 512
preprocess = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
])

def denorm(tensor):
    return (tensor.clamp(-1,1) + 1.0) / 2.0

def tensor_to_pil(tensor):
    t = denorm(t.squeeze(0)).cpu()
    return transforms.ToPILImage()(t)

# -------- Inference --------
img = Image.open(test_image_path).convert("RGB")
input_tensor = preprocess(img).unsqueeze(0).to(device)  # shape (1,C,H,W)

with torch.no_grad():
    fake = G(input_tensor)

out_pil = tensor_to_pil(fake)
out_name = os.path.basename(test_image_path).rsplit(".",1)[0] + "_ghibli.png"
out_path = os.path.join(out_dir, out_name)
out_pil.save(out_path)
print("Saved output to:", out_path)

# show side-by-side
plt.figure(figsize=(10,5))
plt.subplot(1,2,1); plt.title("Original"); plt.axis("off"); plt.imshow(img)
plt.subplot(1,2,2); plt.title("Ghibli Style"); plt.axis("off"); plt.imshow(out_pil)
plt.show()


Device: cuda


Checkpoint keys: ['epoch', 'G_A2B', 'G_B2A', 'D_A', 'D_B', 'optim_G', 'optim_D_A', 'optim_D_B']
Loaded weights from ckpt['G_A2B'].


UnboundLocalError: local variable 't' referenced before assignment