# Dataloader

## Includes

In [None]:
# mass includes
import os
import torch as t
import numpy as np
from cv2 import imread
from pickle import load, dump
from torch.utils.data import Dataset
from ipynb.fs.full.util import randCrop, randFlip

## Image set

In [None]:
class ImageSet(Dataset):

    def __init__(self,
                 opt,
                 mode='train',
                 norm=True,
                 rand_trans=True,
                 mask_out=True):
        assert mode in ['train', 'test','all'], print('mode must be train, test, or all')
        assert rand_trans in [True, False], print('rand_trans must be a boolean value')
        assert norm in [True, False], print('norm must be a boolean value')

        self.data_path = opt.data_path
        self.crop_size = opt.crop_size
        self.mode = mode
        self.norm = norm
        self.rand_trans = rand_trans
        self.mask_out = mask_out

        # list all samples
        self.file_list = [
            file.split('-')[0] for file in os.listdir(self.data_path)
            if 'actin.tif' in file
        ]
        self.file_list.sort()

        # load or compute normalization params
        param_path = os.path.join(self.data_path, 'norm-params.pkl')
        if os.path.exists(param_path):
            with open(param_path, 'rb') as file:
                data_dict = load(file)
            self.mean = data_dict['mean']
            self.std = data_dict['std']

        else:
            # compute mean for each channel
            total = [0.0, 0.0, 0.0]
            count = [0, 0, 0]
            for idx, f_name in enumerate(self.file_list):
                file_path = os.path.join(self.data_path, f_name)

                img = imread(file_path + '-DNA.tif', -1).astype('float32')
                total[0] += np.sum(img)
                count[0] += np.prod(img.shape)

                img = imread(file_path + '-pH3.tif', -1).astype('float32')
                total[1] += np.sum(img)
                count[1] += np.prod(img.shape)

                img = imread(file_path + '-actin.tif', -1).astype('float32')
                total[2] += np.sum(img)
                count[2] += np.prod(img.shape)

            self.mean = [x / y for (x, y) in zip(total, count)]

            # compute std for each channel
            diff = [0.0, 0.0, 0.0]
            count = [0, 0, 0]
            for idx, f_name in enumerate(self.file_list):
                file_path = os.path.join(self.data_path, f_name)

                img = imread(file_path + '-DNA.tif', -1).astype('float32')
                diff[0] += np.sum((img - self.mean[0])**2)
                count[0] += np.prod(img.shape)

                img = imread(file_path + '-pH3.tif', -1).astype('float32')
                diff[1] += np.sum((img - self.mean[1])**2)
                count[1] += np.prod(img.shape)

                img = imread(file_path + '-actin.tif', -1).astype('float32')
                diff[2] += np.sum((img - self.mean[2])**2)
                count[2] += np.prod(img.shape)

            self.std = [np.sqrt(x / y) for (x, y) in zip(diff, count)]

            # save to file
            save_dict = {'mean': self.mean, 'std': self.std}
            with open(param_path, 'wb') as file:
                dump(save_dict, file)

        # dataset partition
        if self.mode == 'train':
            self.file_list = self.file_list[:opt.data_part[0] * 4]
        elif self.mode == 'test':
            self.file_list = self.file_list[-opt.data_part[1] * 4:]
        else:
            pass

    def __getitem__(self, index):
        f_name = self.file_list[index * 4:index * 4 + 4]

        # load all 4 samples
        out_sample = []
        for img_idx in range(4):
            file_path = os.path.join(self.data_path, f_name[img_idx])
            chnl0 = imread(file_path + '-DNA.tif', -1).astype('float32')
            chnl1 = imread(file_path + '-pH3.tif', -1).astype('float32')
            chnl2 = imread(file_path + '-actin.tif', -1).astype('float32')

            # normalize to [-1,1]
            if self.norm == True:
                chnl0 = (chnl0 - self.mean[0]) / self.std[0]
                chnl1 = (chnl1 - self.mean[1]) / self.std[1]
                chnl2 = (chnl2 - self.mean[2]) / self.std[2]

            # register to assm. img
            assm_img = [chnl0, chnl1, chnl2]

            # load mask
            if self.mask_out == True:
                mask = imread(file_path + '-mask.png', -1) / 255
                assm_img.append(mask.astype('float32'))

            # to 3D array
            assm_img = np.stack(assm_img, axis=0)

            # random transforms
            if self.rand_trans == True:
                assm_img = randCrop(assm_img, crop_size=self.crop_size)
                assm_img = randFlip(assm_img)

            # register to 4D data
            out_sample.append(assm_img)

        # to pytorch tensor
        out_sample = t.tensor(np.stack(out_sample, axis=0))

        return out_sample, f_name

    def __len__(self):

        return int(len(self.file_list) / 4)