## Testing Full GAP - Generative

In [None]:
import os
import torch
from skimage import io, measure
import numpy as np
import matplotlib.pyplot as plt

if not torch.cuda.is_available():
    raise ValueError("GPU not found, code will run on CPU and can be extremely slow!")
else:
    device = torch.device("cuda:0")

import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping
import torch.optim as optim
import torch.utils.data as dt

import time
import os
from CGAP_UNET_Super_Res import UN
from BinomDataset_Super_Res import BinomDataset
from inference import sample_image


In [None]:
def preprocess(inp):
    ''' 
    Preprocess images from dataset for plt.imshow
    '''
    if isinstance(inp, np.ndarray):
        img = inp.copy()
        img/=img.max()
        return img.transpose(1, 2, 0)
    elif torch.is_tensor(inp):
        img = inp.clone()
        img/=img.max()
        return img.permute(1, 2, 0)
    else:
        raise ValueError("Invalid input type")

In [None]:
name = 'm40to30-256x256-ffhq-inpainting-full'
CHECKPOINT_PATH = ''

model = UN(channels = 3, levels=10, depth=7,start_filts=32, 
           up_mode = 'upsample', merge_mode = 'concat').to(device)
model = UN.load_from_checkpoint(os.path.join(CHECKPOINT_PATH, name)+'.ckpt').to(device)

In [None]:
cond_input = #grayscale image

In [None]:
channels = 3
batch_size = 1
pixels_x = 256
pixels_y = 256

inp_img =  torch.zeros(batch_size, channels, pixels_y, pixels_x)

cond_img = cond_input

input_img = torch.cat((cond_img, inp_img), 1).to(device)

for i in range(1):
    startTime = time.time()
    denoised, photons, stack, iterations = sample_image(input_img,
                                                        model, 
                                                        beta = 0.1,
                                                        save_every_n = 10,
                                                        max_psnr = 30,
                                                        max_its = 20000000,
                                                        channels = 3)
    for j in range(denoised.shape[0]):
            denoised/=denoised.mean()
            print(denoised.shape)
            plt.figure(figsize = (5,5))
            plt.imshow(preprocess(denoised[0]), vmin = 0,
                       vmax = np.percentile(denoised[0],99.99))
            plt.title('Generative')
            plt.axis('off')
            plt.show()
            
            plt.figure(figsize = (5,5))
            plt.imshow(preprocess(photons[0]), vmin = 0,
                       vmax = np.percentile(photons[j],99.99))
            plt.title('Photon counts')
            plt.show()

    print('_______________________________________', iterations)

    endTime = time.time()
    elapsedTime = endTime - startTime
    print ('time (s):', elapsedTime, 'time per image (s)', elapsedTime/denoised.shape[0])
   

## Testing Full GAP - Diversity Denoising

In [None]:
img_path = ''
dataset = BinomDataset(img_path, 256, -30, -20, 1)
img = dataset[0]

input_img = torch.cat((cond_img, img), 1).to(device)

for i in range(1):
    startTime = time.time()
    denoised, photons, stack, iterations = sample_image(input_img,
                                                        model, 
                                                        beta = 0.1,
                                                        save_every_n = 10,
                                                        max_psnr = 30,
                                                        max_its = 20000000,
                                                        channels = 3)
    for j in range(denoised.shape[0]):
            denoised/=denoised.mean()
            print(denoised.shape)
            plt.figure(figsize = (5,5))
            plt.imshow(preprocess(denoised[0]), vmin = 0,
                       vmax = np.percentile(denoised[0],99.99))
            plt.title('Generative')
            plt.axis('off')
            plt.show()
            
            plt.figure(figsize = (5,5))
            plt.imshow(preprocess(photons[0]), vmin = 0,
                       vmax = np.percentile(photons[j],99.99))
            plt.title('Photon counts')
            plt.show()

    print('_______________________________________', iterations)

    endTime = time.time()
    elapsedTime = endTime - startTime
    print ('time (s):', elapsedTime, 'time per image (s)', elapsedTime/denoised.shape[0])

## Testing Cascaded GAP

load model1, model2, model3, model4 and model5

In [None]:
from Inference_cascade import sample_image
channels = 3
batch_size = 1
pixels_x = 256
pixels_y = 256

inp_img =  torch.zeros(batch_size, channels, pixels_y, pixels_x)

cond_img = cond_input

input_img = torch.cat((cond_img, inp_img), 1).to(device)

for i in range(1):
    startTime = time.time()
    denoised, photons, stack, iterations = sample_image(input_img,
                                                        [model1, model2, model3, model4, model5], 
                                                        beta = 0.1,
                                                        save_every_n = 10,
                                                        max_psnr = 30,
                                                        max_its = 20000000,
                                                        channels = 3)
    for j in range(denoised.shape[0]):
            denoised/=denoised.mean()
            print(denoised.shape)
            plt.figure(figsize = (5,5))
            plt.imshow(preprocess(denoised[0]), vmin = 0,
                       vmax = np.percentile(denoised[0],99.99))
            plt.title('Generative')
            plt.axis('off')
            plt.show()
            
            plt.figure(figsize = (5,5))
            plt.imshow(preprocess(photons[0]), vmin = 0,
                       vmax = np.percentile(photons[j],99.99))
            plt.title('Photon counts')
            plt.show()

    print('_______________________________________', iterations)

    endTime = time.time()
    elapsedTime = endTime - startTime
    print ('time (s):', elapsedTime, 'time per image (s)', elapsedTime/denoised.shape[0])
   