# CNN

Use Resnet to train models to fit five travel behavioral variables by using images. The default is using only the black and white images with resnet18 to fit continous outputs, but the codes leave the flexibility of using other models, images, and output types. Several functions are not useful yet: bottleneck_resnet18, return_bottleneck_resnet18, and train_discrete_model. 

#### To be done: adjustment along many dimensions - hyperparameters, model choice, etc.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

from tqdm.notebook import tqdm

import pickle

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import util
import statsmodels.api as sm
from scipy import stats
import copy

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import log_loss
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import PolynomialFeatures

In [2]:
# ALWAYS choose devise first.
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Helper Functions

In [3]:
# CHANGED: filepath for loading datasets

def initialize_data(image_type, output_var, output_type, input_var, BE_var, num_categories, size):
    # outputs: randomized training and testing sets for NHTS, BE, images, and y.
    
    ### read image array
    if image_type == 'rgb':
        image_array_ = np.load("../data_process/image_array_rgb_tract_large.npy", mmap_mode='r')
        image_array = image_array_[:size,]
    elif image_type == 'bw':
        image_array_ = np.load("../data_process/image_array_bw_tract_large.npy", mmap_mode='r')
        image_array = image_array_[:size,]        
    elif image_type == 'merge':
        bw_image_array_ = np.load("../data_process/image_array_bw_tract_large.npy", mmap_mode='r')
        rgb_image_array_ = np.load("../data_process/image_array_rgb_tract_large.npy", mmap_mode='r')
        bw_image_array = bw_image_array_[:size,]
        rgb_image_array = rgb_image_array_[:size,]        
        image_array = np.concatenate([rgb_image_array, bw_image_array], axis=1)
    
    ### create output array
    df_ = pd.read_csv("../data_process/df_merged_tract_large.csv")
    df = df_.iloc[:size,]
    y_ = df[output_var].values 
    # cut y into categories for discrete variables
    if output_type == 'continuous':
        y = copy.deepcopy(y_)
    elif output_type == 'discrete':
        y = np.array(pd.qcut(y_, q = num_categories, labels=np.arange(num_categories))) 
    x = df[input_var]
    BE = df[BE_var]
            
    ### randomization
    shuffle_idx = np.arange(size)
    np.random.seed(0) # important: don't change the seed number, unless the seed number across scripts are all changed.
    np.random.shuffle(shuffle_idx)
    train_ratio = 0.8

    ###
    # y
    if output_type == 'discrete':
        y_train = y[shuffle_idx[:int(train_ratio*size)]].astype("int")
        y_test = y[shuffle_idx[int(train_ratio*size):]].astype("int")
    elif output_type == 'continuous':
        y_train = y[shuffle_idx[:int(train_ratio*size)]].astype("float32")
        y_test = y[shuffle_idx[int(train_ratio*size):]].astype("float32")
    # BE
    BE_train = BE.values[shuffle_idx[:int(train_ratio*size)]].astype("float32")
    BE_test = BE.values[shuffle_idx[int(train_ratio*size):]].astype("float32")        
    # image array
    x_train_images = image_array[shuffle_idx[:int(train_ratio*size)],].astype("float32")
    x_test_images = image_array[shuffle_idx[int(train_ratio*size):],].astype("float32")
    # NHTS
    x_train = x.values[shuffle_idx[:int(train_ratio*size)]].astype("float32")
    x_test = x.values[shuffle_idx[int(train_ratio*size):]].astype("float32")
    
    return y_train,y_test,BE_train,BE_test,x_train,x_test,x_train_images,x_test_images

# # test 
# image_type = 'bw'
# output_var = 'HHFAMINC_mean'
# output_type = 'continuous'
# input_var=['R_AGE_IMP_mean', 'HHSIZE_mean', 'HHFAMINC_mean', 'HBHTNRNT_mean', 'HBPPOPDN_mean', 'HBRESDN_mean', 
#            'R_SEX_IMP_2_mean', 'EDUC_2_mean', 'HH_RACE_2_mean', 'HOMEOWN_1_mean', 'HOMEOWN_2_mean',
#            'HBHUR_R_mean', 'HBHUR_S_mean', 'HBHUR_T_mean','HBHUR_U_mean']
# BE_var = ['density', 'diversity', 'design']
# num_categories = 1 # (1) certain category values can cause errors. (2) when output_type = 'continuous', this value needs to be 1.
# size = 10000 # size needs to be smaller than the max
# # 
# y_train,y_test,BE_train,BE_test,x_train,x_test,x_train_images,x_test_images = \
#     initialize_data(image_type, output_var, output_type, input_var, BE_var, num_categories, size)

# print(y_train.shape)
# print(y_test.shape)
# print(x_train_images.shape)
# print(x_test_images.shape)
# plt.figure()
# plt.boxplot(y_train)
# plt.figure()
# plt.boxplot(y_test)

In [4]:
def initialize_model(model_name, num_categories, input_channels = 3, use_pretrained=True, full_training=False):
    # initliaze the CNN model.
    # default input image size = 3*224*224, but inputs and output channels can be changed. 
    # num_categories: output channels. For continuous varialbes, use num_categories = 1.
    # return the model

    if model_name == 'resnet18':
        """ resnet 18"""
        model_ft = models.resnet18(pretrained=use_pretrained)
        # train only the last layer.
        for param in model_ft.parameters():
            param.requires_grad=full_training
        if input_channels != 3:
            # Edit the input channels.
            model_ft.conv1 = nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_categories) # if output_type == continuous, then num_categories = 1.

    elif model_name == 'alexnet':
        """ alexnet """
        model_ft = models.alexnet(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training 
        if input_channels != 3:
            model_ft.features[0] = nn.Conv2d(input_channels, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))    
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_categories) # if output_type == continuous, then num_categories = 1.

    elif model_name == 'vgg':
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training 
        if input_channels != 3:
            model_ft.features[0] = nn.Conv2d(input_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_categories) # if output_type == continuous, then num_categories = 1.

    elif model_name == 'squeezenet':
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training 
        if input_channels != 3:
            model_ft.features[0] = nn.Conv2d(input_channels, 96, kernel_size=(7, 7), stride=(2, 2))
        model_ft.classifier[1] = nn.Conv2d(512, num_categories, kernel_size=(1,1), stride=(1,1))

    elif model_name == 'densenet':
        model_ft = models.densenet121(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training
        if input_channels != 3:
            model_ft.features[0] = nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_categories)

    elif model_name == 'wide_resnet':
        model_ft = models.wide_resnet50_2(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training 
        if input_channels != 3:
            model_ft.conv1 = nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_categories)
        
    elif model_name == 'mnasnet':
        model_ft = models.mnasnet1_0(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training
        if input_channels != 3:
            model_ft.layers[0] = nn.Conv2d(input_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, num_categories)
        
    return model_ft

# # test 1. initialize model for continuous var
# model_name = 'resnet18'
# num_categories = 1 
# input_channels = 4
# use_pretrained = True
# full_training = True
# model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
# model.to(device)

# # test 2. initialize model for discrete var
# model_name = 'resnet18'
# # num_categories = 1 
# input_channels = 4
# use_pretrained = True
# full_training = True
# model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
# model.to(device)

# # test 3. initialize model for continuous var
# model_name = 'bottleneck_resnet18'
# num_categories = 1 
# input_channels = 4
# use_pretrained = True
# full_training = True
# model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
# model.to(device)


In [5]:
# class bottleneck_resnet18(nn.Module):
#     # This model does NOT work yet. It seems that the fc layer or the upsampling do not work...
#     # Goal: create a resnet architecture with bottleneck in the middle that reduces information into several nodes.
#     def __init__(self, num_categories, num_bottleneck, input_channels = 3, use_pretrained=True, full_training=False):
#         super(bottleneck_resnet18, self).__init__()
#         ref = models.resnet18(pretrained=use_pretrained)
#         self.sequence1 = nn.Sequential(ref.conv1, ref.bn1, ref.relu, ref.maxpool, ref.layer1,
#                                        ref.layer2)
#         ### condense 
#         if num_bottleneck == 1:
#             self.condense = nn.AvgPool3d((128,28,28))
#         elif num_bottleneck == 2:
#             self.condense = nn.AvgPool3d((128,28,14))
#         elif num_bottleneck == 3:
#             self.condense = nn.AvgPool3d((128,28,9))

#         ### upsampling
#         self.upsample = nn.Sequential(nn.Conv2d(num_bottleneck, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
#                                       nn.Upsample((28, 28)))
#         self.sequence2 = nn.Sequential(ref.layer3, ref.layer4, ref.avgpool)
#         self.fc = ref.fc
        
#         ### edit parameters
#         for param in self.parameters():
#             param.requires_grad=full_training
#         if input_channels != 3:
#             self.sequence1[0]=nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#         num_ftrs=self.fc.in_features
#         self.fc=nn.Linear(num_ftrs, num_categories)
        
#     def forward(self, x):
#         x=self.sequence1(x)
#         x=self.condense(x)
#         x=self.upsample(x)
#         x=self.sequence2(x)
#         x=x.squeeze() # sw: this line is important, but I don't understand why resnet18 does not need it...
#         out=self.fc(x)
#         return out

In [6]:
# def return_bottleneck_renset(model,device,x_train_images,x_test_images,y_train,y_test):
#     # This function does not work yet.
#     # Goal: return the several nodes' values from the bottleneck resnet architecture.
#     from sklearn.preprocessing import MinMaxScaler

#     bottleneck_train_list = []
#     def hook_train(module,inputs,outputs):
#         bottleneck_train_list.append(outputs)
        
#     bottleneck_test_list = []
#     def hook_test(module,inputs,outputs):
#         bottleneck_test_list.append(outputs)

#     x_train_images_norm = x_train_images/255
#     x_test_images_norm = x_test_images/255

#     x_train_torch = torch.from_numpy(x_train_images_norm)
#     x_test_torch = torch.from_numpy(x_test_images_norm)
#     y_train_torch = torch.from_numpy(y_train)
#     y_test_torch = torch.from_numpy(y_test)

#     # create data loader: train and test. 
#     train_ds = TensorDataset(x_train_torch, y_train_torch)
#     batch_size = 50
#     train_dl_no_shuffle = DataLoader(train_ds, batch_size, shuffle = False) # important: NO SHUFFLE.

#     test_ds = TensorDataset(x_test_torch, y_test_torch)
#     batch_size = 50
#     test_dl_no_shuffle = DataLoader(test_ds, batch_size, shuffle = False) # important: NO SHUFFLE.

#     for param in model.parameters():
#         param.requires_grad=False # save space
    
#     for inputs, labels in train_dl_no_shuffle:
#         # to device
#         inputs = inputs.to(device)
#         labels = labels.to(device)
#         outputs = model(inputs)
#         bottleneck_train_list = model.layer3[1].conv1.register_forward_hook(hook_train)

#     for inputs, labels in test_dl_no_shuffle:
#         inputs = inputs.to(device)
#         labels = labels.to(device)
#         outputs = model(inputs)
#         bottleneck_test_list = model.layer3[1].conv1.register_forward_hook(hook_test)

#     return bottleneck_train_list,bottleneck_test_list


In [7]:
# def train_discrete_model(model, train_dl, test_dl, criterion, optimizer, device, n_epoch = 25):
#     # Train a model with discrete outputs.
#     # Outputs: model; training and testing accuracy/log-loss.
#     # But so far this function is not used because of bad performance on discrete outputs.
#     log_loss_train_list=[]
#     log_loss_test_list=[]
#     accuracy_train_list=[]
#     accuracy_test_list=[]

#     # automatic model searching.
#     best_model_wts = copy.deepcopy(model.state_dict())
#     best_acc = 0.0
    
#     # iterate
#     for epoch in range(n_epoch):
    
#         running_log_loss_train = 0.0
#         running_log_loss_test = 0.0
#         correct_train = 0
#         total_train = 0
#         correct_test = 0
#         total_test = 0

#         # training    
#         for inputs, labels in tqdm(train_dl):
#             # to device
#             inputs = inputs.to(device)
#             labels = labels.to(device)

#             # forward + backward
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()

#             # evaluate prediction
#             _, predicted = torch.max(outputs.data, 1)
#             total_train += labels.size(0)
#             correct_train += (predicted == labels).sum().item()

#             # evaluate log loss
#             running_log_loss_train += loss.item()

#             # optimize
#             with torch.no_grad():
#                 optimizer.step()
#                 optimizer.zero_grad()

#         # testing
#         for inputs, labels in test_dl:
#             # to device
#             inputs = inputs.to(device)
#             labels = labels.to(device)

#             # forward + backward
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)

