In [30]:
import zipfile
import os
# !unzip /content/drive/MyDrive/grass/datasets/train.zip -d /content/train
# with zipfile.ZipFile("/content/drive/MyDrive/Colab Notebooks/train_channel7.zip","r") as zf:
#     zf.extractall("/content/train/")

# if not os.path.exists("/content/train/generated/"):
#     os.makedirs("/content/train/generated/")

# with zipfile.ZipFile("/content/drive/MyDrive/Colab Notebooks/grass_style_transfer_references.zip","r") as zf2:
#     zf2.extractall("/content/train/refs/")

In [31]:
config_defaults = {
    'BATCH_SIZE' : 8,
    'IN_CHANNEL_GEN' : 7,
    'IN_CHANNEL_DIS' : 10,

    'OUT_CHANNEL' : 3,
    'PATH_DATA' : '/content/train/',
    'PATH' : '/content/drive/MyDrive/Colab Notebooks/seashell 14 results/lightning_logs/version_4/checkpoints/epoch=27-step=14000.ckpt',
    'MAX_EPOCH' : 10,
    'Lr_GEN' : 0.001,
    'Lr_DIS' : 0.004,
    'RESULT_PATH': '/content/drive/MyDrive/grass/experiments/example-sea-creature-30050/',
    }
CONFIG = config_defaults

In [32]:
#Downloading data from kaggle
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from torchvision import transforms
from PIL import Image
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (20, 12)

import sys
from torch.nn import Module, Conv2d
from torch.nn.utils import spectral_norm
from torch.nn.functional import interpolate, relu

import warnings
warnings.filterwarnings("ignore")
from torch.nn.functional import leaky_relu
from torch.nn.utils import spectral_norm


# if not os.path.isdir(CONFIG['PATH_DATA']):
#   !pip install -q kaggle
#   !mkdir ~/.kaggle
#   !pwd
#   from google.colab import files
#   files.upload()
#   !cp kaggle.json ~/.kaggle/
#   !chmod 600 ~/.kaggle/kaggle.json
#   #!kaggle datasets list
#   !kaggle datasets download -d fluxo4/grass-generation-training-set -p /content/
#   !unzip /content/grass-generation-training-set.zip -d /content/drive/MyDrive/grass-generation-training-set


In [33]:
class GrassDataset(Dataset):
  def __init__(self, path, num_items = -1):
    self.df = pd.read_csv(path + "train.csv")
    self.path = path
    self.length = num_items
    if num_items == -1:
      self.length = len(self.df)
    else:
      self.length = min(len(self.df),num_items)
      self.df = self.df.sample(n = self.length, replace=False)
    self.df = self.df.reset_index(drop=True)
    self.df.head()

  def __len__(self):
    return self.length

  def __getitem__(self, idx):
    inputImagePath = self.path + self.df['Input'][idx][2:]
       
    inputImage = Image.open(inputImagePath)
    inputImage = transforms.functional.to_tensor(inputImage)

   
    depthImagePath = self.path + self.df['Depth'][idx][2:]
    depthImage = Image.open(depthImagePath)
    depthImage = transforms.functional.to_tensor(depthImage)

    nImagePath = self.path + self.df['Normal'][idx][2:]
    nImage = Image.open(nImagePath)
    nImage = transforms.functional.to_tensor(nImage)

    realImagePath = self.path + self.df['Output'][idx][2:]
    realImage = Image.open(realImagePath)
    real = transforms.functional.to_tensor(realImage)
   
    #                                       here
    condition = torch.cat((inputImage, depthImage[0:1, :, :], nImage), 0) # (3+1+3 = 7) x 480 x 800

    return condition, real

