In [3]:
# from google.colab import drive
# drive.mount('/content/drive')

# !pip install -q pytorch-lightning

# !unzip /content/drive/MyDrive/grass/datasets/train.zip -d /content/train

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 extracting: /content/train/Output/3875_output.jpg  
 extracting: /content/train/Output/3876_output.jpg  
 extracting: /content/train/Output/3877_output.jpg  
 extracting: /content/train/Output/3878_output.jpg  
 extracting: /content/train/Output/3879_output.jpg  
 extracting: /content/train/Output/387_output.jpg  
 extracting: /content/train/Output/3880_output.jpg  
 extracting: /content/train/Output/3881_output.jpg  
 extracting: /content/train/Output/3882_output.jpg  
 extracting: /content/train/Output/3883_output.jpg  
 extracting: /content/train/Output/3884_output.jpg  
 extracting: /content/train/Output/3885_output.jpg  
 extracting: /content/train/Output/3886_output.jpg  
 extracting: /content/train/Output/3887_output.jpg  
 extracting: /content/train/Output/3888_output.jpg  
 extracting: /content/train/Output/3889_output.jpg  
 extracting: /content/train/Output/388_output.jpg  
 extracting: /content/train/Output/3

In [2]:

import os
from glob import glob
from pathlib import Path

import matplotlib.pyplot as plt
import pytorch_lightning as pl
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.transforms.functional import center_crop
from torchvision.utils import make_grid, save_image
from tqdm.auto import tqdm

import torch.nn.functional as F
from torch.nn import Module, Conv2d
from torch.nn.utils import spectral_norm
from torch.nn.functional import interpolate, relu

import pandas as pd


ModuleNotFoundError: ignored

In [None]:
config_defaults = {
    'BATCH_SIZE' : 8,
    'IN_CHANNEL' : 7,
    'OUT_CHANNEL' : 3,
    'LOAD_CHECKPOINT' : False,
    'PATH_CONTEXT' : '/content/train/',
    'PATH_DATA' : '/content/drive/MyDrive/grass/datasets/train.zip',
    'PATH_CHECKPOINT' : '/content/drive/MyDrive/grass/experiments/example-sea-creature-noise/lightning_logs/version_13/checkpoints/epoch=8-step=4500.ckpt',
    'Lr_gen' : 0.0002,
    'Lr_disc' : 0.0002,
    'MAX_EPOCH' : 10,
    'SAVE_NTH_BATCH' : 20,
    'DATASET_SIZE' : 2000,
    'LAMBDA_RECON' : 80,
    'HALF_SIZE_LOSS_WEIGHT' : 1,
    'DISPLAY_LOSS_N': 40,
    'TRAIN_DISCRIMINATOR': 2,
    'RESULT_PATH': '/content/drive/MyDrive/grass/experiments/example-sea-creature-TinySPADE/',
    'INPUT_DIM': 50*30*128,
    'INITIAL_FILTER_SIZE': 128,
    }

CONFIG = config_defaults

In [None]:
class GrassDataset(Dataset):
  def __init__(self, path, num_items = -1):
    self.df = pd.read_csv(path + "train.csv")
    self.path = path
    self.length = num_items
    if num_items == -1:
      self.length = len(self.df)
    else:
      self.length = min(len(self.df),num_items)
      self.df = self.df.sample(n = self.length, replace=False)
    self.df = self.df.reset_index(drop=True)
    self.df.head()

  def __len__(self):
    return self.length

  def __getitem__(self, idx):
    inputImagePath = self.path + self.df['Input'][idx][2:]
       
    inputImage = Image.open(inputImagePath)
    inputImage = transforms.functional.to_tensor(inputImage)

   
    depthImagePath = self.path + self.df['Depth'][idx][2:]
    depthImage = Image.open(depthImagePath)
    depthImage = transforms.functional.to_tensor(depthImage)

    nImagePath = self.path + self.df['Normal'][idx][2:]
    nImage = Image.open(nImagePath)
    nImage = transforms.functional.to_tensor(nImage)

    realImagePath = self.path + self.df['Output'][idx][2:]
    realImage = Image.open(realImagePath)
    real = transforms.functional.to_tensor(realImage)
   
    #                                       here
    condition = torch.cat((inputImage, depthImage[0:1, :, :], nImage), 0) # (3+1+3 = 7) x 480 x 800

    return idx, real, condition

