In [None]:
import sys
sys.path.append('/home/datalab/notebooks/Adam/')
import torch
import torch.nn as nn
import yaml
import numpy as np
import pandas as pd

from torch import tensor
from torch.utils.data import DataLoader
from torchvision import transforms
import pytorch_ssim
from tqdm import tqdm

from gunet import GUNet
from unet import UNet
from dataset import CustomDataset, get_images

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon

from torchsummary import summary
import scipy
import math

In [None]:
from PIL import Image, ImageOps
import torchvision.transforms.functional as TF
image = Image.open('radar_aeroporto.jpeg')
image = ImageOps.grayscale(image)

original_tensor = TF.to_tensor(image)

constatnt = 0.2

# Add Gaussian noise to the image
noisy_tensor = original_tensor + torch.randn(original_tensor.shape) * constatnt

# Add a constant value to all pixels in the image
value_tensor = original_tensor + constatnt

# Add a batch dimension to the tensors
original_tensor = original_tensor.unsqueeze(0)
noisy_tensor = noisy_tensor.unsqueeze(0)
value_tensor = value_tensor.unsqueeze(0)

losses = [nn.L1Loss(), nn.MSELoss(), pytorch_ssim.SSIM(window_size=11)]

num_images = 2
num_losses = len(losses)
loss_array = np.zeros((num_images, num_losses))

# Compute the losses for each modified image
for i, (name, tensor) in enumerate([('Noisy', noisy_tensor), ('Value', value_tensor)]):
    for j, loss_fn in enumerate(losses):
        loss = loss_fn(original_tensor, tensor).item()
        loss_array[i, j] = loss
print(loss_array)

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid
from mpl_toolkits.basemap import Basemap
import matplotlib.colors as colors
mpl.rcParams.update(mpl.rcParamsDefault)

coor = (11.8, 48.07, 19.41, 51.57)

idx = [1,3,5,7]
# ratio = targets[0,0].shape[0]/targets[0,0].shape[1]
ratio = image.height / image.width

fig,axs = plt.subplots(1,3, dpi=200, figsize=(9,3),height_ratios=[ratio])

cmap = 'gray'

images = []


axs[1].set_xlabel(r"$\bf{SSIM}: $" + f"{loss_array[0,2]:.3f}, " + 
                    r"$\bf{MAE}: $" + f"{loss_array[0,0]:.3f}, " +
                    r"$\bf{MSE}: $" + f"{loss_array[0,1]:.3f}",labelpad=0)
axs[2].set_xlabel(r"$\bf{SSIM}: $" + f"{loss_array[1,2]:.3f}, " + 
                    r"$\bf{MAE}: $" + f"{loss_array[1,0]:.3f}, " +
                    r"$\bf{MSE}: $" + f"{loss_array[1,1]:.3f}",labelpad=0)
axs[0].imshow(original_tensor[0,0],cmap=cmap)
axs[1].imshow(noisy_tensor[0,0],cmap=cmap)
axs[2].imshow(value_tensor[0,0],cmap=cmap)
    
    
# for axs2 in axs:
#     for ax in axs2:
#         ax.set_frame_on(False)
        

    
for ax, title in zip(axs, ["Original", "Added\ noise", "Added\ constant"]):
    ax.set_title(r"$\bf{"+f"{title}"+"}$")
    ax.set_frame_on(False)
    ax.set_xticks([])
    ax.set_yticks([])
    
fig.tight_layout(pad=0.5)


In [None]:
device = torch.device('cuda:3')

In [None]:
with open('/home/datalab/notebooks/Adam/run/unet_without_clr/dataset_config.yaml') as file:
    dataset_config = yaml.safe_load(file)

with open('/home/datalab/notebooks/Adam/run/unet_without_clr/model_config.yaml') as file:
    unet_config = yaml.safe_load(file)

with open('/home/datalab/notebooks/Adam/run/gunet_without_clr/model_config.yaml') as file:
    gunet_config = yaml.safe_load(file)

In [None]:
dataset_config

In [None]:
X_train, y_train, X_val, y_val, X_test, y_test = get_images(dataset_config['data_path'],
                                                            dataset_config['stride_minutes'],
                                                            dataset_config['input_length'],
                                                            dataset_config['output_length'],
                                                            dataset_config['chunk_size'],
                                                            dataset_config['test_frac'],
                                                            dataset_config['val_frac'],
                                                            dataset_config['seed'])

In [None]:
transform = transforms.Compose([
#     transforms.CenterCrop((256, 512)),
    transforms.ToTensor(),
])
        
dataset = CustomDataset(X_test, y_test,transform=transform)
dataloader = DataLoader(dataset, batch_size=gunet_config['batch_size'], shuffle=True, num_workers=4,)

In [None]:
train_loader = DataLoader(CustomDataset(X_train, y_train,transform=transform), batch_size=1, shuffle=True, num_workers=16,)
val_loader = DataLoader(CustomDataset(X_val, y_val,transform=transform), batch_size=1, shuffle=False, num_workers=16,)
test_loader = DataLoader(CustomDataset(X_test, y_test,transform=transform), batch_size=50, shuffle=False, num_workers=16,)

print(len(train_loader), len(val_loader), len(test_loader))
#168110 21108 20957

In [None]:
gunet_config

In [None]:
unet_config

In [None]:
gunet_model = GUNet(dataset_config['input_length'], dataset_config['output_length'], gunet_config['dropout_rate'], gunet_config['r'], gunet_config['epsilon'])
gunet_model.to(device)
path = f"/home/datalab/notebooks/Adam/run/gunet_without_clr/models/best.pt"
gunet_model.load_state_dict(torch.load(path, map_location=device))

In [None]:
unet_model = UNet(dataset_config['input_length'], dataset_config['output_length'],unet_config['dropout_rate'],k=unet_config['kernel_size'])
unet_model.to(device)
path = f"/home/datalab/notebooks/Adam/run/unet_without_clr/models/best.pt"
unet_model.load_state_dict(torch.load(path, map_location=device))

In [None]:
gunet_model.eval()
unet_model.eval();

