In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import matplotlib.pyplot as plt
from torch.utils.data.dataset import TensorDataset, Dataset
from torch.utils.data.dataloader import DataLoader

import operator
from functools import reduce
from functools import partial
from timeit import default_timer

import tqdm
import yaml
import pandas as pd
from geckoml.data import load_data, transform_data, inv_transform_preds
from geckoml.metrics import ensembled_metrics
from collections import defaultdict

from geckoml.box import rnn_box_test
from tensorflow.keras.models import load_model
import tensorflow as tf
from tensorflow.keras.layers import LeakyReLU

In [3]:
torch.manual_seed(0)
np.random.seed(0)

In [4]:
is_cuda = torch.cuda.is_available()
device = torch.device(torch.cuda.current_device()) if is_cuda else torch.device("cpu")
if is_cuda:
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True

In [5]:
config_file = "/glade/work/schreck/repos/GECKO_OPT/dev/gecko-ml/config/toluene_agg.yml"

In [6]:
with open(config_file) as cf:
    conf = yaml.load(cf, Loader=yaml.FullLoader)

In [62]:
species = conf['species']
data_path = conf['dir_path']
aggregate_bins = conf['aggregate_bins']
input_vars = conf['input_vars']
output_vars = conf['output_vars']
tendency_cols = conf['tendency_cols']
log_trans_cols = conf['log_trans_cols']
output_path = "./"
scaler_type = conf['scaler_type']
ensemble_members = conf["ensemble_members"]
seed = conf['random_seed']

# Get the shapes of the input and output data 
input_size = len(input_vars)
output_size = len(output_vars)

start_time = 0
num_timesteps = 1439
batch_size = 256

L1_penalty = 1.39e-5
L2_penalty = 3.49e-4

lr_patience = 3
stopping_patience = 5
learning_rate = 1.39e-6

In [8]:
output_vars

['Precursor [ug/m3]', 'Gas [ug/m3]', 'Aerosol [ug_m3]']

In [9]:
data = load_data(data_path, aggregate_bins, species, input_vars, output_vars, log_trans_cols)
    
transformed_data, x_scaler, y_scaler = transform_data(
    data, 
    output_path, 
    species, 
    tendency_cols, 
    log_trans_cols,
    scaler_type, 
    output_vars, 
    train=True
)

# Batch the training data by experiment
train_in_array = transformed_data['train_in'].copy()
n_exps = len(train_in_array.index.unique(level='id'))
n_timesteps = len(train_in_array.index.unique(level='Time [s]'))
n_features = len(input_vars)
out_col_idx = train_in_array.columns.get_indexer(output_vars)
train_in_array = train_in_array.values.reshape(n_exps, n_timesteps, n_features)

# Batch the validation data by experiment
val_in_array = transformed_data['val_in'].copy()
n_exps = len(val_in_array.index.unique(level='id'))
n_timesteps = len(val_in_array.index.unique(level='Time [s]'))
val_out_col_idx = val_in_array.columns.get_indexer(output_vars)
val_in_array = val_in_array.values.reshape(n_exps, n_timesteps, n_features)

train_out_array = transformed_data['train_out'].copy()
n_exps = len(train_out_array.index.unique(level='id'))
n_timesteps = len(train_out_array.index.unique(level='Time [s]'))
n_features = len(output_vars)
out_col_idx = train_out_array.columns.get_indexer(output_vars)
train_out_array = train_out_array.values.reshape(n_exps, n_timesteps, n_features)

val_out_array = transformed_data['val_out'].copy()
n_exps = len(val_out_array.index.unique(level='id'))
n_timesteps = len(val_out_array.index.unique(level='Time [s]'))
val_out_col_idx = val_out_array.columns.get_indexer(output_vars)
val_out_array = val_out_array.values.reshape(n_exps, n_timesteps, n_features)

In [63]:
train_data = TensorDataset(
    torch.from_numpy(transformed_data["train_in"].copy().values).float(),
    torch.from_numpy(transformed_data["train_out"].copy().values).float()
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=8)

valid_data = TensorDataset(
    torch.from_numpy(transformed_data["val_in"].copy().values).float(),
    torch.from_numpy(transformed_data["val_out"].copy().values).float()
)
valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, num_workers=0)

In [11]:
def initialize_weights(m):
    if type(m) in [nn.Linear]:
        nn.init.xavier_uniform_(m.weight)
        #nn.init.xavier_uniform_(model.bias)
        nn.init.zeros_(m.bias)
    elif isinstance(m, nn.BatchNorm2d or nn.BatchNorm1d):
        m.weight.data.fill_(1)
        m.bias.data.zero_()

In [12]:
input_size = 9
middle_size = 4902
output_size = 3
dr = 0.0

In [89]:
model = nn.Sequential(
        nn.Linear(input_size, middle_size),
        nn.ReLU(),
        nn.Linear(middle_size, output_size)
)
        
model = model.apply(initialize_weights).to(device)

In [90]:
sum(p.numel() for p in model.parameters() if p.requires_grad)

63729

In [23]:
# keras_model = load_model(
#     '/glade/work/schreck/repos/GECKO_OPT/dev/gecko-ml/results/mlp/no_prec/toluene/models/toluene_DNN_0/',
#     compile = False
# )
# keras_model.compile(metrics = ["mae"], loss = ["mae"])
# keras_model.summary()
# weights = keras_model.get_weights()

# for layer in keras_model.layers:
#     print(layer.get_config())

