# Deep Image Prior with Distributional Consistency (DC) loss
## Plotting results

## Setup

In [None]:
# --- Set up import path ---
import sys, os

self_dir = os.getcwd()
src_path = os.path.abspath(os.path.join(self_dir, '..', 'src'))
sys.path.insert(0, src_path)

In [None]:
from __future__ import print_function
%matplotlib inline

import torch
import numpy as np
from skimage.metrics import peak_signal_noise_ratio as compare_psnr
from skimage.metrics import structural_similarity as ssim
from matplotlib import pyplot as plt
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

from utils.losses import dc_loss_clipped_gaussian
from utils.plotting import plot_CDF_histogram, hide
from utils.common_utils import crop_image, fix_seed, get_image, pil_to_np, get_noisy_image, np_to_torch

import pickle

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
dtype = torch.cuda.FloatTensor

imsize =-1
SAVEFIGS = False

In [None]:
# Losses
mse = torch.nn.MSELoss().type(dtype)
def dist_loss(q, m, sigma=0.25, return_values=False):
    return dc_loss_clipped_gaussian(q.flatten(), m.flatten(), sigma=sigma, return_values=return_values)

## Plots and figures

In [None]:
save_directory = "../figures/DIP/"

In [None]:
image_name = "F16_GT" #"F16_GT" # horse, cat, train, castle
fname = f'../data/images/{image_name}.png'
img_pil = crop_image(get_image(fname, imsize)[0], d=32)
img_np = pil_to_np(img_pil)

# sigma = 50 # one of 10,25,50,75,100
use_scheduler = False # either False or True
seed = 0   # one of 0,1,2,3,4
epochs = 10000 #10001
sigma = 75 # horse: 50, cat: 75, train: 25, castle: 25
states = pickle.load(open(fr'../results/DIP/image={image_name}_sigma={sigma}_scheduled={use_scheduler}_seed={seed}.pkl', 'rb'))
img_noisy_np = states["mse"]["image_noisy"]

In [None]:
iters = [0, 100, 1000, epochs-1]

fig, ax = plt.subplots(2, len(iters), figsize=(20, 5))

