# CNN

Use Resnet to train models to fit five travel behavioral variables by using images. The default is using only the black and white images with resnet18 to fit continous outputs, but the codes leave the flexibility of using other models, images, and output types. Several functions are not useful yet: bottleneck_resnet18, return_bottleneck_resnet18, and train_discrete_model. 

#### To be done: adjustment along many dimensions - hyperparameters, model choice, etc.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import TensorDataset, DataLoader
import torch.optim as optim

from tqdm.notebook import tqdm

import pickle

import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import util
import statsmodels.api as sm
from scipy import stats
import copy

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import log_loss
from sklearn.metrics import r2_score
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import PolynomialFeatures

In [2]:
# ALWAYS choose devise first.
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Helper Functions

In [3]:
def initialize_data(image_type, output_var, output_type, input_var, BE_var, num_categories, size):
    # outputs: randonmized training and testing sets for NHTS, BE, images, and y.
    
    ### read image array
    if image_type == 'rgb':
        image_array_ = np.load("data_process/image_array_rgb_tract_large.npy", mmap_mode='r')
        image_array = image_array_[:size,]
    elif image_type == 'bw':
        image_array_ = np.load("data_process/image_array_bw_tract_large.npy", mmap_mode='r')
        image_array = image_array_[:size,]        
    elif image_type == 'merge':
        bw_image_array_ = np.load("data_process/image_array_bw_tract_large.npy", mmap_mode='r')
        rgb_image_array_ = np.load("data_process/image_array_rgb_tract_large.npy", mmap_mode='r')
        bw_image_array = bw_image_array_[:size,]
        rgb_image_array = rgb_image_array_[:size,]        
        image_array = np.concatenate([rgb_image_array, bw_image_array], axis=1)
    
    ### create output array
    df_ = pd.read_csv("data_process/df_merged_tract_large.csv")
    df = df_.iloc[:size,]
    y_ = df[output_var].values 
    # cut y into categories for discrete variables
    if output_type == 'continuous':
        y = copy.deepcopy(y_)
    elif output_type == 'discrete':
        y = np.array(pd.qcut(y_, q = num_categories, labels=np.arange(num_categories))) 
    x = df[input_var]
    BE = df[BE_var]
            
    ### randomization
    shuffle_idx = np.arange(size)
    np.random.seed(0) # important: don't change the seed number, unless the seed number across scripts are all changed.
    np.random.shuffle(shuffle_idx)
    train_ratio = 0.8

    ###
    # y
    if output_type == 'discrete':
        y_train = y[shuffle_idx[:int(train_ratio*size)]].astype("int")
        y_test = y[shuffle_idx[int(train_ratio*size):]].astype("int")
    elif output_type == 'continuous':
        y_train = y[shuffle_idx[:int(train_ratio*size)]].astype("float32")
        y_test = y[shuffle_idx[int(train_ratio*size):]].astype("float32")
    # BE
    BE_train = BE.values[shuffle_idx[:int(train_ratio*size)]].astype("float32")
    BE_test = BE.values[shuffle_idx[int(train_ratio*size):]].astype("float32")        
    # image array
    x_train_images = image_array[shuffle_idx[:int(train_ratio*size)],].astype("float32")
    x_test_images = image_array[shuffle_idx[int(train_ratio*size):],].astype("float32")
    # NHTS
    x_train = x.values[shuffle_idx[:int(train_ratio*size)]].astype("float32")
    x_test = x.values[shuffle_idx[int(train_ratio*size):]].astype("float32")
    
    return y_train,y_test,BE_train,BE_test,x_train,x_test,x_train_images,x_test_images

# # test 
# image_type = 'bw'
# output_var = 'HHFAMINC_mean'
# output_type = 'continuous'
# input_var=['R_AGE_IMP_mean', 'HHSIZE_mean', 'HHFAMINC_mean', 'HBHTNRNT_mean', 'HBPPOPDN_mean', 'HBRESDN_mean', 
#            'R_SEX_IMP_2_mean', 'EDUC_2_mean', 'HH_RACE_2_mean', 'HOMEOWN_1_mean', 'HOMEOWN_2_mean',
#            'HBHUR_R_mean', 'HBHUR_S_mean', 'HBHUR_T_mean','HBHUR_U_mean']
# BE_var = ['density', 'diversity', 'design']
# num_categories = 1 # (1) certain category values can cause errors. (2) when output_type = 'continuous', this value needs to be 1.
# size = 10000 # size needs to be smaller than the max
# # 
# y_train,y_test,BE_train,BE_test,x_train,x_test,x_train_images,x_test_images = \
#     initialize_data(image_type, output_var, output_type, input_var, BE_var, num_categories, size)

# print(y_train.shape)
# print(y_test.shape)
# print(x_train_images.shape)
# print(x_test_images.shape)
# plt.figure()
# plt.boxplot(y_train)
# plt.figure()
# plt.boxplot(y_test)

