In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt

**Data Loader**

In [None]:
img_size = 256
batch_size = 16
train_dir = ''

In [None]:
transform = transforms.Compose([transforms.CenterCrop(256),transforms.ToTensor()])
train_data = torchvision.datasets.ImageFolder(train_dir, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)

**Generator**

In [None]:
class resBlock(nn.Moudle):
  def _init_(self):
    super(resBlock, self)._init_()
    self.conv_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
    self.conv_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
    self.norm_1 = nn.BatchNorm2d(256)
    self.norm_2 = nn.BatchNorm2d(256)
  def forward(self, x):
    old = x
    x = F.relu(self.norm_1(self.conv_1(x)))
    x = self.norm_2(self.conv_2)
    return old + x

In [None]:
class Generator(nn.Module):
    def __init__(self):
      super(Generator, self).__init__()

      # first block
      self.conv_1 = nn.Conv2d(3, 64, 7, stride=1, padding=3)
      self.norm_1 = nn.BatchNorm2d(64)
      
      # down-convolution #
      self.conv_2 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1)
      self.conv_3 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
      self.norm_2 = nn.BatchNorm2d(128)
      
      self.conv_4 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
      self.conv_5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
      self.norm_3 = nn.BatchNorm2d(256)
      
      # residual blocks #
      resBlock_list = []
      for i in range(8):
        resBlock_list.append(resBlock())
      self.resNetwork = nn.Sequential(*resBlock_list)
      
      # up-convolution #
      self.conv_6 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
      self.conv_7 = nn.ConvTranspose2d(128, 128, kernel_size=3, stride=1, padding=1)
      self.norm_4 = nn.BatchNorm2d(128)

      self.conv_8 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
      self.conv_9 = nn.ConvTranspose2d(64, 64, kernel_size=3, stride=1, padding=1)
      self.norm_5 = nn.BatchNorm2d(64)
      
      self.conv_10 = nn.Conv2d(64, 3, kernel_size=7, stride=1, padding=3)

    def forward(self, x):
      # first block
      x = F.relu(self.norm_1(self.conv_1(x)))
      # down-convolution
      x = F.relu(self.norm_2(self.conv_3(self.conv_2(x))))
      x = F.relu(self.norm_3(self.conv_5(self.conv_4(x))))
      # residual blocks
      x = self.resNetwork(x)
      # up convolution
      x = F.relu(self.norm_4(self.conv_7(self.conv_6(x))))
      x = F.relu(self.norm_5(self.conv_9(self.conv_8(x))))
      # last block
      x = self.conv_10(x)

      #x = sigmoid(x)

      return x

**Discriminator**

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
     super(Discriminator, self).__init__()
     self.conv_1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=1)
      
     self.conv_2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2)
     self.conv_3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1)
     self.norm_1 = nn.BatchNorm2d(128)
      
     self.conv_4 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2)
     self.conv_5 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1)
     self.norm_2 = nn.BatchNorm2d(256)
    
     self.conv_6 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1)
     self.norm_3 = nn.BatchNorm2d(256)
    
     self.conv_7 = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=3, stride=1)

  def forward(self, x):
    x = F.leaky_relu(self.conv_1(x))
    x = F.leaky_relu(self.norm_1(self.conv_3(F.leaky_relu(self.conv_2(x)))))
    x = F.leaky_relu(self.norm_2(self.conv_5(F.leaky_relu(self.conv_4(x)))))
    x = F.leaky_relu(self.norm_3(self.conv_6(x)))
    x = self.conv_7(x)
    #x = sigmoid(x)
    
    return x

**VGG16**

In [None]:
# save pretrained vgg16 model to drive
from torchvision.models import vgg16
vgg16_path = '/content/drive/MyDrive/Colab Notebooks/APS360/Project/VGG16/vgg16_trained.pth'
vgg16_model = vgg16(pretrained=True)
torch.save(vgg16_model.state_dict(),vgg16_path)