In [None]:
def measure(c):
    mse_loss = nn.MSELoss()
    ssim_loss = pytorch_ssim.SSIM()
    mae_loss = nn.L1Loss()
    losses = [mae_loss,mse_loss,ssim_loss]


    running_losses_gunet_sum = np.zeros(len(losses))
    running_losses_unet_sum = np.zeros(len(losses))
    
    
    
    for i in range(c):
        running_losses_gunet = np.zeros(len(losses))
        running_losses_unet = np.zeros(len(losses))
        batch_size = 50
        test_loader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=32,))
        with torch.no_grad():
            for (inputs, targets) in tqdm(test_loader, "Testing"):
                inputs, targets = inputs.to(device), targets.to(device)
                gunet_outputs = gunet_model(inputs)
                unet_outputs = unet_model(inputs)

                for k in range(len(losses)):
                    running_losses_gunet[k] += losses[k](gunet_outputs, targets).item()
                    running_losses_unet[k] += losses[k](unet_outputs, targets).item()
        print(f'GUNet: {", ".join([f"{i:.7f}" for i in running_losses_gunet/len(test_loader)])}')
        print(f'UNet: {", ".join([f"{i:.7f}" for i in running_losses_unet/len(test_loader)])}')
        running_losses_gunet_sum +=  running_losses_gunet
        running_losses_unet_sum +=  running_losses_unet
        print(f'sum GUNet: {", ".join([f"{i:.7f}" for i in running_losses_gunet_sum/(len(test_loader)*(i+1))])}')
        print(f'sum UNet: {", ".join([f"{i:.7f}" for i in running_losses_unet_sum/(len(test_loader)*(i+1))])}')
        print('---------------')
    return running_losses_gunet_sum, running_losses_unet_sum

In [None]:
run_loss_gunet, run_loss_unet = measure(10)

In [None]:
def measure_with_thresholds(c):
    mse_loss = nn.MSELoss()
    ssim_loss = pytorch_ssim.SSIM()
    mae_loss = nn.L1Loss()
    losses = [mae_loss,mse_loss,ssim_loss]

    step = 1/c
    threshold = 0
    dic_unet = {}
    dic_gunet = {}
    for i in range(c+1):
        running_losses_gunet = np.zeros(len(losses))
        running_losses_unet = np.zeros(len(losses))
        batch_size = 40
        test_loader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=32,))
        with torch.no_grad():
            for (inputs, targets) in tqdm(test_loader, f"Testing {threshold}"):
                inputs, targets = inputs.to(device), targets.to(device)
                gunet_outputs = gunet_model(inputs)
                unet_outputs = unet_model(inputs)
                
                targets.masked_fill_(targets < threshold, 0)
                gunet_outputs.masked_fill_(gunet_outputs < threshold, 0)
                unet_outputs.masked_fill_(unet_outputs < threshold, 0)
                
                for k in range(len(losses)):
                    running_losses_gunet[k] += losses[k](gunet_outputs, targets).item()
                    running_losses_unet[k] += losses[k](unet_outputs, targets).item()
                    
        dic_unet[threshold] = running_losses_unet/len(test_loader)
        dic_gunet[threshold] = running_losses_gunet/len(test_loader)
        print(f'GUNet: {", ".join([f"{i:.7f}" for i in running_losses_gunet/len(test_loader)])}') 
        print(f'UNet: {", ".join([f"{i:.7f}" for i in running_losses_unet/len(test_loader)])}')
        threshold += step

    return dic_unet, dic_gunet

In [None]:
dic_unet, dic_gunet = measure_with_thresholds(20)

In [None]:
df_unet = pd.DataFrame(dic_unet).transpose()
df_unet.rename(columns={0:'mae_unet',1:'mse_unet',2:'ssim_unet'},inplace=True)
df_unet.head()

In [None]:
df_gunet = pd.DataFrame(dic_gunet).transpose()
df_gunet.rename(columns={0:'mae_gunet',1:'mse_gunet',2:'ssim_gunet'},inplace=True)
df_gunet.head()

In [None]:
df = df_unet.join(df_gunet)
df = df.sort_index(axis=1)
df

In [None]:
import matplotlib.cm as cm
import matplotlib.colors as colors
import matplotlib.colorbar as olorbar

def threshold_plot(loss,cmap,reverse):
    y = (df[f'{loss}_gunet']-df[f'{loss}_unet']).to_numpy()
    x = df.index

    fig, axs = plt.subplots(1,2,dpi=200,figsize=(7,3))

    axs[0].plot(df[f'{loss}_unet'],c="#268BD2",label='Unet',)
    axs[0].plot(df[f'{loss}_gunet'],c="#D1495B",label='GUnet',ls="--")

    axs[0].set_xlim(0.0,1)
    axs[0].legend(framealpha=1)
    axs[0].grid(True,c='0.8')
    axs[0].set_frame_on(False)
    axs[0].set_xlabel('Threshold')
    axs[0].set_ylabel(loss.upper())


    cmap = cm.get_cmap(cmap)

    vmin = -max(abs(np.max(y)),abs(np.min(y)))
    vmax = max(abs(np.max(y)),abs(np.min(y)))


    axs[1].plot(x, y, c='#EDAE49', linewidth=2)

    axs[1].set_frame_on(False)
    axs[1].set_xlabel('Threshold')
    axs[1].set_ylabel('Difference',labelpad=0)
    axs[1].grid(True,c='0.8')
    axs[1].set_ylim(vmin-0.1*abs(vmin),vmax+0.1*abs(vmax))

    fig.tight_layout(pad=2)

#     fig.colorbar(dummy_scatter,ax=axs,shrink=1,location='right', orientation='vertical',pad=0.05)
    
    
threshold_plot('mse','coolwarm',False)


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.scale as mscale
import matplotlib.transforms as mtransforms

fig, ax = plt.subplots(1,1,dpi=200,figsize=(6,6))

import datetime as dt


ax.plot(df['ssim_unet']-df['ssim_gunet'],c="#268BD2")
# ax.plot(10**df['ssim_gunet'],c="#D1495B",label='GUnet',lw=0.5)

ax.set_xlabel('Threshold')
ax.set_ylabel('SSIM')

# ax.set_yscale('log')


# y_ticks = np.arange(0.825, 0.875, 0.01)
# plt.yticks(y_ticks,[round(i,3) if round(i,4) % 0.001 == 0 else ""  for i in y_ticks])
# ax.set_xticks([i * 86400 for i in np.arange(0,3.5,0.5)])
# ax.set_xticklabels([f"{round(i // 3600)}" for i in ax.get_xticks()])
# ax.set_xlim(0.25,0.35)
# ax.set_ylim(0.93,0.96)
# ax.set_yscale('log')
ax.legend(framealpha=1)


