In [None]:
from torch import nn
import random
from tqdm import tqdm
import torch,torchvision
from torch.nn import *
from torch.optim import *
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import wandb
import os
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
PROJECT_NAME = 'Fake-Face-V1'
device = 'cuda'
IMG_SIZE = 224

In [None]:
# transformations = torchvision.transforms.Compose(
# [torchvision.transforms.ToTensor(),torchvision.transforms.Normalize(0.4,0.4)]
# )
transformations = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor(),torchvision.transforms.ColorJitter(0.5,0.5,0.5,0.5)]
)

In [None]:
def load_data(data_dir='./data/',transformations=transformations):
    idx = -1
    data = []
    for file in tqdm(os.listdir(data_dir)):
        idx += 1
        file = data_dir + file
        img = cv2.imread(file)
        img = cv2.resize(img,(IMG_SIZE,IMG_SIZE))
        data.append(np.array(transformations(np.array(img))))
    return data

In [None]:
data = load_data()

In [None]:
data = torch.from_numpy(np.array(data))

In [None]:
# torch.save(data,'data.pt')
# torch.save(data,'data.pth')

In [None]:
# data = torch.load('data.pt')

In [None]:
plt.imshow(data[0].view(IMG_SIZE,IMG_SIZE,3))

In [None]:
class Desc(nn.Module):
    def __init__(self,activation=nn.LeakyReLU,starter=128):
        super().__init__()
        self.dropout = Dropout()
        self.activation = activation()
        self.linear1 = Linear(IMG_SIZE*IMG_SIZE*3,starter)
        self.linear1batchnorm = BatchNorm1d(starter)
        self.linear2 = Linear(starter,starter*2)
        self.linear2batchnorm = BatchNorm1d(starter*2)
        self.linear3 = Linear(starter*2,starter*4)
        self.linear3batchnorm = BatchNorm1d(starter*4)
        self.linear4 = Linear(starter*4,starter*2)
        self.linear4batchnorm = BatchNorm1d(starter*2)
        self.output = Linear(starter*2,1)
        self.output_activation = Sigmoid()
    
    def forward(self,X):
        preds = self.dropout(self.activation(self.linear1batchnorm(self.linear1(X))))
        preds = self.dropout(self.activation(self.linear2batchnorm(self.linear2(preds))))
        preds = self.dropout(self.activation(self.linear3batchnorm(self.linear3(preds))))
        preds = self.dropout(self.activation(self.linear4batchnorm(self.linear4(preds))))
        preds = self.output_activation(self.output(preds))
        return preds
        
class Gen(nn.Module):
    def __init__(self,z_dim,activation=nn.LeakyReLU,starter=512):
        super().__init__()
        self.activation = activation()
        self.linear1 = Linear(z_dim,starter)
        self.linear1batchnorm = BatchNorm1d(starter)
        self.linear2 = Linear(starter,starter*2)
        self.linear2batchnorm = BatchNorm1d(starter*2)
        self.linear3 = Linear(starter*2,starter*4)
        self.linear3batchnorm = BatchNorm1d(starter*4)
        self.linear4 = Linear(starter*4,starter*4)
        self.linear4batchnorm = BatchNorm1d(starter*4)
        self.linear5 = Linear(starter*4,starter*2)
        self.linear5batchnorm = BatchNorm1d(starter*2)
        self.output = Linear(starter*2,IMG_SIZE*IMG_SIZE*3)
        self.output_activation = Tanh()
 
    def forward(self, X):
        preds = self.activation(self.linear1batchnorm(self.linear1(X)))
        preds = self.activation(self.linear2batchnorm(self.linear2(preds)))
        preds = self.activation(self.linear3batchnorm(self.linear3(preds)))
        preds = self.activation(self.linear4batchnorm(self.linear4(preds)))
        preds = self.activation(self.linear5batchnorm(self.linear5(preds)))
        preds = self.output_activation(self.output(preds))
        return preds

In [None]:
data = data[0:1056]

In [None]:
z_dim = 64
gen = Gen(z_dim).to(device)
desc = Desc().to(device)
lr = 3e-4
batch_size = 32
epochs = 250
optimizer_gen = Adam(gen.parameters(),lr=lr)
optimizer_desc = Adam(desc.parameters(),lr=lr)
criterion = BCELoss()
fixed_noise = torch.randn((batch_size,z_dim)).to(device)

