<a href="https://colab.research.google.com/github/SergeBurnt/ya_practicum_ds/blob/main/Lesson_4_CNN/FC%2BCNN_classification_face.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# !wget https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/imdb_crop.tar
# !tar xf imdb_crop.tar

--2024-07-17 06:59:46--  https://data.vision.ee.ethz.ch/cvl/rrothe/imdb-wiki/static/imdb_crop.tar
Resolving data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)... 129.132.52.178, 2001:67c:10ec:36c2::178
Connecting to data.vision.ee.ethz.ch (data.vision.ee.ethz.ch)|129.132.52.178|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 7012157440 (6.5G) [application/x-tar]
Saving to: ‘imdb_crop.tar’


2024-07-17 07:12:07 (9.04 MB/s) - ‘imdb_crop.tar’ saved [7012157440/7012157440]



In [5]:
import numpy as np
import cv2
import torch
from albumentations.pytorch import ToTensorV2
import albumentations as A
from scipy.io import loadmat
from torch.utils.data import Dataset
import random
import os
from dataclasses import dataclass

In [27]:
class ImdbWikiDataset(Dataset):
    def __init__(self, image_size: int = 128):
        imdb_dat = loadmat("imdb_crop/imdb.mat")["imdb"][0][0]
        imdb_paths = [f"imdb_crop/{path[0]}" for path in imdb_dat[2][0]]
        imdb_genders = imdb_dat[3][0]
        bad_indices = set(np.where(np.isnan(imdb_genders))[0])
        imdb_paths = [x for i, x in enumerate(imdb_paths) if i not in bad_indices]
        imdb_genders = [int(x) for i, x in enumerate(imdb_genders) if i not in bad_indices]

        self.paths = imdb_paths
        self.labels = imdb_genders
        self.transforms = A.Compose(
            [
                A.Resize(image_size,image_size),
                A.HorizontalFlip(p=0.5),
                A.ToFloat(max_value=255),
                ToTensorV2()
            ]
        )

        assert len(self.imdb_paths) == len(self.imdb_genders)

    def __getitem__(self, index):
        img_numpy = cv2.imread(self.paths[index])
        img_tensor = self.transforms(image=img_numpy)["image"]
        label = self.labels[index]

        return img_tensor, label

    def __len__(self):
        return len(self.paths)

In [29]:
@dataclass
class Config:
    seed: int = 0

    batch_size: int = 64
    do_shuffle_train: bool = True
    img_size: int = 128
    ratio_train_val_test: tuple[float, float, float] = (0.8, 0.1, 0.1)

    hidden_dim: int = 512
    p_dropout: float = 0.3

    n_epochs: int = 10
    eval_every: int = 2000
    lr: float = 1e-5

In [34]:
def enable_determinism():
    os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
    torch.use_deterministic_algorithms(True)

def fix_seeds(seed: int):
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

In [35]:
config = Config()
enable_determinism()
fix_seeds(seed=config.seed)