In [1]:
import os
import sys
import time
import random
import itertools
from pathlib import Path

import certifi
os.environ["SSL_CERT_FILE"] = certifi.where()
os.environ["REQUESTS_CA_BUNDLE"] = os.environ["SSL_CERT_FILE"]

import torch
import torch.nn as nn
from torch.optim import lr_scheduler, Adam
import yaml

PROJECT_ROOT = Path.cwd().parent.parent
sys.path.insert(0, str(PROJECT_ROOT))

from src.dataset import get_dataloaders
from src.evaluation import TrainingMonitor, Evaluator
from src.utils import set_seed, get_device, save_samples, CheckpointManager, print_model_summary

config_path = PROJECT_ROOT / 'config' / 'config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)

device = get_device()
set_seed(config['training'].get('seed', 42))

IMAGE_SIZE = config['data'].get('image_size', 256)
BATCH_SIZE = config['training'].get('batch_size', 1)
EPOCHS = config['training'].get('epochs', 200)
LR = float(config['training'].get('learning_rate', 2e-4))
BETAS = tuple(config['training'].get('betas', [0.5, 0.999]))

LAMBDA_CYCLE = 10.0
LAMBDA_ID = 5.0
POOL_SIZE = 50

MODEL_NAME = 'cyclegan'
OUTPUT_BASE = PROJECT_ROOT / 'data' / MODEL_NAME
CHECKPOINT_DIR = OUTPUT_BASE / config['outputs'].get('checkpoints', 'outputs/checkpoints')
SAMPLES_DIR = OUTPUT_BASE / config['outputs'].get('samples', 'outputs/generated_images')
LOGS_DIR = OUTPUT_BASE / config['outputs'].get('logs', 'outputs/logs')
REPORTS_DIR = OUTPUT_BASE / config['outputs'].get('reports', 'outputs/reports')

CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
SAMPLES_DIR.mkdir(parents=True, exist_ok=True)
LOGS_DIR.mkdir(parents=True, exist_ok=True)
REPORTS_DIR.mkdir(parents=True, exist_ok=True)

Using MPS (Apple Silicon)


In [2]:
loaders_photo = get_dataloaders(domain='photos', batch_size=BATCH_SIZE, image_size=IMAGE_SIZE, paired=False)
loaders_monet = get_dataloaders(domain='monet', batch_size=BATCH_SIZE, image_size=IMAGE_SIZE, paired=False)

train_loader_photo = loaders_photo['train']
train_loader_monet = loaders_monet['train']
val_loader_photo = loaders_photo['val']
val_loader_monet = loaders_monet['val']

print(f"Train batches - Photo: {len(train_loader_photo)}, Monet: {len(train_loader_monet)}")

`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'huggan/monet2photo' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loading photos (imageB) from HuggingFace...


Repo card metadata block was not found. Setting CardData to empty.
`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'huggan/monet2photo' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loaded 6287 images for train split.
Loading photos (imageB) from HuggingFace...