class RAMGrassDataset(Dataset): 
#This one loads the dataset onto the RAM, to speed up training speed, especially if you're running a large number of epochs
#Be very careful about how many items you want there to be here
  def __init__(self, path, num_items = -1):
    print("preparing RAM dataset, hopefully this doesn't take very long")
    self.df = pd.read_csv(path + "train.csv")
    self.path = path
    self.length = num_items
    if num_items == -1:
      self.length = len(self.df)
    else:
      self.length = min(len(self.df),num_items)
      self.df = self.df.sample(n = self.length, replace=False)
    self.df = self.df.reset_index(drop=True)


    self.real = []
    self.condition = []
    print("prepared all variables, adding data to RAM")
    for i in range(self.length):
      idx, real, condition = self.getitem(i)
      self.real.append(real)
      self.condition.append(condition)
    print("dataset added to RAM")

  def __len__(self):
    return self.length

  def __getitem__(self, idx):
    return self.real[idx], self.condition[idx]

  def getitem(self, idx):
    inputImagePath = self.path + self.df['Input'][idx][2:]
       
    inputImage = Image.open(inputImagePath)
    inputImage = transforms.functional.to_tensor(inputImage)

    print(inputImage)
    depthImagePath = self.path + self.df['Depth'][idx][2:]
    depthImage = Image.open(depthImagePath)
    depthImage = transforms.functional.to_tensor(depthImage)

    nImagePath = self.path + self.df['Normal'][idx][2:]
    nImage = Image.open(nImagePath)
    nImage = transforms.functional.to_tensor(nImage)


    realImagePath = self.path + self.df['Output'][idx][2:]
    realImage = Image.open(realImagePath)
    real = transforms.functional.to_tensor(realImage)

    #                                       here
    condition = torch.cat((inputImage, depthImage[0:1, :, :], nImage), 0) # (3+1+3 = 7) x 480 x 800

    return idx, real, condition


In [None]:
#load dataset
dataset = GrassDataset(CONFIG['PATH_CONTEXT'], CONFIG['DATASET_SIZE'])
dataloader = DataLoader(dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)

In [None]:
class DownSampleConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):

        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm

        self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)
            #self.bn = nn.GroupNorm(1, out_channels)

        if activation:
            self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        if self.activation:
            x = self.act(x)
        return x

class UpSampleConv(nn.Module):

    def __init__(
        self,
        in_channels,
        out_channels,
        kernel=4,
        strides=2,
        padding=1,
        activation=True,
        batchnorm=True,
        dropout=False
    ):
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm
        self.dropout = dropout

        self.deconv = nn.ConvTranspose2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.ReLU(True)

        if dropout:
            self.drop = nn.Dropout2d(0.5)

    def forward(self, x):
        x = self.deconv(x)
        if self.batchnorm:
            x = self.bn(x)

        if self.dropout:
            x = self.drop(x)
        return x

In [None]:
# class ResnetBlock(nn.Module):
#     def __init__(self, dim, padding_type, norm_layer=nn.BatchNorm2d, activation=nn.ReLU(True), use_dropout=False):
#         super(ResnetBlock, self).__init__()
#         self.build_conv_block(dim, padding_type, norm_layer, activation, use_dropout)

#     def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
#         self.activation = activation;
#         self.norm_layer = norm_layer;
#         self.dropout = use_dropout;

#         if(norm_layer):
#             self.nl = nn.BatchNorm2d(dim)
#             self.nl2 = nn.BatchNorm2d(dim)

          
#         if activation:
#             self.act = nn.ReLU(True)
#             self.act2 = nn.ReLU(True)
          
#         if use_dropout:
#             self.dp = nn.Dropout(0.5)

#         self.conv_1 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)


#         self.conv_2 = nn.Conv2d(dim, dim, kernel_size=3, padding=1)


#     def forward(self, x):
        
#         y = self.conv_1(x)
#         y = self.nl(y)
#         y = self.act(y)

