In [1]:
!pip install -q qqdm
!pip install animeface

  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.6/11.6 MB[0m [31m52.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m117.2/117.2 kB[0m [31m7.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.4/76.4 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m78.0/78.0 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.1/69.1 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m383.6/383.6 kB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m133.5/133.5 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.7/59.7 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00

In [2]:
!gdown --id '1IJdJcXPNN_B7mMMPQO950szHZgFVvpT_' --output "crypko_data.zip"

Downloading...
From (original): https://drive.google.com/uc?id=1IJdJcXPNN_B7mMMPQO950szHZgFVvpT_
From (redirected): https://drive.google.com/uc?id=1IJdJcXPNN_B7mMMPQO950szHZgFVvpT_&confirm=t&uuid=028d8d3f-1ec9-4366-be81-45dfc8493a12
To: /content/crypko_data.zip
100% 479M/479M [00:06<00:00, 75.9MB/s]


In [3]:
!unzip -q "crypko_data.zip" -d "./"

In [None]:
workspace_dir = '.'

import random
import torch
import numpy as np
import os
import glob
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from qqdm.notebook import qqdm
import shutil

In [None]:
def same_seeds(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

same_seeds(2023)

In [None]:
class CrypkoDataset(Dataset):
    def __init__(self, fnames, transform=None):
        self.transform = transform
        self.fnames = fnames
        self.num_samples = len(self.fnames)

    def __getitem__(self, idx):
        fname = self.fnames[idx]
        img = torchvision.io.read_image(fname)
        img = img.float() / 255.0
        if self.transform:
            img = self.transform(img)
        return img
    def __len__(self):
        return self.num_samples


def get_dataset(root):
    fnames = glob.glob(os.path.join(root, '*'))
    compose = [
        transforms.ToPILImage(),
        transforms.Resize((64, 64)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
    ]
    transform = transforms.Compose(compose)
    dataset = CrypkoDataset(fnames, transform)
    return dataset

dataset = get_dataset(os.path.join(workspace_dir, 'faces'))


def visualize_samples(dataset, num_samples=16):
    images = [dataset[i] for i in range(num_samples)]
    grid_img = torchvision.utils.make_grid(images, nrow=4)
    plt.figure(figsize=(8,8))
    plt.imshow(grid_img.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()

visualize_samples(dataset)

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [None]:
class Generator(nn.Module):

    def __init__(self, in_dim, dim=64):
        super(Generator, self).__init__()
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(dim * 8 * 4 * 4),
            nn.ReLU(True)
        )
        self.deconv_blocks = nn.Sequential(
            nn.ConvTranspose2d(dim * 8, dim * 4, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(dim * 4),
            nn.ReLU(True),

            nn.ConvTranspose2d(dim * 4, dim * 2, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(dim * 2),
            nn.ReLU(True),

            nn.ConvTranspose2d(dim * 2, dim, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(dim),
            nn.ReLU(True),

            nn.ConvTranspose2d(dim, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )
        self.apply(weights_init)

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.deconv_blocks(y)
        return y

In [None]:
class Discriminator(nn.Module):

    def __init__(self, in_channels=3, dim=64):
        super(Discriminator, self).__init__()
        self.conv_blocks = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(in_channels, dim, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),

            nn.utils.spectral_norm(nn.Conv2d(dim, dim * 2, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(dim * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.utils.spectral_norm(nn.Conv2d(dim * 2, dim * 4, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(dim * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.utils.spectral_norm(nn.Conv2d(dim * 4, dim * 8, kernel_size=4, stride=2, padding=1)),
            nn.BatchNorm2d(dim * 8),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(dim * 8, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )
        self.apply(weights_init)

    def forward(self, x):
        y = self.conv_blocks(x)
        return y.view(-1)

In [None]:
class EarlyStopping:

    def __init__(self, patience=5, min_delta=0.0):

        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
            return

        if self.best_loss - loss > self.min_delta:
            self.best_loss = loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
class EarlyStopping:

    def __init__(self, patience=5, min_delta=0.0):

        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, loss):
        if self.best_loss is None:
            self.best_loss = loss
            return

        if self.best_loss - loss > self.min_delta:
            self.best_loss = loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

batch_size = 64
z_dim = 100
lr = 0.0002
n_epoch = 20
beta1 = 0.5
beta2 = 0.999
patience = 3
min_delta = 0.0001
log_dir = os.path.join(workspace_dir, 'logs')
ckpt_dir = os.path.join(workspace_dir, 'checkpoints')
os.makedirs(log_dir, exist_ok=True)
os.makedirs(ckpt_dir, exist_ok=True)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = Generator(in_dim=z_dim).to(device)
D = Discriminator().to(device)
G.train()
D.train()
criterion = nn.BCELoss()
opt_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))
opt_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
scheduler_D = torch.optim.lr_scheduler.StepLR(opt_D, step_size=10, gamma=0.5)
scheduler_G = torch.optim.lr_scheduler.StepLR(opt_G, step_size=10, gamma=0.5)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
early_stopping = EarlyStopping(patience=patience, min_delta=min_delta)
steps = 0
for epoch in range(n_epoch):
    progress_bar = qqdm(dataloader, desc=f'Epoch {epoch+1}/{n_epoch}')
    epoch_loss_G = 0.0
    for i, data in enumerate(progress_bar):
        imgs = data.to(device, non_blocking=True)
        bs = imgs.size(0)
        real_labels = torch.full((bs,), 0.9, device=device)
        fake_labels = torch.zeros(bs, device=device)
        D.zero_grad()
        outputs = D(imgs)
        loss_real = criterion(outputs, real_labels)

        z = torch.randn(bs, z_dim, device=device)
        fake_imgs = G(z)
        outputs = D(fake_imgs.detach())
        loss_fake = criterion(outputs, fake_labels)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        opt_D.step()
        G.zero_grad()
        z = torch.randn(bs, z_dim, device=device)
        fake_imgs = G(z)
        outputs = D(fake_imgs)
        loss_G = criterion(outputs, real_labels)
        loss_G.backward()
        opt_G.step()
        epoch_loss_G += loss_G.item()
        progress_bar.set_postfix({
            'Loss_D': loss_D.item(),
            'Loss_G': loss_G.item(),
            'Epoch': epoch+1,
            'Step': steps+1,
        })
        steps += 1


    avg_loss_G = epoch_loss_G / len(dataloader)
    print(f'\nEpoch {epoch+1}  GLoss: {avg_loss_G:.4f}')


    early_stopping(avg_loss_G)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break
    scheduler_D.step()
    scheduler_G.step()
    G.eval()
    with torch.no_grad():
        z_sample = torch.randn(100, z_dim, device=device)
        f_imgs_sample = G(z_sample).cpu()
        f_imgs_sample = (f_imgs_sample + 1) / 2.0
        filename = os.path.join(log_dir, f'Epoch_{epoch+1:03d}.jpg')
        torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
        print(f' | Saved sample images to {filename}.')


        grid_img = torchvision.utils.make_grid(f_imgs_sample, nrow=10)
        plt.figure(figsize=(10,10))
        plt.imshow(grid_img.permute(1, 2, 0).numpy())
        plt.axis('off')
        plt.show()
    G.train()


    if (epoch + 1) % 5 == 0:
        torch.save(G.state_dict(), os.path.join(ckpt_dir, f'G_epoch_{epoch+1}.pth'))
        torch.save(D.state_dict(), os.path.join(ckpt_dir, f'D_epoch_{epoch+1}.pth'))
        print(f' Saved model  {epoch+1}.')






In [None]:
G.eval()
with torch.no_grad():
    n_output = 1000
    batch_size_gen = 100
    generated_images = []
    for _ in range(n_output // batch_size_gen):
        z = torch.randn(batch_size_gen, z_dim, device=device)
        imgs = G(z).cpu()
        imgs = (imgs + 1) / 2.0
        generated_images.append(imgs)
    generated_images = torch.cat(generated_images, dim=0)
output_dir = 'output'
os.makedirs(output_dir, exist_ok=True)
for i in range(n_output):
    torchvision.utils.save_image(generated_images[i], os.path.join(output_dir, f'{i+1}.jpg'))
shutil.make_archive('images', 'gztar', output_dir)
print('saved output to  into images2.tar,gz')