In [None]:
import sys
import json
import numpy as np
from tqdm.notebook import tqdm
from toolz.curried import pipe, curry, compose


import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader


import chnet.ch_tools as ch_tools
import chnet.utilities as ch_utils
import chnet.ch_generator as ch_gen
from chnet.ch_loader import CahnHillDataset


def localization_train(
              n_state=5,
              domain = (-1.1, 1.1),
              mid=0.0, 
              dif=0.449, 
              dim_x=96, 
              dx=0.25, 
              dt=0.01, 
              gamma=0.2, 
              nstep=2,
              init_steps=1, 
              n_samples=1024, 
              final_tstep=501, 
              seed=68457542, 
              device="cuda"):
    
    from pymks.bases import LegendreBasis
    from pymks import MKSLocalizationModel
    print("Start MKS Training")
    device = torch.device("cuda:0") if device == "cuda" else torch.device("cpu")
    torch.cuda.empty_cache()
    x_inp, y_inp = ch_gen.data_generator(nsamples=n_samples, 
                                  dim_x=dim_x, 
                                  init_steps=init_steps, 
                                  delta_sim_steps = (final_tstep-init_steps)//nstep,
                                  dx=dx, 
                                  dt=dt,
                                  m_l=mid-dif, 
                                  m_r=mid+dif,
                                  n_step=nstep,
                                  gamma=gamma, 
                                  seed=seed,
                                  device=device)
    x_inp, y_inp = x_inp[:,0], y_inp[:,-1]
    basis = LegendreBasis(n_state, domain)
    model = MKSLocalizationModel(basis)
    model.fit(x_inp, y_inp)
    print("End MKS Training")
    return model


def localization_validate(model, 
                          mid=0.0, 
                          dif=0.449, 
                          dim_x=96, 
                          dx=0.25, 
                          dt=0.01, 
                          gamma=0.2, 
                          nstep=2,
                          init_steps=1, 
                          n_samples=32, 
                          final_tstep=501, 
                          seed=8634132, 
                          device="cuda",
                          n_items=5,
                          vis=True):
    
    from chnet.ssim import SSIM
    ssim_loss = SSIM(window_size=11)
    
    mae_loss_npy = lambda x1, x2: np.mean(np.fabs(x1.numpy()-x2.numpy()))

    print("Start Validation")
    torch.cuda.empty_cache()
    x_val, y_val = ch_gen.data_generator(nsamples=n_samples, 
                                  dim_x=dim_x, 
                                  init_steps=init_steps, 
                                  delta_sim_steps = (final_tstep-init_steps)//nstep,
                                  dx=dx, 
                                  dt=dt,
                                  m_l=mid-dif, 
                                  m_r=mid+dif,
                                  n_step=nstep,
                                  gamma=gamma, 
                                  seed=seed,
                                  device=device)
    
    x_val, y_val = x_val[:,0], y_val[:,-1]
    y_prd = model.predict(x_val)
    
    errs = []
    for ix in tqdm(range(n_samples)):

        im_x  = torch.tensor(x_val[ix])
        im_y1 = torch.tensor(y_val[ix])
        im_y2 = torch.tensor(y_prd[ix])
        errs.append(mae_loss_npy(im_y1, im_y2))

        if vis:
            if ((ix+1) % (n_samples//n_items)) == 0:
                ch_utils.draw_by_side([im_x, im_y1, im_y2], 
                                      sub_titles=["inp", "sim", "cnn"], 
                                      scale=8, vmin=None, vmax=None)
                

                print("mae: {}, inp: {:1.3f}, sim: {:1.3f}, cnn: {:1.3f}".format(errs[-1], 
                                ssim_loss(im_y1[None, None], im_x[None, None]).item(),
                                ssim_loss(im_y1[None, None], im_y1[None, None]).item(), 
                                ssim_loss(im_y1[None, None], im_y2[None, None]).item()))
    return errs

In [None]:
n_state=5
domain = (-1.1, 1.1)
mid=0.0
dif=1e-4
dim_x=96
dx=0.25
dt=0.01
gamma=0.2
nstep=6
init_steps=1
n_samples=2
final_tstep=2001
seed=68457542
device="cuda"
device = torch.device("cuda:0") if device == "cuda" else torch.device("cpu")
torch.cuda.empty_cache()
micros = {}
for mid in [-.3, 0.0, .3]:
    
    x_inp, y_inp = ch_gen.data_generator(nsamples=n_samples, 
                                  dim_x=dim_x, 
                                  init_steps=init_steps, 
                                  delta_sim_steps = (final_tstep-init_steps)//nstep,
                                  dx=dx, 
                                  dt=dt,
                                  m_l=mid-dif, 
                                  m_r=mid+dif,
                                  n_step=nstep,
                                  gamma=gamma, 
                                  seed=seed,
                                  device=device)
    micros[mid] = x_inp[0]
    for i in range(nstep):
        print(i*400)
        ch_utils.draw_im(x_inp[0,i], vmax=None, vmin=None, title="timestep={}".format(i*400))

In [None]:
for m in [-0.3, 0.0, 0.3]:
    ch_utils.draw_by_side([micros[m][ix,...] for ix in range(nstep)], vmax=None, vmin=None, scale=12)

In [None]:
errs_dct = {}
init_steps = [1, 1, 1, 401, 801, 1201, 1601, 2001]
final_steps = [101, 401, 801, 801, 1201, 1601, 2001, 2401]
for init_step, final_tstep in tqdm(zip(init_steps, final_steps)):
    key = "MKS {}-{}".format(init_step, final_tstep)
    print(key)
    
    model0 = localization_train(n_state=11,
                              dif=0.35, 
                              init_steps=init_step, 
                              final_tstep=final_tstep,
                              n_samples=2048,)

    errs_dct[key] = localization_validate(model0,
                                      dif=0.35, 
                                      init_steps=init_step, 
                                      final_tstep=final_tstep,
                                      n_samples=1024,
                                      n_items=2)

In [None]:
import pandas as pd
df_errs = pd.DataFrame(errs_dct)
df_errs.to_csv("errs_allMksModels.csv")

In [None]:
df_errs.describe()