for j in range(2):
    state = states["mse"] if j == 0 else states["dist"]
    for i in range(len(iters)):
        color = 'red' if j == 0 else 'blue'
        ax[j, i] = plot_CDF_histogram(ax[j, i], state["histos"][iters[i] // 100][0], comparison="uniform", show_comparison=True, bins=100, color=color, alpha=0.7, label='DIP-MSE Empirical Density' if j ==0 else 'DIP-DC Empirical Density')
        if j == 0:
            ax[j,i].set_xlabel("")
        else:
            ax[j,i].set_xlabel("CDF value", fontsize=16)
        if i != 0:
            ax[j,i].set_ylabel("")
        else:
            ax[j,i].set_ylabel("Probability density", fontsize=12)
        ax[0, i].set_title(f"{iters[i]} iteration(s)", fontsize=20)
        ax[j,i].legend(fontsize=12)

# Add labels to the left of the subplots
fig.text(-0.03, 0.74, "DIP-MSE", fontsize=24, va='center', rotation='vertical')
fig.text(-0.03, 0.3, "DIP-DC", fontsize=24, va='center', rotation='vertical')

plt.tight_layout()
if SAVEFIGS: 
    plt.savefig(os.path.join(save_directory, f'DIP_images_histograms.svg'), bbox_inches='tight', dpi=30)
plt.show()

In [None]:
iters = [0, 100, 1000, epochs-1]
# ==== CONFIGURATION ====
if image_name == "cat":
    zoom_coords = (253, 386)  # Top-left corner of zoom patch (row, col) - "cat"
elif image_name == "horse":
    zoom_coords = (100, 200)
elif image_name == "F16_GT":
    zoom_coords = (220, 130)
elif image_name == "train":
    zoom_coords = (110, 320)
elif image_name == "castle":
    zoom_coords = (90, 100)
patch_size = 25         # Size of the zoomed patch

# Helper to add circular zoom insets
def add_circular_zoom(ax_img, img, zoom_coords, patch_size):
    r0, c0 = zoom_coords
    patch = img[r0:r0 + patch_size, c0:c0 + patch_size]

    # Draw a rectangle on the original ax_img
    rect = plt.Rectangle((c0, r0), patch_size, patch_size, linewidth=0.5, edgecolor='black', linestyle='--', facecolor='none')
    ax_img.add_patch(rect)

    anchor = 'N' if img.shape[0] > img.shape[1] else 'E'

    # Inset with relative sizing
    ax_inset = inset_axes(ax_img, width="35%", height="35%", loc='upper right',borderpad=0)
    ax_inset.set_anchor(anchor)

    ax_inset.imshow(patch, interpolation='nearest')
    ax_inset.set_xticks([]); ax_inset.set_yticks([])

# Setup figure and layout
fig = plt.figure(figsize=(15, 6), dpi=600) # use 9 for castle, 6 otherwise
gs = GridSpec(2, len(iters) + 2, width_ratios=[1, 0.2] + [1]*len(iters), wspace=0.03, hspace=0.15)
plt.tight_layout()

# True Image
ax_true = fig.add_subplot(gs[0, 0])
img_true = img_np.transpose(1, 2, 0)
ax_true.imshow(img_true[:-1, 1:])
ax_true.set_title("True Image", fontsize=16)
ax_true.axis("off")
add_circular_zoom(ax_true, img_true, zoom_coords, patch_size)

# Noisy Image
ax_noisy = fig.add_subplot(gs[1, 0])
ax_noisy.imshow(img_noisy_np[:-1, 1:])
ax_noisy.set_title("Noisy Image", fontsize=16)
ax_noisy.axis("off")
add_circular_zoom(ax_noisy, img_noisy_np, zoom_coords, patch_size)

# Iteration plots
for j, loss_name in enumerate(["mse", "dist"]):
    state = states[loss_name]

    # Y-axis label
    ax_label = fig.add_subplot(gs[j, 2])
    hide(ax_label)
    ax_label.set_ylabel("DIP-MSE" if j == 0 else "DIP-DC", 
                        fontsize=24, labelpad=15, rotation=90, va='center')

    for i, iteration in enumerate(iters):
        ax_img = fig.add_subplot(gs[j, i + 2])
        img = state["images"][iteration // 100][:-1, 1:]

        ax_img.imshow(img)
        if j == 0:
            ax_img.set_title(f"{iteration} Iterations", fontsize=16)
        ax_img.axis("off")

        add_circular_zoom(ax_img, img, zoom_coords, patch_size)

plt.tight_layout()

if SAVEFIGS:
    plt.savefig(os.path.join(save_directory, f'image={image_name}_sigma={sigma}_scheduled={use_scheduler}_seed={seed}_image_final.png'), bbox_inches='tight', dpi=600)
    plt.savefig(os.path.join(save_directory, f'image={image_name}_sigma={sigma}_scheduled={use_scheduler}_seed={seed}_image_final.svg'), bbox_inches='tight', dpi=600)
    plt.savefig(os.path.join(save_directory, f'image={image_name}_sigma={sigma}_scheduled={use_scheduler}_seed={seed}_image_final_lower.svg'), bbox_inches='tight', dpi=300)
    plt.savefig(os.path.join(save_directory, f'image={image_name}_sigma={sigma}_scheduled={use_scheduler}_seed={seed}_image_final_lowest.svg'), bbox_inches='tight', dpi=150)
plt.show()

In [None]:
# Plot the MSE loss curves
fig, ax = plt.subplots(1, 1, figsize=(5, 4))

cutoff = 50  # Cutoff for the x-axis

x_values = list(range(0, epochs, 1))

x_values_dist = x_values[:len(states["dist"]["losses"]["mse_loss"])]

ax.set_xlim(cutoff, x_values[-1])
ax.set_ylim(ax.get_ylim())

ax.plot(x_values, states["mse"]["losses"]["mse_loss"], label="DIP-MSE", color='red', alpha=0.7)
ax.plot(x_values_dist, states["dist"]["losses"]["mse_loss"], label="DIP-DC", color='blue', alpha=0.7)
ax.set_xlabel("Iterations", fontsize=16)
ax.set_ylabel("MSE loss", fontsize=16)

ax.set_ylim(min(states["mse"]["losses"]["mse_loss"][cutoff:].min(), 
                states["dist"]["losses"]["mse_loss"][cutoff:].min()),
            max(states["mse"]["losses"]["mse_loss"][cutoff:].max(), 
                states["dist"]["losses"]["mse_loss"][cutoff:].max()))

# ax.ticklabel_format(axis='y', style='sci', scilimits=(0,0))

ax.legend(fontsize=16)
ax.grid()
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("MSE Loss Trajectory", fontsize=16)

if SAVEFIGS: 
    plt.savefig(os.path.join(save_directory, f'DIP_MSE.svg'), bbox_inches='tight', dpi=600)
plt.show()

In [None]:
# Plot the DC loss curves
fig, ax = plt.subplots(1, 1, figsize=(5, 4))

cutoff = 50  # Cutoff for the x-axis

x_values = list(range(0, epochs, 1))

x_values_dist = x_values[:len(states["dist"]["losses"]["mse_loss"])]

ax.set_xlim(cutoff, x_values[-1])
ax.set_ylim(ax.get_ylim())

ax.plot(x_values, states["mse"]["losses"]["dist_loss"], label="DIP-MSE", color='red', alpha=0.7)
ax.plot(x_values_dist, states["dist"]["losses"]["dist_loss"], label="DIP-DC", color='blue', alpha=0.7)
ax.set_xlabel("Iterations", fontsize=14)
ax.set_ylabel("DC loss", fontsize=14)

ax.set_ylim(min(states["mse"]["losses"]["dist_loss"][cutoff:].min(), 
                states["dist"]["losses"]["dist_loss"][cutoff:].min()),
            max(states["mse"]["losses"]["dist_loss"][cutoff:].max(), 
                states["dist"]["losses"]["dist_loss"][cutoff:].max()))

ax.legend(fontsize=16)
ax.grid()
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("DC Loss Trajectory", fontsize=16)

if SAVEFIGS:
    plt.savefig(os.path.join(save_directory, f'DIP_DC.svg'), bbox_inches='tight', dpi=600)
plt.show()

In [None]:
# Helper function to calculate moving average
def moving_average(data, window_size):
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')

fig, ax = plt.subplots(1, 1, figsize=(5, 4))

cutoff = 50  # Cutoff for the x-axis
window_size = 100  # Window size for moving average

x_values = list(range(0, epochs, 1))
x_values_dist = x_values[:len(states["dist"]["losses"]["mse_loss"])]

# Calculate moving averages
mse_psnr_ma = moving_average(states["mse"]["losses"]["psnr"], window_size)
dist_psnr_ma = moving_average(states["dist"]["losses"]["psnr"], window_size)

# Plot original lines
ax.plot(x_values, states["mse"]["losses"]["psnr"], color='red', alpha=0.5, linewidth=0.25)
ax.plot(x_values_dist, states["dist"]["losses"]["psnr"], color='blue', alpha=0.5, linewidth=0.25)

# Plot moving averages
ax.plot(x_values[window_size-1:], mse_psnr_ma, label="DIP-MSE", color='red', alpha=1)
ax.plot(x_values_dist[window_size-1:], dist_psnr_ma, label="DIP-DC", color='blue', alpha=1)

ax.set_xlim(cutoff, x_values[-1])
ax.set_xlabel("Iterations", fontsize=14)
ax.set_ylabel(r"PSNR (dB) $\uparrow$", fontsize=14)

ax.set_ylim(min(states["mse"]["losses"]["psnr"][cutoff:].min(), 
                states["dist"]["losses"]["psnr"][cutoff:].min()),
            max(states["mse"]["losses"]["psnr"][cutoff:].max(), 
                states["dist"]["losses"]["psnr"][cutoff:].max()))

ax.legend(fontsize=16)
ax.grid()
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("PSNR Trajectory", fontsize=16)


if SAVEFIGS:
    plt.savefig(os.path.join(save_directory, f'DIP_PSNR.svg'), bbox_inches='tight', dpi=600)
plt.show()

In [None]:
# Helper function to calculate moving average
def moving_average(data, window_size):
    return np.convolve(data, np.ones(window_size)/window_size, mode='valid')

fig, ax = plt.subplots(1, 1, figsize=(5, 4))

cutoff = 50  # Cutoff for the x-axis
window_size = 100  # Window size for moving average

x_values = list(range(0, epochs, 1))
x_values_dist = x_values[:len(states["dist"]["losses"]["mse_loss"])]

# Calculate moving averages
mse_ssim_ma = moving_average(states["mse"]["losses"]["ssim"], window_size)
dist_ssim_ma = moving_average(states["dist"]["losses"]["ssim"], window_size)

# Plot original lines
ax.plot(x_values, states["mse"]["losses"]["ssim"], color='red', alpha=0.5, linewidth=0.25)
ax.plot(x_values_dist, states["dist"]["losses"]["ssim"], color='blue', alpha=0.5, linewidth=0.25)

# Plot moving averages
ax.plot(x_values[window_size-1:], mse_ssim_ma, label="DIP-MSE", color='red', alpha=1)
ax.plot(x_values_dist[window_size-1:], dist_ssim_ma, label="DIP-DC", color='blue', alpha=1)

ax.set_xlim(cutoff, x_values[-1])
ax.set_xlabel("Iterations", fontsize=14)
ax.set_ylabel("SSIM", fontsize=14)

ax.set_ylim(min(states["mse"]["losses"]["ssim"][cutoff:].min(), 
                states["dist"]["losses"]["ssim"][cutoff:].min()),
            max(states["mse"]["losses"]["ssim"][cutoff:].max(), 
                states["dist"]["losses"]["ssim"][cutoff:].max()))

ax.legend(fontsize=16)
ax.grid()
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt.title("SSIM Trajectory", fontsize=16)


if SAVEFIGS:
    plt.savefig(os.path.join(save_directory, f'DIP_SSIM.svg'), bbox_inches='tight', dpi=600)
plt.show()

## Validating uniform assumption

In [None]:
image_name = "F16_GT"
fname = f'../data/images/{image_name}.png'

seed = 0
sigma = 10
sigma_ = sigma / 255.

img_pil = crop_image(get_image(fname, imsize)[0], d=32)
img_np = pil_to_np(img_pil)
img_np = img_np[:3, ]  # Remove alpha channel if present

# Clean image calculations
mean_dist_loss_clean = []

for seed in range(100):
    fix_seed(seed)
    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
    img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

    dist_loss_value_clean, histo_values_clean, ix_clean = dist_loss(torch.tensor(img_np[None, :, :, :]).type(dtype), img_noisy_torch, sigma=sigma_, return_values=True)
    mean_dist_loss_clean.append(dist_loss_value_clean.item())

mean_dist_loss_clean = np.array(mean_dist_loss_clean)

# Noisy image calculations
mean_dist_loss_noisy = []

for seed in range(100):
    fix_seed(seed)
    img_noisy_pil, img_noisy_np = get_noisy_image(img_np, sigma_)
    img_noisy_torch = np_to_torch(img_noisy_np).type(dtype)

    dist_loss_value_noisy, histo_values_noisy, ix_noisy = dist_loss(img_noisy_torch, img_noisy_torch, sigma=sigma_, return_values=True)
    mean_dist_loss_noisy.append(dist_loss_value_noisy.item())

mean_dist_loss_noisy = np.array(mean_dist_loss_noisy)

# Plotting
fig, ax = plt.subplots(1, 2, figsize=(10, 4))

# Clean image subplot
ax[0] = plot_CDF_histogram(ax[0], histo_values_clean.cpu().numpy(), comparison="uniform", show_comparison=True, bins=100, color='green', alpha=0.7, label='Empirical density')
ax[0].set_title("CDF histogram of clean image\n(DC loss value: " + rf"{mean_dist_loss_clean.mean():.4f} $\pm$ {mean_dist_loss_clean.std() * 1.96:.4f})", fontsize=12)

# Noisy image subplot
ax[1] = plot_CDF_histogram(ax[1], histo_values_noisy.cpu().numpy(), comparison="uniform", show_comparison=True, bins=100, color='orange', alpha=0.7, label='Empirical density')
ax[1].set_title("CDF histogram of noisy image\n(DC loss value: " + rf"{mean_dist_loss_noisy.mean():.4f} $\pm$ {mean_dist_loss_noisy.std() * 1.96:.4f})", fontsize=12)

plt.tight_layout()

if SAVEFIGS:
    plt.savefig(os.path.join(save_directory, f'image={image_name}_sigma={sigma}_combined_histograms.svg'), bbox_inches='tight', dpi=300)
plt.show()


## Repeat results (e.g. peak PSNR)

Note: you must run repeats with different sigma and seeds to run the below cells. This can be achieved easily by passing in multiple arguments with the --sigmas and --seeds flags to the DIP.py script, e.g. --seeds 0 1 2 3 4 --sigmas 10 25 50 75 100

In [None]:
image_name = "F16_GT"
use_scheduler = False
seed = 0
import pickle

stats = {}
for sigma in [10,25,50,75,100]:
    print(sigma)
    stats[sigma] = {}
    for seed in range(5):
        print(seed, end=' ')
        _state = pickle.load(open(f'../results/DIP/image={image_name}_sigma={sigma}_scheduled={use_scheduler}_seed={seed}.pkl', 'rb'))

        max_psnr_dist = _state["dist"]["losses"]["psnr"].max()
        max_psnr_mse = _state["mse"]["losses"]["psnr"].max()
        max_ssim_dist = _state["dist"]["losses"]["ssim"].max()
        max_ssim_mse = _state["mse"]["losses"]["ssim"].max()

        average_image_dist = np.array(_state["dist"]["images"][-10:]).mean(axis=0)
        average_image_mse = np.array(_state["mse"]["images"][-10:]).mean(axis=0)

        average_psnr_dist = compare_psnr(img_np.transpose(1, 2, 0), average_image_dist, data_range=1.0)
        average_psnr_mse = compare_psnr(img_np.transpose(1, 2, 0), average_image_mse, data_range=1.0)
        average_ssim_dist = ssim(img_np.transpose(1, 2, 0), average_image_dist, data_range=1.0, multichannel=True, channel_axis=2)
        average_ssim_mse = ssim(img_np.transpose(1, 2, 0), average_image_mse, data_range=1.0, multichannel=True, channel_axis=2)

        final_psnr_dist = _state["dist"]["losses"]["psnr"][-1]
        final_psnr_mse = _state["mse"]["losses"]["psnr"][-1]
        final_ssim_dist = _state["dist"]["losses"]["ssim"][-1]
        final_ssim_mse = _state["mse"]["losses"]["ssim"][-1]

        
        stat = {"max_psnr_dist": max_psnr_dist, "max_psnr_mse": max_psnr_mse,
                "max_ssim_dist": max_ssim_dist, "max_ssim_mse": max_ssim_mse,

                "final_psnr_dist": final_psnr_dist, "final_psnr_mse": final_psnr_mse,
                "final_ssim_dist": final_ssim_dist, "final_ssim_mse": final_ssim_mse,

                "average_psnr_dist": average_psnr_dist, "average_psnr_mse": average_psnr_mse,
                "average_ssim_dist": average_ssim_dist, "average_ssim_mse": average_ssim_mse,
        }
        
        stats[sigma][seed] = stat

In [None]:
# Extract peak PSNR values for each sigma and seed
sigmas = sorted(stats.keys())
peak_psnr_dist = [np.array([stats[sigma][seed]["max_psnr_dist"] for seed in stats[sigma]]) for sigma in sigmas]
peak_psnr_mse = [np.array([stats[sigma][seed]["max_psnr_mse"] for seed in stats[sigma]]) for sigma in sigmas]

# Calculate mean and standard deviation
mean_psnr_dist = [psnr.mean() for psnr in peak_psnr_dist]
std_psnr_dist = [psnr.std() * 1.96  for psnr in peak_psnr_dist]
mean_psnr_mse = [psnr.mean() for psnr in peak_psnr_mse]
std_psnr_mse = [psnr.std() * 1.96 for psnr in peak_psnr_mse]

# Plot
fig, ax = plt.subplots(figsize=(5, 4))
ax.errorbar(sigmas, mean_psnr_dist, yerr=std_psnr_dist,  fmt='--', color='blue', alpha=0.7,label="DIP-MSE")
ax.errorbar(sigmas, mean_psnr_mse, yerr=std_psnr_mse,  fmt='--', color='red', alpha=0.7, label="DIP-DC")
ax.errorbar(sigmas, mean_psnr_mse, yerr=std_psnr_mse,  fmt='x', color='red', capsize=5)
ax.errorbar(sigmas, mean_psnr_dist, yerr=std_psnr_dist,  fmt='x', color='blue', capsize=5)

ax.set_xlabel(r"Noise level (Standard deviation $\sigma$ of noise)", fontsize=12)
ax.set_ylabel(r"Peak PSNR (dB) $\uparrow$", fontsize=16)
ax.set_title("Peak PSNR vs Noise Level", fontsize=16)
ax.legend(fontsize=16)
ax.grid()
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)


if SAVEFIGS:
    plt.savefig(os.path.join(save_directory, f'DIP_peak_PSNR.svg'), bbox_inches='tight', dpi=600)
plt.show()

In [None]:
# Extract peak SSIM values for each sigma and seed
peak_ssim_dist = [np.array([stats[sigma][seed]["max_ssim_dist"] for seed in stats[sigma]]) for sigma in sigmas]
peak_ssim_mse = [np.array([stats[sigma][seed]["max_ssim_mse"] for seed in stats[sigma]]) for sigma in sigmas]

# Calculate mean and standard deviation
mean_ssim_dist = [ssim.mean() for ssim in peak_ssim_dist]
std_ssim_dist = [ssim.std() * 1.96 for ssim in peak_ssim_dist]
mean_ssim_mse = [ssim.mean() for ssim in peak_ssim_mse]
std_ssim_mse = [ssim.std() * 1.96 for ssim in peak_ssim_mse]

# Plot
fig, ax = plt.subplots(figsize=(5,4))
ax.errorbar(sigmas, mean_ssim_dist, yerr=std_ssim_dist, fmt='--', color='blue', alpha=0.7)
ax.errorbar(sigmas, mean_ssim_mse, yerr=std_ssim_mse, fmt='--', color='red', alpha=0.7)
ax.errorbar(sigmas, mean_ssim_mse, yerr=std_ssim_mse, label="DIP-MSE", fmt='x', color='red', capsize=5)
ax.errorbar(sigmas, mean_ssim_dist, yerr=std_ssim_dist, label="DIP-DC", fmt='x', color='blue', capsize=5)

ax.set_xlabel(r"Noise level (Standard deviation $\sigma$ of noise)", fontsize=12)
ax.set_ylabel(r"Peak SSIM $\uparrow$", fontsize=16)
ax.set_title("Peak SSIM vs Noise Level", fontsize=16)
ax.legend(fontsize=16)
ax.grid()
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)


if SAVEFIGS: 
    plt.savefig(os.path.join(save_directory, f'image={image_name}_peak_SSIM.svg'), bbox_inches='tight', dpi=600)
plt.show()


In [None]:
# Extract peak PSNR values for each sigma and seed
sigmas = sorted(stats.keys())
avg_end_psnr_dist = [np.array([stats[sigma][seed]["average_psnr_dist"] for seed in stats[sigma]]) for sigma in sigmas]
avg_end_psnr_mse = [np.array([stats[sigma][seed]["average_psnr_mse"] for seed in stats[sigma]]) for sigma in sigmas]

end_psnr_dist = [np.array([stats[sigma][seed]["final_psnr_dist"] for seed in stats[sigma]]) for sigma in sigmas]
end_psnr_mse = [np.array([stats[sigma][seed]["final_psnr_mse"] for seed in stats[sigma]]) for sigma in sigmas]

# Calculate mean and standard deviation for the new quantities
mean_avg_end_psnr_dist = [psnr.mean() for psnr in avg_end_psnr_dist]
std_avg_end_psnr_dist = [psnr.std() * 1.96 for psnr in avg_end_psnr_dist]
mean_avg_end_psnr_mse = [psnr.mean() for psnr in avg_end_psnr_mse]
std_avg_end_psnr_mse = [psnr.std() * 1.96 for psnr in avg_end_psnr_mse]

mean_end_psnr_dist = [psnr.mean() for psnr in end_psnr_dist]
std_end_psnr_dist = [psnr.std() * 1.96 for psnr in end_psnr_dist]
mean_end_psnr_mse = [psnr.mean() for psnr in end_psnr_mse]
std_end_psnr_mse = [psnr.std() * 1.96 for psnr in end_psnr_mse]

# Plot
fig, ax = plt.subplots(figsize=(10, 8))
ax.errorbar(sigmas, mean_avg_end_psnr_dist,  fmt='--', color='blue', alpha=0.7, label="End PSNR of averaged images: DIP-DC")
ax.errorbar(sigmas, mean_avg_end_psnr_mse,  fmt='--', color='red', alpha=0.7, label="End PSNR of averaged images: DIP-MSE")
ax.errorbar(sigmas, mean_avg_end_psnr_dist, yerr=std_avg_end_psnr_dist, fmt='x', color='blue', alpha=1.0)
ax.errorbar(sigmas, mean_avg_end_psnr_mse, yerr=std_avg_end_psnr_mse, fmt='x', color='red', alpha=1.0)

ax.errorbar(sigmas, mean_end_psnr_dist, yerr=std_end_psnr_dist, fmt='x', color='blue', capsize=3)
ax.errorbar(sigmas, mean_end_psnr_mse, yerr=std_end_psnr_mse, fmt='x', color='red', capsize=3)
ax.errorbar(sigmas, mean_end_psnr_dist, fmt=':', color='blue', capsize=3, label="End PSNR: DIP-DC", alpha=0.7)
ax.errorbar(sigmas, mean_end_psnr_mse, fmt=':', color='red', capsize=3, label="End PSNR: DIP-MSE", alpha=0.7)

ax.errorbar(sigmas, mean_psnr_dist,  fmt='-', label="Peak PSNR: DIP-DC", color='blue', alpha=0.7)
ax.errorbar(sigmas, mean_psnr_mse,  fmt='-', label="Peak PSNR: DIP-MSE", color='red', alpha=0.7)
ax.errorbar(sigmas, mean_psnr_mse, yerr=std_psnr_mse, fmt='x', color='red', capsize=3)
ax.errorbar(sigmas, mean_psnr_dist, yerr=std_psnr_dist,  fmt='x', color='blue', capsize=3)

ax.set_xlabel(r"Noise level (Standard deviation $\sigma$ of noise)", fontsize=12)
ax.set_ylabel(r"PSNR (dB) $\uparrow$", fontsize=16)
ax.set_title("PSNR vs Noise Level\n(varied choice of output)", fontsize=16)
ax.legend(fontsize=12)
ax.grid()
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)


if SAVEFIGS:
    plt.savefig(os.path.join(save_directory, f'image={image_name}_mixed_PSNR.svg'), bbox_inches='tight', dpi=600)
plt.show()