# Imports

In [1]:
import os
import torch
import torch.nn as nn
import numpy as np
import cv2
import matplotlib.pyplot as plt
import itertools
import torch_fidelity
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from PIL import Image
import torchvision.transforms as transforms
import pandas as pd
import random
import torch.nn.functional as F
import sys
from tqdm import tqdm

SEED = random.randrange(2**32 - 1)
# SEED = 2853739981
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print("Random Seed:", SEED)


Random Seed: 1343577804


# Work

In [None]:
from models.generators import UNetGenerator, FastStyleGenerator, DefaultGenerator, ZhirongGenerator, BasicGenerator
from models.discriminators import PatchDiscriminator, GithubDiscriminator, NLayerDiscriminator, RandomKaggleDiscriminator, PatchSpectralDiscriminator, DefaultDiscriminator

"""
Step 4. Initalize G and D
"""
G_AB = UNetGenerator()
D_B = GithubDiscriminator()
G_BA = BasicGenerator()
D_A = GithubDiscriminator()

## Total parameters in CycleGAN should be less than 60MB
total_params = sum(p.numel() for p in G_AB.parameters()) + \
               sum(p.numel() for p in G_BA.parameters()) + \
               sum(p.numel() for p in D_A.parameters()) + \
               sum(p.numel() for p in D_B.parameters())


def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.xavier_uniform_(m.weight.data)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)
    elif classname.find('Linear') != -1:
        nn.init.xavier_uniform_(m.weight.data)
        nn.init.constant_(m.bias.data, 0)

# init weights
# G_AB.apply(weights_init)
# D_B.apply(weights_init)
# G_BA.apply(weights_init)
# D_A.apply(weights_init)


total_params_million = total_params / (1024 * 1024)
print(f'Total parameters in CycleGAN model: {total_params_million:.2f} million')

Total parameters in CycleGAN model: 27.05 million


In [3]:
"""
Step 3. Define Loss
"""
from collections import defaultdict
from loss_functions import GradientPreservationLoss, IdentityPreservationLoss, PatchNCELoss, HingeAdversarialLoss, EdgeConsistencyLoss, LineContinuityLoss

# criterion_GAN = HingeAdversarialLoss()
# criterion_GAN = nn.MSELoss()
criterion_GAN = nn.BCELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = IdentityPreservationLoss()
# criterion_gradient = GradientPreservationLoss()
# criterion_contrast = PatchNCELoss()
# edge_loss = EdgeAwareLoss()


def dynamic_loss_weighting(epoch, total_epochs):
    # Early epochs: Focus on content preservation
    d = defaultdict(float)
    if epoch < total_epochs * 0.3:
        d.update({
            'identity': 5.0,
            'gan': 1.0,
            'cycle': 10.0,
            'facial_component': 2.0,
            'edge': 2.0,
        })
    elif epoch < total_epochs * 0.7:
        d.update({
            'identity': 2.0,
            'gan': 1.0,
            'cycle': 10.0,
            'line': 2.0,
            'facial_component': 5.0,
        })
    else:
        d.update({
            'identity': 1.0,
            'gan': 1.5,
            'cycle': 10.0,
            'line': 3.0,
            'facial_component': 5.0,
        })
    return d

