In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch import Tensor, tensor
from math import log2
import os
import numpy as np
import itertools
from torchvision.utils import make_grid
from torchvision.io import read_image
from torch.nn.utils import spectral_norm
import cv2 as cv
from tqdm import tqdm
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import shutil
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import glob
from skimage import io
from torchvision.utils import save_image
import argparse

In [None]:
class MonetPhotoDataset(Dataset):
    def __init__(self, root_monet, root_photo, transform=None) -> None:
        super(MonetPhotoDataset, self).__init__()
        self.root_monet = root_monet
        self.root_photo = root_photo
        self.monet_imgs = os.listdir(root_monet)
        self.photo_imgs = os.listdir(root_photo)
        self.length = max(len(self.monet_imgs), len(self.photo_imgs))
        self.transforms = transform
        
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        monet_img = self.monet_imgs[idx%len(self.monet_imgs)]
        photo_img = self.photo_imgs[idx%len(self.photo_imgs)]
        monet_path = os.path.join(self.root_monet, monet_img)
        photo_path = os.path.join(self.root_photo, photo_img)
        monet = Image.open(monet_path).convert('RGB')
        photo = Image.open(photo_path).convert('RGB')
        if self.transforms:
            monet = self.transforms(monet)
            photo = self.transforms(photo)
        return photo, monet

In [None]:
def train(critic_monet, critic_photo, gen_monet, gen_photo, loader, opt_critic, opt_gen, epoch, lambda_cycle, device='cuda:0'):
    criterion = nn.MSELoss()
    loop = tqdm(loader, leave=True)

    for idx, (photo, monet) in enumerate(loop):
        monet = monet.to(device)
        photo = photo.to(device)

        fake_photo = gen_photo(monet)
        critic_photo_real = critic_photo(photo).reshape(-1)
        critic_photo_fake = critic_photo(fake_photo.detach()).reshape(-1)
        critic_photo_loss = -(torch.mean(critic_photo_real)-torch.mean(critic_photo_fake))

        fake_monet = gen_monet(photo)
        critic_monet_real = critic_monet(monet).reshape(-1)
        critic_monet_fake = critic_monet(fake_monet.detach()).reshape(-1)
        critic_monet_Loss = -(torch.mean(critic_monet_real)-torch.mean(critic_monet_fake))

            critic_loss = (critic_photo_loss + critic_monet_Loss)

        opt_critic.zero_grad()
        critic_loss.backward(retain_graph=True)
        opt_critic.step()

        critic_photo_fake_gen = critic_photo(fake_photo).reshape(-1)
        critic_monet_fake_gen = critic_monet(fake_monet).reshape(-1)
        loss_gen_photo = -torch.mean(critic_photo_fake_gen)
        loss_gen_monet = -torch.mean(critic_monet_fake_gen)

        cycle_monet = gen_monet(fake_photo)
        cycle_photo = gen_photo(fake_monet)
        cycle_monet_loss = torch.abs(monet-cycle_monet).mean()
        cycle_photo_loss = torch.abs(photo-cycle_photo).mean()
        identity_monet = gen_monet(monet)
        identity_photo = gen_photo(photo)
        identity_monet_loss = torch.abs(monet-identity_monet).mean()
        identity_photo_loss = torch.abs(photo-identity_photo).mean()
        gen_loss = (
                loss_gen_photo
                + loss_gen_monet
                + cycle_monet_loss * lambda_cycle
                + cycle_photo_loss * lambda_cycle
                + identity_monet_loss
                + identity_photo_loss
            )

        opt_gen.zero_grad()
        gen_loss.backward(retain_graph=True)
        opt_gen.step()

        if idx % 100 == 0:
            save_image(fake_photo*0.5+0.5, f"photo_{idx}.png")
            save_image(fake_monet*0.5+0.5, f"monet_{idx}.png")
            checkpoint = {
                'generator_monet': gen_monet.state_dict(),
                'generator_photo': gen_photo.state_dict(),
                'critic_monet': critic_monet.state_dict(),
                'critic_photo': critic_photo.state_dict(),
                'opt_gen': opt_gen.state_dict(),
                'opt_critic': opt_critic.state_dict(),
                'epoch': epoch, 
                "lambda_cycle": lambda_cycle
            }
            save_checkpoint(checkpoint)

        loop.set_postfix(critic_loss=critic_loss.item(), gen_loss=gen_loss.item())

