In [9]:
import argparse
from pprint import pp
import torch
from torch import nn
from tqdm import tqdm
import numpy as np
import json
import os
from omegaconf import OmegaConf
from torch.utils.tensorboard import SummaryWriter

from utils import load_dataset, load_config
from utils import get_model
from utils import get_loss_hparams_and_lr, get_loss
from utils import trainer
from utils import utils
import matplotlib.pyplot as plt
import pandas as pd

In [10]:
#set torch device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


#code in this file will serve as a denoising test of the trained ne_lpn model

dataset_config_path = "../exps/mnist/configs/mnist/test_dataset.json"

In [11]:
# create the directory for the experiment results if it does not exist
result_folder = "average_mse_results"
if not os.path.exists(result_folder):
    os.makedirs(result_folder)

In [12]:
def average_mses(sigma, img_count: int, model_config_paths: dict, dataset_config_path: str, model_weight_paths: dict):
    lpn_config, ne_config = load_config(model_config_paths["lpn"]), load_config(model_config_paths["ne"])
    lpn_weight, ne_weight = model_weight_paths["lpn"], model_weight_paths["ne"]

    model_lpn = get_model(lpn_config)
    model_lpn.load_state_dict(torch.load(lpn_weight)["model_state_dict"])

    model_ne = get_model(ne_config)
    model_ne.load_state_dict(torch.load(ne_weight)["model_state_dict"])

    dataset_config = load_config(dataset_config_path)
    test_dataset = load_dataset(dataset_config, "test")
    test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=4)

    lpn_loss = 0.0
    ne_loss = 0.0
    noised_loss = 0.0
    ground_truth_loss = 0.0

    for step, batch in enumerate(test_data_loader):
        if step >= img_count:
            break

        clean_images = batch["image"].to(device)
        noise = torch.randn_like(clean_images)
        noised_images = clean_images + noise * sigma 
                
        ne_denoised = model_ne(noised_images)
        lpn_denoised = model_lpn(noised_images)

        mse_loss = nn.MSELoss() 
        
        noised_mse = mse_loss(noised_images, clean_images).item()
        noised_loss += noised_mse

        ne_mse = mse_loss(ne_denoised, clean_images).item()
        ne_loss += ne_mse

        lpn_mse = mse_loss(lpn_denoised, clean_images).item()
        lpn_loss += lpn_mse 

        ground_truth_loss += mse_loss(clean_images, clean_images).item()

        if step == 0:
            fig, axs = plt.subplots(2, 2)
            # plot config should be 
            # clean \ noised \\ ne \ lpn
            axs[0,0].imshow(clean_images[0][0].cpu().detach().numpy(), cmap='gray')
            axs[0,0].set_title("Clean Image")
            axs[0,0].axis('off')
            axs[0,1].imshow(noised_images[0][0].cpu().detach().numpy(), cmap='gray')
            axs[0,1].set_title("Noised Image")
            axs[0,1].axis('off')
            axs[1,0].imshow(ne_denoised[0][0].cpu().detach().numpy(), cmap='gray')
            axs[1,0].set_title("NE Denoised Image")
            axs[1,0].axis('off')
            axs[1,1].imshow(lpn_denoised[0][0].cpu().detach().numpy(), cmap='gray')
            axs[1,1].set_title("LPN Denoised Image")
            axs[1,1].axis('off')

            plt.suptitle(f"Denoising Visualization for \u03C3={sigma}")
            plt.show()
            
            plt.savefig(result_folder + f"/denoising_visualization_{sigma}_aff.png", bbox_inches='tight')

    
    return {
        "lpn_res": lpn_loss / img_count,
        "ne_res": ne_loss / img_count,
        "noised_res": noised_loss / img_count,
        "ground_truth_res": ground_truth_loss / img_count
    }