class RAMGrassDataset(Dataset): 
#This one loads the dataset onto the RAM, to speed up training speed, especially if you're running a large number of epochs
#Be very careful about how many items you want there to be here
  def __init__(self, path, num_items = -1):
    print("preparing RAM dataset, hopefully this doesn't take very long")
    self.df = pd.read_csv(path + "train.csv")
    self.path = path
    self.length = num_items
    if num_items == -1:
      self.length = len(self.df)
    else:
      self.length = min(len(self.df),num_items)
      self.df = self.df.sample(n = self.length, replace=False)
    self.df = self.df.reset_index(drop=True)


    self.real = []
    self.condition = []
    print("prepared all variables, adding data to RAM")
    for i in range(self.length):
      idx, real, condition = self.getitem(i)
      self.real.append(real)
      self.condition.append(condition)
    print("dataset added to RAM")

  def __len__(self):
    return self.length

  def __getitem__(self, idx):
    return self.real[idx], self.condition[idx]

  def getitem(self, idx):
    inputImagePath = self.path + self.df['Input'][idx][2:]
       
    inputImage = Image.open(inputImagePath)
    inputImage = transforms.functional.to_tensor(inputImage)

    print(inputImage)
    depthImagePath = self.path + self.df['Depth'][idx][2:]
    depthImage = Image.open(depthImagePath)
    depthImage = transforms.functional.to_tensor(depthImage)

    nImagePath = self.path + self.df['Normal'][idx][2:]
    nImage = Image.open(nImagePath)
    nImage = transforms.functional.to_tensor(nImage)


    realImagePath = self.path + self.df['Output'][idx][2:]
    realImage = Image.open(realImagePath)
    real = transforms.functional.to_tensor(realImage)

    #                                       here
    condition = torch.cat((inputImage, depthImage[0:1, :, :], nImage), 0) # (3+1+3 = 7) x 480 x 800

    return  condition,real


In [34]:
class Args:
    def __init__(self, spade_filter=128, spade_kernel=3, spade_resblk_kernel=3, gen_input_size=256, gen_hidden_size=16384):
        self.spade_filter = spade_filter
        self.spade_kernel = spade_kernel
        self.spade_resblk_kernel = spade_resblk_kernel
        self.gen_input_size = gen_input_size
        self.gen_hidden_size = gen_hidden_size
        
        if gen_hidden_size%16 != 0:
            print("Gen hidden size not multiple of 16")

spade_filter = 64
gen_input_size = 256
gen_hidden_size = 128 * 375
args = Args(spade_filter, 3, 3, gen_input_size, gen_hidden_size)


# def weights_init(m):
#     classname = m.__class__.__name__
#     if classname.find('Conv') != -1:
#         nn.init.normal_(m.weight.data, 0.0, 0.02)
#     elif classname.find('BatchNorm') != -1:
#         nn.init.normal_(m.weight.data, 1.0, 0.02)
#         nn.init.constant_(m.bias.data, 0)

def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)


In [35]:
from pandas._libs.lib import fast_unique_multiple_list_gen


inchannel = 7
class SPADE(Module):
    def __init__(self, args, k):
        super().__init__()
        num_filters = args.spade_filter
        kernel_size = args.spade_kernel
        self.conv = spectral_norm(Conv2d(inchannel, num_filters, kernel_size=(kernel_size, kernel_size), padding=1)) #made changes on inputChannel
        self.conv_gamma = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))
        self.conv_beta = spectral_norm(Conv2d(num_filters, k, kernel_size=(kernel_size, kernel_size), padding=1))


    def forward(self, x, seg):
        #x = (b, 128, h, w), seg = (b, 4, h,w )
        # print('In SPADE')
        # seg_copy = self.convseg(seg) #seg = (b, 128, h,w)
        N, C, H, W = x.size()
        sum_channel = torch.sum(x.reshape(N, C, H*W), dim=-1)

        mean = sum_channel / (N*H*W)

        std = torch.sqrt((sum_channel**2 - mean**2) / (N*H*W))
        mean = torch.unsqueeze(torch.unsqueeze(mean, -1), -1)
        std = torch.unsqueeze(torch.unsqueeze(std, -1), -1)
        x = (x - mean) / std

        # seg = interpolate(seg, size=(H,W), mode='nearest')
        seg_copy = relu(self.conv(seg))  #------------------------------->CPU and !CUDA
        # print(f'seg_copy: {seg_copy.shape}')

        seg_gamma = self.conv_gamma(seg_copy)
        # print(f'seg_gamma: {seg_gamma.shape}')

        seg_beta = self.conv_beta(seg_copy)
        # print(f'seg_beta: {seg_beta.shape}')

        #torch.matmul performs matrix multiplication so for the given equal sized vectors need to make size of seg_gamma such that seg_gamma.size(3)=x.size(2)
        #so taking the difference between those two those indices concatinating the difference to x
        x_copy = torch.cat((x, torch.zeros(x.size(0), x.size(1), seg_gamma.size(3)-x.size(2), x.size(3)).to(device)), dim = 2)
        # print(f'x_copy: {x_copy.shape}')

        x = (torch.matmul(seg_gamma, x_copy) + seg_beta)
        # print(f'exit: {x.shape}')
        return x


