## Pytorch MRI image class

The purpose of this notebook is to work out a class that I an consistently use for the super-resolution and just general image training of MRI and potentially other medical imaging data types.

It's become clear that everyone seems to have a different, unique, way of loading and organizing their data into a format that can be fed into a Pytorch `Dataset` for training a model. This can include creating intermediate `.png` images in order to limit memory usage/training time on personal hardware.

Goals for this tool:
1. Given an input folder, list all files that match a particular `prefix` and `suffix`
2. Display sample images from the list for assurance
3. Create randomly shuffled/altered images at different resolutions (gaussian blur, affine transformation, etc.)
4. Save any image generated in a specified format
5. Be given an input and output folder and generate list of matching data
6. Have locations of matching low and high resolution images (potentially a list for every x2 magnification)


Output file labeling protocol:
To keep things consistent I should probably create a labeling structure that is robust to future changes that I might want to make. This will most likely have to be stored in the string name, unless I want to make an intermediate file that is used as a key for the location of the files relative to the directory.

## Import necessary libraries

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import os
import numpy as np
from PIL import Image
import cv2
from matplotlib import pyplot
from skimage.transform import rotate, AffineTransform, warp, rescale

## Class definition

In [11]:
# The super resolution class:

class sr_gen():
    def __init__(self, inp_dir, out_dir, prefix='', suffix=''):
        self.inp_dir = inp_dir
        self.out_dir = out_dir
        self.inp_files = self.__get_inp__(prefix, suffix)

    def __get_inp__(self, prefix, suffix):
        # Get the original files that will be used to generate everything
        # Based on the prefix, for now this could be nifti files or png's.
        # Will have to write a "try" statement in order to check for file type
        files = []
        for fil in os.listdir(self.inp_dir):
            if fil.startswith(prefix) & fil.endswith(suffix):
                files.append(fil)
        
        if files:
            raise FileNotFoundError('No applicable files found in input directory')

        return files

    def __get_out__(self):
        # get list of files in output directory and determine matching files
        return os.listdir(self.out_dir)


    def match_altered(self):
        # Get the files that have been generated in the output directory
        ex_files = self.__get_out__()

        # Get a set of all the files with agreement before the metadata
        if len(ex_files) > len(self.inp_files):
            # TODO: add stripping function to remove metadata from string name
            matches = list(set(ex_files)-(set(ex_files)-set(self.inp_files)))
        else:
            matches = list(set(self.inp_files)-(set(self.inp_files)-set(ex_files)))

        return matches

    def change_inp(self, inp_dir, out_dir):
        # Change which set of images are viewed as the inputs and outputs
        pass

    def gen_LR(self):
        # Generate Low-Resolution equivalents for input images
        # This will probably be the main method in this class, it'll take:
        # input files, desired changes (probably in dict form),

        pass

    def imresize(self):
        # Given an particular resolution to change to, upsample or downsample original image
        pass

    def img_transform(self):
        # Transform the original files using a variety of methods
        pass

    def img2patches(self):
        # Take a given image and generate patches
        pass

    def save_img(self):
        # Save the generated image to location specified
        # Should specify the number of 
        pass


In [None]:
test = sr_gen('./data/nii_HR_ax/','./data/nii_HR_sag/')
test.inp_files

In [34]:
a = {'1','2','3','2'}
b = {'1','2'}

In [35]:
set(a)-(set(a)-set(b))

{'1', '2'}

In [36]:
set(a)-set(b)

{'3'}

In [37]:
list(a)

['3', '1', '2']