# Dataloader

## Includes

In [None]:
# mass includes
import os
import torch as t
import numpy as np
from PIL import Image
from pickle import load, dump
from torch.utils.data import Dataset
from ipynb.fs.full.util import randCrop, randHorFlip, randVerFlip

## Denoising set

In [None]:
class ImageSet(Dataset):

    def __init__(self, opt, mode='train', partition=[0.7, 0.3]):
        assert mode in ['train', 'test'], print('Invalid mode: %s' % mode)
        self.data_path = opt.data_path
        self.crop_size = opt.crop_size
        self.mode = mode

        # list all samples
        self.file_list = [
            file for file in os.listdir(self.data_path) if '.tif' in file
        ]
        self.file_list.sort()

        # load normalization params
        param_path = os.path.join(opt.save_path, 'norm_params.pkl')
        if os.path.exists(param_path):
            with open(param_path, 'rb') as file:
                data_dict = load(file)
            self.mean = data_dict['mean']
            self.std = data_dict['std']

        else:
            # calculate mean for each channel
            total = [0.0, 0.0, 0.0]
            count = [0, 0, 0]
            for idx, f_name in enumerate(self.file_list):
                file_path = os.path.join(self.data_path, f_name)
                img = np.array(Image.open(file_path))
                total[idx % 3] += np.sum(img)
                count[idx % 3] += np.prod(img.shape)
            self.mean = [x / y for (x, y) in zip(total, count)]

            # calculate std for each channel
            diff = [0.0, 0.0, 0.0]
            count = [0, 0, 0]
            for idx, f_name in enumerate(self.file_list):
                file_path = os.path.join(self.data_path, f_name)
                img = np.array(Image.open(file_path))
                diff[idx % 3] += np.sum((img - self.mean[idx % 3])**2)
                count[idx % 3] += np.prod(img.shape)
            self.std = [np.sqrt(x / y) for (x, y) in zip(diff, count)]

            # save to file
            save_dict = {'mean': self.mean, 'std': self.std}
            with open(param_path, 'wb') as file:
                dump(save_dict, file)

        # divide into training and testing sets
        if mode == 'train':
            end_idx = round(len(self.file_list) / 3 * partition[0]) * 3
            self.file_list = self.file_list[:int(end_idx)]
        else:
            start_idx = round(len(self.file_list) / 3 * partition[1]) * 3
            self.file_list = self.file_list[int(-start_idx):]

    def __getitem__(self, index):
        # load all three channels
        out_img = []
        for i in range(3):
            f_name = self.file_list[int(index * 3 + i)]
            file_path = os.path.join(self.data_path, f_name)
            img = np.array(Image.open(file_path), dtype='float32')

            # normalize to [-1,1]
            out_img.append((img - self.mean[i]) / self.std[i])

        # to 3D array
        out_img = np.stack(out_img, axis=0)

        # random transforms
        out_img = randCrop(out_img, crop_size=self.crop_size)
        out_img = randHorFlip(out_img)
        out_img = randVerFlip(out_img)

        # to pytorch tensor
        out_img = t.tensor(out_img)

        if self.mode == 'train':
            return out_img
        else:
            return out_img, f_name.split('-')[0]

    def __len__(self):

        return int(len(self.file_list) / 3)