In [None]:
!pip install torch_fidelity

In [None]:
import os
import torch
import torch.nn as nn
import numpy as np
import cv2
import matplotlib.pyplot as plt
import itertools

from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms

"""
Step 1. Define Generator
"""
class Generator(nn.Module):
    def __init__(self, in_channels):
        super(Generator, self).__init__()

    def forward(self, x):
        return x

"""
Step 2. Define Discriminator
"""
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

        )
    def forward(self, x):
        return self.model(x)

"""
Step 3. Define Loss
"""
criterion_GAN =
criterion_cycle =
criterion_identity =

"""
Step 4. Initalize G and D¶
"""
G_AB = Generator(3)
D_B = Discriminator(3)
G_BA = Generator(3)
D_A = Discriminator(3)

## Total parameters in CycleGAN should be less than 60MB
total_params = sum(p.numel() for p in G_AB.parameters()) + \
               sum(p.numel() for p in G_BA.parameters()) + \
               sum(p.numel() for p in D_A.parameters()) + \
               sum(p.numel() for p in D_B.parameters())


total_params_million = total_params / (1024 * 1024)
print(f'Total parameters in CycleGAN model: {total_params_million:.2f} million')

cuda = torch.cuda.is_available()
print(f'cuda: {cuda}')
if cuda:
    G_AB = G_AB.cuda()
    D_B = D_B.cuda()
    G_BA = G_BA.cuda()
    D_A = D_A.cuda()

criterion_GAN = criterion_GAN.cuda()
criterion_cycle = criterion_cycle.cuda()
criterion_identity = criterion_identity.cuda()

"""
Step 5. Configure Optimizers
"""
lr =
optimizer_G = torch.optim.Adam(itertools.chain(G_AB.parameters(), G_BA.parameters()), lr=lr)
optimizer_D_A = torch.optim.Adam(D_A.parameters(), lr=lr)
optimizer_D_B = torch.optim.Adam(D_B.parameters(), lr=lr)

"""
Step 6. DataLoader
"""
class ImageDataset(Dataset):
    def __init__(self, data_dir, mode='train', transforms=None):
        A_dir = os.path.join(data_dir, 'VAE_generation/train') # modification forbidden
        B_dir = os.path.join(data_dir, 'VAE_generation_Cartoon/train')  # modification forbidden

        if mode == 'train':
            self.files_A = [os.path.join(A_dir, name) for name in sorted(os.listdir(A_dir))[:200]] # can be modified
            self.files_B = [os.path.join(B_dir, name) for name in sorted(os.listdir(B_dir))[:200]] # can be modified
        elif mode == 'valid':
            self.files_A = [os.path.join(A_dir, name) for name in sorted(os.listdir(A_dir))[200:250]] # can be modified
            self.files_B = [os.path.join(B_dir, name) for name in sorted(os.listdir(B_dir))[200:250]] # can be modified

        self.transforms = transforms

    def __len__(self):
        return len(self.files_A)

    def __getitem__(self, index):
        file_A = self.files_A[index]
        file_B = self.files_B[index]

        img_A = Image.open(file_A)
        img_B = Image.open(file_B)

        if self.transforms is not None:
            img_A = self.transforms(img_A)
            img_B = self.transforms(img_B)

        return img_A, img_B

data_dir = '/kaggle/input/group-project/image_image_translation'

image_size = (256, 256)
transforms_ = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

batch_size = 5

trainloader = DataLoader(
    ImageDataset(data_dir, mode='train', transforms=transforms_),
    batch_size = batch_size,
    shuffle = True,
    num_workers = 3
)

validloader = DataLoader(
    ImageDataset(data_dir, mode='valid', transforms=transforms_),
    batch_size = batch_size,
    shuffle = False,
    num_workers = 3
)

Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

"""
Step 7. Training
"""
Tensor = torch.cuda.FloatTensor if cuda else torch.Tensor