In [4]:
def initialize_model(model_name, num_categories, input_channels = 3, use_pretrained=True, full_training=False):
    # initliaze the CNN model.
    # default input image size = 3*224*224, but inputs and output channels can be changed. 
    # num_categories: output channels. For continuous varialbes, use num_categories = 1.
    # return the model

    if model_name == 'resnet18':
        """ resnet 18"""
        model_ft = models.resnet18(pretrained=use_pretrained)
        # train only the last layer.
        for param in model_ft.parameters():
            param.requires_grad=full_training
        if input_channels != 3:
            # Edit the input channels.
            model_ft.conv1 = nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_categories) # if output_type == continuous, then num_categories = 1.

    elif model_name == 'alexnet':
        """ alexnet """
        model_ft = models.alexnet(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training 
        if input_channels != 3:
            model_ft.features[0] = nn.Conv2d(input_channels, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))    
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_categories) # if output_type == continuous, then num_categories = 1.

    elif model_name == 'vgg':
        model_ft = models.vgg11_bn(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training 
        if input_channels != 3:
            model_ft.features[0] = nn.Conv2d(input_channels, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        num_ftrs = model_ft.classifier[6].in_features
        model_ft.classifier[6] = nn.Linear(num_ftrs, num_categories) # if output_type == continuous, then num_categories = 1.

    elif model_name == 'squeezenet':
        model_ft = models.squeezenet1_0(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training 
        if input_channels != 3:
            model_ft.features[0] = nn.Conv2d(input_channels, 96, kernel_size=(7, 7), stride=(2, 2))
        model_ft.classifier[1] = nn.Conv2d(512, num_categories, kernel_size=(1,1), stride=(1,1))

    elif model_name == 'densenet':
        model_ft = models.densenet121(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training
        if input_channels != 3:
            model_ft.features[0] = nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        num_ftrs = model_ft.classifier.in_features
        model_ft.classifier = nn.Linear(num_ftrs, num_categories)

    elif model_name == 'wide_resnet':
        model_ft = models.wide_resnet50_2(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training 
        if input_channels != 3:
            model_ft.conv1 = nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        num_ftrs = model_ft.fc.in_features
        model_ft.fc = nn.Linear(num_ftrs, num_categories)
        
    elif model_name == 'mnasnet':
        model_ft = models.mnasnet1_0(pretrained=use_pretrained)
        for param in model_ft.parameters():
            param.requires_grad = full_training
        if input_channels != 3:
            model_ft.layers[0] = nn.Conv2d(input_channels, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
        num_ftrs = model_ft.classifier[1].in_features
        model_ft.classifier[1] = nn.Linear(num_ftrs, num_categories)
        
    return model_ft

# # test 1. initialize model for continuous var
# model_name = 'resnet18'
# num_categories = 1 
# input_channels = 4
# use_pretrained = True
# full_training = True
# model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
# model.to(device)

# # test 2. initialize model for discrete var
# model_name = 'resnet18'
# # num_categories = 1 
# input_channels = 4
# use_pretrained = True
# full_training = True
# model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
# model.to(device)

# # test 3. initialize model for continuous var
# model_name = 'bottleneck_resnet18'
# num_categories = 1 
# input_channels = 4
# use_pretrained = True
# full_training = True
# model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
# model.to(device)


In [5]:
# class bottleneck_resnet18(nn.Module):
#     # This model does NOT work yet. It seems that the fc layer or the upsampling do not work...
#     # Goal: create a resnet architecture with bottleneck in the middle that reduces information into several nodes.
#     def __init__(self, num_categories, num_bottleneck, input_channels = 3, use_pretrained=True, full_training=False):
#         super(bottleneck_resnet18, self).__init__()
#         ref = models.resnet18(pretrained=use_pretrained)
#         self.sequence1 = nn.Sequential(ref.conv1, ref.bn1, ref.relu, ref.maxpool, ref.layer1,
#                                        ref.layer2)
#         ### condense 
#         if num_bottleneck == 1:
#             self.condense = nn.AvgPool3d((128,28,28))
#         elif num_bottleneck == 2:
#             self.condense = nn.AvgPool3d((128,28,14))
#         elif num_bottleneck == 3:
#             self.condense = nn.AvgPool3d((128,28,9))

#         ### upsampling
#         self.upsample = nn.Sequential(nn.Conv2d(num_bottleneck, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False),
#                                       nn.Upsample((28, 28)))
#         self.sequence2 = nn.Sequential(ref.layer3, ref.layer4, ref.avgpool)
#         self.fc = ref.fc
        
#         ### edit parameters
#         for param in self.parameters():
#             param.requires_grad=full_training
#         if input_channels != 3:
#             self.sequence1[0]=nn.Conv2d(input_channels, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
#         num_ftrs=self.fc.in_features
#         self.fc=nn.Linear(num_ftrs, num_categories)
        
#     def forward(self, x):
#         x=self.sequence1(x)
#         x=self.condense(x)
#         x=self.upsample(x)
#         x=self.sequence2(x)
#         x=x.squeeze() # sw: this line is important, but I don't understand why resnet18 does not need it...
#         out=self.fc(x)
#         return out

In [6]:
# def return_bottleneck_renset(model,device,x_train_images,x_test_images,y_train,y_test):
#     # This function does not work yet.
#     # Goal: return the several nodes' values from the bottleneck resnet architecture.
#     from sklearn.preprocessing import MinMaxScaler

#     bottleneck_train_list = []
#     def hook_train(module,inputs,outputs):
#         bottleneck_train_list.append(outputs)
        
#     bottleneck_test_list = []
#     def hook_test(module,inputs,outputs):
#         bottleneck_test_list.append(outputs)

#     x_train_images_norm = x_train_images/255
#     x_test_images_norm = x_test_images/255

#     x_train_torch = torch.from_numpy(x_train_images_norm)
#     x_test_torch = torch.from_numpy(x_test_images_norm)
#     y_train_torch = torch.from_numpy(y_train)
#     y_test_torch = torch.from_numpy(y_test)

#     # create data loader: train and test. 
#     train_ds = TensorDataset(x_train_torch, y_train_torch)
#     batch_size = 50
#     train_dl_no_shuffle = DataLoader(train_ds, batch_size, shuffle = False) # important: NO SHUFFLE.

#     test_ds = TensorDataset(x_test_torch, y_test_torch)
#     batch_size = 50
#     test_dl_no_shuffle = DataLoader(test_ds, batch_size, shuffle = False) # important: NO SHUFFLE.

#     for param in model.parameters():
#         param.requires_grad=False # save space
    
#     for inputs, labels in train_dl_no_shuffle:
#         # to device
#         inputs = inputs.to(device)
#         labels = labels.to(device)
#         outputs = model(inputs)
#         bottleneck_train_list = model.layer3[1].conv1.register_forward_hook(hook_train)

#     for inputs, labels in test_dl_no_shuffle:
#         inputs = inputs.to(device)
#         labels = labels.to(device)
#         outputs = model(inputs)
#         bottleneck_test_list = model.layer3[1].conv1.register_forward_hook(hook_test)

#     return bottleneck_train_list,bottleneck_test_list


In [7]:
# def train_discrete_model(model, train_dl, test_dl, criterion, optimizer, device, n_epoch = 25):
#     # Train a model with discrete outputs.
#     # Outputs: model; training and testing accuracy/log-loss.
#     # But so far this function is not used because of bad performance on discrete outputs.
#     log_loss_train_list=[]
#     log_loss_test_list=[]
#     accuracy_train_list=[]
#     accuracy_test_list=[]

#     # automatic model searching.
#     best_model_wts = copy.deepcopy(model.state_dict())
#     best_acc = 0.0
    
#     # iterate
#     for epoch in range(n_epoch):
    
#         running_log_loss_train = 0.0
#         running_log_loss_test = 0.0
#         correct_train = 0
#         total_train = 0
#         correct_test = 0
#         total_test = 0

#         # training    
#         for inputs, labels in tqdm(train_dl):
#             # to device
#             inputs = inputs.to(device)
#             labels = labels.to(device)

#             # forward + backward
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)
#             loss.backward()

#             # evaluate prediction
#             _, predicted = torch.max(outputs.data, 1)
#             total_train += labels.size(0)
#             correct_train += (predicted == labels).sum().item()

#             # evaluate log loss
#             running_log_loss_train += loss.item()

#             # optimize
#             with torch.no_grad():
#                 optimizer.step()
#                 optimizer.zero_grad()

#         # testing
#         for inputs, labels in test_dl:
#             # to device
#             inputs = inputs.to(device)
#             labels = labels.to(device)

#             # forward + backward
#             outputs = model(inputs)
#             loss = criterion(outputs, labels)

#             # evaluate log loss
#             running_log_loss_test += loss.item()

#             # evaluate prediction
#             _, predicted = torch.max(outputs.data, 1)
#             total_test += labels.size(0)
#             correct_test += (predicted == labels).sum().item()
        
#         # print
#         print("Epoch {}: Training Loss {}; Testing Loss {}".format(epoch, running_log_loss_train, running_log_loss_test))
#         print("Epoch {}: Training Accuracy {}; Testing Accuracy {}".format(epoch, correct_train/total_train, correct_test/total_test))

#         # append loss here.
#         log_loss_train_list.append(running_log_loss_train)
#         log_loss_test_list.append(running_log_loss_test)
#         accuracy_train_list.append(correct_train/total_train)
#         accuracy_test_list.append(correct_test/total_test)
        
#         if correct_test/total_test > best_acc:
#             best_acc = correct_test/total_test
#             best_model_wts = copy.deepcopy(model.state_dict())

#     # load the best model weights
#     model.load_state_dict(best_model_wts)
#     return model, log_loss_train_list, log_loss_test_list, accuracy_train_list, accuracy_test_list


In [8]:
def train_continuous_model(model, train_dl, test_dl, criterion, optimizer, device, total_mse_train, total_mse_test, n_epoch = 25):
    # This function trains the model with continous outputs.
    # outputs: model, and R2 and MSE for training and testing
    mse_train_list = []
    mse_test_list = []
    r_square_train_list = []
    r_square_test_list = []

    # automatic model searching.
    best_model_wts = copy.deepcopy(model.state_dict())
    best_r_square = 0.0
    
    for epoch in range(n_epoch): 
        running_mse_train = 0.0
        running_mse_test = 0.0
        
        # training
        for inputs, labels in tqdm(train_dl):
            # to device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # forward + backward
            outputs = model(inputs)
            # sw: be careful about the dimension matching at this point...
            loss = criterion(outputs.view(-1), labels) # this .view(-1) seems specific to continuous variables
            loss.backward()

            # performance
            running_mse_train += loss.item()*batch_size

            # optimize
            with torch.no_grad():
                optimizer.step()
                optimizer.zero_grad()

        # testing
        for inputs, labels in test_dl:
            # to device
            inputs = inputs.to(device)
            labels = labels.to(device)

            # forward + backward        
            outputs = model(inputs)
            loss = criterion(outputs.view(-1), labels) # this .view(-1) is specific to continuous variables
            running_mse_test += loss.item()*batch_size # this *batch_size is specific to continuous variables.

        # R square for a batch
        running_r_square_train = 1-running_mse_train/total_mse_train.item()
        running_r_square_test = 1-running_mse_test/total_mse_test.item()
        
        print("Epoch {}: Training MSE {}; Testing MSE {}".format(epoch, running_mse_train, running_mse_test))
        print("Epoch {}: Training R2 {}; Testing R2 {}".format(epoch, running_r_square_train, running_r_square_test))

        mse_train_list.append(running_mse_train)
        mse_test_list.append(running_mse_test)
        r_square_train_list.append(running_r_square_train)
        r_square_test_list.append(running_r_square_test)

        # check overfitting for early stopping. This is designed in an ad-hoc way.
        if epoch > 5 and running_r_square_test < 0.0:
            break # break the for loop.
        
        # store the best performance.
        if running_r_square_test > best_r_square:
            best_r_square = running_r_square_test
            best_model_wts = copy.deepcopy(model.state_dict())

    # load the weights of the best model
    model.load_state_dict(best_model_wts)
    return model, mse_train_list, mse_test_list, r_square_train_list, r_square_test_list


## Train resnet18 for continous outputs.

In [13]:
# set up.
output_list = ['HHVEHCNT_mean_norm', 'HHVEHCNT_P_CAP_mean_norm', 'TRPTRANS_1_mean_norm', 'TRPTRANS_2_mean_norm', 'TRPTRANS_3_mean_norm']
input_var=['R_AGE_IMP_mean', 'HHSIZE_mean', 'HHFAMINC_mean', 'HBHTNRNT_mean', 'HBPPOPDN_mean', 'HBRESDN_mean', 
           'R_SEX_IMP_2_mean', 'EDUC_2_mean', 'HH_RACE_2_mean', 'HOMEOWN_1_mean', 'HOMEOWN_2_mean',
           'HBHUR_R_mean', 'HBHUR_S_mean', 'HBHUR_T_mean','HBHUR_U_mean']
BE_var = ['density', 'diversity', 'design']
image_type = ['bw', 'rgb', 'merge'] # It can be 'rgb', 'bw', 'merge'
output_type = 'continuous' 
num_categories = 1 # Certain category values can cause errors. When output_type = 'continuous', this value needs to be 1.
size = 12000 # size needs to be smaller than the max (18491).

model_name = 'resnet18' 
model_dic = {}


In [17]:
for image_type in image_types:
    performance_continuous = {}
    
    for output_var in output_list:

        print(output_var)

        # data set up
        y_train,y_test,BE_train,BE_test,x_train,x_test,x_train_images,x_test_images = \
            initialize_data(image_type, output_var, output_type, input_var, BE_var, num_categories, size)

        # process data
        x_train_images_norm = x_train_images/255 # very crude processing. It is improvable.
        x_test_images_norm = x_test_images/255

        x_train_torch = torch.from_numpy(x_train_images_norm)
        x_test_torch = torch.from_numpy(x_test_images_norm)
        y_train_torch = torch.from_numpy(y_train)
        y_test_torch = torch.from_numpy(y_test)

        # create data loader: train and test. 
        train_ds = TensorDataset(x_train_torch, y_train_torch)
        batch_size = 100
        train_dl = DataLoader(train_ds, batch_size, shuffle = True)

        test_ds = TensorDataset(x_test_torch, y_test_torch)
        batch_size = 100
        test_dl = DataLoader(test_ds, batch_size, shuffle = True)

        # model set up
        input_channels = 4 # 4 for BW images; 3 for RGB images; 7 for merged images.
        use_pretrained = True # unclear whether True or False is better.
        full_training = True # Fully retraining the network seems to work better.
    #     num_bottleneck = 3 # Used for the bottleneck model

        if model_name == 'bottleneck_resnet18': # It does not work.
            model = bottleneck_resnet18(num_categories, num_bottleneck, input_channels, use_pretrained, full_training)
            model.to(device)
        else: 
            # 'resnet18' and others works
            if image_type == "rgb":
                input_channels = 3
            elif image_type == "bw":
                input_channels = 4
            elif image_type == "merge":
                input_channels = 7
            model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
            model.to(device)

        # training set up
        criterion = nn.MSELoss(reduction='mean')
        optimizer = optim.Adam(model.parameters(), lr=0.001)
        n_epoch = 25

        # create baseline mse
        total_mse_train = criterion(y_train_torch.mean().repeat(y_train_torch.size()), y_train_torch)*y_train_torch.size()[0]
        total_mse_test = criterion(y_test_torch.mean().repeat(y_test_torch.size()), y_test_torch)*y_test_torch.size()[0]
        print(total_mse_train)
        print(total_mse_test)

        # training here.
        model, mse_train_list, mse_test_list, r_square_train_list, r_square_test_list = \
            train_continuous_model(model, train_dl, test_dl, criterion, optimizer, device, total_mse_train, total_mse_test, n_epoch)

        # save models.
        PATH = './models/'+model_name+'_'+output_var+'_'+image_type+'.pth'
        torch.save(model.state_dict(), PATH)
        model_dic[output_var]=model.state_dict()

        # save performance
        performance_continuous[output_var] = {}
        performance_continuous[output_var]['mse_train_list']=mse_train_list
        performance_continuous[output_var]['mse_test_list']=mse_test_list
        performance_continuous[output_var]['r_square_train_list']=r_square_train_list
        performance_continuous[output_var]['r_square_test_list']=r_square_test_list

        print("Printing performance_continuous")
        print(performance_continuous)
        %store performance_continuous
        
                
        with open('outputs/performance_continuous_'+image_type+'_'+model_name+'.pickle', 'wb') as h:
            pickle.dump(performance_continuous, h, protocol=pickle.HIGHEST_PROTOCOL)


HHVEHCNT_mean_norm
tensor(10700.4521)
tensor(2549.4304)


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12968.76671910286; Testing MSE 2314.4653499126434
Epoch 0: Training R2 -0.211983058210917; Testing R2 0.09216375084142592


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9803.879088163376; Testing MSE 2293.4097290039062
Epoch 1: Training R2 0.08378833415979003; Testing R2 0.10042270183855984


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9795.471370220184; Testing MSE 2329.2478561401367
Epoch 2: Training R2 0.08457406898917474; Testing R2 0.08636539442738966


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9716.708314418793; Testing MSE 2347.6718723773956
Epoch 3: Training R2 0.09193479120060877; Testing R2 0.07913867582652523


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9511.821076273918; Testing MSE 2339.412635564804
Epoch 4: Training R2 0.11108232209955238; Testing R2 0.08237831584495914


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9420.36070227623; Testing MSE 2312.281894683838
Epoch 5: Training R2 0.11962965942034431; Testing R2 0.09302019909423698


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 9256.327441334724; Testing MSE 2332.5907111167908
Epoch 6: Training R2 0.1349592229440183; Testing R2 0.08505417802762749


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 9344.197827577591; Testing MSE 2260.4303300380707
Epoch 7: Training R2 0.12674738432038613; Testing R2 0.11335868891556589


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 8872.700893878937; Testing MSE 2326.6711831092834
Epoch 8: Training R2 0.17081065633525172; Testing R2 0.08737608019104826


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 8167.845937609673; Testing MSE 2359.0803802013397
Epoch 9: Training R2 0.23668216779023155; Testing R2 0.07466375164942463


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 7423.487824201584; Testing MSE 2771.1448907852173
Epoch 10: Training R2 0.3062454070891224; Testing R2 -0.08696627651839828
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12968.76671910286, 9803.879088163376, 9795.471370220184, 9716.708314418793, 9511.821076273918, 9420.36070227623, 9256.327441334724, 9344.197827577591, 8872.700893878937, 8167.845937609673, 7423.487824201584], 'mse_test_list': [2314.4653499126434, 2293.4097290039062, 2329.2478561401367, 2347.6718723773956, 2339.412635564804, 2312.281894683838, 2332.5907111167908, 2260.4303300380707, 2326.6711831092834, 2359.0803802013397, 2771.1448907852173], 'r_square_train_list': [-0.211983058210917, 0.08378833415979003, 0.08457406898917474, 0.09193479120060877, 0.11108232209955238, 0.11962965942034431, 0.1349592229440183, 0.12674738432038613, 0.17081065633525172, 0.23668216779023155, 0.3062454070891224], 'r_square_test_list': [0.09216375084142592, 0.10042270183855984, 0.08636539442

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11759.440505504608; Testing MSE 2224.727487564087
Epoch 0: Training R2 -0.0980998187416513; Testing R2 0.09045047769151782


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9766.517919301987; Testing MSE 2219.6477830410004
Epoch 1: Training R2 0.08799984557918394; Testing R2 0.09252724567693082


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9711.34986281395; Testing MSE 2270.866107940674
Epoch 2: Training R2 0.09315145400829972; Testing R2 0.0715873313699672


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9565.398448705673; Testing MSE 2290.711635351181
Epoch 3: Training R2 0.1067804375728123; Testing R2 0.06347375787519693


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9640.703243017197; Testing MSE 2168.278819322586
Epoch 4: Training R2 0.09974845497592322; Testing R2 0.11352874661254586


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9334.938263893127; Testing MSE 2170.409804582596
Epoch 5: Training R2 0.12830087360472742; Testing R2 0.11265752232277426


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8970.03358900547; Testing MSE 2249.590343236923
Epoch 6: Training R2 0.16237577344069742; Testing R2 0.08028563789569465


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8505.051308870316; Testing MSE 2264.1828775405884
Epoch 7: Training R2 0.20579594783551514; Testing R2 0.07431967906282144


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7827.041923999786; Testing MSE 2289.709311723709
Epoch 8: Training R2 0.26910865240534176; Testing R2 0.0638835442340484


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 6664.316242933273; Testing MSE 2373.796057701111
Epoch 9: Training R2 0.3776842992677699; Testing R2 0.029505823787953855


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 5663.7800842523575; Testing MSE 2566.187047958374
Epoch 10: Training R2 0.4711146435071649; Testing R2 -0.04915061133190424
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12968.76671910286, 9803.879088163376, 9795.471370220184, 9716.708314418793, 9511.821076273918, 9420.36070227623, 9256.327441334724, 9344.197827577591, 8872.700893878937, 8167.845937609673, 7423.487824201584], 'mse_test_list': [2314.4653499126434, 2293.4097290039062, 2329.2478561401367, 2347.6718723773956, 2339.412635564804, 2312.281894683838, 2332.5907111167908, 2260.4303300380707, 2326.6711831092834, 2359.0803802013397, 2771.1448907852173], 'r_square_train_list': [-0.211983058210917, 0.08378833415979003, 0.08457406898917474, 0.09193479120060877, 0.11108232209955238, 0.11962965942034431, 0.1349592229440183, 0.12674738432038613, 0.17081065633525172, 0.23668216779023155, 0.3062454070891224], 'r_square_test_list': [0.09216375084142592, 0.10042270183855984, 0.08636539442

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12230.166044831276; Testing MSE 2291.6292428970337
Epoch 0: Training R2 -0.0507455393520595; Testing R2 0.24593072060881294


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9215.357342362404; Testing MSE 2255.251407623291
Epoch 1: Training R2 0.20826948828591074; Testing R2 0.2579009850465298


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 8983.33434164524; Testing MSE 2160.932868719101
Epoch 2: Training R2 0.22820357030384641; Testing R2 0.28893685740029695


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8734.77349281311; Testing MSE 2111.8488252162933
Epoch 3: Training R2 0.24955848913410228; Testing R2 0.3050881477665244


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8743.052563071251; Testing MSE 2120.903104543686
Epoch 4: Training R2 0.24884719902472552; Testing R2 0.30210880287075825


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8348.892933130264; Testing MSE 2233.1446170806885
Epoch 5: Training R2 0.28271112789003994; Testing R2 0.26517530817963764


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8404.468074440956; Testing MSE 2196.3607251644135
Epoch 6: Training R2 0.2779364313227968; Testing R2 0.2772791870930711


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8076.0271191596985; Testing MSE 2259.3310326337814
Epoch 7: Training R2 0.30615418956395823; Testing R2 0.25655856898975193


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7642.847928404808; Testing MSE 2159.1048181056976
Epoch 8: Training R2 0.34337045471991634; Testing R2 0.2895383843763609


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 6612.923592329025; Testing MSE 2267.4333304166794
Epoch 9: Training R2 0.4318556312935504; Testing R2 0.253892477225781


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 5841.098818182945; Testing MSE 2376.7027139663696
Epoch 10: Training R2 0.4981663776581222; Testing R2 0.21793697283159297


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 4618.132694065571; Testing MSE 2496.234178543091
Epoch 11: Training R2 0.6032365946105802; Testing R2 0.17860469181915728


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 3747.4667713046074; Testing MSE 2488.228678703308
Epoch 12: Training R2 0.6780392040971097; Testing R2 0.18123893185343076


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 2860.5041205883026; Testing MSE 2548.551285266876
Epoch 13: Training R2 0.7542419347383595; Testing R2 0.16138954975840936


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 2252.7525819838047; Testing MSE 2403.438013792038
Epoch 14: Training R2 0.8064564521768132; Testing R2 0.20913962119352314


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 2153.988979756832; Testing MSE 2819.208914041519
Epoch 15: Training R2 0.8149416529590373; Testing R2 0.07232863219313601


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 2071.227313578129; Testing MSE 2561.3675355911255
Epoch 16: Training R2 0.8220520594120526; Testing R2 0.15717231406181553


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 1795.3870363533497; Testing MSE 2548.621815443039
Epoch 17: Training R2 0.8457506698646935; Testing R2 0.1613663415368879


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 1558.6245216429234; Testing MSE 2426.9540190696716
Epoch 18: Training R2 0.8660919436712657; Testing R2 0.20140159061600826


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 1301.457180455327; Testing MSE 2490.904539823532
Epoch 19: Training R2 0.8881862828347227; Testing R2 0.18035842961995108


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 1224.8165853321552; Testing MSE 2415.772485733032
Epoch 20: Training R2 0.8947708020606896; Testing R2 0.20508091649814308


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 1115.707864239812; Testing MSE 2576.003336906433
Epoch 21: Training R2 0.9041447959682081; Testing R2 0.15235634822207322


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 1086.6693586111069; Testing MSE 2410.4267060756683
Epoch 22: Training R2 0.9066396173914795; Testing R2 0.20683996553563766


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 1030.8579668402672; Testing MSE 2424.946653842926
Epoch 23: Training R2 0.9114346112397457; Testing R2 0.20206212174454863


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 1061.6857629269361; Testing MSE 2459.397292137146
Epoch 24: Training R2 0.908786064269297; Testing R2 0.19072600877007473
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12968.76671910286, 9803.879088163376, 9795.471370220184, 9716.708314418793, 9511.821076273918, 9420.36070227623, 9256.327441334724, 9344.197827577591, 8872.700893878937, 8167.845937609673, 7423.487824201584], 'mse_test_list': [2314.4653499126434, 2293.4097290039062, 2329.2478561401367, 2347.6718723773956, 2339.412635564804, 2312.281894683838, 2332.5907111167908, 2260.4303300380707, 2326.6711831092834, 2359.0803802013397, 2771.1448907852173], 'r_square_train_list': [-0.211983058210917, 0.08378833415979003, 0.08457406898917474, 0.09193479120060877, 0.11108232209955238, 0.11962965942034431, 0.1349592229440183, 0.12674738432038613, 0.17081065633525172, 0.23668216779023155, 0.3062454070891224], 'r_square_test_list': [0.09216375084142592, 0.10042270183855984, 0.0863653944273

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11870.062518119812; Testing MSE 2270.030653476715
Epoch 0: Training R2 -0.015363088040754258; Testing R2 0.23359574404229322


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 8421.525579690933; Testing MSE 2003.9832383394241
Epoch 1: Training R2 0.27962416326315764; Testing R2 0.3234182629300797


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 8144.220009446144; Testing MSE 1945.943021774292
Epoch 2: Training R2 0.30334483362229325; Testing R2 0.3430137115307832


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8039.416149258614; Testing MSE 1873.7303376197815
Epoch 3: Training R2 0.3123097376365863; Testing R2 0.3673940468295084


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 7723.625993728638; Testing MSE 1780.141305923462
Epoch 4: Training R2 0.33932237274295773; Testing R2 0.3989914316900013


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 7568.865957856178; Testing MSE 1856.9938123226166
Epoch 5: Training R2 0.35256051935666644; Testing R2 0.3730446067450961


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 7380.3219348192215; Testing MSE 1905.3146809339523
Epoch 6: Training R2 0.368688542370041; Testing R2 0.356730589443796


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 6994.640618562698; Testing MSE 1922.438132762909
Epoch 7: Training R2 0.4016796552370586; Testing R2 0.35094939598797215


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 6499.428540468216; Testing MSE 2004.2107075452805
Epoch 8: Training R2 0.44404001046530883; Testing R2 0.3233414651268428


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 5671.397563815117; Testing MSE 2065.1791632175446
Epoch 9: Training R2 0.5148696365236782; Testing R2 0.30275738894495197


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 4566.176956892014; Testing MSE 2363.508051633835
Epoch 10: Training R2 0.6094100154558586; Testing R2 0.20203604872551606


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 3977.96171605587; Testing MSE 2052.280533313751
Epoch 11: Training R2 0.6597258450866048; Testing R2 0.3071122045238298


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 3142.7128225564957; Testing MSE 2198.265951871872
Epoch 12: Training R2 0.7311728904995115; Testing R2 0.25782483216198404


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 2404.581528902054; Testing MSE 2169.135242700577
Epoch 13: Training R2 0.7943125132740684; Testing R2 0.2676598973641887


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 2014.2502509057522; Testing MSE 2142.5213515758514
Epoch 14: Training R2 0.8277013830614194; Testing R2 0.2766452383305519


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 1770.815448462963; Testing MSE 2203.953433036804
Epoch 15: Training R2 0.8485247538202159; Testing R2 0.25590463352334125


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 1487.9779882729053; Testing MSE 2095.045620203018
Epoch 16: Training R2 0.8727186210853455; Testing R2 0.2926739216979404


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 1417.8446125239134; Testing MSE 2128.606081008911
Epoch 17: Training R2 0.8787178178769811; Testing R2 0.28134329056564944


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 1313.2252871990204; Testing MSE 2082.340317964554
Epoch 18: Training R2 0.8876669368111457; Testing R2 0.2969634662879549


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 1116.0281911492348; Testing MSE 2052.898386120796
Epoch 19: Training R2 0.9045351421885084; Testing R2 0.30690360600990685


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 1164.945562928915; Testing MSE 2145.699602365494
Epoch 20: Training R2 0.9003507586948887; Testing R2 0.27557220218984946


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 1181.1119575053453; Testing MSE 2096.388679742813
Epoch 21: Training R2 0.8989678881081035; Testing R2 0.29222048000290124


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 1110.7569471001625; Testing MSE 2064.2991721630096
Epoch 22: Training R2 0.9049860434897704; Testing R2 0.30305449017059727


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 1063.326845318079; Testing MSE 2023.5982239246368
Epoch 23: Training R2 0.909043206165875; Testing R2 0.316795881681601


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 904.2911317199469; Testing MSE 2103.645047545433
Epoch 24: Training R2 0.9226470935102984; Testing R2 0.2897705962719459
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12968.76671910286, 9803.879088163376, 9795.471370220184, 9716.708314418793, 9511.821076273918, 9420.36070227623, 9256.327441334724, 9344.197827577591, 8872.700893878937, 8167.845937609673, 7423.487824201584], 'mse_test_list': [2314.4653499126434, 2293.4097290039062, 2329.2478561401367, 2347.6718723773956, 2339.412635564804, 2312.281894683838, 2332.5907111167908, 2260.4303300380707, 2326.6711831092834, 2359.0803802013397, 2771.1448907852173], 'r_square_train_list': [-0.211983058210917, 0.08378833415979003, 0.08457406898917474, 0.09193479120060877, 0.11108232209955238, 0.11962965942034431, 0.1349592229440183, 0.12674738432038613, 0.17081065633525172, 0.23668216779023155, 0.3062454070891224], 'r_square_test_list': [0.09216375084142592, 0.10042270183855984, 0.08636539442738

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12753.823333978653; Testing MSE 2612.2389376163483
Epoch 0: Training R2 -0.11219592880793883; Testing R2 0.10515012542627189


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9213.852804899216; Testing MSE 2603.7721902132034
Epoch 1: Training R2 0.19650685838317195; Testing R2 0.10805049864353389


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9033.557894825935; Testing MSE 2483.541724085808
Epoch 2: Training R2 0.21222945855704012; Testing R2 0.14923670714259807


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9198.729892075062; Testing MSE 2431.369113922119
Epoch 3: Training R2 0.19782565053155698; Testing R2 0.16710898252632012


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8874.68115389347; Testing MSE 2445.860806107521
Epoch 4: Training R2 0.22608428936503722; Testing R2 0.1621447012166245


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8611.23317182064; Testing MSE 2563.7428492307663
Epoch 5: Training R2 0.24905824513039931; Testing R2 0.12176296967426148


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8459.611907601357; Testing MSE 2516.6953921318054
Epoch 6: Training R2 0.2622803627942164; Testing R2 0.13787957006550844


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8097.447144985199; Testing MSE 2671.098831295967
Epoch 7: Training R2 0.2938629058475023; Testing R2 0.08498705086285097


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7451.7426908016205; Testing MSE 2693.216636776924
Epoch 8: Training R2 0.3501715002470216; Testing R2 0.07741036437545756


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 7231.0048609972; Testing MSE 2715.747219324112
Epoch 9: Training R2 0.3694209213197055; Testing R2 0.06969227677018619


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 6841.959276795387; Testing MSE 2812.7094835042953
Epoch 10: Training R2 0.40334760381633217; Testing R2 0.036476835146306286


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 6345.289148390293; Testing MSE 2805.4447889328003
Epoch 11: Training R2 0.44665967426825504; Testing R2 0.03896543254545948


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 5898.714861273766; Testing MSE 2821.4260637760162
Epoch 12: Training R2 0.48560314173169894; Testing R2 0.0334908790568611


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 5369.727790355682; Testing MSE 3083.9070945978165
Epoch 13: Training R2 0.5317334080260174; Testing R2 -0.05642475389944046
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12968.76671910286, 9803.879088163376, 9795.471370220184, 9716.708314418793, 9511.821076273918, 9420.36070227623, 9256.327441334724, 9344.197827577591, 8872.700893878937, 8167.845937609673, 7423.487824201584], 'mse_test_list': [2314.4653499126434, 2293.4097290039062, 2329.2478561401367, 2347.6718723773956, 2339.412635564804, 2312.281894683838, 2332.5907111167908, 2260.4303300380707, 2326.6711831092834, 2359.0803802013397, 2771.1448907852173], 'r_square_train_list': [-0.211983058210917, 0.08378833415979003, 0.08457406898917474, 0.09193479120060877, 0.11108232209955238, 0.11962965942034431, 0.1349592229440183, 0.12674738432038613, 0.17081065633525172, 0.23668216779023155, 0.3062454070891224], 'r_square_test_list': [0.09216375084142592, 0.10042270183855984, 0.08636539442

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12242.481911182404; Testing MSE 2261.9253873825073
Epoch 0: Training R2 -0.144108841510036; Testing R2 0.11277226093041515


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9517.24494099617; Testing MSE 2223.0931401252747
Epoch 1: Training R2 0.11057544027371813; Testing R2 0.12800399542051466


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9246.151149272919; Testing MSE 2201.210141181946
Epoch 2: Training R2 0.13591023808998026; Testing R2 0.13658748088155315


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9110.277658700943; Testing MSE 2364.5673274993896
Epoch 3: Training R2 0.1486081585784912; Testing R2 0.07251152687985507


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8838.356119394302; Testing MSE 2357.3063910007477
Epoch 4: Training R2 0.17402031271315055; Testing R2 0.07535958911442453


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8503.915193676949; Testing MSE 2310.520112514496
Epoch 5: Training R2 0.2052751532636211; Testing R2 0.09371124841865675


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8011.842387914658; Testing MSE 2405.54701089859
Epoch 6: Training R2 0.25126132272040846; Testing R2 0.05643747242479913


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 7045.496594905853; Testing MSE 2589.069586992264
Epoch 7: Training R2 0.34157019748603334; Testing R2 -0.015548244329650407
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12242.481911182404, 9517.24494099617, 9246.151149272919, 9110.277658700943, 8838.356119394302, 8503.915193676949, 8011.842387914658, 7045.496594905853], 'mse_test_list': [2261.9253873825073, 2223.0931401252747, 2201.210141181946, 2364.5673274993896, 2357.3063910007477, 2310.520112514496, 2405.54701089859, 2589.069586992264], 'r_square_train_list': [-0.144108841510036, 0.11057544027371813, 0.13591023808998026, 0.1486081585784912, 0.17402031271315055, 0.2052751532636211, 0.25126132272040846, 0.34157019748603334], 'r_square_test_list': [0.11277226093041515, 0.12800399542051466, 0.13658748088155315, 0.07251152687985507, 0.07535958911442453, 0.09371124841865675, 0.05643747242479913, -0.015548244329650407]}}
Stored 'performance_continuous' (dict)
HHVEHCNT_P_CAP_mean_norm
te

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12115.843003988266; Testing MSE 2204.591280221939
Epoch 0: Training R2 -0.13138078298487188; Testing R2 0.09868289171594669


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9845.910412073135; Testing MSE 2239.971625804901
Epoch 1: Training R2 0.08058615256541879; Testing R2 0.08421811946677249


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9657.149475812912; Testing MSE 2209.0087175369263
Epoch 2: Training R2 0.09821270119210124; Testing R2 0.09687688265545114


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9634.848350286484; Testing MSE 2184.148931503296
Epoch 3: Training R2 0.10029518648439406; Testing R2 0.10704046747110751


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9400.747740268707; Testing MSE 2228.985059261322
Epoch 4: Training R2 0.1221555664325299; Testing R2 0.08870982751073797


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9377.84526348114; Testing MSE 2218.7969505786896
Epoch 5: Training R2 0.12429420607251873; Testing R2 0.09287509693690998


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8957.459890842438; Testing MSE 2333.1436693668365
Epoch 6: Training R2 0.16354990886553122; Testing R2 0.046125998886623276


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8487.666934728622; Testing MSE 2312.321239709854
Epoch 7: Training R2 0.20741930551864585; Testing R2 0.05463896555498016


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7671.99302315712; Testing MSE 2522.8414952754974
Epoch 8: Training R2 0.2835871618065098; Testing R2 -0.031429372682547196
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12242.481911182404, 9517.24494099617, 9246.151149272919, 9110.277658700943, 8838.356119394302, 8503.915193676949, 8011.842387914658, 7045.496594905853], 'mse_test_list': [2261.9253873825073, 2223.0931401252747, 2201.210141181946, 2364.5673274993896, 2357.3063910007477, 2310.520112514496, 2405.54701089859, 2589.069586992264], 'r_square_train_list': [-0.144108841510036, 0.11057544027371813, 0.13591023808998026, 0.1486081585784912, 0.17402031271315055, 0.2052751532636211, 0.25126132272040846, 0.34157019748603334], 'r_square_test_list': [0.11277226093041515, 0.12800399542051466, 0.13658748088155315, 0.07251152687985507, 0.07535958911442453, 0.09371124841865675, 0.05643747242479913, -0.015548244329650407]}, 'HHVEHCNT_P_CAP_mean_norm': {'mse_train_list': [12115.843003988266,

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12568.739041686058; Testing MSE 2360.640996694565
Epoch 0: Training R2 -0.07983378434285981; Testing R2 0.22322214171590327


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9146.030738949776; Testing MSE 2197.747829556465
Epoch 1: Training R2 0.21422563140181172; Testing R2 0.27682275514073984


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9004.806604981422; Testing MSE 2296.3084846735
Epoch 2: Training R2 0.22635879691012617; Testing R2 0.24439099838480416


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8742.689135670662; Testing MSE 2155.450761318207
Epoch 3: Training R2 0.24887842261716464; Testing R2 0.290740765597991


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8446.132493019104; Testing MSE 2105.8695912361145
Epoch 4: Training R2 0.2743568640617956; Testing R2 0.3070556373474501


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8121.264544129372; Testing MSE 2175.654548406601
Epoch 5: Training R2 0.30226765013968393; Testing R2 0.28409263295701526


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 7542.963668704033; Testing MSE 2299.6438413858414
Epoch 6: Training R2 0.351951935948164; Testing R2 0.2432934866296267


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 6680.124044418335; Testing MSE 2503.7658989429474
Epoch 7: Training R2 0.42608215487332435; Testing R2 0.17612635070350813


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 5470.423409342766; Testing MSE 2421.0475623607635
Epoch 8: Training R2 0.5300126772879563; Testing R2 0.20334513256025877


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 3826.299634575844; Testing MSE 2456.408739089966
Epoch 9: Training R2 0.6712663377980779; Testing R2 0.19170940346609477


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 2765.8651798963547; Testing MSE 2512.5910460948944
Epoch 10: Training R2 0.7623727683195678; Testing R2 0.17322240261765764


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 2026.4368996024132; Testing MSE 2624.8215079307556
Epoch 11: Training R2 0.8259001942221769; Testing R2 0.13629254420158987


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 1603.7531428039074; Testing MSE 2654.3321430683136
Epoch 12: Training R2 0.8622147520445786; Testing R2 0.126581957970624


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 1042.085998877883; Testing MSE 2385.3937178850174
Epoch 13: Training R2 0.9104699632917145; Testing R2 0.21507716508457897


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 787.1122887358069; Testing MSE 2388.5120689868927
Epoch 14: Training R2 0.9323758382897941; Testing R2 0.21405105984719464


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 697.2983855754137; Testing MSE 2541.786879301071
Epoch 15: Training R2 0.9400921323917425; Testing R2 0.16361540315417755


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 633.6351016536355; Testing MSE 2456.763982772827
Epoch 16: Training R2 0.9455617156628202; Testing R2 0.19159250918715598


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 685.0380547344685; Testing MSE 2454.828494787216
Epoch 17: Training R2 0.9411454695169198; Testing R2 0.19222938883734542


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 642.7915656939149; Testing MSE 2483.9770674705505
Epoch 18: Training R2 0.9447750449249666; Testing R2 0.18263794062785366


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 594.8238087818027; Testing MSE 2427.1545946598053
Epoch 19: Training R2 0.948896158769486; Testing R2 0.20133559045861615


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 491.2542378529906; Testing MSE 2427.799218893051
Epoch 20: Training R2 0.9577942607467725; Testing R2 0.2011234744138639


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 479.40747011452913; Testing MSE 2409.669017791748
Epoch 21: Training R2 0.9588120669082996; Testing R2 0.20708928573436958


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 475.6238227710128; Testing MSE 2411.313784122467
Epoch 22: Training R2 0.9591371361309174; Testing R2 0.20654806914550972


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 501.6646867617965; Testing MSE 2377.5989174842834
Epoch 23: Training R2 0.956899854839814; Testing R2 0.21764207367064203


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 500.5396196618676; Testing MSE 2423.2143938541412
Epoch 24: Training R2 0.9569965141355578; Testing R2 0.2026321284528806
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12242.481911182404, 9517.24494099617, 9246.151149272919, 9110.277658700943, 8838.356119394302, 8503.915193676949, 8011.842387914658, 7045.496594905853], 'mse_test_list': [2261.9253873825073, 2223.0931401252747, 2201.210141181946, 2364.5673274993896, 2357.3063910007477, 2310.520112514496, 2405.54701089859, 2589.069586992264], 'r_square_train_list': [-0.144108841510036, 0.11057544027371813, 0.13591023808998026, 0.1486081585784912, 0.17402031271315055, 0.2052751532636211, 0.25126132272040846, 0.34157019748603334], 'r_square_test_list': [0.11277226093041515, 0.12800399542051466, 0.13658748088155315, 0.07251152687985507, 0.07535958911442453, 0.09371124841865675, 0.05643747242479913, -0.015548244329650407]}, 'HHVEHCNT_P_CAP_mean_norm': {'mse_train_list': [12115.843003988266,

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11478.799536824226; Testing MSE 1914.6355509757996
Epoch 0: Training R2 0.018105479485143183; Testing R2 0.35358369164382153


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 8120.51086127758; Testing MSE 2087.1659606695175
Epoch 1: Training R2 0.305372910042489; Testing R2 0.2953342402239143


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 7919.544404745102; Testing MSE 1972.1661925315857
Epoch 2: Training R2 0.32256354586146085; Testing R2 0.3341602849736066


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 7508.014380931854; Testing MSE 1889.910614490509
Epoch 3: Training R2 0.35776575268746935; Testing R2 0.3619312866516635


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 7254.8455864191055; Testing MSE 1911.7294311523438
Epoch 4: Training R2 0.3794217674388337; Testing R2 0.3545648513464926


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 6775.137779116631; Testing MSE 1844.693273305893
Epoch 5: Training R2 0.4204558900339227; Testing R2 0.37719749579906214


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 5938.940522074699; Testing MSE 1917.0822083950043
Epoch 6: Training R2 0.4919840582996945; Testing R2 0.3527576549310233


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 4948.3507335186005; Testing MSE 2088.644963502884
Epoch 7: Training R2 0.576718936920138; Testing R2 0.2948349015632977


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 3834.7768262028694; Testing MSE 2004.547268152237
Epoch 8: Training R2 0.6719738557183927; Testing R2 0.32322783605264327


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 2967.5786301493645; Testing MSE 2009.902384877205
Epoch 9: Training R2 0.7461538389277591; Testing R2 0.32141984978475746


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 2154.1549153625965; Testing MSE 2054.0228813886642
Epoch 10: Training R2 0.8157339623408159; Testing R2 0.30652395564801493


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 1399.5881043374538; Testing MSE 1976.0539501905441
Epoch 11: Training R2 0.8802794764192801; Testing R2 0.3328477062155205


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 978.138867020607; Testing MSE 1928.857970237732
Epoch 12: Training R2 0.9163301710471493; Testing R2 0.348781940338981


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 764.9358779191971; Testing MSE 1926.3222515583038
Epoch 13: Training R2 0.9345675177387165; Testing R2 0.34963804577740254


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 587.2177205979824; Testing MSE 1894.1909909248352
Epoch 14: Training R2 0.9497694980773308; Testing R2 0.3604861525468248


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 537.0086846873164; Testing MSE 1984.8793029785156
Epoch 15: Training R2 0.9540643702965783; Testing R2 0.32986810418826584


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 462.99963649362326; Testing MSE 1919.0093994140625
Epoch 16: Training R2 0.9603950914366054; Testing R2 0.35210699966485437


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 475.4259740933776; Testing MSE 1933.6453974246979
Epoch 17: Training R2 0.9593321446746096; Testing R2 0.3471656165393182


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 415.0498826056719; Testing MSE 1918.1413352489471
Epoch 18: Training R2 0.964496705063673; Testing R2 0.3524000741524477


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 393.6796458438039; Testing MSE 1933.3669245243073
Epoch 19: Training R2 0.966324711408001; Testing R2 0.34725963413141614


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 504.66566886752844; Testing MSE 1860.6671154499054
Epoch 20: Training R2 0.9568309862574631; Testing R2 0.3718044317959759


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 562.8022549673915; Testing MSE 1893.652468919754
Epoch 21: Training R2 0.9518579927706643; Testing R2 0.360667967517468


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 469.0253982320428; Testing MSE 1908.2927465438843
Epoch 22: Training R2 0.9598796488231247; Testing R2 0.3557251395154095


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 444.2909261211753; Testing MSE 1859.4866514205933
Epoch 23: Training R2 0.9619954312754253; Testing R2 0.3722029782449783


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 416.66394248604774; Testing MSE 1858.1558763980865
Epoch 24: Training R2 0.9643586386615863; Testing R2 0.37265227246019483
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12242.481911182404, 9517.24494099617, 9246.151149272919, 9110.277658700943, 8838.356119394302, 8503.915193676949, 8011.842387914658, 7045.496594905853], 'mse_test_list': [2261.9253873825073, 2223.0931401252747, 2201.210141181946, 2364.5673274993896, 2357.3063910007477, 2310.520112514496, 2405.54701089859, 2589.069586992264], 'r_square_train_list': [-0.144108841510036, 0.11057544027371813, 0.13591023808998026, 0.1486081585784912, 0.17402031271315055, 0.2052751532636211, 0.25126132272040846, 0.34157019748603334], 'r_square_test_list': [0.11277226093041515, 0.12800399542051466, 0.13658748088155315, 0.07251152687985507, 0.07535958911442453, 0.09371124841865675, 0.05643747242479913, -0.015548244329650407]}, 'HHVEHCNT_P_CAP_mean_norm': {'mse_train_list': [12115.84300398826

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12461.463016271591; Testing MSE 2524.8239248991013
Epoch 0: Training R2 -0.08670067561335393; Testing R2 0.13509505582274917


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9327.404725551605; Testing MSE 2411.4327639341354
Epoch 1: Training R2 0.18660457413861575; Testing R2 0.17393838853099597


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9617.104521393776; Testing MSE 2432.9877257347107
Epoch 2: Training R2 0.16134133149562402; Testing R2 0.16655451005574173


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9268.591833114624; Testing MSE 2369.6746349334717
Epoch 3: Training R2 0.19173334672836628; Testing R2 0.18824307404830742


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9904.560640454292; Testing MSE 2936.2110406160355
Epoch 4: Training R2 0.13627374846910723; Testing R2 -0.005829919913376447


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9322.053015232086; Testing MSE 2448.0747550725937
Epoch 5: Training R2 0.18707126951879016; Testing R2 0.16138628975388303


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8854.563960433006; Testing MSE 2353.8623690605164
Epoch 6: Training R2 0.22783860727268468; Testing R2 0.1936597317396792


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8788.888430595398; Testing MSE 2469.754785299301
Epoch 7: Training R2 0.23356583549240517; Testing R2 0.1539595677761353


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 8480.80005645752; Testing MSE 2412.5748991966248
Epoch 8: Training R2 0.26043265232499424; Testing R2 0.17354713810529132


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 8226.281651854515; Testing MSE 2729.5516192913055
Epoch 9: Training R2 0.28262790515180647; Testing R2 0.06496343462587584


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 8076.953659951687; Testing MSE 2462.806051969528
Epoch 10: Training R2 0.29565003822533875; Testing R2 0.1563399293345873


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 7678.461866080761; Testing MSE 2600.2962172031403
Epoch 11: Training R2 0.33040047652142546; Testing R2 0.10924122969316563


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 6892.516948282719; Testing MSE 2639.993268251419
Epoch 12: Training R2 0.3989387269700917; Testing R2 0.09564258806817183


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 6348.155775666237; Testing MSE 2695.492497086525
Epoch 13: Training R2 0.4464096903142496; Testing R2 0.07663074453162666


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 5683.413264155388; Testing MSE 2716.4568811655045
Epoch 14: Training R2 0.5043784966594205; Testing R2 0.06944917465551537


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 5102.76248306036; Testing MSE 2839.9651288986206
Epoch 15: Training R2 0.5550140917968027; Testing R2 0.027140127653233925


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 4064.4277706742287; Testing MSE 2803.6818236112595
Epoch 16: Training R2 0.6455619698420569; Testing R2 0.03956935482612889


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 3741.2643998861313; Testing MSE 2676.570585370064
Epoch 17: Training R2 0.6737434002977232; Testing R2 0.08311264405557328


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 3239.8665711283684; Testing MSE 2698.242288827896
Epoch 18: Training R2 0.717467749401088; Testing R2 0.07568877449992839


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 2633.5545778274536; Testing MSE 2726.3344198465347
Epoch 19: Training R2 0.7703411280639505; Testing R2 0.06606552007379385


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 2082.2483964264393; Testing MSE 3108.057963848114
Epoch 20: Training R2 0.8184177302266353; Testing R2 -0.06469788772662222
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [12242.481911182404, 9517.24494099617, 9246.151149272919, 9110.277658700943, 8838.356119394302, 8503.915193676949, 8011.842387914658, 7045.496594905853], 'mse_test_list': [2261.9253873825073, 2223.0931401252747, 2201.210141181946, 2364.5673274993896, 2357.3063910007477, 2310.520112514496, 2405.54701089859, 2589.069586992264], 'r_square_train_list': [-0.144108841510036, 0.11057544027371813, 0.13591023808998026, 0.1486081585784912, 0.17402031271315055, 0.2052751532636211, 0.25126132272040846, 0.34157019748603334], 'r_square_test_list': [0.11277226093041515, 0.12800399542051466, 0.13658748088155315, 0.07251152687985507, 0.07535958911442453, 0.09371124841865675, 0.05643747242479913, -0.015548244329650407]}, 'HHVEHCNT_P_CAP_mean_norm': {'mse_train_list': [12115.84300398826

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11864.913165569305; Testing MSE 2237.79656291008
Epoch 0: Training R2 -0.10882353390102706; Testing R2 0.12223665905004177


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9525.501376390457; Testing MSE 2234.9283814430237
Epoch 1: Training R2 0.10980384340288007; Testing R2 0.12336168738760434


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9367.49469935894; Testing MSE 2269.645929336548
Epoch 2: Training R2 0.12457019858485141; Testing R2 0.10974392099467489


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9366.35686159134; Testing MSE 2323.048597574234
Epoch 3: Training R2 0.1246765340697279; Testing R2 0.08879701935720141


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9298.717707395554; Testing MSE 2334.691834449768
Epoch 4: Training R2 0.1309976832377714; Testing R2 0.08423002400618074


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9170.300447940826; Testing MSE 2219.8149621486664
Epoch 5: Training R2 0.14299878914183162; Testing R2 0.12928984262426324


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 9123.222714662552; Testing MSE 2210.451817512512
Epoch 6: Training R2 0.1473983913852891; Testing R2 0.13296248438886615


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 9093.797519803047; Testing MSE 2196.477222442627
Epoch 7: Training R2 0.15014829339422442; Testing R2 0.13844394211396593


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 9001.711189746857; Testing MSE 2212.856638431549
Epoch 8: Training R2 0.15875412880928552; Testing R2 0.132019206666813


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 8855.008047819138; Testing MSE 2243.282848596573
Epoch 9: Training R2 0.17246412347985107; Testing R2 0.12008469379395093


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 8564.586502313614; Testing MSE 2254.8877000808716
Epoch 10: Training R2 0.19960517709859293; Testing R2 0.11553275490061399


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 8143.000993132591; Testing MSE 2421.2916135787964
Epoch 11: Training R2 0.23900402710350444; Testing R2 0.050261738991490224


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 7176.059594750404; Testing MSE 2359.1983437538147
Epoch 12: Training R2 0.3293685635706276; Testing R2 0.07461748109755817


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 6378.278011083603; Testing MSE 2412.7527594566345
Epoch 13: Training R2 0.4039244395840814; Testing R2 0.05361105735508909


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 4564.268854260445; Testing MSE 2503.7707209587097
Epoch 14: Training R2 0.5734508419883053; Testing R2 0.01790976470915584


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 2920.40776014328; Testing MSE 2639.0649676322937
Epoch 15: Training R2 0.7270762282162326; Testing R2 -0.03515865622767822
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11864.913165569305, 9525.501376390457, 9367.49469935894, 9366.35686159134, 9298.717707395554, 9170.300447940826, 9123.222714662552, 9093.797519803047, 9001.711189746857, 8855.008047819138, 8564.586502313614, 8143.000993132591, 7176.059594750404, 6378.278011083603, 4564.268854260445, 2920.40776014328], 'mse_test_list': [2237.79656291008, 2234.9283814430237, 2269.645929336548, 2323.048597574234, 2334.691834449768, 2219.8149621486664, 2210.451817512512, 2196.477222442627, 2212.856638431549, 2243.282848596573, 2254.8877000808716, 2421.2916135787964, 2359.1983437538147, 2412.7527594566345, 2503.7707209587097, 2639.0649676322937], 'r_square_train_list': [-0.10882353390102706, 0.10980384340288007, 0.12457019858485141, 0.1246765340697279, 0.1309976832377714, 0.142998789141831

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12760.436004400253; Testing MSE 2202.6474237442017
Epoch 0: Training R2 -0.19157305629780752; Testing R2 0.09947761095263763


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9824.599921703339; Testing MSE 2198.0917096138
Epoch 1: Training R2 0.08257613207179204; Testing R2 0.10134015260515161


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9630.457592010498; Testing MSE 2165.9247159957886
Epoch 2: Training R2 0.10070519671104905; Testing R2 0.11449118968403316


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9520.020359754562; Testing MSE 2195.676502585411
Epoch 3: Training R2 0.11101785611570403; Testing R2 0.10232757709252216


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 9542.309772968292; Testing MSE 2206.881672143936
Epoch 4: Training R2 0.10893646452242223; Testing R2 0.09774649618427078


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 9508.344829082489; Testing MSE 2180.2167415618896
Epoch 5: Training R2 0.11210812041090712; Testing R2 0.10864808975420714


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 9419.629722833633; Testing MSE 2219.5713341236115
Epoch 6: Training R2 0.12039236165911216; Testing R2 0.09255850077522076


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 9334.135860204697; Testing MSE 2202.689528465271
Epoch 7: Training R2 0.12837580228389467; Testing R2 0.09946039701108766


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 9401.708686351776; Testing MSE 2234.733772277832
Epoch 8: Training R2 0.12206583302054108; Testing R2 0.08635954451774974


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 9344.430583715439; Testing MSE 2210.0933969020844
Epoch 9: Training R2 0.12741447814471907; Testing R2 0.09643342627530926


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 8897.81464934349; Testing MSE 2320.5089271068573
Epoch 10: Training R2 0.16911959807380061; Testing R2 0.05129154111653378


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 8802.618703246117; Testing MSE 2282.939863204956
Epoch 11: Training R2 0.1780090219460957; Testing R2 0.06665114103003411


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 8219.216760993004; Testing MSE 2432.779175043106
Epoch 12: Training R2 0.2324872572619502; Testing R2 0.005391379882998315


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 7660.678657889366; Testing MSE 2361.4377081394196
Epoch 13: Training R2 0.2846437003760005; Testing R2 0.034558366628904325


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 6591.738510131836; Testing MSE 2558.348274230957
Epoch 14: Training R2 0.3844616281038341; Testing R2 -0.045945835493465204
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11864.913165569305, 9525.501376390457, 9367.49469935894, 9366.35686159134, 9298.717707395554, 9170.300447940826, 9123.222714662552, 9093.797519803047, 9001.711189746857, 8855.008047819138, 8564.586502313614, 8143.000993132591, 7176.059594750404, 6378.278011083603, 4564.268854260445, 2920.40776014328], 'mse_test_list': [2237.79656291008, 2234.9283814430237, 2269.645929336548, 2323.048597574234, 2334.691834449768, 2219.8149621486664, 2210.451817512512, 2196.477222442627, 2212.856638431549, 2243.282848596573, 2254.8877000808716, 2421.2916135787964, 2359.1983437538147, 2412.7527594566345, 2503.7707209587097, 2639.0649676322937], 'r_square_train_list': [-0.10882353390102706, 0.10980384340288007, 0.12457019858485141, 0.1246765340697279, 0.1309976832377714, 0.14299878914183

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12141.083237528801; Testing MSE 2290.6496345996857
Epoch 0: Training R2 -0.043092056802197076; Testing R2 0.24625306442824035


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9035.33582687378; Testing MSE 2336.2269818782806
Epoch 1: Training R2 0.2237359017175593; Testing R2 0.2312556657323308


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 9137.25562095642; Testing MSE 2214.9611234664917
Epoch 2: Training R2 0.21497953908017087; Testing R2 0.27115865560329755


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 8834.881055355072; Testing MSE 2133.7208807468414
Epoch 3: Training R2 0.24095782300979984; Testing R2 0.2978910650779327


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8773.719084262848; Testing MSE 2087.7532184123993
Epoch 4: Training R2 0.24621250786588111; Testing R2 0.3130168984208561


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8925.029781460762; Testing MSE 2100.7145643234253
Epoch 5: Training R2 0.23321276284572656; Testing R2 0.3087519184720441


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8525.386682152748; Testing MSE 2103.8426965475082
Epoch 6: Training R2 0.26754779986741906; Testing R2 0.30772259471936203


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8427.037093043327; Testing MSE 2093.6946660280228
Epoch 7: Training R2 0.2759974310232859; Testing R2 0.31106184258626146


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 8332.129111886024; Testing MSE 2106.5198063850403
Epoch 8: Training R2 0.28415137901421206; Testing R2 0.306841681590724


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 8295.78527212143; Testing MSE 2082.366317510605
Epoch 9: Training R2 0.2872738327385189; Testing R2 0.3147894785595319


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 8080.269539356232; Testing MSE 2120.770552754402
Epoch 10: Training R2 0.3057897052158942; Testing R2 0.3021524195389157


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 7626.181542873383; Testing MSE 2197.6871132850647
Epoch 11: Training R2 0.34480233472792876; Testing R2 0.27684273405974


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 6746.285361051559; Testing MSE 2259.9085986614227
Epoch 12: Training R2 0.4203979549961362; Testing R2 0.2563685187014628


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 6018.201771378517; Testing MSE 2335.202831029892
Epoch 13: Training R2 0.48295071031606107; Testing R2 0.23159266644683385


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 5161.077645421028; Testing MSE 2572.1055269241333
Epoch 14: Training R2 0.5565898864907366; Testing R2 0.15363893735540135


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 4839.238674938679; Testing MSE 2879.292231798172
Epoch 15: Training R2 0.5842404401614208; Testing R2 0.052557988986072335


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 3987.6981630921364; Testing MSE 2743.178653717041
Epoch 16: Training R2 0.6573999043191838; Testing R2 0.09734667723358315


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 3039.4254997372627; Testing MSE 2544.0568268299103
Epoch 17: Training R2 0.7388700386948923; Testing R2 0.16286846832489266


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 2558.8708579540253; Testing MSE 2663.2748186588287
Epoch 18: Training R2 0.7801565301797779; Testing R2 0.12363933670705352


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 2139.878575503826; Testing MSE 2649.719899892807
Epoch 19: Training R2 0.8161539377532872; Testing R2 0.12809963405130387


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 1684.2904105782509; Testing MSE 2526.799041032791
Epoch 20: Training R2 0.8552954531114904; Testing R2 0.1685472080862468


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 1443.637304753065; Testing MSE 2567.759871482849
Epoch 21: Training R2 0.8759709841345462; Testing R2 0.15506889173272842


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 1239.0142895281315; Testing MSE 2592.596399784088
Epoch 22: Training R2 0.8935510169573413; Testing R2 0.14689633805427305


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 1054.5334056019783; Testing MSE 2563.119977712631
Epoch 23: Training R2 0.9094005536824009; Testing R2 0.15659566634629407


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 856.8879321217537; Testing MSE 2518.2615280151367
Epoch 24: Training R2 0.9263811162415037; Testing R2 0.17135650907111155
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11864.913165569305, 9525.501376390457, 9367.49469935894, 9366.35686159134, 9298.717707395554, 9170.300447940826, 9123.222714662552, 9093.797519803047, 9001.711189746857, 8855.008047819138, 8564.586502313614, 8143.000993132591, 7176.059594750404, 6378.278011083603, 4564.268854260445, 2920.40776014328], 'mse_test_list': [2237.79656291008, 2234.9283814430237, 2269.645929336548, 2323.048597574234, 2334.691834449768, 2219.8149621486664, 2210.451817512512, 2196.477222442627, 2212.856638431549, 2243.282848596573, 2254.8877000808716, 2421.2916135787964, 2359.1983437538147, 2412.7527594566345, 2503.7707209587097, 2639.0649676322937], 'r_square_train_list': [-0.10882353390102706, 0.10980384340288007, 0.12457019858485141, 0.1246765340697279, 0.1309976832377714, 0.142998789141831

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 11235.901886224747; Testing MSE 1908.4068953990936
Epoch 0: Training R2 0.03888290236847247; Testing R2 0.35568660075457337


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 8305.71317076683; Testing MSE 1822.3247021436691
Epoch 1: Training R2 0.2895307366261126; Testing R2 0.3847495383726569


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 7890.285620093346; Testing MSE 1773.9283174276352
Epoch 2: Training R2 0.32506633722342515; Testing R2 0.40108905130501626


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 7705.041527748108; Testing MSE 1780.3926020860672
Epoch 3: Training R2 0.34091208473805246; Testing R2 0.3989065894662932


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 7815.2194529771805; Testing MSE 1791.5060877799988
Epoch 4: Training R2 0.3314874841326435; Testing R2 0.3951544715284543


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 7785.186168551445; Testing MSE 1814.1490936279297
Epoch 5: Training R2 0.33405652607087843; Testing R2 0.38750977473858883


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 7435.7778787612915; Testing MSE 1792.322999238968
Epoch 6: Training R2 0.36394485054826; Testing R2 0.39487866713879294


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 7500.682005286217; Testing MSE 1786.4492148160934
Epoch 7: Training R2 0.3583929628278426; Testing R2 0.3968617651966868


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 7294.275167584419; Testing MSE 1867.4463629722595
Epoch 8: Training R2 0.3760489679079928; Testing R2 0.36951563267982057


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 7196.101933717728; Testing MSE 1774.6347397565842
Epoch 9: Training R2 0.38444668929738446; Testing R2 0.4008505500854046


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 7011.365684866905; Testing MSE 1788.2231831550598
Epoch 10: Training R2 0.4002489959676233; Testing R2 0.3962628407359815


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 6950.368398427963; Testing MSE 1873.9323616027832
Epoch 11: Training R2 0.40546669326502227; Testing R2 0.36732583980325484


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 6447.128933668137; Testing MSE 1898.3026087284088
Epoch 12: Training R2 0.4485137097556692; Testing R2 0.35909799447120405


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 5878.937289118767; Testing MSE 1950.4812687635422
Epoch 13: Training R2 0.4971167244346505; Testing R2 0.3414815155660188


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 4831.504815816879; Testing MSE 2067.593914270401
Epoch 14: Training R2 0.5867139164445903; Testing R2 0.30194212441046075


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 4147.6250767707825; Testing MSE 2039.35849070549
Epoch 15: Training R2 0.6452128706519804; Testing R2 0.311474924663961


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 3294.691114127636; Testing MSE 2183.331173658371
Epoch 16: Training R2 0.7181726938106339; Testing R2 0.26286708899982947


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 17: Training MSE 2621.060787141323; Testing MSE 2196.0485011339188
Epoch 17: Training R2 0.7757949150889651; Testing R2 0.25857348446777606


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 18: Training MSE 2332.7984675765038; Testing MSE 2140.0321900844574
Epoch 18: Training R2 0.8004528238836601; Testing R2 0.27748562520279463


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 19: Training MSE 1628.9651103317738; Testing MSE 2187.030929327011
Epoch 19: Training R2 0.8606586071292988; Testing R2 0.261617983184405


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 20: Training MSE 1308.4627658128738; Testing MSE 2204.081177711487
Epoch 20: Training R2 0.8880743220640975; Testing R2 0.2558615045628566


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 21: Training MSE 1135.6906302273273; Testing MSE 2094.241017103195
Epoch 21: Training R2 0.902853220561704; Testing R2 0.2929455705583277


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 22: Training MSE 972.993079200387; Testing MSE 1985.3190660476685
Epoch 22: Training R2 0.9167703408443653; Testing R2 0.32971963205759414


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 23: Training MSE 937.7330094575882; Testing MSE 2035.5598777532578
Epoch 23: Training R2 0.9197864810916411; Testing R2 0.31275740652334394


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 24: Training MSE 801.5471048653126; Testing MSE 2059.9389493465424
Epoch 24: Training R2 0.9314357997387293; Testing R2 0.3045265818881018
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11864.913165569305, 9525.501376390457, 9367.49469935894, 9366.35686159134, 9298.717707395554, 9170.300447940826, 9123.222714662552, 9093.797519803047, 9001.711189746857, 8855.008047819138, 8564.586502313614, 8143.000993132591, 7176.059594750404, 6378.278011083603, 4564.268854260445, 2920.40776014328], 'mse_test_list': [2237.79656291008, 2234.9283814430237, 2269.645929336548, 2323.048597574234, 2334.691834449768, 2219.8149621486664, 2210.451817512512, 2196.477222442627, 2212.856638431549, 2243.282848596573, 2254.8877000808716, 2421.2916135787964, 2359.1983437538147, 2412.7527594566345, 2503.7707209587097, 2639.0649676322937], 'r_square_train_list': [-0.10882353390102706, 0.10980384340288007, 0.12457019858485141, 0.1246765340697279, 0.1309976832377714, 0.1429987891418316

HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 0: Training MSE 12026.16914510727; Testing MSE 2492.138224840164
Epoch 0: Training R2 -0.04874091573066974; Testing R2 0.14629188555256845


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 1: Training MSE 9226.563295722008; Testing MSE 2422.6040959358215
Epoch 1: Training R2 0.19539844126180594; Testing R2 0.17011153146345215


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 2: Training MSE 8959.398412704468; Testing MSE 2414.2766416072845
Epoch 2: Training R2 0.21869652901412084; Testing R2 0.1729641883755375


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 3: Training MSE 9060.228373110294; Testing MSE 2412.3735696077347
Epoch 3: Training R2 0.20990366208092004; Testing R2 0.17361610566977093


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 4: Training MSE 8829.961717128754; Testing MSE 2379.7866582870483
Epoch 4: Training R2 0.22998404351764445; Testing R2 0.18477909428009687


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 5: Training MSE 8564.13431018591; Testing MSE 2401.1562287807465
Epoch 5: Training R2 0.25316549679838174; Testing R2 0.17745872354347914


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 6: Training MSE 8650.91355741024; Testing MSE 2575.0668346881866
Epoch 6: Training R2 0.24559792094755573; Testing R2 0.11788381956271243


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 7: Training MSE 8560.185411572456; Testing MSE 2358.822476863861
Epoch 7: Training R2 0.25350986011957166; Testing R2 0.19196059473433869


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 8: Training MSE 8506.585229933262; Testing MSE 2514.5860254764557
Epoch 8: Training R2 0.2581840587689801; Testing R2 0.13860215576002077


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 9: Training MSE 7263.9195561409; Testing MSE 2616.118296980858
Epoch 9: Training R2 0.36655059851703287; Testing R2 0.10382121014567902


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 10: Training MSE 7243.510901927948; Testing MSE 2751.9912868738174
Epoch 10: Training R2 0.368330333231377; Testing R2 0.057276490896291166


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 11: Training MSE 6549.617567658424; Testing MSE 2824.4199872016907
Epoch 11: Training R2 0.4288412342523634; Testing R2 0.032465279153510496


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 12: Training MSE 5881.55687302351; Testing MSE 2704.9343645572662
Epoch 12: Training R2 0.4870994024355483; Testing R2 0.07339633369720111


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 13: Training MSE 5710.098005831242; Testing MSE 2794.5315808057785
Epoch 13: Training R2 0.5020514563456264; Testing R2 0.04270386656963554


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 14: Training MSE 5116.813084483147; Testing MSE 2810.265102982521
Epoch 14: Training R2 0.5537888104603352; Testing R2 0.03731418335814951


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 15: Training MSE 4836.936663091183; Testing MSE 2907.628020644188
Epoch 15: Training R2 0.578195406685655; Testing R2 0.003961493677634986


HBox(children=(IntProgress(value=0, max=96), HTML(value='')))


Epoch 16: Training MSE 4568.530189990997; Testing MSE 3003.4923255443573
Epoch 16: Training R2 0.6016017671808962; Testing R2 -0.028877830467150867
Printing performance_continuous
{'HHVEHCNT_mean_norm': {'mse_train_list': [11864.913165569305, 9525.501376390457, 9367.49469935894, 9366.35686159134, 9298.717707395554, 9170.300447940826, 9123.222714662552, 9093.797519803047, 9001.711189746857, 8855.008047819138, 8564.586502313614, 8143.000993132591, 7176.059594750404, 6378.278011083603, 4564.268854260445, 2920.40776014328], 'mse_test_list': [2237.79656291008, 2234.9283814430237, 2269.645929336548, 2323.048597574234, 2334.691834449768, 2219.8149621486664, 2210.451817512512, 2196.477222442627, 2212.856638431549, 2243.282848596573, 2254.8877000808716, 2421.2916135787964, 2359.1983437538147, 2412.7527594566345, 2503.7707209587097, 2639.0649676322937], 'r_square_train_list': [-0.10882353390102706, 0.10980384340288007, 0.12457019858485141, 0.1246765340697279, 0.1309976832377714, 0.1429987891418

## Store and save resnet's last layer

In [18]:
### 
def return_last_layer_resnet(model,device,x_train_images,x_test_images,y_train,y_test):
    ###  
    # 
    from sklearn.preprocessing import MinMaxScaler

    image_train_hidden_list = []
    image_test_hidden_list = []

    # return values in the last layer.
    model_no_last_layer = nn.Sequential(*list(model.children())[:-1]).to(device)

    # process data
    x_train_images_norm = x_train_images/255
    x_test_images_norm = x_test_images/255

    x_train_torch = torch.from_numpy(x_train_images_norm)
    x_test_torch = torch.from_numpy(x_test_images_norm)
    y_train_torch = torch.from_numpy(y_train)
    y_test_torch = torch.from_numpy(y_test)

    # create data loader: train and test. 
    train_ds = TensorDataset(x_train_torch, y_train_torch)
    batch_size = 100
    train_dl_no_shuffle = DataLoader(train_ds, batch_size, shuffle = False) # important: NO SHUFFLE!!!

    test_ds = TensorDataset(x_test_torch, y_test_torch)
    batch_size = 100
    test_dl_no_shuffle = DataLoader(test_ds, batch_size, shuffle = False)

    for inputs, labels in train_dl_no_shuffle:
        # to device
        inputs = inputs.to(device)
        labels = labels.to(device)
        image_train_hidden_vector_batch = model_no_last_layer(inputs)
        image_train_hidden_list.append(image_train_hidden_vector_batch.squeeze().cpu().detach().numpy())

    for inputs, labels in test_dl_no_shuffle:
        inputs = inputs.to(device)
        labels = labels.to(device)
        # forward + backward
        image_test_hidden_vector_batch = model_no_last_layer(inputs)
        image_test_hidden_list.append(image_test_hidden_vector_batch.squeeze().cpu().detach().numpy())

    # vectorize
    image_train_hidden_vector=np.array(image_train_hidden_list).reshape(-1,512) # 512, resnet architecture   
    image_test_hidden_vector=np.array(image_test_hidden_list).reshape(-1,512) # 512, resnet architecture

    # scale
    scaler = MinMaxScaler()
    image_train_hidden_vector_norm = scaler.fit_transform(image_train_hidden_vector)
    image_test_hidden_vector_norm = scaler.fit_transform(image_test_hidden_vector)
    
    return image_train_hidden_vector_norm,image_test_hidden_vector_norm


In [20]:
# load model dictionary for all the output variables
model_dic = {}
model_name = 'resnet18'
if image_type == "rgb":
    input_channels = 3
elif image_type == "bw":
    input_channels = 4
elif image_type == "merge":
    input_channels = 7
use_pretrained = True # unclear True vs False is better
full_training = True
last_layer_dic_train = {}
last_layer_dic_test = {}

for output_var in output_list:
    print(output_var)
    # read models
    model = initialize_model(model_name, num_categories, input_channels, use_pretrained, full_training)
    PATH = './models/'+model_name+'_'+output_var+'_'+image_type+'.pth'
    model.load_state_dict(torch.load(PATH))
    model_dic[output_var]=model.state_dict()
    
    # initialize data. 
    y_train,y_test,BE_train,BE_test,x_train,x_test,x_train_images,x_test_images = \
        initialize_data(image_type, output_var, output_type, input_var, BE_var, num_categories, size)
    
    # obtain the last layer
    image_train_hidden_vector_norm,image_test_hidden_vector_norm = \
        return_last_layer_resnet(model,device,x_train_images,x_test_images,y_train,y_test)

    # 
    last_layer_dic_train[output_var]=image_train_hidden_vector_norm 
    last_layer_dic_test[output_var]=image_test_hidden_vector_norm


HHVEHCNT_mean_norm
HHVEHCNT_P_CAP_mean_norm
TRPTRANS_1_mean_norm
TRPTRANS_2_mean_norm
TRPTRANS_3_mean_norm


In [21]:
import pickle
with open('data_process/last_layer_dic_train.pickle', 'wb') as h:
    pickle.dump(last_layer_dic_train, h, protocol=pickle.HIGHEST_PROTOCOL)

with open('data_process/last_layer_dic_test.pickle', 'wb') as h:
    pickle.dump(last_layer_dic_test, h, protocol=pickle.HIGHEST_PROTOCOL)
    