In [None]:
# !curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
# !python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev

In [None]:
# !pip install pytorch-lightning==1.2.10

In [None]:
#!pip install torchinfo
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset
import torchvision.utils as vutils
from PIL import Image
import os
import matplotlib.pyplot as plt
from torchinfo import summary
import pytorch_lightning as pl
import shutil

In [None]:
print(pl.__version__)

In [None]:
# device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")

In [None]:
torch.manual_seed(42)
BATCH_SIZE = 4

## Data exploration

In [None]:
class MonetPhotoDataset(Dataset):
    def __init__(self,monet,photo,transforms=None, train = True):
        if train:
            self.monet = monet[BATCH_SIZE:]
            self.photo = photo[BATCH_SIZE:]
        else:
            self.monet = monet[:BATCH_SIZE]
            self.photo = photo[:BATCH_SIZE]
        self.transforms = transforms
        
    def __len__(self):
        return len(self.monet)
    
    def __getitem__(self, index):
        monet = self.monet[index]
        photo = self.photo[index]
        
        monet = Image.open(monet)
        photo = Image.open(photo)
        
        monet = self.transforms(monet)
        photo = self.transforms(photo)
        
        return monet, photo

In [None]:
class PhotoDataset(Dataset):
    def __init__(self,photo,transforms=None):

        self.photo = photo
        self.transforms = transforms
        
        
    def __len__(self):
        return len(self.photo)
    
    def __getitem__(self, index):
        photo = self.photo[index]
        name = self.photo[index].split('/')[-1]
        
        photo = Image.open(photo)
        
        photo = self.transforms(photo)
        
        return photo,name

In [None]:
class MonetDataModule(pl.LightningDataModule):
    def __init__(self):
        super(MonetDataModule, self).__init__()
        self.monet_dir = "../input/gan-getting-started/monet_jpg"
        self.photo_dir = "../input/gan-getting-started/photo_jpg"
        self.transform = transforms.Compose([transforms.ToTensor(),
                                             transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5, 0.5, 0.5])])
    def prepare_data(self):
        self.monet = [os.path.join(self.monet_dir, name) for name in sorted(os.listdir(self.monet_dir))]
        self.photo = [os.path.join(self.photo_dir, name) for name in sorted(os.listdir(self.photo_dir))]

    def train_dataloader(self):
        train_dataset = MonetPhotoDataset(self.monet,self.photo,self.transform)
        
        return torch.utils.data.DataLoader(train_dataset,
                                           pin_memory=True,
                                           batch_size=BATCH_SIZE,
                                           shuffle=True,
                                           num_workers = 2)
    def val_dataloader(self):
        val_dataset = MonetPhotoDataset(self.monet,self.photo,self.transform,False)
        return torch.utils.data.DataLoader(val_dataset,
                                           pin_memory=True,
                                           batch_size=BATCH_SIZE,
                                           shuffle=False,
                                           num_workers = 2)
    def predict_dataloader(self):
        pred_dataset = PhotoDataset(self.photo,self.transform)
        return torch.utils.data.DataLoader(pred_dataset,
                                           pin_memory=True,
                                           batch_size=BATCH_SIZE,
                                           shuffle=False,
                                           num_workers = 2)

In [None]:
dm = MonetDataModule()
dm.prepare_data()
dataloader = dm.val_dataloader()

In [None]:
batch = list(dataloader)[0]
plt.figure(figsize=(10,10))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(batch[0],nrow = BATCH_SIZE,padding=2, normalize=True).cpu(),(1,2,0)))

In [None]:
plt.figure(figsize=(10,10))
plt.axis("off")
plt.imshow(np.transpose(vutils.make_grid(batch[1],nrow = BATCH_SIZE,padding=2, normalize=True).cpu(),(1,2,0)))

## Model

In [None]:
class Resblock(nn.Module):
    def __init__(self):
        super(Resblock, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(256, 256, 3, 1, 1,padding_mode = "reflect"),
            nn.InstanceNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(256, 256, 3,1,1, padding_mode = "reflect"),
            nn.InstanceNorm2d(256),
        )
    def forward(self, x):
        x = x + self.main(x)
        return x

In [None]:
class Generator(nn.Module):
    def __init__(self, res_count = 9):
        super(Generator, self).__init__()
        modules = nn.ModuleList()
        # Encoder
        modules += [
            nn.Conv2d(3,64,7,1,3, padding_mode = "reflect"),
            nn.InstanceNorm2d(64),
            nn.ReLU(True),
            
            nn.Conv2d(64,128,3,2,1),
            nn.InstanceNorm2d(128),
            nn.ReLU(True),
            
            nn.Conv2d(128,256,3,2,1),
            nn.InstanceNorm2d(256),
            nn.ReLU(True)
           ]
        for i in range(res_count):
            modules += [Resblock()]
            
        
        # Decoder
        modules += [nn.ConvTranspose2d(256, 128, 3, 2,1,1),
                    nn.InstanceNorm2d(128),
                    nn.ReLU(True),
                    nn.ConvTranspose2d(128, 64, 3, 2,1,1),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(True),
                    nn.Conv2d(64, 3, 7, 1,3,padding_mode = "reflect"),
                    nn.Tanh()]
        self.main = nn.Sequential(*modules)

    def forward(self, input):
        return self.main(input)

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        modules = nn.ModuleList([
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.InstanceNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, 4, 2 ,1 ),
            nn.InstanceNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, 4, 1, 1),
            nn.InstanceNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, 4, 1,1),
        ])
        self.main = nn.Sequential(*modules)

    def forward(self, input):
        return self.main(input)