#             # evaluate log loss
#             running_log_loss_test += loss.item()

#             # evaluate prediction
#             _, predicted = torch.max(outputs.data, 1)
#             total_test += labels.size(0)
#             correct_test += (predicted == labels).sum().item()
        
#         # print
#         print("Epoch {}: Training Loss {}; Testing Loss {}".format(epoch, running_log_loss_train, running_log_loss_test))
#         print("Epoch {}: Training Accuracy {}; Testing Accuracy {}".format(epoch, correct_train/total_train, correct_test/total_test))

#         # append loss here.
#         log_loss_train_list.append(running_log_loss_train)
#         log_loss_test_list.append(running_log_loss_test)
#         accuracy_train_list.append(correct_train/total_train)
#         accuracy_test_list.append(correct_test/total_test)
        
#         if correct_test/total_test > best_acc:
#             best_acc = correct_test/total_test
#             best_model_wts = copy.deepcopy(model.state_dict())

#     # load the best model weights
#     model.load_state_dict(best_model_wts)
#     return model, log_loss_train_list, log_loss_test_list, accuracy_train_list, accuracy_test_list


In [8]:
def train_continuous_model(model, train_dl, test_dl, criterion, optimizer, device, total_mse_train, total_mse_test, n_epoch = 25):
    # This function trains the model with continous outputs.
    # outputs: model, and R2 and MSE for training and testing
    mse_train_list = []
    mse_test_list = []
    r_square_train_list = []
    r_square_test_list = []

    # automatic model searching.
    best_model_wts = copy.deepcopy(model.state_dict())
    best_r_square = 0.0
    
    for epoch in range(n_epoch): 
        running_mse_train = 0.0
        running_mse_test = 0.0
        
        # training
        for inputs, labels in tqdm(train_dl):
            # to device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # forward + backward
            outputs = model(inputs)
            # sw: be careful about the dimension matching at this point...
            loss = criterion(outputs.view(-1), labels) # this .view(-1) seems specific to continuous variables
            loss.backward()

            # performance
            running_mse_train += loss.item()*batch_size

            # optimize
            with torch.no_grad():
                optimizer.step()
                optimizer.zero_grad()

        # testing
        for inputs, labels in test_dl:
            # to device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # forward + backward        
            outputs = model(inputs)
            loss = criterion(outputs.view(-1), labels) # this .view(-1) is specific to continuous variables
            running_mse_test += loss.item()*batch_size # this *batch_size is specific to continuous variables.

        # R square for a batch
        running_r_square_train = 1-running_mse_train/total_mse_train.item()
        running_r_square_test = 1-running_mse_test/total_mse_test.item()
        
        print("Epoch {}: Training MSE {}; Testing MSE {}".format(epoch, running_mse_train, running_mse_test))
        print("Epoch {}: Training R2 {}; Testing R2 {}".format(epoch, running_r_square_train, running_r_square_test))

        mse_train_list.append(running_mse_train)
        mse_test_list.append(running_mse_test)
        r_square_train_list.append(running_r_square_train)
        r_square_test_list.append(running_r_square_test)

        # check overfitting for early stopping. This is designed in an ad-hoc way.
        if epoch > 5 and running_r_square_test < 0.0:
            break # break the for loop.
        
        # store the best performance.
        if running_r_square_test > best_r_square:
            best_r_square = running_r_square_test
            best_model_wts = copy.deepcopy(model.state_dict())

    # load the weights of the best model
    model.load_state_dict(best_model_wts)
    return model, mse_train_list, mse_test_list, r_square_train_list, r_square_test_list


## Train resnet18 for continous outputs.

In [9]:
#CHANGED: only selected needed vars
# set up.
output_list = ['HHVEHCNT_mean_norm', 'HHVEHCNT_P_CAP_mean_norm', 'TRPTRANS_1_mean_norm', 'TRPTRANS_2_mean_norm', 'TRPTRANS_3_mean_norm']
input_var=['R_AGE_IMP_mean', 'HHSIZE_mean', 'HHFAMINC_mean', 'HBHTNRNT_mean', 'HBPPOPDN_mean', 'HBRESDN_mean', 
           'R_SEX_IMP_2_mean', 'EDUC_2_mean', 'HH_RACE_2_mean', 'HOMEOWN_1_mean', 'HOMEOWN_2_mean',
           'HBHUR_R_mean', 'HBHUR_S_mean', 'HBHUR_T_mean','HBHUR_U_mean']
BE_var = ['density', 'diversity', 'design']
image_types = ['bw', 'rgb', 'merge'] # It can be 'rgb', 'bw', 'merge'
output_type = 'continuous' 
num_categories = 1 # Certain category values can cause errors. When output_type = 'continuous', this value needs to be 1.
size = 12000 # size needs to be smaller than the max (18491).

model_name = 'resnet18' 
model_dic = {}


In [10]:
for image_type in image_types:
    performance_continuous = {}
    
    for output_var in output_list:

        print(output_var)

        # data set up
        y_train,y_test,BE_train,BE_test,x_train,x_test,x_train_images,x_test_images = \
            initialize_data(image_type, output_var, output_type, input_var, BE_var, num_categories, size)

        # process data
        x_train_images_norm = x_train_images/255 # very crude processing. It is improvable.
        x_test_images_norm = x_test_images/255

        x_train_torch = torch.from_numpy(x_train_images_norm)
        x_test_torch = torch.from_numpy(x_test_images_norm)
        y_train_torch = torch.from_numpy(y_train)
        y_test_torch = torch.from_numpy(y_test)

        # create data loader: train and test. 
        train_ds = TensorDataset(x_train_torch, y_train_torch)
        batch_size = 100
        train_dl = DataLoader(train_ds, batch_size, shuffle = True)

        test_ds = TensorDataset(x_test_torch, y_test_torch)
        batch_size = 100
        test_dl = DataLoader(test_ds, batch_size, shuffle = True)

        # model set up 
        #TODO: Change use_pretrained, full_training as needed
        input_channels = 4 # 4 for BW images; 3 for RGB images; 7 for merged images.
        use_pretrained = True # unclear whether True or False is better.
        full_training = True # Fully retraining the network seems to work better.
    #     num_bottleneck = 3 # Used for the bottleneck model

        if model_name == 'bottleneck_resnet18': # It does not work.
            model = bottleneck_resnet18(num_categories, num_bottleneck, input_channels, use_pretrained, full_training)
            model.to(device)
        else: 
            # 'resnet18' and others works
            if image_type == "rgb":
                input_channels = 3
            elif image_type == "bw":
                input_channels = 4
            elif image_type == "merge":
                input_channels = 7
            model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
            model.to(device)

        # training set up
        criterion = nn.MSELoss(reduction='mean')
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        n_epoch = 25

        # create baseline mse
        total_mse_train = criterion(y_train_torch.mean().repeat(y_train_torch.size()), y_train_torch)*y_train_torch.size()[0]
        total_mse_test = criterion(y_test_torch.mean().repeat(y_test_torch.size()), y_test_torch)*y_test_torch.size()[0]
        print(total_mse_train)
        print(total_mse_test)

        # training here.
        model, mse_train_list, mse_test_list, r_square_train_list, r_square_test_list = \
            train_continuous_model(model, train_dl, test_dl, criterion, optimizer, device, total_mse_train, total_mse_test, n_epoch)

        # save models.
        PATH = './models/'+model_name+'_'+output_var+'_'+image_type+'.pth'
        torch.save(model.state_dict(), PATH)
        model_dic[output_var]=model.state_dict()

        # save performance
        performance_continuous[output_var] = {}
        performance_continuous[output_var]['mse_train_list']=mse_train_list
        performance_continuous[output_var]['mse_test_list']=mse_test_list
        performance_continuous[output_var]['r_square_train_list']=r_square_train_list
        performance_continuous[output_var]['r_square_test_list']=r_square_test_list

        print("Printing performance_continuous")
        print(performance_continuous)
        %store performance_continuous
        
                
        with open('outputs/performance_continuous_'+image_type+'_'+model_name+'.pickle', 'wb') as h:
            pickle.dump(performance_continuous, h, protocol=pickle.HIGHEST_PROTOCOL)


