In [None]:
import torch
import torch.nn as nn
!pip install q torchinfo
from torchinfo import summary

In [None]:
class FeatureBlock(nn.Module):
  def __init__(self, input_channel, output_channel, kernel_size = 7, padding = 3, stride = 1):
    super().__init__()

    self.conv = nn.Conv2d(input_channel, output_channel, kernel_size = kernel_size, padding = padding, stride = stride, padding_mode = 'reflect')

  def forward(self, x):
    return self.conv(x)

class ContractBlock(nn.Module):
  def __init__(self, input_channel, use_bn, kernel_size, activation = 'relu'):
    super().__init__()

    self.conv = nn.Conv2d(input_channel, 2 * input_channel, kernel_size = 4, padding = 1, stride = 2, padding_mode = 'reflect')
    self.activation = nn.ReLU() if activation == 'relu' else nn.LeakyReLU(0.2)
    self.norm = nn.InstanceNorm2d(2 * input_channel)
    self.use_bn = use_bn #[True, False]

  def forward(self, x):
    x = self.conv(x)
    if self.use_bn:
      x = self.norm(x)
    x = self.activation(x)
    return x

class Discriminator(nn.Module):
  def __init__(self, input_channel, hidden_size =64):
    super().__init__()

    self.feature = FeatureBlock(input_channel, hidden_size)
    self.Contrack1 = ContractBlock(hidden_size, False, 4, 'leakyrelu')    #[2]
    self.Contrack2 = ContractBlock(hidden_size * 2, True, 4, 'leakyrelu') #[4]
    self.Contrack3 = ContractBlock(hidden_size * 4, True, 4, 'leakyrelu') #[8]

    self.conv = nn.Conv2d( hidden_size * 8, 1, kernel_size = 1)

  def forward(self, x):
    x = self.feature(x) # disentangled
    x = self.Contrack1(x)
    x = self.Contrack2(x)
    x = self.Contrack3(x)
    x = self.conv(x)
    return x

model = Discriminator(3)
img = torch.randn((1,3, 256, 256))

assert model(img).shape == (1, 1, 32 ,32)

In [None]:
class ResidualBlock(nn.Module):
  def __init__(self, input_channel):
    super().__init__()
    self.conv1 = nn.Conv2d(input_channel, input_channel, kernel_size = 3, padding = 1, padding_mode = 'reflect')
    self.norm1 = nn.InstanceNorm2d(input_channel)
    self.relu = nn.ReLU()
    self.conv2 = nn.Conv2d(input_channel, input_channel, kernel_size = 3, padding = 1, padding_mode = 'reflect')
    #self.norm2 = nn.InstanceNorm2d(input_channel)

  def forward(self, x):
    x_orginal = x.clone()
    x = self.conv1(x)
    x = self.norm1(x)
    x = self.relu(x)
    x = self.conv2(x)
    x = self.norm1(x)
    return x + x_orginal


