In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import time
import json
import torch
import warnings
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from datetime import datetime
from tqdm.notebook import tqdm
from collections import OrderedDict, defaultdict
from toolz.curried import pipe, curry, compose
warnings.filterwarnings('ignore')

In [None]:
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader

In [None]:
import chnet.ch_tools as ch_tools
from chnet.ch_losses import *
import chnet.utilities as ch_utils
import chnet.ch_generator as ch_gen
from chnet.torchsummary import summary
from chnet.ch_loader import CahnHillDataset

In [None]:
import matplotlib as mpl
import matplotlib.pyplot as plt
import seaborn as sns

mpl.rcParams['figure.figsize'] = [8.0, 6.0]
mpl.rcParams['figure.dpi'] = 80
mpl.rcParams['savefig.dpi'] = 100

mpl.rcParams['font.size'] = 12
mpl.rcParams['legend.fontsize'] = 'large'
mpl.rcParams['figure.titlesize'] = 'medium'

# Set Device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# CNN Model

In [None]:
# summary(unet, input_size=(1, 128, 128))
# summary(unet_solo_resloop, input_size=(1, 128, 128))
# summary(unet_loop, input_size=(3, 1, 128, 128))
# summary(unet_res, input_size=(1, 128, 128))
# summary(unet_resloop, input_size=(3, 1, 128, 128))

### Training and Validation data generation

In [None]:
## Process Parameters
mid=0.0
dif=0.449
dim_x=96
init_steps=1
nstep=2
dx=0.25 # not from paper
dt=0.01 # from paper
gamma=0.2 # from paper
m_l=mid-dif, 
m_r=mid+dif,
seed_trn=110364
n_samples_val=512

def mae_loss_npy(x1, x2):
    return np.mean(np.fabs(x1-x2))

maerr = lambda x1, x2: np.fabs(x1-x2)
diff = lambda x1,x2: np.log(maerr(x1, x2))

In [None]:
%%time
x_val, y_val = ch_gen.data_generator(nsamples=n_samples_val, 
                              dim_x=dim_x, 
                              init_steps=init_steps, 
                              delta_sim_steps = delta_sim_steps,
                              dx=dx, 
                              dt=dt,
                              m_l=m_l, 
                              m_r=m_r,
                              n_step=nstep,
                              gamma=gamma, 
                              seed=38921641,
                              device=device)


val_dataset = CahnHillDataset(x_val, y_val, transform_x=lambda x: x[:,None], transform_y=lambda x: x[:,None])
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=1)

total_val_step = len(val_loader)
print("No. of validation steps: %d" % total_val_step)

for ix in range(3):
    ch_utils.draw_by_side([x_val[ix][0], y_val[ix][-1]], 
                          sub_titles=["training input", "training output"], 
                          title="mean: {:1.3f}".format(np.mean(x_val[ix])), 
                          vmax=None, 
                          vmin=None)

In [None]:
from chnet.models import UNet, UNet_solo_loop, UNet_loop, mse_loss
# obj = torch.load("weights/model_unet_size_64_step_2_init_1_delta_2500_tstep_250.pt")
# model = UNet(in_channels=1, out_channels=1, init_features=64).double().to(device)
# obj = torch.load("weights/model_unet_loop_size_64_step_4_init_1_delta_1250_tstep_250.pt")
# model = UNet_loop(in_channels=1, out_channels=1, init_features=64, temporal=4).double().to(device)
# obj = torch.load("weights/model_unet_size_32_step_5_init_1_delta_1000.pt")
# model = UNet(in_channels=1, out_channels=1, init_features=32).double().to(device)
# obj = torch.load("weights/model_unet_loop_size_16_step_5_init_1_delta_1000.pt")
# model = UNet_loop(in_channels=1, out_channels=1, init_features=16, temporal=nstep).double().to(device)
# obj = torch.load("weights/model_unet_loop_size_32_step_5_init_1_delta_1000.pt")
# model = UNet_loop(in_channels=1, out_channels=1, init_features=32, temporal=nstep).double().to(device)
# obj = torch.load("weights/model_unet_loop_size_32_step_10_init_1_delta_500_tstep_250.pt")
# model = UNet_loop(in_channels=1, out_channels=1, init_features=32, temporal=10).double().to(device)
# obj = torch.load("weights/model_unet_solo_loop_size_16_step_5_init_1_delta_1000.pt")
# model = UNet_solo_loop(in_channels=1, out_channels=1, init_features=16, temporal=nstep).double().to(device)
model.load_state_dict(obj["state"])

In [None]:
obj = torch.load("weights/model_unet_size_64_step_5_init_1_delta_100_tstep_250.pt")
model = UNet(in_channels=1, out_channels=1, init_features=64).double().to(device)
model.load_state_dict(obj["state"])