In [4]:
if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.current_device()}")
    print(f"Current GPU name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
    DEVICE = torch.device("cuda")
    G_AB = G_AB.cuda()
    D_B = D_B.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()
    criterion_GAN = criterion_GAN.cuda()
    criterion_cycle = criterion_cycle.cuda()
    criterion_identity = criterion_identity.cuda()
    TENSOR = torch.cuda.FloatTensor
else:
    print("PyTorch does not have access to GPU, falling back to CPU")
    TENSOR = torch.Tensor
    DEVICE = torch.device("cpu")



Current GPU: 0
Current GPU name: NVIDIA GeForce RTX 4070 Ti


In [None]:
"""
Step 5. Configure Optimizers
"""

def get_lr_scheduler(optimizer, n_epochs=100, n_epochs_decay=50, lr_policy='linear', step_size=50, gamma=0.5, min_lr=1e-6, monitor='loss', patience=10):
    if lr_policy == 'linear':
        def lambda_rule(epoch):
            # Keep constant for first n_epochs, then linearly decay to zero
            lr_l = 1.0 - max(0, epoch - n_epochs) / float(n_epochs_decay + 1)
            return lr_l
        scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)

    elif lr_policy == 'step':
        # Decays the learning rate by gamma every step_size epochs
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    elif lr_policy == 'exponential':
        # Exponentially decays the learning rate by gamma every epoch
        scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=gamma)

    elif lr_policy == 'cosine':
        # Cosine annealing from initial lr to min_lr over total epochs
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=n_epochs + n_epochs_decay, eta_min=min_lr)

    elif lr_policy == 'plateau':
        # Reduces learning rate when a metric has stopped improving
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min' if monitor == 'loss' else 'max', factor=gamma, patience=patience, min_lr=min_lr)

    else:
        raise NotImplementedError(f'learning rate policy {lr_policy} not implemented')

    return scheduler

# Optimizer setup
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()),
                               lr=0.0002, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(itertools.chain(D_A.parameters(), D_B.parameters()),
                               lr=0.0002, betas=(0.5, 0.999))

# Learning rate schedulers
scheduler_G = get_lr_scheduler(optimizer_G, lr_policy='cosine', min_lr=0.00004, n_epochs=300, n_epochs_decay=100)
scheduler_D = get_lr_scheduler(optimizer_D, lr_policy='cosine', min_lr=0.00004, n_epochs=300, n_epochs_decay=100)


In [5]:
from typing import Type, Optional, Literal

def evaluate(
    model: torch.nn.Module,
    input_dir: str,
    output_dir: str,
    ref_dir: str,
    batch_size: int,
    generate_transforms: transforms.Compose,
    dataloader: Optional[DataLoader] = None,
    mode: Optional[Literal['A_B', 'B_A']] = None,
    verbose: bool = False
) -> float:

    if input_dir is not None and dataloader is None:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        files = [os.path.join(input_dir, name) for name in os.listdir(input_dir)]

        def file_loader():
            for i in range(0, len(files), batch_size):
                imgs = [generate_transforms(Image.open(files[j])) for j in range(i, min(len(files), i + batch_size))]
                yield torch.stack(imgs, 0).type(TENSOR), files[i:i + batch_size]

        loader = file_loader()

    elif dataloader is None:
        raise ValueError("Either input_dir or dataloader must be provided.")

    elif dataloader and mode is None:
        raise ValueError("Mode must be provided if dataloader is provided.")

    elif mode == 'A_B':
        if not os.path.exists('evaluation_dumps/A_B'):
            os.makedirs('evaluation_dumps/A_B')
        output_dir = 'evaluation_dumps/A_B'
    elif mode == 'B_A':
        if not os.path.exists('evaluation_dumps/B_A'):
            os.makedirs('evaluation_dumps/B_A')
        output_dir = 'evaluation_dumps/B_A'

    model.eval()
    ############## ORIGINAL CODE ##############
    # for i in range(0, len(files), batch_size):
    #     # Read and transform images
    #     imgs = [generate_transforms(Image.open(files[j])) for j in range(i, min(len(files), i + batch_size))]
    #     imgs = torch.stack(imgs, 0).type(TENSOR)

    #     # Generate images
    #     fake_imgs = model(imgs).detach().cpu()

    #     # Save generated images
    #     for j in range(fake_imgs.size(0)):
    #         img = fake_imgs[j].squeeze().permute(1, 2, 0).numpy()
    #         img = (img - np.min(img)) * 255 / (np.max(img) - np.min(img))
    #         img = transforms.ToPILImage()(img.astype(np.uint8))
    #         _, name = os.path.split(files[i + j])
    #         img.save(os.path.join(output_dir, name))
    ###########################################

    ############### WITH DATALOADER SUPPORT ###############
    with torch.no_grad():
        for i, batch in enumerate(dataloader if dataloader else loader):
            if isinstance(batch, (list, tuple)) and len(batch) == 4: # the only time that its a dataloader is when both A and B are "zipped"
                imgs = (batch[0] if mode == 'A_B' else batch[1]).type(TENSOR)
                filenames = batch[2] if mode == 'A_B' else batch[3]
            else: # otherwise, its a single folder
                imgs, filenames = batch

            # Generate images
            fake_imgs = model(imgs).detach().cpu()

            # Save generated images
            for img, name in zip(fake_imgs, filenames):
                img = img.squeeze().permute(1, 2, 0).numpy()
                img = (img - np.min(img)) * 255 / (np.max(img) - np.min(img))
                img = transforms.ToPILImage()(img.astype(np.uint8))
                img.save(os.path.join(output_dir, os.path.basename(name)))
        torch.cuda.empty_cache()
    ######################################################


    # Compute metrics
    metrics: dict[str, float] = torch_fidelity.calculate_metrics(
        input1=output_dir,
        input2=ref_dir if dataloader is None else dataloader.dataset.get_partial_dataset('B' if mode == 'A_B' else 'A'),
        cuda=True,
        fid=True,
        isc=True,
        verbose=verbose
    )

    fid_score: float = metrics["frechet_inception_distance"]
    is_score: float = metrics["inception_score_mean"]

    del imgs; del fake_imgs
    if is_score > 0:
        gms: float = np.sqrt(fid_score / is_score)
        print("Geometric Mean Score:", gms)
        return gms, fid_score, is_score
    else:
        print("IS is 0, GMS cannot be computed!")
        return 0, 0, 0

