In [None]:
import torch.nn as nn
import torch
import os

import torch.nn.functional as F
import torchvision.transforms as T
from torchvision.utils import save_image
import tqdm
import torchvision
import torch.optim as optim

In [None]:
os.makedirs("./drive/MyDrive/preds/", exist_ok=True)
os.makedirs("./drive/MyDrive/preds/A/", exist_ok=True)
os.makedirs("./drive/MyDrive/preds/B/", exist_ok=True)

In [None]:

!rm -rf /content/drive/*

In [None]:

from google.colab import drive
drive.mount('/content/drive')

trA = './drive/MyDrive/selfie2anime/trainA'
trB = './drive/MyDrive/selfie2anime/trainB'
teA = './drive/MyDrive/selfie2anime/testA'
teB = './drive/MyDrive/selfie2anime/testB'



MessageError: Error: credential propagation was unsuccessful

In [None]:
preprocess = T.Compose(
    [
        T.RandomHorizontalFlip(),
        T.Resize((128,128)),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

trainA = torchvision.datasets.ImageFolder(root=trA, transform=preprocess)
trainB = torchvision.datasets.ImageFolder(root=trB, transform=preprocess)
testA = torchvision.datasets.ImageFolder(root=teA, transform=preprocess)
testB = torchvision.datasets.ImageFolder(root=teB, transform=preprocess)

dlTrainA = torch.utils.data.DataLoader(trainA, batch_size=16,shuffle=True, num_workers=10)
dlTrainB = torch.utils.data.DataLoader(trainB, batch_size=16,shuffle=True, num_workers=10)
dlTestA = torch.utils.data.DataLoader(testA, batch_size=16, shuffle=False, num_workers=10)
dlTestB = torch.utils.data.DataLoader(testB, batch_size=16, shuffle=False, num_workers=10)

NameError: name 'trA' is not defined

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.conv1 = nn.Sequential(
        nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=2, padding=2),
        nn.LeakyReLU(0.2, inplace=True),
    )
    self.conv2 = nn.Sequential(
        nn.Conv2d(in_channels=64, out_channels=128, kernel_size=5, stride=2, padding=2),
        nn.InstanceNorm2d(128),
        nn.LeakyReLU(0.2, inplace=True),
    )
    self.conv3 = nn.Sequential(
        nn.Conv2d(in_channels=128, out_channels=256, kernel_size=5, stride=2, padding=2),
        nn.InstanceNorm2d(256),
        nn.LeakyReLU(0.2, inplace=True),
    )

    self.fc1 = nn.Linear(256*32*32, 1)

  def forward(self,x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = self.conv3(x)
    x = x.reshape(-1, 256 * 32 * 32)
    x = self.fc1(x)
    out = torch.sigmoid(x)
    return out

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()
    self.conv1 = nn.Sequential(
      nn.ReflectionPad2d(2),
      nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=2, padding=0),
      nn.InstanceNorm2d(64),
      nn.ReLU(inplace=True)
    )
    self.down1 = nn.Sequential(
      nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
      nn.InstanceNorm2d(128),
      nn.ReLU(inplace=True),
      nn.Dropout(p=0.5),
      nn.ReflectionPad2d(1),
      nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0),
      nn.InstanceNorm2d(128),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(kernel_size=2, stride=2)
    )
    self.down2 = nn.Sequential(
      nn.ReflectionPad2d(1),
      nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0),
      nn.InstanceNorm2d(128),
      nn.ReLU(inplace=True),
      nn.Dropout(p=0.5),
      nn.ReflectionPad2d(1),
      nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0),
      nn.InstanceNorm2d(128),
    )

    self.up1 = nn.Sequential(
      nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=2, stride=2),
      nn.InstanceNorm2d(64),
      nn.ReLU(inplace=True),
    )

    self.final_conv = nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=2, stride=2)

  def forward(self, x):
    x = self.conv1(x)
    x = self.down1(x)
    x = self.down2(x)
    x = self.down2(x)
    x = self.down2(x)
    x = self.down2(x)
    x = self.up1(x)
    out = self.final_conv(x)
    return out

In [None]:
import torchvision.transforms.functional as F

def train(loader_train_a, loader_train_b, loader_test_a, loader_test_b, gen_a, gen_b, dis_a, dis_b, gen_optimizer, dis_optimizer, epochs=100):
  criterion = torch.nn.MSELoss()
  criterion_L1 = torch.nn.L1Loss()
  gen_a = gen_a.cuda()
  gen_b = gen_b.cuda()
  dis_a = dis_a.cuda()
  dis_b = dis_b.cuda()

  for e in range(epochs):
    print("epoch" + str(e))

    total_loss_a = 0
    total_loss_b = 0
    total_gen_loss_a = 0
    total_gen_loss_b = 0
    for im_a, im_b in tqdm.tqdm(zip(loader_train_a, loader_train_b), total=len(loader_train_a)):
      real_a = im_a[0].cuda()
      real_b = im_b[0].cuda()

      # train generator b cycle
      gen_optimizer.zero_grad()
      fake_a = gen_b(real_b).detach()
      pred_fake_a = dis_a(fake_a)
      loss_gen_b = criterion(pred_fake_a, torch.ones_like(pred_fake_a))
      recovered_b = gen_a(fake_a)
      loss_cycle_1 = criterion_L1(recovered_b, real_b) * 10
      b_id = gen_b(real_a)
      loss_b_id = criterion_L1(b_id, real_a) * 10
      final_loss_gen_b = loss_gen_b + loss_cycle_1 + loss_b_id
      final_loss_gen_b.backward(retain_graph=True)
      gen_optimizer.step()

      # train generator a cycle
      gen_optimizer.zero_grad()
      fake_b = gen_a(real_a).detach()
      pred_fake_b = dis_b(fake_b)
      loss_gen_a = criterion(pred_fake_b, torch.ones_like(pred_fake_b))
      recovered_a = gen_b(fake_b)
      loss_cycle_2 = criterion_L1(recovered_a, real_a) * 10
      a_id = gen_a(real_b)
      loss_a_id = criterion_L1(a_id, real_b) * 10
      final_loss_gen_a = loss_gen_a + loss_cycle_2 + loss_a_id
      final_loss_gen_a.backward(retain_graph=True)
      gen_optimizer.step()

      # train discriminator a
      # real
      dis_optimizer.zero_grad()
      pred_real_b = dis_a(real_b)
      loss_dis_a_real = criterion(pred_real_b, torch.ones_like(pred_real_b))

      # fake
      fake_b = gen_a(real_a)
      pred_fake_b = dis_a(fake_b)
      loss_dis_a_fake = criterion(pred_fake_b, torch.ones_like(pred_fake_b))

      loss_dis_a = (loss_dis_a_real + loss_dis_a_fake) / 2
      loss_dis_a.backward()
      dis_optimizer.step()

      # train discriminator b
      # real
      dis_optimizer.zero_grad()
      pred_real_a = dis_b(real_a)
      loss_dis_b_real = criterion(pred_real_a, torch.ones_like(pred_real_a))

      #fake
      fake_a = gen_b(real_b)
      pred_fake_a = dis_b(fake_a)
      loss_dis_b_fake = criterion(pred_fake_a, torch.ones_like(pred_fake_a))

      loss_dis_b = (loss_dis_b_real + loss_dis_b_fake) / 2
      loss_dis_b.backward()
      dis_optimizer.step()


      total_loss_a += loss_dis_a.item()
      total_loss_b += loss_dis_b.item()
      total_gen_loss_a += final_loss_gen_a.item()
      total_gen_loss_b += final_loss_gen_b.item()



    for test_a, test_b in tqdm.tqdm(zip(loader_test_a, loader_train_b), total=len(loader_test_a)):
      test_a = test_a[0].cuda()
      test_b = test_b[0].cuda()
      fake_test_a = gen_b(test_b)
      fake_test_b = gen_a(test_a)


      save_image(fake_test_a * 0.5 + 0.5, f"./drive/MyDrive/preds/A/anime_{e}_{d}.png")
      save_image(fake_test_b * 0.5 + 0.5, f"./drive/MyDrive/preds/B/human_{e}_{d}.png")

    print(f"Epoch [{e+1}/{epochs}], Loss D_A: {total_loss_a / 3400}, Loss D_B: {total_loss_b / 3400}, Loss G_A: {total_gen_loss_a / 3400}, Loss G_B: {total_gen_loss_b / 3400}")




In [None]:
gen_A = Generator()
gen_B = Generator()
dis_A = Discriminator()
dis_B = Discriminator()
gen_lr = 0.0002
dis_lr = 0.0001

gen_opt = optim.Adam(list(gen_A.parameters()) + list(gen_B.parameters()), lr=gen_lr, betas=(0.5, 0.999))
dis_opt = optim.Adam(list(dis_A.parameters()) + list(dis_B.parameters()), lr=dis_lr, betas=(0.5, 0.999))


train(dlTrainA, dlTrainB, dlTestA, dlTestB, gen_A, gen_B, dis_A, dis_B, gen_opt, dis_opt, epochs=100)


epoch0


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:08<00:00,  1.24s/it]


Epoch [1/100], Loss D_A: 4.482887778609743e-05, Loss D_B: 9.037817992099651e-05, Loss G_A: 0.4005918421464808, Loss G_B: 0.4029512235697578
epoch1


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:03<00:00,  2.16it/s]


Epoch [2/100], Loss D_A: 2.0882642020831767e-15, Loss D_B: 5.748081749141634e-15, Loss G_A: 0.32073316546047437, Loss G_B: 0.3269041445676018
epoch2


100%|██████████| 220/220 [01:39<00:00,  2.20it/s]
100%|██████████| 7/7 [00:04<00:00,  1.74it/s]


Epoch [3/100], Loss D_A: 2.377183417432688e-15, Loss D_B: 5.395945128625231e-15, Loss G_A: 0.286620789345573, Loss G_B: 0.29506322489065284
epoch3


100%|██████████| 220/220 [01:39<00:00,  2.20it/s]
100%|██████████| 7/7 [00:03<00:00,  2.14it/s]


Epoch [4/100], Loss D_A: 2.3359092438113293e-15, Loss D_B: 5.8912352120815364e-15, Loss G_A: 0.268521953680936, Loss G_B: 0.27697166947757496
epoch4


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:03<00:00,  1.87it/s]


Epoch [5/100], Loss D_A: 2.6222161696911343e-15, Loss D_B: 5.126356855351546e-15, Loss G_A: 0.25823909184511973, Loss G_B: 0.267580541863161
epoch5


100%|██████████| 220/220 [01:39<00:00,  2.21it/s]
100%|██████████| 7/7 [00:03<00:00,  2.24it/s]


Epoch [6/100], Loss D_A: 2.3285948332961517e-15, Loss D_B: 5.775249559626579e-15, Loss G_A: 0.25171496223000916, Loss G_B: 0.2613199361632852
epoch6


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:03<00:00,  2.01it/s]


Epoch [7/100], Loss D_A: 2.610722096024427e-15, Loss D_B: 5.326980686624987e-15, Loss G_A: 0.2456011907493367, Loss G_B: 0.25563042928190793
epoch7


100%|██████████| 220/220 [01:40<00:00,  2.20it/s]
100%|██████████| 7/7 [00:04<00:00,  1.50it/s]


Epoch [8/100], Loss D_A: 2.619603880221428e-15, Loss D_B: 5.435129470670825e-15, Loss G_A: 0.24128412134507124, Loss G_B: 0.25247159186531515
epoch8


100%|██████████| 220/220 [01:40<00:00,  2.20it/s]
100%|██████████| 7/7 [00:03<00:00,  2.07it/s]


Epoch [9/100], Loss D_A: 2.757532764221918e-15, Loss D_B: 6.0291640960820265e-15, Loss G_A: 0.23729765674647163, Loss G_B: 0.24765449306544135
epoch9


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:04<00:00,  1.48it/s]


Epoch [10/100], Loss D_A: 2.8787429956162883e-15, Loss D_B: 6.577222426826398e-15, Loss G_A: 0.23451812940485336, Loss G_B: 0.2441905909426072
epoch10


100%|██████████| 220/220 [01:39<00:00,  2.20it/s]
100%|██████████| 7/7 [00:03<00:00,  2.11it/s]


Epoch [11/100], Loss D_A: 2.7831332010250393e-15, Loss D_B: 5.9116110699452456e-15, Loss G_A: 0.23134395795709947, Loss G_B: 0.24115103314904607
epoch11


100%|██████████| 220/220 [01:40<00:00,  2.18it/s]
100%|██████████| 7/7 [00:03<00:00,  2.22it/s]


Epoch [12/100], Loss D_A: 2.729842495843032e-15, Loss D_B: 5.567311317837962e-15, Loss G_A: 0.2288828524421243, Loss G_B: 0.23930642625864815
epoch12


100%|██████████| 220/220 [01:39<00:00,  2.20it/s]
100%|██████████| 7/7 [00:04<00:00,  1.54it/s]


Epoch [13/100], Loss D_A: 2.792537443115982e-15, Loss D_B: 5.2355505551852675e-15, Loss G_A: 0.2265384707731359, Loss G_B: 0.2375940868433784
epoch13


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:03<00:00,  2.20it/s]


Epoch [14/100], Loss D_A: 2.573105127660657e-15, Loss D_B: 5.638365591413972e-15, Loss G_A: 0.22439899423543144, Loss G_B: 0.235216386458453
epoch14


100%|██████████| 220/220 [01:41<00:00,  2.18it/s]
100%|██████████| 7/7 [00:03<00:00,  2.10it/s]


Epoch [15/100], Loss D_A: 2.6974501064186745e-15, Loss D_B: 5.633141012474559e-15, Loss G_A: 0.22190773178549375, Loss G_B: 0.2331349649148829
epoch15


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:04<00:00,  1.46it/s]


Epoch [16/100], Loss D_A: 2.794104816797806e-15, Loss D_B: 5.446623544337533e-15, Loss G_A: 0.21988258768530453, Loss G_B: 0.2302603750369128
epoch16


100%|██████████| 220/220 [01:40<00:00,  2.20it/s]
100%|██████████| 7/7 [00:03<00:00,  1.89it/s]


Epoch [17/100], Loss D_A: 2.780520911555333e-15, Loss D_B: 6.020282311885025e-15, Loss G_A: 0.2183893400781295, Loss G_B: 0.22824815974516027
epoch17


100%|██████████| 220/220 [01:41<00:00,  2.16it/s]
100%|██████████| 7/7 [00:04<00:00,  1.65it/s]


Epoch [18/100], Loss D_A: 2.750740811600682e-15, Loss D_B: 5.488942633746774e-15, Loss G_A: 0.2169703830690945, Loss G_B: 0.22733617803629708
epoch18


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:03<00:00,  1.83it/s]


Epoch [19/100], Loss D_A: 2.8332891588433994e-15, Loss D_B: 5.593956670428965e-15, Loss G_A: 0.21442745208740235, Loss G_B: 0.22561847918173847
epoch19


100%|██████████| 220/220 [01:41<00:00,  2.17it/s]
100%|██████████| 7/7 [00:03<00:00,  2.02it/s]


Epoch [20/100], Loss D_A: 2.757532764221918e-15, Loss D_B: 5.6138100703987325e-15, Loss G_A: 0.21367186195710125, Loss G_B: 0.22440182952319873
epoch20


100%|██████████| 220/220 [01:39<00:00,  2.21it/s]
100%|██████████| 7/7 [00:05<00:00,  1.40it/s]

Epoch [21/100], Loss D_A: 2.7956721904796293e-15, Loss D_B: 5.857797906869297e-15, Loss G_A: 0.21163254913161783, Loss G_B: 0.22182141072609846
epoch21



100%|██████████| 220/220 [01:39<00:00,  2.20it/s]
100%|██████████| 7/7 [00:03<00:00,  2.19it/s]


Epoch [22/100], Loss D_A: 2.6175140486456634e-15, Loss D_B: 5.42885997594353e-15, Loss G_A: 0.21032530307769776, Loss G_B: 0.22134423957151525
epoch22


100%|██████████| 220/220 [01:41<00:00,  2.16it/s]
100%|██████████| 7/7 [00:03<00:00,  1.95it/s]


Epoch [23/100], Loss D_A: 2.8479179798737546e-15, Loss D_B: 5.7783843069902265e-15, Loss G_A: 0.2091499278124641, Loss G_B: 0.21883314357084385
epoch23


100%|██████████| 220/220 [01:41<00:00,  2.18it/s]
100%|██████████| 7/7 [00:04<00:00,  1.54it/s]


Epoch [24/100], Loss D_A: 2.8416484851464593e-15, Loss D_B: 5.279437018276333e-15, Loss G_A: 0.20822626941344316, Loss G_B: 0.21868279323858372
epoch24


100%|██████████| 220/220 [01:39<00:00,  2.20it/s]
100%|██████████| 7/7 [00:03<00:00,  2.02it/s]


Epoch [25/100], Loss D_A: 2.596615732888013e-15, Loss D_B: 5.534918928413604e-15, Loss G_A: 0.2067115160998176, Loss G_B: 0.21705570627661314
epoch25


100%|██████████| 220/220 [01:40<00:00,  2.18it/s]
100%|██████████| 7/7 [00:03<00:00,  2.11it/s]


Epoch [26/100], Loss D_A: 2.6979725643126157e-15, Loss D_B: 5.68381942818686e-15, Loss G_A: 0.20586524879231172, Loss G_B: 0.21611692225231843
epoch26


100%|██████████| 220/220 [01:39<00:00,  2.20it/s]
100%|██████████| 7/7 [00:04<00:00,  1.73it/s]


Epoch [27/100], Loss D_A: 2.828064579903987e-15, Loss D_B: 5.6153774440805564e-15, Loss G_A: 0.204448940894183, Loss G_B: 0.21411834148799672
epoch27


100%|██████████| 220/220 [01:39<00:00,  2.20it/s]
100%|██████████| 7/7 [00:03<00:00,  2.18it/s]


Epoch [28/100], Loss D_A: 2.809256095722102e-15, Loss D_B: 5.676505017671683e-15, Loss G_A: 0.20341865097775177, Loss G_B: 0.213243668570238
epoch28


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:03<00:00,  2.08it/s]


Epoch [29/100], Loss D_A: 2.644159401236667e-15, Loss D_B: 5.137328471124313e-15, Loss G_A: 0.2028634837094475, Loss G_B: 0.2120032221429488
epoch29


100%|██████████| 220/220 [01:40<00:00,  2.19it/s]
100%|██████████| 7/7 [00:03<00:00,  1.81it/s]


Epoch [30/100], Loss D_A: 2.738724280040033e-15, Loss D_B: 5.335340012928046e-15, Loss G_A: 0.2023890568228329, Loss G_B: 0.21172735052950242
epoch30


 30%|███       | 66/220 [00:30<01:09,  2.23it/s]