In [24]:
# model[0].weight.data=torch.from_numpy(np.transpose(weights[0]))
# model[0].bias.data=torch.from_numpy(weights[1])
# model[1].weight.data=torch.from_numpy(np.transpose(weights[2]))
# model[1].bias.data=torch.from_numpy(weights[3])

# model = model.to(device)

In [25]:
optimizer = torch.optim.Adam(model.parameters(),
                             lr = learning_rate,
                             eps = 1e-7,
                             betas = (0.9, 0.999),
                             amsgrad = False)#,
                             #weight_decay = L2_penalty)

# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer, 
#     T_max=10, 
#     eta_min=1e-3*learning_rate
# )

# lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
#         optimizer, 
#         patience = lr_patience, 
#         verbose = True,
#         min_lr = 1.0e-13
# )

In [26]:
results_dict = defaultdict(list)

for epoch in range(500):
    
    # Train in batch mode
    model.train()
    scaler = torch.cuda.amp.GradScaler()
    
    train_loss = []
    for k, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        
        with torch.cuda.amp.autocast():
            loss = nn.L1Loss()(y.to(device), model(x.to(device)))
        
        l1_norm = sum(p.abs().sum() for p in model.parameters())
        l2_norm = sum(p.pow(2.0).sum() for p in model.parameters())
        
        loss += L1_penalty * l1_norm
        loss += L2_penalty * l2_norm
        
        train_loss.append(loss.item())
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        #loss.backward()
        #optimizer.step()
        
        #lr_scheduler.step()
        
    # Validate 
    model.eval()
    with torch.no_grad():
        
        # Validate in batch mode
        valid_loss = []
        for k, (x, y) in enumerate(valid_loader):
            loss = nn.L1Loss()(y.to(device), model(x.to(device)))
            valid_loss.append(loss.item())

        # Validate in box mode
        box_loss = []
        
        # set up array for saving predicted results
        _in_array = torch.from_numpy(val_in_array.copy()).float()
        pred_array = np.empty((val_in_array.shape[0], num_timesteps-start_time, len(out_col_idx)))

        # use initial condition @ t = start_time and get the first prediction
        gamma = model(_in_array[:, start_time, :].to(device))
        pred_array[:, 0, :] = gamma.cpu().numpy()
        loss = nn.L1Loss()(_in_array[:, start_time + 1, out_col_idx], gamma.cpu()).item()
        box_loss.append(loss)

        # use the first prediction to get the next, and so on for num_timesteps
        for k, i in enumerate(range(start_time + 1, num_timesteps)): 
            new_input = _in_array[:, i, :]
            new_input[:, out_col_idx] = gamma.cpu()
            gamma = model(new_input.to(device))
            pred_array[:, k+1, :] = gamma.cpu().numpy()
            if i < (num_timesteps-1):
                loss = nn.L1Loss()(_in_array[:, i+1, out_col_idx], gamma.cpu()).item()
                box_loss.append(loss)
                
        idx = transformed_data["val_out"].index
        start_time_units = sorted(list(set([x[0] for x in idx])))[start_time]
        start_time_condition = [(x[0] >= start_time_units) for x in idx]
        idx = transformed_data["val_out"][start_time_condition].index

        raw_box_preds = pd.DataFrame(
            data=pred_array.reshape(-1, len(output_vars)),
            columns=output_vars, 
            index=idx
        )

        # inverse transform 
        truth, preds = inv_transform_preds(
            raw_preds=raw_box_preds,
            truth=data['val_out'][start_time_condition],
            y_scaler=y_scaler,
            log_trans_cols=log_trans_cols,
            tendency_cols=tendency_cols)
                
        metrics = ensembled_metrics(y_true=truth,
                                    y_pred=preds,
                                    member=0,
                                    output_vars=output_vars,
                                    stability_thresh=1.0)
        mean_box_mae = metrics['mean_mae'].mean()
        unstable_exps = metrics['n_unstable'].mean()
        
    results_dict["epoch"].append(epoch)
    results_dict["train_loss"].append(np.mean(train_loss))
    results_dict["val_loss"].append(np.mean(valid_loss))
    results_dict["step_loss"].append(np.mean(box_loss))
    results_dict["box_mae"].append(mean_box_mae)
    results_dict["n_unstable"].append(unstable_exps)
    results_dict["lr"].append(optimizer.param_groups[0]['lr'])
    
    # Save the dataframe to disk
    df = pd.DataFrame.from_dict(results_dict).reset_index()
    df.to_csv(f"gecko/training_log_01.csv", index = False)
    
    print(f'Epoch {epoch}',
          f'train_loss {results_dict["train_loss"][-1]:4f}', 
          f'val_loss {results_dict["val_loss"][-1]:4f}',
          f'step_loss {results_dict["step_loss"][-1]:4f}',
          f'box_mae {results_dict["box_mae"][-1]:4f}',
          f'n_unstable {int(results_dict["n_unstable"][-1])}',
          f'lr {results_dict["lr"][-1]}'
         )

    #anneal the learning rate using just the box metric
    #lr_scheduler.step()
    
    if results_dict["box_mae"][-1] == min(results_dict["box_mae"]):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': min(results_dict["box_mae"])
        }
        torch.save(state_dict, f"gecko/mlp_01.pt")
    
    # Stop training if we have not improved after X epochs
#     best_epoch = [i for i,j in enumerate(results_dict["box_mae"]) if j == min(results_dict["box_mae"])][0]
#     offset = epoch - best_epoch
#     if offset >= stopping_patience:
#         break