In [None]:
from typing import Callable, Literal


def run_one_epoch(
    G_AB: nn.Module,
    G_BA: nn.Module,
    D_A: nn.Module,
    D_B: nn.Module,
    state: Literal["train", "eval"],
    dataloader: DataLoader,
    criterion_identity: Callable,
    criterion_GAN: Callable,
    criterion_cycle: Callable,
    optimizer_G: torch.optim.Optimizer,
    optimizer_D: torch.optim.Optimizer,
    # optimizer_D_A: torch.optim.Optimizer,
    # optimizer_D_B: torch.optim.Optimizer,
    lambdas: dict[str, float],
) -> dict[str, float]:

    G_AB.train(), G_BA.train()

    running_losses = {
        "G": 0.0, "D_A": 0.0, "D_B": 0.0,
        "identity": 0.0, "gan": 0.0, "cycle": 0.0, "line": 0.0
    }

    with tqdm(dataloader, unit="batch", desc="Training" if state == "train" else "Validation") as tepoch:
        for real_A, real_B, *_ in tepoch:
            optimizer_G.zero_grad()
            D_A.eval(), D_B.eval()
            with torch.amp.autocast(DEVICE.type, enabled=USEAUTOCAST):
                real_A, real_B = real_A.to(DEVICE, non_blocking=True), real_B.to(DEVICE, non_blocking=True)

                # Train Generators
                fake_B = G_AB(real_A)
                fake_A = G_BA(real_B)

                # Identity loss
                loss_id_A = criterion_identity(G_BA(real_A), real_A)
                loss_id_B = criterion_identity(G_AB(real_B), real_B)
                loss_identity = (loss_id_A + loss_id_B) / 2

                # GAN loss - trying to fool D
                fake_B_pred = D_B(fake_B)
                fake_A_pred = D_A(fake_A)

                # because discriminator outputs are supposed to be identical no matter where, I can reuse
                valid, fake = torch.ones_like(fake_A_pred, requires_grad=False), torch.zeros_like(fake_A_pred, requires_grad=False)

                loss_GAN_AB = criterion_GAN(fake_B_pred, valid)
                loss_GAN_BA = criterion_GAN(fake_A_pred, valid)
                loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

                # Cycle loss
                loss_cycle_A = criterion_cycle(G_BA(fake_B), real_A)
                loss_cycle_B = criterion_cycle(G_AB(fake_A), real_B)
                loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

                # # Line continuity loss
                # loss_line_A = line_loss(fake_A, real_A)
                # loss_line_B = line_loss(fake_B, real_B)
                # loss_line = (loss_line_A + loss_line_B) / 2

            # Total generator loss
            loss_G = lambdas['identity'] * loss_identity + lambdas['gan'] * loss_GAN + lambdas['cycle'] * loss_cycle
            loss_G.backward()
            optimizer_G.step()



            # Train Discriminators
            optimizer_D.zero_grad()
            D_A.train(), D_B.train()
            with torch.amp.autocast(DEVICE.type, enabled=USEAUTOCAST):
                real_A_pred = D_A(real_A)
                fake_A_pred = D_A(fake_A.detach())
                loss_real_A = criterion_GAN(real_A_pred, valid)
                loss_fake_A = criterion_GAN(fake_A_pred, fake)
                loss_D_A = (loss_real_A + loss_fake_A) / 2
                
                real_B_pred = D_B(real_B)
                fake_B_pred = D_B(fake_B.detach())
                loss_real_B = criterion_GAN(real_B_pred, valid)
                loss_fake_B = criterion_GAN(fake_B_pred, fake)
                loss_D_B = (loss_real_B + loss_fake_B) / 2
                
                loss_D = (loss_D_A + loss_D_B)
            loss_D.backward()
            optimizer_D.step()

            # Accumulate losses
            running_losses["G"] += loss_G.item()
            running_losses["D_A"] += loss_D_A.item()
            running_losses["D_B"] += loss_D_B.item()
            running_losses["identity"] += loss_identity.item()
            running_losses["gan"] += loss_GAN.item()
            running_losses["cycle"] += loss_cycle.item()
            # running_losses["line"] += loss_line.item()

    print(f'[G loss: {loss_G.item()} | identity: {loss_identity.item()} GAN: {loss_GAN.item()} cycle: {loss_cycle.item()}]')
    print(f'[D loss: {loss_D.item()} | D_A: {loss_D_A.item()} D_B: {loss_D_B.item()}]')
    # Average the losses over the dataset
    return {k: v / len(dataloader) for k, v in running_losses.items()}