In [None]:
torch.cuda.empty_cache()
x_val, y_val = ch_gen.data_generator(nsamples=n_samples_val, 
                              dim_x=dim_x, 
                              init_steps=1, 
                              delta_sim_steps=250,
                              dx=dx, 
                              dt=dt,
                              m_l=m_l, 
                              m_r=m_r,
                              n_step=2,
                              gamma=gamma, 
                              seed=38921641,
                              device=device)


val_dataset = CahnHillDataset(x_val, y_val, transform_x=lambda x: x[:,None], transform_y=lambda x: x[:,None])
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=1)
torch.cuda.empty_cache()
model.eval()
errs = []
for ix in range(len(val_dataset)):

    item_v = val_dataset[ix]

    x = item_v['x'][None][:,0].double().to(device)
    y_tru = item_v['y'][None][:,-1].double().to(device) 
    
    if "loop" in key:
        y_prd=model(x)[:,-1]
    else:
        y_prd=model(x)
        
    im_y1 = y_tru[0,0].detach().cpu().numpy()
    im_y2 = y_prd[0,0].detach().cpu().numpy()
    errs.append(mae_loss_npy(im_y1, im_y2))
    print("mean conc. : {}, mae: {:1.5f}".format(np.mean(im_y1), errs[-1]))
    ch_utils.draw_by_side([im_y1, im_y2], 
                          sub_titles=["sim", "cnn"], 
                          scale=8, vmin=None, vmax=None)

In [None]:
from chnet.models import UNet, UNet_solo_loop, UNet_loop, mse_loss
err_dct = {}
weight_files = ["weights/model_unet_size_64_step_2_init_1_delta_250_tstep_250.pt",
                "weights/model_unet_size_64_step_5_init_1_delta_100_tstep_250.pt", 
                "weights/model_unet_size_64_step_2_init_1_delta_500_tstep_250.pt", 
                "weights/model_unet_size_64_step_2_init_1_delta_1000_tstep_250.pt",
                "weights/model_unet_size_64_step_2_init_1_delta_2500_tstep_250.pt",
                "weights/model_unet_loop_size_64_step_2_init_1_delta_2500_tstep_250.pt", 
                "weights/model_unet_loop_size_64_step_4_init_1_delta_1250_tstep_250.pt", 
                "weights/model_unet_loop_size_64_step_5_init_1_delta_1000_tstep_250.pt", 
                "weights/model_unet_solo_loop_size_64_step_4_init_1_delta_1250_tstep_250.pt"]

tsteps = [500, 500, 1000, 2000, 5000, 5000, 5000, 5000, 5000]
keys = ["unet-500-spl", "unet-500", "unet-1k", "unet-2k", "unet-5k", "uloop-2", "uloop-4", "uloop-5", "usolo-4"]

from chnet.models import UNet, UNet_solo_loop, UNet_loop, mse_loss
for key, weight, tstep in zip(keys[:2], weight_files[:2], tsteps[:2]):
    print(key, weight)
    x_val, y_val = ch_gen.data_generator(nsamples=n_samples_val, 
                                  dim_x=dim_x, 
                                  init_steps=1, 
                                  delta_sim_steps = tstep//2,
                                  dx=dx, 
                                  dt=dt,
                                  m_l=m_l, 
                                  m_r=m_r,
                                  n_step=2,
                                  gamma=gamma, 
                                  seed=38921641,
                                  device=device)
    val_dataset = CahnHillDataset(x_val, y_val, transform_x=lambda x: x[:,None], transform_y=lambda x: x[:,None])
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=1)
    
    obj = torch.load(weight)
    if "unet" in key:
        model = UNet(in_channels=1, out_channels=1, init_features=64).double().to(device)
    elif "uloop" in key:
        t = int(key.split("-")[-1])
        model = UNet_loop(in_channels=1, out_channels=1, init_features=64, temporal=t).double().to(device)
    elif "usolo" in key:
        t = int(key.split("-")[-1])
        model = UNet_solo_loop(in_channels=1, out_channels=1, init_features=64, temporal=t).double().to(device)
    model.load_state_dict(obj["state"])
    
    
    torch.cuda.empty_cache()
    model.eval()
    errs = []
    for ix in tqdm(range(len(val_dataset))):

        item_v = val_dataset[ix]

        x = item_v['x'][None][:,0].double().to(device)
        y_tru = item_v['y'][None][:,-1].double().to(device) 

        if "loop" in key:
            y_prd=model(x)[:,-1]
        else:
            y_prd=model(x)

        im_y1 = y_tru[0,0].detach().cpu().numpy()
        im_y2 = y_prd[0,0].detach().cpu().numpy()
        errs.append(mae_loss_npy(im_y1, im_y2))
