In [None]:
! pip install -q kaggle
from google.colab import files
files.upload()
! mkdir ~/.kaggle
! cp kaggle.json ~/.kaggle/
! chmod 600 ~/.kaggle/kaggle.json
! mkdir kaggle
%cd kaggle
! mkdir data
%cd data
! kaggle competitions download -c gan-getting-started

In [3]:
! kaggle competitions download -c gan-getting-started

Downloading 003c6c30e0.jpg to /content/kaggle/data
  0% 0.00/13.5k [00:00<?, ?B/s]
100% 13.5k/13.5k [00:00<00:00, 9.19MB/s]
User cancelled operation
Exception KeyboardInterrupt in <module 'threading' from '/usr/lib/python2.7/threading.pyc'> ignored


In [9]:
!mkdir aaa

In [None]:
import os
assert os.environ['COLAB_TPU_ADDR'], 'Make sure to select TPU from Edit > Notebook settings > Hardware accelerator'

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split

import numpy as np

from torchvision import transforms
from PIL import Image

import seaborn as sns
import matplotlib.pyplot as plt
import os

import itertools
import random

In [None]:
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import torch_xla.distributed.parallel_loader as pl

import warnings
warnings.filterwarnings("ignore")

In [None]:
torch.cuda.is_available()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

# Tha Dataset, Dataloader

In [None]:
class ImageDataset(Dataset):
    def __init__(self, photo_dir, monet_dir, size=(256, 256)):
        super(Dataset, self).__init__()
        self.photo_dir = photo_dir
        self.monet_dir = monet_dir
        self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
        ])
        self.photo_idx = []
        self.monet_idx = []
        for i in os.listdir(monet_dir):
            self.monet_idx.append(i)
        for i in os.listdir(photo_dir):
            self.photo_idx.append(i)
            
    def __getitem__(self, i):
        i_photo = int(np.random.uniform(0, len(self.photo_idx)))
        photo_path = os.path.join(self.photo_dir, self.photo_idx[i_photo])
        monet_path = os.path.join(self.monet_dir, self.monet_idx[i])
        photo = Image.open(photo_path)
        monet = Image.open(monet_path)
        photo = self.transform(photo)
        monet = self.transform(monet)
        return monet, photo
    
    def __len__(self):
        return len(self.monet_idx)

### Dataloader

In [None]:
dataset = ImageDataset("/kaggle/input/gan-getting-started/photo_jpg/", "/kaggle/input/gan-getting-started/monet_jpg/")
ImageDataLoader = DataLoader(dataset, batch_size=4)

### A function to return the images back to what they really are without the normalization

In [None]:
def unnorm(img, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
    for t, m, s in zip(img, mean, std):
        t.mul_(s).add_(s)
        
    return img

In [None]:
monet_img ,photo_img = next(iter(ImageDataLoader))

f = plt.figure(figsize=(8, 8))

f.add_subplot(1, 2, 1)
plt.title('With normalization')
plt.imshow(photo_img[0].permute(1,2,0))

f.add_subplot(1, 2, 2)
plt.title('Without normalization')
plt.imshow(unnorm(photo_img[0]).permute(1,2,0))

plt.show()

# The Model

### The residual block for out model, very similiar to the **Resnet** Block

In [None]:
class ResBlock(nn.Module): # ResNet
    def __init__(self, in_features = 256):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_features),
            nn.ReLU(True),
            nn.Dropout(0.5),
            nn.Conv2d(in_features, in_features, kernel_size=3, stride=1, padding=1),
            nn.InstanceNorm2d(in_features),
        )
    def forward(self, x):
        return x + self.model(x)
        

## The generator:
### 1. Down sampling: 3 Convolutional -> Norm -> RelU layers
### 2. Transformer: 6 Resnet blocks
### 3. Up samling: 3 Convolution_transpose -> Norm -> RelU layers


In [None]:
def downsample(in_ch, out_ch, kernel=4, bn=True):
    model = []
    model += [nn.Conv2d(in_ch, out_ch, kernel_size=kernel, stride=2, padding=1)]
    if bn:
        model += [nn.InstanceNorm2d(out_ch)]
    model += [nn.LeakyReLU()]
    return nn.Sequential(*model)

