In [1]:
import random
from tqdm import tqdm
import torch,torchvision
from torch.nn import *
from torch.optim import *
import numpy as np
import pandas as pd
import cv2
import matplotlib.pyplot as plt
import wandb
import os
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
PROJECT_NAME = 'Fake-Face-V1'
device = 'cuda'
IMG_SIZE = 224

In [2]:
transformations = torchvision.transforms.Compose(
[torchvision.transforms.ToTensor()]
)

In [3]:
def load_data(data_dir='./data/'):
    idx = -1
    data = []
    for file in tqdm(os.listdir(data_dir)):
        idx += 1
        file = data_dir + file
        img = cv2.imread(file)
        img = cv2.resize(img,(IMG_SIZE,IMG_SIZE))
        data.append(np.array(transformations(np.array(img))))
    return data

In [4]:
# data = load_data()

In [5]:
# data = torch.from_numpy(np.array(data))

In [6]:
# torch.save(data,'data.pt')
# torch.save(data,'data.pth')

In [7]:
data = load_data()

In [8]:
data = torch.from_numpy(np.array(data))

In [9]:
# torch.save(data,'data.pt')
# torch.save(data,'data.pth')

In [10]:
# data = torch.load('data.pt')

In [11]:
plt.imshow(data[0])

In [12]:
plt.imshow(data[0].view(224,224,3))

<matplotlib.image.AxesImage at 0x7f9cad2dfdd0>

In [13]:
plt.imshow(data[0].view(IMG_SIZE,IMG_SIZE,3))

<matplotlib.image.AxesImage at 0x7f9cad24d090>

In [14]:
class Desc(Module):
    def __init__(self,starter=512,activation=ReLU()):
        super().__init__()
        self.dropout = Dropout()
        self.activation = activation
        self.linear1 = Linear(IMG_SIZE*IMG_SIZE*3,starter)
        self.linear1batchnorm = BatchNorm1d(starter)
        self.linear2 = Linear(starter,starter*2)
        self.linear2batchnorm = BatchNorm1d(starter*2)
        self.linear3 = Linear(starter*2,starter*2*2)
        self.linear3batchnorm = BatchNorm1d(starter*2*2)
        self.linear4 = Linear(starter*2*2,starter*2)
        self.linear4batchnorm = BatchNorm1d(starter*2)
        self.output = Linear(starter*2,1)
        self.output_activation = Sigmoid()
        
    def forward(self,X):
        preds = self.dropout(self.activation(self.linear1batchnorm(self.linear1(X))))
        preds = self.dropout(self.activation(self.linear2batchnorm(self.linear2(preds))))
        preds = self.dropout(self.activation(self.linear3batchnorm(self.linear3(preds))))
        preds = self.dropout(self.activation(self.linear4batchnorm(self.linear4(preds))))
        preds = self.output_activation(self.output(preds))
        return preds

In [15]:
class Gen(Module):
    def __init__(self,z_dim,starter=512,activation=ReLU()):
        super().__init__()
        self.dropout = Dropout()
        self.activation = activation
        self.linear1 = Linear(z_dim,starter)
        self.linear1batchnorm = BatchNorm1d(starter)
        self.linear2 = Linear(starter,starter*2)
        self.linear2batchnorm = BatchNorm1d(starter*2)
        self.linear3 = Linear(starter*2,starter*2*2)
        self.linear3batchnorm = BatchNorm1d(starter*2*2)
        self.linear4 = Linear(starter*2*2,starter*2)
        self.linear4batchnorm = BatchNorm1d(starter*2)
        self.output = Linear(starter*2,IMG_SIZE*IMG_SIZE*3)
        self.output_activation = Tanh()
        
    def forward(self,X):
        preds = self.dropout(self.activation(self.linear1batchnorm(self.linear1(X))))
        preds = self.dropout(self.activation(self.linear2batchnorm(self.linear2(preds))))
        preds = self.dropout(self.activation(self.linear3batchnorm(self.linear3(preds))))
        preds = self.dropout(self.activation(self.linear4batchnorm(self.linear4(preds))))
        preds = self.output_activation(self.output(preds))

In [16]:
z_dim = 64
gen = Gen().to(device)
desc = Desc().to(device)
lr = 3e-4
batch_size = 32
epochs = 100
optimizer_gen = Adam(gen.parameters(),lr=lr)
optimizer_desc = Adam(desc.parameters(),lr=lr)
criterion = BCELoss()
fixed_noise = torch.randn((batch_size,z_dim)).to(device)