In [None]:
def get_loader():
    transform=transforms.Compose(
    [transforms.Resize((256,256)), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    )
    dataset = MonetPhotoDataset('../input/gan-getting-started/monet_jpg', '../input/gan-getting-started/photo_jpg', transform=transform)
    loader = DataLoader(
        dataset, batch_size=4, shuffle=True, num_workers=2
    )
    return dataset, loader

def save_checkpoint(state, filename='./model.pt'):
    torch.save(state, filename)

def load_checkpoint(critic_monet, critic_photo, gen_monet, gen_photo, opt_gen, opt_critic, lr,
                    path='../input/modelcycle/model (3).pt'):
    checkpoint = torch.load(path)
    gen_monet.load_state_dict(checkpoint['generator_monet'])
    gen_photo.load_state_dict(checkpoint['generator_photo'])
    critic_monet.load_state_dict(checkpoint['critic_monet'])
    critic_photo.load_state_dict(checkpoint['critic_photo'])
    opt_gen.load_state_dict(checkpoint['opt_gen'])
    opt_critic.load_state_dict(checkpoint['opt_critic'])
    for param in opt_critic.param_groups:
        param['lr'] = lr
    for param in opt_gen.param_groups:
        param['lr'] = lr
    return checkpoint

In [None]:
class SNConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True, padding_mode='reflect'):
        super(SNConv2d, self).__init__()
        self.conv = spectral_norm(nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding,
                                           bias=bias, padding_mode=padding_mode))
    
    def forward(self, x):
        return self.conv(x)


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, use_act=True, kernel_size=3, stride=1, padding=1, down=True):
        super().__init__()
        self.conv = nn.Sequential(
            SNConv2d(in_channels, out_channels, kernel_size=kernel_size, 
                        stride=stride,padding=padding
                       ) if down else 
            SNConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True) if use_act else nn.Identity()
        )

    def forward(self, x):
        return self.conv(x)