In [None]:
netG_Photo2Monet = Generator()
netG_Monet2Photo = Generator()
summary(netG_Photo2Monet, input_size=(6, 3, 256, 256))

In [None]:
netD_Photo2Monet = Discriminator()
netD_Monet2Photo = Discriminator()
summary(netD_Photo2Monet, input_size=(6, 3, 256, 256))

In [None]:
def init_func(m):  # define the initialization function
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if hasattr(m, 'bias') and m.bias is not None:
            nn.init.constant_(m.bias.data, 0.0)
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)

In [None]:
class CycleGAN_LightningSystem(pl.LightningModule):
    def __init__(self):
        super(CycleGAN_LightningSystem, self).__init__()
        self.netG_Photo2Monet = Generator()
        self.netG_Monet2Photo = Generator()
        self.netD_Photo2Monet = Discriminator()
        self.netD_Monet2Photo = Discriminator()
        
        self.netG_Photo2Monet.apply(init_func)
        self.netG_Monet2Photo.apply(init_func)
        self.netD_Photo2Monet.apply(init_func)
        self.netD_Monet2Photo.apply(init_func)
        
        self.criterion_GAN = nn.MSELoss()
        self.criterion_cycle = nn.L1Loss()
        self.criterion_identity = nn.L1Loss()
       
        self.loss_d = []
        self.loss_g = []
        self.cycle_loss = []
        self.identity_loss = []
        self.GAN_loss = []

    def configure_optimizers(self):
        self.optimizerD_Photo2Monet = torch.optim.Adam(self.netD_Photo2Monet.parameters(), lr = 0.0001, betas=(0.5, 0.999))
        self.optimizerD_Monet2Photo = torch.optim.Adam(self.netD_Monet2Photo.parameters(), lr = 0.0001, betas=(0.5, 0.999))
        self.optimizerG_Photo2Monet = torch.optim.Adam(self.netG_Photo2Monet.parameters(), lr = 0.0002, betas=(0.5, 0.999))
        self.optimizerG_Monet2Photo = torch.optim.Adam(self.netG_Monet2Photo.parameters(), lr = 0.0002, betas=(0.5, 0.999))
        return [self.optimizerG_Photo2Monet,self.optimizerG_Monet2Photo, self.optimizerD_Photo2Monet, self.optimizerD_Monet2Photo], []

    def training_step(self, batch, batch_idx, optimizer_idx):
        data_monet, data_photo = batch
        b = data_monet.size()[0]

        label_real = torch.ones((b,1,30,30),  dtype=torch.float, device = self.device)
        label_fake = torch.zeros((b,1,30,30), dtype=torch.float, device = self.device)
        
        p2m = self.netG_Photo2Monet(data_photo)
        m2p = self.netG_Monet2Photo(data_monet)
        # Train Generator
        if optimizer_idx == 0 or optimizer_idx == 1:
            # Validity
            # MSELoss
            GAN_loss_Photo2Monet = self.criterion_GAN(self.netD_Photo2Monet(p2m), label_real)
            GAN_loss_Monet2Photo = self.criterion_GAN(self.netD_Monet2Photo(m2p), label_real)
            GAN_loss = (GAN_loss_Photo2Monet + GAN_loss_Monet2Photo) * 0.5

            # Reconstruction
            monet2photo2monet = self.netG_Photo2Monet(m2p)
            photo2monet2photo = self.netG_Monet2Photo(p2m)
            cycle_loss_Photo2Monet = self.criterion_cycle(photo2monet2photo, data_photo)
            cycle_loss_Monet2Photo = self.criterion_cycle(monet2photo2monet, data_monet)
            cycle_loss = (cycle_loss_Photo2Monet + cycle_loss_Monet2Photo) * 0.5

            # Identity
            monet_from_photo_by_net_p2m = self.netG_Photo2Monet(data_monet)
            photo_from_monet_by_net_m2p = self.netG_Monet2Photo(data_photo)
            identity_loss_p2m = self.criterion_identity(monet_from_photo_by_net_p2m, data_monet)
            identity_loss_m2p = self.criterion_identity(photo_from_monet_by_net_m2p, data_photo)
            identity_loss = (identity_loss_m2p + identity_loss_p2m) * 0.5

            # Loss Weight
            total_loss_G = GAN_loss + 10 * cycle_loss + 0.5 * 10 * identity_loss
            return {'loss': total_loss_G, 'GAN_loss': GAN_loss, 'cycle_loss': cycle_loss, 'identity_loss': identity_loss}

        # Train Discriminator
        elif optimizer_idx == 2:
            # MSELoss
            output_real = self.netD_Photo2Monet(data_monet)
            output_fake = self.netD_Photo2Monet(p2m.detach())
            loss_real = self.criterion_GAN(output_real, label_real)
            loss_fake = self.criterion_GAN(output_fake, label_fake)
            loss_photo2Monet = (loss_real + loss_fake) * 0.5

            return {'loss': loss_photo2Monet}
        
        elif optimizer_idx == 3:
            output_real = self.netD_Monet2Photo(data_photo)
            output_fake = self.netD_Monet2Photo(m2p.detach())
            loss_real = self.criterion_GAN(output_real, label_real)
            loss_fake = self.criterion_GAN(output_fake, label_fake)
            loss_Monet2Photo = (loss_real + loss_fake) * 0.5 
            return {'loss': loss_Monet2Photo}
    def training_epoch_end(self, outputs):
        G_loss = sum([torch.stack([x['loss'] for x in outputs[i]])[-1].item()  for i in [0, 1]]) / 2
        D_loss = sum([torch.stack([x['loss'] for x in outputs[i]])[-1].item()  for i in [2, 3]]) / 2
        GAN_loss = sum([torch.stack([x['GAN_loss'] for x in outputs[i]])[-1].item() for i in [0, 1]]) /2
        cycle_loss = sum([torch.stack([x['cycle_loss'] for x in outputs[i]])[-1].item() for i in [0, 1]]) /2
        identity_loss = sum([torch.stack([x['identity_loss'] for x in outputs[i]])[-1].item() for i in [0, 1]]) /2
        self.loss_d.append(D_loss)
        self.loss_g.append(G_loss)
        self.GAN_loss.append(GAN_loss)
        self.cycle_loss.append(cycle_loss)
        self.identity_loss.append(identity_loss)

    def validation_step(self, batch, batch_idx):
        print(f"epoch {self.current_epoch + 1}") 
        print(f"discriminator loss: {self.loss_d[-1]}")
        print(f"generator loss: {self.loss_g[-1]}")

        ts_monet_data, ts_photo_data = batch


        monet = self.netG_Photo2Monet(ts_photo_data).detach()
        photo = self.netG_Monet2Photo(ts_monet_data).detach()

        nrows = ts_monet_data.size(0)

        ts_photo_data = vutils.make_grid(ts_photo_data, nrow=nrows,padding=2, normalize=True)
        ts_monet_data = vutils.make_grid(ts_monet_data, nrow=nrows,padding=2, normalize=True)
        monet = vutils.make_grid(monet, nrow=nrows,padding=2, normalize=True)
        photo = vutils.make_grid(photo, nrow=nrows,padding=2, normalize=True)
        result = torch.cat((ts_photo_data, monet, ts_monet_data, photo), 1)
        result = result.cpu().permute(1,2,0)

        # show images 
        plt.figure(figsize=(20,20))
        plt.axis("off")
        plt.imshow(result)
        plt.show()
        return None
    def predict_step(self, batch, batch_idx):
        photo, names = batch
        generated = self.netG_Photo2Monet(photo).detach()
        for img_arr,name in zip(generated, names):
            vutils.save_image(img_arr, os.path.join('./images', name), normalize = True)
        return None