class SPADEResBlk(Module):
    def __init__(self, args, k, skip=False):
        super().__init__()
        kernel_size = args.spade_resblk_kernel
        self.skip = skip
        
        if self.skip:
            self.spade1 = SPADE(args, 2*k)
            self.conv1 = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
            self.spade_skip = SPADE(args, 2*k)
            self.conv_skip = Conv2d(2*k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
        else:
            self.spade1 = SPADE(args, k)
            self.conv1 = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)

        self.spade2 = SPADE(args, k)
        self.conv2 = Conv2d(k, k, kernel_size=(kernel_size, kernel_size), padding=1, bias=False)
    
    def forward(self, x, seg):
        # print(f'In SpadeResBlk')
        x_skip = x
        # print(f'before spade1: x {x.shape} seg: {seg.shape} ')
        x = relu(self.spade1(x, seg)) #b, 128, 30, 50
        # print(f'x: {x.shape}')
        # print(f'After spade1: {x.shape}')
        x = self.conv1(x) #b, 128, 30, 50

        # print(f'After conv1: {x.shape}')
        x = relu(self.spade2(x, seg) ) #b, 128, 30,50  
        x = self.conv2(x) #b, 128, 30, 50

        if self.skip:
            x_skip = relu(self.spade_skip(x_skip, seg))
            x_skip = self.conv_skip(x_skip)
        # print(f'After spade1: {(x_skip + x).shape}')
        return x_skip + x



class Generator(nn.Module):
    def __init__(self, args, inchannels): #---------_Remember inchannels
        super().__init__()
        self.linear = nn.Linear(args.gen_input_size, args.gen_hidden_size)
        self.fc = nn.Conv2d(inchannels, 128, kernel_size = (3,3), padding = 1)
        self.spade_resblk1 = SPADEResBlk(args, 128)
        self.spade_resblk2 = SPADEResBlk(args, 128)
        self.spade_resblk3 = SPADEResBlk(args, 64, skip=True)
        self.spade_resblk4 = SPADEResBlk(args, 32, skip=True)
        self.conv = nn.utils.spectral_norm(nn.Conv2d(32, 3, kernel_size=(3,3), padding=1))
        self.convTrans128 = nn.ConvTranspose2d(128, 128, kernel_size = 2, stride = 2)
        self.convTrans64 = nn.ConvTranspose2d(64, 64, kernel_size = 2, stride = 2)
        self.convTrans32 = nn.ConvTranspose2d(32, 32, kernel_size = 2, stride = 2)

    
    def forward(self, seg):
        # x = self.linear(x)
        # print(f'After linear {x.shape}')
        # x = x.view(seg.size(0), -1, 15, 25) #change 4 4 such that -1 = 128 If change this also change hidden layer size with x*y
        # print(f'After view {x.shape}')
        #x.shape = 8,4,15,25
        m4 = F.interpolate(seg, scale_factor = 0.5, mode = "bicubic") #b, 4, 240, 400
        m3 = F.interpolate(m4, scale_factor = 0.5, mode = "bicubic") #b, 4, 120, 200
        m2 = F.interpolate(m3, scale_factor = 0.5, mode = "bicubic") #b, 4, 60, 100
        m1 = F.interpolate(m2, scale_factor = 0.5, mode = "bicubic") #b, 4, 30, 50 

        # print(f'm1: {m1.shape}')
        x = self.fc(m1) #b, 128, 30, 50
        # print(f'fc: {x.shape}')
        x = self.spade_resblk1(x, m1) #b, 128, 30, 50
        #print(f'resblk1: {x.shape} type {type(x)}')
        x= self.convTrans128(x)
        #x = F.interpolate(x, scale_factor = 2, mode='bicubic') #b, 128, 60, 100
        #print(f'interpolate1: {x.shape}')

        x = self.spade_resblk2(x, m2) #b, 128, 60, 100
        #print(f'resblk2: {x.shape}')
        x= self.convTrans128(x)
        #x = F.interpolate(x, scale_factor = 2, mode='bicubic')#b, 128, 120, 200
        #print(f'interpolate 2: {x.shape}')

        x = self.spade_resblk3(x, m3)  #b, 64, 120,200
        #print(f'resblk3: {x.shape}')
        x= self.convTrans64(x)

        #x = F.interpolate(x,scale_factor = 2, mode='bicubic')#b, 64, 240, 400 
       # print(f'interpolate3: {x.shape}')
        
        x = self.spade_resblk4(x, m4)  #b, 32, 240, 400
        # print(f'resblk4: {x.shape}')
        x= self.convTrans32(x)

        #x = F.interpolate(x, scale_factor = 2, mode='bicubic')#b, 32, 480, 800
        # print(f'interpolate 4: {x.shape}')

        x = F.tanh(self.conv(x)) #b, 3, 480, 800
        # print(f'tanh: {x.shape}')
        
        return x