#         #y = self.dp(y) Regularisation

#         y = self.conv_2(y)
#         y = self.nl2(y)
#         y = self.act2(y)

#         return x + y

In [None]:
# class RobGenerator(nn.Module):
#     def __init__(self, in_channels, out_channels):
#         super().__init__()
       
#         self.layers = [
#             DownSampleConv(in_channels, 64, batchnorm=False),  # bs x 64 x 240 x 400
#             DownSampleConv(64, 128),  # bs x 128 x 120 x 200
#             ResnetBlock(128, padding_type = 'reflect', use_dropout=True), # bs x 128 x 120 x 200
#             ResnetBlock(128, padding_type = 'reflect', use_dropout=True), # bs x 128 x 120 x 200
#             ResnetBlock(128, padding_type = 'reflect', use_dropout=True), # bs x 128 x 120 x 200
#             UpSampleConv(128, 128),  # bs x 128 x 240 x 400
#             ResnetBlock(128, padding_type = 'reflect', use_dropout=True), # bs x 128 x 240 x 400
#             ResnetBlock(128, padding_type = 'reflect', use_dropout=True), # bs x 128 x 240 x 400
#             nn.ConvTranspose2d(128, out_channels, kernel_size=4, stride=2, padding=1), # bs x 3 x 480 x 800
#         ]

#         self.tanh = nn.Tanh()
#         self.layers = nn.ModuleList(self.layers)

#     def forward(self, x):
#         for layer in self.layers:
#             x = layer(x)
#         return self.tanh(x)

In [None]:
class SPADE(Module):
    def __init__(self, k):
        super().__init__()
        num_filters = 64
        kernel_size = 3
        # self.conv = spectral_norm(Conv2d(CONFIG["IN_CHANNEL"], num_filters, kernel_size=(kernel_size, kernel_size), padding=1)) #made changes on inputChannel
        self.conv_gamma = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1)
        self.conv_beta = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1)
        self.activation = nn.Sequential(
            nn.Conv2d(CONFIG['IN_CHANNEL'], k, kernel_size=(kernel_size, kernel_size), padding =1),
            nn.ReLU()
        )
        self.normalization = nn.BatchNorm2d(k)
        self.conv = nn.Conv2d(CONFIG["IN_CHANNEL"], k, kernel_size=(kernel_size, kernel_size), padding=1,)

    def forward(self, x, seg):
        #x = (b, 128, h, w), seg = (b, 4, h,w )
        # print('In SPADE')
        normalized = self.normalization(x) #b, 128, h, w

        seg = self.activation(seg)  #b, k, h, w

        seg_gamma = self.conv_gamma(seg)
        # print(f'seg_gamma: {seg_gamma.shape}')

        seg_beta = self.conv_beta(seg) #b, k, h, w
        # print(f'seg_beta: {seg_beta.shape}')

        x = normalized*(1+seg_gamma) + seg_beta #b, k, h, w
        # print(f'exit: {x.shape}')
        return x

class SPADEResBlk(Module):
    def __init__(self, k, skip=False):
        super().__init__()
        kernel_size = 3
        self.skip = skip
        
        if self.skip:
            self.spade1 = SPADE(2*k)
            self.conv1 = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
            # self.spade_skip = SPADE(2*k)
            self.conv_skip = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
        else:
            self.spade1 = SPADE(k)
            self.conv1 = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
    
    def forward(self, x, seg):
        # print(f'In SpadeResBlk')
        x_skip = x #x=b*channel(-1)*h*w
        # print(f'before spade1: x {x.shape} seg: {seg.shape} ')
        x = relu(self.spade1(x, seg)) #b, 128, 30, 50
        # print(f'After spade1: {x.shape}')
        x = self.conv1(x) #b, 128, 30, 50

        if self.skip:
            # x_skip = relu(self.spade_skip(x_skip, seg))
            x_skip = self.conv_skip(x_skip) #b, c/2, h, w
        # print(f'After spade1: {(x_skip + x).shape}')
        return x_skip + x 