plt.show()
ax.grid(True,c='0.8')
ax.set_frame_on(False)
# ax.set_yscale('log')

In [None]:
# size = tuple(inputs.shape[1:])
# print(size)
# unet_model.to('cpu')
# summary(unet_model,size,device='cpu')

In [None]:
# size = tuple(inputs.shape[1:])
# print(size)
# gunet_model.to('cpu')
# summary(gunet_model,size,device='cpu')

In [None]:
def accumulate(loader):
    unet_acc_out = None
    gunet_acc_out = None
    acc_tar = None
    count = 0

    with torch.no_grad():
        for inputs, targets in tqdm(loader, desc="Accumulating"):
            batch_size = inputs.size(0)
            inputs, targets = inputs.to(device), targets.to(device)
            outputs_unet = unet_model(inputs)
            outputs_gunet = gunet_model(inputs)

            batch_unet_acc_out = outputs_unet.sum(dim=(0, 1))
            batch_gunet_acc_out = outputs_gunet.sum(dim=(0, 1))
            batch_acc_tar = targets.sum(dim=(0, 1))

            if count == 0:
                unet_acc_out = batch_unet_acc_out
                gunet_acc_out = batch_gunet_acc_out
                acc_tar = batch_acc_tar
            else:
                unet_acc_out += batch_unet_acc_out
                gunet_acc_out += batch_gunet_acc_out
                acc_tar += batch_acc_tar

            count += outputs_unet.size(0) * outputs_unet.size(1)
    return acc_tar / count, unet_acc_out / count, gunet_acc_out / count

In [None]:
tar, unet_out, gunet_out = accumulate(loader = test_loader)

In [None]:
def moving_average(arr, window_size=32):
    window = np.ones((window_size,window_size)) / window_size**2
    
    return scipy.signal.convolve2d(arr, window, mode='same')

def transforms(arr):
    tmp = np.fft.fftshift(torch.fft.fftn(arr))
    return np.log(np.abs(tmp))

In [None]:
fig = plt.figure(figsize=(6,3),dpi=200)
# ax1 = fig.add_subplot(1,2,1,projection='3d')
ax2 = fig.add_subplot(1,2,1,projection='3d')
ax3 = fig.add_subplot(1,2,2,projection='3d')


r = range(0, tar.shape[1])
p = range(0, tar.shape[0])
X, Y = np.meshgrid(r, p)

z1 = moving_average(transforms(tar.cpu()))
z2 = moving_average(transforms(gunet_out.cpu()))
z3 = moving_average(transforms(unet_out.cpu()))
vmin=0.9
vmax=1.9
print(vmin,vmax)
# ax1.plot_surface(X, Y, abs(z1-z1), cmap='turbo', vmin=vmin,vmax=vmax, cstride=4, rstride=4)
ax2.plot_surface(X, Y, abs(z2-z1), cmap='RdYlGn_r', vmin=vmin,vmax=vmax, cstride=4, rstride=4)
ax3.plot_surface(X, Y, abs(z3-z1), cmap='RdYlGn_r', vmin=vmin,vmax=vmax, cstride=4, rstride=4)

# ax1.set_title("Targets")
ax2.set_title("GUNet")
ax3.set_title("UNet")

for ax in [ax2,ax3]:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zlim(vmin, vmax+1.5)
    

# ax.set_xlabel("X label")
# ax.set_ylabel("Y label")
# ax.set_zlabel("Z label")
ax2.set_zlim(vmin, vmax+1.5)
ax3.set_zlim(vmin, vmax+1.5)

# plt.imshow(transforms(gunet_out.cpu()),cmap='turbo')

In [None]:

fig, axs = plt.subplots(1,2,figsize=(15,5),dpi=200)
# axs[0].imshow(z1,vmin=vmin,vmax=vmax,cmap='turbo')
axs[0].imshow(abs(z2-z1),vmin=vmin,vmax=vmax,cmap='RdYlGn_r')
axs[1].imshow(abs(z3-z1),vmin=vmin,vmax=vmax,cmap='RdYlGn_r')

In [None]:
vmin=-3
vmax=1

fig, axs = plt.subplots(1,3,figsize=(15,5),dpi=200)
axs[0].imshow(z1,vmin=vmin,vmax=vmax,cmap='turbo')
axs[1].imshow(z2,vmin=vmin,vmax=vmax,cmap='turbo')
axs[2].imshow(z3,vmin=vmin,vmax=vmax,cmap='turbo')

In [None]:
plt.hist([z1.flatten(),z2.flatten(),z3.flatten()],bins=20);

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,4),dpi=200)
# axs[0].imshow(z1,vmin=vmin,vmax=vmax,cmap='turbo')
axs[0].imshow(abs(z2-z1),vmin=vmin,vmax=vmax,cmap='RdYlGn_r')
axs[1].imshow(abs(z3-z1),vmin=vmin,vmax=vmax,cmap='RdYlGn_r')

for ax, title in zip(axs, ["UNet", "GUNet"]):
    ax.set_title(r"$\bf{"+f"{title}"+"}$")
    ax.axis('off')
    
fig.tight_layout(pad=1)
fig.colorbar(s, ax=axs,shrink=0.5,location='bottom', orientation='horizontal',pad=0.05)

fig,axs = plt.subplots(1,3, dpi=200, figsize=(3,1), height_ratios=[ratio])

axs[0].imshow(gunet_out.cpu(),cmap='turbo')
axs[1].imshow(unet_out.cpu(),cmap='turbo')
axs[2].imshow(tar.cpu(),cmap='turbo')

In [None]:
from scipy import stats

t_stat, p_value = stats.ttest_ind(means_gunet, means_unet, alternative='greater')

print(p_value)

alpha = 0.05  # Set significance level (e.g., 0.05 for 95% confidence)
if p_value < alpha:
    print("Model 1 generates images with significantly higher frequency content.")
    print(f"We reject the null hypothesis (H0) in favor of the alternative hypothesis (H1) at a {100 * (1 - alpha)}% confidence level.")
else:
    print("No significant difference in frequency content between the two models.")
    print(f"We cannot reject the null hypothesis (H0) at a {100 * (1 - alpha)}% confidence level.")

