MIDA Gondara and Wang(2018) in Python (using PyTorch)
https://arxiv.org/abs/1705.02737
https://gist.github.com/lgondara/18387c5f4d745673e9ca8e23f3d7ebd3 

### Note: Section 1 has been tested, moved to utils.py

# 1. Loading Dataset

## 1.1. Load a dataset and introduce missingness

Dataset used: Shuttle Dataset (https://archive.ics.uci.edu/ml/datasets/Statlog+(Shuttle)

### 1.1.1. Load the dataset and store it as dataframe(numeric)

In [1]:
import pandas as pd
import utils

In [2]:
#Test
filename = "data/shuttle/shuttle_trn"
train_df = utils.get_dataframe_from_csv(filename).iloc[:,:-1]  #remove label

INFO:root:Input filename has to be space separated data


In [3]:
train_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,28,0,27,48,22
1,55,0,92,0,0,26,36,92,56
2,53,0,82,0,52,-5,29,30,2
3,37,0,76,0,28,18,40,48,8
4,37,0,79,0,34,-26,43,46,2


### 1.1.2. Inducing missingness

After dataset loading, start with inducing missingness. 

To start off, introduce simple random missing patterns (Missing Completely At Random), i.e. sample half of the variables and set observations in those variables to missing if an appended random uniform vector has value less than a certain threshhold. WIth threshold of 0.2, the procedure should introduce about 20% missingness.

In [4]:
#test
df1 = train_df[:]
df2 = utils.induce_missingness(df1,logger_level=20)

INFO:root: Returning new dataframe with missingness(MCAR) induced
INFO:root: Percentage of NaNs in returned dataframe : 8.79


In [5]:
df1.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,28,0,27,48,22
1,55,0,92,0,0,26,36,92,56
2,53,0,82,0,52,-5,29,30,2
3,37,0,76,0,28,18,40,48,8
4,37,0,79,0,34,-26,43,46,2


In [6]:
df2.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,,0,,,
1,55,0,92,0,0.0,26,36.0,92.0,56.0
2,53,0,82,0,52.0,-5,29.0,30.0,2.0
3,37,0,76,0,28.0,18,40.0,48.0,8.0
4,37,0,79,0,34.0,-26,43.0,46.0,2.0


### 1.1.3. Create Train-Test split

Create 70% training data and 30%  test data which includes missingness and a test data without missingness so we can calculate performance. 

In [7]:
#Test
a,b,c = utils.create_train_test_split(df1)
print(a.head())
print(b.head())
print(c.head())

INFO:root: Returning new dataframe with missingness(MCAR) induced
INFO:root: Percentage of NaNs in returned dataframe : 8.79
INFO:root: Returning train_df, test_df, full_test_df after splitting dataframe in 0.7/0.3 split 
INFO:root: Note: full_test_df is the same as test_df but without NaNs


        0  1    2  3     4  5     6     7    8
7476   55  0   98  0   NaN -4   NaN   NaN  NaN
31355  50 -5  102  2  50.0  0  52.0  53.0  0.0
38462  37  0   77  0  36.0 -2  40.0  41.0  2.0
20525  55 -2   95  0  46.0 -3  40.0  49.0  8.0
34457  55  0   92  8   NaN  0   NaN   NaN  NaN
        0  1   2  3     4   5     6     7     8
15528  45 -1  76  0   NaN -16   NaN   NaN   NaN
14327  37  0  95  0  10.0   7  58.0  84.0  26.0
12125  37  0  75 -4  30.0   0  38.0  44.0   6.0
39952  55  0  96  0  50.0   4  41.0  47.0   6.0
1339   41 -1  76  0   NaN -14   NaN   NaN   NaN
        0  1   2  3   4   5   6   7   8
15528  45 -1  76  0  44 -16  31  32   2
14327  37  0  95  0  10   7  58  84  26
12125  37  0  75 -4  30   0  38  44   6
39952  55  0  96  0  50   4  41  47   6
1339   41 -1  76  0  38 -14  35  37   2


# 2. Modelling

Proceed to modelling.

In R:
Start with initializing 'h2o' package and then reading the training and test datasets as the 'h2o's supported format.
Then run imputation model multiple times as each new start would initialize the weights with different values.<br>
Info at: <br>
[h2o](https://cran.r-project.org/web/packages/h2o/h2o.pdf) package offers an easy to use function for implementing autoencoders. 
More information is available at this [link](http://docs.h2o.ai/h2o/latest-stable/h2o-docs/booklets/DeepLearningBooklet.pdf).

In Python:

In [8]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.functional as F

In [9]:
#Settings for device, randomization seed, default tensor type, kwargs for memory #DevSeedTensKwargs
RANDOM_SEED = 18
np.random.seed(RANDOM_SEED)

if torch.cuda.is_available():
    device = 'cuda'
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    kwargs = {'num_workers':4, 'pin_memory' :True}
else:
    device = 'cpu'
    torch.manual_seed(RANDOM_SEED)
    torch.set_default_tensor_type(torch.FloatTensor)
    kwards = {}

In [10]:
import dataset_module

In [26]:
df2.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,,0,,,
1,55,0,92,0,0.0,26,36.0,92.0,56.0
2,53,0,82,0,52.0,-5,29.0,30.0,2.0
3,37,0,76,0,28.0,18,40.0,48.0,8.0
4,37,0,79,0,34.0,-26,43.0,46.0,2.0


In [14]:
trainset = dataset_module.DataSetForImputation(df2)

In [15]:
trainset

Dataframe Size:43500, Perc of NaNs: 8.79

In [19]:
trainset[0]

(tensor([50.0000, 21.0000, 77.0000,  0.0000, 34.5129,  0.0000, 37.1039, 50.9072,
         13.9429]),
 tensor([50.0000, 21.0000, 77.0000,  0.0000, 34.5129,  0.0000, 37.1039, 50.9072,
         13.9429]))

In [144]:
import Modelling
net = Modelling.DenoisingAutoEncoder(len(trainset.variables()))

In [145]:
net

DenoisingAutoEncoder(
  (drop_layer): Dropout(p=0.5)
  (linear_layer_list): ModuleList(
    (0): Linear(in_features=9, out_features=16, bias=True)
    (1): Linear(in_features=16, out_features=23, bias=True)
    (2): Linear(in_features=23, out_features=30, bias=True)
    (3): Linear(in_features=30, out_features=23, bias=True)
    (4): Linear(in_features=23, out_features=16, bias=True)
    (5): Linear(in_features=16, out_features=9, bias=True)
  )
)

# 3. Training

In [177]:
import torch.utils.data as td
from torch.optim import Adam

LR = 1e-3
BATCH_SIZE = 16
VARIABLES  = len(trainset.variables()) #9
DATAPOINTS = len(trainset)  #45600

import Modelling
net = Modelling.DenoisingAutoEncoder(len(trainset.variables()))

criterion = nn.MSELoss()
net = net.to(device) 

train_loader = td.DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, **kwargs) 
optimizer = Adam(net.parameters(), lr = LR)

LOG_INTERVAL = 2

In [178]:
'''
TO DO:
:- Normalization between 0 and 1 - Error blowing up for some reason
:- (0,1) Normalization for better convergence - how to handle this elegantly while testing, because this has been trained for something between 0 and 1
:- Nesterov Momentum + Adam- Pytorch? Decay factor?
'''
from tqdm import tqdm_notebook as tqdm
def train(start_steps = 0, end_steps = 5, net=None, logger_level = 10):
    import logging
    logger = logging.getLogger()
    logger.setLevel(logger_level)
    
    agg_loss = 0.0
    for epoch in tqdm(range(start_steps, end_steps)):
        net.train()
        #Epoch begins
        for x, d in tqdm(train_loader):
#             x = x/x.sum(0).expand_as(x)  # TO DO : NaN, Explore
#             d = d/d.sum(0).expand_as(d)  # Normalize between 0,1 for better convergence #TO DO
            optimizer.zero_grad()
            x = x.to(device)
            with torch.no_grad():
                d = d.to(device)
            y = net(x)
            loss = torch.sqrt(criterion(y, d)/VARIABLES)   #RMSE Loss   #TO DO /VARIABLES because the value was blowing up-> take this out maybe?
            loss.backward()
            optimizer.step()
            agg_loss += loss.item()
            logging.debug(loss)
            
        if epoch%LOG_INTERVAL == 0:
            print(f"Epoch number:{epoch}  Aggregate loss: {agg_loss/(LOG_INTERVAL*DATAPOINTS):.2f}")  #Determine the correct interval and value for loss
            agg_loss= 0.0
        #Epoch Ends

In [180]:
train(10,40, net, logger_level=20)

HBox(children=(IntProgress(value=0, max=30), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:10  Aggregate loss: 0.45634250864489323


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:12  Aggregate loss: 0.8627423513620749


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:14  Aggregate loss: 0.8285637217253105


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:16  Aggregate loss: 0.809038672548601


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:18  Aggregate loss: 0.8000655019612148


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:20  Aggregate loss: 0.7927715696575998


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:22  Aggregate loss: 0.7861740176184424


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:24  Aggregate loss: 0.777658688328732


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:26  Aggregate loss: 0.7746905001744456


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:28  Aggregate loss: 0.7744138725242395


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:30  Aggregate loss: 0.7692476523963885


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:32  Aggregate loss: 0.7665664829522714


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:34  Aggregate loss: 0.7652386885648487


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:36  Aggregate loss: 0.7665795712553222


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:38  Aggregate loss: 0.7658164659944073


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

In [184]:
train(40,500, net, logger_level=20)

HBox(children=(IntProgress(value=0, max=460), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:40  Aggregate loss: 0.38083006401719716


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:42  Aggregate loss: 0.758743081986219


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:44  Aggregate loss: 0.757156429337359


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:46  Aggregate loss: 0.7573874940954406


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:48  Aggregate loss: 0.7547284749485981


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:50  Aggregate loss: 0.7591297333788598


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:52  Aggregate loss: 0.7531450336856403


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:54  Aggregate loss: 0.7503541740466808


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:56  Aggregate loss: 0.7503922219331237


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:58  Aggregate loss: 0.7500728380488253


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:60  Aggregate loss: 0.7516505943166799


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:62  Aggregate loss: 0.7514485209837727


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:64  Aggregate loss: 0.7456729650853694


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:66  Aggregate loss: 0.7445432918126555


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:68  Aggregate loss: 0.7453098986970967


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:70  Aggregate loss: 0.7475433055642008


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:72  Aggregate loss: 0.7462534130386923


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:74  Aggregate loss: 0.7457102048643705


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:76  Aggregate loss: 0.7477372841396551


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:78  Aggregate loss: 0.7433803105381713


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:80  Aggregate loss: 0.7454377628194875


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:82  Aggregate loss: 0.7439188142316094


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:84  Aggregate loss: 0.7419546529671242


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:86  Aggregate loss: 0.7409693589073488


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:88  Aggregate loss: 0.7436971601617748


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:90  Aggregate loss: 0.7344026312088144


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:92  Aggregate loss: 0.7411538497195846


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:94  Aggregate loss: 0.7425664262387944


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:96  Aggregate loss: 0.7411689058994425


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:98  Aggregate loss: 0.74107006097388


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:100  Aggregate loss: 0.7403810421954626


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:102  Aggregate loss: 0.7416875553514766


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:104  Aggregate loss: 0.7429920425689084


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:106  Aggregate loss: 0.7418618346384201


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:108  Aggregate loss: 0.745785270468942


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:110  Aggregate loss: 0.7417797180482711


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:112  Aggregate loss: 0.7412185553764474


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:114  Aggregate loss: 0.7384872299693097


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:116  Aggregate loss: 0.7409592827133749


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:118  Aggregate loss: 0.7420738339643369


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:120  Aggregate loss: 0.7396936790559484


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:122  Aggregate loss: 0.7403234005659476


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:124  Aggregate loss: 0.7409793649624133


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:126  Aggregate loss: 0.7403994098504384


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:128  Aggregate loss: 0.7375584950118229


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:130  Aggregate loss: 0.7377626005616681


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:132  Aggregate loss: 0.7414905406962866


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:134  Aggregate loss: 0.7307037894068094


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:136  Aggregate loss: 0.7396907482503474


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:138  Aggregate loss: 0.7391513955401278


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:140  Aggregate loss: 0.7379615476487696


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:142  Aggregate loss: 0.7399209403498419


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:144  Aggregate loss: 0.7405085727784826


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:146  Aggregate loss: 0.7359328391003883


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:148  Aggregate loss: 0.7377449745200146


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:150  Aggregate loss: 0.7387385364318716


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:152  Aggregate loss: 0.7359934126042772


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:154  Aggregate loss: 0.7382637344338429


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:156  Aggregate loss: 0.736262359213555


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:158  Aggregate loss: 0.7369284769337753


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:160  Aggregate loss: 0.7376929107315239


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:162  Aggregate loss: 0.7363470935657107


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:164  Aggregate loss: 0.7372443150953315


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:166  Aggregate loss: 0.7378674887131


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:168  Aggregate loss: 0.739117427543662


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:170  Aggregate loss: 0.738394143090851


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:172  Aggregate loss: 0.7354973201422855


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:174  Aggregate loss: 0.7370483711396141


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:176  Aggregate loss: 0.7355311479458864


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:178  Aggregate loss: 0.733541718080126


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:180  Aggregate loss: 0.738204271456291


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:182  Aggregate loss: 0.7350576226930509


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:184  Aggregate loss: 0.7352530891730867


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:186  Aggregate loss: 0.7391133673355497


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:188  Aggregate loss: 0.7352830813588768


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:190  Aggregate loss: 0.7375817112813051


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:192  Aggregate loss: 0.7385801084233427


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:194  Aggregate loss: 0.7376962563004987


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:196  Aggregate loss: 0.7359896223298434


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:198  Aggregate loss: 0.7311335069612525


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:200  Aggregate loss: 0.7354390218915611


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:202  Aggregate loss: 0.736026083779061


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:204  Aggregate loss: 0.7372667746872738


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:206  Aggregate loss: 0.7341269184885354


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:208  Aggregate loss: 0.7364788617073804


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:210  Aggregate loss: 0.7384420485332095


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:212  Aggregate loss: 0.7315998779795636


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:214  Aggregate loss: 0.7368891502851728


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:216  Aggregate loss: 0.732726894748622


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:218  Aggregate loss: 0.7378180320564358


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:220  Aggregate loss: 0.7390023060091611


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:222  Aggregate loss: 0.7315120415002451


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:224  Aggregate loss: 0.7372643048324804


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:226  Aggregate loss: 0.7372563456529858


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:228  Aggregate loss: 0.7374827582123635


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:230  Aggregate loss: 0.7377024431283447


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:232  Aggregate loss: 0.7375724296021735


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:234  Aggregate loss: 0.7325405097610649


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:236  Aggregate loss: 0.7362491182102554


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:238  Aggregate loss: 0.7361970845282763


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:240  Aggregate loss: 0.7341176703990191


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:242  Aggregate loss: 0.7313090145341281


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:244  Aggregate loss: 0.7387975983784116


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:246  Aggregate loss: 0.7375073235281583


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:248  Aggregate loss: 0.7344400660333962


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:250  Aggregate loss: 0.7375130738992801


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:252  Aggregate loss: 0.7337701481786267


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:254  Aggregate loss: 0.7403246960585145


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:256  Aggregate loss: 0.7392061845351909


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:258  Aggregate loss: 0.7377114619551034


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:260  Aggregate loss: 0.7389990109575206


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:262  Aggregate loss: 0.7372302527482483


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:264  Aggregate loss: 0.7381226926398003


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:266  Aggregate loss: 0.7361150052985925


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:268  Aggregate loss: 0.7355434136856561


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:270  Aggregate loss: 0.7359487486515922


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:272  Aggregate loss: 0.7375593468567421


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:274  Aggregate loss: 0.7385166359797292


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:276  Aggregate loss: 0.7391288863187548


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:278  Aggregate loss: 0.7398417717396528


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:280  Aggregate loss: 0.7344966523345859


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:282  Aggregate loss: 0.7350144312765406


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:284  Aggregate loss: 0.7357144990570244


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:286  Aggregate loss: 0.7357289217011682


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:288  Aggregate loss: 0.7355237558408715


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:290  Aggregate loss: 0.7374488908888279


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:292  Aggregate loss: 0.7367298359322823


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:294  Aggregate loss: 0.7365718003634749


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:296  Aggregate loss: 0.7352153197869488


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:298  Aggregate loss: 0.7378762079381395


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:300  Aggregate loss: 0.7367376997059789


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:302  Aggregate loss: 0.7388267464254095


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:304  Aggregate loss: 0.7386105745496421


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:306  Aggregate loss: 0.7333727479282467


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:308  Aggregate loss: 0.7380996432414


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:310  Aggregate loss: 0.739412525683984


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:312  Aggregate loss: 0.7364062837375992


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:314  Aggregate loss: 0.7362748147208115


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:316  Aggregate loss: 0.7335107780242789


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:318  Aggregate loss: 0.7348197130444406


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:320  Aggregate loss: 0.7362895032866248


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:322  Aggregate loss: 0.7405714064795396


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:324  Aggregate loss: 0.7340025428190998


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:326  Aggregate loss: 0.7364108453482047


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:328  Aggregate loss: 0.7313497784713219


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:330  Aggregate loss: 0.7337889066153559


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:332  Aggregate loss: 0.7349614538351694


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:334  Aggregate loss: 0.7358844869438259


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:336  Aggregate loss: 0.7370641684778806


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:338  Aggregate loss: 0.7322630052374697


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:340  Aggregate loss: 0.7338755872825097


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:342  Aggregate loss: 0.7343663783320066


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:344  Aggregate loss: 0.7261802714117642


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:346  Aggregate loss: 0.7352590144091639


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:348  Aggregate loss: 0.7367820741171124


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:350  Aggregate loss: 0.7370161250267906


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:352  Aggregate loss: 0.7325335285197729


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:354  Aggregate loss: 0.7392316076536288


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:356  Aggregate loss: 0.7368157623554098


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:358  Aggregate loss: 0.7403818527391587


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:360  Aggregate loss: 0.7342965573250563


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:362  Aggregate loss: 0.7363517347423509


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:364  Aggregate loss: 0.7384566676918117


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:366  Aggregate loss: 0.7358797632524336


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:368  Aggregate loss: 0.736550359175123


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:370  Aggregate loss: 0.7325103220199717


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:372  Aggregate loss: 0.735434590539713


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:374  Aggregate loss: 0.7363410739131357


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:376  Aggregate loss: 0.7372066340638304


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:378  Aggregate loss: 0.7350093166334876


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:380  Aggregate loss: 0.7355266736649919


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:382  Aggregate loss: 0.7320642780243666


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:384  Aggregate loss: 0.7356782830830279


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:386  Aggregate loss: 0.7325511507933167


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:388  Aggregate loss: 0.7342028316382704


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:390  Aggregate loss: 0.7364360365757997


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:392  Aggregate loss: 0.7372587729486926


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:394  Aggregate loss: 0.733230498234431


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:396  Aggregate loss: 0.7343908495820802


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:398  Aggregate loss: 0.7385892284026091


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:400  Aggregate loss: 0.7318278524080912


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:402  Aggregate loss: 0.7353540977280716


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:404  Aggregate loss: 0.7320769951014683


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:406  Aggregate loss: 0.7352389491700578


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:408  Aggregate loss: 0.7330452768693025


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:410  Aggregate loss: 0.734682599550006


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:412  Aggregate loss: 0.7329475119442775


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:414  Aggregate loss: 0.732590734602391


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:416  Aggregate loss: 0.7333083244521043


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:418  Aggregate loss: 0.7358380764270651


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:420  Aggregate loss: 0.7362142172709278


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:422  Aggregate loss: 0.7386305833301325


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:424  Aggregate loss: 0.735912366839661


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:426  Aggregate loss: 0.7336604996796312


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:428  Aggregate loss: 0.7368059301924431


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:430  Aggregate loss: 0.7325455092123185


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:432  Aggregate loss: 0.734150750127332


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:434  Aggregate loss: 0.7319386845484547


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:436  Aggregate loss: 0.7335033569637386


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:438  Aggregate loss: 0.732307974916765


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:440  Aggregate loss: 0.7343467808619313


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:442  Aggregate loss: 0.730950074963186


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:444  Aggregate loss: 0.7293991739558078


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:446  Aggregate loss: 0.7337972158919805


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:448  Aggregate loss: 0.7290031759821135


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:450  Aggregate loss: 0.7332717742919922


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:452  Aggregate loss: 0.7325247856556684


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:454  Aggregate loss: 0.735698876323371


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:456  Aggregate loss: 0.7330758751370441


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:458  Aggregate loss: 0.7323189034489379


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:460  Aggregate loss: 0.7340768530999107


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:462  Aggregate loss: 0.736850301824767


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:464  Aggregate loss: 0.737188288795537


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:466  Aggregate loss: 0.7363734807283029


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:468  Aggregate loss: 0.7358581774070345


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:470  Aggregate loss: 0.7352355940890039


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:472  Aggregate loss: 0.7357558040783323


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:474  Aggregate loss: 0.7338380990686088


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:476  Aggregate loss: 0.7330586313022964


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:478  Aggregate loss: 0.73497646022117


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:480  Aggregate loss: 0.7292798776791014


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:482  Aggregate loss: 0.7338121838323001


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:484  Aggregate loss: 0.7328250603812865


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:486  Aggregate loss: 0.7376733769082475


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:488  Aggregate loss: 0.7318512862441183


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:490  Aggregate loss: 0.7318681191192277


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:492  Aggregate loss: 0.7352475888372838


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:494  Aggregate loss: 0.7351116572769208


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:496  Aggregate loss: 0.7293754829319045


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Epoch number:498  Aggregate loss: 0.7346177885121312


HBox(children=(IntProgress(value=0, max=2719), HTML(value='')))

Ran for 500 iterations
Observation - quickly converges and plateaus at local minima, maybe need a smaller learning rate after a while? or Nesterov?

In [None]:
#model saving, re-loading

In [185]:
net = net.eval()

print(net(trainset[0][0]))
print(trainset[0][0])

tensor([ 4.5032e+01, -1.2639e-01,  8.5354e+01,  4.3548e-02,  3.9030e+01,
         6.0109e-01,  4.0049e+01,  4.6558e+01,  6.6741e+00],
       grad_fn=<AddBackward0>)
tensor([50.0000, 21.0000, 77.0000,  0.0000, 34.5129,  0.0000, 37.1039, 50.9072,
        13.9429])


In [186]:
print(net(trainset[5][0]))
print(trainset[5][0])

tensor([ 8.9562e+01,  9.5964e-03,  8.5801e+01,  2.3462e-01,  1.6076e+01,
         1.2874e+00, -1.2152e+00,  6.9935e+01,  7.0976e+01],
       grad_fn=<AddBackward0>)
tensor([85.,  0., 88., -4.,  6.,  1.,  3., 83., 80.])


In [189]:
df2

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,,0,,,
1,55,0,92,0,0.0,26,36.0,92.0,56.0
2,53,0,82,0,52.0,-5,29.0,30.0,2.0
3,37,0,76,0,28.0,18,40.0,48.0,8.0
4,37,0,79,0,34.0,-26,43.0,46.0,2.0
5,85,0,88,-4,6.0,1,3.0,83.0,80.0
6,56,0,81,0,-4.0,11,25.0,86.0,62.0
7,55,-1,95,-3,54.0,-4,40.0,41.0,2.0
8,53,8,77,0,28.0,0,23.0,48.0,24.0
9,37,0,101,-7,28.0,0,64.0,73.0,8.0


In [None]:
train(500,510, net,  logger_level=20)

In [190]:
# Adjust learning rate when needed
# Running time is slow
# Nesterv and time decay and all that
# Normalization between 0 and 1
# Training time should be less than an hour definitely!