In [1]:
import os
import sys
import time
import random
import itertools
from pathlib import Path

import certifi
os.environ["SSL_CERT_FILE"] = certifi.where()
os.environ["REQUESTS_CA_BUNDLE"] = os.environ["SSL_CERT_FILE"]

import torch
import torch.nn as nn
from torch.optim import lr_scheduler, Adam
import yaml

PROJECT_ROOT = Path.cwd().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from src.dataset import get_dataloaders
from src.evaluation import TrainingMonitor, Evaluator
from src.utils import set_seed, get_device, save_samples, CheckpointManager, print_model_summary

config_path = PROJECT_ROOT / 'config' / 'config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

device = get_device()
set_seed(config['training'].get('seed', 42))

IMAGE_SIZE = config['data'].get('image_size', 256)
BATCH_SIZE = config['training'].get('batch_size', 1)
EPOCHS = config['training'].get('epochs', 200)
LR = float(config['training'].get('learning_rate', 2e-4))
BETAS = tuple(config['training'].get('betas', [0.5, 0.999]))

LAMBDA_CYCLE = 10.0
LAMBDA_ID = 5.0
POOL_SIZE = 50

MODEL_NAME = 'cyclegan'
OUTPUT_BASE = PROJECT_ROOT / 'data' / MODEL_NAME
CHECKPOINT_DIR = OUTPUT_BASE / config['outputs'].get('checkpoints', 'outputs/checkpoints')
SAMPLES_DIR = OUTPUT_BASE / config['outputs'].get('samples', 'outputs/generated_images')
LOGS_DIR = OUTPUT_BASE / config['outputs'].get('logs', 'outputs/logs')
REPORTS_DIR = OUTPUT_BASE / config['outputs'].get('reports', 'outputs/reports')

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)
REPORTS_DIR.mkdir(parents=True, exist_ok=True)

Using MPS (Apple Silicon)


In [None]:
loaders_apple = get_dataloaders(domain='apple', batch_size=BATCH_SIZE, image_size=IMAGE_SIZE, paired=False, preload=True)
loaders_orange = get_dataloaders(domain='orange', batch_size=BATCH_SIZE, image_size=IMAGE_SIZE, paired=False, preload=True)

train_loader_apple = loaders_apple['train']
train_loader_orange = loaders_orange['train']
val_loader_apple = loaders_apple['val']
val_loader_orange = loaders_orange['val']

print(f"Train batches - Apple: {len(train_loader_apple)}, Orange: {len(train_loader_orange)}")

Loading apple (imageA) from HuggingFace...


Repo card metadata block was not found. Setting CardData to empty.


Loaded 1019 images for train split.
Preloading apple train...
Preloaded 1019 images
Loading apple (imageA) from HuggingFace...


Repo card metadata block was not found. Setting CardData to empty.


Loaded 266 images for test split.
Preloading apple val...


In [None]:
class ResNetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.conv_block(x)