def upsample(in_ch, out_ch, kernel=4, bn=True, dp=False):
    model = []
    model += [nn.ConvTranspose2d(in_ch, out_ch, kernel_size=kernel, stride=2, padding=1)]
    if bn:
        model += [nn.InstanceNorm2d(out_ch)]
    if dp:
        model += [nn.Dropout(0.5)]
    model += [nn.ReLU()]
    return nn.Sequential(*model)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__() 
        
        self.down1 = downsample(3, 64)
        self.down2 = downsample(64, 128)
        self.down3 = downsample(128, 256)
        self.down4 = downsample(256, 512)
        self.down5 = downsample(512, 512)
        self.down6 = downsample(512, 512)
        self.down7 = downsample(512, 512)
        self.down8 = downsample(512, 512, bn=False)
        
        self.up8   = upsample(512,512,bn=False, dp=True)
        self.up7   = upsample(1024,512, dp=True)
        self.up6   = upsample(1024,512, dp=True)
        self.up5   = upsample(1024,512)
        self.up4   = upsample(1024,256)
        self.up3   = upsample(512,128)
        self.up2   = upsample(256,64)
        self.up1   = nn.ConvTranspose2d(128, 3, kernel_size=4, stride=2, padding=1)
        self.tanh  = nn.Tanh()
        
        
    def forward(self, x):
        x1 = self.down1(x)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)
        x6 = self.down6(x5)
        x7 = self.down7(x6)
        x8 = self.down8(x7)
        
        o8 = self.up8(x8)
        o7 = self.up7(torch.cat([o8, x7], dim=1))
        o6 = self.up6(torch.cat([o7, x6], dim=1))
        o5 = self.up5(torch.cat([o6, x5], dim=1))
        o4 = self.up4(torch.cat([o5, x4], dim=1))
        o3 = self.up3(torch.cat([o4, x3], dim=1))
        o2 = self.up2(torch.cat([o3, x2], dim=1))
        o1 = self.up1(torch.cat([o2, x1], dim=1))
        
        return self.tanh(o1)

In [None]:
monet_img ,photo_img = next(iter(ImageDataLoader))

f = plt.figure(figsize=(8, 8))

f.add_subplot(1, 2, 1)
plt.title('Generated one')
out = Generator()( photo_img )[0].detach()
plt.imshow(out.permute(1,2,0))

f.add_subplot(1, 2, 2)
plt.title('Photo')
plt.imshow(unnorm(photo_img[0]).permute(1,2,0))

plt.show()


In [None]:
# class Generator(nn.Module):
#     def __init__(self):
#         super().__init__() 
#         ngf = 64
#         model = [nn.Conv2d(3, ngf, kernel_size=3, padding=1),
#                  nn.InstanceNorm2d(ngf),
#                  nn.ReLU(True)
#                 ] 
        
#         n_downsampling = 2
#         for i in range(n_downsampling):  # add downsampling layers
#             mult = 2 ** i
#             model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1),
#                       nn.InstanceNorm2d(ngf * mult * 2),
#                       nn.ReLU(True)]
            
            
#         mult = 2 ** n_downsampling
#         n_blocks = 6
#         for i in range(n_blocks):       # add ResNet blocks
#             model += [ResBlock(256)]
            
            
#         for i in range(n_downsampling):  # add upsampling layers
#             mult = 2 ** (n_downsampling - i)
#             model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
#                                          kernel_size=3, stride=2,
#                                          padding=1, output_padding=1),
#                       nn.InstanceNorm2d(int(ngf * mult / 2)),
#                       nn.ReLU(True)]
            
#         model += [nn.Conv2d(ngf, 3, kernel_size=3, padding=1)]
#         model += [nn.Tanh()]
        
#         self.model = nn.Sequential(*model)
                
#     def forward(self, x):
#         return self.model(x)
    
# test = Generator()

## The Discriminator:
### * A normal classifier.
### * 5 convolutional layers.
### * No fully connected dense layer at the end, as it not recommended for GANs
### * No sigmoid activation to the end of the classifier
### * The output won't be a single value, instead it's a heatmap

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        model = []
        model += [downsample(3, 64), # 128
                  downsample(64, 128), # 64
                  downsample(128, 256)] # 32
        
        model += [nn.ZeroPad2d(1)]
        
        model += [nn.Conv2d(256, 512, kernel_size=4, stride=1)]
        model += [nn.InstanceNorm2d(512)]
        model += [nn.LeakyReLU(0.2)]
        
        model += [nn.ZeroPad2d(1)]
        
        model += [nn.Conv2d(512, 1, kernel_size=4, stride=1)]
        
        self.model = nn.Sequential(*model)
        
            
        
    def forward(self, x):
        return self.model(x)

