In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from sklearn.preprocessing import StandardScaler
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from ax.service.managed_loop import optimize
from statistics import mean
from pykalman import KalmanFilter

import logging
from ax.utils.common.logger import get_logger

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device: ", device)

# Only show log messages of ERROR while testing.
logger = get_logger(__name__, level=logging.ERROR)
if logger.parent is not None and hasattr(logger.parent, "handlers"):
    logger.parent.handlers[0].setLevel(logging.ERROR)

# Load and preprocess data
data = pd.read_csv('OCULUS_dataset.csv')
data.fillna(data.mean(), inplace=True)

features = data.drop(["game_section", "stress_label", "subject"], axis=1)
labels = data["stress_label"]

scaler = StandardScaler()
features = scaler.fit_transform(features)

features_tensor = torch.tensor(features, dtype=torch.float32).to(device)
labels_tensor = torch.tensor(labels.values, dtype=torch.float32).to(device)

class StressLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout_rate_1, dropout_rate_2, dropout_rate_fc):
        super(StressLSTM, self).__init__()
        self.lstm1 = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout_rate_1)
        self.lstm2 = nn.LSTM(hidden_size, hidden_size, num_layers, batch_first=True, dropout=dropout_rate_2)
        self.dropout = nn.Dropout(dropout_rate_fc)
        self.fc = nn.Linear(hidden_size, 1)

    def forward(self, x):
        out, _ = self.lstm1(x)
        out, _ = self.lstm2(out)
        out = self.dropout(out)
        out = self.fc(out).squeeze(-1)
        return out

parameters = [
    {"name": "lr", "type": "range", "bounds": [1e-6, 1e-3], "log_scale": True, "value_type": "float"},
    {"name": "hidden_size", "type": "range", "bounds": [50, 256], "value_type": "int"},
    {"name": "num_layers", "type": "range", "bounds": [2, 3], "value_type": "int"},
    {"name": "num_epochs", "type": "fixed", "value": 150, "value_type": "int"},
    {"name": "dropout_rate_1", "type": "range", "bounds": [0.0, 0.7], "value_type": "float"},
    {"name": "dropout_rate_2", "type": "range", "bounds": [0.0, 0.7], "value_type": "float"},
    {"name": "dropout_rate_fc", "type": "range", "bounds": [0.0, 0.7], "value_type": "float"},
    {"name": "weight_decay", "type": "range", "bounds": [1e-6, 1], "value_type": "float"},
    {"name": "batch_size", "type": "range", "bounds": [10, 100], "value_type": "int"}
]

criterion = nn.MSELoss().to(device)
mae_criterion = nn.L1Loss().to(device)