In [6]:
"""
Step 6. DataLoader
"""
from CustomImageDataset import CustomImageDataset as ImageDataset
# data_dir = '/kaggle/input/group-project/image_image_translation'
data_dir = ''

image_size = (256, 256)
transforms_ = transforms.Compose([
    transforms.Resize(image_size),
    # transforms.RandomHorizontalFlip(p=0.5),  # 50% chance of flipping
    # transforms.RandomAffine(degrees=5, translate=(0.1,0.1)),  # Position variation
    # transforms.RandomPerspective(0.2, 0.3),  # Perspective distortion
    # transforms.RandomPosterize(bits=5, p=0.3),  # Posterize
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 5.0))], 0.3),
    # transforms.ColorJitter(brightness=0.4, contrast=0.2, saturation=0.2, hue=0.05),
    # transforms.RandomErasing(p=0.1, scale=(0.05, 0.1)),  # Erase a part of the image
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

evaluation_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])


loader_params = {
    "batch_size": 4,
    "num_workers": 2,
    "pin_memory": True,
    "shuffle": True,
    # "prefetch_factor": 2,
    "persistent_workers": True
}


trainloader = DataLoader(
    ImageDataset(data_dir, mode='train', transform=transforms_),
    **loader_params
)

loader_params["shuffle"] = False
validloader = DataLoader(
    ImageDataset(data_dir, mode='valid', transform=transforms_),
    **loader_params
)