class SPADEGenerator(nn.Module):
    def __init__(self, inchannels, outchannels): #---------_Remember inchannels
        super().__init__()
        kernel_size = 3
        # self.ln = nn.Linear(CONFIG['INITIAL_FILTER_SIZE'], CONFIG['INPUT_DIM'])

        self.spade_resblk1 = SPADEResBlk(128)
        self.spade_resblk2 = SPADEResBlk(128)
        self.spade_resblk3 = SPADEResBlk(64, skip=True)
        self.spade_resblk4 = SPADEResBlk(32, skip=True)
        self.conv = nn.utils.spectral_norm(nn.Conv2d(32, outchannels, kernel_size=(3,3), padding=1))
        self.convTrans128 = nn.ConvTranspose2d(128, 128, kernel_size = 2, stride = 2)
        self.convTrans64 = nn.ConvTranspose2d(64, 64, kernel_size = 2, stride = 2)
        self.convTrans32 = nn.ConvTranspose2d(32, 32, kernel_size = 2, stride = 2)
        self.noise = torch.normal(0, 1, (8,7,30,50)).to(device)
        self.convInit = Conv2d(CONFIG['IN_CHANNEL'], 128, kernel_size=(kernel_size, kernel_size), padding=1, bias=False) 


    
    def forward(self, seg):
  
        # x = self.ln(self.noise)
        # x = x.view(CONFIG['BATCH_SIZE'], -1, 30, 50) #b*128*30*50 change 4 4 such that -1 = 128 If change this also change hidden layer size with x*y
        #print(f'After view {x.shape}')
        #x.shape = 8,4,15,25
        m4 = F.interpolate(seg, scale_factor = 0.5, mode = "bicubic") #b, 4, 240, 400
        m3 = F.interpolate(m4, scale_factor = 0.5, mode = "bicubic") #b, 4, 120, 200
        m2 = F.interpolate(m3, scale_factor = 0.5, mode = "bicubic") #b, 4, 60, 100
        m1 = F.interpolate(m2, scale_factor = 0.5, mode = "bicubic") #b, 4, 30, 50 

        # print(f'm1: {m1.shape}')
        
        x = self.convInit(self.noise)

        x = self.spade_resblk1(x, m1) #b, 128, 30, 50
        x= self.convTrans128(x) #b, 128, 60, 100
    
        x = self.spade_resblk2(x, m2) #b, 128, 60, 100
        x = self.convTrans128(x) #b, 128, 120, 200

        x = self.spade_resblk3(x, m3)  #b, 64, 120,200
        #print(f'resblk3: {x.shape}')
        x= self.convTrans64(x) 

       # print(f'interpolate3: {x.shape}')
        
        x = self.spade_resblk4(x, m4)  #b, 32, 240, 400
        # print(f'resblk4: {x.shape}')
        x= self.convTrans32(x) #b,32,480,800

        x = F.tanh(self.conv(x)) #b, 3, 480, 800
        # print(f'tanh: {x.shape}')
        return x

In [None]:
class PatchGAN(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.d1 = DownSampleConv(input_channels, 32, batchnorm=False)
        self.d2 = DownSampleConv(32, 64)
        self.d3 = DownSampleConv(64, 64)
        self.d4 = DownSampleConv(64, 64)
        self.final = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.d1(x)
        x1 = self.d2(x0)
        x2 = self.d3(x1)
        x3 = self.d4(x2)
        xn = self.final(x3)
        return xn

class MultilevelPatchGAN(nn.Module):
    def __init__(self, input_channels):
        super().__init__()
        self.d1 = DownSampleConv(input_channels, 32, batchnorm=False)
        self.d2 = DownSampleConv(32, 64)
        self.d3 = DownSampleConv(64, 64)
        self.d4 = DownSampleConv(64, 64)
        self.final = nn.Conv2d(64, 1, kernel_size=1)

        self.dd1 = DownSampleConv(input_channels, 32, batchnorm=False)
        self.dd2 = DownSampleConv(32, 64)
        self.dd3 = DownSampleConv(64, 64)
        self.dfinal = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        xx = torch.nn.functional.interpolate(x, scale_factor=0.5, mode="bicubic")

        x0 = self.d1(x)
        x1 = self.d2(x0)
        x2 = self.d3(x1)
        x3 = self.d4(x2)
        xn = self.final(x3)

        xx0 = self.dd1(xx)
        xx1 = self.dd2(xx0)
        xx2 = self.dd3(xx1)
        xxn = self.dfinal(xx2)
        return xn, xxn

In [None]:
def _weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
        
def display_progress(cond, fake, real, figsize=(10,5)):
    cond = cond.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)
    real = real.detach().cpu().permute(1, 2, 0)
    
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    #print(cond.shape)
    ax[0].imshow(cond[:,:,0:3])
    ax[2].imshow(fake)
    ax[1].imshow(real)
    plt.show()

