In [1]:
%matplotlib inline

import os
import torch
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import parflow as pf
import xarray as xr
import pandas as pd

from torch import nn
from torch.optim import lr_scheduler
import torch.nn.functional as F
from tqdm.autonotebook import tqdm

from src.datapipes import (
    gen_normalized_ds_minmax,
    gen_normalized_ds_zscale,
    create_dataloader
)
from src.models import UNet
from src.loss import mse_loss
from src.train import train_epoch, save_experiment, load_experiment
import random

  from tqdm.autonotebook import tqdm


In [2]:
seed = 0
random.seed(seed)
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
DTYPE = torch.float32
DTYPE_NAME = "torch.float32"
torch.set_default_dtype(DTYPE)
mod_num = '5_retrain'
mode = "w"
header = True
unet_layers = 4

write_dir = "/home/at8471/c1_inversion/ml_training/outputs/logK_css_pme75_ani10"
model_run_name = f"css_pme75_ani10_ens_model{mod_num}"
yaml_flag = 0

In [3]:
DEVICE

device(type='cuda', index=0)

In [4]:
in_vars = ['pme', 'wtd', 'topo_index', 'elev', 'slope_x', 'slope_y']
out_vars = ['ksat']

ds = xr.open_dataset('../inputs/conus_steady_state_parflow.nc').load() #to get topo_index and elev vars

In [5]:
lake_mask = np.flip(pf.read_pfb('/home/at8471/c1_inversion/ml_training/inputs/c1_great_lakes_masked.pfb').squeeze(),axis = 0)
ksat_ens = xr.open_dataset('../inputs/ani10_ef_PmE75_wtd.Ksat.sx.sy.pme.zarr', engine='zarr')
ksat_ens['topo_index'] = ds['TopoIndex']
ksat_ens['elev'] = ds['elev']

In [6]:
ksat_ens['ksat'] = np.log(ksat_ens['ksat'])

In [7]:
norm_ksat_ds = gen_normalized_ds_zscale(ksat_ens)
norm_ksat_ds['lake_mask'] = xr.DataArray(lake_mask, dims = ('y','x'))

