In [1]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt
from torch.utils.data.dataset import TensorDataset, Dataset
from torch.utils.data.dataloader import DataLoader

import operator
from functools import reduce
from functools import partial
from timeit import default_timer
from utilities3 import *

from adam import Adam

import tqdm
import yaml
import pandas as pd
from geckoml.data import load_data, transform_data, inv_transform_preds
from geckoml.metrics import ensembled_metrics
from collections import defaultdict

In [2]:
torch.manual_seed(0)
np.random.seed(0)

In [3]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")
if is_cuda:
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

In [4]:
config_file = "/glade/work/schreck/repos/GECKO_OPT/dev/gecko-ml/config/toluene_agg.yml"

In [5]:
with open(config_file) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [140]:
species = conf['species']
data_path = conf['dir_path']
aggregate_bins = conf['aggregate_bins']
input_vars = conf['input_vars']
output_vars = conf['output_vars']
tendency_cols = conf['tendency_cols']
log_trans_cols = conf['log_trans_cols']
output_path = "./"
scaler_type = conf['scaler_type']
ensemble_members = conf["ensemble_members"]
seed = conf['random_seed']

# Get the shapes of the input and output data 
input_size = len(input_vars)
output_size = len(output_vars)

start_time = 0
num_timesteps = 1439
batch_size = 16

L1_penalty = 1.39e-5
L2_penalty = 3.49e-4

lr_patience = 3
stopping_patience = 5
learning_rate = 1e-3

In [40]:
data = load_data(data_path, aggregate_bins, species, input_vars, output_vars, log_trans_cols)
    
transformed_data, x_scaler, y_scaler = transform_data(
    data, 
    output_path, 
    species, 
    tendency_cols, 
    log_trans_cols,
    scaler_type, 
    output_vars, 
    train=True
)

# Batch the training data by experiment
train_in_array = transformed_data['train_in'].copy()
n_exps = len(train_in_array.index.unique(level='id'))
n_timesteps = len(train_in_array.index.unique(level='Time [s]'))
n_features = len(input_vars)
out_col_idx = train_in_array.columns.get_indexer(output_vars)
train_in_array = train_in_array.values.reshape(n_exps, n_timesteps, n_features)

# Batch the validation data by experiment
val_in_array = transformed_data['val_in'].copy()
n_exps = len(val_in_array.index.unique(level='id'))
n_timesteps = len(val_in_array.index.unique(level='Time [s]'))
val_out_col_idx = val_in_array.columns.get_indexer(output_vars)
val_in_array = val_in_array.values.reshape(n_exps, n_timesteps, n_features)

train_out_array = transformed_data['train_out'].copy()
n_exps = len(train_out_array.index.unique(level='id'))
n_timesteps = len(train_out_array.index.unique(level='Time [s]'))
n_features = len(output_vars)
out_col_idx = train_out_array.columns.get_indexer(output_vars)
train_out_array = train_out_array.values.reshape(n_exps, n_timesteps, n_features)

val_out_array = transformed_data['val_out'].copy()
n_exps = len(val_out_array.index.unique(level='id'))
n_timesteps = len(val_out_array.index.unique(level='Time [s]'))
val_out_col_idx = val_out_array.columns.get_indexer(output_vars)
val_out_array = val_out_array.values.reshape(n_exps, n_timesteps, n_features)

In [141]:
# train_data = TensorDataset(
#     torch.from_numpy(transformed_data["train_in"].copy().values).float(),
#     torch.from_numpy(transformed_data["train_out"].copy().values).float()
# )
train_data = TensorDataset(
    torch.from_numpy(train_in_array).float(),
    torch.from_numpy(train_out_array).float()
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=0)
train_iter = iter(train_loader)

# valid_data = TensorDataset(
#     torch.from_numpy(transformed_data["val_in"].copy().values).float(),
#     torch.from_numpy(transformed_data["val_out"].copy().values).float()
# )
valid_data = TensorDataset(
    torch.from_numpy(val_in_array).float(),
    torch.from_numpy(val_out_array).float()
)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=0)
valid_iter = iter(valid_loader)

