In [32]:
critic_type = 'baseline'

In [33]:
import sys
sys.path.append("../../../")

import skimage

import fastai
from torchvision import transforms
from fastai import *
from fastai.vision import *
from fastai.callbacks import *
from fastai.script import *

from model.metrics import *

from skimage.measure import compare_psnr, compare_ssim

from data.load import get_data

# Choose GPU Device

In [34]:
torch.cuda.set_device(0)

# Load Data

### Parameters for Data Loading

In [35]:
data_pth = Path('/home/alaa/Dropbox/BPHO Staff/USF')  # path to dataset
lr_dir = f'EM/training/trainsets/lr/'  # input images
hr_dir = f'EM/training/trainsets/hr/'  # target images
model_path = data_pth/f'EM/models/self_sv/{critic_type}'
test_pth = data_pth/f'EM/testing/LR/real-world_SEM'
infer_pth = data_pth/f'EM/testing/self_sv/{critic_type}'

In [36]:
from PIL import Image
img = Image.open(list(test_pth.glob('*.tif'))[0])
img.size

(125, 125)

In [37]:
bs = 8  # batch size
size = 500  # image size

In [38]:
# Store dataset in a databunch object 
data = get_data(data_pth, lr_dir, hr_dir, bs, 
                in_sz=size, out_sz=size, max_zoom=1.)

In [39]:
from model.metrics import psnr, ssim
superres_metrics = [F.mse_loss, psnr, ssim]

# Modelling

## Load Model

In [40]:
arch = models.resnet34
wd = 1e-3
learn = unet_learner(data, arch, wd=wd, 
                     #loss_func=feat_loss,
                     loss_func=F.mse_loss,
                     metrics=superres_metrics,
                     #callback_fns=LossMetrics, 
                     blur=True, norm_type=NormType.Weight, model_dir=model_path)
gc.collect()

0

## Inference

In [41]:
test_set = ImageList.from_folder(test_pth, convert_mode='L')
learn.data.add_test(test_set, tfm_y=False)

In [42]:
model_list = ['5.16_mse_baseline_3.pkl']
tag_list = ['baseline']

In [43]:
from libtiff import TIFF

In [44]:
def model_inference(learner=learn, folder_name=Path("CHANGEME/real-world_SEM"), img_tag=f"CHANGEME"):
    dir_name = infer_pth/folder_name
    if not os.path.isdir(dir_name): dir_name.mkdir(parents=True, exist_ok=True)
    print(dir_name)
    
    
    for img, img_name in zip(learner.data.test_ds, learner.data.test_ds.items):
#         for i in range(3):
#             img[0].data[i] -= mean[i]
#             img[0].data[i] /= std[i]
        
#         for i in range(3):
#             pred[1][i] *= std[i]
#             pred[1][i] += mean[i]
        pred = learner.predict(img[0])
        pred_name = dir_name/img_name.name.replace('lr', img_tag)
        print(pred_name)

        

        pred_img = pred[1][0]
        tiff = TIFF.open(pred_name, mode='w')
        tiff.write_image(pred_img)
        
        print(f"Performed inference on {img_name.stem}, file saved as {pred_name}")
    print("Model Inference Complete")

In [45]:
for model_name, img_tag in zip(model_list, tag_list):
    learn.model.load_state_dict(torch.load(model_path/model_name))
    model_inference(folder_name=Path(f"{model_name}/real-world_SEM"), img_tag=img_tag)

/home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM
/home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_05.tif
Performed inference on realword_SEM_test_lr_05, file saved as /home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_05.tif
/home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_38.tif
Performed inference on realword_SEM_test_lr_38, file saved as /home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_38.tif
/home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_42.tif
Performed inference on realword_SEM_test_lr_42, file saved as /home/alaa/Dropbox/BPHO Staff/USF/EM/t

/home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_11.tif
Performed inference on realword_SEM_test_lr_11, file saved as /home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_11.tif
/home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_31.tif
Performed inference on realword_SEM_test_lr_31, file saved as /home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_31.tif
/home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_17.tif
Performed inference on realword_SEM_test_lr_17, file saved as /home/alaa/Dropbox/BPHO Staff/USF/EM/testing/self_sv/baseline/5.16_mse_baseline_3.pkl/real-world_SEM/realword_SEM_test_baseline_17.tif
/hom

# Model Evaluation

In [21]:
target_dirname = data_pth/f'EM/testing/HR/real-world_SEM/'       # target images
target_files = sorted(list(target_dirname.glob('*.tif')))
print('Processing '+str(len(target_files))+' files...')

Processing 42 files...


