In [None]:
import numpy as np  # Load required libs
import pandas as pd
import torch
import torchvision
from typing import Tuple, List, Type, Dict, Any
from sklearn.utils import shuffle
from torch.autograd import Variable
from threading import Thread, Lock
import os
import pickle
import datetime
import gzip

In [None]:
class Threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    You have to use it in a such way - Threadsafe_iter(get_objects_i(len(#your_unsafe_list)))
    """
    def __init__(self, it):
        self.it = it
        self.lock = Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return next(self.it)

def get_objects_i(objects_count):
    """Cyclic generator of paths indices"""
    current_objects_id = 0
    while True:
        yield current_objects_id
        current_objects_id  = (current_objects_id + 1) % objects_count

In [None]:
class CelebADataset:
    def __init__(self, id_path_correspondence, transform=None):
        self.__id_path_correspondence = id_path_correspondence
        self.__transform = transform

    def __len__(self):
        return len(self.__id_path_correspondence)

    def __getitem__(self, idx):
        path, label = self.__id_path_correspondence[idx]
        img = torchvision.io.read_image(path, mode = torchvision.io.ImageReadMode.RGB)
        if self.__transform:
            img = self.__transform(img)
        return img, label    

    def make_train_test_datasets(zip_path="drive/MyDrive/GitHub/NN studying/NN_studying/metric learning - face recognition(like FaceID)/data/img_align_celeba.zip",
                                 annotation_path = "drive/MyDrive/GitHub/NN studying/NN_studying/metric learning - face recognition(like FaceID)/data/identity_CelebA.txt",
                                 extract_path="",
                                 transform_train=None,
                                 transform_test=None,
                                 train_test_ratio=0.8,
                                 seed=42,
                                 min_num_imgs_in_class=10):
        if not extract_path == "":
            if extract_path[-1] != "/":
                extract_path += "/"

        if not os.path.exists(extract_path):
            data_zip = zipfile.ZipFile(zip_path)
            data_zip.extractall(extract_path)
            data_zip.close() 
        
        extract_path += "img_align_celeba/"

        id_path_correspondence = {}
        with open(annotation_path) as f:
            for line in f:
                path, id = line.split('\n')[0].split(' ')
                id = int(id)
                if id in id_path_correspondence:
                    id_path_correspondence[id].append(extract_path + path)
                else:
                    id_path_correspondence[id] = [extract_path + path,]

        ids_to_delete = []
        for id, pathes in id_path_correspondence.items():
            if len(pathes) < min_num_imgs_in_class:
                ids_to_delete.append(id)
        for id_to_delete in ids_to_delete:
            id_path_correspondence.pop(id_to_delete)

        id_path_correspondence_clear = {}
        for i, (init_id, pathes) in enumerate(id_path_correspondence.items()):
            id_path_correspondence_clear[i] = pathes

        id_path_correspondence_train = []
        id_path_correspondence_test = []
        np.random.seed(seed)
        num_classes = 0
        for id, pathes in id_path_correspondence_clear.items():
            num_classes += 1
            mask = np.linspace(0, 1, len(pathes)) < train_test_ratio
            np.random.shuffle(mask)
            id_path_correspondence_train += list((path, id) for path, mask_i in zip(pathes, mask) if mask_i)
            id_path_correspondence_test += list((path, id) for path, mask_i in zip(pathes, mask) if not mask_i)
        print("num_classes - {}".format(num_classes))

        return CelebADataset(id_path_correspondence_train, transform_train), CelebADataset(id_path_correspondence_test, transform_test)

In [None]:
class FlexDataloader:
    def __init__(self,
                 batch_size: int,
                 dataset):
        self.__batch_size = batch_size
        self.__dataset = dataset
        
        self._yield_lock = Lock()
        self._lock = Lock()
        self.__make_indices_and_safe_iter()
        
    def __make_indices_and_safe_iter(self):
        self.__dataset_indices = np.arange(len(self.__dataset))
        self.__safe_iter = Threadsafe_iter(get_objects_i(len(self.__dataset)))

        self.__yielded_batches = 0
        self.__num_batches_in_epoch = self.get_steps_per_epoch()
                
    def shuffle(self):
        self.__dataset_indices = shuffle(self.__dataset_indices)
        
    def __len__(self):
        return len(self.__dataset_indices)
    
    def __iter__(self):
        with self._lock:
            self.shuffle()
            self.batch = []
        while True:
            for pre_ID in self.__safe_iter:
                ID = self.__dataset_indices[pre_ID]
                data = self.__dataset[ID]
                
                with self._yield_lock:
                    if len(self.batch) < self.__batch_size:
                        self.batch.append(data)
                    if len(self.batch) % self.__batch_size == 0:
                        yield self.batch
                        self.batch = []
                        self.__yielded_batches += 1
                        if self.__yielded_batches > self.__num_batches_in_epoch:
                            self.shuffle()
                            self.__yielded_batches = 0
                        
    def get_steps_per_epoch(self):
        return len(self) // self.get_batch_size() + 1
    
    def get_batch_size(self):
        return self.__batch_size