class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.block = nn.Sequential(
            ConvBlock(channels, channels, kernel_size=3, padding=1),
            ConvBlock(channels, channels, use_act=False, kernel_size=3, padding=1)
        )
        
    def forward(self, x):
        return x + self.block(x)

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            SNConv2d(in_channels, out_channels, 4, stride, 1, bias=True, padding_mode="reflect"),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU(0.2, inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Critic(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
        super(Critic, self).__init__()
        self.initial = nn.Sequential(
            SNConv2d(
                in_channels,
                features[0],
                kernel_size=4,
                stride=2,
                padding=1
            ),
            nn.LeakyReLU(0.2, inplace=True),
        )

        layers = []
        in_channels = features[0]
        for feature in features[1:]:
            layers.append(Block(in_channels, feature, stride=1 if feature==features[-1] else 2))
            in_channels = feature
        self.final = SNConv2d(in_channels+1, 1, kernel_size=4, stride=1, padding=1)
        self.model = nn.Sequential(*layers)

    def minibatch_std(self, x):
        batch_stats = torch.std(x, dim=0).mean().repeat(x.shape[0], 1, x.shape[2], x.shape[3])
        return torch.cat([x, batch_stats], dim=1)

    def forward(self, x):
        x = self.initial(x)
        x = self.model(x)
        x = self.minibatch_std(x)
        x = self.final(x)
        return x

class GenBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, use_norm=True):
        super(GenBlock, self).__init__()
        self.conv = nn.Sequential(
            SNConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.InstanceNorm2d(out_channels) if use_norm else nn.Identity(),
            nn.LeakyReLU(0.2, inplace=True))
    
    def forward(self, x):
        return self.conv(x)

class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, init_features=32):
        super(Generator, self).__init__()
        features = init_features
        self.pointwise = lambda in_channels, out_channels: nn.Sequential(
            SNConv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0,
                    bias=False),
            nn.InstanceNorm2d(out_channels)
        )
        self.encoder1 = GenBlock(in_channels, features, use_norm=False)
        self.pool1 = nn.PixelUnshuffle(2)
        self.enc_point1 = self.pointwise(features*4, features*2)
        self.encoder2 = GenBlock(features*2, features * 2)
        self.pool2 = nn.PixelUnshuffle(2)
        self.enc_point2 = self.pointwise(features*8, features*4)
        self.encoder3 = GenBlock(features * 4, features * 4)
        self.pool3 = nn.PixelUnshuffle(2)
        self.enc_point3 = self.pointwise(features*16, features*8)
        self.encoder4 = GenBlock(features * 8, features * 8)
        self.pool4 = nn.PixelUnshuffle(2)
        self.enc_point4 = self.pointwise(features*32, features*16)
        
        
        self.bottleneck = GenBlock(features*16, features*16)
        self.residual = ResidualBlock(features*16)
        self.upconv4 = SNConvTranspose2d(
            features * 16, features * 8, kernel_size=3, stride=2  
        )
        self.upconv4 = nn.PixelShuffle(2)
        self.point4 = self.pointwise(features*4, features*8)
        self.decoder4 = GenBlock((features * 8) * 2, features * 8)
        self.upconv3 = SNConvTranspose2d(
            features * 8, features * 4, kernel_size=3, stride=2
        )
        self.upconv3 = nn.PixelShuffle(2)
        self.point3 = self.pointwise(features*2, features*4)
        self.decoder3 = GenBlock((features * 4) * 2, features * 4)
        self.upconv2 = SNConvTranspose2d(
            features * 4, features * 2, kernel_size=3, stride=2
        )
        self.upconv2 = nn.PixelShuffle(2)
        self.point2 = self.pointwise(features, features*2)
        self.decoder2 = GenBlock((features * 2) * 2, features * 2)
        self.upconv1 = SNConvTranspose2d(
            features * 2, features, kernel_size=3, stride=2
        )
        self.upconv1 = nn.PixelShuffle(2)
        self.point1 = self.pointwise(features//2, features)
        self.decoder1 = GenBlock(features * 2, features)

        self.conv = SNConv2d(
            in_channels=features, out_channels=out_channels, kernel_size=3, padding=1
        )

    def forward(self, x):
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(self.enc_point1(self.pool1(enc1)))
        enc3 = self.encoder3(self.enc_point2(self.pool2(enc2)))
        enc4 = self.encoder4(self.enc_point3(self.pool3(enc3)))
        residual_connect = self.residual(self.bottleneck(self.enc_point4(self.pool4(enc4))))

        dec4 = self.upconv4(residual_connect)
        dec4 = self.point4(dec4)
        dec4 = torch.cat([dec4, enc4], dim=1)
        dec4 = self.decoder4(dec4)
        dec3 = self.upconv3(dec4)
        dec3 = self.point3(dec3)
        dec3 = torch.cat([dec3, enc3], dim=1)
        dec3 = self.decoder3(dec3)
        dec2 = self.upconv2(dec3)
        dec2 = self.point2(dec2)
        dec2 = torch.cat([dec2, enc2], dim=1)
        dec2 = self.decoder2(dec2)
        dec1 = self.upconv1(dec2)
        dec1 = self.point1(dec1)
        dec1 = torch.cat([dec1, enc1], dim=1)
        dec1 = self.decoder1(dec1)
        return torch.tanh(self.conv(dec1))

In [None]:
def main():
    device = "cuda:0"
    gen_monet = Generator().to(device)
    gen_photo = Generator().to(device)
    critic_monet = Critic().to(device)
    critic_photo = Critic().to(device)
    opt_gen = optim.Adam(itertools.chain(gen_monet.parameters(),gen_photo.parameters()), lr=1e-5, betas=(0.0, 0.9))
    opt_critic = optim.Adam(
            itertools.chain(critic_monet.parameters(), critic_photo.parameters()), lr=1e-5, betas=(0.0, 0.9)
    )
    checkpoint = {
                'generator_monet': gen_monet.state_dict(),
                'generator_photo': gen_photo.state_dict(),
                'critic_monet': critic_monet.state_dict(),
                'critic_photo': critic_photo.state_dict(),
                'opt_gen': opt_gen.state_dict(),
                'opt_critic': opt_critic.state_dict(),
                'epoch': 0
                }
    epochs = 5
    lambda_cycle = 10
    for epoch in range(epochs):
        train(critic_monet, critic_photo, gen_monet, gen_photo, loader, opt_critic, opt_gen, epoch, lambda_cycle)