In [None]:
import sys
import glob
import json
import random
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from toolz.curried import pipe, curry, compose

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader

import chnet.ch_tools as ch_tools
import chnet.utilities as ch_utils
from chnet.models import get_model
import chnet.ch_generator as ch_gen
from chnet.ch_loader import CahnHillDataset

import seaborn as sns
import matplotlib.pyplot as plt

In [None]:
to_npy = lambda x: x.detach().cpu().numpy()

def localization_train(
              n_state=5,
              domain = (-1.1, 1.1),
              mid=0.0, 
              dif=0.449, 
              dim_x=96, 
              dx=0.25, 
              dt=0.01, 
              gamma=0.2, 
              nstep=2,
              init_steps=1, 
              n_samples=1024, 
              final_tstep=501, 
              seed=68457542, 
              device="cuda"):
    
    from pymks.bases import LegendreBasis
    from pymks import MKSLocalizationModel
    print("Start MKS Training")
    device = torch.device("cuda:0") if device == "cuda" else torch.device("cpu")
    torch.cuda.empty_cache()
    x_inp, y_inp = ch_gen.data_generator(nsamples=n_samples, 
                                  dim_x=dim_x, 
                                  init_steps=init_steps, 
                                  delta_sim_steps = (final_tstep-init_steps)//nstep,
                                  dx=dx, 
                                  dt=dt,
                                  m_l=mid-dif, 
                                  m_r=mid+dif,
                                  n_step=nstep,
                                  gamma=gamma, 
                                  seed=seed,
                                  device=device)
    x_inp, y_inp = x_inp[:,0], y_inp[:,-1]
    basis = LegendreBasis(n_state, domain)
    model = MKSLocalizationModel(basis)
    model.fit(x_inp, y_inp)
    print("End MKS Training")
    return model