In [None]:
def accumulate_random(batch_size, channels, height, width, max_batches):
    unet_acc_out = None
    gunet_acc_out = None
    inputs_acc = None
    count = 0

    with torch.no_grad():
        for i in tqdm(range(max_batches), desc="Accumulating"):
            
            inputs = torch.randn(batch_size, channels, height, width).to(device)
            outputs_unet = unet_model(inputs)
            outputs_gunet = gunet_model(inputs)

            batch_unet_acc_out = outputs_unet.sum(dim=(0, 1))
            batch_gunet_acc_out = outputs_gunet.sum(dim=(0, 1))
            batch_inputs_acc = inputs.sum(dim=(0, 1))

            if count == 0:
                unet_acc_out = batch_unet_acc_out
                gunet_acc_out = batch_gunet_acc_out
                inputs_acc = batch_inputs_acc
            else:
                unet_acc_out += batch_unet_acc_out
                gunet_acc_out += batch_gunet_acc_out
                inputs_acc += batch_inputs_acc

            count += outputs_unet.size(0) * outputs_unet.size(1)
    return  unet_acc_out / count, gunet_acc_out / count, inputs_acc / count

In [None]:
unet_rand,gunet_rand, in_rand = accumulate_random(50,8,256,512,100)

In [None]:
def moving_average(arr, window_size=32):
    window = np.ones((window_size,window_size)) / window_size**2
    
    return scipy.signal.convolve2d(arr, window, mode='same')

fig = plt.figure(figsize=(6,3),dpi=200)
# ax1 = fig.add_subplot(1,2,1,projection='3d')
ax2 = fig.add_subplot(1,2,1,projection='3d')
ax3 = fig.add_subplot(1,2,2,projection='3d')


r = range(0, in_rand.shape[1])
p = range(0, in_rand.shape[0])
X, Y = np.meshgrid(r, p)

z1 = moving_average(transforms(in_rand.cpu()))
z2 = moving_average(transforms(gunet_rand.cpu()))
z3 = moving_average(transforms(unet_rand.cpu()))
vmin=-3
vmax=1
print(vmin,vmax)
# ax1.plot_surface(X, Y, abs(z1-z1), cmap='turbo', vmin=vmin,vmax=vmax, cstride=4, rstride=4)
ax2.plot_surface(X, Y, z2, cmap='turbo', vmin=vmin,vmax=vmax, cstride=4, rstride=4)
ax3.plot_surface(X, Y, z3, cmap='turbo', vmin=vmin,vmax=vmax, cstride=4, rstride=4)

# ax1.set_title("Targets")
ax2.set_title("GUNet")
ax3.set_title("UNet")

for ax in [ax2,ax3]:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_zlim(vmin, vmax+1.5)
    

# ax.set_xlabel("X label")
# ax.set_ylabel("Y label")
# ax.set_zlabel("Z label")
ax2.set_zlim(vmin, vmax+1.5)
ax3.set_zlim(vmin, vmax+1.5)

# plt.imshow(transforms(gunet_out.cpu()),cmap='turbo')

In [None]:
def transforms(arr):
    tmp = torch.fft.fftshift(torch.fft.fftn(arr))
    return tmp

def accumulate_fourier(loader):
    unet_acc_out = None
    gunet_acc_out = None
    acc_tar = None
    count = 0

    with torch.no_grad():
        for (inputs, targets) in tqdm(loader, "Accumulating"):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs_unet = unet_model(inputs)
            outputs_gunet = gunet_model(inputs)

            for i in range(outputs_unet.shape[0]):
                for j in range(outputs_unet.shape[1]):
                    if count == 0:
                        unet_acc_out = transforms(outputs_unet[i,j])
                        gunet_acc_out = transforms(outputs_gunet[i,j])
                        acc_tar = transforms(targets[i,j])
                    else:
                        unet_acc_out += transforms(outputs_unet[i,j])
                        gunet_acc_out += transforms(outputs_gunet[i,j])
                        acc_tar += transforms(targets[i,j])
                    count +=1
    return acc_tar/count, unet_acc_out/count, gunet_acc_out/count


In [None]:
tar_fourier, unet_fourier, gunet_fourier = accumulate_fourier(loader = test_loader)

In [None]:
def moving_average(arr, window_size=32):
    window = np.ones((window_size,window_size)) / (window_size**2)
    
    return scipy.signal.convolve2d(arr, window, mode='same')

fig, axs = plt.subplots(1,2,figsize=(8,4),dpi=200, subplot_kw={'projection': '3d'})


r = range(0, gunet_fourier.shape[1])
p = range(0, gunet_fourier.shape[0])
X, Y = np.meshgrid(r, p)

z1 = moving_average(np.log(np.abs(tar_fourier.cpu())))
z2 = moving_average(np.log(np.abs(unet_fourier.cpu())))
z3 = moving_average(np.log(np.abs(gunet_fourier.cpu())))
vmin=0.9
vmax=1.8
print(vmin,vmax)
# ax1.plot_surface(X, Y, abs(z1-z1), cmap='turbo', vmin=vmin,vmax=vmax, cstride=4, rstride=4)
s = axs[0].plot_surface(X, Y, abs(z2-z1), cmap='RdYlGn_r', vmin=vmin,vmax=vmax, cstride=4, rstride=4)
axs[1].plot_surface(X, Y, abs(z3-z1), cmap='RdYlGn_r', vmin=vmin,vmax=vmax, cstride=4, rstride=4)

for ax, title in zip(axs, ["UNet", "GUNet"]):
    ax.set_proj_type('ortho') 
#     ax.set_xticks([])
#     ax.set_yticks([])
    ax.set_zlim(0, vmax+0.5)
    ax.xaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.yaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.zaxis.set_pane_color((1.0, 1.0, 1.0, 0.0))
    ax.xaxis._axinfo["grid"]['color'] =  (0.8,0.8,0.8,1)
    ax.yaxis._axinfo["grid"]['color'] =  (0.8,0.8,0.8,1)
    ax.zaxis._axinfo["grid"]['color'] =  (0.8,0.8,0.8,1)
    ax.set_title(r"$\bf{"+f"{title}"+"}$")
#     ax.axis('off')
    
fig.tight_layout(pad=0)
fig.colorbar(s, ax=axs,shrink=0.5,location='bottom', orientation='horizontal',pad=0.05)

In [None]:
fig, axs = plt.subplots(1,2,figsize=(8,4),dpi=200)
# axs[0].imshow(z1,vmin=vmin,vmax=vmax,cmap='turbo')
axs[0].imshow(abs(z2-z1),vmin=vmin,vmax=vmax,cmap='RdYlGn_r')
axs[1].imshow(abs(z3-z1),vmin=vmin,vmax=vmax,cmap='RdYlGn_r')

