# Training Dataset Creation

Create training and validation pth files from the data.mat files of processed scenes.

In [None]:
import h5py
import os
import matplotlib.pyplot as plt
import numpy as np
import scipy.io as sio
from scipy.interpolate import interp1d
import sys
import copy 
import torch

sys.path.append('../functions/')
from helpers import crop_center_array

%matplotlib widget
%load_ext autoreload
%autoreload 2

In [None]:
# training 
scene_names_train = [
    '0625.1013.Rulers', 
    '0625.1027.Markers', 
    '0625.1106.YellowLED', 
    '0625.1126.Newport', 
    '0625.1135.Feathers', 
    '0625.1342.NewportBird', 
    '0625.1412.Elephant', 
    '0625.1439.Cat', 
    '0625.1501.TriceraEKE', 
    '0625.1524.BearChopper', 
    '0625.1533.Pony',
    '0628.1352.TheHumanCondition',
    '0628.1404.Slinkies',
    '0628.1425.HeatSink',
    '0628.1436.Cups'
]

# validation 
scene_names_val = [
    '0625.1036.MatroshkaSplit', 
    '0625.1357.SigCup', 
    '0625.1510.TriceraLED', 
    '0625.1544.Island', 
    '0625.1552.MatroshkaSmall',
    '0628.1143.FlowersEKE',
    '0628.1154.Tulips',
    '0628.1448.Wood'
]

# Creation of a single pth file for all test data scenes is not needed. 
# Use restore_create_data_single_image.ipynb to create scenewise pth files.
# scene_names_test = [
#         '0625.0928.CheckerEKE', 
#         '0625.0938.CheckerLED', 
#         '0625.0955.HHPainting', 
#         '0625.1049.MatroshkaFamily', 
#         '0625.1331.FeathersZoom', 
#         '0628.1131.Flowers',
#         '0628.1229.Feathers',
#         '0628.1303.Painting',
#         '0628.1316.Chopper',
#         '0628.1332.SpectralonEKE',
#         '0628.1332.SpectralonLED'
# ]

In [None]:
PROCESSEDDATA_DIR = 'Box/data/Processed/'
SAVEPTHFILE_DIR = '../data/restore/'

# Choose 'train' or 'val'
dataset_type = 'train' 
crop_size = 1024
num_patterns = 92

save_pthfile_train = '0625' + '_data4b_' + dataset_type + '_' + str(crop_size) + 'x' + str(crop_size) + '.pth'

if dataset_type == 'train': scene_names = scene_names_train
if dataset_type == 'val': scene_names = scene_names_val
if dataset_type == 'test': scene_names = scene_names_test

num_scenes = len(scene_names)

if crop_size:
    assort_meas_t = torch.zeros((num_scenes, num_patterns, crop_size, crop_size))
    assort_sim_t = torch.zeros((num_scenes, num_patterns, crop_size, crop_size))
    assort_index_t = torch.zeros((num_scenes, num_patterns, crop_size, crop_size))
    guide_image_t = torch.zeros((num_scenes, 3, crop_size, crop_size))

for i, scene_name in enumerate(scene_names):
    data_matfile = PROCESSEDDATA_DIR + '/' + scene_name + '/data.mat'
    print(f'{i}/{num_scenes} {data_matfile}')
    data = sio.loadmat(data_matfile)

    assort_sim = data['assort_sim'].astype(float)
    assort_meas = data['assort_meas'].astype(float)
    assort_index = data['assort_index'].astype(float)
    guide_image = data['guide'].astype(float)
    
    if crop_size:
        assort_sim = crop_center_array(assort_sim, crop_size, crop_size)
        assort_meas = crop_center_array(assort_meas, crop_size, crop_size)
        assort_index = crop_center_array(assort_index, crop_size, crop_size)
        guide_image = crop_center_array(guide_image, crop_size, crop_size)
    
    assort_sim = assort_sim.transpose(2,0,1)
    assort_meas = assort_meas.transpose(2,0,1) 
    assort_index = assort_index.transpose(2,0,1) 
    guide_image = guide_image.transpose(2,0,1) 
    
    assort_sim_t[i] = torch.tensor(assort_sim, dtype=torch.float)
    assort_meas_t[i] = torch.tensor(assort_meas, dtype=torch.float)
    assort_index_t[i] = torch.tensor(assort_index, dtype=torch.float)
    guide_image_t[i] = torch.tensor(guide_image, dtype=torch.float)
    
torch.save({'assort_sim': assort_sim_t.detach().clone(),
            'assort_meas': assort_meas_t.detach().clone(),
            'assort_index': assort_index_t.detach().clone(),
            'guide_image': guide_image_t.detach().clone()
           }, 
           os.path.join(SAVEPTHFILE_DIR, save_pthfile_train))