In [151]:
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision.transforms as transforms
import numpy as np
from torch.utils.data.dataloader import DataLoader

In [None]:
!unzip /content/drive/MyDrive/Projects/Moret_colouring/data/gan-getting-started.zip

In [132]:
monet_path = "/content/monet_jpg/*jpg"
photo_path = "/content/photo_jpg/*jpg"
from skimage.io import imread_collection
monet = imread_collection(monet_path)
photo = imread_collection(photo_path)

For Ease of Selecting data randomly from the collection we are using `set()` object in python. **Data Augmentation to be added in this place**

In [157]:
monet_set = set()             # two data holding sets
photo_set = set()
for img in monet:
  image = torch.Tensor(img).cuda(0)
  image = image.permute(2,0,1)/255
  monet_set.update([image])   # as the time of update set iterates we use []

In [158]:
for img in photo:
  image = torch.Tensor(img).cuda(0)
  image = image.permute(2,0,1)/255
  photo_set.update([image])

####Get Item
This function take a input set and returns desired size of tensor data  
Take a sample of size as `batch_size` and form a vector from those images

In [154]:
from random import sample
def get_item(data : set,batch_size = 50) -> torch.TensorType:
  batch = torch.zeros((batch_size,3,256,256)).cuda(0)
  tensor_list = sample(data,k=batch_size)
  for i,im in enumerate(tensor_list):
    batch[i]= im.clone().detach()
  return batch

In [None]:
batch = get_item(monet_set)
print(batch[0])

## Generator


In [177]:
class Generator(nn.Module):
  def __init__(self):
      super(Generator,self).__init__()
      self.conv1 = nn.Sequential(nn.Conv2d(3,8,kernel_size=3,padding=1,stride=2),
                                   nn.BatchNorm2d(8),
        nn.LeakyReLU())
      self.conv2 = nn.Sequential(nn.Conv2d(8,16,kernel_size=5,padding=2,stride=2),
                                  nn.BatchNorm2d(16),
      nn.LeakyReLU())
      self.conv3 = nn.Sequential(nn.Conv2d(16,32,kernel_size=3,padding=1,stride=2),
                                  nn.BatchNorm2d(32),
      nn.LeakyReLU())
      self.bottleneck = nn.Sequential(nn.Conv2d(32,32,kernel_size=5,padding=2,stride=2),
      nn.LeakyReLU())
      # BottleNeck
      self.deconv4 = nn.Sequential(nn.ConvTranspose2d(34,32,kernel_size=4,stride=2,padding=1),
      nn.ReLU())
      self.deconv3 = nn.Sequential(nn.ConvTranspose2d(32,16,kernel_size=4,stride=2,padding=1),
      nn.ReLU())
      self.deconv2 = nn.Sequential(nn.ConvTranspose2d(16,8,kernel_size=4,stride=2,padding=1),
      nn.ReLU())
      self.deconv1 = nn.Sequential(nn.ConvTranspose2d(8,3,kernel_size=4,stride=2,padding=1),
      nn.ReLU())
  def forward(self,xb : torch.TensorType ,z : torch.TensorType) -> torch.TensorType:
      out1 = self.conv1(xb)
      #print("out1 ",out1.shape)
      out2 = self.conv2(out1)
      #print("out 2",out2.shape)
      out3 = self.conv3(out2)
      #print("out3 ",out3.shape)
      out4 = self.bottleneck(out3)
      print("after bottleneck",out4.shape)
      out4 = torch.cat((z,out4),1)
      out4 = self.deconv4(out4)
      #print("after deconv4",out4.shape)
      out4 = self.deconv3(out4)
      #print("after deconv3",out4.shape)
      out4 = self.deconv2(out4)
      #print("after deconv2",out4.hape)
      out4 = self.deconv1(out4)
      #print("after deconv1",out4.shape)
      return out4


##Discriminator

In [190]:
class Discriminator(nn.Module):
    def __init__(self):
      super(Discriminator,self).__init__()
      self.conv = nn.Sequential(
            nn.Conv2d(3,16,kernel_size=4,stride=4),
            nn.ReLU(),
            nn.Conv2d(16,64,kernel_size=4,stride=4),
            nn.ReLU(),
            nn.Conv2d(64,128,kernel_size=4,stride=4),
            nn.ReLU(),
            nn.Conv2d(128,64,kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(64,16,kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(16,4,kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(4,1,kernel_size=1),
            nn.Sigmoid(),
            nn.Flatten()
            )
      self.linear = nn.Linear(16,1)
    def forward(self,x : torch.TensorType) -> torch.TensorType:
        out = self.conv(x)
        out = self.linear(out)
        return out