In [None]:
import time
import os
import cv2
import numpy as np
from PIL import Image
from glob import glob
from tqdm import tqdm

from models import FaceSwapper

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import torchvision.transforms as transforms

from torchvision.utils import make_grid

import matplotlib.pyplot as plt

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

%matplotlib inline

In [None]:
# %load Train.ipynb

In [None]:
# from visdom import Visdom

# class VisdomEnv(object):
#     def __init__(self, env_name='main'):
#         self.viz = Visdom()
#         self.env = env_name

# class VisdomPlotter(VisdomEnv):
#     def __init__(self):
#         super(VisdomPlotter, self).__init__()
#         self.plots = {}
        
#     def plot(self, var_name, split_name, title_name, x, y):
#         if var_name not in self.plots:
#             self.plots[var_name] = self.viz.line(X=np.array([x,x]), Y=np.array([y,y]), env=self.env, opts=dict(
#                 legend=[split_name],
#                 title=title_name,
#                 xlabel='Epochs',
#                 ylabel=var_name
#             ))
#         else:
#             self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update = 'append')
            
# class VisdomImage(VisdomEnv):  
#     def __init__(self):
#         super(VisdomImage, self).__init__()
        
#     def show(self, img, text=None):
#         self.viz.text(text)
#         self.viz.image(img)
        
# plotter = VisdomPlotter()
# imager = VisdomImage()

In [None]:
training_previews_folder = "./training_previews"
show_every = 100

# if not os.path.exists(training_previews_folder):
#     os.makedirs(training_previews_folder)

def training_window(ogA, ogB, predA, predB, epoch, iteration):
    show = True
        
    if show:
        if iteration % show_every == 0:
        
            if len(ogA) < 4 or len(ogB) < 4 or len(predA) < 4 or len(predB) < 4 :
                raise ValueError("lenghts of imgs has to be grater that 4.")

            ogA = ogA[:6]
            ogB = ogB[:6]

            predA = predA[:6]
            predB = predB[:6]

            swapA = model(ogA, "B")
            swapB = model(ogB, "A")

            A_actual = []
            A_recreated = []
            A_swapped = []
            for i in range(len(ogA)):
                A_actual.append(ogA[i].permute(1,2,0).detach().cpu().numpy())
                A_recreated.append(predA[i].permute(1,2,0).detach().cpu().numpy())
                A_swapped.append(swapA[i].permute(1,2,0).detach().cpu().numpy())

            B_actual = []
            B_recreated = []
            B_swapped = []
            for i in range(len(ogB)):
                B_actual.append(ogB[i].permute(1,2,0).detach().cpu().numpy())
                B_recreated.append(predB[i].permute(1,2,0).detach().cpu().numpy())
                B_swapped.append(swapB[i].permute(1,2,0).detach().cpu().numpy())

            col1 = np.vstack(A_actual)
            col2 = np.vstack(A_recreated)
            col3 = np.vstack(A_swapped)

            col4 = np.vstack(B_actual)
            col5 = np.vstack(B_recreated)
            col6 = np.vstack(B_swapped)

            window = np.hstack([col1, col2, col3, col4, col5, col6])
    #         window = cv2.resize(window, (800, 800), interpolation=cv2.INTER_AREA)

            cv2.imshow("Training window", cv2.cvtColor(window, cv2.COLOR_BGR2RGB))
#             imager.show(window)
    
            # cv2.imwrite(f'{training_previews_folder}/preview_at_epoch{epoch}_iteration{iteration}.jpg', window) 
        
        if cv2.waitKey(33)==27:    # Esc key to stop
            show=False
# for i in range(20):
#     imgt = Image.open(fls[i])
#     cv2.imshow("preview some", cv2.cvtColor(np.array(imgt), cv2.COLOR_BGR2RGB))
#     time.sleep(1)
#     k = cv2.waitKey(33)
#     if k==27:    # Esc key to stop
#         break

In [None]:
class Faceset(Dataset):
    def __init__(self, root, transforms):
        
        self.transforms = transforms
        
        self.imgs_A = glob(f'{root}/A/*')
        self.imgs_B = glob(f'{root}/B/*')    
        
    def __len__(self):
        if len(self.imgs_A) > len(self.imgs_B):
            return len(self.imgs_B)
        else:
            return len(self.imgs_A)
    
    def __getitem__(self, idx):
        
        img_A = self.transforms(Image.open(self.imgs_A[idx]))
        img_B = self.transforms(Image.open(self.imgs_B[idx]))
        
        return img_A, img_B

In [None]:
epochs = 5
# change_lr_each = 100 #epochs

In [None]:
transforms = transforms.Compose([
    transforms.Resize(64),
    transforms.CenterCrop(64),
    transforms.ToTensor(),
])

dataset = Faceset('./data/faces/', transforms)
loader = DataLoader(dataset,
                    batch_size=64,
                    shuffle=True,
                    num_workers=2)

In [None]:
len(dataset)

In [None]:
model = FaceSwapper().to(device)

criterion = nn.SmoothL1Loss()#nn.BCELoss() -> Tanh # Try RMSE or nn.SmoothL1Loss() -> sigmoid

optimizerA = optim.Adam([{'params': model.encoder.parameters()}, {'params': model.decoderA.parameters()}], lr=5e-5, betas=(0.5, 0.999))
optimizerB = optim.Adam([{'params': model.encoder.parameters()}, {'params': model.decoderB.parameters()}], lr=5e-5, betas=(0.5, 0.999))

# scheduler = optim.lr_scheduler.StepLR(optimizerA, step_size=change_lr_each, gamma=0.1)

In [None]:
# load weights

In [None]:
plotA = []
plotB = []

trained_models = "./models"

if not os.path.exists(trained_models):
    os.makedirs(trained_models)

for epoch in range(epochs):
    t = tqdm(enumerate(loader), leave=True)
    for i, (imgsA, imgsB) in t:
        imgsA = imgsA.to(device)
        imgsB = imgsB.to(device)
        
        optimizerA.zero_grad()
        optimizerB.zero_grad()
        
        outputsA = model(imgsA, "A")
        outputsB = model(imgsB, "B")
        
        lossA = criterion(outputsA, imgsA)
        lossB = criterion(outputsB, imgsB)        
        
        lossA.backward()
        lossB.backward()
        
        optimizerA.step()
        optimizerB.step()
        
        plotA.append(lossA.item())
        plotB.append(lossB.item())
#         plotter.plot('loss', 'val', 'Loss', epoch, [lossA.item(), lossB.item()])
                
        training_window(imgsA, imgsB, outputsA, outputsB, epoch, i)

        t.set_description(f"Loss A: {lossA.item()}, Loss B: {lossB.item()}")
        t.refresh()
        
    # schedulerA.step()
    # schedulerB.step()
            
    torch.save(model.state_dict(), f"{trained_models}/model_{epoch}.pth")


In [None]:
plt.plot(plotA)
plt.plot(plotB)
plt.legend(["Loss model A", "Loss model B"])