In [None]:
import numpy as np  # Load required libs
import pandas as pd
import torch
import torchvision
from torch.nn.functional import one_hot
from threading import Thread, Lock
from queue import Empty, Queue
import gc
import time

In [None]:
def clear_cuda():
    gc.collect()
    torch.cuda.empty_cache()

In [None]:
class Queue_with_cleaning(Queue):
    def __init__(self, maxsize):
        super(Queue_with_cleaning, self).__init__(maxsize)

    def clear(self):
        while not self.empty():
            try:
                self.get(block=False)
            except Empty:
                continue

In [None]:
class Controller(object):
    """
    Class for control threads and queues
    There is no need to create own threads for each task, just modificate their 
    behavior. So Controller exists. It is always parameter of thread-function,
    so it can be called when necessary.


    When task starts:
    1)Every worker calls controller.get_status() to understand what to DO
    2)Every worker calls controller.boil_process() to make controller know how
    many workers are working on the task
    3)Every worker starts it's own infinity loop with checking if status changed

    When task changes:
    1)Controller changes it's field __status to 'cleaning'
    -> workers when they'll have done their PART of work will STOP and call 
    controller.freeze_thread() to make controller know how many workers are
    still working
    2)Controller calls cleaning to free queues(because there are only to queues
    to reduce using memory).
    3)When there are no active threads(hot_threads == 0) 
    controller finally changes it's field __status to new status
    4)All workers are still in infinity loop, so it's starts with the first 
    point of 'When task starts'


    Once a thread is launched, it should be terminated at some moment.
    In case the function of this thread is an infinite loop, one needs a mutex
    for signaling a worker thread to break the loop.
    The fuction will return, and the thread will be terminated.
    """
    def __init__(self):
        self.to_kill = False
        self.__status = 'train'
        self.lock = Lock()

        self.hot_threads = 0

    def is_kill(self):
        return self.to_kill

    def set_tokill(self, tokill):
        self.to_kill = tokill

    def get_status(self):
        return self.__status

    def boil_process(self):
        with self.lock:
            self.hot_threads += 1
    
    def freeze_thread(self):
        with self.lock:
            self.hot_threads -= 1

    def cleaning(self, ram_queue, cuda_queue):
        self.__status = 'cleaning'
        
        ram_queue.clear()
        cuda_queue.clear()
        print('Number of hot threads - {}'.format(self.hot_threads))

        if self.hot_threads < 0:
            raise ValueError('Something went wrong, hot_threads cannot be < 0')
        if self.hot_threads > 0:
            time.sleep(1)
            self.cleaning(ram_queue, cuda_queue)

    def change_status(self, new_status, ram_queue, cuda_queue):
        self.cleaning(ram_queue, cuda_queue)
        self.__status = new_status
        time.sleep(1)
        print('{} processes have been starter'.format(self.hot_threads))

In [None]:
def getting_loop(controller, data_generator, queue):
    controller.boil_process()
    print('process has been started')
    for sample_batch in data_generator:
        queue.put(sample_batch, block=True)
        if controller.get_status() == 'cleaning':
            print('threads has been frozen')
            controller.freeze_thread()
            break
        if controller.is_kill():
            break

def threaded_batches_feeder(controller, train_generator, val_generator, ram_queue):
    """
    Threaded worker for taking data from data-generators and put it in queue
    Controlled by controller
    """
    status_generator_correspondence = {'train':train_generator, 'validate':val_generator}

    while not controller.is_kill():
        status = controller.get_status()

        if status == 'cleaning':
            time.sleep(1)
            continue
        else:
            current_generator = status_generator_correspondence[status]
            getting_loop(controller, current_generator, ram_queue)

In [None]:
def data_loop(controller, ram_queue, cuda_queue, img_transform, device, data_handler):
    controller.boil_process()
    print('process has been started')
    while not controller.is_kill():
        sample_batch = ram_queue.get(block=True)

        result_batch = None
        with torch.no_grad():
            result_batch = data_handler(sample_batch, img_transform=img_transform, device=device)
        cuda_queue.put(result_batch, block=True)
        
        if controller.is_kill():
            break
        if controller.get_status() == 'cleaning':
            print('thread has been frozen')
            controller.freeze_thread()
            break

def train_data_handler_casia(sample_batch, img_transform, device):
    result_batch = None
    for data in sample_batch:
        anchor = data['anchor'].to(device)
        anchor = anchor / 255
        positive = data['positive'].to(device)
        positive = positive / 255
        ID = data['ID'].to(device)

        anchor = img_transform(anchor)
        positive = img_transform(positive)

        anchor = anchor.reshape((1, *anchor.shape))
        positive = positive.reshape((1, *positive.shape))
        ID = ID.reshape((1, *ID.shape))

        data_pack = [anchor, positive, ID]
        if result_batch is None:
            result_batch = data_pack
        else:
            for idx, obj in enumerate(result_batch):
                result_batch[idx] = torch.cat((obj, data_pack[idx]))
    return result_batch

