In [1]:
import random
import torch,torchvision
import cv2
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb
import numpy as np
from torch.nn import *
from torch.optim import *
from torchvision.transforms import *
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [2]:
IMG_SIZE = 224
PROJECT_NAME = 'Fake-Face-V2'
device = torch.device('cuda')

In [3]:
transformation = Compose([
    ToTensor(),
    RandomVerticalFlip(),
    RandomHorizontalFlip(),
])

In [4]:
def load_data(data_dir='./data/',transformation=transformation):
    data = []
    idx = -1
    for file in tqdm(os.listdir(data_dir)):
        idx += 1
        file = data_dir + file
        img = cv2.imread(file)
        img = cv2.resize(img,(IMG_SIZE,IMG_SIZE))
        data.append(np.array(transformation(np.array(img))))
    np.random.shuffle(data)
    data = torch.from_numpy(np.array(data))
    print(idx)
    return data

In [5]:
data = load_data()

100%|██████████| 1081/1081 [00:10<00:00, 100.37it/s]


1080


In [7]:
class Desc(Module):
    def __init__(self,linear_starter=128):
        super().__init__()
        self.activation = ReLU()
        self.max_pool2d = MaxPool2d((2,2),(2,2))
        self.dropout2d = Dropout2d()
        self.conv1 = Conv2d(3,8,3)
        self.conv1batchnorm = BatchNorm2d(8)
        self.conv2 = Conv2d(8,16,3)
        self.conv2batchnorm = BatchNorm2d(16)
        self.conv3 = Conv2d(16,8,3)
        self.conv3batchnorm = BatchNorm2d(8)
        self.dropout = Dropout()
        self.linear1 = Linear(8*5*5,linear_starter)
        self.linear1batchnorm = BatchNorm1d(linear_starter)
        self.linear2 = Linear(linear_starter,linear_starter*2)
        self.linear2batchnorm = BatchNorm1d(linear_starter*2)
        self.linear3 = Linear(linear_starter*2,linear_starter)
        self.linear3batchnorm = BatchNorm1d(linear_starter)
        self.output = Linear(linear_starter,1)
        self.output_activation = Sigmoid()
    
    def forward(self,X):
        preds = X.view(-1,3,IMG_SIZE,IMG_SIZE)
        preds = self.max_pool2d(self.activation(self.dropout2d(self.conv1batchnorm(self.conv1(preds)))))
        preds = self.max_pool2d(self.activation(self.dropout2d(self.conv2batchnorm(self.conv2(preds)))))
        preds = self.max_pool2d(self.activation(self.dropout2d(self.conv3batchnorm(self.conv3(preds)))))
        print(preds.shape)
        preds = preds.view(-1,8*5*5)
        preds = self.activation(self.dropout(self.linear1batchnorm(self.linear1(preds))))
        preds = self.activation(self.dropout(self.linear2batchnorm(self.linear2(preds))))
        preds = self.activation(self.dropout(self.linear3batchnorm(self.linear3(preds))))
        preds = self.output_activation(self.output(preds))
        return preds

In [8]:
class Gen(Module):
    def __init__(self,z_dim=64,linear_starter=128):
        super().__init__()
        self.activation = ReLU()
        self.dropout = Dropout()
        self.linear1 = Linear(z_dim,linear_starter)
        self.linear1batchnorm = BatchNorm1d(linear_starter)
        self.linear2 = Linear(linear_starter,linear_starter*2)
        self.linear2batchnorm = BatchNorm1d(linear_starter*2)
        self.linear3 = Linear(linear_starter*2,linear_starter*3)
        self.linear3batchnorm = BatchNorm1d(linear_starter*3)
        self.linear4 = Linear(linear_starter*3,linear_starter*4)
        self.linear4batchnorm = BatchNorm1d(linear_starter*4)
        self.linear5 = Linear(linear_starter*4,linear_starter*5)
        self.linear5batchnorm = BatchNorm1d(linear_starter*5)
        self.linear6 = Linear(linear_starter*5,linear_starter*4)
        self.linear6batchnorm = BatchNorm1d(linear_starter*4)
        self.output = Linear(linear_starter*4,IMG_SIZE*IMG_SIZE*3)
        self.output_activation = Tanh()
    
    def forward(self,X):
        preds = self.activation(self.dropout(self.linear1batchnorm(self.linear1(X))))
        preds = self.activation(self.dropout(self.linear2batchnorm(self.linear2(preds))))
        preds = self.activation(self.dropout(self.linear3batchnorm(self.linear3(preds))))
        preds = self.activation(self.dropout(self.linear4batchnorm(self.linear4(preds))))
        preds = self.activation(self.dropout(self.linear5batchnorm(self.linear5(preds))))
        preds = self.activation(self.dropout(self.linear6batchnorm(self.linear6(preds))))
        preds = self.output_activation(self.output(preds))
        return preds

In [9]:
z_dim = 64
batch_size = 32
desc = Desc().to(device)
gen = Gen().to(device)
lr = 3e-4
epochs = 125
criterion = BCELoss()
optimizer_gen = Adam(gen.parameters(),lr=lr)
optimizer_desc = Adam(desc.parameters(),lr=lr)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)

In [10]:
def accuracy_fake(desc_fake):
    correct = 0
    total = 0
    preds = np.round(np.array(desc_fake.cpu().detach().numpy()))
    for pred in preds:
        if pred == 0:
            correct += 1
        total += 1
    return round(correct/total,3)
def accuracy_real(desc_real):
    correct = 0
    total = 0
    preds = np.round(np.array(desc_real.cpu().detach().numpy()))
    for pred in preds:
        if pred == 1:
            correct += 1
        total += 1
    return round(correct/total,3)

In [11]:
wandb.init(project=PROJECT_NAME,name='test')
for epoch in tqdm(range(epochs)):
    for idx in range(0,len(data),batch_size):
        data_batch = data[idx:idx+batch_size].view(-1,3,IMG_SIZE,IMG_SIZE).to(device).float()
        batch_size = data_batch.shape[0]
        noise = torch.randn((batch_size,z_dim)).to(device)
        fake = gen(noise)
        fake.to(device)
        desc_fake = desc(fake).view(-1)
        lossD_fake = criterion(desc_fake,torch.zeros_like(desc_fake))
        desc_real = desc(data_batch).view(-1)
        lossD_real = criterion(desc_real,torch.ones_like(desc_real))
        lossD = (lossD_real+lossD_fake)/2
        optimizer_desc.zero_grad()
        lossD.backward(retain_graph=True)
        optimizer_desc.step()
        output = gen(fake).view(-1)
        lossG = criterion(output,torch.ones_like(output))
        optimizer_gen.zero_grad()
        lossG.backward()
        optimizer_gen.step()
        wandb.log({'lossD':lossD.item()})
        wandb.log({'lossG':lossG.item()})
        wandb.log({'lossD Fake':lossD_fake.item()})
        wandb.log({'lossD Real':lossD_real.item()})
        wandb.log({'Acc Fake':accuracy_fake(desc_fake)})
        wandb.log({'Acc Real':accuracy_real(desc_real)})
with torch.no_grad():
    fake = gen(noise).view(-1,3,IMG_SIZE,IMG_SIZE)
    wandb.log({'img':wandb.Image(img_grid_fake)})
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mranuga-d[0m (use `wandb login --relogin` to force relogin)


  0%|          | 0/125 [00:00<?, ?it/s]


RuntimeError: shape '[-1, 3, 224, 224]' is invalid for input of size 32