In [17]:
z_dim = 64
gen = Gen(z_dim).to(device)
desc = Desc().to(device)
lr = 3e-4
batch_size = 32
epochs = 100
optimizer_gen = Adam(gen.parameters(),lr=lr)
optimizer_desc = Adam(desc.parameters(),lr=lr)
criterion = BCELoss()
fixed_noise = torch.randn((batch_size,z_dim)).to(device)

In [18]:
wandb.init(project=PROJECT_NAME,name='baseline')
for _ in tqdm(range(epochs)):
    for idx in range(0,len(data),batch_size):
        data_batch = data[idx:idx+batch_size].view(IMG_SIZE*IMG_SIZE*3).float().to(device)
        batch_size = data_batch.shape[0]
        noise = torch.randn((batch_size,z_dim)).to(device)
        fake = gen(noise)
        desc_fake = desc(fake)
        lossD_fake = criterion(desc_fake,torch.zeros_like(desc_fake))
        desc_real = desc(data_batch)
        lossD_real = criterion(desc_real,torch.ones_like(desc_real))
        lossD = (lossD_fake/lossD_real)/2
        desc.zero_grad()
        lossD.backward(retain_graph=True)
        wandb.log({'lossD':lossD.item()})
        optimizer_desc.step()
        output = desc(fake).view(-1)
        lossG = criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        wandb.log({'lossG':lossG.item()})
        optimizer_gen.step()
with torch.no_grad():
    fake = gen(fixed_noise).view(-1,3,IMG_SIZE,IMG_SIZE)
    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
    wandb.log({'img':wandb.Image(img_grid_fake)})

In [19]:
data[idx:idx+batch_size]

tensor([[[[0.1765, 0.1176, 0.0510,  ..., 0.6000, 0.3804, 0.0980],
          [0.0667, 0.0353, 0.0353,  ..., 0.5137, 0.5804, 0.6118],
          [0.0314, 0.0157, 0.0353,  ..., 0.6863, 0.5255, 0.4431],
          ...,
          [0.0118, 0.0000, 0.0000,  ..., 0.0235, 0.0392, 0.0431],
          [0.0157, 0.0039, 0.0000,  ..., 0.0235, 0.0431, 0.0549],
          [0.0157, 0.0039, 0.0000,  ..., 0.0196, 0.0431, 0.0549]],

         [[0.1843, 0.1059, 0.0588,  ..., 0.6549, 0.4157, 0.1137],
          [0.0627, 0.0353, 0.0314,  ..., 0.5529, 0.6314, 0.6824],
          [0.0235, 0.0275, 0.0275,  ..., 0.7412, 0.5686, 0.4667],
          ...,
          [0.0118, 0.0000, 0.0000,  ..., 0.0275, 0.0392, 0.0471],
          [0.0157, 0.0039, 0.0000,  ..., 0.0314, 0.0431, 0.0588],
          [0.0157, 0.0039, 0.0000,  ..., 0.0275, 0.0431, 0.0588]],

         [[0.1882, 0.1137, 0.0667,  ..., 0.6431, 0.4392, 0.1451],
          [0.0784, 0.0510, 0.0471,  ..., 0.5412, 0.6235, 0.6588],
          [0.0392, 0.0431, 0.0431,  ..., 0

In [20]:

data[idx:idx+batch_size].shape

torch.Size([32, 3, 224, 224])

In [21]:
wandb.init(project=PROJECT_NAME,name='baseline')
for _ in tqdm(range(epochs)):
    for idx in range(0,len(data),batch_size):
        data_batch = data[idx:idx+batch_size].view(-1,IMG_SIZE*IMG_SIZE*3).float().to(device)
        batch_size = data_batch.shape[0]
        noise = torch.randn((batch_size,z_dim)).to(device)
        fake = gen(noise)
        desc_fake = desc(fake)
        lossD_fake = criterion(desc_fake,torch.zeros_like(desc_fake))
        desc_real = desc(data_batch)
        lossD_real = criterion(desc_real,torch.ones_like(desc_real))
        lossD = (lossD_fake/lossD_real)/2
        desc.zero_grad()
        lossD.backward(retain_graph=True)
        wandb.log({'lossD':lossD.item()})
        optimizer_desc.step()
        output = desc(fake).view(-1)
        lossG = criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        wandb.log({'lossG':lossG.item()})
        optimizer_gen.step()
with torch.no_grad():
    fake = gen(fixed_noise).view(-1,3,IMG_SIZE,IMG_SIZE)
    img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
    wandb.log({'img':wandb.Image(img_grid_fake)})