In [1]:
import random
import torch,torchvision
import cv2
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import wandb
import numpy as np
from torch.nn import *
from torch.optim import *
from torchvision.transforms import *
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

In [2]:
IMG_SIZE = 224
PROJECT_NAME = 'Fake-Face-V2'
device = torch.device('cuda')

In [3]:
transformation = Compose([
    ToTensor(),
])

In [4]:
def load_data(data_dir='./data/',transformation=transformation):
    data = []
    idx = -1
    for file in tqdm(os.listdir(data_dir)):
        idx += 1
        file = data_dir + file
        img = cv2.imread(file)
        img = cv2.resize(img,(IMG_SIZE,IMG_SIZE))
        data.append(np.array(transformation(np.array(img))))
    np.random.shuffle(data)
    data = torch.from_numpy(np.array(data))
    print(idx)
    return data

In [5]:
data = load_data()

100%|██████████| 1081/1081 [00:08<00:00, 126.84it/s]


1080


In [6]:
from torch import nn

In [7]:
class Desc(Module):
    def __init__(self,activation=LeakyReLU,starter=128):
        super().__init__()
        self.dropout = Dropout()
        self.activation = activation()
        self.linear1 = Linear(IMG_SIZE*IMG_SIZE*3,starter)
        self.linear1batchnorm = BatchNorm1d(starter)
        self.linear2 = Linear(starter,starter*2)
        self.linear2batchnorm = BatchNorm1d(starter*2)
        self.linear3 = Linear(starter*2,starter*4)
        self.linear3batchnorm = BatchNorm1d(starter*4)
        self.linear4 = Linear(starter*4,starter*8)
        self.linear4batchnorm = BatchNorm1d(starter*8)
        self.linear5 = Linear(starter*8,starter*4)
        self.linear5batchnorm = BatchNorm1d(starter*4)
        self.output = Linear(starter*4,1)
        self.output_activation = Sigmoid()
    
    def forward(self,X):
        X = X.view(-1,IMG_SIZE*IMG_SIZE*3)
        preds = self.dropout(self.activation(self.linear1batchnorm(self.linear1(X))))
        preds = self.dropout(self.activation(self.linear2batchnorm(self.linear2(preds))))
        preds = self.dropout(self.activation(self.linear3batchnorm(self.linear3(preds))))
        preds = self.dropout(self.activation(self.linear4batchnorm(self.linear4(preds))))
        preds = self.dropout(self.activation(self.linear5batchnorm(self.linear5(preds))))
        preds = self.output_activation(self.output(preds))
        return preds
        
class Gen(Module):
    def __init__(self,z_dim,activation=LeakyReLU,starter=512):
        super().__init__()
        self.dropout = Dropout()
        self.activation = activation()
        self.linear1 = Linear(z_dim,starter)
        self.linear1batchnorm = BatchNorm1d(starter)
        self.linear2 = Linear(starter,starter*2)
        self.linear2batchnorm = BatchNorm1d(starter*2)
        self.linear3 = Linear(starter*2,starter*4)
        self.linear3batchnorm = BatchNorm1d(starter*4)
        self.linear4 = Linear(starter*4,starter*8)
        self.linear4batchnorm = BatchNorm1d(starter*8)
        self.linear5 = Linear(starter*8,starter*4)
        self.linear5batchnorm = BatchNorm1d(starter*4)
        self.output = Linear(starter*4,IMG_SIZE*IMG_SIZE*3)
        self.output_activation = Tanh()
 
    def forward(self, X):
        preds = self.dropout(self.activation(self.linear1batchnorm(self.linear1(X))))
        preds = self.dropout(self.activation(self.linear2batchnorm(self.linear2(preds))))
        preds = self.dropout(self.activation(self.linear3batchnorm(self.linear3(preds))))
        preds = self.dropout(self.activation(self.linear4batchnorm(self.linear4(preds))))
        preds = self.dropout(self.activation(self.linear5batchnorm(self.linear5(preds))))
        preds = self.output_activation(self.output(preds))
        return preds

In [8]:
z_dim = 64
batch_size = 32
desc = Desc().to(device)
gen = Gen(z_dim).to(device)
lr = 3e-4
epochs = 56
criterion = BCELoss()
optimizer_gen = Adam(gen.parameters(),lr=lr)
optimizer_desc = Adam(desc.parameters(),lr=lr)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)

In [9]:
data = data[0:1056]

