# Run trained model on test data

Create copies of this notebook to test on different data simultaneously in different GPUs.

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import scipy.io as sio
import sys

from pathlib import Path

import torch

sys.path.append('../functions/')
from solvers import restore_net

os.environ['CUDA_DEVICE_ORDER']='PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES']='0'

GPU_DEBUG = False

%matplotlib widget
%load_ext autoreload
%autoreload 2

# Test scene names
scene_names = [
    '0625.0928.CheckerEKE',
    '0625.0938.CheckerLED',
    '0625.0955.HHPainting',
    '0625.1049.MatroshkaFamily',
    '0625.1331.FeathersZoom',
    '0628.1131.Flowers',
    '0628.1229.Feathers',
    '0628.1303.Painting',
    '0628.1316.Chopper',
    '0628.1332.SpectralonEKE',
    '0628.1332.SpectralonLED',
    '0625.1342.NewportBird',
    '0723.1436.ButterflyWhite',
    '0723.1419.ButterflyOrange',
    '0729.1303.ButterflyBrown',
    '0729.1320.ButterflyOrangev2',
    '0729.1337.ButterflyBlue',
    '0729.1401.ButterflyTransp',
    '0730.1448.TheBulb_16',
    '0730.1448.TheBulb_33',
    '0730.1529.Plants',
    '0730.1551.CheckerCFL',
    '0803.1155.Plants',
    '0803.1155.Plants2',
    '0806.1542.Plants3'
]

# Have multiple files _gpu0.ipynb, _gpu1.ipynb, etc., and test simultaneously for multiple slices from multiple files
scene_names = scene_names[:3]

data4_pthfile_list = ['../data/restore/'+ s + '_data4_1024x1024.pth' for s in scene_names]

restore_saved_model = 'Box/data/RestorationTrainedModel/restore_0625_data4b_r1_pm1_64x64_Assort_PosEncND_Unet192_1024x1024_NoGuide_rand2022/best.pth'

save_mat_path = '../saved_outputs/'
Path(save_mat_path).mkdir(parents=True, exist_ok=True)

pattern_idx_save = -1
num_patterns = 92  # not used

patch_size = (1024, 1024)
batch_size = 1

pos_enc = {'enabled': True, 'len': 64, 'min_freq': 1e-4, 'nd': True} # this gives 3*64 channels
use_guide_image = False

for scene_i, data_pthfile in enumerate(data4_pthfile_list):
    print(f'Processing {data_pthfile}')
    
    image_dataset, assort_meas, assort_restored, assort_gt =\
        restore_net(data_pthfile, restore_saved_model, pattern_idx_save, patch_size, batch_size, pos_enc, use_guide_image)
    print(assort_meas.shape, assort_restored.shape, assort_gt.shape)

    assort_meas_np = assort_meas.clone().detach().numpy().astype(np.double)
    assort_restored_np = assort_restored.clone().detach().numpy().astype(np.double)
    assort_gt_np = assort_gt.clone().detach().numpy().astype(np.double)
    save_dict = {
                    'assort_meas': assort_meas_np,
                    'assort_restored': assort_restored_np,
                    'assort_gt': assort_gt_np
                }
    if use_guide_image:
        guide_image_np = image_dataset.guide_image.clone().detach().numpy().astype(np.double)
        save_dict['guide_image'] = guide_image_np
        
    sio.savemat(os.path.join(save_mat_path, scene_names[scene_i] + '.mat'), save_dict)
                