for ax, title in zip(axs, ["UNet", "GUNet"]):
    ax.set_title(r"$\bf{"+f"{title}"+"}$")
    ax.axis('off')
    
fig.tight_layout(pad=1)
fig.colorbar(s, ax=axs,shrink=0.5,location='bottom', orientation='horizontal',pad=0.05)

In [None]:
vmin=9.5
vmax=16

pad_x = 75
pad_y = 65
window = 8

z1 = moving_average(np.log(np.abs(tar_fourier.cpu())),window )
z2 = moving_average(np.log(np.abs(unet_fourier.cpu())),window )
z3 = moving_average(np.log(np.abs(gunet_fourier.cpu())),window )

fig, axs = plt.subplots(1,3,figsize=(9,3),dpi=200)
s = axs[0].imshow(z1,vmin=vmin,vmax=vmax,cmap='turbo')
axs[1].imshow(z2,vmin=vmin,vmax=vmax,cmap='turbo')
axs[2].imshow(z3,vmin=vmin,vmax=vmax,cmap='turbo')

for ax, title in zip(axs, ["Targets", "UNet", "GUNet"]):
    ax.set_title(r"$\bf{"+f"{title}"+"}$")
    ax.axis('off')
#     ax.grid(0.8)
#     ax.set_frame_on(False)
#     ax.set_xticks([544/2-pad_x,544/2+pad_x])
#     ax.set_yticks([352/2-pad_y,352/2+pad_y])
    
    
fig.tight_layout(pad=0.4)
fig.colorbar(s, ax=axs,shrink=0.5,location='bottom', orientation='horizontal',pad=0.05)



In [None]:
from scipy.stats import ttest_ind
ttest_ind(arr0.flatten(),arr1.flatten(),equal_var=False)

In [None]:
ssim_loss = pytorch_ssim.SSIM()
test_loader = iter(DataLoader(dataset, batch_size=gunet_config['batch_size'], shuffle=True, num_workers=4,))
model = gunet_model
with torch.no_grad():
    (inputs, targets) = next(test_loader)
    inputs, targets = inputs.to(device), targets.to(device)
    outputs = model(inputs)

    print(ssim_loss(outputs, targets).item())


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from mpl_toolkits.axes_grid1 import Grid
import matplotlib.dates as mdates
from matplotlib.ticker import MaxNLocator

In [None]:
unet_training = pd.read_csv('./run/unet_without_clr/training.csv')
unet_training.head()

In [None]:
gunet_training = pd.read_csv('./run/gunet_without_clr/training.csv')
gunet_training.head()

In [None]:
print(f'SSIM: {unet_training["val_ssim_loss"].max():.4f}, MSE: {unet_training["val_mse_loss"].min():.7f}, MAE: {unet_training["val_mae_loss"].min():.5f}')
print(f'SSIM: {gunet_training["val_ssim_loss"].max():.4f}, MSE: {gunet_training["val_mse_loss"].min():.7f}, MAE: {gunet_training["val_mae_loss"].min():.5f}')

In [None]:
fig, ax = plt.subplots(1,1,dpi=200,figsize=(3,3))

import datetime as dt

unet_training['delta'] = [dt.datetime.strptime(d, '%Y-%m-%d %H:%M:%S') for d in unet_training['time']]
unet_training['delta']  = [(t - unet_training['delta'][0]).total_seconds() for t in unet_training['delta']]
gunet_training['delta'] = [dt.datetime.strptime(d, '%Y-%m-%d %H:%M:%S') for d in gunet_training['time']]
gunet_training['delta']  = [(t - gunet_training['delta'][0]).total_seconds() for t in gunet_training['delta']]


ax.plot(unet_training['delta'],unet_training['val_ssim_loss'],c="#268BD2",label='Unet')
ax.plot(gunet_training['delta'],gunet_training['val_ssim_loss'],c="#D1495B",label='GUnet')

ax.set_xlabel('Duration (hours)')
ax.set_ylabel('SSIM')


# y_ticks = np.arange(0.81, 0.875, 0.01)
# plt.yticks(y_ticks,[round(i,3)  for i in y_ticks])
ax.set_xticks([i * 86400 for i in np.arange(0,3.5,0.5)])
ax.set_xticklabels([f"{round(i // 3600)}" for i in ax.get_xticks()])
ax.set_ylim(0.825,0.873)

ax.legend(framealpha=1)
ax.grid(True,c='0.8')
ax.set_frame_on(False)

In [None]:
fig, ax = plt.subplots(1,1,dpi=200,figsize=(3,3))

import datetime as dt



ax.plot(unet_training['epoch']+1,unet_training['val_ssim_loss'],c="#268BD2",label='Unet')
ax.plot(gunet_training['epoch']+1,gunet_training['val_ssim_loss'],c="#D1495B",label='GUnet')

ax.set_xlabel('Epoch')
ax.set_ylabel('SSIM')

# y_ticks = np.arange(0.825, 0.875, 0.01)
# plt.yticks(y_ticks,[round(i,3) if round(i,4) % 0.001 == 0 else ""  for i in y_ticks])
# ax.set_xticks([i * 86400 for i in np.arange(0,3.5,0.5)])
# ax.set_xticklabels([f"{round(i // 3600)}" for i in ax.get_xticks()])
ax.set_ylim(0.825,0.875)

ax.legend(framealpha=1)
ax.grid(True,c='0.8')
ax.set_frame_on(False)
# ax.set_yscale('log')

In [None]:
from matplotlib.ticker import NullLocator, NullFormatter

fig, ax = plt.subplots(1,1,dpi=200,figsize=(3,3))

ax.plot(unet_training['epoch']+1,unet_training['val_mae_loss'] ,c="#268BD2",label='Unet validation')
ax.plot(gunet_training['epoch']+1,gunet_training['val_mae_loss'] ,c="#D1495B",label='GUNet validation')
# # set the axis labels and title
ax.set_xlabel('Epoch')
ax.set_ylabel('MAE')


ax.set_yscale('log')

y_min, y_max = 0.0125, 0.0185
num_ticks = 10
y_ticks = np.round(np.logspace(np.log10(y_min), np.log10(y_max),5),5)
ax.set_yticks(y_ticks)

ax.yaxis.set_major_formatter(plt.ScalarFormatter())
ax.yaxis.set_minor_locator(NullLocator())
ax.yaxis.set_minor_formatter(NullFormatter())