Epoch 0 train_loss 0.583948 val_loss 0.378858 step_loss 0.626657 box_mae 0.003039 n_unstable 0 lr 1.39e-05
Epoch 1 train_loss 0.262337 val_loss 0.134695 step_loss 0.566282 box_mae 0.002775 n_unstable 0 lr 1.39e-05
Epoch 2 train_loss 0.096616 val_loss 0.035953 step_loss 0.664230 box_mae 0.003119 n_unstable 0 lr 1.39e-05
Epoch 3 train_loss 0.049356 val_loss 0.016455 step_loss 0.725331 box_mae 0.005137 n_unstable 3 lr 1.39e-05
Epoch 4 train_loss 0.038472 val_loss 0.011553 step_loss 0.705827 box_mae 0.005379 n_unstable 4 lr 1.39e-05
Epoch 5 train_loss 0.034111 val_loss 0.008759 step_loss 0.694561 box_mae 0.004965 n_unstable 0 lr 1.39e-05
Epoch 6 train_loss 0.031321 val_loss 0.006981 step_loss 0.660722 box_mae 0.004523 n_unstable 0 lr 1.39e-05
Epoch 7 train_loss 0.029332 val_loss 0.005727 step_loss 0.652781 box_mae 0.004603 n_unstable 0 lr 1.39e-05
Epoch 8 train_loss 0.027844 val_loss 0.004874 step_loss 0.640886 box_mae 0.004593 n_unstable 0 lr 1.39e-05
Epoch 9 train_loss 0.026693 val_loss 

##### Train keras model for comparison

In [32]:
#from keras.models import Sequential
#from keras.layers import Dense

In [33]:
# keras_model = Sequential()
# keras_model.add(Dense(middle_size, input_dim=input_size, activation="relu", 
#                       kernel_regularizer=tf.keras.regularizers.l1_l2(l1=L1_penalty, l2=L2_penalty)))
# keras_model.add(Dense(output_size, activation='linear'))

In [34]:
#keras_model.compile(loss='mae', optimizer='adam', metrics=['mae'])

In [35]:
# keras_model.fit(
#     transformed_data["train_in"],
#     transformed_data["train_out"],
#     validation_data = (transformed_data["val_in"], transformed_data["val_out"]),
#     batch_size = batch_size,
#     shuffle = True,
#     epochs = 841,
#     verbose = 2
# )

In [36]:
# # Validate in box mode
# box_loss = []

# # set up array for saving predicted results
# _in_array = val_in_array.copy()
# pred_array = np.empty((val_in_array.shape[0], num_timesteps-start_time, len(out_col_idx)))

# # use initial condition @ t = start_time and get the first prediction
# gamma = keras_model.predict(_in_array[:, start_time, :])
# pred_array[:, 0, :] = gamma

# # use the first prediction to get the next, and so on for num_timesteps
# for k, i in enumerate(range(start_time + 1, num_timesteps)): 
#     new_input = _in_array[:, i, :]
#     new_input[:, out_col_idx] = gamma
#     gamma = keras_model.predict(new_input)
#     pred_array[:, k+1, :] = gamma

# idx = transformed_data["val_out"].index
# start_time_units = sorted(list(set([x[0] for x in idx])))[start_time]
# start_time_condition = [(x[0] >= start_time_units) for x in idx]
# idx = transformed_data["val_out"][start_time_condition].index

# raw_box_preds = pd.DataFrame(
#     data=pred_array.reshape(-1, len(output_vars)),
#     columns=output_vars, 
#     index=idx
# )

# # inverse transform 
# truth, preds = inv_transform_preds(
#     raw_preds=raw_box_preds,
#     truth=data['val_out'][start_time_condition],
#     y_scaler=y_scaler,
#     log_trans_cols=log_trans_cols,
#     tendency_cols=tendency_cols)

# metrics = ensembled_metrics(y_true=truth,
#                             y_pred=preds,
#                             member=0,
#                             output_vars=output_vars,
#                             stability_thresh=1.0)
# mean_box_mae = metrics['mean_mae'].mean()
# unstable_exps = metrics['n_unstable'].mean()

# print(mean_box_mae)

### Can the VQ-VAE reconstruct the data, using FCL for encoder/decoder?

In [13]:
class Encoder(nn.Module):
    def __init__(self, input_size = 9, output_size = 100, fcl_layers = 1, dr = 0.0):
        super(Encoder, self).__init__()
        self.fcn = self.make_fcn(input_size, output_size, fcl_layers, dr)

    
    def make_fcn(self, input_size, output_size, fcl_layers, dr):
        if len(fcl_layers) > 0:
            fcn = [
                nn.Linear(input_size, fcl_layers[0]),
                #nn.BatchNorm1d(fcl_layers[0]),
                #nn.Dropout(dr),
                torch.nn.ReLU()
            ]
            if len(fcl_layers) == 1:
                fcn.append(nn.Linear(fcl_layers[0], output_size))
            else:
                for i in range(len(fcl_layers)-1):
                    fcn += [
                        nn.Linear(fcl_layers[i], fcl_layers[i+1]),
                        #nn.BatchNorm1d(fcl_layers[i+1]),
                        torch.nn.ReLU(),
                        #nn.Dropout(dr)
                    ]
                #fcn.append(nn.Linear(fcl_layers[i+1], output_size))
        else:
            fcn = [
                nn.Linear(input_size, output_size)
            ]
        return nn.Sequential(*fcn)
    
    def forward(self, x):
        x = self.fcn(x)
        return x