In [None]:
# class Discriminator(nn.Module):
#     def __init__(self):
#         super().__init__()
        
#         model = []
#         ndf = 64
        
#         model += [nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False),
#                   nn.InstanceNorm2d(64),
#                   nn.LeakyReLU(0.2, True)]
        
        
#         for i in range(2):
#             n = 2 ** i
#             model += [
#                 nn.Conv2d(ndf * n, ndf * n * 2, kernel_size=3, stride=2, padding=1, bias=False),
#                 nn.InstanceNorm2d(ndf * n * 2),
#                 nn.LeakyReLU(0.2, True)
#             ]
        
#         model += [
#                 nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
#                 nn.InstanceNorm2d(512),
#                 nn.LeakyReLU(0.2, True)
#             ]
        
#         model += [nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1)]  # output 1 channel prediction map
#         self.model = nn.Sequential(*model)
    
        
#     def forward(self, x):
#         return  self.model(x)

### The param initializer function

In [None]:
def init_func(m):
    if isinstance(m,nn.Conv2d) or isinstance(m,nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight,0.0,0.02)
    if isinstance(m,nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight,0.0,0.02)
        torch.nn.init.constant_(m.bias,0)        

In [None]:
class lr_sched():
    def __init__(self, lr, decay=0.001):
        self.decay = decay
        self.lr = lr

    def step(self, epoch_num):
        self.lr = self.lr / (1 + self.decay*epoch_num)
        return self.lr
    
    
def update_req_grad(models, requires_grad=True):
    for model in models:
        for param in model.parameters():
            param.requires_grad = requires_grad

# The final Model

In [None]:
# class CycleGAN(nn.Module):
#     def __init__(self, lambda_=10, idt_coef=0.1):
#         super().__init__()
#         self.M2P_gen = Generator()
#         self.P2M_gen = Generator()
#         self.M_disc = Discriminator()
#         self.P_disc = Discriminator()
        
#         self.L1_loss = nn.L1Loss()
#         self.MSE_loss = nn.MSELoss()
#         self.lambda_ = lambda_
#         self.idt_coef = idt_coef
        
#         self.gen_losses = []
#         self.desc_losses = []
        
#         self.adam_gen = torch.optim.Adam(itertools.chain(self.M2P_gen.parameters(), self.P2M_gen.parameters()), lr=0.0002, betas=(0.5,0.99))
#         self.adam_dis = torch.optim.Adam(itertools.chain(self.M_disc.parameters(), self.P_disc.parameters()), lr=0.0002, betas=(0.5,0.99))
        
         
#     def initialise_sub_models(self):
#         self.M2P_gen = self.M2P_gen.apply(init_func).to(device)
#         self.P2M_gen = self.P2M_gen.apply(init_func).to(device)
#         self.M_disc = self.M_disc.apply(init_func).to(device)
#         self.P_disc = self.P_disc.apply(init_func).to(device)
        
        
#     def train_(self, M_real, P_real):
        
#         # Generators only: 
#         self.adam_gen.zero_grad()
#         update_req_grad([self.P_disc, self.M_disc], False)
        
#         P_fake = self.M2P_gen(M_real)
#         M_fake = self.P2M_gen(P_real)
        
#         P_idt  = self.M2P_gen(P_real)
#         M_idt  = self.P2M_gen(M_real)
        
#         P_cycle = self.M2P_gen(M_fake)
#         M_cycle = self.P2M_gen(P_fake)
        
#         #The generator loss: cylce-consist., identity, Adversarial
#         #identity:
#         P_idt_loss = self.L1_loss(P_idt, P_real) * self.idt_coef
#         M_idt_loss = self.L1_loss(M_idt, M_real) * self.idt_coef
        
#         #cylce-consist:
#         P_cycle_loss = self.L1_loss(P_cycle, P_real) * self.lambda_
#         M_cycle_loss = self.L1_loss(M_cycle, M_real) * self.lambda_
        
