In [1]:
PATH = "/kaggle/input/cyclegan"
import sys
sys.path.insert(1, PATH)
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import itertools
import torchvision.transforms as tt
from PIL import Image



from utils import *

from tqdm import tqdm
from torchvision.utils import save_image
from models.Discriminator import PatchDisc
from models.Generator import ResNetGen
from data.Ukiyo import MyDataset

torch.Size([3, 256, 256])
torch.Size([55, 3, 256, 256])
50
torch.Size([1, 15, 15])


In [2]:
class config:
    PATH_0 = r"/kaggle/input/selfie2anime/trainA"
    PATH_1 = r"/kaggle/input/selfie2anime/trainB"
    PATH_0_test = r"/kaggle/input/selfie2anime/testA"
    BATCH = 1
    LOAD = True
    lr = 0.0001
    LAMBDA_A = 10
    LAMBDA_B = 10
    LAMBDA_IDT = 0
    PATH_D_A = r"/kaggle/input/checkpoint-75/75_e_0.0001_lr/DiscA.pth.tar"
    PATH_D_F = r"/kaggle/input/checkpoint-75/75_e_0.0001_lr/DiscF.pth.tar"
    PATH_G_A_F = r"/kaggle/input/checkpoint-75/75_e_0.0001_lr/GenA_F.pth.tar"
    PATH_G_F_A = r"/kaggle/input/checkpoint-75/75_e_0.0001_lr/GenF_A.pth.tar"
    SAVE = True
    epoches = 25
    epoch_count = 1
    device = torch.device('cuda')
    
    transform=tt.Compose([ 
        tt.ToTensor(), 
        tt.Normalize((0.5, 0.5, 0.5), 
                     (0.5, 0.5, 0.5))])

In [3]:
def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + config.epoch_count - 100) / float(100 + 1)
            return lr_l
        scheduler = torch.optim.lr_scheduler.LambdaLR.LambdaLR(optimizer, lr_lambda=lambda_rule)

def plot_ing(x, y):
    
    plt.figure(figsize = (12, 6))

    plt.subplot(1, 2, 1)
    plt.title("Real face")
    plt.imshow(torch.permute(x, (1, 2, 0)).detach().to("cpu"))

    plt.subplot(1, 2, 2)
    plt.title("Real face")
    plt.imshow(torch.permute(y, (1, 2, 0)).detach().to("cpu"))

    plt.show()

DiscA = PatchDisc().to(device)
DiscA.apply(weights_init)

DiscF = PatchDisc().to(device)
DiscF.apply(weights_init)

GenA_F = ResNetGen().to(device)
GenA_F.apply(weights_init)

GenF_A = ResNetGen().to(device)
GenF_A.apply(weights_init)

opt_D = torch.optim.Adam(
    itertools.chain(DiscA.parameters(), DiscF.parameters()),
    lr=config.lr, betas=(0.5, 0.999))

opt_G = torch.optim.Adam(
    itertools.chain(GenA_F.parameters(), GenF_A.parameters()),
    lr=config.lr, betas=(0.5, 0.999))

load_checkpoint(config.PATH_D_A, DiscA, opt_D, config.lr, device=device)
load_checkpoint(config.PATH_D_F, DiscF, opt_D, config.lr, device=device)

load_checkpoint(config.PATH_G_A_F, GenA_F, opt_G, config.lr, device=device)
load_checkpoint(config.PATH_G_F_A, GenF_A, opt_G, config.lr, device=device)

for i in imgs_test:
    x = Image.open(i).convert("RGB")
    x = config.transform(x).to(device)
    y = GenA_F(x)
    plot_ing(x * 0.5 + 0.5, y * 0.5 + 0.5)
    
    
    

plot_ing(x, y)

In [4]:
import torch
from PIL import Image
import os
from torch.utils.data import Dataset


class MyDataset(Dataset):

    def __init__(self, path_0, path_1, transform=config.transform):
        super().__init__()

        self.path_0 = path_0
        self.path_1 = path_1

        self.img_0 = os.listdir(path_0)
        self.img_1 = os.listdir(path_1)

        self.max_length = max(len(self.img_0), len(self.img_1))
        self.transform = transform

    def __getitem__(self, index):
        x = os.path.join(self.path_0, self.img_0[index % len(self.img_0)])
        y = os.path.join(self.path_1, self.img_1[index % len(self.img_1)])
        
        x = Image.open(x).convert("RGB")
        y = Image.open(y).convert("RGB")

        if self.transform is not None:
            x = self.transform(x)
            y = self.transform(y)

        return x, y

    def __len__(self):
        return self.max_length

In [5]:
device = torch.device('cuda')

