# Mixture of Experts 

LSTM enhanced with specialized linear storm model

Kit Calcraft 05/08/2024

In [None]:
# magic
%load_ext autoreload
%autoreload 2                                                                                                          
%pdb 1
%matplotlib inline            

In [None]:
import numpy as np
import pandas as pd
import yaml
from tqdm import tqdm

#pytorch
import torch
from torch.utils.data import DataLoader

# MODEL FUNCITONS
from functions.load_data import data_select
from functions.custom_datasets import SequenceDataset
from functions.custom_loss_functions import cumulative_dx_loss
from functions.misc import *
from functions.model_utils import train, predict

# MODELS
from functions.MoE import MixtureOfExperts

#statsitics & plotting
import matplotlib.pyplot as plt
from matplotlib.collections import LineCollection
from matplotlib.lines import Line2D

#check for gpu
if torch.backends.mps.is_available():
   device = torch.device("cpu")
   print(f"Running Torch v{torch.__version__} on MPS device")
else:
   print ("Falling back to CPU")


### 0 - Load Settings & Model Parameters

In [None]:
settings = yaml.safe_load(open("config/model_settings.yml", "r"))

target = settings['target']
batch_size = settings['batch_size']
sequence_length = settings['sequence_length']

### 1 - Select, Split & Standardize Data

In [None]:
data = data_select(settings)
data.train_test_split()
data.standardize()
plot_train_test(data, settings)

### 2 - Build MoE Model

In [None]:
train_dataset = SequenceDataset(data.train, data, settings)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

data.scalers['varIdx'] = train_dataset.varIdx

X, y, I = next(iter(train_loader))
print(X.shape, y.shape, I[0])

In [None]:
torch.random.seed()

model = MixtureOfExperts(n_inputs=len(data.inputs), settings = settings)
loss_function = cumulative_dx_loss()
optimizer = torch.optim.Adam(model.parameters(), lr=float(settings.get('learning_rate')))

print(model)

### 4 - Train MoE Model

In [None]:
progress_bar = tqdm(total=settings['epochs'], desc="Training Progress", unit="epoch")
for epoch in np.arange(settings['epochs']):

    # ----> Training <----
    trainloss,preds,y = train(data.df, 
                            data.train, 
                            train_loader, 
                            model, 
                            loss_function, 
                            data.scalers, 
                            optimizer,
                            settings,
                            device)
    
    progress_bar.set_description(f"Loss: {trainloss:.3f}")
    progress_bar.update()
    
progress_bar.close()

### 5 - Model Output

In [None]:
shoreline = destandardize(data.df[settings['shoreline']], data.scalers,'shoreline')
training_start_position = data.df.index.get_loc(data.train.index[0])
test_start_date = data.test.index[0]

modelTrain = pd.DataFrame()
modelTest = pd.DataFrame()

train_dataset = SequenceDataset(data.train, data, settings)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=False)
test_dataset = SequenceDataset(data.test, data, settings)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

modelTrain['model_output'], train_weight = predict(data.df, 
                        train_loader, 
                        model, 
                        data.scalers,
                        training_start_position, 
                        settings)

modelTest['model_output'], test_weight  = predict(data.df,               
                            test_loader, 
                            model, 
                            data.scalers,
                            test_start_date,           
                            settings)

modelTest.index = data.test.index
modelTrain.index = data.train.index

test_weight.index = data.test.index
train_weight.index = data.train.index

model_output = pd.concat([modelTrain, modelTest]).sort_index()
weight = pd.concat([train_weight,test_weight]).sort_index()

In [None]:
validation_obs = destandardize(data.test[settings['shoreline']], data.scalers, 'shoreline').to_numpy()
validation_preds = modelTest['model_output'].to_numpy()

NMSE, r2_validation = calculate_skill(validation_obs, validation_preds)
NMSE, r2_validation

In [None]:
plt.style.use("bmh")
fig, axs = plt.subplots(1, 1, figsize=(20, 4), sharex=True, sharey=True)
df = model_output.join(weight)
df.rename(columns={0: 'weight'}, inplace=True)

if settings['Model'] == 'Transformer':
    axs.plot(model_output, color = 'purple', linewidth = 3, label = 'Transformer')
    axs.legend(loc=2)

elif settings['Model'] == 'LSTM':
    axs.plot(modelTrain, linewidth = 2, label = 'Configuration')
    axs.plot(modelTest, linewidth = 2, label = 'Validation')
    axs.legend(loc=2)

elif settings['Model'] == 'MoE':
    for ii in range(len(df)):
        temp = df[ii:ii+2]
        c = ('#253D5B' if temp['weight'][-1] == 1 else '#CA2E55')
        axs.plot(temp.model_output, color = c, linewidth = 3, zorder = 3)

    custom_lines = [Line2D([0], [0], color='#253D5B', lw=2),
                    Line2D([0], [0], color='#CA2E55', lw=2)]
    axs.legend(custom_lines, ['LSTM', 'Storm Expert'], loc=2);

axs.scatter(shoreline.index, shoreline, color = 'k', facecolor = 'w', alpha = 1, s = 30, zorder = 0, marker = 's')
axs.axvline(data.train.index[-1], alpha=0.7, color='k', zorder = 0)
axs.set_title(settings['shoreline'])