In [14]:
class Decoder(nn.Module):
    def __init__(self, input_size  = 100, output_size = 3, fcl_layers = 1, dr = 0.0):
        super(Decoder, self).__init__()
        self.fcn = self.make_fcn(input_size, output_size, fcl_layers, dr)

    def make_fcn(self, input_size, output_size, fcl_layers, dr):
        if len(fcl_layers) > 0:
            fcn = [
                nn.Linear(input_size, fcl_layers[0]),
                #nn.BatchNorm1d(fcl_layers[0]),
                torch.nn.ReLU(),
                #nn.Dropout(dr)
            ]
            if len(fcl_layers) == 1:
                fcn.append(nn.Linear(fcl_layers[0], output_size))
            else:
                for i in range(len(fcl_layers)-1):
                    fcn += [
                        nn.Linear(fcl_layers[i], fcl_layers[i+1]),
                        #nn.BatchNorm1d(fcl_layers[i+1]),
                        torch.nn.ReLU(),
                        #nn.Dropout(dr)
                    ]
                #fcn.append(nn.Linear(fcl_layers[i+1], output_size))
        else:
            fcn = [
                nn.Linear(input_size, output_size)
            ]
        return nn.Sequential(*fcn)
    
    def forward(self, x):
        x = self.fcn(x)
        return x

