In [1]:
import torch
import matplotlib.pyplot as plt

from torch import nn
from torch.utils.data import DataLoader, RandomSampler
from tqdm.notebook import tqdm

from models.energy_models import EnergyAttentionRNN, EnergyAttentionRNN2D, EnergyRNN
from models.probability_models import ProbabilityRNN, ProbabilityAttentionRNN
from data_utils.pytorch_datasets import EnergyDataset, EnergyDataset2D, ProbabilityDataset

## Datasets

In [12]:
train_data_path = './data/train_data_32k.hdf5'
batch_size = 128
n_random_samples = 3200

energy_dataset_2D = EnergyDataset2D(filepath=train_data_path)
random_sampler_2D = RandomSampler(energy_dataset_2D, num_samples=n_random_samples, replacement=True)
energy_loader_2D = DataLoader(
    energy_dataset_2D,
    batch_size=batch_size,
    sampler=random_sampler_2D,
    num_workers=0)

energy_dataset = EnergyDataset(filepath=train_data_path)
random_sampler = RandomSampler(energy_dataset, num_samples=n_random_samples, replacement=True)
energy_loader = DataLoader(
    energy_dataset,
    batch_size=batch_size,
    sampler=random_sampler,
    num_workers=0)

prob_dataset = EnergyDataset(filepath=train_data_path)
prob_random_sampler = RandomSampler(prob_dataset, num_samples=n_random_samples, replacement=True)
prob_loader = DataLoader(
    prob_dataset,
    batch_size=batch_size,
    sampler=prob_random_sampler,
    num_workers=0)

## Energy Models

In [3]:
energy_gru_ckpt = './model_weights/energy/rnn/GRU_1L/model.ckpt'
energy_lstm_ckpt = './model_weights/energy/rnn/LSTM_1L/model.ckpt'
energy_attn_ckpt = './model_weights/energy/rnn_attn/GRU_1L_1H/model.ckpt'
energy_2D_attn_ckpt = './model_weights/energy/rnn_attn/2D_GRU_1L_1H/model.ckpt'

energy_gru = EnergyRNN.load_from_checkpoint(energy_gru_ckpt)
energy_gru.eval()

energy_lstm = EnergyRNN.load_from_checkpoint(energy_lstm_ckpt)
energy_lstm.eval()

energy_attn = EnergyAttentionRNN.load_from_checkpoint(energy_attn_ckpt)
energy_attn.eval()

energy_2D_attn = EnergyAttentionRNN2D.load_from_checkpoint(energy_2D_attn_ckpt)
energy_2D_attn.eval()

EnergyAttentionRNN2D(
  (rnn): GRU(1, 512, batch_first=True)
  (linear): Linear(in_features=512, out_features=1, bias=True)
  (attention_rows): MultiheadAttention(
    (out_proj): Linear(in_features=512, out_features=512, bias=True)
  )
  (attention_cols): MultiheadAttention(
    (out_proj): Linear(in_features=512, out_features=512, bias=True)
  )
)

In [4]:
preds_energy_2D = []
true_energy_2D = []
with torch.no_grad():
    for x_row, x_col, y in tqdm(energy_loader_2D):
        y_pred, _ = energy_2D_attn(x_row, x_col)
        preds_energy_2D.append(y_pred)
        true_energy_2D.append(y)

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))




In [5]:
preds_energy_lstm = []
preds_energy_gru = []
preds_energy_attn = []
true_energy_1D = []
with torch.no_grad():
    for x, y in tqdm(energy_loader):
        y_pred, _ = energy_attn(x)
        preds_energy_attn.append(y_pred)
        y_pred = energy_lstm(x)
        preds_energy_lstm.append(y_pred)
        y_pred = energy_gru(x)
        preds_energy_gru.append(y_pred)

        true_energy_1D.append(y)

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))




In [6]:
def center_predictions(preds_list):
    preds_tensor = torch.cat(preds_list, axis=0)
    return preds_tensor - torch.mean(preds_tensor)