In [7]:
"""
Step 7. Training
"""
n_epochs = 300
EVALUATION_INTERVAL = 15
evaluation_scaling = 1
patience = 0
USEAUTOCAST = False
loss_history = []
best_metrics = {
    'AVG': float('inf'),
    'AB': {'GMS': float('inf'), 'FID': float('inf'), 'IS': float('inf')},
    'BA': {'GMS': float('inf'), 'FID': float('inf'), 'IS': float('inf')},
    'BEST_AB': {'GMS': float('inf'), 'FID': float('inf'), 'IS': float('inf'), 'epoch': float('inf')},
    'BEST_BA': {'GMS': float('inf'), 'FID': float('inf'), 'IS': float('inf'), 'epoch': float('inf')}
}
for epoch in range(1, n_epochs+1):
    print(f'[Epoch {epoch}/{n_epochs}]')
    epoch_loss = run_one_epoch(
        G_AB, G_BA, D_A, D_B, "train", trainloader,
        criterion_identity, criterion_GAN, criterion_cycle,
        optimizer_G, optimizer_D,
        dynamic_loss_weighting(epoch, n_epochs)
    )

    loss_history.append(epoch_loss)

    # validation
    if epoch % int(EVALUATION_INTERVAL * evaluation_scaling) == 0:
        metrics = best_metrics.copy()
        metrics['AB']['GMS'], metrics['AB']['FID'], metrics['AB']['IS'] = evaluate(G_AB, 
                                                                'NO_INPUT_DIR',
                                                                'EVAL_DUMP',
                                                                'NO_REF_DIR',
                                                                loader_params['batch_size'],
                                                                evaluation_transforms,
                                                                validloader,
                                                                'A_B',
                                                                verbose=True)
        metrics['BA']['GMS'], metrics['BA']['FID'], metrics['BA']['IS'] = evaluate(G_BA, 
                                                                'NO_INPUT_DIR',
                                                                'EVAL_DUMP',
                                                                'NO_REF_DIR',
                                                                loader_params['batch_size'],
                                                                evaluation_transforms,
                                                                validloader,
                                                                'B_A')

        metricsString_AB = " ".join([f"{k}={v:4f}" for k, v in metrics['AB'].items()])
        metricsString_BA = " ".join([f"{k}={v:4f}" for k, v in metrics['BA'].items()])
        metrics['AVG'] = (metrics['AB']['GMS'] + metrics['BA']['GMS']) / 2

        print("A->B |", metricsString_AB)
        print("B->A |", metricsString_BA)
        print(f"Average GMS: {metrics['AVG']}, Best: {best_metrics['AVG']}")

        if metrics['AVG'] > best_metrics['AVG'] * 1.05:
            print("💩🚽 Getting Cooked!")

        elif metrics['AVG'] < best_metrics['AVG']:
            best_metrics['AB'] = metrics['AB'].copy()
            best_metrics['BA'] = metrics['BA'].copy()
            best_metrics['AVG'] = metrics['AVG']
            patience = 0
            print("(ง 🔥 ﾛ 🔥 )ง CooKING! 🚀")
            # Save model checkpoints  
            torch.save(G_AB.state_dict(), f'checkpoints/G_AB_{epoch}.pth')
            torch.save(D_A.state_dict(), f'checkpoints/D_A_{epoch}.pth')
            torch.save(G_BA.state_dict(), f'checkpoints/G_BA_{epoch}.pth')
            torch.save(D_B.state_dict(), f'checkpoints/D_B_{epoch}.pth')
        else:
            patience += 1

        if patience >= 3:  # Reduce evaluation frequency
            print("I'm losing my patience! Best GMS:", best_metrics['AVG'])
            evaluation_interval = int(1.5 * evaluation_scaling)
        

        if metrics['AB']['GMS'] < best_metrics['BEST_AB']['GMS']:
            best_metrics['BEST_AB'] = metrics['AB'].copy()
            best_metrics['BEST_AB']['epoch'] = epoch
            torch.save(G_AB.state_dict(), f'checkpoints_best/G_AB.pth')
            print('Best A->B model saved. Epoch:', epoch)

        if metrics['BA']['GMS'] < best_metrics['BEST_BA']['GMS']:
            best_metrics['BEST_BA'] = metrics['BA'].copy()
            best_metrics['BEST_BA']['epoch'] = epoch
            torch.save(G_BA.state_dict(), f'checkpoints_best/G_BA.pth')
            print('Best B->A model saved. Epoch:', epoch)