In [10]:
def accuracy_fake(desc_fake):
    correct = 0
    total = 0
    preds = np.round(np.array(desc_fake.cpu().detach().numpy()))
    for pred in preds:
        if pred == 0:
            correct += 1
        total += 1
    return round(correct/total,3)
def accuracy_real(desc_real):
    correct = 0
    total = 0
    preds = np.round(np.array(desc_real.cpu().detach().numpy()))
    for pred in preds:
        if pred == 1:
            correct += 1
        total += 1
    return round(correct/total,3)

In [11]:
transformation = Compose([
    ToTensor(),
])
data = load_data(transformation=transformation)
data = data[0:1056]
wandb.init(project=PROJECT_NAME,name='transformation-None')
for epoch in tqdm(range(epochs)):
    for idx in range(0,len(data),batch_size):
#         try:
            data_batch = data[idx:idx+batch_size].view(-1,3,IMG_SIZE,IMG_SIZE).to(device).float()
            batch_size = data_batch.shape[0]
            noise = torch.randn(batch_size,z_dim).to(device)
            fake = gen(noise)
            fake.to(device)
            desc_fake = desc(fake).view(-1)
            lossD_fake = criterion(desc_fake,torch.zeros_like(desc_fake))
            desc_real = desc(data_batch).view(-1)
            lossD_real = criterion(desc_real,torch.ones_like(desc_real))
            lossD = (lossD_real+lossD_fake)/2
            optimizer_desc.zero_grad()
            lossD.backward(retain_graph=True)
            optimizer_desc.step()
            output = desc(fake).view(-1)
            lossG = criterion(output,torch.ones_like(output))
            optimizer_gen.zero_grad()
            lossG.backward()
            optimizer_gen.step()
#         except Exception as e:
#             print(e)
    wandb.log({'lossD':lossD.item()})
    wandb.log({'lossG':lossG.item()})
    wandb.log({'lossD Fake':lossD_fake.item()})
    wandb.log({'lossD Real':lossD_real.item()})
    wandb.log({'Acc Fake':accuracy_fake(desc_fake)})
    wandb.log({'Acc Real':accuracy_real(desc_real)})
    with torch.no_grad():
        fake = gen(fixed_noise).view(-1,3,IMG_SIZE,IMG_SIZE)
        img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
        wandb.log({'img':wandb.Image(img_grid_fake)})
wandb.finish()

100%|██████████| 1081/1081 [00:07<00:00, 137.11it/s]


1080


