In [1]:
# !pip3 install --upgrade torch --user
# !pip3 install torchvision --user

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import os
import pickle
import tqdm

import torch
import torch.nn.functional as F
import torch.nn as nn


from torchvision import datasets
from torchvision import transforms

from utils.model_utils import PatchLoader
from utils.martemev_utils import compute_psnr, get_freer_gpu, normalize

from time import time

from collections import defaultdict


In [2]:
GPU_NUM = 2

K_CLOSEST = 8
HIDDEN_SIZE = 32
PATCH_SIZE = (33, 33)

gpu_num = get_freer_gpu()
device = torch.device('cuda:{}'.format(GPU_NUM))
device

device(type='cuda', index=2)

In [3]:
import utils.model_classes as classes

In [4]:
names = ['CNN_full.33.valid', 'GraphCNN_full.33.valid',
         'GraphCNN_Baseline.33.valid', "GraphCNN_FastBaseline.33.valid"]

models = [func()(K_CLOSEST, 1, hsize, patch_size=PATCH_SIZE) for func, name, hsize in \
                 zip([classes.get_CNN, classes.get_GCNN, classes.get_GCNN_baseline, classes.get_GCNN_fast_baseline],
                     names, [32, 32, 32, 16])]



for name, model in zip(names, models):
    model.load_state_dict(torch.load('./SavedModels/{}.state_dict'.format(name)))
    model.eval()
    model.cpu()

In [None]:
real_data_path = '../Data/dune_experimental_data/training_histograms_sim/'

events = [i for i in os.listdir(real_data_path) if 'root' not in i]
apas = defaultdict(list)

for event in tqdm.tqdm_notebook(events):
    for apa in [i for i in os.listdir(os.path.join(real_data_path, event)) if '.dat' in i]:
        apas[event].append(np.loadtxt(os.path.join(real_data_path, event, apa), delimiter=','))

In [None]:
real_data = defaultdict(list)

for event in tqdm.tqdm_notebook(events):
    for apa in apas[event]:
        real_data[event].append(normalize(torch.Tensor(apa)))

In [None]:
plt.figure(figsize=(35, 25))

plt.imshow(real_data[events[0]][0])
plt.show()

In [None]:
6000 - 6000%33

In [None]:
denoised_apas = defaultdict(list)

model = models[-1]
name = names[-1]

times = []

model.cuda(device)
model.eval()

for event in tqdm.tqdm_notebook(events, desc='events'):
    for image in tqdm.tqdm_notebook(real_data[event], desc=name, leave=True):
        h, w = image.shape
        image = image[:h - h%33, :w - w%33]
        start_time = time()
        denoised_image = model.forward_image(image, device)
        end_time = time() - start_time
        denoised_image = denoised_image.cpu().data
        denoised_apas[event].append(denoised_image)
        times.append(end_time)
model.cpu()
print("Mean process time of {} = {}".format(name, np.mean(times)))

In [None]:
for event in tqdm.tqdm_notebook(events, desc='events'):
    for ind, (image, d_image) in enumerate(zip(real_data[event], denoised_apas[event])):
        h, w = image.shape
        image = image[:h - h%33, :w - w%33]
        plt.figure(figsize=(35, 15))
        plt.subplot(2, 1, 1)
        plt.imshow(image)
        plt.title('noised event {}, apa {}'.format(event, ind))
        plt.subplot(2, 1, 2)
        plt.imshow(d_image)
        plt.title('denoised event {}, apa {}'.format(event, ind))
        plt.show()

In [None]:
val_images_clear =  torch.load('../Data/val/clear/normalized_data.tensor')
val_images_noised =  torch.load('../Data/val/noised/normalized_data.tensor')

In [None]:
psnrs = defaultdict(list)
mses = defaultdict(list)
times = defaultdict(list)


for name, model in zip(names, models):
    model.cuda(device)
    model.eval()
    for clear_image, noised_image in zip(tqdm.tqdm_notebook(val_images_clear, desc=name), val_images_noised):
        
        start_time = time()
        denoised_image = model.forward_image(noised_image, device)
        end_time = time() - start_time
        denoised_image = denoised_image.cpu().data
        
        mse = nn.MSELoss()(clear_image, denoised_image)
        mses[name].append(mse)
        psnr = compute_psnr(clear_image, denoised_image)
        psnrs[name].append(psnr)
        times[name].append(end_time)
    model.cpu()
    print("Mean MSE of {} = {}".format(name, np.mean(mses[name])))
    print("Mean PSNR of {} = {}".format(name, np.mean(psnrs[name])))
    print("Mean process time of {} = {}".format(name, np.mean(times[name])))
        