In [36]:


def custom_model1(in_chan, out_chan):
    return nn.Sequential(
        spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=2, padding=1)),
        nn.LeakyReLU(inplace=False)
    )

def custom_model2(in_chan, out_chan, stride=2):
    return nn.Sequential(
        spectral_norm(nn.Conv2d(in_chan, out_chan, kernel_size=(4,4), stride=stride, padding=1)),
        nn.InstanceNorm2d(out_chan),
        nn.LeakyReLU(inplace=False)
    )

class SPADEDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = custom_model1(7, 64)  #here channel
        self.layer2 = custom_model2(64, 128)
        self.layer3 = custom_model2(128, 256)
        self.layer4 = custom_model2(256, 512, stride=1)
        self.inst_norm = nn.InstanceNorm2d(512)
        self.conv = spectral_norm(nn.Conv2d(512, 1, kernel_size=(4,4), padding=1))

    def forward(self, img, seg):
        x = torch.cat((seg, img.detach()), dim=1)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = leaky_relu(self.inst_norm(x), inplace =False)
        x = self.conv(x)
        return x




class Discriminator(nn.Module):
    def __init__(self, inchannel):
        super().__init__()
        self.inchannel = inchannel
        self.layer1 = custom_model1(inchannel, 32) #Changed here our's 7 channel + 3 channel
        self.layer2 = custom_model2(32, 64)
        self.layer3 = custom_model2(64, 128)
        self.layer4 = custom_model2(128, 256)
        self.inst_norm = nn.InstanceNorm2d(256)
        self.conv = nn.utils.spectral_norm(nn.Conv2d(256, 1, kernel_size=(4,4)))

    def forward(self, img, seg):
        x = torch.cat((img, seg), dim=1)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = F.leaky_relu(self.inst_norm(x))

        x = self.conv(x)
        return x.squeeze()

In [37]:
class DownSampleConv(nn.Module):

    def __init__(self, in_channels, out_channels, kernel=4, strides=2, padding=1, activation=True, batchnorm=True):
        """
        Paper details:
        - C64-C128-C256-C512-C512-C512-C512-C512
        - All convolutions are 4×4 spatial filters applied with stride 2
        - Convolutions in the encoder downsample by a factor of 2
        """
        super().__init__()
        self.activation = activation
        self.batchnorm = batchnorm

        self.conv = nn.Conv2d(in_channels, out_channels, kernel, strides, padding)

        if batchnorm:
            self.bn = nn.BatchNorm2d(out_channels)

        if activation:
            self.act = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.conv(x)
        if self.batchnorm:
            x = self.bn(x)
        if self.activation:
            x = self.act(x)
        return x


class PatchGAN(nn.Module):

    def __init__(self, input_channels):
        super().__init__()
        self.d1 = DownSampleConv(input_channels, 32, batchnorm=False)
        self.d2 = DownSampleConv(32, 64)
        self.d3 = DownSampleConv(64, 64)
        self.d4 = DownSampleConv(64, 64)
        self.final = nn.Conv2d(64, 1, kernel_size=1)

    def forward(self, x, y):
        x = torch.cat([x, y], axis=1)
        x0 = self.d1(x)
        x1 = self.d2(x0)
        x2 = self.d3(x1)
        x3 = self.d4(x2)
        xn = self.final(x3)
        return xn


In [38]:


class GANLoss(nn.Module):
    def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
                 tensor=torch.FloatTensor):
        super().__init__()
        self.real_label = target_real_label
        self.fake_label = target_fake_label
        self.real_label_var = None
        self.fake_label_var = None
        self.Tensor = tensor
        if use_lsgan:
            self.loss = nn.L1Loss()
        else:
            self.loss = nn.BCEWithLogitsLoss()
            
    def get_target_tensor(self, input, target_is_real):
        target_tensor = None
        if target_is_real:
            create_label = ((self.real_label_var is None) or
                            (self.real_label_var.numel() != input.numel()))
            if create_label:
                real_tensor = self.Tensor(input.size()).fill_(self.real_label)
                self.real_label_var = torch.tensor(real_tensor, requires_grad=False)
            target_tensor = self.real_label_var
        else:
            create_label = ((self.fake_label_var is None) or
                            (self.fake_label_var.numel() != input.numel()))
            if create_label:
                fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
                self.fake_label_var = torch.tensor(fake_tensor, requires_grad=False)
            
            target_tensor = self.fake_label_var
        return target_tensor

    def __call__(self, input, target_is_real):        
        target_tensor = self.get_target_tensor(input, target_is_real)

        return self.loss(input, target_tensor.to(device)) #here------>for loss GPU