In [13]:
def average_mse(sigma, img_count: int, model_path: str, dataset_config_path: str, model_weight_path: str):
    # load model and dataset
    model_config = load_config(model_path)
    model = get_model(model_config)
    model.load_state_dict(torch.load(model_weight_path)["model_state_dict"])

    dataset_config = load_config(dataset_config_path)
    test_dataset = load_dataset(dataset_config, "test")

    #get the data loader
    test_data_loader = torch.utils.data.DataLoader(
        test_dataset, batch_size=1, shuffle=False, num_workers=4
    )

    total_loss = 0.0

    for step, batch in enumerate(test_data_loader):
        if step >= img_count:
            break


        clean_images = batch["image"].to(device)
        noise = torch.randn_like(clean_images)
        noised_images = clean_images + noise * sigma
        denoised_images = model(noised_images)
        
        mse_loss = nn.MSELoss()
        loss = mse_loss(denoised_images, clean_images)
        total_loss += loss.item()


    return total_loss / img_count

In [14]:
def run_test(sigma_levels: list[float], img_count: int, model_config_paths, model_weight_paths, dataset_config_path: str, result_path: str):
    
    result_table = pd.DataFrame(columns=["\u03C3", "noisy_mse", "lpn_mse", "ne_mse", "ground_truth"])
    results_dict = {"\u03C3": [], "noisy_mse": [], "lpn_mse": [], "ne_mse": [], "ground_truth": []}

    for sigma in sigma_levels:
        results = average_mses(sigma, img_count, model_config_paths, dataset_config_path, model_weight_paths)
        result_table.loc[len(result_table)] = {"\u03C3": sigma, "noisy_mse": results["noised_res"], "lpn_mse": results["lpn_res"], "ne_mse": results["ne_res"], "ground_truth": results["ground_truth_res"]}
        results_dict["\u03C3"].append(sigma)
        results_dict["noisy_mse"].append(results["noised_res"])
        results_dict["lpn_mse"].append(results["lpn_res"])
        results_dict["ne_mse"].append(results["ne_res"])
        results_dict["ground_truth"].append(results["ground_truth_res"])
    
    result_table.to_csv(result_path + "/result_table.csv")

    figsize = (len(sigma_levels) * 2, 6)
    fig, ax = plt.subplots(figsize=figsize)
    ax.axis('tight')
    ax.axis('off')

    result_table_rounded = result_table.round(6)
    table = ax.table(cellText=result_table_rounded.values, colLabels=result_table_rounded.columns, cellLoc='center', loc='center')
    table.auto_set_font_size(False)
    table.set_fontsize(14)

    table.scale(1.5, 4)

    ax.title.set_fontsize(20)

    plt.tight_layout()
    plt.subplots_adjust(top=0.85)
    plt.savefig(result_path + "/result_table_ne_aff.png", bbox_inches='tight')

    # Plot sigma vs results
    fig, ax = plt.subplots(figsize=(8, 6))
    ax.plot(results_dict["\u03C3"], results_dict["noisy_mse"], label="Noisy MSE")
    ax.plot(results_dict["\u03C3"], results_dict["lpn_mse"], label="LPN MSE")
    ax.plot(results_dict["\u03C3"], results_dict["ne_mse"], label="NE MSE")
    ax.plot(results_dict["\u03C3"], results_dict["ground_truth"], label="Ground Truth")
    ax.set_xlabel("\u03C3")
    ax.set_ylabel("MSE")
    ax.set_title("Noise Level vs Average MSE")
    ax.legend()
    plt.savefig(result_path + "/sigma_vs_results_ne_aff.png", bbox_inches='tight')


In [15]:
ne_model_config_path = "../exps/mnist/configs/mnist/model_ne_mnist_affine_z_scored.json"
ne_model_weight_path = "../exps/mnist/experiments/ne_mnist_affine_z_scored/model.pt"
lpn_model_config_path = "../exps/mnist/configs/mnist/model.json"
lpn_model_weight_path = "../exps/mnist/experiments/mnist/model.pt"
dataset_config_path = "../exps/mnist/configs/mnist/test_dataset.json"

config_paths = {"lpn": lpn_model_config_path, "ne": ne_model_config_path}
weight_paths = {"lpn": lpn_model_weight_path, "ne": ne_model_weight_path}

In [16]:
sigma_levels = [0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.4, 0.5, 0.6, 0.7]

run_test(sigma_levels, 1000, config_paths, weight_paths, dataset_config_path, result_folder)

FileNotFoundError: [Errno 2] No such file or directory: '../exps/mnist/configs/mnist/model_ne_mnist_affine_z.json'