In [1]:
import numpy as np  # Load required libs
import pandas as pd
import torch
import torchvision
from typing import Tuple, List, Type, Dict, Any
from sklearn.utils import shuffle
from torch.autograd import Variable
from threading import Thread, Lock
import os
import pickle
import datetime
import gzip

In [6]:
class Threadsafe_iter:
    """
    Takes an iterator/generator and makes it thread-safe by
    serializing call to the `next` method of given iterator/generator.
    You have to use it in a such way - Threadsafe_iter(get_objects_i(len(#your_unsafe_list)))
    """
    def __init__(self, it):
        self.it = it
        self.lock = Lock()

    def __iter__(self):
        return self

    def __next__(self):
        with self.lock:
            return next(self.it)

def get_objects_i(objects_count):
    """Cyclic generator of paths indices"""
    current_objects_id = 0
    while True:
        yield current_objects_id
        current_objects_id  = (current_objects_id + 1) % objects_count

In [20]:
class Cifar_dataset:
    def __init__(self,
                 batch_size: int,
                 path_to_batches: str,
                 transform_cpu = None,
                 train : bool = True):
        self.batch_size = batch_size
        self.train = train
        self.transform_cpu = transform_cpu
        self.path_to_batches = path_to_batches

        self._read_data()

        self.new_epoch = True
        self.yield_lock = Lock()
        self.lock = Lock()
        self.safe_iterator = Threadsafe_iter(get_objects_i(len(self.indexes)))
        self.shuffle()

    def unpickle(file):
        import pickle
        with open(file, 'rb') as fo:
            d = pickle.load(fo, encoding='latin1')
        return d

    def shuffle(self):
        self.indexes = shuffle(self.indexes)

    def __len__(self):
        return len(self.indexes)

    def get_batch_size(self):
        return self.batch_size

    def get_steps_per_epoch(self):
        return len(self) // self.get_batch_size() + 1

    def __iter__(self):
        """
        TODO!
        """ 
        while True:
            with self.lock:
                if self.new_epoch:
                    self.new_epoch = False
                    self.shuffle()
                    self.batch = []
            for ID_gen in self.safe_iterator:  # Solve thread sequence problem, making each thread await for the previous one
                img = self.data[ID_gen].reshape((3, 32, 32))
                img = torch.from_numpy(img)
                target = torch.tensor(self.target[ID_gen])
                if self.transform_cpu:
                    img = self.transform_cpu(img)
                data = {'img':img, 'target':target}                                    
                with self.yield_lock:  # Solve thread sequence problem making each tread await for the previous one
                    if len(self.batch) < self.batch_size:
                        self.batch.append(data)
                    if len(self.batch) % self.batch_size == 0:                   
                        yield self.batch
                        self.batch = []
            with self.lock:
                self.new_epoch = True
                self.shuffle()

    def _read_data(self):
        if self.train:
            if self.path_to_batches == '':
                state_filenames = list('data_batch_'+str(i+1) for i in range(5))
            else:
                state_filenames = list(self.path_to_batches + '/'+'data_batch_'+str(i+1) for i in range(5))
        else:
            if self.path_to_batches == '':
                state_filenames = ['test_batch', ]
            else:
                state_filenames = [self.path_to_batches + '/' + 'test_batch']

        self.data = None
        self.target = None
        for filename in state_filenames:
            if self.data is None and self.target is None:
                data_dict = Cifar_dataset.unpickle(filename)
                self.target = data_dict['labels']
                self.data = data_dict['data']
            else:
                data_dict = Cifar_dataset.unpickle(filename)
                self.target = np.concatenate((self.target, data_dict['labels']))
                self.data = np.concatenate((self.data, data_dict['data']))

        self.indexes = np.arange(0, len(self.data))