[34m[1mwandb[0m: W&B API key is configured (use `wandb login --relogin` to force relogin)


100%|██████████| 56/56 [05:37<00:00,  6.03s/it]


VBox(children=(Label(value=' 115.80MB of 256.74MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.4510463…

0,1
lossD,0.682
_runtime,371.0
_timestamp,1629171642.0
_step,391.0
lossG,0.82757
lossD Fake,0.70489
lossD Real,0.65911
Acc Fake,0.5
Acc Real,0.656


0,1
lossD,▆█▄▄▅▅▆▄▄▅▅▆▄▅▄▄▆▅▃▄▅▄▄▄▅▅▄▅▄▄▅▄▅▄▄▃▃▁▄▄
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lossG,▅▃▃▅▂▄▃▃▁▂▄▃▄▄▃▂▂▂▁▄▃▂▃▃▂▂▁▂▃▁▂▄▃▃▄▂▅▃▄█
lossD Fake,▃█▃▂▄▂▄▃▂▄▃▆▄▄▄▂▅▂▁▂▄▃▃▃▄▄▃▃▃▃▄▃▄▄▃▂▂▁▆▄
lossD Real,█▇▄▆▆▆▇▄▅▆▆▄▃▅▄▆▆▆▅▅▅▅▄▅▅▄▄▅▄▄▅▅▅▃▅▄▄▁▃▃
Acc Fake,▄▂▇█▅▆▃▅▆▆▄▁▄▆▄▇▄▅▇▆▃▄▄▄▄▄▄▃▄▆▄▆▃▃▆▇▆▇▁▄
Acc Real,▄▁▅▄▂▂▃▇▄▂▃▆▇▄▆▂▂▄▃▄▄▃▄▄▃▆▅▅▆▄▆▂▄▄▅▄▆█▆▇


In [None]:
transformations = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
for transformation in transformations:
    transformation_num = transformation
    transformation = Compose([
        ToTensor(),
        Normalize((transformation,transformation,transformation),(transformation,transformation,transformation))
    ])
    data = load_data(transformation=transformation)
    data = data[0:1056]
    wandb.init(project=PROJECT_NAME,name=f'transformation-{transformation_num}')
    for epoch in tqdm(range(epochs)):
        for idx in range(0,len(data),batch_size):
                data_batch = data[idx:idx+batch_size].view(-1,3,IMG_SIZE,IMG_SIZE).to(device).float()
                batch_size = data_batch.shape[0]
                noise = torch.randn(batch_size,z_dim).to(device)
                fake = gen(noise)
                fake.to(device)
                desc_fake = desc(fake).view(-1)
                lossD_fake = criterion(desc_fake,torch.zeros_like(desc_fake))
                desc_real = desc(data_batch).view(-1)
                lossD_real = criterion(desc_real,torch.ones_like(desc_real))
                lossD = (lossD_real+lossD_fake)/2
                optimizer_desc.zero_grad()
                lossD.backward(retain_graph=True)
                optimizer_desc.step()
                output = desc(fake).view(-1)
                lossG = criterion(output,torch.ones_like(output))
                optimizer_gen.zero_grad()
                lossG.backward()
                optimizer_gen.step()
        wandb.log({'lossD':lossD.item()})
        wandb.log({'lossG':lossG.item()})
        wandb.log({'lossD Fake':lossD_fake.item()})
        wandb.log({'lossD Real':lossD_real.item()})
        wandb.log({'Acc Fake':accuracy_fake(desc_fake)})
        wandb.log({'Acc Real':accuracy_real(desc_real)})
        with torch.no_grad():
            fake = gen(fixed_noise).view(-1,3,IMG_SIZE,IMG_SIZE)
            img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
            wandb.log({'img':wandb.Image(img_grid_fake)})
    wandb.finish()

100%|██████████| 1081/1081 [00:04<00:00, 222.64it/s]


1080


 41%|████      | 23/56 [02:19<03:24,  6.19s/it][34m[1mwandb[0m: Network error resolved after 0:00:16.187103, resuming normal operation.
 68%|██████▊   | 38/56 [03:53<01:52,  6.24s/it][34m[1mwandb[0m: Network error resolved after 0:01:17.410483, resuming normal operation.
 91%|█████████ | 51/56 [05:14<00:31,  6.25s/it][34m[1mwandb[0m: Network error resolved after 0:01:13.933779, resuming normal operation.
100%|██████████| 56/56 [05:45<00:00,  6.17s/it]


[34m[1mwandb[0m: Network error resolved after 0:01:02.640130, resuming normal operation.


VBox(children=(Label(value=' 144.49MB of 147.71MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.9781882…

0,1
lossD,0.02094
_runtime,353.0
_timestamp,1629172409.0
_step,391.0
lossG,4.07846
lossD Fake,0.01843
lossD Real,0.02346
Acc Fake,1.0
Acc Real,1.0


0,1
lossD,█▇█▇█▆▆▆█▇█▅▅▄▅▃▃▃▄▃▂▃▂▂▂▂▂▂▂▂▅▄▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
lossG,▁▁▁▁▁▁▁▂▁▂▂▂▃▂▃▃▃▃▅▄▄▅▅▄▅▃▆▅▃▄▇▆▇▇▆█▅▇▇▇
lossD Fake,█▇█▇▇▇▆▆▇▆▇▆▆▄▅▄▂▃▄▃▁▃▂▃▃▂▂▂▂▁▄▆▂▁▂▁▁▁▁▁
lossD Real,█▇▇▇▇▅▇▇▇▇█▄▄▄▅▂▄▂▄▂▂▂▂▁▂▂▂▁▂▃▅▂▁▁▁▁▁▁▁▁
Acc Fake,▁▂▁▄▄▃▆▅▃▅▄▆▆█▆▇██▆▇██████████▇▆████████
Acc Real,▃▅▇▆▄▇▅▄▅▃▁███▇███▇███████████▅█████████


100%|██████████| 1081/1081 [00:07<00:00, 136.29it/s]


1080


 96%|█████████▋| 54/56 [05:35<00:12,  6.20s/it][34m[1mwandb[0m: Network error resolved after 0:00:21.259028, resuming normal operation.
100%|██████████| 56/56 [05:47<00:00,  6.21s/it]


VBox(children=(Label(value=' 92.24MB of 117.06MB uploaded (0.00MB deduped)\r'), FloatProgress(value=0.78795254…

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced
100%|██████████| 1081/1081 [00:09<00:00, 117.37it/s]


1080


 41%|████      | 23/56 [02:27<03:32,  6.44s/it]