#         #Adversarial:
#         Disc_P_fake = self.P_disc(P_fake)
#         Disc_M_fake = self.M_disc(M_fake)
        
#         ones = torch.ones(Disc_M_fake.size()).to(device)
#         adv_loss_P = self.MSE_loss(Disc_P_fake, ones)
#         adv_loss_M = self.MSE_loss(Disc_M_fake, ones)
        
#         total_adv_loss = P_idt_loss + M_idt_loss +\
#                          P_cycle_loss + M_cycle_loss+\
#                          adv_loss_P + adv_loss_M
        
#         self.gen_losses.append(total_adv_loss.item())
        
#         total_adv_loss.backward()
#         self.adam_gen.step()
        
        
        
        
#         # The discriminator :3 3yit ya rebi wellah :'( :
#         self.adam_dis.zero_grad()
#         update_req_grad([self.P_disc, self.M_disc], True)
        
#         P_fake = self.M2P_gen(M_real)
#         M_fake = self.P2M_gen(P_real)
        
#         Disc_P_fake = self.P_disc(P_fake)
#         Disc_P_real = self.P_disc(P_real)
#         Disc_M_fake = self.M_disc(M_fake)
#         Disc_M_real = self.M_disc(M_real)
        
        
#         ones = torch.ones(Disc_P_fake.size()).to(device)
#         zeros = torch.zeros(Disc_P_fake.size()).to(device)
        
        
#         loss1 = self.MSE_loss(Disc_P_fake, zeros)
#         loss2 = self.MSE_loss(Disc_P_real, ones)
#         loss3 = self.MSE_loss(Disc_M_fake, zeros)
#         loss4 = self.MSE_loss(Disc_M_real, ones)
        
#         total_desc_loss = (loss1 + loss2) / 2 + (loss3 + loss4) / 2
        
#         total_desc_loss.backward()
#         self.adam_dis.step()
        
#         self.desc_losses.append(total_desc_loss.item())    