In [7]:
centered_energy_2D = center_predictions(preds_energy_2D)
centered_energy_lstm = center_predictions(preds_energy_lstm)
centered_energy_gru = center_predictions(preds_energy_gru)
centered_energy_attn = center_predictions(preds_energy_attn)

true_centered_energy_2D = center_predictions(true_energy_2D)
true_centered_energy_1D = center_predictions(true_energy_1D)

### Predictions

In [8]:
criterion = nn.MSELoss()

rmse_2D = torch.sqrt(criterion(true_centered_energy_2D, centered_energy_2D))
print(f"RMSE for 2D Model: {rmse_2D}")

rmse_attn = torch.sqrt(criterion(true_centered_energy_1D, centered_energy_attn))
print(f"RMSE for Attention Model: {rmse_attn}")

rmse_lstm = torch.sqrt(criterion(true_centered_energy_1D, centered_energy_lstm))
print(f"RMSE for LSTM Model: {rmse_lstm}")

rmse_gru = torch.sqrt(criterion(true_centered_energy_1D, centered_energy_gru))
print(f"RMSE for GRU Model: {rmse_gru}")

RMSE for 2D Model: 5.375464916229248
RMSE for Attention Model: 7.628973007202148
RMSE for LSTM Model: 8.55644416809082
RMSE for GRU Model: 9.345841407775879


## Probability Models

In [9]:
prob_gru_ckpt = './model_weights/probability/rnn/GRU_1L/model.ckpt'
prob_lstm_ckpt = './model_weights/probability/rnn/LSTM_1L/model.ckpt'
prob_attn_ckpt = './model_weights/probability/rnn_attn/GRU_1L_1H/model.ckpt'

prob_gru = ProbabilityRNN.load_from_checkpoint(prob_gru_ckpt)
prob_gru.eval()

prob_lstm = ProbabilityRNN.load_from_checkpoint(prob_lstm_ckpt)
prob_lstm.eval()

prob_attn = ProbabilityAttentionRNN.load_from_checkpoint(prob_attn_ckpt)
prob_attn.eval()

ProbabilityAttentionRNN(
  (rnn): GRU(1, 512, batch_first=True)
  (linear): Linear(in_features=512, out_features=1, bias=True)
  (attention): MultiheadAttention(
    (out_proj): Linear(in_features=512, out_features=512, bias=True)
  )
)

In [14]:
preds_prob_lstm = []
preds_prob_gru = []
preds_prob_attn = []
true_energy_prob = []
with torch.no_grad():
    for x, y in tqdm(prob_loader):
        y_pred = prob_attn.predict_energy(x)
        preds_prob_attn.append(y_pred)
        y_pred = prob_lstm.predict_energy(x)
        preds_prob_lstm.append(y_pred)
        y_pred = prob_gru.predict_energy(x)
        preds_prob_gru.append(y_pred)
        
        true_energy_prob.append(y)

HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))




In [15]:
centered_prob_lstm = center_predictions(preds_prob_lstm)
centered_prob_gru = center_predictions(preds_prob_gru)
centered_prob_attn = center_predictions(preds_prob_attn)

true_centered_prob = center_predictions(true_energy_prob)

In [16]:
criterion = nn.MSELoss()

rmse_attn_prob = torch.sqrt(criterion(true_centered_prob, centered_prob_attn))
print(f"RMSE for Attention Model: {rmse_attn_prob}")

rmse_lstm_prob = torch.sqrt(criterion(true_centered_prob, centered_prob_lstm))
print(f"RMSE for LSTM Model: {rmse_lstm_prob}")

rmse_gru_prob = torch.sqrt(criterion(true_centered_prob, centered_prob_gru))
print(f"RMSE for GRU Model: {rmse_gru_prob}")

RMSE for Attention Model: 12.718344688415527
RMSE for LSTM Model: 11.144323348999023
RMSE for GRU Model: 11.343841552734375
