In [46]:
import numpy as np
import pandas as pd
import os
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.io import read_image
torch.cuda.empty_cache()

import csv
import glob
from PIL import Image
import tensorflow as tf

import warnings
warnings.filterwarnings("ignore")

In [None]:
!gdown --id 1Of_EVz-yHV7QVWQGihYfvtny9Ne8qXVz -O CASIA-WebFace.zip
!unzip /content/CASIA-WebFace.zip
!rm /content/CASIA-WebFace.zip

In [None]:
row = []
data = []

labels_paths = glob.glob('/content/CASIA-WebFace/' + '*')
n_labels = len(labels_paths)

for i in range(n_labels):
  for j in range(len(glob.glob(labels_paths[i] + '/*.jpg'))):
    data.append([glob.glob(labels_paths[i] + '/*.jpg')[j], labels_paths[i].split('/')[-1]])

data

In [61]:
with open('data.csv', 'w') as f:
  writer = csv.writer(f)
  writer.writerow(['image', 'label'])
  writer.writerows(data)

In [64]:
df = pd.read_csv('/content/data.csv')
df = df.sample(frac=1)
df.head()

Unnamed: 0,image,label
110519,/content/CASIA-WebFace/0355910/029.jpg,355910
346966,/content/CASIA-WebFace/1900981/034.jpg,1900981
328483,/content/CASIA-WebFace/1058940/041.jpg,1058940
196373,/content/CASIA-WebFace/1556320/123.jpg,1556320
247693,/content/CASIA-WebFace/0251986/071.jpg,251986


In [91]:
class CasiaDataset(Dataset):
  def __init__(self, 
               imgs_path,
               csv_file,
               transform
               ):
    self.imgs_path = imgs_path
    self.data = pd.read_csv(csv_file)
    self.transform = transform

  def __len__(self):
    return len(self.data)

  def __getitem__(self, index):
    self.data = self.data.sample(frac=1)
    
    image = read_image(self.data.iloc[index, 0])
    label = self.data.iloc[index, 1]

    if self.transform:
      image = self.transform(image)

    return image, label

In [92]:
transform = T.Compose([
                       T.Resize(128),
                       T.Grayscale()
                       ])

my_dataset = CasiaDataset('/content/CASIA-WebFace/',
                          '/content/data.csv',
                          transform)
loader = DataLoader(my_dataset, 
                    batch_size=20, 
                    shuffle=False)

In [93]:
images, labels = next(iter(loader))
# print(images.shape)
# print(len(labels))
# print(labels)

In [94]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
device

device(type='cpu')

In [120]:
class BlockDiscriminator(nn.Module):
  def __init__(self, in_ch, out_ch, stride, type):
    super().__init__()
    if type == 1:
      self.block_disc = nn.Sequential(
          nn.Conv2d(in_ch, out_ch, kernel_size=2)
      )
    elif type == 2:
      self.block_disc = nn.Sequential(
          nn.Conv2d(in_ch, out_ch, kernel_size=3, stride = stride, padding = 1),
          nn.BatchNorm2d(out_ch),
          nn.LeakyReLU()
      )
    else:
      self.block_disc = nn.Sequential(
          nn.Conv2d(in_ch, out_ch, kernel_size=3, stride = stride),
          nn.LeakyReLU()
      )
  
  def forward(self, input):
    return self.block_disc(input)

In [121]:
class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.disc = nn.ModuleList([
                               BlockDiscriminator(1, 1, 2, type=1),
                               BlockDiscriminator(1, 1, 2, type=2),
                               BlockDiscriminator(1, 1, 2, type=2),
                               BlockDiscriminator(1, 1, 1, type=2),
                               BlockDiscriminator(1, 1, 1, type=3),
                               nn.Sigmoid()
    ])
  
  def forward(self, input):
    for layer in self.disc:
      output = layer(input)
      input = output
    return output


In [122]:
input = torch.randn(1, 1, 128, 128)
model = Discriminator()
output = model(input)
output.shape

torch.Size([1, 1, 30, 30])

In [123]:
## Encoder as U-Net
class Enc_block(nn.Module):
  def __init__(self, in_ch, out_ch):
    super().__init__()
    self.enc_block = nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size = 4, stride = 2, padding = 1),
        nn.BatchNorm2d(out_ch),
        nn.LeakyReLU(inplace=True),
    )

  def forward(self, input):
    return self.enc_block(input)

In [124]:
class Dec_block(nn.Module):
  def __init__(self, in_ch, out_ch):
    super().__init__()
    self.dec_block = nn.Sequential(
        nn.ConvTranspose2d(in_ch, out_ch, kernel_size = 4, stride = 2, padding = 1),
        nn.BatchNorm2d(out_ch),
        nn.ReLU(inplace=True),
    )

  def forward(self, input):
    return self.dec_block(input)