ax.set_ylim(y_min,y_max)
ax.legend(loc='upper right',framealpha=1)

ax.grid(True,c='0.8')
ax.set_frame_on(False)

In [None]:
from matplotlib.ticker import NullLocator, NullFormatter

fig, ax = plt.subplots(1,1,dpi=200,figsize=(3,3))

ax.plot(unet_training['epoch']+1,unet_training['train_mse_loss'],c="#268BD2",linestyle=':' ,label='UNet training')
ax.plot(unet_training['epoch']+1,unet_training['val_mse_loss'] ,c="#268BD2",label='Unet validation')
ax.plot(gunet_training['epoch']+1,gunet_training['train_mse_loss'],c="#D1495B",linestyle=':' ,label='GUNet training')
ax.plot(gunet_training['epoch']+1,gunet_training['val_mse_loss'] ,c="#D1495B",label='GUNet validation')
# # set the axis labels and title
ax.set_xlabel('Epoch')
ax.set_ylabel('MSE')


ax.set_yscale('log')

y_min, y_max = 0.0017, 0.00301
num_ticks = 10
y_ticks = np.round(np.logspace(np.log10(0.0017), np.log10(0.003),5),5)
ax.set_yticks(y_ticks)

ax.yaxis.set_major_formatter(plt.ScalarFormatter())
ax.yaxis.set_minor_locator(NullLocator())
ax.yaxis.set_minor_formatter(NullFormatter())

ax.set_ylim(0.0017,0.003)
ax.legend(loc='upper right',framealpha=1)

ax.grid(True,c='0.8')
ax.set_frame_on(False)

In [None]:
test_loader = iter(test_loader)

In [None]:
batch_size = 50
test_loader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4,))


mse_loss = nn.MSELoss()
ssim_loss = pytorch_ssim.SSIM()
mae_loss = nn.L1Loss()
losses = [mae_loss,mse_loss,ssim_loss]

gunet_losses = np.zeros((batch_size,dataset_config['output_length'],len(losses)))
unet_losses = np.zeros((batch_size,dataset_config['output_length'],len(losses)))
            
with torch.no_grad():
    (inputs, targets) = next(test_loader)
    inputs, targets = inputs.to(device), targets.to(device)
    gunet_outputs = gunet_model(inputs)
    unet_outputs = unet_model(inputs)
    
    for i in range(targets.shape[0]):
        for j in range(targets.shape[1]):
            for k in range(len(losses)):
                gunet_losses[i,j,k] = losses[k](gunet_outputs[i,j].unsqueeze(0).unsqueeze(0), targets[i,j].unsqueeze(0).unsqueeze(0)).item()
                unet_losses[i,j,k] = losses[k](unet_outputs[i,j].unsqueeze(0).unsqueeze(0), targets[i,j].unsqueeze(0).unsqueeze(0)).item()
            
#     for i, loss in zip([mse_loss,]):
#                 running_losses[i] += 
    
    max_vals, max_idxs = torch.max(targets, dim=3)
    max_vals, max_idxs = torch.max(max_vals, dim=2)
    max_vals, max_idxs = torch.max(max_vals, dim=1)
    mean_vals = torch.mean(targets, dim=3)
    mean_vals = torch.mean(mean_vals, dim=2)
    mean_vals = torch.mean(mean_vals, dim=1)
    ssim_vals = torch.mean(torch.tensor(gunet_losses[:,:,2]), dim=1)
    sorted_max = torch.sort(max_vals,descending=True).indices
    sorted_mean = torch.sort(mean_vals,descending=True).indices
    sorted_ssim = torch.sort(ssim_vals,descending=True).indices
iterator_max = iter(sorted_max)
iterator_mean = iter(sorted_mean)
print(ssim_vals,sorted_ssim)


In [None]:
loss_i = 2
max_i = 0
max_diff = gunet_losses[0,0,loss_i] - unet_losses[0,0,loss_i]
for i in range(batch_size):
        for j in range(dataset_config['output_length']):
            tmp_diff = gunet_losses[i,j,loss_i] - unet_losses[i,j,loss_i]
            if tmp_diff > max_diff:
                max_diff = tmp_diff
                max_i = i
print(max_i, max_diff)
max_j = max_i

In [None]:
iterator_max = iter(sorted_max)
iterator_mean = iter(sorted_mean[10:])
iterator_ssim = iter(sorted_ssim[20:])

In [None]:
max_j = next(iterator_ssim).item()
print(max_j)

In [None]:
sorted_mean

In [None]:
from mpl_toolkits.axes_grid1 import ImageGrid
from mpl_toolkits.basemap import Basemap
import matplotlib.colors as colors
mpl.rcParams.update(mpl.rcParamsDefault)

coor = (11.8, 48.07, 19.41, 51.57)

idx = [1,3,5,7]
# ratio = targets[0,0].shape[0]/targets[0,0].shape[1]
ratio = 352/544

llcrnrlon, llcrnrlat, urcrnrlon, urcrnrlat = coor
lat_1 = (llcrnrlat + urcrnrlat) / 2.0
lon_0 = (llcrnrlon + urcrnrlon) / 2.0

# create the Basemap object
m = Basemap(llcrnrlon=llcrnrlon, llcrnrlat=llcrnrlat,
            urcrnrlon=urcrnrlon, urcrnrlat=urcrnrlat,
            resolution='i', projection='aea', lat_1=lat_1, lon_0=lon_0)
fig,axs = plt.subplots(len(idx),3, dpi=300, figsize=(3*3,3*len(idx)), height_ratios=[ratio for i in idx])

cmap = 'turbo'

# max_j = next(iterator_mean).item()
# max_j = next(iterator_max).item()
# max_j = next(iterator_ssim).item()

low = 0.3
high = 0.6
# high = torch.max(targets[max_j]).item()

def modified_sigmoid(x):
#     return torch.where(x > a, 1, x/a)
    return torch.where(x > high, 1, 
                       torch.where(x < low, 0.05,  
                                   torch.tanh((4*torch.tensor(math.pi) / (high-low)) * (x-low))))
#     return torch.where(x > a, 1, 1 / (1 + np.exp(-k * (x - a/2))))
#     return 1 / (1 + np.exp(-k * (x - a)))