In [None]:
class CycleGAN(nn.Module):
    def __init__(self, lambda_=15, idt_coef=0.5):
        super().__init__()
        self.M2P_gen = Generator()
        self.P2M_gen = Generator()
        self.M_disc = Discriminator()
        self.P_disc = Discriminator()
        
        self.L1_loss = nn.L1Loss()
        self.MSE_loss = nn.MSELoss()
        self.lambda_ = lambda_
        self.idt_coef = idt_coef
        
        self.gen_P_losses = []
        self.gen_M_losses = []
        self.disc_losses_P = []
        self.disc_losses_M = []
        
        self.adam_photo_gen = torch.optim.Adam(self.M2P_gen.parameters(), lr=0.002, betas=(0.5,0.99))
        self.adam_monet_gen = torch.optim.Adam(self.P2M_gen.parameters(), lr=0.002, betas=(0.5,0.99))
        self.adam_photo_dis = torch.optim.Adam(self.P_disc.parameters(), lr=0.002, betas=(0.5,0.99))
        self.adam_monet_dis = torch.optim.Adam(self.M_disc.parameters(), lr=0.002, betas=(0.5,0.99))
        
        
         
    def initialise_sub_models(self):
        self.M2P_gen = self.M2P_gen.apply(init_func).to(device)
        self.P2M_gen = self.P2M_gen.apply(init_func).to(device)
        self.M_disc = self.M_disc.apply(init_func).to(device)
        self.P_disc = self.P_disc.apply(init_func).to(device)
        
        
    def train_(self, M_real, P_real):
        
        # Generators only: 
        self.adam_photo_gen.zero_grad()
        self.adam_monet_gen.zero_grad()
        update_req_grad([self.P_disc, self.M_disc], False)
        
        P_fake = self.M2P_gen(M_real)
        M_fake = self.P2M_gen(P_real)
        
        P_idt  = self.M2P_gen(P_real)
        M_idt  = self.P2M_gen(M_real)
        
        P_cycle = self.M2P_gen(M_fake)
        M_cycle = self.P2M_gen(P_fake)
        
        
        #The generator loss: cylce-consist., identity, Adversarial
        #identity:
        P_idt_loss = self.L1_loss(P_idt, P_real) * self.idt_coef * self.lambda_
        M_idt_loss = self.L1_loss(M_idt, M_real) * self.idt_coef * self.lambda_
        
        #cylce-consist:
        P_cycle_loss = self.L1_loss(P_cycle, P_real) * self.lambda_
        M_cycle_loss = self.L1_loss(M_cycle, M_real) * self.lambda_
        
        #Adversarial:
        Disc_P_fake = self.P_disc(P_fake)
        Disc_M_fake = self.M_disc(M_fake)
        
        ones = torch.ones(Disc_M_fake.size()).to(device)
        
        adv_loss_P = self.MSE_loss(Disc_P_fake, ones)
        adv_loss_M = self.MSE_loss(Disc_M_fake, ones)
        
        
        total_adv_P_loss = P_idt_loss + P_cycle_loss + adv_loss_P
        total_adv_M_loss = M_idt_loss + M_cycle_loss + adv_loss_M
        
        self.gen_P_losses.append(total_adv_P_loss.item())
        self.gen_M_losses.append(total_adv_M_loss.item())
        
        total_adv_P_loss.backward(retain_graph=True)
        total_adv_M_loss.backward()
        
        #self.adam_photo_gen.step()
        xm.optimizer_step(self.adam_photo_gen)
        #self.adam_monet_gen.step()
        xm.optimizer_step(self.adam_monet_gen)

        
        
        
        
        # The discriminator :3 3yit ya rebi wellah :'( :
        self.adam_photo_dis.zero_grad()
        self.adam_monet_dis.zero_grad()
        update_req_grad([self.P_disc, self.M_disc], True)
        
        P_fake = self.M2P_gen(M_real)
        M_fake = self.P2M_gen(P_real)
        
        Disc_P_fake = self.P_disc(P_fake)
        Disc_P_real = self.P_disc(P_real)
        Disc_M_fake = self.M_disc(M_fake)
        Disc_M_real = self.M_disc(M_real)
        
        
        ones = torch.ones(Disc_P_fake.size()).to(device)
        zeros = torch.zeros(Disc_P_fake.size()).to(device)
        
        
        loss1 = self.MSE_loss(Disc_P_fake, zeros)
        loss2 = self.MSE_loss(Disc_P_real, ones)
        loss3 = self.MSE_loss(Disc_M_fake, zeros)
        loss4 = self.MSE_loss(Disc_M_real, ones)
        
        total_dis_p_loss = loss1 + loss2
        total_dis_m_loss = loss3 + loss4
        
        total_dis_p_loss.backward(retain_graph=True)
        total_dis_m_loss.backward()
        
        #self.adam_photo_dis.step()
        xm.optimizer_step(self.adam_photo_dis)
        #self.adam_monet_dis.step()
        xm.optimizer_step(self.adam_monet_dis)
        
        self.disc_losses_P.append(total_dis_p_loss.item())    
        self.disc_losses_M.append(total_dis_m_loss.item()) 

In [None]:
def save_checkpoint(state, save_path):
    torch.save(state, save_path)
    
# checkpoint = torch.load("current.ckpt")
# CycGAN_test.M2P_gen.load_state_dict(checkpoint["M2P_gen"])
# CycGAN_test.P2M_gen.load_state_dict(checkpoint["P2M_gen"])
# CycGAN_test.P_disc .load_state_dict(checkpoint["desc_P"])
# CycGAN_test.M_disc.load_state_dict(checkpoint["desc_M"])

In [None]:
CycGAN = CycleGAN().to(device)
CycGAN.initialise_sub_models()

In [None]:
WRAPPED_MODEL = xmp.MpModelWrapper(CycGAN)

## The training loop

In [None]:
def training_loop_please_emchi(index):
    EPOCHS = 20
    
    device = xm.xla_device()
    mp_device_loader = pl.MpDeviceLoader(ImageDataLoader, device) 
    
    CycGAN = CycleGAN().train().to(device)
    for epoch in range(1, EPOCHS+1):

        for i, (M_real, P_real) in enumerate(mp_device_loader):
            M_real = M_real.to(device)
            P_real = P_real.to(device)

            CycGAN.train_(M_real, P_real)
        xla.mast