class ExpandBlock(nn.Module):
  def __init__(self, input_channel, use_bn):
    super(ExpandBlock, self).__init__()

    self.conv = nn.ConvTranspose2d(input_channel, input_channel // 2, kernel_size = 3, stride = 2, padding = 1, output_padding = 1)
    self.norm = nn.InstanceNorm2d(input_channel // 2)
    self.relu = nn.ReLU()
    self.use_bn = use_bn  #[true , false]

  def forward(self, x):
    x = self.conv(x)
    if self.use_bn:
      x = self.norm(x)

    x = self.relu(x)
    return x

class Generator(nn.Module):
  def __init__(self, input_channel, output_channel, hidden_channel = 64):
    super().__init__()

    self.feature1 = FeatureBlock(input_channel, hidden_channel)

    self.contract1 = ContractBlock(hidden_channel, True, kernel_size = 3, activation = 'relu')
    self.contract2 = ContractBlock(hidden_channel * 2, True, kernel_size = 3, activation = 'relu')

    self.residual1 = ResidualBlock(hidden_channel * 4)
    self.residual2 = ResidualBlock(hidden_channel * 4)
    self.residual3 = ResidualBlock(hidden_channel * 4)
    self.residual4 = ResidualBlock(hidden_channel * 4)
    self.residual5 = ResidualBlock(hidden_channel * 4)
    self.residual6 = ResidualBlock(hidden_channel * 4)
    self.residual7 = ResidualBlock(hidden_channel * 4)
    self.residual8 = ResidualBlock(hidden_channel * 4)
    self.residual9 = ResidualBlock(hidden_channel * 4)

    self.expand1 = ExpandBlock(hidden_channel * 4, use_bn = True)
    self.expand2 = ExpandBlock(hidden_channel * 2, use_bn = True)

    self.feature2 = FeatureBlock(hidden_channel, output_channel)
    self.activation = nn.Tanh()

  def forward(self, x):
    x = self.feature1(x)
    x = self.contract1(x)
    x = self.contract2(x)
    x = self.residual1(x)
    x = self.residual2(x)
    x = self.residual3(x)
    x = self.residual4(x)
    x = self.residual5(x)
    x = self.residual6(x)
    x = self.residual7(x)
    x = self.residual8(x)
    x = self.residual9(x)
    x = self.expand1(x)
    x = self.expand2(x)
    x = self.feature2(x)
    x = self.activation(x)

    return x


model = Generator(3, 3)
img = torch.randn((1, 3, 256, 256))

assert model(img).shape == (1, 3, 256, 256)

In [None]:
summary( model = Generator(3,3),
         input_size = (1,3,256,256),
         col_names = ['input_size', 'output_size', 'num_params', 'trainable'],
         row_settings = ['var_names'],
         col_width=20)

In [None]:
def Discriminator_loss(real_x, fake_x, disc_x, loss_fn):
  A = disc_x(real_x)
  B = disc_x(fake_x)

  C = torch.ones_like(A)
  D = torch.zeros_like(B)

  return (loss_fn(A, C) + loss_fn(B, D)) / 2

In [None]:
def Adversarial_loss(real_x, disc_y, gen_xy, loss_fn):
  fake_y = gen_xy(real_x)
  output = disc_y(fake_y)

  return loss_fn(output, torch.ones_like(output)), fake_y


In [None]:
def Cycle_consistent_loss(real_x, fake_y, gen_yx, loss_fn):
  cycle_x = gen_yx(fake_y)

  return loss_fn(real_x, cycle_x), cycle_x

In [None]:
def Identity_loss(real_x, gen_yx, loss_fn):
  identity_x = gen_yx(real_x)

  return loss_fn(real_x, identity_x), identity_x

In [None]:
def Generator_loss(real_x, real_y, gen_xy, gen_yx, disc_x, disc_y,
                   adv_loss_fn, identity_loss_fn, cycle_loss_fn,
                   lambda_identity = 0.1, lambda_cycle = 10):

  adver_loss1, fake_y = Adversarial_loss(real_x, disc_y, gen_xy, adv_loss_fn)
  adver_loss2, fake_x = Adversarial_loss(real_y, disc_x, gen_yx, adv_loss_fn)

  cycle_loss1 , cycle_x  = Cycle_consistent_loss(real_x, fake_y, gen_yx, cycle_loss_fn)
  cycle_loss2 , cycle_y  = Cycle_consistent_loss(real_y, fake_x, gen_xy, cycle_loss_fn)

  identity_loss1, identity_x  = Identity_loss(real_x, gen_yx, identity_loss_fn)
  identity_loss2, identity_y  = Identity_loss(real_y, gen_xy, identity_loss_fn)

  return (adver_loss1 + adver_loss2) + (cycle_loss1+cycle_loss2) * lambda_cycle + (identity_loss1 + identity_loss2) * lambda_identity, fake_x, fake_y


In [None]:
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

import matplotlib.pyplot as plt
import glob
import random
import os

from PIL import Image
import numpy as np


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
class VangoghDataset(Dataset):
  def __init__(self, transform, mode = 'train'):
    self.transform = transform
    # pathA = '/content/drive/MyDrive/DL/Neural Style Transfer/Data/vangogh2photo/{}A/*.*'.format(mode) ## A: 400
    # pathB = '/content/drive/MyDrive/DL/Neural Style Transfer/Data/vangogh2photo/{}B/*.*'.format(mode) ## B: 6287
    pathA = '/content/drive/MyDrive/Colab Notebooks/images/Data/vangogh2photo/{}A/*.*'.format(mode) ## A: 400
    pathB = '/content/drive/MyDrive/Colab Notebooks/images/Data/vangogh2photo/{}B/*.*'.format(mode) ## B: 6287
    self.image_collectionA = sorted(glob.glob(pathA))
    self.image_collectionB = sorted(glob.glob(pathB))
    self.new_perm()

  def new_perm(self):
      self.randperm = torch.randperm(len(self.image_collectionB))[:len(self.image_collectionA)]

  def __getitem__(self, index):
      item_A = self.transform(Image.open(self.image_collectionA[index % len(self.image_collectionA)]))
      item_B = self.transform(Image.open(self.image_collectionB[self.randperm[index]]))
      if index == len(self) - 1:
          self.new_perm()
      return (item_A - 0.5) * 2, (item_B - 0.5) * 2 #[-1,1]  #这段不理解

  def __len__(self):
      return min(len(self.image_collectionA), len(self.image_collectionB))

In [None]:
def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28)):
    image_tensor = (image_tensor + 1) / 2
    image_shifted = image_tensor
    image_unflat = image_shifted.detach().cpu().view(-1, *size) ## torch(matrix) -> numpy(matrix)
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0)) ### plt( numpy() )
    plt.show()