In [18]:
def evaluate_model(pred_files, hr_files):
    psnr_scores = {}
    ssim_scores = {}
    for pred_fname, targ_fname in zip(pred_files, hr_files):
        pred_img = PIL.Image.open(pred_fname)
        targ_img = PIL.Image.open(targ_fname)
        
        pred = skimage.util.img_as_ubyte(np.asarray(pred_img))
        targ = skimage.util.img_as_ubyte(np.asarray(targ_img))
        
        psnr_score = compare_psnr(targ, pred, data_range=255.)
        psnr_scores[targ_fname] = psnr_score
        
        ssim_score = compare_ssim(targ, pred, data_range=255.)
        ssim_scores[targ_fname] = ssim_score
        

    return psnr_scores, ssim_scores

In [21]:
results = {}
mean_psnr = "mean psnr"
mean_ssim = "mean ssim"
for model_name, img_tag in zip(model_list, tag_list):
    inference_dirname = infer_pth/f'{model_name}/real-world_SEM/'  
    inference_files = sorted(list(inference_dirname.glob('*.tif')))
    results[img_tag] = {}
    results[img_tag]['psnr'], results[img_tag]['ssim'] = evaluate_model(inference_files, target_files)
    results[img_tag]['mean psnr'] = np.mean([results[img_tag]['psnr'][f] for f in results[img_tag]['psnr'].keys()])
    results[img_tag]['mean ssim'] = np.mean([results[img_tag]['ssim'][f] for f in results[img_tag]['ssim'].keys()])
    print(f'{img_tag} \n \t Mean PSNR: {results[img_tag][mean_psnr]:.3f} \n \t Mean SSIM: {results[img_tag][mean_ssim]:.3f}')

  .format(dtypeobj_in, dtypeobj_out))


inpaint_fresh 
 	 Mean PSNR: 24.582 
 	 Mean SSIM: 0.483


# Visualization

In [46]:
model_list[1]='baseline-pretrained-best'
tag_list[1]='baseline-pretrained-best'

IndexError: list assignment index out of range

In [None]:
model_list, tag_list

In [None]:
model_1_name = model_list[0]
model_2_name = model_list[1]

dir1 = f'{model_1_name}/real-world_SEM/'  # pssr images
dir2 = f'{model_2_name}/real-world_SEM/'  # our images

# Modify accordingly
pth1 = infer_pth/dir1
pth2 = infer_pth.parent/dir2

model1_files = sorted(list(pth1.glob('*.tif')))
model2_files = sorted(list(pth2.glob('*.tif')))

print('Processing '+str(len(model2_files))+' files...')

In [None]:
import ipywidgets as widgets
from ipywidgets import interact, interact_manual
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
np.set_printoptions(linewidth=110)

# Set size for visualizations
fig_size = plt.rcParams["figure.figsize"]  # Get current size
print("Current size:", fig_size)
 
# Set figure width to 12 and height to 9
fig_size[0] = 30
fig_size[1] = 24
plt.rcParams["figure.figsize"] = fig_size

In [None]:
def visualize_sample(img_fname1, img_fname2, targ_fname, title1=model_1_name, title2="Target", title3=model_2_name):
    model1_im = PIL.Image.open(img_fname1) # load image to tensor
    targ_im = PIL.Image.open(targ_fname)
    model2_im = PIL.Image.open(img_fname2)
    
    model1_arr = np.asarray(model1_im)
    model1_arr = skimage.util.img_as_ubyte(model1_arr)
    model2_arr = np.asarray(model2_im)
    model2_arr = skimage.util.img_as_ubyte(model2_arr)
    
    
    targ_arr = np.asarray(targ_im, np.uint8)
    
    model1_psnr = compare_psnr(targ_arr, model1_arr, data_range=255.)
    model1_ssim = compare_ssim(targ_arr, model1_arr, data_range=255.)
    model2_psnr = compare_psnr(targ_arr, model2_arr, data_range=255.)
    model2_ssim = compare_ssim(targ_arr, model2_arr, data_range=255.)
    f, axarr = plt.subplots(3,1) # create visualizations
    axarr[0].imshow(np.squeeze(model1_arr), cmap=plt.cm.gray) # visualize image tensor
    axarr[0].set_title(title1)
    axarr[0].set_xlabel(f"PSNR: {model1_psnr:.2f}, SSIM: {model1_ssim:.2f}")
    axarr[2].imshow(np.squeeze(targ_arr), cmap=plt.cm.gray) # visualize original image file
    axarr[2].set_title(title2)
    axarr[1].imshow(np.squeeze(model2_arr), cmap=plt.cm.gray) # visualize image tensor
    axarr[1].set_title(title3)
    axarr[2].set_xlabel(f"PSNR: {model2_psnr:.2f}, SSIM: {model2_ssim:.2f}")
    plt.show() # show visualization

In [None]:
@interact
def show_sample(sample=33):
    return visualize_sample(model1_files[sample], model2_files[sample], target_files[sample])