# Use pre-trained model to denoise experimental lysosome image

In [None]:
import os
import sys
sys.path.append("..")

import torch
import math
import json
import matplotlib.pyplot as plt
import os.path
import datetime
import numpy as np

from util import show, plot_images, plot_tensors, psnr
from data_loader import tiff_loader, load_confocal

from util import getbestgpu
from models.unet import Unet
from metric import frc, match_intensity, quantify, plot_quantifications
from train import train

In [None]:
device = getbestgpu(1)

In [None]:
config = {
    'sample_size_list' : [10],
    'root' : '/Data/Confocal/Cropped',
    'psignal_levels' : [i for i in range(10,51,5)], #[s for s in range(10,100,10)],
    'types' : ['lysosome'], # ['DNA', 'lysosome', 'microtubule', 'mitochondria'],
    'captures' : 1,
    'train_stat' : True,
    'patch_size' : 128,
    'batch_size' : 32,
    'n_iters' : 10000,
    'repeats' : 0,
    'metrics_key' : ['mse', 'ssmi', 'frc'],
    'loss':'mse',
    'Unet':{'up':'tconv'}
}

In [None]:
for signal in config['psignal_levels']:
    loader, dataset = load_confocal(config['root'], config['train_stat'], config['batch_size'], 
                                        [signal], config['sample_size_list'][0], split_ratio = 0, 
                                        types=config['types'], captures=config['captures'],
                                        patch_size=config['patch_size'], loader=tiff_loader)
    noisy = dataset[0]
    clean = dataset[1]


In [None]:
FMD = {}
random = {}
simulation = {}

for signal in config['psignal_levels']:
    key_s = signal
    FMD[key_s] = {}
    random[key_s] = {}
    simulation[key_s] = {}
    
    print(f"Running training with {signal} peak signal level...")
    n_epoch = config['n_iters']
    
    for repeat in range(config['repeats']):
        print(f"No. {repeat}...")
        loader, dataset = load_confocal(config['root'], config['train_stat'], config['batch_size'], 
                                        [signal], config['sample_size_list'][0], split_ratio = 0.2, 
                                        types=config['types'], captures=config['captures'],
                                        patch_size=config['patch_size'], loader=tiff_loader)
        noisy = dataset[0]
        clean = dataset[1]
        noisy, clean = noisy.type(torch.FloatTensor).to(device), clean.type(torch.FloatTensor).to(device)

        model_random = Unet(**config['Unet'])
        model_FMD = Unet(**config['Unet'])
        model_simulation = Unet(**config['Unet'])

        model_FMD.load_state_dict(torch.load('./trained_models/FMD_epoch50_model'))
        model_simulation.load_state_dict(torch.load('./trained_models/MT_simulation_iter1000_model_trained'))

        optimizer_random = Adam(model_random.parameters(), lr=0.001)
        optimizer_FMD = Adam(model_FMD.parameters(), lr=0.0001)
        optimizer_simulation = Adam(model_simulation.parameters(), lr=0.0001)

        model_random = model_random.to(device)
        model_FMD = model_FMD.to(device)
        model_simulation = model_simulation.to(device)


        output_FMD = model_FMD(noisy)
        output_random = model_random(noisy)
        output_simulation = model_simulation(noisy)
        
        # plot example images
        if repeat == 0:
            nplot = 1
            plot_tensors([noisy[nplot,0,:], clean[nplot,0,:], output_random[nplot,0,:], output_FMD[nplot,0,:], output_simulation[nplot,0,:]])

        output_FMD = output_FMD.cpu().detach().numpy()
        output_random = output_random.cpu().detach().numpy()
        output_simulation = output_simulation.cpu().detach().numpy()
        noisy = noisy.cpu().detach().numpy()
        clean = clean.cpu().detach().numpy()
        
        if repeat == 0 and True:
            frc_FMD, spatial_freq = frc(output_FMD[0,0,:], clean[0,0,:])
            frc_simulation, spatial_freq = frc(output_simulation[0,0,:], clean[0,0,:])
            frc_random, spatial_freq = frc(output_random[0,0,:], clean[0,0,:])
            plt.figure()
            plt.plot( spatial_freq , frc_FMD , '-' , linewidth=2 , color='red' , label='Pretrained with FMD' )
            plt.plot( spatial_freq , frc_simulation , '-' , linewidth=2 , color='blue' , label='Pretrained with simulation' )
            plt.plot( spatial_freq , frc_random , '-' , linewidth=2 , color='green' , label='Random initialization' )
            plt.legend(loc='lower left')
            plt.title('FRC curve')

        for sample in range(config['sample_size_list'][0]):
            output_random[sample,:] = match_intensity(clean[sample,:], output_random[sample,:])
            output_FMD[sample,:] = match_intensity(clean[sample,:], output_FMD[sample,:])
            output_simulation[sample,:] = match_intensity(clean[sample,:], output_simulation[sample,:])

        #quantify(FMD[key_s], config['metrics_key'], clean[0,0, :], output_FMD[0,0,:])
        #quantify(simulation[key_s], config['metrics_key'], clean[0,0, :], output_simulation[0,0,:])
        #quantify(random[key_s], config['metrics_key'], clean[0,0, :], output_random[0,0,:])