In [142]:
class SpectralConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, modes1):
        super(SpectralConv1d, self).__init__()

        """
        1D Fourier layer. It does FFT, linear transform, and Inverse FFT.    
        """

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.modes1 = modes1  #Number of Fourier modes to multiply, at most floor(N/2) + 1

        self.scale = (1 / (in_channels*out_channels))
        self.weights1 = nn.Parameter(self.scale * torch.rand(in_channels, out_channels, self.modes1, dtype=torch.cfloat))

    # Complex multiplication
    def compl_mul1d(self, input, weights):
        # (batch, in_channel, x ), (in_channel, out_channel, x) -> (batch, out_channel, x)
        return torch.einsum("bix,iox->box", input, weights)

    def forward(self, x):
        batchsize = x.shape[0]
        #Compute Fourier coeffcients up to factor of e^(- something constant)
        x_ft = torch.fft.rfft(x)

        # Multiply relevant Fourier modes
        out_ft = torch.zeros(batchsize, self.out_channels, x.size(-1)//2 + 1,  device=x.device, dtype=torch.cfloat)
        out_ft[:, :, :self.modes1] = self.compl_mul1d(x_ft[:, :, :self.modes1], self.weights1)

        #Return to physical space
        x = torch.fft.irfft(out_ft, n=x.size(-1))
        return x

In [143]:
class FNO1d(nn.Module):
    def __init__(self, modes, width, resize_input = 1024):
        super(FNO1d, self).__init__()

        """
        The overall network. It contains 4 layers of the Fourier layer.
        1. Lift the input to the desire channel dimension by self.fc0 .
        2. 4 layers of the integral operators u' = (W + K)(u).
            W defined by self.w; K defined by self.conv .
        3. Project from the channel space to the output space by self.fc1 and self.fc2 .
        
        input: the solution of the initial condition and location (a(x), x)
        input shape: (batchsize, x=s, c=2)
        output: the solution of a later timestep
        output shape: (batchsize, x=s, c=1)
        """

        self.modes1 = modes
        self.width = width
        self.padding = 2 # pad the domain if input is non-periodic
        self.fc0 = nn.Linear(10, self.width) # input channel is 2: (a(x), x)

        self.conv0 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv1 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv2 = SpectralConv1d(self.width, self.width, self.modes1)
        self.conv3 = SpectralConv1d(self.width, self.width, self.modes1)
        self.w0 = nn.Conv1d(self.width, self.width, 1)
        self.w1 = nn.Conv1d(self.width, self.width, 1)
        self.w2 = nn.Conv1d(self.width, self.width, 1)
        self.w3 = nn.Conv1d(self.width, self.width, 1)

        self.fc1 = nn.Linear(self.width, 128)
        self.fc2 = nn.Linear(128, 3)
        
        self.enlarge = nn.Linear(9, resize_input)
        self.squish = nn.Linear(resize_input, 3)

    def forward(self, x):
        #x = self.enlarge(x).unsqueeze(-1)
        grid = self.get_grid(x.shape, x.device)
        x = torch.cat((x, grid), dim=-1)
        x = self.fc0(x)
        x = x.permute(0, 2, 1)
        # x = F.pad(x, [0,self.padding]) # pad the domain if input is non-periodic

        x1 = self.conv0(x)
        x2 = self.w0(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv1(x)
        x2 = self.w1(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv2(x)
        x2 = self.w2(x)
        x = x1 + x2
        x = F.gelu(x)

        x1 = self.conv3(x)
        x2 = self.w3(x)
        x = x1 + x2

        # x = x[..., :-self.padding] # pad the domain if input is non-periodic
        x = x.permute(0, 2, 1)
        x = self.fc1(x)
        x = F.gelu(x)
        x = self.fc2(x)
        #x = self.squish(x.reshape(x.shape[0], x.shape[1]))
        
        return x

    def get_grid(self, shape, device):
        batchsize, size_x = shape[0], shape[1]
        gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float)
        gridx = gridx.reshape(1, size_x, 1).repeat([batchsize, 1, 1])
        return gridx.to(device)

In [144]:
# ntrain = 1000
# ntest = 100

# sub = 2**3 #subsampling rate
# h = 2**13 // sub #total grid size divided by the subsampling rate
# s = h

# batch_size = 20
# learning_rate = 0.001

# epochs = 500
# step_size = 50
# gamma = 0.5

modes = 16
width = 64

In [145]:
# dataloader = MatReader('data/burgers_data_R10.mat')
# x_data = dataloader.read_field('a')[:,::sub]
# y_data = dataloader.read_field('u')[:,::sub]

# x_train = x_data[:ntrain,:]
# y_train = y_data[:ntrain,:]
# x_test = x_data[-ntest:,:]
# y_test = y_data[-ntest:,:]

# x_train = x_train.reshape(ntrain, s, 1)
# x_test = x_test.reshape(ntest, s, 1)

In [146]:
model = FNO1d(modes, width, resize_input = 16).to(device)
print(count_params(model))

550550


In [147]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay = L2_penalty)

lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=20, 
    eta_min=1e-2*learning_rate
)

# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer, 
#         patience = lr_patience, 
#         verbose = True,
#         min_lr = 1.0e-13
# )

In [148]:
myloss = LpLoss(size_average=True)

In [None]:
results_dict = defaultdict(list)

for epoch in range(200):
    
    # Train in batch mode
    fiter = list(range(transformed_data['train_in'].shape[0] // batch_size)) #
    #fiter = tqdm.tqdm(range(transformed_data['train_in'].shape[0] // batch_size), leave=True)
    model.train()
    
    train_loss = []
    train_mae = []
    for k, t in enumerate(fiter):

        try:
            x, y = next(train_iter)
        except StopIteration:
            train_iter = iter(train_loader)
            x, y = next(train_iter)

        y_pred = model(x.to(device))
        mae_loss = torch.nn.functional.l1_loss(y.to(device), y_pred)
        loss = myloss(y.to(device), y_pred)
        #l1_norm = sum(p.abs().sum() for p in model.parameters()).cpu()
        #loss += L1_penalty * l1_norm
        train_loss.append(loss.item())
        train_mae.append(mae_loss.item())

        #fiter.set_description(f"loss {np.mean(train_loss):.4f}")
        #fiter.update()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if (k + 1) % 1000:
            break
        
        #lr_scheduler.step()
        
    # Validate 
    model.eval()
    with torch.no_grad():
        
        # Validate in batch mode
        valid_loss = []
        valid_mae = []
        fiter = list(range(transformed_data['val_in'].shape[0] // batch_size))
        for t in fiter:

            try:
                x, y = next(valid_iter)
            except StopIteration:
                valid_iter = iter(valid_loader)
                x, y = next(valid_iter)
            y_pred = model(x.to(device))
            mae_loss = torch.nn.functional.l1_loss(y.to(device), y_pred)
            loss = myloss(y.to(device), y_pred)
            valid_loss.append(loss.item())
            valid_mae.append(mae_loss.item())
            
            lr_scheduler.step()
        
#         # Validate in box mode
#         box_loss = []
#         box_nll = []
#         box_mae = []
        
#         # set up array for saving predicted results
#         _in_array = torch.from_numpy(val_in_array).float()#.to(device).float()
#         pred_array = np.empty((val_in_array.shape[0], num_timesteps-start_time, len(out_col_idx)))

#         # use initial condition @ t = start_time and get the first prediction
#         print(_in_array[:, start_time, :].shape)
#         output = model(_in_array[:, start_time, :].to(device))
#         pred_array[:, 0, :] = output.cpu().numpy()
#         loss = torch.nn.functional.l1_loss(_in_array[:, start_time + 1, out_col_idx], output.cpu()).item()
#         ev_loss = myloss(_in_array[:, start_time + 1, out_col_idx], output.cpu())
        
#         box_loss.append(loss)
#         box_nll.append(ev_loss)

#         # use the first prediction to get the next, and so on for num_timesteps
#         for k, i in enumerate(range(start_time + 1, num_timesteps)): 
#             new_input = _in_array[:, i, :]
#             new_input[:, out_col_idx] = output.cpu()
#             output = model(new_input.to(device))
#             pred_array[:, k+1, :] = output.cpu().numpy()
#             if i < (num_timesteps-1):
#                 loss = torch.nn.functional.l1_loss(_in_array[:, i+1, out_col_idx], output.cpu()).item()
#                 ev_loss = myloss(
#                     _in_array[:, start_time + 1, out_col_idx],
#                     output.cpu()
#                 )
#                 box_loss.append(loss)
#                 box_nll.append(ev_loss)
                
#         idx = transformed_data["val_out"].index
#         start_time_units = sorted(list(set([x[0] for x in idx])))[start_time]
#         start_time_condition = [(x[0] >= start_time_units) for x in idx]
#         idx = transformed_data["val_out"][start_time_condition].index

#         raw_box_preds = pd.DataFrame(
#             data=pred_array.reshape(-1, len(output_vars)),
#             columns=output_vars, 
#             index=idx
#         )

#         # inverse transform 
#         truth, preds = inv_transform_preds(
#             raw_preds=raw_box_preds,
#             truth=data['val_out'][start_time_condition],
#             y_scaler=y_scaler,
#             log_trans_cols=log_trans_cols,
#             tendency_cols=tendency_cols)
                
#         metrics = ensembled_metrics(y_true=truth,
#                                     y_pred=preds,
#                                     member=0,
#                                     output_vars=output_vars,
#                                     stability_thresh=1.0)
#         mean_box_mae = metrics['mean_mae'].mean()
#         unstable_exps = metrics['n_unstable'].mean()
#         box_mae.append(mean_box_mae)
        
    results_dict["epoch"].append(epoch)
    results_dict["train_loss"].append(np.mean(train_loss))
    results_dict["train_mae"].append(np.mean(train_mae))
    results_dict["val_loss"].append(np.mean(valid_loss))
    results_dict["val_mae"].append(np.mean(valid_mae))
#     results_dict["box_loss"].append(np.mean(box_nll))
#     results_dict["box_step_mae"].append(np.mean(box_loss))
#     results_dict["box_mae"].append(np.mean(box_mae))
#     results_dict["n_unstable"].append(unstable_exps)
    results_dict["lr"].append(optimizer.param_groups[0]['lr'])
    
    # Save the dataframe to disk
    df = pd.DataFrame.from_dict(results_dict).reset_index()
    df.to_csv(f"gecko/training_log.csv", index = False)
    
    print(f'Epoch {epoch} train_loss {results_dict["train_loss"][-1]:4f}',
          f'train_mae {results_dict["train_mae"][-1]:4f}',
          f'val_loss {results_dict["val_loss"][-1]:4f}',
          f'val_mae {results_dict["val_mae"][-1]:4f}',
#           f'box_loss {results_dict["box_loss"][-1]:4f}',
#           f'box_step_mae {results_dict["box_step_mae"][-1]:4f}',
#           f'box_mae {results_dict["box_mae"][-1]:4f}',
#           f'n_unstable {int(results_dict["n_unstable"][-1])}',
          f'lr {results_dict["lr"][-1]}'
         )

    # anneal the learning rate using just the box metric
    #lr_scheduler.step(results_dict["val_loss"][-1])
    
    if results_dict["val_loss"][-1] == min(results_dict["val_loss"]):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': min(results_dict["val_loss"])
        }
        torch.save(state_dict, f"gecko/fno.pt")
    
    # Stop training if we have not improved after X epochs
    best_epoch = [i for i,j in enumerate(results_dict["val_loss"]) if j == min(results_dict["val_loss"])][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break

Epoch 0 train_loss 22.417957 train_mae 0.792602 val_loss 16.176359 val_mae 0.794645 lr 0.0002802747026301318
Epoch 1 train_loss 16.924126 train_mae 0.791155 val_loss 15.005882 val_mae 0.794892 lr 0.00021404630011609955
Epoch 2 train_loss 13.265509 train_mae 0.705045 val_loss 14.250486 val_mae 0.795082 lr 0.0009939057286000041
Epoch 3 train_loss 13.613236 train_mae 0.754310 val_loss 11.543204 val_mae 0.795917 lr 0.00035203658779054726
Epoch 4 train_loss 10.668476 train_mae 0.747207 val_loss 10.821776 val_mae 0.796190 lr 0.00015498214331173107
Epoch 5 train_loss 11.307106 train_mae 0.849139 val_loss 10.533316 val_mae 0.796341 lr 0.0009757729755740051
Epoch 6 train_loss 9.613633 train_mae 0.739867 val_loss 9.000327 val_mae 0.797142 lr 0.00042756493979441185
Epoch 7 train_loss 7.795693 train_mae 0.702892 val_loss 8.451453 val_mae 0.797524 lr 0.00010453658778894987
Epoch 8 train_loss 8.265524 train_mae 0.796532 val_loss 8.326717 val_mae 0.797596 lr 0.0009460482294830313
Epoch 9 train_loss 8

In [None]:
# model.eval()
# y_pred = model(torch.from_numpy(val_in_array).float().to(device))