torch.save(G_AB.state_dict(), f'checkpoints/G_AB_{n_epochs}.pth')
torch.save(D_A.state_dict(), f'checkpoints/D_A_{n_epochs}.pth')
torch.save(G_BA.state_dict(), f'checkpoints/G_BA_{n_epochs}.pth')
torch.save(D_B.state_dict(), f'checkpoints/D_B_{n_epochs}.pth')

[Epoch 1/300]


NameError: name 'run_one_epoch' is not defined

# Output Results

In [None]:

import pickle
from datetime import datetime
def plot_losses(loss_history):
    formatted_loss_history = {k: [epoch_losses[k] for epoch_losses in loss_history] for k in loss_history[0]}
    timestamp = datetime.now().strftime('%Y%m%d_%H%M')

    with open(os.path.join(f"loss_log-{timestamp}.pkl"), 'wb') as f:
        pickle.dump(formatted_loss_history, f)

    plt.figure(figsize=(10, 6))
    for loss_name, values in formatted_loss_history.items():
        plt.plot(values, label=loss_name)

    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training Losses Over Epochs')
    plt.legend()
    plt.grid(True)
    plt.savefig(f"loss_plot-{timestamp}.png")
    plt.show()


In [8]:
# plot_losses(loss_history)

# Parameters
batch_size = loader_params["batch_size"]

# data_dir = '/kaggle/input/group-project/image_image_translation'
data_dir = '.'


def format_score(score):
    return "{:.4f}".format(score).replace(".", "_")

def report_score(s_value_1, fid_1, is_1, s_value_2, fid_2, is_2, epoch=None):
    s_value = np.round((s_value_1+s_value_2)/2, 5)
    df = pd.DataFrame({'id': [1], 'label': [s_value]})
    filename = f"{format_score(s_value)}-C[{format_score(s_value_1)}-{format_score(fid_1)}-{format_score(is_1)}]-R[{format_score(s_value_2)}-{format_score(fid_2)}-{format_score(is_2)}]"
    if epoch: filename = f"{epoch}-{filename}"
    csv_path = filename+".csv"
    df.to_csv(csv_path, index=False)

    print(f"CSV saved to {csv_path}")



# Process latest model
print("Processing latest model...")
G_AB.load_state_dict(torch.load(f'curbest/with256/G_AB.pth'))
G_BA.load_state_dict(torch.load(f'curbest/with256/G_BA.pth'))
s_value_1, fid_1, is_1 = evaluate(
    model=G_AB,
    input_dir=os.path.join(data_dir, 'VAE_generation/test'),
    output_dir='../Cartoon_images',
    ref_dir=f"{data_dir}/VAE_generation_Cartoon/test",
    batch_size=batch_size,
    generate_transforms=evaluation_transforms,
    verbose=True
)
s_value_2, fid_2, is_2 = evaluate(
    model=G_BA,
    input_dir=os.path.join(data_dir, 'VAE_generation_Cartoon/test'),
    output_dir='../Raw_images',
    ref_dir=f"{data_dir}/VAE_generation/test",
    batch_size=batch_size,
    generate_transforms=evaluation_transforms,
    verbose=True
)
report_score(s_value_1, fid_1, is_1, s_value_2, fid_2, is_2, epoch=n_epochs)
print()

# Process best for each side
# try:
#     print("Processing best sides...")
#     G_AB.load_state_dict(torch.load('checkpoints_best/G_AB.pth'))
#     G_BA.load_state_dict(torch.load('checkpoints_best/G_BA.pth'))
#     s_value_1, fid_1, is_1 = evaluate(
#         model=G_AB,
#         input_dir=os.path.join(data_dir, 'VAE_generation/test'),
#         output_dir='../Cartoon_images',
#         ref_dir=f"{data_dir}/VAE_generation_Cartoon/test",
#         batch_size=batch_size,
#         generate_transforms=evaluation_transforms,
#         verbose=True
#     )
#     s_value_2, fid_2, is_2 = evaluate(
#         model=G_BA,
#         input_dir=os.path.join(data_dir, 'VAE_generation_Cartoon/test'),
#         output_dir='../Raw_images',
#         ref_dir=f"{data_dir}/VAE_generation/test",
#         batch_size=batch_size,
#         generate_transforms=evaluation_transforms,
#         verbose=True
#     )
#     report_score(s_value_1, fid_1, is_1, s_value_2, fid_2, is_2)
# except:
#     print("Failed to process best sides.")
# finally:
#     print()