for i, image_idx in enumerate(idx):
#     m.arcgisimage(ax=axs[i,0], service='Canvas/World_Light_Gray_Base',xpixels=1000)
#     m.arcgisimage(ax=axs[i,1], service='Canvas/World_Light_Gray_Base',xpixels=1000)
#     m.arcgisimage(ax=axs[i,2], service='Canvas/World_Light_Gray_Base',xpixels=1000)
    m.drawcountries(linewidth=.5, ax=axs[i,0],zorder=-10,color='gray')
    m.drawcountries(linewidth=.5, ax=axs[i,1],zorder=-10,color='gray')
    m.drawcountries(linewidth=.5, ax=axs[i,2],zorder=-10,color='gray')
    axs[i,1].set_xlabel(r"$\bf{SSIM}: $" + f"{gunet_losses[max_j,i,2]:.3f}, " + 
                        r"$\bf{MAE}: $" + f"{gunet_losses[max_j,i,0]:.4f}, " +
                        r"$\bf{MSE}: $" + f"{gunet_losses[max_j,i,1]:.4f}",labelpad=0)
    axs[i,2].set_xlabel(r"$\bf{SSIM}: $" + f"{unet_losses[max_j,i,2]:.3f}, " + 
                        r"$\bf{MAE}: $" + f"{unet_losses[max_j,i,0]:.4f}, " +
                        r"$\bf{MSE}: $" + f"{unet_losses[max_j,i,1]:.4f}",labelpad=0)
    im = m.imshow(targets[max_j,image_idx].cpu(),
             cmap=cmap, 
             vmin=low,vmax=high,
#              vmin=torch.min(targets).item(),vmax=torch.max(targets).item(),
             ax=axs[i,0], 
             origin='upper',
             alpha=modified_sigmoid(targets[max_j,image_idx].cpu()))
    
    m.imshow(gunet_outputs[max_j,image_idx].cpu(),
             cmap=cmap, 
             vmin=low,vmax=high,
#              vmin=torch.min(gunet_outputs).item(),vmax=torch.max(gunet_outputs).item(),
             ax=axs[i,1],
             origin='upper',
             alpha=modified_sigmoid(gunet_outputs[max_j,image_idx].cpu()),)
    
    m.imshow(unet_outputs[max_j,image_idx].cpu(),
             cmap=cmap,
             vmin=low, vmax=high,
#              vmin=torch.min(unet_outputs).item(),vmax=torch.max(unet_outputs).item(),
             ax=axs[i,2],
             origin='upper', 
             alpha=modified_sigmoid(unet_outputs[max_j,image_idx ].cpu()))
    
for axs2 in axs:
    for ax in axs2:
        ax.set_frame_on(False)
        
fig.tight_layout(pad=0.2)
cbar = fig.colorbar(im, ax=axs[:,:],shrink=0.5,location='bottom', orientation='horizontal',pad=0.03, aspect=15)
cbar.ax.tick_params(labelsize=14)

for ax, idx in zip(axs[:,0], idx):
    ax.set_ylabel(r"T +$\bf{"+f"{(idx+1)*10}"+"}$ min", fontsize=16,labelpad=0)
    
for ax, title in zip(axs[0], ["Observation", "GUNet", "UNet"]):
    ax.set_title(r"$\bf{"+f"{title}"+"}$", fontsize=16)


In [None]:
# max_j = next(iterator_mean).item()
max_j = 47
print(max_j)


ratio = 352/544

fig,axs = plt.subplots(1,2,dpi=200,figsize=(2*10*ratio,10))
llcrnrlon, llcrnrlat, urcrnrlon, urcrnrlat = coor
lat_1 = (llcrnrlat + urcrnrlat) / 2.0
lon_0 = (llcrnrlon + urcrnrlon) / 2.0

# create the Basemap object
m = Basemap(llcrnrlon=llcrnrlon, llcrnrlat=llcrnrlat,
            urcrnrlon=urcrnrlon, urcrnrlat=urcrnrlat,
            resolution='i', projection='aea', lat_1=lat_1, lon_0=lon_0)


m.scatter(13.8178, 49.6583, 25, ax=axs[0], latlon=True, marker='x', color='Black',linewidths=1) 
m.scatter(13.8178, 49.6583, 25, ax=axs[1], latlon=True, marker='x', color='Black',linewidths=1) 
m.scatter(16.7885, 49.5011, 25, ax=axs[0], latlon=True, marker='+', color='Black',linewidths=1) 
m.scatter(16.7885, 49.5011, 25, ax=axs[1], latlon=True, marker='+', color='Black',linewidths=1) 


# m.arcgisimage(ax=axs[0], service='Reference/World_Boundaries_and_Places_Alternate',xpixels=1000)
# m.arcgisimage(ax=axs[1], service='Reference/World_Boundaries_and_Places_Alternate',xpixels=1000)

m.drawcountries(linewidth=.5, ax=axs[0],color='gray')
m.drawcountries(linewidth=.5, ax=axs[1],color='gray')

low=0
high=1

im = m.imshow(targets[47,0].cpu(),
             cmap=cmap, 
             vmin=low,vmax=high,
             ax=axs[0],
             origin='upper',
             alpha=modified_sigmoid(targets[47,0].cpu()))

im = m.imshow(targets[21,0].cpu(),
             cmap=cmap, 
             vmin=low,vmax=high,
             ax=axs[1],
             origin='upper',
             alpha=modified_sigmoid(targets[21,0].cpu()))

axs[0].set_frame_on(False)
axs[1].set_frame_on(False)

fig.tight_layout(pad=0.3)
fig.colorbar(im, ax=axs,shrink=0.3,location='bottom', orientation='horizontal',pad=0.02)

In [None]:
def find_distinct_values(tensor, epsilon):
    print(".")
    
    sorted_tensor = np.sort(np.unique(tensor.cpu().numpy()))
    print(".")
    
    diffs = np.diff(sorted_tensor)
    print(".")
    
    distinct_indices = np.where(diffs >= epsilon)[0]
    print(".")
    
    distinct_values = sorted_tensor[np.concatenate(([0], distinct_indices + 1))]
    
    return distinct_values

distinct = find_distinct_values(torch.flatten(inputs),0.0000001) * 60

In [None]:
distinct

In [None]:
means = []
medians = []
max_values = []

