MIDA Gondara and Wang(2018) in Python (using PyTorch)
https://arxiv.org/abs/1705.02737
https://gist.github.com/lgondara/18387c5f4d745673e9ca8e23f3d7ebd3 

### Note: Section 1 has been tested, moved to utils.py

# 1. Loading Dataset

## 1.1. Load a dataset and introduce missingness

Dataset used: Shuttle Dataset (https://archive.ics.uci.edu/ml/datasets/Statlog+(Shuttle)

### 1.1.1. Load the dataset and store it as dataframe(numeric)

In [1]:
import pandas as pd
import utils

In [30]:
#Test
filename = "data/shuttle/shuttle_trn"
train_df = utils.get_dataframe_from_csv(filename).iloc[:,:-1]  #remove label

INFO:root:Input filename has to be space separated data


In [31]:
# #Test
# filename = "data/shuttle/shuttle_trn_debug"
# train_df = utils.get_dataframe_from_csv(filename).iloc[:,:-1]  #remove label

In [32]:
train_df.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,28,0,27,48,22
1,55,0,92,0,0,26,36,92,56
2,53,0,82,0,52,-5,29,30,2
3,37,0,76,0,28,18,40,48,8
4,37,0,79,0,34,-26,43,46,2


### 1.1.2. Inducing missingness

After dataset loading, start with inducing missingness. 

To start off, introduce simple random missing patterns (Missing Completely At Random), i.e. sample half of the variables and set observations in those variables to missing if an appended random uniform vector has value less than a certain threshhold. WIth threshold of 0.2, the procedure should introduce about 20% missingness.

In [33]:
#test
df1 = train_df[:]
df2 = utils.induce_missingness(df1,logger_level=20)

INFO:root: Returning new dataframe with missingness(MCAR) induced
INFO:root: Percentage of NaNs in returned dataframe : 8.79


In [34]:
df1.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,28,0,27,48,22
1,55,0,92,0,0,26,36,92,56
2,53,0,82,0,52,-5,29,30,2
3,37,0,76,0,28,18,40,48,8
4,37,0,79,0,34,-26,43,46,2


In [35]:
df2.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,,0,,,
1,55,0,92,0,0.0,26,36.0,92.0,56.0
2,53,0,82,0,52.0,-5,29.0,30.0,2.0
3,37,0,76,0,28.0,18,40.0,48.0,8.0
4,37,0,79,0,34.0,-26,43.0,46.0,2.0


### 1.1.3. Create Train-Test split

Create 70% training data and 30%  test data which includes missingness and a test data without missingness so we can calculate performance. 

In [36]:
#Test
a,b,c = utils.create_train_test_split(df1)
print(a.head())
print(b.head())
print(c.head())

INFO:root: Returning new dataframe with missingness(MCAR) induced
INFO:root: Percentage of NaNs in returned dataframe : 8.79
INFO:root: Returning train_df, test_df, full_test_df after splitting dataframe in 0.7/0.3 split 
INFO:root: Note: full_test_df is the same as test_df but without NaNs


        0  1    2  3     4  5     6     7    8
7476   55  0   98  0   NaN -4   NaN   NaN  NaN
31355  50 -5  102  2  50.0  0  52.0  53.0  0.0
38462  37  0   77  0  36.0 -2  40.0  41.0  2.0
20525  55 -2   95  0  46.0 -3  40.0  49.0  8.0
34457  55  0   92  8   NaN  0   NaN   NaN  NaN
        0  1   2  3     4   5     6     7     8
15528  45 -1  76  0   NaN -16   NaN   NaN   NaN
14327  37  0  95  0  10.0   7  58.0  84.0  26.0
12125  37  0  75 -4  30.0   0  38.0  44.0   6.0
39952  55  0  96  0  50.0   4  41.0  47.0   6.0
1339   41 -1  76  0   NaN -14   NaN   NaN   NaN
        0  1   2  3   4   5   6   7   8
15528  45 -1  76  0  44 -16  31  32   2
14327  37  0  95  0  10   7  58  84  26
12125  37  0  75 -4  30   0  38  44   6
39952  55  0  96  0  50   4  41  47   6
1339   41 -1  76  0  38 -14  35  37   2


# 2. Modelling

Proceed to modelling.

In R:
Start with initializing 'h2o' package and then reading the training and test datasets as the 'h2o's supported format.
Then run imputation model multiple times as each new start would initialize the weights with different values.<br>
Info at: <br>
[h2o](https://cran.r-project.org/web/packages/h2o/h2o.pdf) package offers an easy to use function for implementing autoencoders. 
More information is available at this [link](http://docs.h2o.ai/h2o/latest-stable/h2o-docs/booklets/DeepLearningBooklet.pdf).

In Python:

In [37]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.functional as F

In [38]:
#Settings for device, randomization seed, default tensor type, kwargs for memory #DevSeedTensKwargs
RANDOM_SEED = 18
np.random.seed(RANDOM_SEED)

if torch.cuda.is_available():
    device = 'cuda'
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.set_default_tensor_type(torch.cuda.FloatTensor)
    kwargs = {'num_workers':4, 'pin_memory' :True}
else:
    device = 'cpu'
    torch.manual_seed(RANDOM_SEED)
    torch.set_default_tensor_type(torch.FloatTensor)
    kwards = {}

In [39]:
import dataset_module

In [40]:
# from importlib import reload
# reload(dataset_module)

In [41]:
df2.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8
0,50,21,77,0,,0,,,
1,55,0,92,0,0.0,26,36.0,92.0,56.0
2,53,0,82,0,52.0,-5,29.0,30.0,2.0
3,37,0,76,0,28.0,18,40.0,48.0,8.0
4,37,0,79,0,34.0,-26,43.0,46.0,2.0


In [42]:
trainset = dataset_module.DataSetForImputation(df2, normalize=False)  #normalize True for [0,1] normalization for dataframe

In [43]:
trainset = dataset_module.DataSetForImputation(df2, normalize=True)  #normalize True for [0,1] normalization for dataframe

In [44]:
trainset

Dataframe Size:43500, Perc of NaNs: 8.79

In [45]:
trainset[0]

(tensor([0.2323, 0.4893, 0.4375, 0.5070, 0.3566, 0.5128, 0.5562, 0.6483, 0.5948]),
 tensor([0.2323, 0.4893, 0.4375, 0.5070, 0.3566, 0.5128, 0.5562, 0.6483, 0.5948]))

In [46]:
trainset[1]

(tensor([0.2828, 0.4872, 0.5547, 0.5070, 0.3013, 0.5138, 0.5490, 0.7143, 0.6624]),
 tensor([0.2828, 0.4872, 0.5547, 0.5070, 0.3013, 0.5138, 0.5490, 0.7143, 0.6624]))

In [47]:
trainset[2]

(tensor([0.2626, 0.4872, 0.4766, 0.5070, 0.3846, 0.5126, 0.5033, 0.6148, 0.5756]),
 tensor([0.2626, 0.4872, 0.4766, 0.5070, 0.3846, 0.5126, 0.5033, 0.6148, 0.5756]))

In [48]:
import Modelling
net = Modelling.DenoisingAutoEncoder(len(trainset.variables()))

In [49]:
net

DenoisingAutoEncoder(
  (drop_layer): Dropout(p=0.5)
  (linear_layer_list): ModuleList(
    (0): Linear(in_features=9, out_features=16, bias=True)
    (1): Linear(in_features=16, out_features=23, bias=True)
    (2): Linear(in_features=23, out_features=30, bias=True)
    (3): Linear(in_features=30, out_features=23, bias=True)
    (4): Linear(in_features=23, out_features=16, bias=True)
    (5): Linear(in_features=16, out_features=9, bias=True)
  )
)

# 3. Training

In [50]:
import torch.utils.data as td
from torch.optim import Adam

LR = 1e-3
DATAPOINTS = len(trainset)  #45600
BATCH_SIZE = 512  # wasn't working too well even after 600 iterations
BATCHES = DATAPOINTS/BATCH_SIZE
VARIABLES  = len(trainset.variables()) #9


import Modelling
net = Modelling.DenoisingAutoEncoder(len(trainset.variables()))

criterion = nn.MSELoss()
net = net.to(device) 

train_loader = td.DataLoader(trainset, batch_size= BATCH_SIZE, shuffle= True, **kwargs) 
optimizer = Adam(net.parameters(), lr = LR)

LOG_INTERVAL = 10
SAVE_INTERVAL = 50

In [22]:
'''
TO DO:
:- SOLVED : Normalization between 0 and 1 - Error blowing up for some reason - NaNs and inf (Decided to normalize before passing the dataset)
:- (0,1) Normalization for better convergence - how to handle this elegantly while testing, because this has been trained for something between 0 and 1
:- Nesterov Momentum + Adam- Pytorch? Decay factor?
'''

from tqdm import tqdm_notebook as tqdm
def train(start_steps = 0, end_steps = 5, net=None, logger_level = 10):
    import logging
    logger = logging.getLogger()
    logger.setLevel(logger_level)
    
    agg_loss = 0.0
    
    #Debug tools
    prev_loss = 0
    NaN_flag = False
    
    for epoch in tqdm(range(start_steps, end_steps)):
        count = epoch-start_steps+1
        net.train()
        #Epoch begins
        
        for x, d in tqdm(train_loader):
            # Normalize between [0,1] for better convergence NOT WORKING - values getting minimised to -inf
            original_x = x
            
            # Normalize [-1,1]
#             x = x/x.sum(0).expand_as(x)  #PROBLEM : Going to NaNs 
            x[torch.isnan(x)]=0   #SOLN: if an entire column is zero, division by 0, replace NaNs with zero
#             x[torch.isinf(x)]=0
#             x = 2*x - 1
#             d = d/d.sum(0).expand_as(d)  
            d[torch.isnan(d)]=0 
#             d[torch.isinf(d)]=0
#             d = 2*d - 1
            
            optimizer.zero_grad()
            x = x.to(device)
            with torch.no_grad():
                d = d.to(device)
            y = net(x)
            loss = torch.sqrt(criterion(y, d))   #RMSE Loss   
            loss.backward()
            optimizer.step()
            agg_loss += loss.item()
            
            if torch.isnan(loss) or torch.isinf(loss):
                logging.info(f"Loss value: {loss.item()}")
                logging.info(f"previous loss was {prev_loss}")
                logging.info("NaN/inf occured at:")
                logging.info(f"{x}\n")
                logging.info(f"{d}\n")
                logging.info(f"Original x was : {original_x}")
                NaN_flag = True
                break

            prev_loss = loss
            logging.debug(f"Count: {count}, Loss :{loss}")
            
        if NaN_flag: break    
        if count%LOG_INTERVAL == 0:
            print(f"Epoch number:{epoch}  Aggregate loss: {agg_loss/(LOG_INTERVAL*BATCHES):.4f}")  #Determine the correct interval and value for loss
            agg_loss= 0.0
            
        if epoch%SAVE_INTERVAL== 0:
            torch.save(net.state_dict(),f"./artifacts/DAE_saved_model_epoch{epoch}")
        #Epoch Ends

In [51]:
train(0,500, net, logger_level=20)


HBox(children=(IntProgress(value=0, max=500), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:9  Aggregate loss: 0.0662


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:19  Aggregate loss: 0.0476


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:29  Aggregate loss: 0.0418


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:39  Aggregate loss: 0.0397


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:49  Aggregate loss: 0.0390


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:59  Aggregate loss: 0.0384


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:69  Aggregate loss: 0.0364


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:79  Aggregate loss: 0.0354


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:89  Aggregate loss: 0.0350


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:99  Aggregate loss: 0.0349


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:109  Aggregate loss: 0.0341


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:119  Aggregate loss: 0.0335


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:129  Aggregate loss: 0.0330


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:139  Aggregate loss: 0.0326


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:149  Aggregate loss: 0.0326


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:159  Aggregate loss: 0.0325


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:169  Aggregate loss: 0.0324


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:179  Aggregate loss: 0.0323


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:189  Aggregate loss: 0.0324


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:199  Aggregate loss: 0.0322


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:209  Aggregate loss: 0.0321


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:219  Aggregate loss: 0.0320


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:229  Aggregate loss: 0.0321


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:239  Aggregate loss: 0.0321


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:249  Aggregate loss: 0.0321


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:259  Aggregate loss: 0.0319


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:269  Aggregate loss: 0.0319


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:279  Aggregate loss: 0.0319


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:289  Aggregate loss: 0.0320


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:299  Aggregate loss: 0.0317


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:309  Aggregate loss: 0.0317


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:319  Aggregate loss: 0.0317


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:329  Aggregate loss: 0.0316


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:339  Aggregate loss: 0.0316


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:349  Aggregate loss: 0.0315


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:359  Aggregate loss: 0.0316


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:369  Aggregate loss: 0.0315


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:379  Aggregate loss: 0.0315


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:389  Aggregate loss: 0.0316


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:399  Aggregate loss: 0.0315


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:409  Aggregate loss: 0.0315


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:419  Aggregate loss: 0.0314


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:429  Aggregate loss: 0.0315


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:439  Aggregate loss: 0.0314


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:449  Aggregate loss: 0.0315


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:459  Aggregate loss: 0.0312


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:469  Aggregate loss: 0.0313


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:479  Aggregate loss: 0.0313


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:489  Aggregate loss: 0.0312


HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

HBox(children=(IntProgress(value=0, max=85), HTML(value='')))

Epoch number:499  Aggregate loss: 0.0312



Ran for 500 iterations <br>
Observation - Bad performance : Maybe quickly getting stuck at plateaus of local minima? <br>
              Soln: Maybe need a smaller learning rate after a while? DONE <br>
              Soln: Check paper to see if all the features have been implemented exactly (like Nesterov?)
              
Problems/Issues :

Next steps:
Read the paper again, check the exact implementation of each thing - make a list of things that you don't exactly understand - Nesterov Momentum,


Next : How to debug a neural net, how to debug an autoencoder? How to visualize where I am going wrong

2. Why is my loss value so large? Shouldn't it get normalized when I am dividing by LOG_INTERVAL?
2. How do I normalize between 0 and 1?
3. What to print exactly and where?
3. Possibly save optimizer parameters? 

#torch.save(optimizer.state_dict(), filename)
#optimizer.load_state_dict(torch.load(filename))

In [None]:
#Tweaking the learning rate to improve convergence speed
optimizer = Adam(net.parameters(), LR)

In [None]:
# Modelling Loading from saved point

# model =  Modelling.DenoisingAutoEncoder(len(trainset.variables()))
# model.load_state_dict(torch.load("./artifacts/saved_model_epoch50"))
# model.eval()

In [None]:
torch.set_printoptions(sci_mode=False)

In [52]:
net = net.eval()

print((net(trainset[0][0]).detach()))
print(trainset[0][0])

tensor([0.0973, 0.4706, 0.3143, 0.4873, 0.3732, 0.5086, 0.4446, 0.5729, 0.5312])
tensor([0.2323, 0.4893, 0.4375, 0.5070, 0.3566, 0.5128, 0.5562, 0.6483, 0.5948])


In [53]:
print(net(trainset[5][0]).detach())
print(trainset[5][0])

tensor([0.2208, 0.4831, 0.3598, 0.5005, 0.3787, 0.5136, 0.4331, 0.5873, 0.5611])
tensor([0.5859, 0.4872, 0.5234, 0.5065, 0.3109, 0.5128, 0.3333, 0.6998, 0.7010])


In [29]:
print(net(trainset[19][0]).detach())
print(trainset[19][0])

tensor([0.1975, 0.8936, 0.2887, 0.8101, 0.6161, 0.5467, 0.5638, 0.3690, 0.1681])
tensor([0.0000, 0.0000, 1.0000, 0.6250, 0.6897, 0.4231, 1.0000, 0.6290, 0.0000])


In [27]:
print(net(trainset[15][0]).detach())
print(trainset[15][0])

tensor([0.2053, 0.8955, 0.2859, 0.8039, 0.6273, 0.5452, 0.5619, 0.3419, 0.1416])
tensor([0.0833, 0.9381, 0.4483, 1.0000, 0.7241, 0.1923, 0.6923, 0.3226, 0.0250])


In [None]:
net = Modelling.DenoisingAutoEncoder(len(trainset.variables()))

criterion = nn.MSELoss()
net = net.to(device) 
train_loader = td.DataLoader(trainset, batch_size= 18, shuffle= True, **kwargs) 