model = CycleGAN_LightningSystem()

In [None]:
trainer = pl.Trainer(
#     logger=False,
    max_epochs= 140,
#     precision = 16,
    gpus= -1,
#     profiler="simple",
#     tpu_cores=8,
    check_val_every_n_epoch = 10,
    enable_checkpointing = False,
    num_sanity_val_steps = 0,  # Skip Sanity Check
)


# Train
trainer.fit(model, datamodule=dm)

In [None]:
os.makedirs('./images', exist_ok=True)
y = trainer.predict(model, dm)

In [None]:
# Make Zipfile
shutil.make_archive("./images", 'zip', "./images")
    
    # Delete Origin file
shutil.rmtree('./images')

In [None]:
# Loss Plot
fig, axes = plt.subplots(ncols=1, nrows=2, figsize=(18, 12), facecolor='w')
epoch_num = len(model.loss_g)

axes[0].plot(np.arange(epoch_num), model.loss_g, label='generator')
axes[0].plot(np.arange(epoch_num), model.loss_d, label='discriminator')
axes[0].legend()
axes[0].set_xlabel('Epoch')

axes[1].plot(np.arange(epoch_num), model.GAN_loss, label='Gan Loss')
axes[1].plot(np.arange(epoch_num), model.cycle_loss, label='Cycle Loss')
axes[1].plot(np.arange(epoch_num), model.identity_loss, label='Identity Loss')
axes[1].legend()
axes[1].set_xlabel('Epoch')

plt.show()