def fit(DiscA, DiscF, GenF_A, GenA_F,
        opt_D, opt_G, mse, l1, loader,
        Buffer_A, Buffer_F, lambda_a,
        lambda_b, lambda_idt=0, sched_D=None,
        sched_G=None, penalty=None, device=device, show=False):


    for i, (A_real, F_real) in enumerate(tqdm(loader)):
        A_real = A_real.to(device)
        F_real = F_real.to(device)

        #Train D_A
        A_fake = GenF_A(F_real)
        A_fake_buff = Buffer_A.extract(A_fake.detach())
        D_A_real = DiscA(A_real)
        D_A_fake = DiscA(A_fake_buff)
        MSE_A_real = mse(D_A_real, torch.ones_like(D_A_real))
        MSE_A_fake = mse(D_A_fake, torch.zeros_like(D_A_fake))

        D_A_loss = (MSE_A_real + MSE_A_fake) / 2

        #Train D_F

        F_fake = GenA_F(A_real)
        F_fake_buff = Buffer_F.extract(F_fake.detach( ))
        D_F_real = DiscF(F_real)
        D_F_fake = DiscF(F_fake_buff)
        MSE_F_real = mse(D_F_real, torch.ones_like(D_F_real))
        MSE_F_fake = mse(D_F_fake, torch.zeros_like(D_F_fake))
        D_F_loss = (MSE_F_real + MSE_F_fake) / 2

        if penalty is not None:
            D_F_loss += grad_penalty(DiscF, F_fake_buff, F_real, device=device)
            D_A_loss += grad_penalty(DiscA, A_fake_buff, A_real, device=device)

        D_loss = D_A_loss + D_F_loss
        
        opt_D.zero_grad()
        D_loss.backward()
        opt_D.step()
        if sched_D is not None:
            sched_D.step()

        #Train Generator


        D_F_fake = DiscF(F_fake)
        D_A_fake = DiscA(A_fake)

        loss_G_F = mse(D_F_fake, torch.ones_like(D_F_fake))
        loss_G_A = mse(D_A_fake, torch.ones_like(D_A_fake))

        A_rec = GenF_A(F_fake)
        F_rec = GenA_F(A_fake)

        A_cycle_loss = l1(A_rec, A_real) * lambda_a
        F_cycle_loss = l1(F_rec, F_real) * lambda_b

        loss = loss_G_F + loss_G_A + A_cycle_loss + F_cycle_loss

        if lambda_idt > 0:

            idt_A = GenF_A(A_real)
            idt_F = GenA_F(F_real)

            idt_loss_A = l1(idt_A, A_real) * lambda_a * lambda_idt
            idt_loss_F = l1(idt_F, F_real) * lambda_b * lambda_idt

            loss += idt_loss_A + idt_loss_F

        opt_G.zero_grad()
        loss.backward()
        opt_G.step()
        if sched_G is not None:
            sched_G.step()
        
        
        if show:
            if i % 500 == 0:
                print(Buffer_A.n_images)
                x = torch.cat([A_real, F_fake, A_rec], dim=0)
                y = torch.cat([F_real, A_fake, F_rec], dim=0)
                plot_reconstruct(x, y)
        
    if sched_D is not None:
            sched_D.step()
            print("D lr after sched:", opt_D.param_groups[0]["lr"])
    
    if sched_G is not None:
            sched_G.step()


In [None]:
DiscA = PatchDisc().to(device)
DiscA.apply(weights_init)

DiscF = PatchDisc().to(device)
DiscF.apply(weights_init)

GenA_F = ResNetGen().to(device)
GenA_F.apply(weights_init)

GenF_A = ResNetGen().to(device)
GenF_A.apply(weights_init)

opt_D = torch.optim.Adam(
    itertools.chain(DiscA.parameters(), DiscF.parameters()),
    lr=config.lr, betas=(0.5, 0.999))

opt_G = torch.optim.Adam(
    itertools.chain(GenA_F.parameters(), GenF_A.parameters()),
    lr=config.lr, betas=(0.5, 0.999))

sched_D = torch.optim.lr_scheduler.LambdaLR.LambdaLR(opt_D, lr_lambda=lambda_rule)
sched_G = torch.optim.lr_scheduler.LambdaLR.LambdaLR(opt_G, lr_lambda=lambda_rule)

df = MyDataset(config.PATH_0,
               config.PATH_1)

train_dl = DataLoader(
    df, batch_size=config.BATCH, shuffle=True)

mse = torch.nn.MSELoss()
l1 = torch.nn.L1Loss()
lambda_a = config.LAMBDA_A
lambda_b = config.LAMBDA_B
lambda_idt = config.LAMBDA_IDT

if config.LOAD:
    load_checkpoint(config.PATH_D_A, DiscA, opt_D, config.lr, device=device)
    load_checkpoint(config.PATH_D_F, DiscF, opt_D, config.lr, device=device)

    load_checkpoint(config.PATH_G_A_F, GenA_F, opt_G, config.lr, device=device)
    load_checkpoint(config.PATH_G_F_A, GenF_A, opt_G, config.lr, device=device)

Buffer_A = ImageBuffer(50)
Buffer_F = ImageBuffer(50)


for epoch in range(config.epoches):
    
    print("Epoch:", epoch)
    fit(DiscA, DiscF, GenF_A, GenA_F,
        opt_D, opt_G, mse, l1,
        train_dl, Buffer_A, Buffer_F,
        lambda_a, lambda_b, lambda_idt,
        sched_D=sched_D, sched_G=sched_G, penalty=None)

    if config.SAVE:
        save_checkpoint(DiscF, opt_D, filename="DiscF.pth.tar")
        save_checkpoint(DiscA, opt_D, filename="DiscA.pth.tar")
        save_checkpoint(GenA_F, opt_G, filename="GenA_F.pth.tar")
        save_checkpoint(GenF_A, opt_G, filename="GenF_A.pth.tar")


Loading checkpoint
Loading checkpoint
Loading checkpoint
Loading checkpoint
Epoch: 0


100%|██████████| 3400/3400 [13:53<00:00,  4.08it/s]


Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Epoch: 1


100%|██████████| 3400/3400 [12:53<00:00,  4.40it/s]


Saving checkpoint
Saving checkpoint
Saving checkpoint
Saving checkpoint
Epoch: 2


100%|██████████| 3400/3400 [12:52<00:00,  4.40it/s]

In [None]:
df[0]