def validate(key="unet", 
              ngf=32,
              tanh=True,
              conv=True,
              mid=0.0, 
              dif=0.449, 
              dim_x=96, 
              dx=0.25, 
              dt=0.01, 
              gamma=0.2, 
              nstep=5,
              init_steps=1, 
              n_samples=1024,
              final_tstep=500, 
              seed=8634132, 
              device="cuda", 
              weight_files="",
              vis=True, 
              n_items=10, 
              mks=False):
    
    from chnet.ssim import SSIM
    ssim_loss = SSIM(window_size=11)
    mae_loss_npy = lambda x1, x2: np.mean(np.fabs(x1-x2))
    
    device = torch.device("cuda:0") if device == "cuda" else torch.device("cpu")
    print(device)
    
    if mks and (init_steps > 1):
        model0 = localization_train(n_state=11,
                      init_steps=1, 
                      final_tstep=init_steps,
                      n_samples=2048,)

    
    models = []
    
    for weight_file in weight_files:
        model = get_model(key=key, ngf=ngf, tanh=tanh, conv=conv, nstep=nstep, device=device)
        model.load_state_dict(torch.load(weight_file, map_location=device)["state"])
        models.append(model)
        
    print("Start Validation")
    
    torch.cuda.empty_cache()
    
    if mks and (init_steps > 1):
        x_val, y_val = ch_gen.data_generator(nsamples=n_samples, 
                              dim_x=dim_x, 
                              init_steps=1, 
                              delta_sim_steps=(final_tstep-1)//nstep,
                              dx=dx, 
                              dt=dt,
                              m_l=mid-dif, 
                              m_r=mid+dif,
                              n_step=nstep,
                              gamma=gamma, 
                              seed=seed,
                              device=device)
        x_inp = x_val.copy()
        x_val = model0.predict(x_val[:,0])[:,None]
    else:
        x_val, y_val = ch_gen.data_generator(nsamples=n_samples, 
                      dim_x=dim_x, 
                      init_steps=init_steps, 
                      delta_sim_steps=(final_tstep-init_steps)//nstep,
                      dx=dx, 
                      dt=dt,
                      m_l=mid-dif, 
                      m_r=mid+dif,
                      n_step=nstep,
                      gamma=gamma, 
                      seed=seed,
                      device=device)
        
    val_dataset = CahnHillDataset(x_val, y_val, 
                                  transform_x=lambda x: x[:,None], 
                                  transform_y=lambda x: x[:,None])
    
    errs = []
    ssims = []
    
    torch.cuda.empty_cache()
    
    for ix in tqdm(range(len(val_dataset))):

        item_v = val_dataset[ix]

        x = item_v['x'][None][:,0].double().to(device)
        y_tru = item_v['y'][None][:,-1].double().to(device)
        y_prd= torch.clone(x)
        
        for model in models:
            model.eval()
            y_prd=model(y_prd)
        
        if mks and (init_steps > 1):
            im_x = x_inp[ix,0]
        else:
            im_x = x[0,0].detach().cpu().numpy()
            
        im_y1 = y_tru[0,0].detach().cpu().numpy()
        im_y2 = y_prd[0,0].detach().cpu().numpy()
        
        errs.append(mae_loss_npy(im_y1, im_y2))
        ssims.append(ssim_loss(y_tru, y_prd).item())

        if vis:
            if ((ix+1) % (n_samples//n_items)) == 0:
                ch_utils.draw_by_side([im_x, im_y1, im_y2], 
                                      sub_titles=["inp", "sim", "cnn"], 
                                      scale=8, vmin=None, vmax=None)
                
                
                print("{:1.3f}, {:1.3f}".format( 
                      ssim_loss(y_tru, y_tru).item(), 
                      ssim_loss(y_tru, y_prd).item()))
                
                print("mae: {:1.3e}, inp: {:1.3e}, out: {:1.3e}".format(errs[-1], np.mean(im_y1), np.mean(im_y2)))
    
    return errs, ssims



def compare_models(key="unet", ngf=64, tanh=True, conv=True, nstep=2, device="cuda", weight_files=[]):
    from chnet.ssim import SSIM
    from tqdm.notebook import tqdm
    import matplotlib.pyplot as plt
    from sklearn.manifold import TSNE

    ssim_loss = SSIM(window_size=11)
    mae_loss_npy = lambda x1, x2: np.mean(np.fabs(x1-x2))

    random.seed(0)
    torch.manual_seed(0)
    np.random.seed(0)

    device="cuda"
    device = torch.device("cuda:0") if device == "cuda" else torch.device("cpu")
    print(device)
    
    torch.manual_seed(0)
    x_inp = torch.rand([64,1,96,96]).double().to(device)

    outs = []
    for weight_file in tqdm(weight_files):
        model = get_model(key=key, ngf=ngf, tanh=tanh, conv=conv, nstep=nstep, device=device)
        model.load_state_dict(torch.load(weight_file, map_location=device)["state"])
        outs.append(np.ravel(model(x_inp).detach().cpu().numpy())[None])

    x_out = np.concatenate(outs)
    X_emb = TSNE(n_components=2).fit_transform(x_out)

    plt.scatter(X_emb[:,0], X_emb[:,1], c=np.arange(1, 7))
    plt.show()

In [None]:
flist = sorted(glob.glob("stepmodels/*_tanh_True_*_stepModel.pt"))
get_dt = lambda f: (int(f.split("_")[7]), int(f.split("_")[9]))
fdict = {}
for f in flist:
    dt = get_dt(f)
    fdict[(dt[0], dt[0]+2*dt[1])] = f
    
lst = [[1, 401, 801, 1201, 1601, 2001], 
       [1, 401, 801, 1201, 2001], 
       [1, 401, 801, 1601, 2001], 
       [1, 401, 801, 2001], 
       [1, 401, 1201, 1601, 2001], 
       [1, 401, 1201, 2001], 
       [1, 401, 1601, 2001], 
       [1, 401, 2001], 
       [1, 2001]]

weight_lists_001 = [[fdict[(l, lst[i0][ix+1])] for ix, l in enumerate(lst[i0][:-1])] for i0 in range(len(lst))]
key_list_001 = ["-".join([str(l0) for l0 in l]) for l in lst]

for l in lst:
    l[0]=101
    print(l)
    
weight_lists_101 = [[fdict[(l, lst[i0][ix+1])] for ix, l in enumerate(lst[i0][:-1])] for i0 in range(len(lst))]
key_list_101 = ["mks-"+"-".join([str(l0) for l0 in l]) for l in lst]

lst = [[401, 801, 1201, 1601, 2001], 
       [401, 801, 1201, 2001], 
       [401, 801, 1601, 2001], 
       [401, 801, 2001], 
       [401, 1201, 1601, 2001], 
       [401, 1201, 2001], 
       [401, 1601, 2001], 
       [401, 2001]]

weight_lists_401 = [[fdict[(l, lst[i0][ix+1])] for ix, l in enumerate(lst[i0][:-1])] for i0 in range(len(lst))]
key_list_401 = ["mks-"+"-".join([str(l0) for l0 in l]) for l in lst]

lst = [[801, 1201, 1601, 2001], 
       [801, 1201, 2001], 
       [801, 1601, 2001], 
       [801, 2001]]

weight_lists_801 = [[fdict[(l, lst[i0][ix+1])] for ix, l in enumerate(lst[i0][:-1])] for i0 in range(len(lst))]
key_list_801 = ["mks-"+"-".join([str(l0) for l0 in l]) for l in lst]

In [None]:
n_samples = 1024
n_items = 10
errs_dct={}
ssim_dct={}

for key in fdict:
    init_tstep, finl_tstep = key[0], key[1]
    weight_file = fdict[key]
    print("{}-{}".format(init_tstep, finl_tstep))
    errs, ssims = validate(key="unet", 
                        ngf=64,
                        tanh=True,
                        conv=True,
                        mid=0.0, 
                        dif=0.35, 
                        dim_x=96,
                        dx=0.25, 
                        dt=0.01, 
                        gamma=0.2, 
                        n_samples=n_samples, 
                        nstep=2,
                        init_steps=init_tstep, 
                        final_tstep=finl_tstep, 
                        seed=8634132, 
                        device="cuda", 
                        weight_files=[weight_file],
                        vis=True, 
                        n_items=n_items, 
                        mks=False)
    errs_dct["{}-{}".format(init_tstep, finl_tstep)] = errs
    ssim_dct["{}-{}".format(init_tstep, finl_tstep)] = ssims
    
init_tstep = 1
finl_tstep = 2001
mks=False
for key, weight_file in zip(key_list_001, weight_lists_001):
    print(key)
    errs, ssims = validate(key="unet", 
                            ngf=64,
                            tanh=True,
                            conv=True,
                            mid=0.0, 
                            dif=0.35, 
                            dim_x=96,
                            dx=0.25, 
                            dt=0.01, 
                            gamma=0.2, 
                            n_samples=n_samples, 
                            nstep=2,
                            init_steps=init_tstep, 
                            final_tstep=finl_tstep, 
                            seed=8634132, 
                            device="cuda", 
                            weight_files=weight_file,
                            vis=True, 
                            n_items=n_items, 
                            mks=False)
    errs_dct[key] = errs
    ssim_dct[key] = ssims    
    
init_tstep = 101
finl_tstep = 2001
mks=True
for key, weight_file in zip(key_list_101, weight_lists_101):
    print(key)
    errs, ssims = validate(key="unet", 
                            ngf=64,
                            tanh=True,
                            conv=True,
                            mid=0.0, 
                            dif=0.35, 
                            dim_x=96,
                            dx=0.25, 
                            dt=0.01, 
                            gamma=0.2, 
                            n_samples=n_samples, 
                            nstep=2,
                            init_steps=init_tstep, 
                            final_tstep=finl_tstep, 
                            seed=8634132, 
                            device="cuda", 
                            weight_files=weight_file,
                            vis=True, 
                            n_items=n_items, 
                            mks=mks)
    errs_dct[key] = errs
    ssim_dct[key] = ssims    
    
init_tstep = 401
finl_tstep = 2001
mks=True
for key, weight_file in zip(key_list_401, weight_lists_401):
    print(key)
    errs, ssims = validate(key="unet", 
                            ngf=64,
                            tanh=True,
                            conv=True,
                            mid=0.0, 
                            dif=0.35, 
                            dim_x=96,
                            dx=0.25, 
                            dt=0.01, 
                            gamma=0.2, 
                            n_samples=n_samples, 
                            nstep=2,
                            init_steps=init_tstep, 
                            final_tstep=finl_tstep, 
                            seed=8634132, 
                            device="cuda", 
                            weight_files=weight_file,
                            vis=True, 
                            n_items=n_items, 
                            mks=mks)
    errs_dct[key] = errs
    ssim_dct[key] = ssims
    
init_tstep = 801
finl_tstep = 2001
mks=True
for key, weight_file in zip(key_list_801, weight_lists_801):
    print(key)
    errs, ssims = validate(key="unet", 
                            ngf=64,
                            tanh=True,
                            conv=True,
                            mid=0.0, 
                            dif=0.35, 
                            dim_x=96,
                            dx=0.25, 
                            dt=0.01, 
                            gamma=0.2, 
                            n_samples=n_samples, 
                            nstep=2,
                            init_steps=init_tstep, 
                            final_tstep=finl_tstep, 
                            seed=8634132, 
                            device="cuda", 
                            weight_files=weight_file,
                            vis=True, 
                            n_items=n_items, 
                            mks=mks)
    errs_dct[key] = errs
    ssim_dct[key] = ssims
    
pd.DataFrame(errs_dct).to_csv("errs_allStepModels.csv")
pd.DataFrame(ssim_dct).to_csv("ssim_allStepModels.csv")

In [None]:
mid=0.0
dif=0.5
mks=True
n_samples=1024
init_tstep=401
finl_tstep=2001
n_items=16
weight_file=['weights/model_unet_size_64_step_2_init_401_delta_400_tstep_300_tanh_True_loss_mae_tag_stepModelBroadRange.pt', 
             'weights/model_unet_size_64_step_2_init_1201_delta_400_tstep_300_tanh_True_loss_mae_tag_stepModelBroadRange.pt']


errs, ssims = validate(key="unet", 
                        ngf=64,
                        tanh=True,
                        conv=True,
                        mid=mid, 
                        dif=dif, 
                        dim_x=96,
                        dx=0.25, 
                        dt=0.01, 
                        gamma=0.2, 
                        n_samples=n_samples, 
                        nstep=2,
                        init_steps=init_tstep, 
                        final_tstep=finl_tstep, 
                        seed=8634132, 
                        device="cuda", 
                        weight_files=weight_file,
                        vis=True, 
                        n_items=n_items, 
                        mks=mks)

In [None]:
for key in err_dct:
    print(key, np.mean(err_dct[key]), np.std(err_dct[key]))

In [None]:
dif=1e-4
mks=True
n_samples=128
init_tstep=401
finl_tstep=2001
n_items=2
weight_file=['stepmodels/model_unet_size_64_step_2_init_401_delta_400_tstep_200_tanh_True_loss_mae_tag_stepModel.pt', 
             'stepmodels/model_unet_size_64_step_2_init_1201_delta_400_tstep_200_tanh_True_loss_mae_tag_stepModel.pt']

err_dct = {}
for mid in [-.3, -0.2, -0.1, -0.0, 0.1, 0.2, 0.3]:
    print("{:.3f}".format(mid))
    errs, ssims = validate(key="unet", 
                            ngf=64,
                            tanh=True,
                            conv=True,
                            mid=mid, 
                            dif=dif, 
                            dim_x=96,
                            dx=0.25, 
                            dt=0.01, 
                            gamma=0.2, 
                            n_samples=n_samples, 
                            nstep=2,
                            init_steps=init_tstep, 
                            final_tstep=finl_tstep, 
                            seed=8634132, 
                            device="cuda", 
                            weight_files=weight_file,
                            vis=True, 
                            n_items=n_items, 
                            mks=mks)
    
    err_dct["{:.3f}".format(mid)]=errs
#     print("errs, mean: {}, std: {}".format(np.mean(err_dct["{:.3f}".format(mid)]), np.std(err_dct["{:.3f}".format(mid)])))
# pd.DataFrame(err_dct).to_csv("errs_bestModel_conc.csv")

In [None]:
dif=0.001
mks=True
n_samples=128
init_tstep=401
finl_tstep=4001
n_items=2
weight_file=weight_lists_401[0]+["stepmodels/model_unet_size_64_step_2_init_2001_delta_1000_tstep_200_tanh_True_loss_mae_tag_stepModel.pt"]

err_dct1 = {}
for mid in np.arange(-.35, .351, 0.05):
    print("{:.3f}".format(mid))
    errs, ssims = validate(key="unet", 
                            ngf=64,
                            tanh=True,
                            conv=True,
                            mid=mid, 
                            dif=dif, 
                            dim_x=96,
                            dx=0.25, 
                            dt=0.01, 
                            gamma=0.2, 
                            n_samples=n_samples, 
                            nstep=2,
                            init_steps=init_tstep, 
                            final_tstep=finl_tstep, 
                            seed=8634132, 
                            device="cuda", 
                            weight_files=weight_file,
                            vis=True, 
                            n_items=n_items, 
                            mks=mks)
    
    err_dct1["{:.3f}".format(mid)]=errs
pd.DataFrame(err_dct1).to_csv("errs_bestModel_conc_limited.csv")

In [None]:
import os
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

df_errs = pd.read_csv("errs_allStepModels.csv", index_col=0)
df_ssim = pd.read_csv("ssim_allStepModels.csv", index_col=0)

cols = ["1-2001",] + list(df_errs.columns[23:])
lbls = ["1-2001"]*np.count_nonzero(["1-" in c[:3] and "-2001" in c[-5:] for c in cols])+["mks-101-2001"]*np.count_nonzero(["mks-101" in c for c in cols])+["mks-401-2001"]*np.count_nonzero(["mks-401" in c for c in cols])+["mks-801-2001"]*np.count_nonzero(["mks-801" in c for c in cols])
hues = np.ravel([[l]*1024 for l in lbls])

df_plot = pd.melt(df_errs[cols])
df_plot["hue"] = hues
plt.figure(figsize=(18,12))
sns.boxplot(x="variable", y="value", hue="hue", data=df_plot, showfliers=False)
plt.xticks(rotation = 90)
plt.show()

vals = df_errs[cols].describe().loc["mean"].values
for t1, t2 in zip([cols[ix] for ix in np.argsort(vals)], ["{:.4f}".format(vals[ix]) for ix in np.argsort(vals)]):
    print(t2, t1)

In [None]:
df_cnc = pd.read_csv("errs_bestModel_conc.csv", index_col=0)
df_plot = pd.melt(df_cnc)
plt.figure(figsize=(18,12))
sns.boxplot(data=df_cnc)
plt.xticks(rotation = 90)
plt.show()