In [None]:
name = f'baseline-0.5'
wandb.init(project=PROJECT_NAME,name=name)
epochs_iter = tqdm(range(epochs))
for _ in epochs_iter:
    torch.cuda.empty_cache()
    for idx in range(0,len(data),32):
        try:
            torch.cuda.empty_cache()
            data_batch = torch.tensor(np.array(data[idx:idx+batch_size])).view(-1,IMG_SIZE*IMG_SIZE*3).to(device)
            batch_size = data_batch.shape[0]
            noise = torch.randn(batch_size,z_dim).to(device)
            fake = gen(noise)
            desc_real = desc(data_batch).view(-1)
            lossD_real = criterion(desc_real,torch.ones_like(desc_real))
            desc_fake = desc(fake).view(-1)
            lossD_fake = criterion(desc_fake,torch.zeros_like(desc_fake))
            lossD = (lossD_real+lossD_fake)/2
            desc.zero_grad()
            lossD.backward(retain_graph=True)
            wandb.log({'lossD':lossD.item()})
            optimizer_desc.step()
            output = desc(fake).view(-1)
            lossG = criterion(output,torch.ones_like(output))
            gen.zero_grad()
            lossG.backward()
            wandb.log({'lossG':lossG.item()})
            optimizer_gen.step()
        except Exception as e:
            pass
with torch.no_grad():
    fake = gen(fixed_noise).view(-1,3,IMG_SIZE,IMG_SIZE)
    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
    wandb.log({'img':wandb.Image(img_grid_fake)})
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mranuga-d[0m (use `wandb login --relogin` to force relogin)


 24%|██▍       | 61/250 [04:18<13:30,  4.29s/it]

In [15]:
# gen_starters = [1024,2048,2048*2,2048*4]
# gen_activations = [ELU,LeakyReLU,PReLU,ReLU,ReLU6,RReLU,SELU,GELU,SiLU,Tanh]
# desc_activations = [ELU,LeakyReLU,PReLU,ReLU,ReLU6,RReLU,SELU,GELU,SiLU,Tanh]
# lrs = [8e-1,8e-2,8e-3,,8e-5]
# then lrs = [whateverbest_lr-1,whateverbest_lr-2,whateverbest_lr-3,whateverbest_lr-4,whateverbest_lr-5]
# batch_sizes = [8,16,32,64,128,256,512]
# optimizers_gen = [Adam,AdamW,Adamax,ASGD,SGD,Rprop,RMSprop]
# optimizers_desc = [Adam,AdamW,Adamax,ASGD,SGD,Rprop,RMSprop]
# criterions = [BCELoss(),MSELoss(),L1Loss()]

In [16]:
# for optimizer_gen in optimizers_gen:
#     z_dim = 256
#     gen = Gen(z_dim,starter=1024,activation=LeakyReLU).to(device)
#     desc = Desc(activation=SiLU).to(device)
#     lr = 8e-4
#     batch_size = 32
#     epochs = 100
#     optimizer_gen = optimizer_gen(gen.parameters(),lr=lr)
#     optimizer_desc = Adam(desc.parameters(),lr=lr)
#     criterion = BCELoss()
#     fixed_noise = torch.randn((batch_size,z_dim)).to(device)
#     name = f'{optimizer_gen}-optimizer_gen'
#     wandb.init(project=PROJECT_NAME,name=name)
#     epochs_iter = tqdm(range(epochs),desc='Bar desc')
#     for _ in epochs_iter:
#         torch.cuda.empty_cache()
#         for idx in range(0,len(data),batch_size):
#             try:
#                 torch.cuda.empty_cache()
#                 data_batch = data[idx:idx+batch_size].view(-1,IMG_SIZE*IMG_SIZE*3).float().to(device)
#                 batch_size = data_batch.shape[0]
#                 noise = torch.randn((batch_size,z_dim)).to(device)
#                 fake = gen(noise)
#                 desc_fake = desc(fake).view(-1)
#                 lossD_fake = criterion(desc_fake,torch.zeros_like(desc_fake))
#                 desc_real = desc(data_batch).view(-1)
#                 lossD_real = criterion(desc_real,torch.ones_like(desc_real))
#                 lossD = (lossD_fake/lossD_real)/2
#                 desc.zero_grad()
#                 lossD.backward(retain_graph=True)
#                 optimizer_desc.step()
#                 output = desc(fake).view(-1)
#                 lossG = criterion(output,torch.ones_like(output))
#                 gen.zero_grad()
#                 lossG.backward()
#                 optimizer_gen.step()
#             except Exception as e:
#                 pass
#         epochs_iter.set_description(f'lossD - {lossD.item()} | lossG - {lossG.item()}')
#         wandb.log({'lossD':lossD.item()})
#         wandb.log({'lossG':lossG.item()})
#     with torch.no_grad():
#         fake = gen(fixed_noise).view(-1,3,IMG_SIZE,IMG_SIZE)
#         img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
#         wandb.log({'img':wandb.Image(img_grid_fake)})
#     wandb.finish()

In [17]:
def accuracy_fake(desc_fake):
    correct = 0
    total = 0
    preds = np.round(np.array(desc_fake.cpu().detach().numpy()))
    for pred in preds:
        if pred == 0:
            correct += 1
        total += 1
    acc = round(correct / total,3)
    return acc
def accuracy_real(desc_real):
    correct = 0
    total = 0
    preds = np.round(np.array(desc_real.cpu().detach().numpy()))
    for pred in preds:
        if pred == 1:
            correct += 1
        total += 1
    acc = round(correct / total,3)
    return acc