HHVEHCNT_mean_norm
tensor(10700.4521)
tensor(2549.4304)


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12698.27309846878; Testing MSE 2314.804685115814
Epoch 0: Training R2 -0.1867043487805331; Testing R2 0.09203064848235809


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9758.412963151932; Testing MSE 2336.7906153202057
Epoch 1: Training R2 0.08803732517257479; Testing R2 0.08340678880272612


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9714.318877458572; Testing MSE 2354.092448949814
Epoch 2: Training R2 0.09215809362998972; Testing R2 0.07662024013114554


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9523.651683330536; Testing MSE 2317.954343557358
Epoch 3: Training R2 0.10997670460858078; Testing R2 0.09079521235633903


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9587.58698105812; Testing MSE 2328.8352251052856
Epoch 4: Training R2 0.10400169562385098; Testing R2 0.08652724667157197


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9450.506654381752; Testing MSE 2286.034721136093
Epoch 5: Training R2 0.11681239976744973; Testing R2 0.10331550793759392


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 9238.332736492157; Testing MSE 2286.6768062114716
Epoch 6: Training R2 0.13664089999774864; Testing R2 0.10306365361344327


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8830.137184262276; Testing MSE 2306.646591424942
Epoch 7: Training R2 0.1747884050346724; Testing R2 0.09523061566997892


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 8388.996928930283; Testing MSE 2450.3201246261597
Epoch 8: Training R2 0.21601472418571965; Testing R2 0.03887546587710067


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 7773.921477794647; Testing MSE 2528.3136904239655
Epoch 9: Training R2 0.2734959822300771; Testing R2 0.008282920503692948


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 6634.627676010132; Testing MSE 2665.255695581436
Epoch 10: Training R2 0.37996753931758553; Testing R2 -0.04543182459677042
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12698.27309846878, 9758.412963151932, 9714.318877458572, 9523.651683330536, 9587.58698105812, 9450.506654381752, 9238.332736492157, 8830.137184262276, 8388.996928930283, 7773.921477794647, 6634.627676010132], 'mse_test_list': [2314.804685115814, 2336.7906153202057, 2354.092448949814, 2317.954343557358, 2328.8352251052856, 2286.034721136093, 2286.6768062114716, 2306.646591424942, 2450.3201246261597, 2528.3136904239655, 2665.255695581436], 'r_square_train_list': [-0.1867043487805331, 0.08803732517257479, 0.09215809362998972, 0.10997670460858078, 0.10400169562385098, 0.11681239976744973, 0.13664089999774864, 0.1747884050346724, 0.21601472418571965, 0.2734959822300771, 0.37996753931758553], 'r_square_test_list': [0.09203064848235809, 0.08340678880272612, 0.07662024013114

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 13178.774550557137; Testing MSE 2224.8687148094177
Epoch 0: Training R2 -0.2306376258657632; Testing R2 0.09039273885642651


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9719.109898805618; Testing MSE 2222.6975440979004
Epoch 1: Training R2 0.0924268196932665; Testing R2 0.09128039242053687


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9641.036340594292; Testing MSE 2273.9757001399994
Epoch 2: Training R2 0.09971735023170847; Testing R2 0.07031601696616685


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9623.26552271843; Testing MSE 2209.157007932663
Epoch 3: Training R2 0.10137679413801004; Testing R2 0.09681625614754885


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9596.776592731476; Testing MSE 2219.0219402313232
Epoch 4: Training R2 0.10385033777333164; Testing R2 0.09278311298282071


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9397.071981430054; Testing MSE 2253.542798757553
Epoch 5: Training R2 0.12249880981324812; Testing R2 0.07866973030663948


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 9086.2719476223; Testing MSE 2210.0871920585632
Epoch 6: Training R2 0.15152140324611052; Testing R2 0.0964359630410444


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8572.333234548569; Testing MSE 2243.8052773475647
Epoch 7: Training R2 0.1995131429386925; Testing R2 0.08265078326545428


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7629.321679472923; Testing MSE 2311.533111333847
Epoch 8: Training R2 0.2875718237250037; Testing R2 0.05496118110359094


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 6651.62670314312; Testing MSE 2502.2147059440613
Epoch 9: Training R2 0.37886925201590105; Testing R2 -0.022996390895772434
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12698.27309846878, 9758.412963151932, 9714.318877458572, 9523.651683330536, 9587.58698105812, 9450.506654381752, 9238.332736492157, 8830.137184262276, 8388.996928930283, 7773.921477794647, 6634.627676010132], 'mse_test_list': [2314.804685115814, 2336.7906153202057, 2354.092448949814, 2317.954343557358, 2328.8352251052856, 2286.034721136093, 2286.6768062114716, 2306.646591424942, 2450.3201246261597, 2528.3136904239655, 2665.255695581436], 'r_square_train_list': [-0.1867043487805331, 0.08803732517257479, 0.09215809362998972, 0.10997670460858078, 0.10400169562385098, 0.11681239976744973, 0.13664089999774864, 0.1747884050346724, 0.21601472418571965, 0.2734959822300771, 0.37996753931758553], 'r_square_test_list': [0.09203064848235809, 0.08340678880272612, 0.076620240131145

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 13725.89420080185; Testing MSE 2283.8355660438538
Epoch 0: Training R2 -0.17924990161569077; Testing R2 0.2484952594872991


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9146.174371242523; Testing MSE 2283.470603823662
Epoch 1: Training R2 0.21421329134114886; Testing R2 0.24861535168774473


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9004.518634080887; Testing MSE 2763.4383261203766
Epoch 2: Training R2 0.2263835377139789; Testing R2 0.0906801553181219


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8918.348050117493; Testing MSE 2161.319887638092
Epoch 3: Training R2 0.23378681878072805; Testing R2 0.288809507359596


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8653.98119688034; Testing MSE 2238.4271383285522
Epoch 4: Training R2 0.25649969862007216; Testing R2 0.2634370745613105


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8453.361085057259; Testing MSE 2048.022946715355
Epoch 5: Training R2 0.27373582500050697; Testing R2 0.32609029475731244


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8166.652327775955; Testing MSE 2225.19508600235
Epoch 6: Training R2 0.29836819276249815; Testing R2 0.26779113148103806


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8106.5434366464615; Testing MSE 2346.279853582382
Epoch 7: Training R2 0.30353240304371565; Testing R2 0.22794773023394255


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7207.327029109001; Testing MSE 2263.2843792438507
Epoch 8: Training R2 0.3807879060081649; Testing R2 0.2552577053187727


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 6560.543364286423; Testing MSE 2410.438573360443
Epoch 9: Training R2 0.43635583928453425; Testing R2 0.2068360605606483


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 5219.430768489838; Testing MSE 2367.5449073314667
Epoch 10: Training R2 0.5515765217050863; Testing R2 0.22095038377989074


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 4461.051334440708; Testing MSE 2514.6288752555847
Epoch 11: Training R2 0.6167321217633728; Testing R2 0.17255184721630423


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 3718.065884709358; Testing MSE 2668.5931503772736
Epoch 12: Training R2 0.6805651592092246; Testing R2 0.12188931951778903


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 3317.39065349102; Testing MSE 2619.6911334991455
Epoch 13: Training R2 0.7149888710695758; Testing R2 0.13798071333394069


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 2550.241394340992; Testing MSE 2456.189316511154
Epoch 14: Training R2 0.7808979240713374; Testing R2 0.1917816052964736


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 2157.28649944067; Testing MSE 2499.432522058487
Epoch 15: Training R2 0.8146583490295553; Testing R2 0.1775522647752209


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 1866.8771795928478; Testing MSE 2434.265774488449
Epoch 16: Training R2 0.8396086478478878; Testing R2 0.19899562981025765


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 1682.0489522069693; Testing MSE 2380.079036951065
Epoch 17: Training R2 0.8554880263256752; Testing R2 0.21682598096938288


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 1583.3694495260715; Testing MSE 2398.5252618789673
Epoch 18: Training R2 0.8639660017584988; Testing R2 0.21075618081214342


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 1562.9783250391483; Testing MSE 2438.1775200366974
Epoch 19: Training R2 0.865717889919172; Testing R2 0.19770845512618862


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 1448.8191336393356; Testing MSE 2343.6685264110565
Epoch 20: Training R2 0.8755257911937492; Testing R2 0.22880699732718612


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 1392.327207699418; Testing MSE 2429.451882839203
Epoch 21: Training R2 0.8803792526245415; Testing R2 0.2005796590847424


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 1303.7259869277477; Testing MSE 2427.069240808487
Epoch 22: Training R2 0.887991360028948; Testing R2 0.20136367646657594


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 1321.748649328947; Testing MSE 2385.44242978096
Epoch 23: Training R2 0.886442956511294; Testing R2 0.2150611362507765


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 1260.2558508515358; Testing MSE 2318.5645163059235
Epoch 24: Training R2 0.8917260641539513; Testing R2 0.23706756690600495
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12698.27309846878, 9758.412963151932, 9714.318877458572, 9523.651683330536, 9587.58698105812, 9450.506654381752, 9238.332736492157, 8830.137184262276, 8388.996928930283, 7773.921477794647, 6634.627676010132], 'mse_test_list': [2314.804685115814, 2336.7906153202057, 2354.092448949814, 2317.954343557358, 2328.8352251052856, 2286.034721136093, 2286.6768062114716, 2306.646591424942, 2450.3201246261597, 2528.3136904239655, 2665.255695581436], 'r_square_train_list': [-0.1867043487805331, 0.08803732517257479, 0.09215809362998972, 0.10997670460858078, 0.10400169562385098, 0.11681239976744973, 0.13664089999774864, 0.1747884050346724, 0.21601472418571965, 0.2734959822300771, 0.37996753931758553], 'r_square_test_list': [0.09203064848235809, 0.08340678880272612, 0.07662024013114

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11934.712228178978; Testing MSE 2172.9994982481003
Epoch 0: Training R2 -0.020893213020838353; Testing R2 0.26635525335279875


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 8233.031541109085; Testing MSE 2082.991388440132
Epoch 1: Training R2 0.2957479106149158; Testing R2 0.2967436528758992


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 8156.334739923477; Testing MSE 1959.7169697284698
Epoch 2: Training R2 0.3023085416794775; Testing R2 0.33836337241873027


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 7841.441988945007; Testing MSE 1881.4141035079956
Epoch 3: Training R2 0.32924441295623574; Testing R2 0.3647998656146024


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 7925.807556509972; Testing MSE 1808.6864858865738
Epoch 4: Training R2 0.3220277969463108; Testing R2 0.3893540519580134


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 7596.035286784172; Testing MSE 1929.945307970047
Epoch 5: Training R2 0.3502364596747388; Testing R2 0.34841483504706194


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 7524.946087598801; Testing MSE 1839.4424229860306
Epoch 6: Training R2 0.35631741743726253; Testing R2 0.378970279803708


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 7439.887264370918; Testing MSE 1854.7436714172363
Epoch 7: Training R2 0.36359333441629593; Testing R2 0.37380429585490915


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7242.333588004112; Testing MSE 1947.5859016180038
Epoch 8: Training R2 0.3804920416121007; Testing R2 0.34245904496611834


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 6941.767480969429; Testing MSE 1881.9247871637344
Epoch 9: Training R2 0.40620241425194625; Testing R2 0.36462744938462477


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 6271.221801638603; Testing MSE 1987.9812180995941
Epoch 10: Training R2 0.4635607753050923; Testing R2 0.32882084037851367


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 5706.1218827962875; Testing MSE 1963.9607429504395
Epoch 11: Training R2 0.5118993243035856; Testing R2 0.3369305962341199


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 4830.140113830566; Testing MSE 2278.6487579345703
Epoch 12: Training R2 0.5868306528156888; Testing R2 0.23068611287732532


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 4323.393513262272; Testing MSE 2109.8532795906067
Epoch 13: Training R2 0.6301776690948144; Testing R2 0.287674582522481


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 3458.4086060523987; Testing MSE 2081.510156393051
Epoch 14: Training R2 0.7041683279605588; Testing R2 0.2972437441602196


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 2858.09658318758; Testing MSE 2200.3845632076263
Epoch 15: Training R2 0.755518914226937; Testing R2 0.257109549863062


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 2280.0650514662266; Testing MSE 2174.246710538864
Epoch 16: Training R2 0.8049636311471379; Testing R2 0.2659341714584814


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 1966.5565438568592; Testing MSE 2057.4227571487427
Epoch 17: Training R2 0.8317810944863047; Testing R2 0.3053760948258446


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 1593.4222653508186; Testing MSE 2200.2430975437164
Epoch 18: Training R2 0.8636989359213777; Testing R2 0.2571573112828147


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 1439.3052078783512; Testing MSE 2035.923010110855
Epoch 19: Training R2 0.8768820822743243; Testing R2 0.31263480633558327


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 1380.7879984378815; Testing MSE 1969.185882806778
Epoch 20: Training R2 0.8818876342156307; Testing R2 0.335166492556604


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 1257.6554600149393; Testing MSE 2084.0125739574432
Epoch 21: Training R2 0.8924203701856868; Testing R2 0.29639888179301055


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 1216.8292865157127; Testing MSE 2154.5680224895477
Epoch 22: Training R2 0.895912633982426; Testing R2 0.2725780598348614


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 1094.7719804942608; Testing MSE 2050.61793923378
Epoch 23: Training R2 0.906353394759439; Testing R2 0.30767352697860306


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 1010.9495084732771; Testing MSE 2109.115982055664
Epoch 24: Training R2 0.9135235544707728; Testing R2 0.28792350778162745
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12698.27309846878, 9758.412963151932, 9714.318877458572, 9523.651683330536, 9587.58698105812, 9450.506654381752, 9238.332736492157, 8830.137184262276, 8388.996928930283, 7773.921477794647, 6634.627676010132], 'mse_test_list': [2314.804685115814, 2336.7906153202057, 2354.092448949814, 2317.954343557358, 2328.8352251052856, 2286.034721136093, 2286.6768062114716, 2306.646591424942, 2450.3201246261597, 2528.3136904239655, 2665.255695581436], 'r_square_train_list': [-0.1867043487805331, 0.08803732517257479, 0.09215809362998972, 0.10997670460858078, 0.10400169562385098, 0.11681239976744973, 0.13664089999774864, 0.1747884050346724, 0.21601472418571965, 0.2734959822300771, 0.37996753931758553], 'r_square_test_list': [0.09203064848235809, 0.08340678880272612, 0.076620240131145

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12725.117462873459; Testing MSE 2492.29052066803
Epoch 0: Training R2 -0.10969263609797753; Testing R2 0.14623971501774435


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9510.696265101433; Testing MSE 2648.7089663743973
Epoch 1: Training R2 0.1706206531814276; Testing R2 0.09265693416804033


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9290.560698509216; Testing MSE 2532.785078883171
Epoch 2: Training R2 0.18981755394901212; Testing R2 0.13236787893953195


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9162.933035194874; Testing MSE 2537.1327936649323
Epoch 3: Training R2 0.20094731196936255; Testing R2 0.13087852359338914


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8825.788486003876; Testing MSE 2461.0778480768204
Epoch 4: Training R2 0.2303479697337697; Testing R2 0.1569319437228418


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8655.535352230072; Testing MSE 2430.72871863842
Epoch 5: Training R2 0.24519487777784932; Testing R2 0.1673283566544076


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8374.01834577322; Testing MSE 2639.1855776309967
Epoch 6: Training R2 0.26974454106488455; Testing R2 0.09591927097032582


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8116.722932457924; Testing MSE 2544.516059756279
Epoch 7: Training R2 0.29218196266753715; Testing R2 0.1283493082752013


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7470.026263594627; Testing MSE 2483.2698851823807
Epoch 8: Training R2 0.348577083052119; Testing R2 0.14932982841306586


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 7301.975579559803; Testing MSE 2501.9387274980545
Epoch 9: Training R2 0.3632319266799877; Testing R2 0.1429346204693922


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 6697.233983874321; Testing MSE 2541.482964158058
Epoch 10: Training R2 0.4159683589708676; Testing R2 0.12938832701753533


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 6579.5823857188225; Testing MSE 2750.6674259901047
Epoch 11: Training R2 0.4262281551950853; Testing R2 0.05772999334152473


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 6145.277214050293; Testing MSE 2784.6242010593414
Epoch 12: Training R2 0.46410169563398096; Testing R2 0.04609774352162044


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 5455.945283174515; Testing MSE 2799.0132957696915
Epoch 13: Training R2 0.5242148199137249; Testing R2 0.041168608054198086


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 5037.167744338512; Testing MSE 2816.4340794086456
Epoch 14: Training R2 0.5607342746350967; Testing R2 0.03520093571388805


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 4714.30324614048; Testing MSE 3012.7114057540894
Epoch 15: Training R2 0.588889641935488; Testing R2 -0.03203592318622328
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12698.27309846878, 9758.412963151932, 9714.318877458572, 9523.651683330536, 9587.58698105812, 9450.506654381752, 9238.332736492157, 8830.137184262276, 8388.996928930283, 7773.921477794647, 6634.627676010132], 'mse_test_list': [2314.804685115814, 2336.7906153202057, 2354.092448949814, 2317.954343557358, 2328.8352251052856, 2286.034721136093, 2286.6768062114716, 2306.646591424942, 2450.3201246261597, 2528.3136904239655, 2665.255695581436], 'r_square_train_list': [-0.1867043487805331, 0.08803732517257479, 0.09215809362998972, 0.10997670460858078, 0.10400169562385098, 0.11681239976744973, 0.13664089999774864, 0.1747884050346724, 0.21601472418571965, 0.2734959822300771, 0.37996753931758553], 'r_square_test_list': [0.09203064848235809, 0.08340678880272612, 0.0766202401311455

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11315.364849567413; Testing MSE 2314.13556933403
Epoch 0: Training R2 -0.05746604840616043; Testing R2 0.0922931054517876


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9530.11499941349; Testing MSE 2314.4658267498016
Epoch 1: Training R2 0.10937268190063398; Testing R2 0.09216356380468449


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9370.144230127335; Testing MSE 2233.00239443779
Epoch 2: Training R2 0.12432258935006024; Testing R2 0.12411714515188921


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9313.716471195221; Testing MSE 2198.96337389946
Epoch 3: Training R2 0.12959598884284274; Testing R2 0.1374687629377055


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8972.319748997688; Testing MSE 2207.430684566498
Epoch 4: Training R2 0.16150087636176724; Testing R2 0.13414750709919643


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8564.183390140533; Testing MSE 2347.8362023830414
Epoch 5: Training R2 0.199642849541541; Testing R2 0.07907421828951555


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 7944.165068864822; Testing MSE 2313.728141784668
Epoch 6: Training R2 0.25758603854652584; Testing R2 0.0924529166575293


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 6337.693986296654; Testing MSE 2672.9079842567444
Epoch 7: Training R2 0.4077171788276166; Testing R2 -0.048433392560936595
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11315.364849567413, 9530.11499941349, 9370.144230127335, 9313.716471195221, 8972.319748997688, 8564.183390140533, 7944.165068864822, 6337.693986296654], 'mse_test_list': [2314.13556933403, 2314.4658267498016, 2233.00239443779, 2198.96337389946, 2207.430684566498, 2347.8362023830414, 2313.728141784668, 2672.9079842567444], 'r_square_train_list': [-0.05746604840616043, 0.10937268190063398, 0.12432258935006024, 0.12959598884284274, 0.16150087636176724, 0.199642849541541, 0.25758603854652584, 0.4077171788276166], 'r_square_test_list': [0.0922931054517876, 0.09216356380468449, 0.12411714515188921, 0.1374687629377055, 0.13414750709919643, 0.07907421828951555, 0.0924529166575293, -0.048433392560936595]}}
Stored 'performance_continuous' (dict)
HHVEHCNT_P_CAP_mean_norm
tensor(

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12089.893049001694; Testing MSE 2285.968017578125
Epoch 0: Training R2 -0.12895756898470134; Testing R2 0.06541312137190158


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9924.098697304726; Testing MSE 2210.679680109024
Epoch 1: Training R2 0.07328490878597715; Testing R2 0.09619373237401552


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9766.016918420792; Testing MSE 2412.4809086322784
Epoch 2: Training R2 0.08804662918054429; Testing R2 0.013690049549669903


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9708.214098215103; Testing MSE 2235.312360525131
Epoch 3: Training R2 0.09344427257700605; Testing R2 0.08612299659586453


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9515.746426582336; Testing MSE 2200.396704673767
Epoch 4: Training R2 0.11141695718174016; Testing R2 0.1003977867794783


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9493.635547161102; Testing MSE 2193.8647150993347
Epoch 5: Training R2 0.11348167723991898; Testing R2 0.10306830172135739


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 9323.465648293495; Testing MSE 2190.6908571720123
Epoch 6: Training R2 0.1293721896341382; Testing R2 0.10436589029245547


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 9020.015114545822; Testing MSE 2227.9255390167236
Epoch 7: Training R2 0.1577084847112209; Testing R2 0.08914299792722158


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 8374.292036890984; Testing MSE 2389.358991384506
Epoch 8: Training R2 0.21800628495079555; Testing R2 0.02314313038996385


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 7027.830973267555; Testing MSE 2574.9546110630035
Epoch 9: Training R2 0.34373919284003285; Testing R2 -0.052735110052849477
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11315.364849567413, 9530.11499941349, 9370.144230127335, 9313.716471195221, 8972.319748997688, 8564.183390140533, 7944.165068864822, 6337.693986296654], 'mse_test_list': [2314.13556933403, 2314.4658267498016, 2233.00239443779, 2198.96337389946, 2207.430684566498, 2347.8362023830414, 2313.728141784668, 2672.9079842567444], 'r_square_train_list': [-0.05746604840616043, 0.10937268190063398, 0.12432258935006024, 0.12959598884284274, 0.16150087636176724, 0.199642849541541, 0.25758603854652584, 0.4077171788276166], 'r_square_test_list': [0.0922931054517876, 0.09216356380468449, 0.12411714515188921, 0.1374687629377055, 0.13414750709919643, 0.07907421828951555, 0.0924529166575293, -0.048433392560936595]}, 'HHVEHCNT_P_CAP_mean_norm': {'mse_train_list': [12089.893049001694, 99

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12081.082955002785; Testing MSE 2342.55313873291
Epoch 0: Training R2 -0.037937177547657575; Testing R2 0.22917401986597963


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9169.74595785141; Testing MSE 2185.1834177970886
Epoch 1: Training R2 0.21218815616359288; Testing R2 0.28095712240400184


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 8871.46013379097; Testing MSE 2161.8258327245712
Epoch 2: Training R2 0.23781515893154948; Testing R2 0.28864302421318155


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8766.327074170113; Testing MSE 2310.176309943199
Epoch 3: Training R2 0.24684758686671526; Testing R2 0.23982773797073076


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8431.430971622467; Testing MSE 2464.9133145809174
Epoch 4: Training R2 0.27561993424192077; Testing R2 0.1889109406991062


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8147.319030761719; Testing MSE 2214.0186071395874
Epoch 5: Training R2 0.3000291984695518; Testing R2 0.2714687941694055


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 7624.738058447838; Testing MSE 2250.3514796495438
Epoch 6: Training R2 0.3449263506092918; Testing R2 0.2595133248993593


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 6796.907985210419; Testing MSE 2200.7143437862396
Epoch 7: Training R2 0.4160487502240803; Testing R2 0.2758466124010348


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 5696.479028463364; Testing MSE 2507.744652032852
Epoch 8: Training R2 0.5105912783825162; Testing R2 0.17481712693413998


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 4272.649231553078; Testing MSE 2561.1862659454346
Epoch 9: Training R2 0.6329185470733864; Testing R2 0.15723196152508878


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 2940.699829161167; Testing MSE 2507.2761833667755
Epoch 10: Training R2 0.7473519806077917; Testing R2 0.17497127832252035


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 2097.396568953991; Testing MSE 2572.248423099518
Epoch 11: Training R2 0.8198037474714331; Testing R2 0.15359191682782947


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 1511.0222727060318; Testing MSE 2541.576224565506
Epoch 12: Training R2 0.8701816551723376; Testing R2 0.16368471989253752


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 1068.6983332037926; Testing MSE 2442.3086524009705
Epoch 13: Training R2 0.9081835845555474; Testing R2 0.19634909037962134


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 881.8659879267216; Testing MSE 2418.93327832222
Epoch 14: Training R2 0.9242351453182512; Testing R2 0.20404084572865833


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 787.4497881159186; Testing MSE 2509.11505818367
Epoch 15: Training R2 0.9323468422837803; Testing R2 0.1743661896013129


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 727.6203881949186; Testing MSE 2396.529197692871
Epoch 16: Training R2 0.9374870402875244; Testing R2 0.21141299329046792


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 653.984259814024; Testing MSE 2328.6290287971497
Epoch 17: Training R2 0.9438134330076036; Testing R2 0.23375580096253656


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 538.3841747418046; Testing MSE 2416.928058862686
Epoch 18: Training R2 0.9537451275809318; Testing R2 0.20470067078437293


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 557.7746534720063; Testing MSE 2424.6559977531433
Epoch 19: Training R2 0.9520792091496549; Testing R2 0.20215776323143064


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 606.5411075949669; Testing MSE 2375.6326615810394
Epoch 20: Training R2 0.9478894758335342; Testing R2 0.218289077620628


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 455.4856387898326; Testing MSE 2363.4223103523254
Epoch 21: Training R2 0.9608672930978229; Testing R2 0.22230693992566208


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 485.4843504726887; Testing MSE 2318.0438578128815
Epoch 22: Training R2 0.9582899762918593; Testing R2 0.23723889155455957


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 459.2966862022877; Testing MSE 2547.314929962158
Epoch 23: Training R2 0.9605398698188407; Testing R2 0.16179637715275763


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 529.2788427323103; Testing MSE 2412.6693785190582
Epoch 24: Training R2 0.9545274053487254; Testing R2 0.20610200567649417
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11315.364849567413, 9530.11499941349, 9370.144230127335, 9313.716471195221, 8972.319748997688, 8564.183390140533, 7944.165068864822, 6337.693986296654], 'mse_test_list': [2314.13556933403, 2314.4658267498016, 2233.00239443779, 2198.96337389946, 2207.430684566498, 2347.8362023830414, 2313.728141784668, 2672.9079842567444], 'r_square_train_list': [-0.05746604840616043, 0.10937268190063398, 0.12432258935006024, 0.12959598884284274, 0.16150087636176724, 0.199642849541541, 0.25758603854652584, 0.4077171788276166], 'r_square_test_list': [0.0922931054517876, 0.09216356380468449, 0.12411714515188921, 0.1374687629377055, 0.13414750709919643, 0.07907421828951555, 0.0924529166575293, -0.048433392560936595]}, 'HHVEHCNT_P_CAP_mean_norm': {'mse_train_list': [12089.893049001694, 992

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11786.737409234047; Testing MSE 2251.1132657527924
Epoch 0: Training R2 -0.008235472685701906; Testing R2 0.23998260337435162


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 8238.061627745628; Testing MSE 1904.3690502643585
Epoch 1: Training R2 0.2953176378768745; Testing R2 0.35704985181526694


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 7950.699400901794; Testing MSE 1880.74049949646
Epoch 2: Training R2 0.31989855289640545; Testing R2 0.3650272867643919


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 7802.119779586792; Testing MSE 1952.2163212299347
Epoch 3: Training R2 0.33260802792132915; Testing R2 0.34089572982232563


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 7607.433521747589; Testing MSE 2024.9466925859451
Epoch 4: Training R2 0.3492614566338532; Testing R2 0.3163406137672804


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 7035.550817847252; Testing MSE 2112.403178215027
Epoch 5: Training R2 0.39818020388922304; Testing R2 0.28681368967285215


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 6508.9185655117035; Testing MSE 2091.420656442642
Epoch 6: Training R2 0.4432282353698508; Testing R2 0.2938977763843913


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 5563.46572637558; Testing MSE 2025.5631268024445
Epoch 7: Training R2 0.5241021071693239; Testing R2 0.31613249419571443


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 4632.918290793896; Testing MSE 2032.4881464242935
Epoch 8: Training R2 0.6037009733352188; Testing R2 0.3137944797276512


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 3441.931375861168; Testing MSE 2122.828283905983
Epoch 9: Training R2 0.705577787371896; Testing R2 0.28329398153229435


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 2415.608775615692; Testing MSE 2188.968950510025
Epoch 10: Training R2 0.7933692445036928; Testing R2 0.2609636714549475


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 1676.484253257513; Testing MSE 2110.3115022182465
Epoch 11: Training R2 0.8565938278892168; Testing R2 0.28751987810407964


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 1270.9576830267906; Testing MSE 2073.750653862953
Epoch 12: Training R2 0.8912825003375286; Testing R2 0.29986349546361


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 980.6191146373749; Testing MSE 2185.4078888893127
Epoch 13: Training R2 0.9161180110963973; Testing R2 0.26216595159934175


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 836.2754434347153; Testing MSE 2183.969485759735
Epoch 14: Training R2 0.9284651436837569; Testing R2 0.26265158305044256


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 703.1042858958244; Testing MSE 2125.586900115013
Epoch 15: Training R2 0.9398565813910342; Testing R2 0.2823626217729376


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 628.2416695728898; Testing MSE 2167.712962627411
Epoch 16: Training R2 0.9462603166007209; Testing R2 0.26814008537379497


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 552.0787814632058; Testing MSE 2087.5828832387924
Epoch 17: Training R2 0.9527752768334157; Testing R2 0.29519347946766294


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 541.7395975440741; Testing MSE 2076.9554555416107
Epoch 18: Training R2 0.9536596888317455; Testing R2 0.29878149525261644


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 475.5193680524826; Testing MSE 2064.626330137253
Epoch 19: Training R2 0.959324155771554; Testing R2 0.302944035598785


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 454.1709257289767; Testing MSE 1991.4639711380005
Epoch 20: Training R2 0.9611502978234063; Testing R2 0.3276449986571712


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 459.2134917154908; Testing MSE 2007.756644487381
Epoch 21: Training R2 0.960718957603934; Testing R2 0.32214429135316647


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 430.5494261905551; Testing MSE 2055.334413051605
Epoch 22: Training R2 0.963170876794989; Testing R2 0.30608115834624194


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 381.53140414506197; Testing MSE 2013.556146621704
Epoch 23: Training R2 0.967363869894881; Testing R2 0.3201862723672231


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 421.34439181536436; Testing MSE 2004.977634549141
Epoch 24: Training R2 0.9639582738381341; Testing R2 0.323082536411996
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11315.364849567413, 9530.11499941349, 9370.144230127335, 9313.716471195221, 8972.319748997688, 8564.183390140533, 7944.165068864822, 6337.693986296654], 'mse_test_list': [2314.13556933403, 2314.4658267498016, 2233.00239443779, 2198.96337389946, 2207.430684566498, 2347.8362023830414, 2313.728141784668, 2672.9079842567444], 'r_square_train_list': [-0.05746604840616043, 0.10937268190063398, 0.12432258935006024, 0.12959598884284274, 0.16150087636176724, 0.199642849541541, 0.25758603854652584, 0.4077171788276166], 'r_square_test_list': [0.0922931054517876, 0.09216356380468449, 0.12411714515188921, 0.1374687629377055, 0.13414750709919643, 0.07907421828951555, 0.0924529166575293, -0.048433392560936595]}, 'HHVEHCNT_P_CAP_mean_norm': {'mse_train_list': [12089.893049001694, 9924.

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12649.357032775879; Testing MSE 2627.3732900619507
Epoch 0: Training R2 -0.10308595512766949; Testing R2 0.09996569409700773


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9815.56729376316; Testing MSE 2437.652140855789
Epoch 1: Training R2 0.14403440464953066; Testing R2 0.16495666568325662


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 8959.015733003616; Testing MSE 2743.63195002079
Epoch 2: Training R2 0.21872990056037644; Testing R2 0.060140069501882554


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8987.874254584312; Testing MSE 2460.9507143497467
Epoch 3: Training R2 0.2162132959804811; Testing R2 0.15697549471585692


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8916.17025732994; Testing MSE 2416.213721036911
Epoch 4: Training R2 0.2224662361175066; Testing R2 0.17230062147907965


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8353.138320147991; Testing MSE 2389.2938047647476
Epoch 5: Training R2 0.2715653816775905; Testing R2 0.18152232143645863


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8251.626248657703; Testing MSE 2634.13989841938
Epoch 6: Training R2 0.2804177320418587; Testing R2 0.09764772136000355


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 7879.798702895641; Testing MSE 2502.523100376129
Epoch 7: Training R2 0.3128429171636762; Testing R2 0.1427344374045435


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 6805.358597636223; Testing MSE 2533.20609331131
Epoch 8: Training R2 0.4065393624981367; Testing R2 0.13222365602684638


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 6148.814527690411; Testing MSE 2779.785943031311
Epoch 9: Training R2 0.4637932245405344; Testing R2 0.04775513960709832


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 5251.576536893845; Testing MSE 2691.6738122701645
Epoch 10: Training R2 0.5420367766428398; Testing R2 0.077938875108715


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 4182.086662948132; Testing MSE 2863.98703455925
Epoch 11: Training R2 0.6353015375349502; Testing R2 0.018911171657711168


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 3158.3401016891003; Testing MSE 2804.2855501174927
Epoch 12: Training R2 0.7245772572738267; Testing R2 0.03936254197275624


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 2560.92569231987; Testing MSE 2639.2540335655212
Epoch 13: Training R2 0.7766747229915347; Testing R2 0.09589582067121993


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 2079.849884659052; Testing MSE 2695.367230474949
Epoch 14: Training R2 0.8186268922130476; Testing R2 0.07667365592462472


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 1577.329396829009; Testing MSE 2612.1334552764893
Epoch 15: Training R2 0.8624491426920966; Testing R2 0.10518625950927374


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 1389.2657026648521; Testing MSE 2737.867718935013
Epoch 16: Training R2 0.8788492189303102; Testing R2 0.06211466738029425


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 1108.765833079815; Testing MSE 2538.216444849968
Epoch 17: Training R2 0.9033101828949345; Testing R2 0.13050730750245387


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 994.5759326219559; Testing MSE 2628.644350171089
Epoch 18: Training R2 0.9132681021066471; Testing R2 0.09953027911147194


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 980.3209878504276; Testing MSE 2731.090968847275
Epoch 19: Training R2 0.9145112031850536; Testing R2 0.06443611427349594


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 840.8327758312225; Testing MSE 2526.717147231102
Epoch 20: Training R2 0.9266752592066814; Testing R2 0.13444651263554863


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 679.7834578901529; Testing MSE 2602.0015090703964
Epoch 21: Training R2 0.9407195493823298; Testing R2 0.10865706405978826


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 657.7570198103786; Testing MSE 2534.174147248268
Epoch 22: Training R2 0.9426403627821202; Testing R2 0.13189203898692192


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 503.67720536887646; Testing MSE 2547.534403204918
Epoch 23: Training R2 0.9560768780799891; Testing R2 0.12731534303659275


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 456.79870741441846; Testing MSE 2495.8090037107468
Epoch 24: Training R2 0.9601649129545724; Testing R2 0.1450344217101046
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11315.364849567413, 9530.11499941349, 9370.144230127335, 9313.716471195221, 8972.319748997688, 8564.183390140533, 7944.165068864822, 6337.693986296654], 'mse_test_list': [2314.13556933403, 2314.4658267498016, 2233.00239443779, 2198.96337389946, 2207.430684566498, 2347.8362023830414, 2313.728141784668, 2672.9079842567444], 'r_square_train_list': [-0.05746604840616043, 0.10937268190063398, 0.12432258935006024, 0.12959598884284274, 0.16150087636176724, 0.199642849541541, 0.25758603854652584, 0.4077171788276166], 'r_square_test_list': [0.0922931054517876, 0.09216356380468449, 0.12411714515188921, 0.1374687629377055, 0.13414750709919643, 0.07907421828951555, 0.0924529166575293, -0.048433392560936595]}, 'HHVEHCNT_P_CAP_mean_norm': {'mse_train_list': [12089.893049001694, 992

tensor(10700.4521)
tensor(2549.4304)


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12289.831066131592; Testing MSE 2252.972501516342
Epoch 0: Training R2 -0.1485338091929298; Testing R2 0.11628398095862424


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9480.62248826027; Testing MSE 2213.598781824112
Epoch 1: Training R2 0.11399795478318664; Testing R2 0.13172810502043597


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9419.594323635101; Testing MSE 2448.0222284793854
Epoch 2: Training R2 0.11970128056592744; Testing R2 0.03977680294785102


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9423.573732376099; Testing MSE 2383.0234944820404
Epoch 3: Training R2 0.11932938892193012; Testing R2 0.06527219732670086


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9390.378141403198; Testing MSE 2293.3943271636963
Epoch 4: Training R2 0.12243164950983887; Testing R2 0.10042874312530747


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9189.957186579704; Testing MSE 2247.4409699440002
Epoch 5: Training R2 0.14116178839025606; Testing R2 0.11845369366351599


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 9208.962392807007; Testing MSE 2479.7330141067505
Epoch 6: Training R2 0.13938567594531814; Testing R2 0.02733842244545759


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 9008.318811655045; Testing MSE 2231.999635696411
Epoch 7: Training R2 0.15813662014549024; Testing R2 0.12451047172928575


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 8940.918028354645; Testing MSE 2276.9818365573883
Epoch 8: Training R2 0.16443549260110335; Testing R2 0.10686645190843669


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 8769.993835687637; Testing MSE 2412.8169894218445
Epoch 9: Training R2 0.18040904122278156; Testing R2 0.05358586350602068


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 8458.903816342354; Testing MSE 2458.054429292679
Epoch 10: Training R2 0.2094816462893544; Testing R2 0.03584172759341131


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 7779.080682992935; Testing MSE 2273.6178278923035
Epoch 11: Training R2 0.2730138338940331; Testing R2 0.10818596572563988


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 7207.260704040527; Testing MSE 2583.915901184082
Epoch 12: Training R2 0.32645269526363474; Testing R2 -0.013526739538654908
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12289.831066131592, 9480.62248826027, 9419.594323635101, 9423.573732376099, 9390.378141403198, 9189.957186579704, 9208.962392807007, 9008.318811655045, 8940.918028354645, 8769.993835687637, 8458.903816342354, 7779.080682992935, 7207.260704040527], 'mse_test_list': [2252.972501516342, 2213.598781824112, 2448.0222284793854, 2383.0234944820404, 2293.3943271636963, 2247.4409699440002, 2479.7330141067505, 2231.999635696411, 2276.9818365573883, 2412.8169894218445, 2458.054429292679, 2273.6178278923035, 2583.915901184082], 'r_square_train_list': [-0.1485338091929298, 0.11399795478318664, 0.11970128056592744, 0.11932938892193012, 0.12243164950983887, 0.14116178839025606, 0.13938567594531814, 0.15813662014549024, 0.16443549260110335, 0.18040904122278156, 0.2094816462893544, 

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 13540.380555391312; Testing MSE 2281.4390897750854
Epoch 0: Training R2 -0.26440449497645147; Testing R2 0.06726471179942606


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9735.587310791016; Testing MSE 2164.970076084137
Epoch 1: Training R2 0.09088815438805686; Testing R2 0.11488148120534214


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9618.94556581974; Testing MSE 2197.1999406814575
Epoch 2: Training R2 0.10178019291239915; Testing R2 0.10170474018316089


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9510.758924484253; Testing MSE 2190.2142107486725
Epoch 3: Training R2 0.11188269151215458; Testing R2 0.10456076068852971


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9490.629184246063; Testing MSE 2234.970724582672
Epoch 4: Training R2 0.11376241224347039; Testing R2 0.08626266979629194


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9593.11273097992; Testing MSE 2149.606055021286
Epoch 5: Training R2 0.10419247020075417; Testing R2 0.12116285188852383


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 9410.212725400925; Testing MSE 2204.2458415031433
Epoch 6: Training R2 0.12127172349347048; Testing R2 0.0988241196296682


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 9397.264230251312; Testing MSE 2185.689878463745
Epoch 7: Training R2 0.12248085756495208; Testing R2 0.10641047230108602


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 9295.734244585037; Testing MSE 2200.8840322494507
Epoch 8: Training R2 0.13196175581047576; Testing R2 0.10019854953979457


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 9178.00213098526; Testing MSE 2227.311858534813
Epoch 9: Training R2 0.14295561326002626; Testing R2 0.08939389283111066


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 8911.69028878212; Testing MSE 2508.815550804138
Epoch 10: Training R2 0.16782388701124196; Testing R2 -0.025695056383062687
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12289.831066131592, 9480.62248826027, 9419.594323635101, 9423.573732376099, 9390.378141403198, 9189.957186579704, 9208.962392807007, 9008.318811655045, 8940.918028354645, 8769.993835687637, 8458.903816342354, 7779.080682992935, 7207.260704040527], 'mse_test_list': [2252.972501516342, 2213.598781824112, 2448.0222284793854, 2383.0234944820404, 2293.3943271636963, 2247.4409699440002, 2479.7330141067505, 2231.999635696411, 2276.9818365573883, 2412.8169894218445, 2458.054429292679, 2273.6178278923035, 2583.915901184082], 'r_square_train_list': [-0.1485338091929298, 0.11399795478318664, 0.11970128056592744, 0.11932938892193012, 0.12243164950983887, 0.14116178839025606, 0.13938567594531814, 0.15813662014549024, 0.16443549260110335, 0.18040904122278156, 0.2094816462893544, 0

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11589.290422201157; Testing MSE 2117.0675933361053
Epoch 0: Training R2 0.004314808912195156; Testing R2 0.30337089235637715


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 8945.562520623207; Testing MSE 2170.309364795685
Epoch 1: Training R2 0.23144870796646055; Testing R2 0.28585148586316933


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 8658.142161369324; Testing MSE 2129.5178294181824
Epoch 2: Training R2 0.2561422124780054; Testing R2 0.29927409503204494


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8583.106908202171; Testing MSE 2090.3336107730865
Epoch 3: Training R2 0.26258880995432166; Testing R2 0.3121678105441845


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8551.913020014763; Testing MSE 2200.423604249954
Epoch 4: Training R2 0.26526880945292364; Testing R2 0.2759422813463046


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8333.581930398941; Testing MSE 2115.4613733291626
Epoch 5: Training R2 0.28402656120173597; Testing R2 0.30389942513144674


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8071.595141291618; Testing MSE 2079.396069049835
Epoch 6: Training R2 0.3065349596171465; Testing R2 0.31576685006208516


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 7543.664637207985; Testing MSE 2220.050159096718
Epoch 7: Training R2 0.3518917127650891; Testing R2 0.26948408916006716


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 6944.561877846718; Testing MSE 2277.8170585632324
Epoch 8: Training R2 0.4033631768241077; Testing R2 0.25047567216225053


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 6217.7868485450745; Testing MSE 2250.7968842983246
Epoch 9: Training R2 0.4658035081615469; Testing R2 0.2593667628131996


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 5625.437590479851; Testing MSE 2213.0647599697113
Epoch 10: Training R2 0.516694750223921; Testing R2 0.27178266119229644


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 4941.749545931816; Testing MSE 2299.4918435811996
Epoch 11: Training R2 0.575433295594757; Testing R2 0.243343502082768


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 4544.346979260445; Testing MSE 2484.0884804725647
Epoch 12: Training R2 0.6095758389360615; Testing R2 0.18260127975768714


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 4108.480057120323; Testing MSE 2270.719599723816
Epoch 13: Training R2 0.6470230185173558; Testing R2 0.2528111178671506


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 3782.21971988678; Testing MSE 2441.2298917770386
Epoch 14: Training R2 0.6750534305950826; Testing R2 0.1967040606475411


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 3348.8546445965767; Testing MSE 2447.4378407001495
Epoch 15: Training R2 0.7122856658814215; Testing R2 0.19466131154863808


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 2775.9963780641556; Testing MSE 2565.4694736003876
Epoch 16: Training R2 0.7615023540304988; Testing R2 0.15582255582830595


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 2441.845463216305; Testing MSE 2512.2011244297028
Epoch 17: Training R2 0.7902106791635966; Testing R2 0.17335070781798678


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 1967.9514281451702; Testing MSE 2462.160551548004
Epoch 18: Training R2 0.8309249296203174; Testing R2 0.1898167477981373


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 1700.1413457095623; Testing MSE 2557.1705639362335
Epoch 19: Training R2 0.8539336319127648; Testing R2 0.1585533434762546


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 1278.2281830906868; Testing MSE 2497.9071736335754
Epoch 20: Training R2 0.8901819847144066; Testing R2 0.17805418645001136


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 1181.6141374409199; Testing MSE 2348.490169644356
Epoch 21: Training R2 0.8984825079561292; Testing R2 0.22722041736461807


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 970.6743523478508; Testing MSE 2398.120477795601
Epoch 22: Training R2 0.9166052413226231; Testing R2 0.210889376547448


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 789.4730728119612; Testing MSE 2343.8899874687195
Epoch 23: Training R2 0.9321730132972061; Testing R2 0.2287341247276229


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 779.5168777927756; Testing MSE 2401.253205537796
Epoch 24: Training R2 0.9330283923220681; Testing R2 0.20985854062213072
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12289.831066131592, 9480.62248826027, 9419.594323635101, 9423.573732376099, 9390.378141403198, 9189.957186579704, 9208.962392807007, 9008.318811655045, 8940.918028354645, 8769.993835687637, 8458.903816342354, 7779.080682992935, 7207.260704040527], 'mse_test_list': [2252.972501516342, 2213.598781824112, 2448.0222284793854, 2383.0234944820404, 2293.3943271636963, 2247.4409699440002, 2479.7330141067505, 2231.999635696411, 2276.9818365573883, 2412.8169894218445, 2458.054429292679, 2273.6178278923035, 2583.915901184082], 'r_square_train_list': [-0.1485338091929298, 0.11399795478318664, 0.11970128056592744, 0.11932938892193012, 0.12243164950983887, 0.14116178839025606, 0.13938567594531814, 0.15813662014549024, 0.16443549260110335, 0.18040904122278156, 0.2094816462893544, 0.2

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12354.858168959618; Testing MSE 1917.7852928638458
Epoch 0: Training R2 -0.05683242388915577; Testing R2 0.3525202806346046


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 8083.891075849533; Testing MSE 1827.8386652469635
Epoch 1: Training R2 0.3085053601335356; Testing R2 0.3828879226349635


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 7986.111298203468; Testing MSE 1910.0234299898148
Epoch 2: Training R2 0.31686942534608953; Testing R2 0.3551408288336815


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 7964.319717884064; Testing MSE 1952.5537848472595
Epoch 3: Training R2 0.3187334733452153; Testing R2 0.3407817958751551


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 7891.997271776199; Testing MSE 1793.386685848236
Epoch 4: Training R2 0.3249199228354892; Testing R2 0.3945195469026411


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 7652.407121658325; Testing MSE 1848.8260984420776
Epoch 5: Training R2 0.34541442270155776; Testing R2 0.37580217773644053


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 7680.348724126816; Testing MSE 1705.4666101932526
Epoch 6: Training R2 0.3430243028749853; Testing R2 0.4242029875482122


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 7397.751063108444; Testing MSE 1757.7032774686813
Epoch 7: Training R2 0.3671976577605801; Testing R2 0.4065669243278237


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7254.348391294479; Testing MSE 1811.5449845790863
Epoch 8: Training R2 0.37946429742351817; Testing R2 0.388388970028289


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 6922.560986876488; Testing MSE 1908.9200019836426
Epoch 9: Training R2 0.40784533442383886; Testing R2 0.35551336649909937


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 6935.7991099357605; Testing MSE 1855.3449243307114
Epoch 10: Training R2 0.4067129476745014; Testing R2 0.3736013017715083


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 6725.269651412964; Testing MSE 2138.31946849823
Epoch 11: Training R2 0.4247216010242998; Testing R2 0.27806387162909063


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 6508.821114897728; Testing MSE 1937.5603079795837
Epoch 12: Training R2 0.44323657127846006; Testing R2 0.3458438704622937


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 5722.982302308083; Testing MSE 1991.997516155243
Epoch 13: Training R2 0.5104570869442604; Testing R2 0.32746486400950314


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 5292.420035600662; Testing MSE 2047.58979678154
Epoch 14: Training R2 0.5472873085248557; Testing R2 0.30869588377342805


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 4638.540598750114; Testing MSE 1941.484248638153
Epoch 15: Training R2 0.603220042088258; Testing R2 0.3445190756555597


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 4100.440895557404; Testing MSE 2043.9857006072998
Epoch 16: Training R2 0.6492489973253115; Testing R2 0.3099126931775601


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 3060.4637518525124; Testing MSE 2150.8102416992188
Epoch 17: Training R2 0.7382084617352145; Testing R2 0.2738467560025777


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 2494.528353214264; Testing MSE 2206.7391872406006
Epoch 18: Training R2 0.7866184775304748; Testing R2 0.2549641114759532


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 1874.480029195547; Testing MSE 2340.996688604355
Epoch 19: Training R2 0.8396573035728048; Testing R2 0.20963630046959647


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 1582.7989727258682; Testing MSE 2361.891895532608
Epoch 20: Training R2 0.8646076505291032; Testing R2 0.2025816928613674


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 1250.1707941293716; Testing MSE 2147.7686166763306
Epoch 21: Training R2 0.8930606072067574; Testing R2 0.274873664762157


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 980.292809009552; Testing MSE 2228.0475854873657
Epoch 22: Training R2 0.916145923223179; Testing R2 0.24777000285061335


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 866.4809342473745; Testing MSE 2023.9900648593903
Epoch 23: Training R2 0.9258813712410666; Testing R2 0.31666358894819957


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 816.2217736244202; Testing MSE 2107.584375143051
Epoch 24: Training R2 0.9301805311195053; Testing R2 0.2884406065505286
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12289.831066131592, 9480.62248826027, 9419.594323635101, 9423.573732376099, 9390.378141403198, 9189.957186579704, 9208.962392807007, 9008.318811655045, 8940.918028354645, 8769.993835687637, 8458.903816342354, 7779.080682992935, 7207.260704040527], 'mse_test_list': [2252.972501516342, 2213.598781824112, 2448.0222284793854, 2383.0234944820404, 2293.3943271636963, 2247.4409699440002, 2479.7330141067505, 2231.999635696411, 2276.9818365573883, 2412.8169894218445, 2458.054429292679, 2273.6178278923035, 2583.915901184082], 'r_square_train_list': [-0.1485338091929298, 0.11399795478318664, 0.11970128056592744, 0.11932938892193012, 0.12243164950983887, 0.14116178839025606, 0.13938567594531814, 0.15813662014549024, 0.16443549260110335, 0.18040904122278156, 0.2094816462893544, 0.27

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12221.366964280605; Testing MSE 2434.195813536644
Epoch 0: Training R2 -0.06576312264946726; Testing R2 0.1661406668951998


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9284.279108047485; Testing MSE 2318.77281665802
Epoch 1: Training R2 0.190365339436942; Testing R2 0.20568002632837956


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 8985.270258784294; Testing MSE 2349.91497695446
Epoch 2: Training R2 0.21644037718206066; Testing R2 0.1950119523501801


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8789.655330777168; Testing MSE 2419.274690747261
Epoch 3: Training R2 0.2334989580363176; Testing R2 0.17125205416697908


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9078.52722555399; Testing MSE 2341.0926580429077
Epoch 4: Training R2 0.2083079124255588; Testing R2 0.1980341303208737


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8953.550997376442; Testing MSE 2418.983170390129
Epoch 5: Training R2 0.21920645230156; Testing R2 0.1713519175260534


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8693.798604607582; Testing MSE 2356.831109523773
Epoch 6: Training R2 0.2418581345595361; Testing R2 0.1926427585245064


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8784.278130531311; Testing MSE 2414.697849750519
Epoch 7: Training R2 0.23396787626304993; Testing R2 0.1728198991037121


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 8640.774777531624; Testing MSE 2374.1650879383087
Epoch 8: Training R2 0.24648207226506957; Testing R2 0.18670482222520868


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 8533.23737680912; Testing MSE 2534.981232881546
Epoch 9: Training R2 0.2558598608815633; Testing R2 0.13161556332946622


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 8358.53601694107; Testing MSE 2625.91512799263
Epoch 10: Training R2 0.27109467605144233; Testing R2 0.10046520282343008


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 7860.521158576012; Testing MSE 2450.1741141080856
Epoch 11: Training R2 0.3145240135493055; Testing R2 0.16066713227462526


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 6942.787021398544; Testing MSE 2564.8618668317795
Epoch 12: Training R2 0.3945549301119885; Testing R2 0.12137963844621313


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 6337.767417728901; Testing MSE 2568.4024930000305
Epoch 13: Training R2 0.447315606062206; Testing R2 0.12016675977843583


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 5950.50647854805; Testing MSE 2669.2791879177094
Epoch 14: Training R2 0.48108665876259793; Testing R2 0.08561038880692451


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 5471.460755169392; Testing MSE 2806.922161579132
Epoch 15: Training R2 0.5228617946808707; Testing R2 0.038459343034185745


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 5030.691860616207; Testing MSE 2852.891978621483
Epoch 16: Training R2 0.5612990034480004; Testing R2 0.02271189955871966


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 4827.313807606697; Testing MSE 2748.250252008438
Epoch 17: Training R2 0.5790345668663388; Testing R2 0.05855802166740631


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 4774.267952144146; Testing MSE 2902.468317747116
Epoch 18: Training R2 0.583660425555192; Testing R2 0.005729004077926203


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 4238.681660592556; Testing MSE 2865.4274731874466
Epoch 19: Training R2 0.6303661762458121; Testing R2 0.01841773428212823


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 4036.992648243904; Testing MSE 2739.799642562866
Epoch 20: Training R2 0.6479544470368805; Testing R2 0.061452866657865646


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 3608.556942641735; Testing MSE 2919.481974840164
Epoch 21: Training R2 0.6853161412558361; Testing R2 -9.920278956920825e-05
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12289.831066131592, 9480.62248826027, 9419.594323635101, 9423.573732376099, 9390.378141403198, 9189.957186579704, 9208.962392807007, 9008.318811655045, 8940.918028354645, 8769.993835687637, 8458.903816342354, 7779.080682992935, 7207.260704040527], 'mse_test_list': [2252.972501516342, 2213.598781824112, 2448.0222284793854, 2383.0234944820404, 2293.3943271636963, 2247.4409699440002, 2479.7330141067505, 2231.999635696411, 2276.9818365573883, 2412.8169894218445, 2458.054429292679, 2273.6178278923035, 2583.915901184082], 'r_square_train_list': [-0.1485338091929298, 0.11399795478318664, 0.11970128056592744, 0.11932938892193012, 0.12243164950983887, 0.14116178839025606, 0.13938567594531814, 0.15813662014549024, 0.16443549260110335, 0.18040904122278156, 0.2094816462893544, 

## Store and save resnet's last layer

In [11]:
### 
def return_last_layer_resnet(model,device,x_train_images,x_test_images,y_train,y_test):
    ###  
    # 
    from sklearn.preprocessing import MinMaxScaler

    image_train_hidden_list = []
    image_test_hidden_list = []

    # return values in the last layer.
    model_no_last_layer = nn.Sequential(*list(model.children())[:-1]).to(device)

    # process data
    x_train_images_norm = x_train_images/255
    x_test_images_norm = x_test_images/255

    x_train_torch = torch.from_numpy(x_train_images_norm)
    x_test_torch = torch.from_numpy(x_test_images_norm)
    y_train_torch = torch.from_numpy(y_train)
    y_test_torch = torch.from_numpy(y_test)

    # create data loader: train and test. 
    train_ds = TensorDataset(x_train_torch, y_train_torch)
    batch_size = 100
    train_dl_no_shuffle = DataLoader(train_ds, batch_size, shuffle = False) # important: NO SHUFFLE!!!

    test_ds = TensorDataset(x_test_torch, y_test_torch)
    batch_size = 100
    test_dl_no_shuffle = DataLoader(test_ds, batch_size, shuffle = False)

    for inputs, labels in train_dl_no_shuffle:
        # to device
        inputs = inputs.to(device)
        labels = labels.to(device)
        image_train_hidden_vector_batch = model_no_last_layer(inputs)
        image_train_hidden_list.append(image_train_hidden_vector_batch.squeeze().cpu().detach().numpy())

    for inputs, labels in test_dl_no_shuffle:
        inputs = inputs.to(device)
        labels = labels.to(device)
        # forward + backward
        image_test_hidden_vector_batch = model_no_last_layer(inputs)
        image_test_hidden_list.append(image_test_hidden_vector_batch.squeeze().cpu().detach().numpy())

    # vectorize
    image_train_hidden_vector=np.array(image_train_hidden_list).reshape(-1,512) # 512, resnet architecture   
    image_test_hidden_vector=np.array(image_test_hidden_list).reshape(-1,512) # 512, resnet architecture

    # scale
    scaler = MinMaxScaler()
    image_train_hidden_vector_norm = scaler.fit_transform(image_train_hidden_vector)
    image_test_hidden_vector_norm = scaler.fit_transform(image_test_hidden_vector)
    
    return image_train_hidden_vector_norm,image_test_hidden_vector_norm


In [12]:
# load model dictionary for all the output variables
model_dic = {}
model_name = 'resnet18'
if image_type == "rgb":
    input_channels = 3
elif image_type == "bw":
    input_channels = 4
elif image_type == "merge":
    input_channels = 7
use_pretrained = True # unclear True vs False is better
full_training = True
last_layer_dic_train = {}
last_layer_dic_test = {}

for output_var in output_list:
    print(output_var)
    # read models
    model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
    PATH = './models/'+model_name+'_'+output_var+'_'+image_type+'.pth'
    model.load_state_dict(torch.load(PATH))
    model_dic[output_var]=model.state_dict()
    
    # initialize data. 
    y_train,y_test,BE_train,BE_test,x_train,x_test,x_train_images,x_test_images = \
        initialize_data(image_type, output_var, output_type, input_var, BE_var, num_categories, size)
    
    # obtain the last layer
    image_train_hidden_vector_norm,image_test_hidden_vector_norm = \
        return_last_layer_resnet(model,device,x_train_images,x_test_images,y_train,y_test)

    # 
    last_layer_dic_train[output_var]=image_train_hidden_vector_norm 
    last_layer_dic_test[output_var]=image_test_hidden_vector_norm
    



HHVEHCNT_mean_norm
HHVEHCNT_P_CAP_mean_norm
TRPTRANS_1_mean_norm
TRPTRANS_2_mean_norm
TRPTRANS_3_mean_norm


In [13]:
with open('last_layer/last_layer_dic_train.pickle', 'wb') as h:
    pickle.dump(last_layer_dic_train, h, protocol=pickle.HIGHEST_PROTOCOL)

with open('last_layer/last_layer_dic_test.pickle', 'wb') as h:
    pickle.dump(last_layer_dic_test, h, protocol=pickle.HIGHEST_PROTOCOL)