In [None]:
####### DEFINE MODEL RELATED HYPERPARAMETERS ##########

adv_criterion = nn.MSELoss()
cycle_criterion = identity_criterion = nn.L1Loss()

n_epoches = 30
dim_X = 3
dim_Y = 3

batch_size = 1
lr = 3e-4 ##

target_shape = 256
device = 'cuda' if torch.cuda.is_available() else 'cpu'

transform = transforms.Compose([
    transforms.Resize(286),
    transforms.RandomCrop(target_shape),
    transforms.RandomHorizontalFlip(0.5),
    transforms.ToTensor()
])

dataset = VangoghDataset(transform=transform, mode = 'train')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
#######################################################

In [None]:
test_train, test_val = dataset[0]
print(test_train.shape)
print(test_val.shape)

In [None]:
gen_XY = Generator(dim_X, dim_Y).to(device)
gen_YX = Generator(dim_Y, dim_X).to(device)
gen_opt = torch.optim.Adam(list(gen_XY.parameters()) + list(gen_YX.parameters()), lr=lr, betas=(0.5, 0.999))
disc_X = Discriminator(dim_X).to(device)
disc_X_opt = torch.optim.Adam(disc_X.parameters(), lr=lr, betas=(0.5, 0.999))
disc_Y = Discriminator(dim_Y).to(device)
disc_Y_opt = torch.optim.Adam(disc_Y.parameters(), lr=lr, betas=(0.5, 0.999))

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

gen_XY = gen_XY.apply(weights_init)
gen_YX = gen_YX.apply(weights_init)
disc_X = disc_X.apply(weights_init)
disc_Y = disc_Y.apply(weights_init)

In [None]:
## 如果效果不好，可以适当的调大一点 n_epoches

step= 0

for epoch in range(n_epoches):
  print('this is the {} epoch running '.format(epoch + 1))
  for real_x, real_y in tqdm(dataloader):  ## tqdm： 进度条
    step += 1
    real_x = real_x.to(device)
    real_y = real_y.to(device)

    ##### Update disc X ######
    disc_X_opt.zero_grad()    ##清空gradient
    with torch.no_grad():     ## no_grad() 冻住生成器
      fake_x = gen_YX(real_y)

    disc_x_loss = Discriminator_loss(real_x, fake_x, disc_X, adv_criterion) ## 算loss
    disc_x_loss.backward()  ## loss back propogation
    disc_X_opt.step()       ## 更新参数

    #### Update disc Y ####
    disc_Y_opt.zero_grad()
    with torch.no_grad():
      fake_y = gen_XY(real_x)

    disc_y_loss = Discriminator_loss(real_y, fake_y, disc_Y, adv_criterion)
    disc_y_loss.backward()
    disc_Y_opt.step()

    ##### Update Genereator #####
    gen_opt.zero_grad()

    gen_loss, fake_x, fake_y =  Generator_loss(real_x, real_y, gen_XY, gen_YX, disc_X, disc_Y,
                                                adv_criterion, identity_criterion, cycle_criterion)
    gen_loss.backward()
    gen_opt.step()

    ###############################

    if step % 20 == 0: ## 每20步，我看看我真实的image 和 生成的 image是什么样的
      print('real image')
      show_tensor_images(torch.cat([real_x, real_y]), size = (3, 256, 256))

      print('fake image')
      show_tensor_images(torch.cat([fake_x, fake_y]), size = (3, 256, 256))



In [None]:
path = '/content/drive/MyDrive/Colab Notebooks/images/Data/fujisan.jpg'

content_img = Image.open(path)

def prepocessing(img, image_shape = (256,256) , device = device):
  transforms = transforms.Compose([
      transforms.Resize(image_shape),
      transforms.ToTensor(),  # (3, 256, 256)
  ])
  return transforms(img).unsqueeze(0).to(device) # (1,3,256,256)

content_img = prepocessing(content_img) # [1, 3, 256, 256]

print(content_img.image_shape)

In [None]:
plt.imshow(gen_XY(content_img)[0].permute(1,2,0))

In [None]:
plt.imshow(gen_YX(content_img)[0].permute(1,2,0))