class Generator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
        super().__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True),
                nn.InstanceNorm2d(ngf * mult * 2),
                nn.ReLU(True)
            ]

        mult = 2 ** n_downsampling
        for _ in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
                nn.InstanceNorm2d(ngf * mult // 2),
                nn.ReLU(True)
            ]

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, n_layers=3):
        super().__init__()

        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        mult = 1
        for i in range(1, n_layers):
            mult_prev = mult
            mult = min(2 ** i, 8)
            model += [
                nn.Conv2d(ndf * mult_prev, ndf * mult, kernel_size=4, stride=2, padding=1, bias=True),
                nn.InstanceNorm2d(ndf * mult),
                nn.LeakyReLU(0.2, True)
            ]

        mult_prev = mult
        mult = min(2 ** n_layers, 8)
        model += [
            nn.Conv2d(ndf * mult_prev, ndf * mult, kernel_size=4, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(ndf * mult),
            nn.LeakyReLU(0.2, True)
        ]

        model += [nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class ImagePool:
    def __init__(self, pool_size):
        self.pool_size = pool_size
        self.num_imgs = 0
        self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs += 1
                self.images.append(image)
                return_images.append(image)
            else:
                if random.random() > 0.5:
                    idx = random.randint(0, self.pool_size - 1)
                    tmp = self.images[idx].clone()
                    self.images[idx] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return torch.cat(return_images, 0)


def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

In [None]:
G_apple2monet = Generator().to(device)
G_orange2photo = Generator().to(device)

D_apple = Discriminator().to(device)
D_orange = Discriminator().to(device)

print_model_summary(G_apple2monet, 'Generator')
print_model_summary(D_apple, 'Discriminator')

opt_G = Adam(itertools.chain(G_apple2monet.parameters(), G_orange2photo.parameters()), lr=LR, betas=BETAS)
opt_D_apple = Adam(D_apple.parameters(), lr=LR, betas=BETAS)
opt_D_orange = Adam(D_orange.parameters(), lr=LR, betas=BETAS)

def lambda_rule(epoch):
    decay_start = EPOCHS // 2
    return 1.0 - max(0, epoch - decay_start) / (EPOCHS - decay_start + 1)

sched_G = lr_scheduler.LambdaLR(opt_G, lr_lambda=lambda_rule)
sched_D_apple = lr_scheduler.LambdaLR(opt_D_apple, lr_lambda=lambda_rule)
sched_D_orange = lr_scheduler.LambdaLR(opt_D_orange, lr_lambda=lambda_rule)

pool_apple = ImagePool(POOL_SIZE)
pool_orange = ImagePool(POOL_SIZE)

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

checkpoint_manager = CheckpointManager(checkpoint_dir=str(CHECKPOINT_DIR), model_name=MODEL_NAME)
start_epoch = 0

In [None]:
def train():
    monitor = TrainingMonitor(model_name=MODEL_NAME, log_dir=str(LOGS_DIR))
    monitor.start_training()

    for epoch in range(start_epoch, EPOCHS):
        epoch_start = time.time()

        if len(train_loader_apple) > len(train_loader_orange):
            iter_orange = itertools.cycle(train_loader_orange)
            iter_apple = iter(train_loader_apple)
        else:
            iter_apple = itertools.cycle(train_loader_apple)
            iter_orange = iter(train_loader_orange)

        num_batches = max(len(train_loader_apple), len(train_loader_orange))

        for i in range(num_batches):
            real_apple = next(iter_apple)[0].to(device)
            real_orange = next(iter_orange)[0].to(device)

            fake_orange = G_apple2monet(real_apple)
            rec_apple = G_orange2photo(fake_orange)
            fake_apple = G_orange2photo(real_orange)
            rec_orange = G_apple2monet(fake_apple)

            if LAMBDA_ID > 0:
                idt_apple = G_orange2photo(real_apple)
                idt_orange = G_apple2monet(real_orange)

            set_requires_grad([D_apple, D_orange], False)
            opt_G.zero_grad()

            loss_GAN_apple2monet = criterion_GAN(D_orange(fake_orange), torch.ones_like(D_orange(fake_orange)))
            loss_GAN_orange2photo = criterion_GAN(D_apple(fake_apple), torch.ones_like(D_apple(fake_apple)))

            loss_cycle_apple = criterion_cycle(rec_apple, real_apple) * LAMBDA_CYCLE
            loss_cycle_orange = criterion_cycle(rec_orange, real_orange) * LAMBDA_CYCLE

            if LAMBDA_ID > 0:
                loss_id_apple = criterion_identity(idt_apple, real_apple) * LAMBDA_ID
                loss_id_orange = criterion_identity(idt_orange, real_orange) * LAMBDA_ID
                loss_G = loss_GAN_apple2monet + loss_GAN_orange2photo + loss_cycle_apple + loss_cycle_orange + loss_id_apple + loss_id_orange
            else:
                loss_G = loss_GAN_apple2monet + loss_GAN_orange2photo + loss_cycle_apple + loss_cycle_orange

            loss_G.backward()
            opt_G.step()

            set_requires_grad(D_apple, True)
            opt_D_apple.zero_grad()
            fake_apple_pool = pool_apple.query(fake_apple.detach())
            loss_D_apple_real = criterion_GAN(D_apple(real_apple), torch.ones_like(D_apple(real_apple)))
            loss_D_apple_fake = criterion_GAN(D_apple(fake_apple_pool), torch.zeros_like(D_apple(fake_apple_pool)))
            loss_D_apple = (loss_D_apple_real + loss_D_apple_fake) * 0.5
            loss_D_apple.backward()
            opt_D_apple.step()

            set_requires_grad(D_orange, True)
            opt_D_orange.zero_grad()
            fake_orange_pool = pool_orange.query(fake_orange.detach())
            loss_D_orange_real = criterion_GAN(D_orange(real_orange), torch.ones_like(D_orange(real_orange)))
            loss_D_orange_fake = criterion_GAN(D_orange(fake_orange_pool), torch.zeros_like(D_orange(fake_orange_pool)))
            loss_D_orange = (loss_D_orange_real + loss_D_orange_fake) * 0.5
            loss_D_orange.backward()
            opt_D_orange.step()

            monitor.log_loss('loss_G', loss_G.item())
            monitor.log_loss('loss_D', (loss_D_apple + loss_D_orange).item())
            monitor.log_loss('loss_cycle', (loss_cycle_apple + loss_cycle_orange).item())
            if LAMBDA_ID > 0:
                monitor.log_loss('loss_identity', (loss_id_apple + loss_id_orange).item())

            if i % 100 == 0:
                print(f"Epoch [{epoch}/{EPOCHS}] Batch [{i}/{num_batches}] "
                      f"Loss_G: {loss_G.item():.4f} Loss_D: {(loss_D_apple + loss_D_orange).item():.4f}")

            if i % 200 == 0:
                with torch.no_grad():
                    samples = [real_apple[0], fake_orange[0], rec_apple[0], real_orange[0], fake_apple[0], rec_orange[0]]
                    titles = ['Photo', 'Fake Monet', 'Rec Photo', 'Monet', 'Fake Photo', 'Rec Monet']
                    save_samples(samples, SAMPLES_DIR / f"epoch_{epoch}_batch_{i}.png", nrow=3, titles=titles)

        epoch_time = time.time() - epoch_start
        monitor.log_epoch_time(epoch_time)
        monitor.log_learning_rate(sched_G.get_last_lr()[0])

        if epoch == start_epoch:
            print(f"Estimated total time: {epoch_time * EPOCHS / 3600:.1f} hours")

        sched_G.step()
        sched_D_apple.step()
        sched_D_orange.step()

        checkpoint_manager.save(
            epoch,
            G_apple2monet=G_apple2monet, G_orange2photo=G_orange2photo,
            D_apple=D_apple, D_orange=D_orange,
            opt_G=opt_G, opt_D_apple=opt_D_apple, opt_D_orange=opt_D_orange
        )

    monitor.save()

In [None]:
train()

In [None]:
def evaluate():
    G_apple2monet.eval()
    G_orange2photo.eval()

    evaluator = Evaluator(model_name=MODEL_NAME, device=device, output_dir=str(REPORTS_DIR))

    print("Computing real Monet statistics for FID...")
    evaluator.set_real_data(val_loader_orange)

    print("Evaluating Photo→Monet translation...")
    result = evaluator.evaluate(
        generator=G_apple2monet,
        content_loader=val_loader_apple,
        num_samples=min(500, len(val_loader_apple.dataset))
    )

    result_path = REPORTS_DIR / f"{MODEL_NAME}_evaluation.json"
    result.save(str(result_path))
    print(f"Evaluation saved to {result_path}")

    print("Generating comparison samples...")
    with torch.no_grad():
        sample_apples = next(iter(val_loader_apple)).to(device)
        sample_oranges = next(iter(val_loader_orange)).to(device)

        fake_oranges = G_apple2monet(sample_apples)
        fake_apples = G_orange2photo(sample_oranges)
        rec_apples = G_orange2photo(fake_oranges)
        rec_oranges = G_apple2monet(fake_apples)

        samples_p2m = [sample_apples[0], fake_oranges[0], rec_apples[0]]
        titles_p2m = ['Photo', '→ Monet', '→ Photo']
        save_samples(samples_p2m, SAMPLES_DIR / "final_apple2monet_cycle.png", nrow=3, titles=titles_p2m)

        samples_m2p = [sample_oranges[0], fake_apples[0], rec_oranges[0]]
        titles_m2p = ['Monet', '→ Photo', '→ Monet']
        save_samples(samples_m2p, SAMPLES_DIR / "final_orange2photo_cycle.png", nrow=3, titles=titles_m2p)

        n_samples = min(4, sample_apples.size(0))
        grid_samples = []
        grid_titles = []
        for i in range(n_samples):
            grid_samples.extend([sample_apples[i], fake_oranges[i], rec_apples[i]])
            grid_titles.extend([f'Photo {i+1}', f'Fake Monet {i+1}', f'Rec Photo {i+1}'])
        save_samples(grid_samples, SAMPLES_DIR / "final_grid.png", nrow=3, titles=grid_titles)

    print(f"\n{'='*50}")
    print(f"Evaluation Summary: {MODEL_NAME}")
    print(f"{'='*50}")
    print(f"FID Score: {result.fid_score:.2f}")
    print(f"SSIM Score: {result.ssim_score:.4f}")
    print(f"Total Parameters: {result.total_params:,}")
    print(f"{'='*50}")

    G_apple2monet.train()
    G_orange2photo.train()

    return result

In [None]:
result = evaluate()