In [8]:
#splitting out training/test/validation
reserve_frac = 0.4 #leave out for valid/test #60,20,20
ind_list = list(range(0,50))
random.shuffle(ind_list)
n_reserved = int(len(ind_list)*reserve_frac)
valid_inds = list(ind_list[0:(n_reserved//2)])
test_inds = list(ind_list[(n_reserved//2):n_reserved])
train_inds = list(set(np.arange(0, 50)) - set(valid_inds) - set(test_inds)) #sets in py unique, can't contain dups, what remains is the unique inds for training data

train_inds_str = ", ".join([str(number) for number in train_inds])
valid_inds_str = ", ".join([str(number) for number in valid_inds])
test_inds_str = ", ".join([str(number) for number in test_inds])

print(valid_inds)
print(test_inds)
print(train_inds)

[43, 1, 28, 14, 36, 12, 0, 27, 47, 7]
[33, 34, 20, 49, 5, 38, 11, 23, 40, 15]
[2, 3, 4, 6, 8, 9, 10, 13, 16, 17, 18, 19, 21, 22, 24, 25, 26, 29, 30, 31, 32, 35, 37, 39, 41, 42, 44, 45, 46, 48]


In [9]:
norm_ksat_ds_train = norm_ksat_ds.isel(n=train_inds)
norm_ksat_ds_valid = norm_ksat_ds.isel(n=valid_inds)
norm_ksat_ds_test  = norm_ksat_ds.isel(n=test_inds)

In [10]:
bc = 56 #base channels related to UNET, from initial to hidden state (1st step 16 dim)
max_epochs = 250
width = 64
overlap = 26
batch_size = 512
num_workers = 2 #24 #parallel loading, doesn't really help because of how set up (maybe)
learning_rate = "NONE" #tried .01 also #start with .0001
max_lr = .01
base_lr = .0001
step_size = 12040 #tried 5, 25 also
gamma = "NONE" #tried .1, .5 also
lambda_val = 2
kernel = 5
input_dims = {'y': width, 'x': width, 'n': 1}
input_overlap = {'y': overlap, 'x': overlap}

phase = "two_phase"
loss_func = "F.mse_loss"
act_func_choice = "nn.Tanh"
optimizer_choice = "AdamW" #SGD

In [11]:
dl = create_dataloader(
    [norm_ksat_ds_train, ], 
    in_vars, out_vars,
    input_dims, input_overlap, batch_size, 
    num_workers=num_workers, dtype=DTYPE, augment_bool = True
)

vdl = create_dataloader(
    [norm_ksat_ds_valid,], 
    in_vars, out_vars,
    input_dims, input_overlap, batch_size, 
    num_workers=num_workers, dtype=DTYPE, augment_bool = True
)

In [12]:
tedl = create_dataloader(
    [norm_ksat_ds_test,], 
    in_vars, out_vars,
    input_dims, input_overlap, batch_size, 
    num_workers=num_workers, dtype=DTYPE
)

In [13]:
# x,y = next(iter(dl)) #check if the dataloaders worked
# x.shape
# y.shape

In [14]:
model = UNet(in_vars, out_vars, activation=nn.Tanh, base_channels=bc)#tanh can be better for regression problems
model = model.to(DTYPE)
model = model.to(DEVICE)
loss_fun = mse_loss
train_loss = []
valid_loss = []
opt = torch.optim.AdamW(model.parameters(), lr=max_lr) #start with AdamW, right now stochastic grad descent, (better in long run but you need to wait a really long time, not get stuck in local min)
#scheduler = lr_scheduler.StepLR(opt, step_size=step_size, gamma=gamma)
#scheduler = lr_scheduler.OneCycleLR(opt, max_lr=max_lr, epochs = max_epochs, steps_per_epoch = batch_size, three_phase = False)
scheduler = lr_scheduler.CyclicLR(opt, base_lr = base_lr, max_lr = max_lr, step_size_up = step_size, step_size_down = step_size, last_epoch = -1, cycle_momentum = False) #302 training iterations

In [15]:
if yaml_flag ==1:
    experiment_config = load_experiment(f"{write_dir}/{model_run_name}.yml")
    model.load_state_dict(torch.load(experiment_config["weights_file"],map_location=torch.device(DEVICE)))
    min_valid_loss = experiment_config["last_loss"]
else:
    experiment_config = {
        "data_config": {
            "dtype": DTYPE_NAME,
            "random_seed": seed,
            "train_inds": train_inds_str,
            "valid_inds": valid_inds_str,
            "test_inds": test_inds_str,
            "input_vars": in_vars,
            "output_vars": out_vars,
            "width": width,
            "overlap": overlap, 
            "batch_size": batch_size, 
            "num_workers": num_workers, 
            "augment_bool": "True", 
            "unet_layers":unet_layers
        },
        "model_config": {
            "max_epochs": max_epochs,
            "activation_func": act_func_choice, 
            "loss_func": loss_func,
            "optimizer": optimizer_choice,
            "input_size": len(in_vars),
            "output_size": len(out_vars),
            "base_channels": bc, 
            "learning_rate": learning_rate,
            "max_lr":max_lr,
            "base_lr": base_lr,
            "three-phase": phase,
            "step_size":step_size, 
            "gamma":gamma,
            "lambda":lambda_val, 
            "kernel":kernel
        },
    }
    save_experiment(config = experiment_config,output_dir = write_dir,name=model_run_name, model = None, valid_loss = None, metrics = None)
    min_valid_loss = np.inf

In [None]:
%%capture
for e in (bar := tqdm(range(max_epochs))):
    print(f"starting epoch {e}")
    model.train()
    train_loss.append(train_epoch(model, dl, opt, loss_fun,lmbda=lambda_val,func=F.mse_loss, train=True, device=DEVICE))
    model.eval()
    valid_loss.append(train_epoch(model, vdl, opt,loss_fun,lmbda=lambda_val,func=F.mse_loss, train=False, device=DEVICE))
    bar.set_description(f'{train_loss[-1]:.4f}')
    
    scheduler.step()
    
    losses = {'train_loss': train_loss, 'valid_loss': valid_loss}
    loss_df = pd.DataFrame.from_dict(losses)
    
    save_experiment(config = experiment_config,output_dir = write_dir,name=model_run_name,metrics=loss_df, mode = mode, header = header)
    
    if min_valid_loss > valid_loss[e]:
        print(f'Validation Loss Decreased({min_valid_loss:.6f}--->{valid_loss[e]:.6f}) \t Saving The Model')
        min_valid_loss = valid_loss[e]
        print(min_valid_loss)
        save_experiment(config = experiment_config,output_dir = write_dir,name=model_run_name,model = model,valid_loss = min_valid_loss)

In [None]:
#test
it = iter(tedl)
x, y = next(it)
x = x.to(DEVICE)
yy = model(x).squeeze().detach().cpu().numpy()
y = y.squeeze().detach().cpu()

In [None]:
fig, axes= plt.subplots(1, 3, dpi=200, sharex=True, sharey=True)

b= 14
vmin = y[b].min()
vmax = y[b].max()

m = axes[0].imshow(y[b], vmin=vmin, vmax=vmax)
axes[0].set_title('ParFlow')
axes[0].set_xticks([])
axes[0].set_yticks([])
#plt.colorbar(m, shrink=0.3)

m = axes[1].imshow(yy[b], vmin=vmin, vmax=vmax)
axes[1].set_title('UNet')
plt.colorbar(m, shrink=0.3)

m = axes[2].imshow(y[b]-yy[b], vmin=-1, vmax=1, cmap='Spectral')
axes[2].set_title('Difference')
plt.colorbar(m, shrink=0.3)
#plt.tight_layout()