In [None]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import torchvision.transforms.functional as F
import gc

from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor

import numpy as np
import matplotlib.pyplot as plt

In [None]:
ngpu = 1
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
# Load the dataset
paint_drl = "/content/drive/MyDrive/Research/Rembrandt/resized_paintings/"
photo_drl = "/content/drive/MyDrive/Research/Rembrandt/resized_photos/"


In [None]:
import os
from torchvision.io import read_image

class ImageDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.dir = img_dir
        self.transform = transform
        self.target_transform = target_transform
        self.img_labels = []

        for image in os.listdir(self.dir):
          self.img_labels.append(image)

    def __len__(self):
      return len(self.img_labels)


    def __getitem__(self, idx):
        img_path = os.path.join(self.dir, self.img_labels[idx])
        image = read_image(img_path).float().to(device)
        if self.transform:
            image = self.transform(image)

        #label = 1

        return image

In [None]:
def resize(img_tensor):
  dim1 = (img_tensor.size()[1])//4*4
  dim2 = (img_tensor.size()[2])//4*4

  return F.resize(img_tensor, [dim1, dim2])

In [None]:
paint = ImageDataset("/content/drive/MyDrive/Research/Rembrandt/resized_paintings/", resize)
photo = ImageDataset("/content/drive/MyDrive/Research/Rembrandt/resized_photos/", resize)

In [None]:
def collect(batch):
  return [item for item in batch]

In [None]:
BATCH_SIZE = 8
paintLoader = torch.utils.data.DataLoader(paint, batch_size= 1, shuffle=True)
photoLoader = torch.utils.data.DataLoader(photo, batch_size= 1, shuffle=True)

In [None]:
# Network structures

class ResBlock(nn.Module):
  def __init__(self, in_channel, hidden, out_channel, filter_size, stride = 1, padding = 1, ngpu = 1):
    super().__init__()
    self.ngpu = ngpu

    self.in_channel = in_channel
    self.hidden = hidden
    self.out_channel = out_channel

    self.filter_size = filter_size
    self.stride = stride
    self.padding = padding

    self.conv = nn.Sequential(
        nn.Conv2d(self.in_channel, self.hidden, self.filter_size, padding = self.padding),
        nn.ReLU(),
        nn.Conv2d(self.hidden, self.out_channel, self.filter_size, padding = self.padding),
        nn.ReLU()
    )

  def forward(self, x):
    #for i,x in enumerate(batch):
    #results = None
    result = self.conv(x)+x
    return result



class Generator(nn.Module):
  def __init__(self, ngpu = 1):
    super().__init__()
    self.ngpu = ngpu
    self.main = nn.Sequential(
        nn.ReflectionPad2d(3),
        nn.Conv2d(3, 64, 7, 1),
        nn.InstanceNorm2d(64),
        nn.ReLU(True),

        nn.Conv2d(64, 128, 3, 2, 1),
        nn.Conv2d(128, 256, 3, 2, 1),

        ResBlock(256, 256, 256, 3),
        ResBlock(256, 256, 256, 3),
        ResBlock(256, 256, 256, 3),
        ResBlock(256, 256, 256, 3),
        ResBlock(256, 256, 256, 3),
        ResBlock(256, 256, 256, 3),
        ResBlock(256, 256, 256, 3),
        ResBlock(256, 256, 256, 3),
        ResBlock(256, 256, 256, 3),

        nn.ConvTranspose2d(256, 128, 3, 2, 1, 1),
        nn.InstanceNorm2d(128),
        nn.ReLU(True),

        nn.ConvTranspose2d(128, 64, 3, 2, 1, 1),
        nn.InstanceNorm2d(64),
        nn.ReLU(True),

        nn.ReflectionPad2d(3),

        nn.Conv2d(64, 3, 7, 1),
        nn.InstanceNorm2d(3),
        nn.Tanh()
    )

  def forward(self, x):
    return self.main(x)




class Discriminator(nn.Module):
  def __init__(self, ngpu = 1):
    super().__init__()
    self.ngpu = ngpu

    self.main = nn.Sequential(

        nn.Conv2d(3, 64, 4, 2),
        # nn.InstanceNorm2d(),
        nn.LeakyReLU(0.2, True),

        nn.Conv2d(64, 128, 4, 2),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, True),

        nn.Conv2d(128, 256, 4, 2),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, True),

        nn.Conv2d(256, 512, 4, 1),
        nn.InstanceNorm2d(512),
        nn.LeakyReLU(0.2, True),

        nn.Conv2d(512, 1, 4, 1)
    )

  def forward(self, x):
    return self.main(x)


In [None]:
# Initialisation
G_photoToPaint = Generator(ngpu).to(device)
G_paintToPhoto = Generator(ngpu).to(device)

D_paintToPhoto = Discriminator(ngpu).to(device)
D_photoToPaint = Discriminator(ngpu).to(device)



In [None]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)

In [None]:
G_photoToPaint.apply(weights_init)
G_paintToPhoto.apply(weights_init)

D_photoToPaint.apply(weights_init)
D_paintToPhoto.apply(weights_init)

Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2))
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2))
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2))
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(1, 1))
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1))
  )
)

In [None]:
# Training
NUM_EPOCH = 200
lr_G = 2e-4
mu = 10.0