SEED = 42
print("Random Seed:", SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

for epoch in range(n_epoches):
    for i, (real_A, real_B) in enumerate(trainloader):
        real_A, real_B = real_A.type(Tensor), real_B.type(Tensor)

        # groud truth
        out_shape = [real_A.size(0), 1, real_A.size(2)//D_A.scale_factor, real_A.size(3)//D_A.scale_factor]
        valid = torch.ones(out_shape).type(Tensor)
        fake = torch.zeros(out_shape).type(Tensor)

        """Train Generators"""
        # set to training mode in the begining, because sample_images will set it to eval mode
        G_AB.train()
        G_BA.train()

        optimizer_G.zero_grad()

        fake_B = G_AB(real_A)
        fake_A = G_BA(real_B)

        # identity loss
        loss_id_A = criterion_identity(fake_B, real_A)
        loss_id_B = criterion_identity(fake_A, real_B)
        loss_identity = (loss_id_A + loss_id_B) / 2

        # GAN loss, train G to make D think it's true
        loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
        loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # cycle loss
        recov_A = G_BA(fake_B)
        recov_B = G_AB(fake_A)
        loss_cycle_A = criterion_cycle(recov_A, real_A)
        loss_cycle_B = criterion_cycle(recov_B, real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # G totol loss
        weight1 =
        weight2 =
        weight3 =
        loss_G = weight1*loss_identity + weight2*loss_GAN + weight3*loss_cycle

        loss_G.backward()
        optimizer_G.step()

        """Train Discriminator A"""
        optimizer_D_A.zero_grad()

        loss_real = criterion_GAN(D_A(real_A), valid)
        loss_fake = criterion_GAN(D_A(fake_A.detach()), fake)
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

        """Train Discriminator B"""
        optimizer_D_B.zero_grad()

        loss_real = criterion_GAN(D_B(real_B), valid)
        loss_fake = criterion_GAN(D_B(fake_B.detach()), fake)
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

    # validation
    if (epoch+1) % 10 == 0:
        valid_real_A, valid_real_B = next(iter(testloader))
        sample_images(valid_real_A, valid_real_B)

        loss_D = (loss_D_A + loss_D_B) / 2
        print(f'[Epoch {epoch+1}/{n_epoches}]')
        print(f'[G loss: {loss_G.item()} | identity: {loss_identity.item()} GAN: {loss_GAN.item()} cycle: {loss_cycle.item()}]')
        print(f'[D loss: {loss_D.item()} | D_A: {loss_D_A.item()} D_B: {loss_D_B.item()}]')

"""
Step 8. Generate Images
"""
## Raw Image to Cartoon Image
test_dir = os.path.join(data_dir, 'VAE_generation/test') # modification forbidden
files = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]
len(files)

save_dir = '../Cartoon_images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

generate_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

to_image = transforms.ToPILImage()

G_BA.eval()
for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i+batch_size)):
        img = Image.open(files[j])
        img = generate_transforms(img)
        imgs.append(img)
    imgs = torch.stack(imgs, 0).type(Tensor)

    # generate
    fake_imgs = G_BA(imgs).detach().cpu()

    # save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)

        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

metrics = torch_fidelity.calculate_metrics(
    input1="/kaggle/input/group-project/image_image_translation/VAE_generation_Cartoon/test",
    input2=save_dir,
    cuda=True,
    fid=True,
    isc=True
)

fid_score = metrics["frechet_inception_distance"]
is_score = metrics["inception_score_mean"]

if is_score > 0:
    s_value_1 = np.sqrt(fid_score / is_score)
    print("Geometric Mean Score:", s_value)
else:
    print("IS is 0, GMS cannot be computed!")


## Cartoon Image to Raw Image

test_dir = os.path.join(data_dir, 'VAE_generation_Cartoon/test')
files = [os.path.join(test_dir, name) for name in os.listdir(test_dir)]
len(files)

save_dir = '../Raw_images'
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

generate_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

to_image = transforms.ToPILImage()

G_BA.eval()
for i in range(0, len(files), batch_size):
    # read images
    imgs = []
    for j in range(i, min(len(files), i+batch_size)):
        img = Image.open(files[j])
        img = generate_transforms(img)
        imgs.append(img)
    imgs = torch.stack(imgs, 0).type(Tensor)

    # generate
    fake_imgs = G_BA(imgs).detach().cpu()

    # save
    for j in range(fake_imgs.size(0)):
        img = fake_imgs[j].squeeze().permute(1, 2, 0)
        img_arr = img.numpy()
        img_arr = (img_arr - np.min(img_arr)) * 255 / (np.max(img_arr) - np.min(img_arr))
        img_arr = img_arr.astype(np.uint8)

        img = to_image(img_arr)
        _, name = os.path.split(files[i+j])
        img.save(os.path.join(save_dir, name))

metrics = torch_fidelity.calculate_metrics(
    input1="/kaggle/input/group-project/image_image_translation/VAE_generation/test",
    input2=save_dir,
    cuda=True,
    fid=True,
    isc=True
)

fid_score = metrics["frechet_inception_distance"]
is_score = metrics["inception_score_mean"]

if is_score > 0:
    s_value_2 = np.sqrt(fid_score / is_score)
    print("Geometric Mean Score:", s_value_2)
else:
    print("IS is 0, GMS cannot be computed!")


s_value = np.round((s_value_1+s_value_2)/2, 5)
df = pd.DataFrame({'id': [1], 'label': [s_value]})

csv_path = "Username.csv"
df.to_csv(csv_path, index=False)

print(f"CSV saved to {csv_path}")