In [None]:
import random
import torchvision
import torch
import matplotlib.pyplot as plt
import numpy as np
import time
from tqdm.notebook import tqdm
import os
import zipfile
from PIL import Image
from collections import defaultdict
import copy
from torch.nn import ReLU, Conv2d, BatchNorm2d, Sequential, AdaptiveAvgPool2d, Linear, MaxPool2d, Flatten, CrossEntropyLoss, PReLU, InstanceNorm2d, LeakyReLU, AvgPool2d
try:
    import pytorch_lightning as pl
except:
    !pip install pytorch-lightning
    import pytorch_lightning as pl

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  # We'd like to use GPU
device

In [None]:
class CelebADataset:
    """
    Map-style dataset of celeb-faces, contain about 200k photoes of about 10k 
    persons.
    --Class method make_train_test_dataset returns train and test datasets with
    adequate class-split structure (say, if there are 10 photoes of person A
    will train split contain 8 photoes, test split will contain remaining
    photoes, i.e. splitting works for classes)
    Args:
    -zip_path - path to zipfile which contains photoes
    -annotation_path - path to annotation file
    -extraction_path - path where photoes will be extracted
    -transform_train - transformation which will be applied to train images
    -transform_test - transformation which will be applied to test images
    -train_test_ration - multiplied by 100, this number shows the percentage of
    images will be placed in train dataset(other imgs will be placed to 
    test_dataset)
    -seed - random seed, requires for reproducibility and ensembles
    -min_num_imgs_in_class - minimum number of classes when 
    """
    def __init__(self, pathes_list, transform=None, if_cache=False):
        self.__pathes_list = pathes_list
        self.__transform = transform

        self.if_cache = if_cache
        if if_cache:
            self.img_list = []
            with tqdm(total=len(self.__pathes_list)) as pbar:
                for path in pathes_list:
                    pbar.update()
                    self.img_list.append(torchvision.io.read_image(path, mode = torchvision.io.ImageReadMode.RGB))

    def __len__(self):
        return len(self.__pathes_list)

    def __getitem__(self, idx):
        if not self.if_cache:
            path = self.__pathes_list[idx]
            img = torchvision.io.read_image(path, mode = torchvision.io.ImageReadMode.RGB)
            if self.__transform:
                img = self.__transform(img)
            return [img, ]
        else:
            img = self.img_list[idx]
            if self.__transform:
                img = self.__transform(img)
            return [img, ]

    def make_dataset(zip_path="drive/MyDrive/GitHub/NN studying/NN_studying/face recognition/data/img_align_celeba.zip",
                     annotation_path = "drive/MyDrive/GitHub/NN studying/NN_studying/face recognition/data/identity_CelebA.txt",
                     extract_path="",
                     transform=None,
                     ratio=0.15,
                     seed=1337,
                     if_cache=False):
        if not extract_path == "":
            if extract_path[-1] != "/":
                extract_path += "/"

        if not os.path.exists(extract_path):
            data_zip = zipfile.ZipFile(zip_path)
            data_zip.extractall(extract_path)
            data_zip.close() 
        
        extract_path += "img_align_celeba/"

        pathes_list = []
        with open(annotation_path) as f:
            for line in f:
                path, id = line.split('\n')[0].split(' ')
                id = int(id)
                pathes_list.append(extract_path+path)

        np.random.seed(seed)
        pathes_list = np.random.choice(pathes_list, size=int(ratio*len(pathes_list)), replace=False)

        return CelebADataset(pathes_list, transform, if_cache)

In [None]:
def get_dtype(dtype_str):
    if dtype_str == 'fl32':
        return torch.float
    elif dtype_str == 'fl16':
        return torch.half

class LinearHe(torch.nn.Module):
    def __init__(self, in_f, out_f, bias=True, dtype='fl32'):
        super(LinearHe, self).__init__()
        dtype = get_dtype(dtype)

        he_weight = torch.tensor(torch.sqrt(2 / in_f), dtype=dtype)
        self.weight = torch.nn.Parameter(torch.randn(size=(out_f, in_f), dtype=dtype) * he_weight)
        self.bias = torch.nn.Parameter(torch.zeros(size=(out_f, ), dtype=dtype)) if bias else None

    def forwaard(self, x):
        return torch.nn.functional.linear(x, self.weight, self.bias)

class Conv2dHe(torch.nn.Module):
    def __init__(self, in_c, out_c, kernel_size=3, padding=0, stride=1, bias=False, dtype='fl32'):
        super(Conv2dHe, self).__init__()
        dtype = get_dtype(dtype)
        self.padding = padding
        self.stride = stride

        he_weight = torch.tensor(torch.sqrt(2 / (in_c * (kernel_size ** 2))), dtype=dtype)
        self.weight = torch.nn.Parameter(torch.randn(size=(out_c, in_c, kernel_size, kernel_size), dtype=dtype))
        self.bias = torch.nn.Parameter(torch.zeros(size=(out_c, ), dtype=dtype)) if bias else None

    def forward(self, x):
        return torch.nn.functional.conv2d(x, self.weight, self.bias, padding=self.padding, stride=self.stride)

class AdaIN(torch.nn.Module):
    def __init__(self, eps=1e-8):
        super(AdaIN, self).__init__()
        self.eps = eps

    def forward(self, x, y_s, y_i):
        x_mean, x_std = torch.mean(x, dim=(2, 3), keepdim=True), torch.std(x, dim=(2, 3), keepdim=True)
        x = (x - x_mean) / (x_std + eps)
        return (x * y_s + y_i)