# print("Processing saved model checkpoints...")
# checkpoints_dir = 'checkpoints'
# generator_weights = defaultdict(lambda: [None, None])
# for filename in os.listdir(checkpoints_dir):
#     model_type, domain, epoch_num = filename[:-4].split('_')
#     if model_type == 'G':
#         generator_weights[epoch_num][0 if domain == 'AB' else 1] = os.path.join(checkpoints_dir, filename) 

# for epoch_num, (G_AB_path, G_BA_path) in generator_weights.items():
#     # Process saved model
#     G_AB.load_state_dict(torch.load(G_AB_path))
#     G_BA.load_state_dict(torch.load(G_BA_path))
#     print(f"Evaluating epoch {epoch_num}")

#     print("Metrics for A -> B:", end=" ")
#     # Raw to Cartoon
#     s_value_1, fid_1, is_1 = evaluate(
#         model=G_AB,
#         input_dir=os.path.join(data_dir, 'VAE_generation/test'),
#         output_dir='../Cartoon_images',
#         ref_dir=f"{data_dir}/VAE_generation_Cartoon/test",
#         batch_size=batch_size,
#         generate_transforms=evaluation_transforms
#     )

#     # Cartoon to Raw
#     s_value_2, fid_2, is_2 = evaluate(
#         model=G_BA,
#         input_dir=os.path.join(data_dir, 'VAE_generation_Cartoon/test'),
#         output_dir='../Raw_images',
#         ref_dir=f"{data_dir}/VAE_generation/test",
#         batch_size=batch_size,
#         generate_transforms=evaluation_transforms
#     )

#     report_score(s_value_1, fid_1, is_1, s_value_2, fid_2, is_2, epoch=epoch_num)




Processing latest model...


Creating feature extractor "inception-v3-compat" with features ['logits_unbiased', '2048']
Extracting features from input1
Looking for samples non-recursivelty in "../Cartoon_images" with extensions png,jpg,jpeg
Found 1000 samples, some are lossy-compressed - this may affect metrics
  img = torch.ByteTensor(torch.ByteStorage.from_buffer(img.tobytes())).view(height, width, 3)
Processing samples                                                          
Extracting features from input2
Looking for samples non-recursivelty in "./VAE_generation_Cartoon/test" with extensions png,jpg,jpeg
Found 1000 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                          
Inception Score: 2.136652938583595 ± 0.15139583515735827
Frechet Inception Distance: 41.89671612520374


Geometric Mean Score: 4.428156893062493


Creating feature extractor "inception-v3-compat" with features ['logits_unbiased', '2048']
Extracting features from input1
Looking for samples non-recursivelty in "../Raw_images" with extensions png,jpg,jpeg
Found 1000 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                          
Extracting features from input2
Looking for samples non-recursivelty in "./VAE_generation/test" with extensions png,jpg,jpeg
Found 1000 samples, some are lossy-compressed - this may affect metrics
Processing samples                                                          
Inception Score: 3.737211246187865 ± 0.2566777612262926


Geometric Mean Score: 3.9784483740348797
CSV saved to 300-4_2033-C[4_4282-41_8967-2_1367]-R[3_9784-59_1528-3_7372].csv



Frechet Inception Distance: 59.15277193971801


# Clean Up

In [None]:
# import gc
# del G_AB, G_BA, D_A, D_B

# gc.collect()
# with torch.no_grad():
#     torch.cuda.empty_cache()