def draw_result(lst_itrd, lst_itrg, lst_fake, lst_real,lst_recon,lst_adver):
  # print(f'{label1} {len(lst_err1)}')
  # print(f'{label2} {len(lst_err2)}')

  fig, axs = plt.subplots(2)
  fig.suptitle('loss')
  
  axs[0].plot(lst_itrd, lst_fake, '-b', label = 'Fake loss')
  axs[0].legend()
  axs[0].plot(lst_itrd, lst_real, '-k', label = 'Real loss')
  axs[0].legend()
  axs[1].plot(lst_itrg, lst_adver, '-r', label = 'Adversarial loss')
  axs[1].plot(lst_itrg, lst_recon, '-g', label = 'Reconstruction loss')



  plt.ylabel('loss')
  plt.xlabel('iteration')
  plt.legend()
  plt.show()
  # print(f'fake: {lst_fake[-1]} real: {lst_real[-1]} recon: {lst_recon[-1]} adver: {lst_adver[-1]}')

In [None]:
class GAN(pl.LightningModule):

    def __init__(self, in_channels, out_channels, learning_rate=0.0002, next_save_idx = 0):

        super().__init__()
        self.save_hyperparameters()

        self.gen = SPADEGenerator(in_channels, out_channels)
        self.patch_gan = PatchGAN(in_channels + out_channels)
        # intializing weights
        self.gen = self.gen.apply(_weights_init)

        self.patch_gan = self.patch_gan.apply(_weights_init)

        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.recon_criterion = nn.L1Loss()
        self.trainidx = next_save_idx

        self.lambda_recon = CONFIG["LAMBDA_RECON"]

        self.recon_lossl = []
        self.adverserial_lossl = []
        self.itrg = []
        self.itrd = []
        self.real_lossl = []
        self.fake_lossl = []

    def _gen_step(self, real_images, conditioned_images):
        fake_images = self.gen(conditioned_images)
        disc_logits = self.patch_gan(fake_images, conditioned_images)
        adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))

        recon_loss = self.recon_criterion(fake_images, real_images)
        self.adverserial_lossl.append(adversarial_loss.cpu().data.numpy())
        self.recon_lossl.append(recon_loss.cpu().data.numpy())
        self.itrg.append(len(self.adverserial_lossl))
        return adversarial_loss + self.lambda_recon * recon_loss
        
    def _disc_step(self, real_images, conditioned_images):
        fake_images = self.gen(conditioned_images).detach()

        fake_logits = self.patch_gan(fake_images, conditioned_images)
        real_logits = self.patch_gan(real_images, conditioned_images)

        fake_loss = self.adversarial_criterion(fake_logits, torch.zeros_like(fake_logits))
        real_loss = self.adversarial_criterion(real_logits, torch.ones_like(real_logits))
        self.fake_lossl.append(fake_loss.cpu().data.numpy())
        self.real_lossl.append(real_loss.cpu().data.numpy())
        self.itrd.append(len(self.fake_lossl))
        return (real_loss + fake_loss) / 2

    def configure_optimizers(self):
        gen_opt = torch.optim.Adam(self.gen.parameters(), lr=CONFIG['Lr_gen'])
        disc_opt = torch.optim.Adam(self.patch_gan.parameters(), lr=CONFIG['Lr_disc'])
        return disc_opt, gen_opt

    def training_step(self, batch, batch_idx, optimizer_idx):
        idx, real, condition = batch

        loss = None

        if optimizer_idx == 1:
            loss = self._gen_step(real, condition)

            self.log('Generator Loss', loss)
        
        elif(batch_idx%CONFIG['TRAIN_DISCRIMINATOR']==0 and optimizer_idx ==0 ):
            loss = self._disc_step(real, condition)

            self.log('PatchGAN Loss', loss)
        

        if batch_idx% CONFIG['SAVE_NTH_BATCH']==0and optimizer_idx==1:
            print(f'batch size {batch_idx}')
            fake = self.gen(condition).detach()
            display_progress(condition[0], fake[0], real[0])
            for i in range(1):
              index = int(idx[i])+1
              path = CONFIG['RESULT_PATH'] + str(self.trainidx) + "_" + str(index) + ".png"
              self.trainidx+=1
              save_image(fake[i], path)
              # draw_result(self.itrd, self.fake_lossl, self.real_lossl, label1 = 'Fake Loss', label2 = 'Real Loss') 
              # draw_result(self.itrg, self.recon_lossl, self.adverserial_lossl, label1 = 'Recon Loss', label2 = 'Adverserial Loss') 

              # print(f'batch: {batch_idx}/{ CONFIG['DATASET_SIZE']/CONFIG['BATCH_SIZE']}')

        if batch_idx%CONFIG['DISPLAY_LOSS_N'] ==0 and batch_idx != 0:
          draw_result(self.itrd, self.itrg, self.fake_lossl, self.real_lossl, self.recon_lossl, self.adverserial_lossl) 

        return loss

    def forward(self, x, y, z):
        x = torch.concat((x,y,z), axis=1)
        return self.gen(x)