In [39]:
def display_progress(cond, fake, real, figsize=(10,5)):
    cond = cond.detach().cpu().permute(1, 2, 0)
    fake = fake.detach().cpu().permute(1, 2, 0)
    real = real.detach().cpu().permute(1, 2, 0)
    
    fig, ax = plt.subplots(1, 3, figsize=figsize)
    # print(cond.shape)
    ax[0].imshow(cond[:,:,0:3])
    ax[2].imshow(fake)
    ax[1].imshow(real)
    plt.show()

In [40]:


if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')
    # raise Exception('GPU not available')
torch.backends.cudnn.benchmark = True

gen = Generator(args,inchannels = CONFIG['IN_CHANNEL_GEN'])
dis = Discriminator(CONFIG['IN_CHANNEL_DIS'])
#dis = SPADEDiscriminator()
gen = gen.to(device)
dis = dis.to(device)

gen.apply(weights_init)
dis.apply(weights_init)

criterionG = GANLoss()    #L1Loss
criterionD = GANLoss(use_lsgan = False) #logitLoss 
recon_loss = nn.L1Loss()

optim_gen = torch.optim.Adam(gen.parameters(), lr=CONFIG['Lr_GEN'])
optim_dis = torch.optim.Adam(dis.parameters(), lr=CONFIG['Lr_DIS'])


In [41]:
# dataset = GrassDataset(path)
# dataloader = DataLoader(dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)
# print(len(dataloader))
# for i, (condition, real) in enumerate(dataloader):
#   print(condition)
#   print(real)
#   break

In [42]:
from torchvision.utils import make_grid, save_image

In [None]:
img_lists = []
G_losses = []
D_losses = []
epochs = 1
iters = 0
recon_var = 100
import time

path = CONFIG['PATH_DATA']
dataset = GrassDataset(path)
dataloader = DataLoader(dataset, batch_size=CONFIG['BATCH_SIZE'], shuffle=True)
path_save = CONFIG['RESULT_PATH']
for epoch in tqdm(range(CONFIG['MAX_EPOCH'])):
  start = time.time()

  print(f'Epoch: {epoch+1}')
  for i, (condition, real) in enumerate(dataloader):
    if(condition.size(0) != CONFIG['BATCH_SIZE']):  #If dataset not in multiple of 8
      break
    con = condition.to(device)
    real = real.to(device)

    fake_img = gen(con)  
    if i%2==0:
    #Discriminator
      pred_real = dis(real, con)
      loss_D_real = criterionD(pred_real, True)

      #Fake Detection
      pred_fake = dis(fake_img.detach(), con) 
      loss_D_fake = criterionD(pred_fake, False)

      #back for discriminator
      optim_dis.zero_grad()
      loss_D = loss_D_fake + loss_D_real*0.5 #culprit-------> loss_D_fake
      loss_D.backward(retain_graph= True)
      optim_dis.step()

    #Generator
    optim_gen.zero_grad()
    pred_fake = dis(fake_img, con) #Shape
    loss_G_d = criterionG(pred_fake, True)
    loss_G_recon = recon_loss(fake_img, real)
    loss_G = loss_G_d + (recon_var*loss_G_recon)
    loss_G.backward()

    optim_gen.step()

    G_losses.append(loss_G.detach().cpu())
    D_losses.append(loss_D.detach().cpu())
    
    if i%20 == 0:
      print("Iteration {}/{} started".format(i+1, len(dataloader)))
      fake_img = gen(con).detach()
      display_progress(con[0,:,:,:],fake_img[0,:,:,:], real[0,:,:,:] )
      print(f'loss_D: {loss_D.detach()} loss_g: {loss_G.detach()}')
      print(time.time()-start)
      start = time.time()

      path_e = path_save + str(epoch) + "_" + str(i) + ".png"
      save_image(fake_img[0], path_e)
  if epoch%5 == 0:
    with torch.no_grad():
      img_lists.append(fake_img.detach().cpu().numpy())      


Output hidden; open in https://colab.research.google.com to view.

In [None]:
#torch.cuda.empty_cache()
import torch
torch.cuda.memory_summary(device=None, abbreviated=False)

In [None]:
import torch
import torch.nn as nn
x = torch.randn(8,7,30,50)
mat = nn.ConvTranspose2d(x.size(1), x.size(1), 2, stride = 2)
x = mat(x)
print(x.shape)

In [None]:
total = sum(
    
    params.numel() for params in gen.parameters() 
)
print(total)