#             if i % 100 == 0:
#                 print("[EPOCH:{}]    STEP {}/{}   Monet_Loss: {}    Photo_Loss: {}    Monet_disc_loss: {}    Photo_disc_loss: {} ".format(
#                             epoch, i, len(ImageDataLoader), CycGAN.gen_M_losses[-1], CycGAN.gen_P_losses[-1], CycGAN.disc_losses_M[-1], CycGAN.disc_losses_P[-1]
#                 ))

#                 photo_image = dataset[random.randint(0,299)][1].unsqueeze(0)

#                 f = plt.figure(figsize=(8, 8))

#                 f.add_subplot(1, 2, 1)
#                 plt.title('Photo')
#                 photo_image = unnorm(photo_image)
#                 plt.imshow(photo_image[0].permute(1, 2, 0))

#                 f.add_subplot(1, 2, 2)
#                 plt.title('Monet Generation')
#                 monet_img = unnorm(CycGAN.P2M_gen(photo_image.to(device)).cpu().detach())
#                 plt.imshow(monet_img[0].permute(1, 2, 0))

#                 plt.show()

In [None]:
import gc

In [None]:
xm.xrt_world_size()

In [None]:
def map_func_(rank, flags):
    global FLAGS
    device = xm.xla_device()
    
    dataset = ImageDataset("/kaggle/input/gan-getting-started/photo_jpg/", "/kaggle/input/gan-getting-started/monet_jpg/")

    train_sampler = torch.utils.data.DistributedSampler(dataset, 
                                                        num_replicas=xm.xrt_world_size(),
                                                        rank=xm.get_ordinal(),
                                                        shuffle=True)
    
    ImageDataLoader = torch.utils.data.DataLoader(dataset,
                                                  batch_size=4,
                                                  sampler=train_sampler,
                                                  num_workers=8,
                                                  drop_last=True)
    
    
    ImageDataLoader = pl.PerDeviceLoader(ImageDataLoader, device)
    
    model = WRAPPED_MODEL.to(device)
    gc.collect()
    
    #xm.master_print("Salam wa3likom ana:", rank)
    
    xm.rendezvous('please emchi yerham babak')
    
    #xm.master_print('save model')

In [None]:
map_func_(0, None)

In [None]:
print("s")

In [None]:
def map_(index):
    xm.master_print(index)
    
    xm.rendezvous("xxx")
    xm.master_print(index)

In [None]:
_ = xmp.spawn(map_, args=(), nprocs=8, start_method='fork')

In [None]:
%tb

In [None]:
photo_image = dataset[random.randint(0,300)][1].unsqueeze(0)

f = plt.figure(figsize=(8, 8))

f.add_subplot(1, 2, 1)
plt.title('Photo')
photo_image = unnorm(photo_image)
plt.imshow(photo_image[0].permute(1, 2, 0))

f.add_subplot(1, 2, 2)
plt.title('Monet Generation')
monet_img = unnorm(CycGAN.M2P_gen(photo_image.to(device)).cpu().detach())
plt.imshow(monet_img[0].permute(1, 2, 0))

plt.show()

In [None]:
save_dict = {
                'epoch': epoch+1,
                'M2P_gen': CycGAN.M2P_gen.state_dict(),
                'P2M_gen': CycGAN.P2M_gen.state_dict(),
                'desc_P': CycGAN.P_disc.state_dict(),
                'desc_M': CycGAN.M_disc.state_dict(),
            }
save_checkpoint(save_dict, 'test1.ckpt')

In [None]:
import PIL
! mkdir ../images

In [None]:
for i in range(len(dataset.photo_idx)):
    img = dataset.photo_idx[i]
    img = Image.open("/kaggle/input/gan-getting-started/photo_jpg/"+img)
    img_tensor = dataset.transform(img).unsqueeze(0)
    out = CycGAN.P2M_gen(img_tensor.to(device))
    out = out.squeeze().cpu().detach()
    out = unnorm(out) * 255
    out = out.permute(1,2,0).numpy().round().astype(np.uint8)
    
    m = Image.fromarray(out)
    m.save("/kaggle/images/" + "fake_monet_" + str(i) + ".jpg")
    
    if i % 70 == 0: print("[SAVING]    ",i // 70 ,"% ")

In [None]:
import shutil
shutil.make_archive("/kaggle/working/images", 'zip', "/kaggle/images")