def evaluate_full_dataset(parameters, features, labels):
    print("Evaluating with parameters: " + str(parameters))

    # Split data
    X_train, X_val, y_train, y_val = train_test_split(features, labels, test_size=0.2, shuffle=False)

    # Create DataLoader
    train_dataset = TensorDataset(X_train, y_train)
    train_dataloader = DataLoader(train_dataset, batch_size=int(parameters['batch_size']), shuffle=False)

    # Model initialization
    model = StressLSTM(input_size=features.shape[1],
                       hidden_size=int(parameters['hidden_size']),
                       num_layers=int(parameters['num_layers']),
                       dropout_rate_1=parameters['dropout_rate_1'],
                       dropout_rate_2=parameters['dropout_rate_2'],
                       dropout_rate_fc=parameters['dropout_rate_fc']).to(device)

    # Optimizer and loss function
    optimizer = optim.Adam(model.parameters(), lr=parameters['lr'], weight_decay=parameters['weight_decay'])
    
    best_val_loss = float('inf')
    patience, patience_threshold = 0, 10  # Early stopping

    for epoch in range(int(parameters['num_epochs'])):
        model.train()
        running_loss = 0.0

        # Training loop
        for batch_features, batch_labels in train_dataloader:
            batch_features, batch_labels = batch_features.unsqueeze(0).to(device), batch_labels.view(-1).to(device)

            optimizer.zero_grad()
            outputs = model(batch_features)
            loss = criterion(outputs, batch_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        # Validation
        model.eval()
        with torch.no_grad():
            val_outputs = model(X_val.unsqueeze(0).to(device))
            val_loss = criterion(val_outputs.view(-1), y_val.view(-1).to(device))

        print(f'Epoch [{epoch+1}/{int(parameters["num_epochs"])}], Loss: {running_loss/len(train_dataloader)}, Val Loss: {val_loss.item()}')

        # Early stopping logic
        if val_loss.item() < best_val_loss:
            best_val_loss = val_loss.item()
            patience = 0
        else:
            patience += 1
            if patience >= patience_threshold:
                print(f"Early stopping at epoch {epoch+1}")
                break

    return best_val_loss

def evaluate_with_objective(parameters):
    mse_loss = evaluate_full_dataset(parameters, features_tensor, labels_tensor)
    print(f"MSE Loss: {mse_loss}")
    return mse_loss  # Return only MSE as the optimization target


best_parameters, values, experiment, model = optimize(
    parameters=parameters,
    evaluation_function=evaluate_with_objective,
    objective_name='loss'
)

print(best_parameters)

input_size = features.shape[1]
output_size = 1
learning_rate = best_parameters['lr']
hidden_size = int(best_parameters['hidden_size'])
num_layers = int(best_parameters['num_layers'])
num_epochs = int(best_parameters['num_epochs'])
dropout_rate_1 = best_parameters['dropout_rate_1']
dropout_rate_2 = best_parameters['dropout_rate_2']
dropout_rate_fc = best_parameters['dropout_rate_fc']
weight_decay = best_parameters['weight_decay']
batch_size = int(best_parameters['batch_size'])

tscv = TimeSeriesSplit(n_splits=5)
kf = KalmanFilter(initial_state_mean=0, n_dim_obs=1)

def calculate_loss_and_mae(model, criterion, mae_criterion, inputs, targets, use_kalman_filter=False):
    outputs = model(inputs.unsqueeze(0).to(device))
    if use_kalman_filter:
        (filtered_state_means, filtered_state_covariances) = kf.filter(outputs.cpu().detach().numpy())
        outputs = torch.tensor(filtered_state_means, dtype=torch.float32).view(1, -1).to(device)
    loss = criterion(outputs, targets.unsqueeze(0).to(device))
    mae = mae_criterion(outputs, targets.unsqueeze(0).to(device))
    return loss, mae

grouped = data.groupby('subject')

def do_stuff():
    all_X_train, all_y_train, all_X_val, all_y_val, all_X_test, all_y_test = [], [], [], [], [], []
    for name, group in grouped:
        print(f"Subject: {name}")
        features = group.drop(["game_section", "stress_label", "subject"], axis=1)
        labels = group["stress_label"]
    
        scaler = StandardScaler()
        features = scaler.fit_transform(features)
    
        features_tensor = torch.tensor(features, dtype=torch.float32).to(device)
        labels_tensor = torch.tensor(labels.values, dtype=torch.float32).to(device)
    
        for train_index, test_index in tscv.split(features_tensor):
            X_train_val, X_test = features_tensor[train_index], features_tensor[test_index]
            y_train_val, y_test = labels_tensor[train_index], labels_tensor[test_index]
    
            X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.2, shuffle=False)
    
            all_X_train.append(X_train)
            all_y_train.append(y_train)
            all_X_val.append(X_val)
            all_y_val.append(y_val)
            all_X_test.append(X_test)
            all_y_test.append(y_test)
    
    X_train = torch.cat(all_X_train).to(device)
    y_train = torch.cat(all_y_train).to(device)
    X_val = torch.cat(all_X_val).to(device)
    y_val = torch.cat(all_y_val).to(device)
    X_test = torch.cat(all_X_test).to(device)
    y_test = torch.cat(all_y_test).to(device)
    
    window_sizes = [2, 4, 6]
    
    all_train_losses, all_val_losses, all_train_maes, all_val_maes, losses, maes = [], [], [], [], [], []
    all_filtered_train_losses, all_filtered_val_losses, all_filtered_train_maes, all_filtered_val_maes, filtered_losses, filtered_maes = [], [], [], [], [], []
    
    for window_size in window_sizes:
        print(f"Window Size: {window_size}")
        if len(X_train) < window_size or len(y_train) < window_size:
            continue
    
        train_dataset = TensorDataset(X_train[:len(X_train)//window_size*window_size], y_train[:len(y_train)//window_size*window_size])
        train_dataloader = DataLoader(train_dataset, batch_size=window_size, shuffle=False)
    
        model = StressLSTM(input_size, hidden_size, num_layers, dropout_rate_1, dropout_rate_2, dropout_rate_fc).to(device)
        optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
        model.eval()
        with torch.no_grad():
            training_outputs = model(X_train.unsqueeze(0).to(device))
            kf.em(training_outputs.cpu().detach().numpy(), n_iter=10)
    
        train_losses, val_losses, train_maes, val_maes = [], [], [], []
        filtered_train_losses, filtered_val_losses, filtered_train_maes, filtered_val_maes = [], [], [], []
        for epoch in range(num_epochs):
            print(f"Epoch {epoch+1}/{num_epochs}", end="\r")
            running_loss = 0.0
            running_mae = 0.0
            running_filtered_loss = 0.0
            running_filtered_mae = 0.0
            num_batches = 0
    
            for batch_features, batch_labels in train_dataloader:
                model.train()
                loss, mae = calculate_loss_and_mae(model, criterion, mae_criterion, batch_features, batch_labels)
                filtered_loss, filtered_mae = calculate_loss_and_mae(model, criterion, mae_criterion, batch_features, batch_labels, use_kalman_filter=True)
    
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
    
                running_loss += loss.item()
                running_mae += mae.item()
                running_filtered_loss += filtered_loss.item()
                running_filtered_mae += filtered_mae.item()
                num_batches += 1
    
            average_train_loss = running_loss / num_batches
            average_train_mae = running_mae / num_batches
            average_filtered_loss = running_filtered_loss / num_batches
            average_filtered_mae = running_filtered_mae / num_batches
            train_losses.append(average_train_loss)
            train_maes.append(average_train_mae)
            filtered_train_losses.append(average_filtered_loss)
            filtered_train_maes.append(average_filtered_mae)
    
            model.eval()
            with torch.no_grad():
                val_loss, val_mae = calculate_loss_and_mae(model, criterion, mae_criterion, X_val, y_val)
                val_filtered_loss, val_filtered_mae = calculate_loss_and_mae(model, criterion, mae_criterion, X_val, y_val, use_kalman_filter=True)
    
            val_losses.append(val_loss.item())
            val_maes.append(val_mae.item())
            filtered_val_losses.append(val_filtered_loss.item())
            filtered_val_maes.append(val_filtered_mae.item())
    
            model.eval()
            with torch.no_grad():
                test_loss, test_mae = calculate_loss_and_mae(model, criterion, mae_criterion, X_test, y_test)
                test_filtered_loss, test_filtered_mae = calculate_loss_and_mae(model, criterion, mae_criterion, X_test, y_test, use_kalman_filter=True)
    
            losses.append(test_loss.item())
            maes.append(test_mae.item())
            filtered_losses.append(test_filtered_loss.item())
            filtered_maes.append(test_filtered_mae.item())
    
        fig, axs = plt.subplots(2, 2, figsize=(15, 10))
        
        maxy = max(max(train_losses), max(val_losses), max(filtered_train_losses), max(filtered_val_losses), max(train_maes), max(val_maes), max(filtered_train_maes), max(filtered_val_maes))
    
        axs[0, 0].plot(train_losses, label='Training MSE')
        axs[0, 0].plot(val_losses, label='Validation MSE')
        axs[0, 0].set_xlabel('Epoch')
        axs[0, 0].set_ylabel('MSE')
        axs[0, 0].set_ylim(0, maxy)
        axs[0, 0].legend()
        axs[0, 0].set_title('Training and Validation MSE')
    
        axs[0, 1].plot(filtered_train_losses, label='Filtered Training MSE')
        axs[0, 1].plot(filtered_val_losses, label='Filtered Validation MSE')
        axs[0, 1].set_xlabel('Epoch')
        axs[0, 1].set_ylabel('MSE')
        axs[0, 1].set_ylim(0, maxy)
        axs[0, 1].legend()
        axs[0, 1].set_title('Filtered Training and Validation MSE')
    
        axs[1, 0].plot(train_maes, label='Training MAE')
        axs[1, 0].plot(val_maes, label='Validation MAE')
        axs[1, 0].set_xlabel('Epoch')
        axs[1, 0].set_ylabel('MAE')
        axs[1, 0].set_ylim(0, maxy)
        axs[1, 0].legend()
        axs[1, 0].set_title('Training and Validation MAE')
    
        axs[1, 1].plot(filtered_train_maes, label='Filtered Training MAE')
        axs[1, 1].plot(filtered_val_maes, label='Filtered Validation MAE')
        axs[1, 1].set_xlabel('Epoch')
        axs[1, 1].set_ylabel('MAE')
        axs[1, 1].set_ylim(0, maxy)
        axs[1, 1].legend()
        axs[1, 1].set_title('Filtered Training and Validation MAE')
    
        fig.suptitle(f'Window Size: {window_size}')
        plt.tight_layout()
        plt.show()
    
        print("Average test MSE :", mean(losses))
        print("Average test MAE :", mean(maes))
    
        print("Average filtered test MSE :", mean(filtered_losses))
        print("Average filtered test MAE :", mean(filtered_maes))
        
for i in range(10):
    do_stuff()

Using device:  cuda


  warn("Encountered exception in computing model fit quality: " + str(e))


Evaluating with parameters: {'lr': 1.7200460023147145e-06, 'hidden_size': 246, 'num_layers': 3, 'dropout_rate_1': 0.5831521153450012, 'dropout_rate_2': 0.6510676383972167, 'dropout_rate_fc': 0.36499016284942626, 'weight_decay': 0.43428298778331276, 'batch_size': 85, 'num_epochs': 150}


  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.38639323388115837, Val Loss: 0.2935720980167389
Epoch [2/150], Loss: 0.38563668489519315, Val Loss: 0.2930510342121124
Epoch [3/150], Loss: 0.38521109109397156, Val Loss: 0.2925291955471039
Epoch [4/150], Loss: 0.384090798417643, Val Loss: 0.2920148968696594
Epoch [5/150], Loss: 0.38378547341957436, Val Loss: 0.2915027141571045
Epoch [6/150], Loss: 0.38269740917672546, Val Loss: 0.29099616408348083
Epoch [7/150], Loss: 0.3818275301196313, Val Loss: 0.29049786925315857
Epoch [8/150], Loss: 0.38060208078536945, Val Loss: 0.2900025546550751
Epoch [9/150], Loss: 0.38003199320223374, Val Loss: 0.28951606154441833
Epoch [10/150], Loss: 0.37940295454177814, Val Loss: 0.28903499245643616
Epoch [11/150], Loss: 0.3782068853092901, Val Loss: 0.28855812549591064
Epoch [12/150], Loss: 0.3776259412673318, Val Loss: 0.2880879044532776
Epoch [13/150], Loss: 0.37703609996933046, Val Loss: 0.28762325644493103
Epoch [14/150], Loss: 0.37612651260096136, Val Loss: 0.28717005252838135

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.2624893104657531, Val Loss: 0.23043808341026306
Epoch [2/150], Loss: 0.2587579749778946, Val Loss: 0.23002737760543823
Epoch [3/150], Loss: 0.2578749622408097, Val Loss: 0.23010821640491486
Epoch [4/150], Loss: 0.2583316520537336, Val Loss: 0.23033036291599274
Epoch [5/150], Loss: 0.2590135567901781, Val Loss: 0.23052702844142914
Epoch [6/150], Loss: 0.2594773297730301, Val Loss: 0.23061710596084595
Epoch [7/150], Loss: 0.2595448973314727, Val Loss: 0.23057201504707336
Epoch [8/150], Loss: 0.25912836190320926, Val Loss: 0.23038963973522186
Epoch [9/150], Loss: 0.25827099258733616, Val Loss: 0.23008683323860168
Epoch [10/150], Loss: 0.25702008770723994, Val Loss: 0.2296917736530304
Epoch [11/150], Loss: 0.25541033850886913, Val Loss: 0.22922973334789276
Epoch [12/150], Loss: 0.2535555067916299, Val Loss: 0.22873133420944214
Epoch [13/150], Loss: 0.25149591541809946, Val Loss: 0.22822287678718567
Epoch [14/150], Loss: 0.24928295324867622, Val Loss: 0.22772654891014

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.28153755897086286, Val Loss: 0.22677768766880035
Epoch [2/150], Loss: 0.23941119001045597, Val Loss: 0.2258443981409073
Epoch [3/150], Loss: 0.2352778013041125, Val Loss: 0.22567778825759888
Epoch [4/150], Loss: 0.2326231468923268, Val Loss: 0.22562570869922638
Epoch [5/150], Loss: 0.23057110593747543, Val Loss: 0.22565971314907074
Epoch [6/150], Loss: 0.22796768973149725, Val Loss: 0.2258041352033615
Epoch [7/150], Loss: 0.22562992363015802, Val Loss: 0.2260734736919403
Epoch [8/150], Loss: 0.2233927677078998, Val Loss: 0.2264787256717682
Epoch [9/150], Loss: 0.22095981038707005, Val Loss: 0.22710192203521729
Epoch [10/150], Loss: 0.21850686001783323, Val Loss: 0.2279781848192215
Epoch [11/150], Loss: 0.215982329753733, Val Loss: 0.2291620522737503
Epoch [12/150], Loss: 0.21350396011398765, Val Loss: 0.23061877489089966
Epoch [13/150], Loss: 0.21133091579003876, Val Loss: 0.23221440613269806
Epoch [14/150], Loss: 0.20952593816405227, Val Loss: 0.2338013648986816

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.3095720180709447, Val Loss: 0.24973388016223907
Epoch [2/150], Loss: 0.30675807563321933, Val Loss: 0.247860386967659
Epoch [3/150], Loss: 0.30317050157380954, Val Loss: 0.24623900651931763
Epoch [4/150], Loss: 0.29934117799358706, Val Loss: 0.2448509931564331
Epoch [5/150], Loss: 0.2966514539771846, Val Loss: 0.2436772882938385
Epoch [6/150], Loss: 0.2942447155714035, Val Loss: 0.24264425039291382
Epoch [7/150], Loss: 0.29278582829449856, Val Loss: 0.24175697565078735
Epoch [8/150], Loss: 0.29074714862342393, Val Loss: 0.24099864065647125
Epoch [9/150], Loss: 0.2883304869356964, Val Loss: 0.24034014344215393
Epoch [10/150], Loss: 0.2877640033939055, Val Loss: 0.23976607620716095
Epoch [11/150], Loss: 0.28602418122547013, Val Loss: 0.23926371335983276
Epoch [12/150], Loss: 0.2845903627973582, Val Loss: 0.23882706463336945
Epoch [13/150], Loss: 0.2835846660791763, Val Loss: 0.23845155537128448
Epoch [14/150], Loss: 0.28291409706164683, Val Loss: 0.2381216436624527

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.38507676934410306, Val Loss: 0.2714443504810333
Epoch [2/150], Loss: 0.3422239865568812, Val Loss: 0.24557064473628998
Epoch [3/150], Loss: 0.3013233460944086, Val Loss: 0.22625800967216492
Epoch [4/150], Loss: 0.266291497746152, Val Loss: 0.21911832690238953
Epoch [5/150], Loss: 0.2404950825542572, Val Loss: 0.22475329041481018
Epoch [6/150], Loss: 0.22556188371663372, Val Loss: 0.23343485593795776
Epoch [7/150], Loss: 0.21801282383291232, Val Loss: 0.23818837106227875
Epoch [8/150], Loss: 0.21176300298632875, Val Loss: 0.23963652551174164
Epoch [9/150], Loss: 0.2075324692297727, Val Loss: 0.23943443596363068
Epoch [10/150], Loss: 0.20363139167067504, Val Loss: 0.23770296573638916
Epoch [11/150], Loss: 0.20034019196736477, Val Loss: 0.23509648442268372
Epoch [12/150], Loss: 0.19596828944500416, Val Loss: 0.2330223023891449
Epoch [13/150], Loss: 0.19270318128138542, Val Loss: 0.23108166456222534
Epoch [14/150], Loss: 0.19123371487273083, Val Loss: 0.2291786968708

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.36730089658675225, Val Loss: 0.274469792842865
Epoch [2/150], Loss: 0.35034867334424663, Val Loss: 0.27080240845680237
Epoch [3/150], Loss: 0.34155541511350557, Val Loss: 0.2636260390281677
Epoch [4/150], Loss: 0.32634552430949715, Val Loss: 0.25486740469932556
Epoch [5/150], Loss: 0.3096432223621952, Val Loss: 0.2470821738243103
Epoch [6/150], Loss: 0.2944989145096195, Val Loss: 0.24091894924640656
Epoch [7/150], Loss: 0.2816092611633633, Val Loss: 0.2362452745437622
Epoch [8/150], Loss: 0.2708254796330278, Val Loss: 0.23278281092643738
Epoch [9/150], Loss: 0.2618431039522157, Val Loss: 0.23027220368385315
Epoch [10/150], Loss: 0.25437228358963404, Val Loss: 0.2284969687461853
Epoch [11/150], Loss: 0.2481627877568826, Val Loss: 0.22728128731250763
Epoch [12/150], Loss: 0.24300279229952904, Val Loss: 0.22648431360721588
Epoch [13/150], Loss: 0.23871409929798623, Val Loss: 0.22599506378173828
Epoch [14/150], Loss: 0.2351478122383062, Val Loss: 0.22572720050811768


  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.3321561343967915, Val Loss: 0.25825268030166626
Epoch [2/150], Loss: 0.319541310756044, Val Loss: 0.2525310814380646
Epoch [3/150], Loss: 0.3096259696239775, Val Loss: 0.2482045590877533
Epoch [4/150], Loss: 0.30162279250269586, Val Loss: 0.244896799325943
Epoch [5/150], Loss: 0.29492569765584037, Val Loss: 0.24228373169898987
Epoch [6/150], Loss: 0.28947007620537824, Val Loss: 0.24020060896873474
Epoch [7/150], Loss: 0.28502742527899416, Val Loss: 0.23851251602172852
Epoch [8/150], Loss: 0.28090474884110417, Val Loss: 0.23710711300373077
Epoch [9/150], Loss: 0.27784978102215313, Val Loss: 0.2359539419412613
Epoch [10/150], Loss: 0.2749295857819644, Val Loss: 0.2349773496389389
Epoch [11/150], Loss: 0.2721877534535121, Val Loss: 0.23413772881031036
Epoch [12/150], Loss: 0.27015708521516485, Val Loss: 0.2334299385547638
Epoch [13/150], Loss: 0.2683527613736012, Val Loss: 0.23282037675380707
Epoch [14/150], Loss: 0.26630964407019997, Val Loss: 0.23228439688682556
E

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.31032369772130447, Val Loss: 0.2499018758535385
Epoch [2/150], Loss: 0.31015921205299785, Val Loss: 0.24962212145328522
Epoch [3/150], Loss: 0.31043578951865575, Val Loss: 0.24934610724449158
Epoch [4/150], Loss: 0.30977088152782395, Val Loss: 0.24907700717449188
Epoch [5/150], Loss: 0.30855827686763454, Val Loss: 0.2488100379705429
Epoch [6/150], Loss: 0.30812560330707367, Val Loss: 0.24854421615600586
Epoch [7/150], Loss: 0.30808839118459597, Val Loss: 0.24828483164310455
Epoch [8/150], Loss: 0.3070463494766714, Val Loss: 0.24802690744400024
Epoch [9/150], Loss: 0.3066381316342092, Val Loss: 0.2477826029062271
Epoch [10/150], Loss: 0.30620231886201027, Val Loss: 0.24753136932849884
Epoch [11/150], Loss: 0.3047847171173484, Val Loss: 0.24729560315608978
Epoch [12/150], Loss: 0.3050206784677323, Val Loss: 0.24705635011196136
Epoch [13/150], Loss: 0.3048602563047887, Val Loss: 0.2468247413635254
Epoch [14/150], Loss: 0.3031202740835483, Val Loss: 0.246599450707435

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.37125684458259817, Val Loss: 0.27946433424949646
Epoch [2/150], Loss: 0.36111093538430283, Val Loss: 0.2730114758014679
Epoch [3/150], Loss: 0.3513003930638001, Val Loss: 0.26727038621902466
Epoch [4/150], Loss: 0.3425038147909747, Val Loss: 0.2622385323047638
Epoch [5/150], Loss: 0.3342928248429474, Val Loss: 0.2577550709247589
Epoch [6/150], Loss: 0.32677591185639765, Val Loss: 0.2537767291069031
Epoch [7/150], Loss: 0.3188858545927635, Val Loss: 0.2502041459083557
Epoch [8/150], Loss: 0.3124085494564717, Val Loss: 0.2470034658908844
Epoch [9/150], Loss: 0.30728780249417886, Val Loss: 0.24417079985141754
Epoch [10/150], Loss: 0.3010216185679234, Val Loss: 0.24161942303180695
Epoch [11/150], Loss: 0.29574935172024686, Val Loss: 0.2393299639225006
Epoch [12/150], Loss: 0.2911330923220717, Val Loss: 0.2373025119304657
Epoch [13/150], Loss: 0.2859675466744904, Val Loss: 0.23548923432826996
Epoch [14/150], Loss: 0.28138686971368937, Val Loss: 0.233893021941185
Epoch

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.28905487192569523, Val Loss: 0.24021539092063904
Epoch [2/150], Loss: 0.2845450994343712, Val Loss: 0.23866596817970276
Epoch [3/150], Loss: 0.2813841411229837, Val Loss: 0.2377292662858963
Epoch [4/150], Loss: 0.27918714016757823, Val Loss: 0.23717951774597168
Epoch [5/150], Loss: 0.2781475870870054, Val Loss: 0.236897274851799
Epoch [6/150], Loss: 0.27759334287902976, Val Loss: 0.23679710924625397
Epoch [7/150], Loss: 0.27748900693380735, Val Loss: 0.23680433630943298
Epoch [8/150], Loss: 0.27758081948671204, Val Loss: 0.23687665164470673
Epoch [9/150], Loss: 0.27772488343723856, Val Loss: 0.2369886338710785
Epoch [10/150], Loss: 0.2779309972762488, Val Loss: 0.23711124062538147
Epoch [11/150], Loss: 0.2782060839440469, Val Loss: 0.23722733557224274
Epoch [12/150], Loss: 0.27848228132147546, Val Loss: 0.2373277246952057
Epoch [13/150], Loss: 0.2787211753714543, Val Loss: 0.23740233480930328
Epoch [14/150], Loss: 0.27887486468918704, Val Loss: 0.2374452650547027

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.3041056174894466, Val Loss: 0.23206397891044617
Epoch [2/150], Loss: 0.25181388913416397, Val Loss: 0.22732125222682953
Epoch [3/150], Loss: 0.24077032026195644, Val Loss: 0.22648964822292328
Epoch [4/150], Loss: 0.23752460898101038, Val Loss: 0.226273313164711
Epoch [5/150], Loss: 0.23651801888812699, Val Loss: 0.22618164122104645
Epoch [6/150], Loss: 0.2360085942417237, Val Loss: 0.22611969709396362
Epoch [7/150], Loss: 0.23556091906685456, Val Loss: 0.22605863213539124
Epoch [8/150], Loss: 0.23505211895441308, Val Loss: 0.22599932551383972
Epoch [9/150], Loss: 0.23451458648139356, Val Loss: 0.22593896090984344
Epoch [10/150], Loss: 0.23393963062770518, Val Loss: 0.22587929666042328
Epoch [11/150], Loss: 0.23333137336314894, Val Loss: 0.22582097351551056
Epoch [12/150], Loss: 0.23260608945480166, Val Loss: 0.22576184570789337
Epoch [13/150], Loss: 0.2318624086842379, Val Loss: 0.22570879757404327
Epoch [14/150], Loss: 0.23095123746487148, Val Loss: 0.2256611734

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.3574807510369134, Val Loss: 0.2754947245121002
Epoch [2/150], Loss: 0.3542831078342789, Val Loss: 0.27335888147354126
Epoch [3/150], Loss: 0.35167005877675755, Val Loss: 0.27136319875717163
Epoch [4/150], Loss: 0.3480748681693661, Val Loss: 0.2695116698741913
Epoch [5/150], Loss: 0.3445585267428233, Val Loss: 0.2677985429763794
Epoch [6/150], Loss: 0.34154464325354417, Val Loss: 0.2661969065666199
Epoch [7/150], Loss: 0.3395192026901914, Val Loss: 0.26469185948371887
Epoch [8/150], Loss: 0.3365361247762886, Val Loss: 0.26329246163368225
Epoch [9/150], Loss: 0.3338565143699549, Val Loss: 0.26198750734329224
Epoch [10/150], Loss: 0.33120483512591037, Val Loss: 0.2607637047767639
Epoch [11/150], Loss: 0.3289101721756921, Val Loss: 0.2596139907836914
Epoch [12/150], Loss: 0.3272454316005567, Val Loss: 0.2585349380970001
Epoch [13/150], Loss: 0.3251203812601767, Val Loss: 0.2575162351131439
Epoch [14/150], Loss: 0.32288262953183483, Val Loss: 0.2565637528896332
Epoch 

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.39509963868643905, Val Loss: 0.29657483100891113
Epoch [2/150], Loss: 0.38866520857774334, Val Loss: 0.2923291325569153
Epoch [3/150], Loss: 0.38183348039623166, Val Loss: 0.28845322132110596
Epoch [4/150], Loss: 0.3752982386952785, Val Loss: 0.284925639629364
Epoch [5/150], Loss: 0.3707820794957339, Val Loss: 0.2817479372024536
Epoch [6/150], Loss: 0.3650274879344907, Val Loss: 0.2788664698600769
Epoch [7/150], Loss: 0.3601384802003864, Val Loss: 0.276231974363327
Epoch [8/150], Loss: 0.355563896906669, Val Loss: 0.27383187413215637
Epoch [9/150], Loss: 0.3515226973800874, Val Loss: 0.27163398265838623
Epoch [10/150], Loss: 0.34815638973453983, Val Loss: 0.2695888876914978
Epoch [11/150], Loss: 0.344980658108338, Val Loss: 0.26774686574935913
Epoch [12/150], Loss: 0.341630703815427, Val Loss: 0.26601719856262207
Epoch [13/150], Loss: 0.3382213467762607, Val Loss: 0.2644144892692566
Epoch [14/150], Loss: 0.33551791282828713, Val Loss: 0.2629396319389343
Epoch [15

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.3754686903741134, Val Loss: 0.27048996090888977
Epoch [2/150], Loss: 0.3378645782008104, Val Loss: 0.2595643699169159
Epoch [3/150], Loss: 0.32509363224500254, Val Loss: 0.2555959224700928
Epoch [4/150], Loss: 0.3189807582825218, Val Loss: 0.25338926911354065
Epoch [5/150], Loss: 0.3148103843264486, Val Loss: 0.25147783756256104
Epoch [6/150], Loss: 0.3108145431519728, Val Loss: 0.24945133924484253
Epoch [7/150], Loss: 0.3062670915323782, Val Loss: 0.24719543755054474
Epoch [8/150], Loss: 0.30115477736299373, Val Loss: 0.2447490245103836
Epoch [9/150], Loss: 0.2956116518216233, Val Loss: 0.24222515523433685
Epoch [10/150], Loss: 0.289716462258645, Val Loss: 0.23971936106681824
Epoch [11/150], Loss: 0.28370096732556105, Val Loss: 0.23732270300388336
Epoch [12/150], Loss: 0.27775701441425726, Val Loss: 0.23511400818824768
Epoch [13/150], Loss: 0.27199914880026743, Val Loss: 0.2331455796957016
Epoch [14/150], Loss: 0.26653355072700574, Val Loss: 0.23144029080867767


  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.3192296041739023, Val Loss: 0.24088649451732635
Epoch [2/150], Loss: 0.2750682629119124, Val Loss: 0.22954612970352173
Epoch [3/150], Loss: 0.2510419258246979, Val Loss: 0.22605782747268677
Epoch [4/150], Loss: 0.23709410199174477, Val Loss: 0.2257470339536667
Epoch [5/150], Loss: 0.228907853111835, Val Loss: 0.22651952505111694
Epoch [6/150], Loss: 0.22419257704468984, Val Loss: 0.2275024801492691
Epoch [7/150], Loss: 0.22108845101939814, Val Loss: 0.22838759422302246
Epoch [8/150], Loss: 0.21890989829065097, Val Loss: 0.22908934950828552
Epoch [9/150], Loss: 0.21782379559121065, Val Loss: 0.22959372401237488
Epoch [10/150], Loss: 0.21709655755830387, Val Loss: 0.22996830940246582
Epoch [11/150], Loss: 0.2162699914972989, Val Loss: 0.23023676872253418
Epoch [12/150], Loss: 0.2157152942471927, Val Loss: 0.23045334219932556
Epoch [13/150], Loss: 0.21561526986019267, Val Loss: 0.2305585741996765
Epoch [14/150], Loss: 0.21538227174671426, Val Loss: 0.230671390891075

  warn("Encountered exception in computing model fit quality: " + str(e))
  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.43938867727239783, Val Loss: 0.3245663046836853
Epoch [2/150], Loss: 0.43873702749089505, Val Loss: 0.324272483587265
Epoch [3/150], Loss: 0.4378850019797135, Val Loss: 0.3239773213863373
Epoch [4/150], Loss: 0.43682836888440174, Val Loss: 0.323682963848114
Epoch [5/150], Loss: 0.43675117351470844, Val Loss: 0.3233848214149475
Epoch [6/150], Loss: 0.4359564622179331, Val Loss: 0.3230940103530884
Epoch [7/150], Loss: 0.43637115754844513, Val Loss: 0.32280170917510986
Epoch [8/150], Loss: 0.43600217673434494, Val Loss: 0.3225134313106537
Epoch [9/150], Loss: 0.43492259199539307, Val Loss: 0.3222251236438751
Epoch [10/150], Loss: 0.4343405004371615, Val Loss: 0.32193970680236816
Epoch [11/150], Loss: 0.4347143950050368, Val Loss: 0.321659117937088
Epoch [12/150], Loss: 0.4333549762651434, Val Loss: 0.3213786482810974
Epoch [13/150], Loss: 0.43372963564744327, Val Loss: 0.3210982382297516
Epoch [14/150], Loss: 0.4330263484872001, Val Loss: 0.32082244753837585
Epoch [

  return F.mse_loss(input, target, reduction=self.reduction)
  return F.mse_loss(input, target, reduction=self.reduction)


Epoch [1/150], Loss: 0.36198062244802715, Val Loss: 0.27820777893066406
Epoch [2/150], Loss: 0.3615946862846613, Val Loss: 0.2780826687812805
Epoch [3/150], Loss: 0.36221623203406733, Val Loss: 0.2779569923877716
Epoch [4/150], Loss: 0.3615015803525845, Val Loss: 0.2778315544128418
Epoch [5/150], Loss: 0.36167225607981285, Val Loss: 0.27770593762397766
Epoch [6/150], Loss: 0.3610603898142775, Val Loss: 0.27758118510246277
Epoch [7/150], Loss: 0.36078667544449367, Val Loss: 0.27745571732521057
Epoch [8/150], Loss: 0.3613309249592324, Val Loss: 0.27733203768730164
Epoch [9/150], Loss: 0.3608045004929105, Val Loss: 0.2772078812122345
Epoch [10/150], Loss: 0.36001117223252854, Val Loss: 0.2770856022834778
Epoch [11/150], Loss: 0.3599764894073208, Val Loss: 0.2769635021686554
Epoch [12/150], Loss: 0.3600929093857606, Val Loss: 0.27684152126312256
Epoch [13/150], Loss: 0.35972173636158306, Val Loss: 0.27672079205513
Epoch [14/150], Loss: 0.3595679118918876, Val Loss: 0.27659985423088074
Epoc