In [None]:
# load saved vgg16 model
from torchvision.models import vgg16
vgg16_path = '/content/drive/MyDrive/Colab Notebooks/APS360/Project/VGG16/vgg16_trained.pth'
pretrained_vgg = torch.load(vgg16_path)
vgg16_model = vgg16(pretrained=False)
vgg16_model.load_state_dict(pretrained_vgg)
#print(vgg16_model)

In [None]:
vgg_features = vgg16_model.features[:24]
print(vgg_features)
# freeze the pretrained weights
for param in vgg_features.parameters():
  param.require_grad = False

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (3): ReLU(inplace=True)
  (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (6): ReLU(inplace=True)
  (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (8): ReLU(inplace=True)
  (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (11): ReLU(inplace=True)
  (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (13): ReLU(inplace=True)
  (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (15): ReLU(inplace=True)
  (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (17): Conv2d(256, 512, kernel_si

**Content Loss**

In [None]:
def contentLoss(generate_img,real_img):
  loss = nn.L1Loss()
  return loss(vgg_features(generate_img),vgg_features(real_img))

**Adversarial Loss**

In [None]:
def advLoss(d_cartoon,d_removeEdge,d_generate):
  loss = nn.BCEWithLogitsLoss()
  print(d_cartoon.shape,d_removeEdge,d_generate)
  true_label = torch.ones(d_cartoon.shape)
  fake_label = torch.zeros(d_generate.shape)
  cartoon_loss = loss(d_cartoon,true_label)
  removeEdge_loss = loss(d_removeEdge,fake_label)
  generate_loss = loss(d_generate,fake_label)
  return cartoon_loss+removeEdge_loss+generate_loss

**Train**

In [None]:
G_path = '/content/drive/MyDrive/Colab Notebooks/APS360/Project/Generator'
D_path = '/content/drive/MyDrive/Colab Notebooks/APS360/Project/Discriminator'

In [None]:
lr = 1e-4
epochs = 20

In [None]:
def train(lr,epochs,batch_size):
  G = Generator()
  D = Discriminator()
  g_optimizer = optim.Adam(G.parameters(),lr=lr,betas=(0.5,0.999))
  d_optimizer = optim.Adam(D.parameters(),lr=lr,betas=(0.5,0.999))
  D_loss,G_loss = [],[]
  outputs = []
  test_img = 
  for i in range(epochs):
    for j,((cartoon_img,_),(removeEdge_img,_),(real_img,_)) in enumerate(zip(cartoon_loader,\
                                          smoothEdge_loader,real_loader)):
      # train discriminator
      generate_img = G(real_img)
      # advLoss components
      cartoon_eval = D(cartoon_img)
      removeEdge_eval = D(removeEdge_img)
      generate_eval = D(generate_img)

      adv_loss = advLoss(cartoon_eval,removeEdge_eval,generate_eval)
      adv_loss.backward()

      d_optimizer.step()
      d_optimizer.zero_grad()

      # train generator
      content_loss = adv_loss + 10*contentLoss(generate_img,real_img)
      content_loss.backward()

      g_optimizer.step()
      g_optimizer.zero_grad()

    D_loss.append(adv_loss)
    G_loss.append(content_loss)
    outputs.append((i,real_img,generate_img),)
    print("Epoch: ",i," adv_loss: ",adv_loss," content_loss: ",content_loss)
  if (i+1) % 10==0:
    torch.save(G.state_dict(),G_path+'/G_epoch{i}.pth')
    torch.save(D.state_dict(),D_path+'/D_epoch{i}.pth')
 return outputs 

In [None]:
outputs = train(lr,epochs,batch_size)

In [None]:
# plot outputs
for k in range(0, max_epochs, 5):
    plt.figure(figsize=(9, 2))
    imgs = outputs[k][1].detach().numpy()
    recon = outputs[k][2].detach().numpy()
    for i, item in enumerate(imgs):
        if i >= 9: break
        plt.subplot(2, 9, i+1)
        plt.imshow(item[0])
        
    for i, item in enumerate(recon):
        if i >= 9: break
        plt.subplot(2, 9, 9+i+1)
        plt.imshow(item[0])