In [None]:
class Dataset:
    def __init__(self,
                 batch_size : int,
                 pathes_list,
                 ID_to_number,
                 train : bool = True,
                 transform_cpu = None):
        
        self.pathes_list = pathes_list
        self.ID_to_number = ID_to_number
        self.batch_size = batch_size
        self.transform_cpu = transform_cpu
        self.train = train

        self._create_correspondence()

        self.new_epoch = True
        self.yield_lock = Lock()
        self.lock = Lock()
        self.safe_iterator = Threadsafe_iter(get_objects_i(len(self.pathes_list)))
        self.shuffle()

    def shuffle(self):
        self.pathes_list = shuffle(self.pathes_list)
    
    def __len__(self):
        return len(self.pathes_list)

    def get_batch_size(self):
        return self.batch_size

    def get_steps_per_epoch(self):
        return len(self) // self.get_batch_size() + 1
    
    def __iter__(self):
        """
        TODO!
        """ 
        while True:
            with self.lock:
                if self.new_epoch:
                    self.new_epoch = False
                    self.shuffle()
                    self.batch = []
            for ID_gen in self.safe_iterator:  # Solve thread sequence problem, making each thread await for the previous one
                if self.train:
                    path1 = self.pathes_list[ID_gen]
                    ID = path1.split('_')[-2].split('/')[-1]

                    path2 = path1
                    ID_list = self.ID_path_correspondence[ID]
                    while path2 == path1:
                        path2 = ID_list[np.random.randint(len(ID_list))]
                    
                    anchor = torchvision.io.read_image(path1).float()
                    positive = torchvision.io.read_image(path2).float()
                    ID = torch.tensor(self.ID_to_number[ID])

                    if self.transform_cpu:
                        anchor = self.transform_cpu(anchor)
                        positive = self.transform_cpu(positive)
                    data = {'anchor':anchor, 'positive':positive, 'ID':ID}
                else:
                    path1 = self.pathes_list[ID_gen]
                    ID = path1.split('_')[-2].split('/')[-1]
                    
                    anchor = torchvision.io.read_image(path1).float()
                    ID = torch.tensor(self.ID_to_number[ID])

                    if self.transform_cpu:
                        anchor = self.transform_cpu(anchor)
                    data = {'anchor':anchor, 'ID':ID}

                                    
                with self.yield_lock:  # Solve thread sequence problem making each tread await for the previous one
                    if len(self.batch) < self.batch_size:
                        self.batch.append(data)
                    if len(self.batch) % self.batch_size == 0:                   
                        yield self.batch
                        self.batch = []
            with self.lock:
                self.new_epoch = True
                self.shuffle()

    def get_batch_size(self):
        return self.batch_size

    def get_steps_per_epoch(self):
        return len(self) // self.batch_size + 1

    def _create_correspondence (self):
        """
        This method creates two containers:

        *ID_path_correspondence - takes ID and return list of pathes
        ID <-> Person, path <-> concrete face

        *ID_to_number - takes ID and returns its 'number' (ID's aren't
        distrubuted like 0,1,2...)
        """
        self.ID_path_correspondence = {}
        for file_path in self.pathes_list:
            file_id = file_path.split('_')[-2].split('/')[-1]

            if file_id in self.ID_path_correspondence:
                self.ID_path_correspondence[file_id].append(file_path)
            else:
                self.ID_path_correspondence[file_id] = [file_path, ]

    def make_train_val_split(batch_size : int,
                             ratio : int = 0.8,
                             path : str = '',
                             cpu_transform_train = None,
                             cpu_transform_val = None,
                             least_size = 10,
                             seed = 42):
        if path == '':
            files_path = 'CASIA-WebFace_crop/'
        else:
            files_path = path + '/' + 'CASIA-WebFace_crop/'

        filenames_list = os.listdir(files_path)
        ID_path_correspondence = {}
        for filename in filenames_list:
            ID = filename.split('_')[-2]
            file_path = files_path + filename
            if ID in ID_path_correspondence:
                ID_path_correspondence[ID].append(file_path)
            else:
                ID_path_correspondence[ID] = [file_path, ]

        train_file_pathes = []
        val_file_pathes = []
        IDs_to_delete = []

        np.random.seed(seed)

        for ID in ID_path_correspondence:
            if len(ID_path_correspondence[ID]) >= least_size:
                mask = np.random.rand(len(ID_path_correspondence[ID])) < ratio
                for idx, path in enumerate(ID_path_correspondence[ID]):
                    if mask[idx]:
                        train_file_pathes.append(path)
                    else:
                        val_file_pathes.append(path)
            else:
                IDs_to_delete.append(ID)
        for ID in IDs_to_delete:
            ID_path_correspondence.pop(ID, None)

        number = 0
        ID_to_number = {}
        for ID in ID_path_correspondence:
            ID_to_number[ID] = number
            number += 1

        return (Dataset(batch_size,
                       train_file_pathes,
                       ID_to_number,
                       train=True,
                       transform_cpu = cpu_transform_train),
                Dataset(batch_size,
                        val_file_pathes,
                        ID_to_number,
                        train=False,
                        transform_cpu = cpu_transform_val))