def val_data_handler_casia(sample_batch, img_tranform, device):
    result_batch = None
    for data in sample_batch:
        image = img_transform(data['anchor'].to(device))
        image = image / 255
        ID = data['ID'].to(device)

        image = image.reshape((1, *image.shape))
        ID = ID.reshape((1, *ID.shape))

        data_pack = [image, ID]

        if result_batch is None:
            result_batch = data_pack
        else:
            for idx,  obj in enumerate(result_batch):
                result_batch[idx] = torch.concat((obj, data_pack[idx]))
    return result_batch

def data_handler_cifar(sample_batch, img_transform, device):
    result_batch = None
    for data in sample_batch:
        img = data['img'].to(device)
        img = img / 255
        target = data['target'].to(device)

        img = img_transform(img)

        img = img.reshape((1, *img.shape))
        target = target.reshape((1, *target.shape))

        data_pack = [img, target]
        if result_batch is None:
            result_batch = data_pack
        else:
            for idx, obj in enumerate(result_batch):
                result_batch[idx] = torch.cat((obj, data_pack[idx]))
    return result_batch

In [None]:
def threaded_cuda_batches(controller,
                          ram_queue,
                          cuda_queue,
                          img_transform_train,
                          img_transform_val,
                          device,
                          dataset_name = 'cifar'):
    while not controller.is_kill():
        
        if dataset_name == 'cifar':
            status_handler_correspondence = {'train':data_handler_cifar, 'validate':data_handler_cifar}
        elif dataset_name == 'casia':
            status_handler_correspondence = {'train':train_data_handler_casia, 'validate':val_data_handler_casia}
        status_transform_correspondence = {'train':img_transform_train, 'validate':img_transform_val}
        while not controller.is_kill():
            status = controller.get_status()

            if status == 'cleaning':
                time.sleep(1)
                continue
            else:
                handler = status_handler_correspondence[status]
                img_transform = status_transform_correspondence[status]
                data_loop(controller, ram_queue, cuda_queue, img_transform, device, handler)

In [None]:
class Thread_data_processing:
    def __init__(self,
                 train_generator,
                 val_generator,
                 device,
                 img_transform_train,
                 img_transform_val,
                 ram_queue_length=24,
                 cuda_queue_length=10,
                 num_workers_ram=2,
                 num_workers_cuda=1,
                 dataset_name='cifar'):
        self.device = device
        self.dataset_name = dataset_name
        self.train_generator = train_generator
        self.val_generator = val_generator
        self.img_transform_train = img_transform_train
        self.img_transform_val = img_transform_val
        self.controller = Controller()

        self.ram_queue_length = ram_queue_length
        self.cuda_queue_length = cuda_queue_length
        self.num_workers_ram = num_workers_ram
        self.num_workers_cuda = num_workers_cuda

        self.ram_queue = Queue_with_cleaning(maxsize=ram_queue_length)
        self.cuda_queue = Queue_with_cleaning(maxsize=cuda_queue_length)

    def start(self):
        for _ in range(self.num_workers_ram):
            thread = Thread(target=threaded_batches_feeder,
                            args=(self.controller,
                                  self.train_generator,
                                  self.val_generator,
                                  self.ram_queue))
            thread.start()

        for _ in range(self.num_workers_cuda):
            thread = Thread(target=threaded_cuda_batches,
                            args=(self.controller,
                                  self.ram_queue,
                                  self.cuda_queue,
                                  self.img_transform_train,
                                  self.img_transform_val,
                                  self.device,
                                  self.dataset_name))
            thread.start()

    def change_task(self, new_status):
        self.controller.change_status(new_status, self.ram_queue, self.cuda_queue)

    def get(self, block=True):
        return self.cuda_queue.get(block=True)

    def get_train_batch_size(self):
        return self.train_generator.get_batch_size()

    def get_val_batch_size(self):
        return self.val_generator.get_batch_size()

    def get_train_steps_per_epoch(self):
        return self.train_generator.get_steps_per_epoch()

    def get_val_steps_per_epoch(self):
        return self.val_generator.get_steps_per_epoch()
    
    def stop_and_clear(self):
        self.controller.set_tokill(True)
        for _ in range(self.cuda_queue_length):
            try:
                self.cuda_queue.get(block=True, timeout=1)
            except Empty:
                pass
        for _ in range(self.ram_queue_length):
            try:
                self.ram_queue.get(block=True, timeout=1)
            except Empty:
                pass

        with self.ram_queue.mutex:
            self.ram_queue.queue.clear()
        with self.cuda_queue.mutex:
            self.cuda_queue.queue.clear() 

        clear_cuda()