Repo card metadata block was not found. Setting CardData to empty.
`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'huggan/monet2photo' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loaded 751 images for test split.
Loading monet (imageA) from HuggingFace...


Repo card metadata block was not found. Setting CardData to empty.
`trust_remote_code` is not supported anymore.
Please check that the Hugging Face dataset 'huggan/monet2photo' isn't based on a loading script and remove `trust_remote_code`.
If the dataset is based on a loading script, please ask the dataset author to remove it and convert it to a standard format like Parquet.


Loaded 6287 images for train split.
Loading monet (imageA) from HuggingFace...


Repo card metadata block was not found. Setting CardData to empty.


Loaded 751 images for test split.
Train batches - Photo: 6287, Monet: 6287


In [3]:
class ResNetBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv_block = nn.Sequential(
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(dim),
            nn.ReLU(True),
            nn.ReflectionPad2d(1),
            nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
            nn.InstanceNorm2d(dim)
        )

    def forward(self, x):
        return x + self.conv_block(x)


class Generator(nn.Module):
    def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9):
        super().__init__()

        model = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
            nn.InstanceNorm2d(ngf),
            nn.ReLU(True)
        ]

        n_downsampling = 2
        for i in range(n_downsampling):
            mult = 2 ** i
            model += [
                nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=True),
                nn.InstanceNorm2d(ngf * mult * 2),
                nn.ReLU(True)
            ]

        mult = 2 ** n_downsampling
        for _ in range(n_blocks):
            model += [ResNetBlock(ngf * mult)]

        for i in range(n_downsampling):
            mult = 2 ** (n_downsampling - i)
            model += [
                nn.ConvTranspose2d(ngf * mult, ngf * mult // 2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=True),
                nn.InstanceNorm2d(ngf * mult // 2),
                nn.ReLU(True)
            ]

        model += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self, input_nc=3, ndf=64, n_layers=3):
        super().__init__()

        model = [
            nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, True)
        ]

        mult = 1
        for i in range(1, n_layers):
            mult_prev = mult
            mult = min(2 ** i, 8)
            model += [
                nn.Conv2d(ndf * mult_prev, ndf * mult, kernel_size=4, stride=2, padding=1, bias=True),
                nn.InstanceNorm2d(ndf * mult),
                nn.LeakyReLU(0.2, True)
            ]

        mult_prev = mult
        mult = min(2 ** n_layers, 8)
        model += [
            nn.Conv2d(ndf * mult_prev, ndf * mult, kernel_size=4, stride=1, padding=1, bias=True),
            nn.InstanceNorm2d(ndf * mult),
            nn.LeakyReLU(0.2, True)
        ]

        model += [nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)


class ImagePool:
    def __init__(self, pool_size):
        self.pool_size = pool_size
        self.num_imgs = 0
        self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images:
            image = torch.unsqueeze(image.data, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs += 1
                self.images.append(image)
                return_images.append(image)
            else:
                if random.random() > 0.5:
                    idx = random.randint(0, self.pool_size - 1)
                    tmp = self.images[idx].clone()
                    self.images[idx] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return torch.cat(return_images, 0)


def set_requires_grad(nets, requires_grad=False):
    if not isinstance(nets, list):
        nets = [nets]
    for net in nets:
        if net is not None:
            for param in net.parameters():
                param.requires_grad = requires_grad

In [4]:
G_photo2monet = Generator().to(device)
G_monet2photo = Generator().to(device)

D_photo = Discriminator().to(device)
D_monet = Discriminator().to(device)

print_model_summary(G_photo2monet, 'Generator')
print_model_summary(D_photo, 'Discriminator')

opt_G = Adam(itertools.chain(G_photo2monet.parameters(), G_monet2photo.parameters()), lr=LR, betas=BETAS)
opt_D_photo = Adam(D_photo.parameters(), lr=LR, betas=BETAS)
opt_D_monet = Adam(D_monet.parameters(), lr=LR, betas=BETAS)

def lambda_rule(epoch):
    decay_start = EPOCHS // 2
    return 1.0 - max(0, epoch - decay_start) / (EPOCHS - decay_start + 1)

sched_G = lr_scheduler.LambdaLR(opt_G, lr_lambda=lambda_rule)
sched_D_photo = lr_scheduler.LambdaLR(opt_D_photo, lr_lambda=lambda_rule)
sched_D_monet = lr_scheduler.LambdaLR(opt_D_monet, lr_lambda=lambda_rule)

pool_photo = ImagePool(POOL_SIZE)
pool_monet = ImagePool(POOL_SIZE)

criterion_GAN = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

checkpoint_manager = CheckpointManager(checkpoint_dir=str(CHECKPOINT_DIR), model_name=MODEL_NAME)
start_epoch = 0


Generator Summary
Total Parameters:     11,378,179
Trainable Parameters: 11,378,179
Non-trainable:        0


Discriminator Summary
Total Parameters:     2,764,737
Trainable Parameters: 2,764,737
Non-trainable:        0



In [5]:
def train():
    monitor = TrainingMonitor(model_name=MODEL_NAME, log_dir=str(LOGS_DIR))
    monitor.start_training()

    for epoch in range(start_epoch, EPOCHS):
        epoch_start = time.time()

        if len(train_loader_photo) > len(train_loader_monet):
            iter_monet = itertools.cycle(train_loader_monet)
            iter_photo = iter(train_loader_photo)
        else:
            iter_photo = itertools.cycle(train_loader_photo)
            iter_monet = iter(train_loader_monet)

        num_batches = max(len(train_loader_photo), len(train_loader_monet))

        for i in range(num_batches):
            real_photo = next(iter_photo).to(device)
            real_monet = next(iter_monet).to(device)

            fake_monet = G_photo2monet(real_photo)
            rec_photo = G_monet2photo(fake_monet)
            fake_photo = G_monet2photo(real_monet)
            rec_monet = G_photo2monet(fake_photo)

            if LAMBDA_ID > 0:
                idt_photo = G_monet2photo(real_photo)
                idt_monet = G_photo2monet(real_monet)

            set_requires_grad([D_photo, D_monet], False)
            opt_G.zero_grad()

            loss_GAN_photo2monet = criterion_GAN(D_monet(fake_monet), torch.ones_like(D_monet(fake_monet)))
            loss_GAN_monet2photo = criterion_GAN(D_photo(fake_photo), torch.ones_like(D_photo(fake_photo)))

            loss_cycle_photo = criterion_cycle(rec_photo, real_photo) * LAMBDA_CYCLE
            loss_cycle_monet = criterion_cycle(rec_monet, real_monet) * LAMBDA_CYCLE

            if LAMBDA_ID > 0:
                loss_id_photo = criterion_identity(idt_photo, real_photo) * LAMBDA_ID
                loss_id_monet = criterion_identity(idt_monet, real_monet) * LAMBDA_ID
                loss_G = loss_GAN_photo2monet + loss_GAN_monet2photo + loss_cycle_photo + loss_cycle_monet + loss_id_photo + loss_id_monet
            else:
                loss_G = loss_GAN_photo2monet + loss_GAN_monet2photo + loss_cycle_photo + loss_cycle_monet

            loss_G.backward()
            opt_G.step()

            set_requires_grad(D_photo, True)
            opt_D_photo.zero_grad()
            fake_photo_pool = pool_photo.query(fake_photo.detach())
            loss_D_photo_real = criterion_GAN(D_photo(real_photo), torch.ones_like(D_photo(real_photo)))
            loss_D_photo_fake = criterion_GAN(D_photo(fake_photo_pool), torch.zeros_like(D_photo(fake_photo_pool)))
            loss_D_photo = (loss_D_photo_real + loss_D_photo_fake) * 0.5
            loss_D_photo.backward()
            opt_D_photo.step()

            set_requires_grad(D_monet, True)
            opt_D_monet.zero_grad()
            fake_monet_pool = pool_monet.query(fake_monet.detach())
            loss_D_monet_real = criterion_GAN(D_monet(real_monet), torch.ones_like(D_monet(real_monet)))
            loss_D_monet_fake = criterion_GAN(D_monet(fake_monet_pool), torch.zeros_like(D_monet(fake_monet_pool)))
            loss_D_monet = (loss_D_monet_real + loss_D_monet_fake) * 0.5
            loss_D_monet.backward()
            opt_D_monet.step()

            monitor.log_loss('loss_G', loss_G.item())
            monitor.log_loss('loss_D', (loss_D_photo + loss_D_monet).item())
            monitor.log_loss('loss_cycle', (loss_cycle_photo + loss_cycle_monet).item())
            if LAMBDA_ID > 0:
                monitor.log_loss('loss_identity', (loss_id_photo + loss_id_monet).item())

            if i % 100 == 0:
                print(f"Epoch [{epoch}/{EPOCHS}] Batch [{i}/{num_batches}] "
                      f"Loss_G: {loss_G.item():.4f} Loss_D: {(loss_D_photo + loss_D_monet).item():.4f}")

            if i % 200 == 0:
                with torch.no_grad():
                    samples = [real_photo[0], fake_monet[0], rec_photo[0], real_monet[0], fake_photo[0], rec_monet[0]]
                    titles = ['Photo', 'Fake Monet', 'Rec Photo', 'Monet', 'Fake Photo', 'Rec Monet']
                    save_samples(samples, SAMPLES_DIR / f"epoch_{epoch}_batch_{i}.png", nrow=3, titles=titles)

        epoch_time = time.time() - epoch_start
        monitor.log_epoch_time(epoch_time)
        monitor.log_learning_rate(sched_G.get_last_lr()[0])

        if epoch == start_epoch:
            print(f"Estimated total time: {epoch_time * EPOCHS / 3600:.1f} hours")

        sched_G.step()
        sched_D_photo.step()
        sched_D_monet.step()

        checkpoint_manager.save(
            epoch,
            G_photo2monet=G_photo2monet, G_monet2photo=G_monet2photo,
            D_photo=D_photo, D_monet=D_monet,
            opt_G=opt_G, opt_D_photo=opt_D_photo, opt_D_monet=opt_D_monet
        )

    monitor.save()

In [None]:
train()



Epoch [0/200] Batch [0/6287] Loss_G: 19.6496 Loss_D: 1.3133
Samples saved to /Users/andreipriboi/PycharmProjects/gan-style-transfer/data/cyclegan/outputs/generated_images/epoch_0_batch_0.png
Epoch [0/200] Batch [100/6287] Loss_G: 8.1594 Loss_D: 0.5812
Epoch [0/200] Batch [200/6287] Loss_G: 8.0220 Loss_D: 0.5761
Samples saved to /Users/andreipriboi/PycharmProjects/gan-style-transfer/data/cyclegan/outputs/generated_images/epoch_0_batch_200.png
Epoch [0/200] Batch [300/6287] Loss_G: 10.9101 Loss_D: 0.3009


In [None]:
def evaluate():
    G_photo2monet.eval()
    G_monet2photo.eval()

    evaluator = Evaluator(model_name=MODEL_NAME, device=device, output_dir=str(REPORTS_DIR))

    print("Computing real Monet statistics for FID...")
    evaluator.set_real_data(val_loader_monet)

    print("Evaluating Photo→Monet translation...")
    result = evaluator.evaluate(
        generator=G_photo2monet,
        content_loader=val_loader_photo,
        num_samples=min(500, len(val_loader_photo.dataset))
    )

    result_path = REPORTS_DIR / f"{MODEL_NAME}_evaluation.json"
    result.save(str(result_path))
    print(f"Evaluation saved to {result_path}")

    print("Generating comparison samples...")
    with torch.no_grad():
        sample_photos = next(iter(val_loader_photo)).to(device)
        sample_monets = next(iter(val_loader_monet)).to(device)

        fake_monets = G_photo2monet(sample_photos)
        fake_photos = G_monet2photo(sample_monets)
        rec_photos = G_monet2photo(fake_monets)
        rec_monets = G_photo2monet(fake_photos)

        samples_p2m = [sample_photos[0], fake_monets[0], rec_photos[0]]
        titles_p2m = ['Photo', '→ Monet', '→ Photo']
        save_samples(samples_p2m, SAMPLES_DIR / "final_photo2monet_cycle.png", nrow=3, titles=titles_p2m)

        samples_m2p = [sample_monets[0], fake_photos[0], rec_monets[0]]
        titles_m2p = ['Monet', '→ Photo', '→ Monet']
        save_samples(samples_m2p, SAMPLES_DIR / "final_monet2photo_cycle.png", nrow=3, titles=titles_m2p)

        n_samples = min(4, sample_photos.size(0))
        grid_samples = []
        grid_titles = []
        for i in range(n_samples):
            grid_samples.extend([sample_photos[i], fake_monets[i], rec_photos[i]])
            grid_titles.extend([f'Photo {i+1}', f'Fake Monet {i+1}', f'Rec Photo {i+1}'])
        save_samples(grid_samples, SAMPLES_DIR / "final_grid.png", nrow=3, titles=grid_titles)

    print(f"\n{'='*50}")
    print(f"Evaluation Summary: {MODEL_NAME}")
    print(f"{'='*50}")
    print(f"FID Score: {result.fid_score:.2f}")
    print(f"SSIM Score: {result.ssim_score:.4f}")
    print(f"Total Parameters: {result.total_params:,}")
    print(f"{'='*50}")

    G_photo2monet.train()
    G_monet2photo.train()

    return result

In [None]:
result = evaluate()

In [None]:
from datasets import load_dataset
ds = load_dataset("huggan/monet2photo", split="train", trust_remote_code=True)
print(type(ds[0]))
print(ds[0].keys())
print(type(ds[0]['imageA']))

In [2]:
print(ds[0]['imageA'].keys())
print(type(ds[0]['imageA']['bytes']))

dict_keys(['bytes', 'path'])
<class 'bytes'>


In [3]:
print(ds[0]['imageA'])


{'bytes': b'\xff\xd8\xff\xe0\x00\x10JFIF\x00\x01\x01\x01\x01,\x01,\x00\x00\xff\xdb\x00C\x00\x08\x06\x06\x07\x06\x05\x08\x07\x07\x07\t\t\x08\n\x0c\x14\r\x0c\x0b\x0b\x0c\x19\x12\x13\x0f\x14\x1d\x1a\x1f\x1e\x1d\x1a\x1c\x1c $.\' ",#\x1c\x1c(7),01444\x1f\'9=82<.342\xff\xdb\x00C\x01\t\t\t\x0c\x0b\x0c\x18\r\r\x182!\x1c!22222222222222222222222222222222222222222222222222\xff\xc0\x00\x11\x08\x01\x00\x01\x00\x03\x01"\x00\x02\x11\x01\x03\x11\x01\xff\xc4\x00\x1b\x00\x00\x02\x03\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x04\x05\x02\x03\x06\x01\x07\x00\xff\xc4\x00;\x10\x00\x02\x02\x01\x03\x02\x05\x01\x06\x05\x03\x04\x03\x00\x03\x00\x01\x02\x03\x11\x00\x04\x12!1A\x05\x13"Qaq\x06\x142\x81\x91\xa1B\xb1\xc1\xd1\xf0#\xe1\xf1\x153Rb\x07$CSr\x82\xff\xc4\x00\x1a\x01\x00\x03\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x03\x04\x01\x00\x05\x06\xff\xc4\x00%\x11\x00\x02\x02\x02\x02\x01\x05\x01\x01\x01\x01\x00\x00\x00\x00\x00\x00\x01\x02\x11\x03!\x121A\x04\x13"2Qaq\x14#\xff\xda\x00\x0c\x0