#         print("mae: {:1.5f}".format(errs[-1]))
#         ch_utils.draw_by_side([im_y1, im_y2], 
#                               sub_titles=["sim", "cnn"], 
#                               scale=8)
    err_dct[key] = errs
    print(np.mean(err_dct[key]))

In [None]:
from chnet.models import UNet, UNet_solo_loop, UNet_loop, mse_loss
err_dct = {}
weight_files = ["weights/model_unet_size_64_step_2_init_1_delta_2500_tstep_250.pt",
                "weights/model_unet_size_64_step_2_init_500_delta_2250_tstep_250.pt", 
                "weights/model_unet_size_64_step_2_init_1000_delta_2000_tstep_250.pt", 
                "weights/model_unet_size_64_step_2_init_2000_delta_1500_tstep_250.pt",]

delta_tsteps = [2500, 2250, 2000, 1500]
init_steps = [1, 500, 1000, 2000]
keys = ["unet-i-1", "unet-i-500", "unet-i-1k", "unet-i-2k"]

for key, weight, dtstep, itstep in zip(keys, weight_files, delta_tsteps, init_steps):
    print(key, weight)
    x_val, y_val = ch_gen.data_generator(nsamples=n_samples_val, 
                                  dim_x=dim_x, 
                                  init_steps=itstep, 
                                  delta_sim_steps = dtstep,
                                  dx=dx, 
                                  dt=dt,
                                  m_l=m_l, 
                                  m_r=m_r,
                                  n_step=2,
                                  gamma=gamma, 
                                  seed=38921641,
                                  device=device)
    val_dataset = CahnHillDataset(x_val, y_val, transform_x=lambda x: x[:,None], transform_y=lambda x: x[:,None])
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=1)
    
    obj = torch.load(weight)
    if "unet" in key:
        model = UNet(in_channels=1, out_channels=1, init_features=64).double().to(device)
    elif "uloop" in key:
        t = int(key.split("-")[-1])
        model = UNet_loop(in_channels=1, out_channels=1, init_features=64, temporal=t).double().to(device)
    elif "usolo" in key:
        t = int(key.split("-")[-1])
        model = UNet_solo_loop(in_channels=1, out_channels=1, init_features=64, temporal=t).double().to(device)
    model.load_state_dict(obj["state"])
    
    
    torch.cuda.empty_cache()
    model.eval()
    errs = []
    for ix in tqdm(range(len(val_dataset))):

        item_v = val_dataset[ix]

        x = item_v['x'][None][:,0].double().to(device)
        y_tru = item_v['y'][None][:,-1].double().to(device) 

        if "loop" in key:
            y_prd=model(x)[:,-1]
        else:
            y_prd=model(x)

        im_y1 = y_tru[0,0].detach().cpu().numpy()
        im_y2 = y_prd[0,0].detach().cpu().numpy()
        errs.append(mae_loss_npy(im_y1, im_y2))
#         print("mae: {:1.5f}".format(errs[-1]))
#         ch_utils.draw_by_side([im_y1, im_y2], 
#                               sub_titles=["sim", "cnn"], 
#                               scale=8)
    err_dct[key] = errs
    print(np.mean(err_dct[key]))

In [None]:
df = pd.DataFrame(err_dct)
plt.figure(figsize=(18, 12))
sns.boxenplot(x="variable", y="value", data=pd.melt(df))
plt.ylabel("Mean Absolute Error")
plt.xlabel("Model Type")
plt.show()

In [None]:
df = pd.DataFrame(err_dct)
plt.figure(figsize=(18, 12))
sns.boxenplot(x="variable", y="value", data=pd.melt(df))
plt.ylabel("Mean Absolute Error")
plt.xlabel("Model Type")
plt.show()

In [None]:
df = pd.DataFrame(err_dct)
plt.figure(figsize=(18, 12))
sns.boxenplot(x="variable", y="value", data=pd.melt(df))
plt.ylabel("Mean Absolute Error")
plt.xlabel("Model Type")
plt.show()

In [None]:
df = pd.DataFrame(err_dct)
plt.figure(figsize=(18, 12))
sns.boxenplot(x="variable", y="value", data=pd.melt(df))
plt.ylabel("Mean Absolute Error")
plt.xlabel("Model Type")
plt.show()

In [None]:
df = pd.DataFrame(err_dct)
sns.boxenplot(x="variable", y="value", data=pd.melt(df))
plt.show()

In [None]:
df = pd.DataFrame(err_dct)
sns.boxenplot(x="variable", y="value", data=pd.melt(df))
plt.show()