In [125]:
class Unet(nn.Module):
  def __init__(self, in_ch = 1, out_ch = 1, features=[4, 8, 16, 32, 64, 128]):
    super().__init__()
    self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

    self.downs = nn.ModuleList([
                                Enc_block(4, 8),
                                Enc_block(8, 16),
                                Enc_block(16, 32),
                                Enc_block(32, 64),
                                Enc_block(64, 128),
                                nn.Sequential(
                                    nn.Conv2d(128, 256, kernel_size = 4, stride = 2, padding = 1), 
                                    nn.LeakyReLU(),
                                )                                
    ])

    self.ups = nn.ModuleList([
                              nn.Sequential(
                                  nn.ConvTranspose2d(512, 128, kernel_size=3),
                                  nn.BatchNorm2d(128),
                                  nn.Dropout(0.5),
                                  nn.ReLU(inplace=True)),
                              Dec_block(256, 64),
                              Dec_block(128, 32),
                              Dec_block(64, 16),
                              Dec_block(32, 8),
                              nn.Sequential(
                                  nn.ConvTranspose2d(16, 4, kernel_size=4, stride=2, padding=1),
                                  nn.BatchNorm2d(4),
                              )
    ])
    
    self.bottleneck = nn.Sequential(
        nn.ConvTranspose2d(256, 256, kernel_size=4, padding=1),
        nn.BatchNorm2d(256),
        nn.Dropout(0.5),
        nn.ReLU(inplace=True),
        nn.MaxPool2d(2)
    )

    self.final_conv = nn.Sequential(
        nn.ConvTranspose2d(4, 1, kernel_size=5),
        nn.ReLU(inplace=True),
    )

  def forward(self, input):
    skip_connections = []
    inp_mod = nn.Conv2d(1, 4, kernel_size = 4)
    input = inp_mod(input)
    for down in self.downs:
      input = down(input)
      skip_connections.append(input)

    input = self.bottleneck(input)
    skip_connections = skip_connections[::-1] ### reversed list
    
    skip_connections_shape = []
    for i in skip_connections:
      skip_connections_shape.append(i.shape)

    for idx in range(0, len(self.ups), 1):
      skip_connection = skip_connections[idx]
      if input.shape != skip_connection.shape:
        input = TF.resize(input, size = skip_connection.shape[2:])
      concat_skip = torch.cat((skip_connection, input), dim = 1)
      input = self.ups[idx](concat_skip)
      
    return self.final_conv(input)

In [126]:
input = torch.randn(1, 1, 128, 128)
model = Unet()
output = model(input)
output.shape

torch.Size([1, 1, 128, 128])

In [127]:
discriminator = Discriminator().to(device)
generator = Unet(n_noise).to(device)

criterion = nn.BCELoss()
D_opt = torch.optim.Adam(discriminator.parameters(), lr = 0.002)
G_opt = torch.optim.Adam(generator.parameters(), lr = 0.002)

In [128]:
lr = 3e-4
batch_size = 20
num_epochs = 10
n_noise = 512
step = 0
n_critic = 1

# fixed_noise = torch.randn(batch_size, n_noise, 4, 4).to(device)

D_labels = torch.ones(batch_size, 1, 30, 30).to(device)
D_fakes = -torch.ones(batch_size, 1, 30, 30).to(device)


In [129]:
for epoch in range(num_epochs):
  for idx, (images, _) in enumerate(loader):
    x = images.float().to(device) # real
    x_output = discriminator(x) # disc of real 
    D_x_loss = criterion(x_output, D_labels) # loss disc real 

    z = torch.randn(batch_size, 1, 128, 128).to(device) # noise
    z_output = generator(z) # gen fake
    z_output = discriminator(generator(z)) # disc of fake 
    D_z_loss = criterion(z_output, D_fakes) # loss disc fake
    D_loss = D_x_loss + D_z_loss # total loss disc

    D_opt.zero_grad()
    D_loss.backward()
    D_opt.step()

    # Training generator
    if step % n_critic == 0:
      z = torch.randn(batch_size, 1, 128, 128).to(device)
      z_outputs = discriminator(generator(z))
      G_loss = criterion(z_outputs, D_labels)

      G_opt.zero_grad()
      G_loss.backward()
      G_opt.step()
        
   
  print(
      f"Epoch [{epoch} / {num_epochs}] \ "
      f"Loss D: {D_loss:.4f}, Loss G: {G_loss:.4f}"
  )

TypeError: ignored