for tensor,_ in tqdm(test_loader, desc="Computing histograms"):
    means_batch = torch.mean(tensor.view(tensor.shape[0], tensor.shape[1], -1), dim=-1).flatten().tolist()
    medians_batch = torch.median(tensor.view(tensor.shape[0], tensor.shape[1], -1), dim=-1).values.flatten().tolist()
    max_values_batch = torch.max(tensor.view(tensor.shape[0], tensor.shape[1], -1), dim=-1).values.flatten().tolist()
    
    means.extend(means_batch)
    medians.extend(medians_batch)
    max_values.extend(max_values_batch)

In [None]:
num_bins=16
combined_histogram = torch.zeros(num_bins)
for inputs,targets in tqdm(train_loader, desc="Computing histograms"):
    histogram = torch.histc(inputs, bins=num_bins, min=0, max=1)
    combined_histogram += histogram

In [None]:
fig,ax = plt.subplots( figsize=(3,2),dpi=200)
# fig.patch.set_facecolor('black')

df = pd.DataFrame({"median":medians,"mean":means,"max":max_values})
columns = df.columns
x_step = 0.1
y_step = 0.05

colors= ["#268BD2","#EDAE49","#D1495B"]

polygons = []

m_min = 0
m_max = len(train_loader)*8

# for i,column in enumerate(columns):
#     vals , bins = np.histogram(df[column],bins=num_bins)
#     m_min = np.min(np.append(vals,m_min))
#     m_max = np.max(np.append(vals,m_max))

    
num_bins = 16
for i,column in enumerate(columns):
    vals , bins = np.histogram(df[column],bins=num_bins)
    bins = (bins - np.min(bins))/(np.max(bins)-np.min(bins))
    bin_centers = (bins[:-1] + bins[1:]) / 2
    x = np.zeros(num_bins+2)
    y = np.zeros(num_bins+2)
    
    x[0] = np.min(bin_centers)
    x[num_bins+1] = np.max(bin_centers)
    
    x[1:num_bins+1] =  bin_centers

    y[1:num_bins+1] = (vals - m_min)/(m_max-m_min)
    polygon = Polygon(np.c_[x+i*x_step, y+i*y_step], fc=colors[i], ec='0', lw=1,closed=False, zorder=-i, alpha=1)
    polygons.append(polygon)
    ax.add_patch(polygon)

ax.axis('tight')
# ax.set_facecolor("black")
ax.set_frame_on(False)

ax.set_yticks(np.linspace(0,1,6))

ax.set_xticks(np.linspace(0,1,6))
ax.set_xticklabels(np.around(np.linspace(0,1,6),1))

ax.set_xlabel('Pixel Value')
ax.set_ylabel('Fraction of images')
ax.legend(polygons,columns,loc='upper right',framealpha=1)


for t in ax.get_yticks():
    ax.plot([0,(len(columns)-1)*x_step], [ t, t + (len(columns)-1)*y_step],
            '0.8',zorder=-10,lw=1)

# for t in ax.get_yticks():
#     ax.plot([(len(columns)-1)*x_step,1+(len(columns)-1)*x_step], [ t + (len(columns)-1)*y_step, t + (len(columns)-1)*y_step],
#             '0.8',zorder=-10,lw=1)
    
for t in ax.get_xticks():
    ax.plot([t, t + (len(columns)-1)*x_step],[0,(len(columns)-1)*y_step],
            '0.8',zorder=-10,lw=1)

# for t in ax.get_xticks():
#     ax.plot([t + (len(columns)-1)*x_step, t + (len(columns)-1)*x_step],[(len(columns)-1)*y_step,1+(len(columns)-1)*y_step],
#             '0.8',zorder=-10,lw=1)
    
fig.tight_layout(pad=0)


In [None]:
bins

In [None]:
fig, axs = plt.subplots(1,1,dpi=200,figsize=(3,2))

bin_edges = torch.linspace(0, 1, num_bins+1)
axs.bar(bin_centers, combined_histogram/(m_max*352*544), width=(1 - 0) / num_bins, edgecolor='black',color="#268BD2")
axs.set_xlabel('Pixel Value')
axs.set_ylabel('Fraction of pixels')

axs.set_yticks(np.linspace(0,1,6))

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
# ax.spines['bottom'].set_visible(False)
# ax.spines['left'].set_visible(False)

# axs.set_frame_on(False)
fig.tight_layout(pad=0)


In [None]:
fig,ax = plt.subplots( figsize=(5,4),dpi=200)
# fig.patch.set_facecolor('black')

df = pd.DataFrame({"pixel_valu":medians,"max":max_values,"mean":means})
columns = df.columns
x_step = 0.1
y_step = 0.05

colors= ["#268BD2","#D1495B","#EDAE49"]

polygons = []

m_min = np.inf
m_max = -np.inf

for i,column in enumerate(columns):
    vals , bins = np.histogram(df[column],bins=num_bins)
    m_min = np.min(np.append(vals,m_min))
    m_max = np.max(np.append(vals,m_max))

    
num_bins = 16
for i,column in enumerate(columns):
    vals , bins = np.histogram(df[column],bins=num_bins)
    bins = (bins - np.min(bins))/(np.max(bins)-np.min(bins))
    bin_centers = (bins[:-1] + bins[1:]) / 2
    x = np.zeros(num_bins+2)
    y = np.zeros(num_bins+2)
    
    x[0] = np.min(bin_centers)
    x[num_bins+1] = np.max(bin_centers)
    
    x[1:num_bins+1] =  bin_centers

    y[1:num_bins+1] = (vals - m_min)/(m_max-m_min)
    polygon = Polygon(np.c_[x+i*x_step, y+i*y_step], fc=colors[i], ec='0', lw=1,closed=False, zorder=-i, alpha=1)
    polygons.append(polygon)
    ax.add_patch(polygon)

ax.axis('tight')
# ax.set_facecolor("black")
ax.set_frame_on(False)

ax.set_yticks(np.linspace(0,1,6))

ax.set_xticks(np.linspace(0,1,6))
ax.set_xticklabels(np.around(np.linspace(0,1,6),1))

ax.set_xlabel('Pixel Value')
ax.set_ylabel('Fraction of images in the dataset')
ax.legend(polygons,columns,loc='right')


for t in ax.get_yticks():
    ax.plot([0,(len(columns)-1)*x_step], [ t, t + (len(columns)-1)*y_step],
            '0.8',zorder=-10,lw=0.5)
    
for t in ax.get_xticks():
    ax.plot([t, t + (len(columns)-1)*x_step],[0,(len(columns)-1)*y_step],
            '0.8',zorder=-10,lw=0.5)
    
fig.tight_layout(pad=0)