# Initialise optimisers
from itertools import chain
optimiser_G = torch.optim.Adam(chain(G_photoToPaint.parameters(), G_paintToPhoto.parameters()), lr = lr_G)
optimiser_D = torch.optim.Adam(chain(D_paintToPhoto.parameters(), D_photoToPaint.parameters()), lr = lr_G)

lambda_G = lambda epoch: lr_G if epoch <= 100 else (lr_G)-(2*lr_G/(epoch))*(epoch-epoch/2)
lambda_D = lambda epoch: lr_G if epoch <= 100 else (lr_G)-(2*lr_G/(epoch))*(epoch-epoch/2)

scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimiser_G, lambda_G)
scheduler_D = torch.optim.lr_scheduler.LambdaLR(optimiser_D, lambda_D)


In [None]:
# Loss for lists of images
def criterion_binary(prediction, label):
  '''
  prediction: 2D tensors
  label: int. 1.0 for real, 0.0 for fake.
  '''

  criterion = nn.BCEWithLogitsLoss()


  truth = torch.tensor(label).expand_as(prediction)
  truth = truth.to(device)
  loss  = criterion(prediction, truth)

  return loss

def criterion_L1(prediction, truth):
  '''
  prediction: list of fake images
  truth: list of real images
  '''
  criterion = nn.L1Loss()

  loss = criterion(prediction, truth)

  return loss

In [None]:
def setGrad(net, grad):
  for param in net.parameters():
    param.requires_grad = grad


In [None]:
count = 0 # Num of visited images
BATCH_SIZE = 2

for epoch in range(1, NUM_EPOCH+1):
  #Load images
  for (paint_real, photo_real) in zip(paintLoader, photoLoader):
    '''
    if count%BATCH_SIZE == 0 and count > 0:

    #paint_real = paint_batch
    #photo_real = photo_batch

      optimiser_G.zero_grad()
      optimiser_D.zero_grad()

      loss_G = 1.0*loss_G/BATCH_SIZE
      loss_D = 0.5*loss_D/BATCH_SIZE

      loss_G.backward()
      loss_D.backward()

      optimiser_G.step()
      optimiser_D.step()

      with torch.no_grad():
        torch.cuda.empty_cache()
    '''
    optimiser_G.zero_grad()
    optimiser_D.zero_grad()

    # Generating fake images
    fake_paint = G_photoToPaint(photo_real)
    fake_photo = G_paintToPhoto(paint_real)

    # Fix D
    #setGrad(D_photoToPaint, False)
    #setGrad(D_paintToPhoto, False)
    for param in D_paintToPhoto.parameters():
      param.requires_grad = False

    for param in D_photoToPaint.parameters():
      param.requires_grad = False

    # D's decisions
    decision_fakePaint = D_photoToPaint(fake_paint.detach()) # evaluate G_photoToPaint
    decision_fakePhoto = D_paintToPhoto(fake_photo.detach()) # evaluate G_paintToPhoto

    # Cycle
    fake_paintToPhoto = G_paintToPhoto(fake_paint) # type: photo
    fake_photoToPaint = G_photoToPaint(fake_photo) # type: paint


    # Calculate the loss for G
    loss_G = criterion_binary(decision_fakePaint, 1.0) + criterion_binary(decision_fakePhoto, 1.0) + mu*(criterion_L1(fake_paintToPhoto, photo_real)+criterion_L1(fake_photoToPaint, paint_real))
    loss_G.backward()
    optimiser_G.step()

    # Feed real images into D
    for param in D_paintToPhoto.parameters():
      param.requires_grad = True

    for param in D_photoToPaint.parameters():
      param.requires_grad = True

    decision_realPaint = D_photoToPaint(paint_real)
    decision_realPhoto = D_paintToPhoto(photo_real)

    #decision_fakePaint1 = D_photoToPaint(fake_paint) # evaluate G_photoToPaint
    #decision_fakePhoto1 = D_paintToPhoto(fake_photo) # evaluate G_paintToPhoto
    # Loss for D
    loss_D_A = criterion_binary(decision_realPaint, 1.0)  + criterion_binary(decision_fakePaint, 0.0)
    loss_D_A = 0.5*loss_D_A
    loss_D_A.backward(retain_graph=True)

    loss_D_B = criterion_binary(decision_realPhoto, 1.0) + criterion_binary(decision_fakePhoto, 0.0)
    loss_D_B = 0.5*loss_D_B
    loss_D_B.backward(retain_graph=True)

    optimiser_D.step()
    #count += 1

    gc.collect()
    torch.cuda.empty_cache()
    #print(count)

  scheduler_G.step()
  scheduler_D.step()

  print("epoch", epoch, "done")


In [None]:
# Visualisation

In [None]:
photo_batch = next(iter(photoLoader))
fake = G_photoToPaint(photo_batch)
#plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(fake[0],(1,2,0)))
plt.show()

In [None]:
'''from torch._C import dtype
#from torch._C import float32
for (paint_batch, photo_batch) in zip(paintLoader, photoLoader):
  #F.to_tensor(paint_batch[0])
  paint_batch[0].to(device)
  print(paint_batch)
  G_photoToPaint(paint_batch)
  print("fine")
  #print(paint_batch)
  #print(photo_batch)
  break'''