In [None]:
#load dataset
dataset = GrassDataset(CONFIG['PATH_CONTEXT'], CONFIG['DATASET_SIZE'])
dataloader = DataLoader(dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)

In [None]:
xadversarial_loss = nn.BCEWithLogitsLoss()
reconstruction_loss = nn.L1Loss()

if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

if CONFIG['LOAD_CHECKPOINT']:
  model = GAN.load_from_checkpoint(CONFIG['PATH_CHECKPOINT'])
else:
  model = GAN(CONFIG['IN_CHANNEL'], CONFIG['OUT_CHANNEL'])

trainer = pl.Trainer(max_epochs=CONFIG['MAX_EPOCH'], gpus=1, default_root_dir = CONFIG['RESULT_PATH'])
trainer.fit(model, dataloader)


1. Used linear layers
2. Decreased Filter size to 128->64
3. lr = 0.0002
4. Used lekayReLU
5. Decreased Recon factor 100->50

## After training:
1. the losses remain same after nth epoch
2. Recon loss : ~0.1
3. Real loss: ~0.67
4. fake loss: ~0.71
5. Adverserial: ~0.7

In [None]:
import torch
torch.cuda.empty_cache()

In [None]:
# import cv2
# import numpy as np
# import matplotlib.pyplot as plt
# import torch
# from PIL import Image
# from torchvision.transforms import transforms

# image = Image.open("/content/train/Input/0_input.jpg")
# image.show()
# T = transforms.ToTensor()
# reI = T(image)
# reI = torch.unsqueeze(reI, dim = 0)
# print(reI.shape)
# rescaled1 = torch.nn.functional.interpolate(reI, scale_factor = 0.5, mode = "bicubic") # 200
# rescaled2 = torch.nn.functional.interpolate(rescaled1, scale_factor = 0.5, mode = "bicubic")#100
# rescaled3 = torch.nn.functional.interpolate(rescaled2, scale_factor = 0.5, mode = "bicubic")#50

# rescaled1 = torch.squeeze(rescaled1)
# rescaled2 = torch.squeeze(rescaled2)
# rescaled3 = torch.squeeze(rescaled3)

# I = transforms.ToPILImage()
# en1 = I(rescaled1)
# en2 = I(rescaled2)
# en3 = I(rescaled3)
# en1.show()
# en2.show()
# en3.show()


In [None]:
from google.colab import drive
drive.mount('/content/drive')

EXPERIMENT INFORMATION


Verson 13: Even after 9th epoch the results seems to be consistent. The model is not learning anything. It is learning about the lightning, structures and colors very well but it's not getting the textures right(or at all). Here we started training it from 60,100 size. 