In [15]:
class VectorQuantizer(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost):
        super(VectorQuantizer, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.uniform_(-1/self._num_embeddings, 1/self._num_embeddings)
        self._commitment_cost = commitment_cost

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        q_latent_loss = F.mse_loss(quantized, inputs.detach())
        loss = q_latent_loss + self._commitment_cost * e_latent_loss
        
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.permute(0, 3, 1, 2).contiguous(), perplexity, encodings
    

class VectorQuantizerEMA(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, commitment_cost, decay, epsilon=1e-5):
        super(VectorQuantizerEMA, self).__init__()
        
        self._embedding_dim = embedding_dim
        self._num_embeddings = num_embeddings
        
        self._embedding = nn.Embedding(self._num_embeddings, self._embedding_dim)
        self._embedding.weight.data.normal_()
        self._commitment_cost = commitment_cost
        
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self._ema_w = nn.Parameter(torch.Tensor(num_embeddings, self._embedding_dim))
        self._ema_w.data.normal_()
        
        self._decay = decay
        self._epsilon = epsilon

    def forward(self, inputs):
        # convert inputs from BCHW -> BHWC
        flat_input = inputs.contiguous()
        #inputs = inputs.permute(0, 2, 3, 1).contiguous()
        input_shape = inputs.shape
        
        # Flatten input
        #flat_input = inputs.view(-1, self._embedding_dim)
        
        # Calculate distances
        distances = (torch.sum(flat_input**2, dim=1, keepdim=True) 
                    + torch.sum(self._embedding.weight**2, dim=1)
                    - 2 * torch.matmul(flat_input, self._embedding.weight.t()))
            
        # Encoding
        encoding_indices = torch.argmin(distances, dim=1).unsqueeze(1)
        encodings = torch.zeros(encoding_indices.shape[0], self._num_embeddings, device=inputs.device)
        encodings.scatter_(1, encoding_indices, 1)
        
        # Quantize and unflatten
        quantized = torch.matmul(encodings, self._embedding.weight).view(input_shape)
        
        # Use EMA to update the embedding vectors
        if self.training:
            self._ema_cluster_size = self._ema_cluster_size * self._decay + \
                                     (1 - self._decay) * torch.sum(encodings, 0)
            
            # Laplace smoothing of the cluster size
            n = torch.sum(self._ema_cluster_size.data)
            self._ema_cluster_size = (
                (self._ema_cluster_size + self._epsilon)
                / (n + self._num_embeddings * self._epsilon) * n)
            
            dw = torch.matmul(encodings.t(), flat_input)
            self._ema_w = nn.Parameter(self._ema_w * self._decay + (1 - self._decay) * dw)
            
            self._embedding.weight = nn.Parameter(self._ema_w / self._ema_cluster_size.unsqueeze(1))
        
        # Loss
        e_latent_loss = F.mse_loss(quantized.detach(), inputs)
        loss = self._commitment_cost * e_latent_loss
        
        # Straight Through Estimator
        quantized = inputs + (quantized - inputs).detach()
        avg_probs = torch.mean(encodings, dim=0)
        perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
        
        # convert quantized from BHWC -> BCHW
        return loss, quantized.contiguous(), perplexity, encodings

In [16]:
class Model(nn.Module):
    def __init__(self, input_size, middle_size, output_size, fcl_layers, dr,
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        
        super(Model, self).__init__()
        
        self._encoder = Encoder(input_size, embedding_dim, fcl_layers, dr)
        
        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, 
                                              commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
            
        self._decoder = Decoder(embedding_dim, output_size, fcl_layers[::-1], dr)
        
#         self.fc1 = torch.nn.Linear(input_size + output_size, middle_size)
#         #self.bn1 = nn.BatchNorm1d(middle_size)
#         self.ac1 = nn.ReLU()
#         #self.dr1 = nn.Dropout(dr)
#         self.fc2 = torch.nn.Linear(middle_size, output_size)

    def forward(self, x):
        z = self._encoder(x)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        x_recon = self._decoder(quantized)
#         x = torch.cat([x, x_recon], axis = 1)
#         x = self.fc1(x)
#         #x = self.bn1(x)
#         x = self.ac1(x)
#         #x = self.dr1(x)
#         x = self.fc2(x)
        return loss, x_recon, perplexity

In [17]:
input_size = 9
middle_size = 128
output_size = 9

fcl_layers = []
embedding_dim = middle_size
num_embeddings = 512

dr = 0.0
commitment_cost = 0.25
decay = 0.99
loss_weight = 1.0

In [20]:
vae_model = Model(input_size, middle_size, 3, 
              fcl_layers, dr, num_embeddings, embedding_dim, 
              commitment_cost, decay).to(device)

vae_model = vae_model.apply(initialize_weights).to(device)

In [21]:
vae_model

Model(
  (_encoder): Encoder(
    (fcn): Sequential(
      (0): Linear(in_features=9, out_features=128, bias=True)
    )
  )
  (_vq_vae): VectorQuantizerEMA(
    (_embedding): Embedding(512, 128)
  )
  (_decoder): Decoder(
    (fcn): Sequential(
      (0): Linear(in_features=128, out_features=3, bias=True)
    )
  )
)

In [22]:
optimizer = torch.optim.Adam(vae_model.parameters(),
                             lr = 1e-3,
                             eps = 1e-7,
                             betas = (0.9, 0.999),
                             amsgrad = False)

# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer, 
#     T_max=10, 
#     eta_min=1e-3*learning_rate
# )

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        patience = lr_patience, 
        verbose = True,
        min_lr = 1.0e-13
)

In [24]:
results_dict = defaultdict(list)

for epoch in range(200):
    
    # Train in batch mode
    vae_model.train()
    
    train_loss = []
    train_perp = []
    train_mse = []
    
    for k, (x, y) in enumerate(train_loader):
        vq_loss, x_pred, perplexity = vae_model(x.to(device))
        #recon_loss = F.mse_loss(y.to(device), y_pred) #/ data_variance
        recon_loss = torch.nn.HuberLoss()(x[:, out_col_idx].to(device), x_pred) 
        loss = recon_loss + loss_weight * vq_loss
        
#         l1_norm = sum(p.abs().sum() for p in vae_model.parameters())
#         l2_norm = sum(p.pow(2.0).sum() for p in vae_model.parameters())
        
#         loss += L1_penalty * l1_norm
#         loss += L2_penalty * l2_norm
               
        train_loss.append(loss.item())
        train_mse.append(recon_loss.item())
        train_perp.append(perplexity.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        #lr_scheduler.step()

    # Validate 
    vae_model.eval()
    with torch.no_grad():
        
        # Validate in batch mode
        valid_loss = []
        valid_perp = []
        valid_mse = []
        
        for k, (x, y) in enumerate(valid_loader):
            vq_loss, x_pred, perplexity = vae_model(x.to(device))
            #recon_loss = F.mse_loss(y.to(device), y_pred) #/ data_variance
            recon_loss = torch.nn.L1Loss()(x[:, out_col_idx].to(device), x_pred)
            loss = recon_loss + vq_loss
        
            valid_loss.append(loss.item())
            valid_mse.append(recon_loss.item())
            valid_perp.append(perplexity.item())
        
    results_dict["epoch"].append(epoch)
    results_dict["train_loss"].append(np.mean(train_loss))
    results_dict["train_perp"].append(np.mean(train_perp))
    results_dict["train_mse"].append(np.mean(train_mse))
    results_dict["valid_loss"].append(np.mean(valid_loss))
    results_dict["valid_perp"].append(np.mean(valid_perp))
    results_dict["valid_mae"].append(np.mean(valid_mse))
    results_dict["lr"].append(optimizer.param_groups[0]['lr'])
    
    # Save the dataframe to disk
    df = pd.DataFrame.from_dict(results_dict).reset_index()
    df.to_csv(f"gecko/training_log.csv", index = False)
    
    print(f'Epoch {epoch}',
          f'train_loss {results_dict["train_loss"][-1]:2f}',
          f'train_mse {results_dict["train_mse"][-1]:2f}',
          f'train_perp {results_dict["train_perp"][-1]:2f}',
          f'valid_loss {results_dict["valid_loss"][-1]:2f}',
          f'valid_mae {results_dict["valid_mae"][-1]:2f}',
          f'valid_perp {results_dict["valid_perp"][-1]:2f}',
          f'lr {results_dict["lr"][-1]:6f}'
         )

    # anneal the learning rate using just the box metric
    lr_scheduler.step(results_dict["valid_mae"][-1])
    
    if results_dict["valid_mae"][-1] == min(results_dict["valid_mae"]):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': vae_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': min(results_dict["valid_mae"])
        }
        torch.save(state_dict, f"gecko/vae.pt")
    
    # Stop training if we have not improved after X epochs
    best_epoch = [i for i,j in enumerate(results_dict["valid_mae"]) if j == min(results_dict["valid_mae"])][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break

Epoch 0 train_loss 0.222003 train_mse 0.193752 train_perp 4.848610 valid_loss 0.486008 valid_mae 0.446790 valid_perp 4.032362 lr 0.001000
Epoch 1 train_loss 0.210637 train_mse 0.159421 train_perp 4.739242 valid_loss 0.506191 valid_mae 0.438513 valid_perp 4.017611 lr 0.001000
Epoch 2 train_loss 0.188017 train_mse 0.133635 train_perp 8.431925 valid_loss 0.344203 valid_mae 0.309318 valid_perp 11.241926 lr 0.001000
Epoch 3 train_loss 0.076186 train_mse 0.055412 train_perp 48.337995 valid_loss 0.223215 valid_mae 0.207162 valid_perp 22.814142 lr 0.001000
Epoch 4 train_loss 0.048049 train_mse 0.035123 train_perp 79.883409 valid_loss 0.189214 valid_mae 0.177309 valid_perp 28.421633 lr 0.001000
Epoch 5 train_loss 0.037072 train_mse 0.027088 train_perp 99.291526 valid_loss 0.169258 valid_mae 0.159621 valid_perp 31.669755 lr 0.001000
Epoch 6 train_loss 0.030644 train_mse 0.022349 train_perp 114.888693 valid_loss 0.152940 valid_mae 0.144810 valid_perp 34.782575 lr 0.001000
Epoch 7 train_loss 0.026

In [25]:
vae_model.eval()

pred_x = vae_model(x.to(device))

In [26]:
pred_x

(tensor(0.0009, device='cuda:0', grad_fn=<MulBackward0>),
 tensor([[ 0.9863,  0.8765, -0.5707],
         [ 0.9863,  0.8765, -0.5707],
         [ 0.9863,  0.8765, -0.5707],
         ...,
         [-0.1751, -0.8979, -0.5781],
         [-0.1751, -0.8979, -0.5781],
         [-0.1751, -0.8979, -0.5781]], device='cuda:0', grad_fn=<AddmmBackward>),
 tensor(13.1072, device='cuda:0'))

In [28]:
x[:, out_col_idx]

tensor([[ 0.8855,  0.9361, -0.5744],
        [ 0.8845,  0.9349, -0.5743],
        [ 0.8835,  0.9336, -0.5742],
        ...,
        [-0.1712, -0.8818, -0.5606],
        [-0.1722, -0.8829, -0.5606],
        [-0.1731, -0.8840, -0.5605]])

### Predict the next state

In [71]:
class Decoder(nn.Module):
    def __init__(self, input_size  = 100, output_size = 3, fcl_layers = 1, dr = 0.0):
        super(Decoder, self).__init__()
        self.fcn = self.make_fcn(input_size, output_size, fcl_layers, dr)

    def make_fcn(self, input_size, output_size, fcl_layers, dr):
        if len(fcl_layers) > 0:
            fcn = [
                nn.Linear(input_size, fcl_layers[0]),
                #nn.BatchNorm1d(fcl_layers[0]),
                torch.nn.ReLU(),
                #nn.Dropout(dr)
            ]
            if len(fcl_layers) == 1:
                fcn.append(nn.Linear(fcl_layers[0], output_size))
            else:
                for i in range(len(fcl_layers)-1):
                    fcn += [
                        nn.Linear(fcl_layers[i], fcl_layers[i+1]),
                        #nn.BatchNorm1d(fcl_layers[i+1]),
                        torch.nn.ReLU(),
                        #nn.Dropout(dr)
                    ]
                fcn.append(nn.Linear(fcl_layers[i+1], output_size))
        else:
            fcn = [
                nn.Linear(input_size, output_size)
            ]
        return nn.Sequential(*fcn)
    
    def forward(self, x):
        x = self.fcn(x)
        return x

class box_vqvae(nn.Module):
    def __init__(self, input_size, middle_size, output_size, fcl_layers, dr,
                 num_embeddings, embedding_dim, commitment_cost, decay=0):
        
        super(box_vqvae, self).__init__()
        
        self._encoder = Encoder(input_size, embedding_dim, fcl_layers, dr)
        
        if decay > 0.0:
            self._vq_vae = VectorQuantizerEMA(num_embeddings, embedding_dim, 
                                              commitment_cost, decay)
        else:
            self._vq_vae = VectorQuantizer(num_embeddings, embedding_dim,
                                           commitment_cost)
            
        self._decoder = Decoder(embedding_dim, output_size, fcl_layers[::-1][1:], dr)
        
#         self.fc1 = torch.nn.Linear(input_size + output_size, middle_size)
#         #self.bn1 = nn.BatchNorm1d(middle_size)
#         self.ac1 = nn.ReLU()
#         #self.dr1 = nn.Dropout(dr)
#         self.fc2 = torch.nn.Linear(middle_size, output_size)

    def forward(self, x):
        z = self._encoder(x)
        loss, quantized, perplexity, _ = self._vq_vae(z)
        x_recon = self._decoder(quantized)
#         x = torch.cat([x, x_recon], axis = 1)
#         x = self.fc1(x)
#         #x = self.bn1(x)
#         x = self.ac1(x)
#         #x = self.dr1(x)
#         x = self.fc2(x)
        return loss, x_recon, perplexity

In [87]:
num_embeddings = 1024
embedding_dim = 512
fcl_layers = []

integrator = box_vqvae(input_size, middle_size, 3, 
              fcl_layers, dr, num_embeddings, embedding_dim, 
              commitment_cost, decay).to(device)

integrator = integrator.apply(initialize_weights).to(device)

In [88]:
integrator

box_vqvae(
  (_encoder): Encoder(
    (fcn): Sequential(
      (0): Linear(in_features=9, out_features=512, bias=True)
    )
  )
  (_vq_vae): VectorQuantizerEMA(
    (_embedding): Embedding(1024, 512)
  )
  (_decoder): Decoder(
    (fcn): Sequential(
      (0): Linear(in_features=512, out_features=3, bias=True)
    )
  )
)

In [89]:
optimizer = torch.optim.Adam(integrator.parameters(),
                             lr = 1e-3,
                             eps = 1e-7,
                             betas = (0.9, 0.999),
                             amsgrad = False)

# lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
#     optimizer, 
#     T_max=10, 
#     eta_min=1e-3*learning_rate
# )

lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 
        patience = lr_patience, 
        verbose = True,
        min_lr = 1.0e-13
)

In [90]:
results_dict = defaultdict(list)

for epoch in range(200):
    
    # Train in batch mode
    integrator.train()
    
    train_loss = []
    train_perp = []
    train_mse = []
    
    for k, (x, y) in enumerate(train_loader):
            
        vq_loss, y_pred, perplexity = integrator(x.to(device))
        #recon_loss = F.mse_loss(y.to(device), y_pred) #/ data_variance
        recon_loss = torch.nn.HuberLoss()(y.to(device), y_pred) 
        loss = recon_loss + loss_weight * vq_loss
        
#         l1_norm = sum(p.abs().sum() for p in integrator.parameters())
#         l2_norm = sum(p.pow(2.0).sum() for p in integrator.parameters())
        
#         loss += L1_penalty * l1_norm
#         loss += L2_penalty * l2_norm
               
        train_loss.append(loss.item())
        train_mse.append(recon_loss.item())
        train_perp.append(perplexity.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        
    # Train in box mode
#     integrator.train()
    
#     train_loss = []
#     train_perp = []
#     train_mse = []

#     # set up array for saving predicted results
#     _in_array = torch.from_numpy(train_in_array).to(device).float()
#     #pred_array = np.empty((train_in_array.shape[0], num_timesteps-start_time, len(out_col_idx)))

#     # use initial condition @ t = start_time and get the first prediction
#     vq_loss, output, perplexity = integrator(_in_array[:, start_time, :].to(device))
#     #pred_array[:, 0, :] = output.cpu().numpy()
#     recon_loss = F.l1_loss(_in_array[:, start_time + 1, out_col_idx], output)
#     loss = recon_loss + loss_weight * vq_loss
#     l1_norm = sum(p.abs().sum() for p in integrator.parameters()).cpu()
#     loss += L1_penalty * l1_norm
    
#     train_loss.append(loss.item())
#     train_mse.append(recon_loss.item())
#     train_perp.append(perplexity.item())

#     optimizer.zero_grad()
#     loss.backward()
#     optimizer.step()
    
#     # use the first prediction to get the next, and so on for num_timesteps
#     for k, i in enumerate(range(start_time + 1, num_timesteps)): 
#         new_input = _in_array[:, i, :]
#         new_input[:, out_col_idx] = output.detach()
#         vq_loss, output, perplexity = integrator(new_input)
#         if i < (num_timesteps-1):
#             recon_loss = F.l1_loss(_in_array[:, i + 1, out_col_idx], output)
#             loss = recon_loss + loss_weight * vq_loss
    
#             train_loss.append(loss.item())
#             train_mse.append(recon_loss.item())
#             train_perp.append(perplexity.item())

#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
        
    # Validate 
    integrator.eval()
    with torch.no_grad():
        
        # Validate in batch mode
        valid_loss = []
        valid_perp = []
        valid_mse = []
        
        for k, (x, y) in enumerate(valid_loader):
            
            vq_loss, y_pred, perplexity = integrator(x.to(device))
            
            #recon_loss = F.mse_loss(y.to(device), y_pred) #/ data_variance
            recon_loss = torch.nn.L1Loss()(y.to(device), y_pred) 
            loss = recon_loss + vq_loss
        
            valid_loss.append(loss.item())
            valid_mse.append(recon_loss.item())
            valid_perp.append(perplexity.item())
            
#         # Validate in box mode
#         box_loss = []
#         box_perp = []
#         box_mse = []
        
#         # set up array for saving predicted results
#         _in_array = torch.from_numpy(val_in_array).to(device).float()
#         pred_array = np.empty((val_in_array.shape[0], num_timesteps-start_time, len(out_col_idx)))

#         # use initial condition @ t = start_time and get the first prediction
#         vq_loss, y_pred, perplexity = integrator(_in_array[:, start_time, :].to(device))
#         pred_array[:, 0, :] = y_pred.cpu().numpy()
#         recon_loss = F.l1_loss(_in_array[:, start_time + 1, out_col_idx], y_pred)
#         loss = recon_loss + loss_weight * vq_loss
        
#         box_loss.append(loss.item())
#         box_mse.append(recon_loss.item())
#         box_perp.append(perplexity.item())

#         # use the first prediction to get the next, and so on for num_timesteps
#         for k, i in enumerate(range(start_time + 1, num_timesteps)): 
#             new_input = _in_array[:, i, :]
#             new_input[:, out_col_idx] = y_pred
#             vq_loss, y_pred, perplexity = integrator(new_input.to(device))
#             pred_array[:, k+1, :] = y_pred.cpu().numpy()
#             if i < (num_timesteps-1):
#                 recon_loss = F.l1_loss(_in_array[:, i + 1, out_col_idx], y_pred)
#                 loss = recon_loss + loss_weight * vq_loss
#                 box_loss.append(loss.item())
#                 box_mse.append(recon_loss.item())
#                 box_perp.append(perplexity.item())
                
#         idx = transformed_data["val_out"].index
#         start_time_units = sorted(list(set([x[0] for x in idx])))[start_time]
#         start_time_condition = [(x[0] >= start_time_units) for x in idx]
#         idx = transformed_data["val_out"][start_time_condition].index

#         raw_box_preds = pd.DataFrame(
#             data=pred_array.reshape(-1, len(output_vars)),
#             columns=output_vars, 
#             index=idx
#         )

#         # inverse transform 
#         truth, preds = inv_transform_preds(
#             raw_preds=raw_box_preds,
#             truth=data['val_out'][start_time_condition],
#             y_scaler=y_scaler,
#             log_trans_cols=log_trans_cols,
#             tendency_cols=tendency_cols)
                
#         metrics = ensembled_metrics(y_true=truth,
#                                     y_pred=preds,
#                                     member=0,
#                                     output_vars=output_vars,
#                                     stability_thresh=1.0)
#         mean_box_mae = metrics['mean_mae'].mean()
#         unstable_exps = int(metrics['n_unstable'].mean())
        
    results_dict["epoch"].append(epoch)
    results_dict["train_loss"].append(np.mean(train_loss))
    results_dict["train_perp"].append(np.mean(train_perp))
    results_dict["train_mse"].append(np.mean(train_mse))
    results_dict["valid_loss"].append(np.mean(valid_loss))
    results_dict["valid_perp"].append(np.mean(valid_perp))
    results_dict["valid_mae"].append(np.mean(valid_mse))
#     results_dict["box_loss"].append(np.mean(box_loss))
#     results_dict["box_perp"].append(np.mean(box_perp))
#     results_dict["box_mse"].append(np.mean(box_mse))
#     results_dict["box_mae"].append(mean_box_mae)
    results_dict["lr"].append(optimizer.param_groups[0]['lr'])
    
    # Save the dataframe to disk
    df = pd.DataFrame.from_dict(results_dict).reset_index()
    df.to_csv(f"gecko/integrator_training_log.csv", index = False)
    
    print(f'Epoch {epoch}',
          f'train_loss {results_dict["train_loss"][-1]:2f}',
          f'train_mse {results_dict["train_mse"][-1]:2f}',
          f'train_perp {results_dict["train_perp"][-1]:2f}',
          f'valid_loss {results_dict["valid_loss"][-1]:2f}',
          f'valid_mae {results_dict["valid_mae"][-1]:2f}',
          f'valid_perp {results_dict["valid_perp"][-1]:2f}',
#           f'box_loss {results_dict["box_loss"][-1]:2f}',
#           f'box_mse {results_dict["box_mse"][-1]:2f}',
#           f'box_perp {results_dict["box_perp"][-1]:2f}',
#           f'box_mae {results_dict["box_mae"][-1]:2f}',
          f'n {unstable_exps}',
          f'lr {results_dict["lr"][-1]:6f}'
         )

    # anneal the learning rate using just the box metric
    lr_scheduler.step(results_dict["valid_loss"][-1])
    
    if results_dict["valid_loss"][-1] == min(results_dict["valid_loss"]):
        state_dict = {
            'epoch': epoch,
            'model_state_dict': vae_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': min(results_dict["valid_loss"])
        }
        torch.save(state_dict, f"gecko/integrator.pt")
    
    # Stop training if we have not improved after X epochs
    best_epoch = [i for i,j in enumerate(results_dict["valid_loss"]) if j == min(results_dict["valid_loss"])][0]
    offset = epoch - best_epoch
    if offset >= stopping_patience:
        break

Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x2aec7627de60><function _MultiProcessingDataLoaderIter.__del__ at 0x2aec7627de60><function _MultiProcessingDataLoaderIter.__del__ at 0x2aec7627de60>


Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/glade/work/schreck/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
  File "/glade/work/schreck/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
  File "/glade/work/schreck/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
            self._shutdown_workers()self._shutdown_workers()
self._shutdown_workers()
  File "/glade/work/schreck/py37/lib/python3.7/site-packages/torch/utils/data/dataloader.py", line 1320, in _shutdown_workers

  File "/glade/work/schreck/py37/lib/python3.7/site-packages/tor

Epoch 0 train_loss 0.145813 train_mse 0.081443 train_perp 113.754626 valid_loss 0.103038 valid_mae 0.095279 valid_perp 3.945969 n 0 lr 0.001000
Epoch 1 train_loss 0.010103 train_mse 0.005799 train_perp 157.658785 valid_loss 0.087850 valid_mae 0.083142 valid_perp 4.853414 n 0 lr 0.001000
Epoch 2 train_loss 0.008979 train_mse 0.005048 train_perp 154.722582 valid_loss 0.080108 valid_mae 0.075404 valid_perp 4.798885 n 0 lr 0.001000
Epoch 3 train_loss 0.008774 train_mse 0.004850 train_perp 151.375798 valid_loss 0.081043 valid_mae 0.076317 valid_perp 4.685005 n 0 lr 0.001000
Epoch 4 train_loss 0.008772 train_mse 0.004800 train_perp 149.141848 valid_loss 0.085481 valid_mae 0.080618 valid_perp 4.640080 n 0 lr 0.001000
Epoch 5 train_loss 0.008781 train_mse 0.004780 train_perp 145.699964 valid_loss 0.082374 valid_mae 0.077513 valid_perp 4.574078 n 0 lr 0.001000
Epoch 6 train_loss 0.008821 train_mse 0.004785 train_perp 143.498911 valid_loss 0.082138 valid_mae 0.077234 valid_perp 4.540907 n 0 lr 0

In [None]:
pred = integrator(x.to(device))[1]

In [None]:
pred

In [61]:
y

tensor([[ 0.8845,  0.9349, -0.5743],
        [ 0.8835,  0.9336, -0.5742],
        [ 0.8825,  0.9324, -0.5742],
        ...,
        [-0.1722, -0.8829, -0.5606],
        [-0.1731, -0.8840, -0.5605],
        [-0.1741, -0.8851, -0.5605]])