In [1]:
import numpy as np
import pandas as pd
# !pip install --upgrade pytorch_lightning
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
# !pip install wandb
from pytorch_lightning.loggers import WandbLogger

from torchsummary import summary

from sklearn.impute import SimpleImputer
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import RobustScaler, LabelEncoder, OrdinalEncoder, MinMaxScaler
from sklearn.pipeline import Pipeline
import numpy as np

from pathlib import Path
from argparse import ArgumentParser
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
import pickle


data_dir = Path.home()/'data/kaggle/m5-forecasting-accuracy'

x_cat_cols = ['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id',
        'weekday', 'wday', 'month', 'year',
       'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2',
       'snap_CA', 'snap_TX', 'snap_WI']

x_cont_cols = ['sell_price']
num_train_val_days = 1913
num_test1_days = 28
num_test2_days = 28

src_len = 28
tgt_len = num_test1_days

#### TODO
 - normalize y
 - sales price is 0. fix it. 

In [None]:
!ls $data_dir

#### Sales

In [None]:
%%time
sales = pd.read_csv(data_dir/'sales_train_validation.csv')
print(f'sales.shape: {sales.shape}')
cat_cols = ['item_id', 'dept_id', 'cat_id', 'store_id', 'state_id']

# encode cat cols
encoders = {}
for col in cat_cols:
    encoder =  OrdinalEncoder()
    sales[[col]] = encoder.fit_transform(sales[[col]])
    sales[col] = sales[col].astype(np.long)
    encoders[col] = encoder
    
# change day column names to just day number
train_day_cols = {col: col.split('_')[1] for col in sales.columns if col.startswith('d_')}
sales.rename(columns=train_day_cols, inplace=True)

#### Add test data

In [None]:
test_day_cols = [str(num_train_val_days + 1 + o) for o in range(56)]
for col in test_day_cols:
    sales[col] = 0
print(sales.shape)

In [None]:
sales.tail(2)

In [None]:
sample = pd.read_csv(data_dir/'sample_submission.csv')
sample.tail(2)

In [None]:
num_days = len(train_day_cols) + len(test_day_cols)
num_stores = sales['store_id'].nunique()
num_items = sales['item_id'].nunique()
print('total days : ', num_days)
print('num store_items - ', num_stores * num_items)

In [None]:
sales['item_id'].nunique()

#### Calendar

In [None]:
calendar = pd.read_csv(data_dir/'calendar.csv')\
            .rename(columns={'d':'day'})

cat_cal_cols = ['wm_yr_wk', 'weekday', 'wday', 'month', 'year',
       'event_name_1', 'event_type_1', 'event_name_2', 'event_type_2',
       'snap_CA', 'snap_TX', 'snap_WI']
# ignore_cal_cols = ['wm_yr_wk']

for col in cat_cal_cols:
    
    # impute
    if str(calendar[col].dtype)[:3] == 'obj':
        fill_value = 'abcxyz' 
    elif str(calendar[col].dtype)[:3] == 'int':
        fill_value = -1
    calendar[[col]] = SimpleImputer(strategy='constant', fill_value=fill_value).fit_transform(calendar[[col]])
    
    # encode
    if col not in encoders:
        encoders[col] = OrdinalEncoder().fit(calendar[[col]])
    calendar[[col]] = encoders[col].transform(calendar[[col]])
    calendar[col] = calendar[col].astype(np.long)
    
# change day column names to just day number
calendar['day'] = calendar['day'].apply(lambda x: x.split('_')[1])
calendar['day'] = calendar['day'].astype(np.long)

calendar.tail(2)

#### Prices

In [None]:
%%time
prices = pd.read_csv(data_dir/'sell_prices.csv')
for col in ['store_id', 'item_id', 'wm_yr_wk']:
    prices[[col]] = encoders[col].transform(prices[[col]])
    prices[col] = prices[col].astype(np.long)

In [None]:
prices.sort_values('wm_yr_wk',ascending=False).head(2)

### Merge

In [None]:
%%time
sales2 = pd.melt(sales, id_vars=['id', 'item_id', 'dept_id', 'cat_id', 'store_id', 'state_id'], 
                                       var_name='day', value_name='demand')
sales2['day'] = sales2['day'].astype(np.long)

sales2.sort_values('day', inplace=True)
calendar.sort_values('day', inplace=True)

sales2 = sales2.merge(calendar, on='day', how='left')
sales2 = sales2.merge(prices, on=['store_id', 'item_id', 'wm_yr_wk'], how='left')
sales2['sell_price'] = sales2['sell_price'].astype(np.float32)
sales2['sell_price'] = sales2['sell_price'].fillna(0.0)

sales2.sort_values(['item_id', 'store_id','day'], inplace=True)

# scale continuous columns
scalers = {}
for col in ['sell_price','demand']:
    scaler = MinMaxScaler()
    sales2[[col]] = scaler.fit_transform(sales2[[col]])
    scalers[col] = scaler

In [None]:
sales2.to_parquet('combined.pq')
with open('encoders.pkl','wb') as f:
    pickle.dump(encoders,f)
    
with open('scalers.pkl','wb') as f:
    pickle.dump(scalers, f)

### Creating tensors

In [None]:
%%time
sales2 = pd.read_parquet('combined.pq')
print(sales2.shape)
sales2.columns

In [None]:
%%time
x = torch.tensor(sales2[x_cat_cols + x_cont_cols].values)
y = torch.tensor(sales2['demand'].values)

In [None]:
sales2[x_cat_cols + x_cont_cols].dtypes

In [None]:
%%time

# from fastai v2
def get_emb_size(nunique):
    return min(600, round(1.6 * nunique**0.56))

emb_sizes = [(sales2[col].nunique(), get_emb_size(sales2[col].nunique())) for col in x_cat_cols]

In [None]:
with open('emb_sz.pkl','wb') as f:
    pickle.dump(emb_sizes,f )

In [None]:
# group_size = num_items * num_stores
# group_size

In [None]:
%%time
num_features = x.size(1)
x1 = x.view(-1, num_days, num_features).refine_names('item_store', 'day','features')\
        .align_to('day','item_store','features').contiguous()

y1 = y.view(-1, num_days).refine_names('item_store', 'day')\
    .align_to('day', 'item_store').contiguous()

print(f'x1.shape - {x1.shape} y1.shape - {y1.shape}')

In [None]:
%%time
torch.save(x1.rename(None), 'x.pt')
torch.save(y1.rename(None), 'y.pt')


### Training

In [21]:
class M5DataSet(Dataset):
    def __init__(self,x, y, src_len, tgt_len, bsz, dstype='train'):
        assert dstype in ['train', 'test1', 'test2', 'val']
        self.x = x
        self.y = y
        self.src_len = src_len
        self.tgt_len = tgt_len
        self.bsz = bsz
        self.dstype = dstype
        
        self.val_days = self.src_len + self.tgt_len
        self.test_days = self.tgt_len * 2
        self.val_idx = self.x.size(0) - (self.val_days + self.test_days)
        self.test1_idx = self.x.size(0) - (self.src_len + self.test_days)
        print(f'val index - {self.val_idx}. test1_idx - {self.test1_idx}', )
        
    def __len__(self):
        if self.dstype == 'train':
            l = (self.x.size(0) - (self.src_len + self.val_days + self.test_days)) 
            return l
        
        if self.dstype =='val':
             return 1
        
        if self.dstype == 'test1':
            return 1
        
        return l
    
    def __getitem__(self, idx):
        if self.dstype == 'train':
            # we have 30490 item_stores. We may not be able to load them all. So randomly pick bsz items. 
            item_store_mask = list(np.random.randint(0, self.x.size(1),(self.bsz,)))
        elif self.dstype == 'val':
            item_store_mask = list(np.random.randint(0, self.x.size(1),(self.bsz,)))
            idx = self.val_idx
        elif self.dstype == 'test1':
            item_store_mask = list(np.arange(self.x.size(1)))
            idx = self.test1_idx
            print('test1 index - ', idx)
        
        x_src = self.x.rename(None)[idx:idx+self.src_len, item_store_mask, :]
        x_tgt = self.x.rename(None)[idx+self.src_len:idx+self.src_len+self.tgt_len, item_store_mask, :]
        y_src = self.y.rename(None)[idx:idx+self.src_len, item_store_mask]
        y_tgt = self.y.rename(None)[idx+self.src_len:idx+self.src_len+self.tgt_len, item_store_mask]
#         print(f'x.shape - {self.x.shape} y.shape - {self.y.shape} idx - {idx}. x_item.shape - {x_item.shape} y_item.shape - {y_item.shape}')
        return x_src, x_tgt, y_src, y_tgt, item_store_mask

# train_ds = M5DataSet(x1, y1, src_len, tgt_len, 200)
# train_dl = DataLoader(train_ds, batch_size=1, shuffle=True, pin_memory=True)

In [22]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    
# def insert_embedding(inp, dim, index, emb):
#     """
#     Replace columns with their embeddings. Works only with 2-d tensors.
#     TODO - make it work for multi-dim tensors

#     :param inp: tensor of two or more dimensions
#     :param dim: dimension along which tensor should be expanded by inserting the embedding
#     :param i: index of tensor along dim which is to be embedded
#     :param emb: Embedding of shape [v,d], where v vocab_size and d is embedding dimension
#     :return: 
#     """
#     # create a slice of the data to be replaced with embedding. 
#     s = inp.index_select(dim, torch.tensor([index])).squeeze(dim)
#     embedded = emb(s.type(torch.long))
    
#     first_indices = torch.arange(0,index)
#     last_indices = torch.arange(index+1,inp.size(dim))

#     return torch.cat([inp.index_select(dim, first_indices), embedded.type(inp.dtype), inp.index_select(dim, last_indices)], axis=dim)

In [9]:
%%time
gx = torch.load('x.pt')
gy = torch.load('y.pt')
print(f'gx.shape - {gx.shape} gy.shape - {gy.shape}')

gx.shape - torch.Size([1969, 30490, 17]) gy.shape - torch.Size([1969, 30490])
CPU times: user 0 ns, sys: 8.11 s, total: 8.11 s
Wall time: 8.09 s


In [23]:
class SalesModel(LightningModule):
    def __init__(self, hparams):
        super(SalesModel, self).__init__()
        self.hparams = hparams
        self.x_cat_cols = x_cat_cols
        self.x_cont_cols = x_cont_cols
        self.pos_encoder = PositionalEncoding(hparams.ninp, hparams.dropout)
        
        encoder_layers = nn.TransformerEncoderLayer(hparams.ninp, hparams.nhead, hparams.nhid, hparams.dropout)
        decoder_layers = nn.TransformerDecoderLayer(hparams.ninp, hparams.nhead, hparams.nhid, hparams.dropout)
        self.encoder = nn.TransformerEncoder(encoder_layers, hparams.nlayers)
        self.decoder = nn.TransformerDecoder(decoder_layers, hparams.nlayers)
        
#         self.lin = nn.Linear()
#         self.transformer = nn.Transformer(d_model=hparams.ninp, nhead=hparams.nhead, 
#                                           num_encoder_layers=hparams.nlayers,
#                                           num_decoder_layers=hparams.nlayers,
#                                           dim_feedforward=hparams.nhid)
        self.transformer = nn.Transformer(d_model=hparams.ninp,
                                          custom_encoder=self.encoder, 
                                          custom_decoder = self.decoder)
        self.criterion = nn.MSELoss()
        self.lin1 = nn.Linear(hparams.ninp, 50)
        self.lin2 = nn.Linear(50, 1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        
        print('reading data', flush=True)
        self.x = gx
        self.y = gy

        with open('emb_sz.pkl','rb') as f:
            emb_szs = pickle.load(f)
        print(f'emb_szs - {emb_szs}')
                    
        self.embs = nn.ModuleList([nn.Embedding(e[0],e[1]) for e in emb_szs])
#         self.init_weights()
        
    def init_weights(self):
        initrange = 0.1
#         self.src_embedding.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    @staticmethod
    def add_model_specifi_args(parent_parser):
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--bsz', default=20, type=int, help='batch_size', )
        parser.add_argument('--src-len', default=90, type=int, help='source length')
        parser.add_argument('--tgt-len', default=28, type=int, help='target length')
        parser.add_argument('--ninp', default=320, type=int, help='expected features in the input')
        parser.add_argument('--nhead', default=4, type=int, help='number of attention heads')
        parser.add_argument('--nhid', default=256, type=int, help='dimesion of feed-forward network model')
        parser.add_argument('--nlayers', default=2, type=int, help='number of encoder layers')
        parser.add_argument('--dropout', default=0.2, type=float, help='dropout')
        
        # they are not hyper params, but adding them as pytorch lightening can save them
        parser.add_argument('--num-cat-cols', default=len(x_cat_cols), type=int, help='number of categorical columns')
        parser.add_argument('--num-cont-cols', default=len(x_cont_cols), type=int, help='number of numeric columns')
        return parser
    
#     def _generate_square_subsequent_mask(self, sz):
#         # populate the lower triangle with True and rest with False
#         return torch.tril(torch.ones(sz, sz)) == 1.0
    
    def prepare_data(self):
        pass
        
    def train_dataloader(self):
        train_ds = M5DataSet(self.x, self.y, self.hparams.src_len, self.hparams.tgt_len, self.hparams.bsz,dstype='train')  
        print(f'train_ds.length - {len(train_ds)}')
        train_dl = DataLoader(train_ds, batch_size=1, shuffle=True, pin_memory=True)
        return train_dl
    
    def val_dataloader(self):
        train_ds = M5DataSet(self.x, self.y, self.hparams.src_len, self.hparams.tgt_len, self.hparams.bsz,dstype='val')  
        print(f'val.length - {len(train_ds)}')
        train_dl = DataLoader(train_ds, batch_size=1, shuffle=False, pin_memory=True)
        return train_dl

    
    def emb_lookups(self, xb, yb=None):
        embs_t = []
        for idx in range(self.hparams.num_cat_cols):
#             print('looking up for ', idx)
            embs_t.append(self.embs[idx](xb[:,:,idx].type(torch.long)))
        xb_cat = torch.cat(embs_t, dim=2)
        xb_cont = xb[:,:,self.hparams.num_cat_cols:]
        
        if yb is not None:
            xb = torch.cat([xb_cat, xb_cont.type(xb_cat.dtype), yb.unsqueeze(2).type(xb_cat.dtype)], dim=2)
        else:
            xb = torch.cat([xb_cat, xb_cont.type(xb_cat.dtype)], dim=2)
            
        #pad to adjust the feature dimension
        dim3_shortfall = self.hparams.ninp - xb.size(2)
        assert dim3_shortfall >= 0
        pad = nn.ConstantPad1d(padding=(0,dim3_shortfall),value=0)
        xb = pad(xb) 

        return xb

    def forward(self, x_src, y_src, x_tgt):        
        x_src = self.emb_lookups(x_src, y_src)
        x_tgt = self.emb_lookups(x_tgt)
            
        x_src = self.pos_encoder(x_src)
#         print('shape after pos encoder - ', x_src.size())
        out = self.transformer(x_src, x_tgt)
#         print('shape after transformer - ', out.size())
        out = self.relu(self.lin1(out))
        out = self.sigmoid(self.lin2(out))
        
        return out
    
    def compute_loss(self, batch, batch_idx):
        x_src, x_tgt, y_src, y_tgt, item_store_mask = batch
        x_src = x_src.squeeze(0)
        x_tgt = x_tgt.squeeze(0)
        y_src = y_src.squeeze(0)
        y_tgt = y_tgt.squeeze(0)
        
        yhat_tgt = self(x_src, y_src, x_tgt)
        loss = self.criterion((yhat_tgt).reshape(-1).type(torch.float32), (y_tgt).reshape(-1).type(torch.float32))
        return loss, y_tgt, yhat_tgt
        
    def training_step(self, batch, batch_idx):
        loss, y_tgt, yhat_tgt = self.compute_loss(batch, batch_idx)
        if batch_idx%10 == 0:
            print(f'{batch_idx} train loss: {loss}  yhat_tgt.sum: {yhat_tgt.sum().item()}  y_tgt.sum: {y_tgt.sum().item()}')
            
        return {'loss': loss}
    
    def validation_step(self, batch, batch_idx):
        loss, y_tgt, yhat_tgt = self.compute_loss(batch, batch_idx)
        if batch_idx%10 == 0:
            print(f'{batch_idx} val loss: {loss}  yhat_tgt.sum: {yhat_tgt.sum().item()}  y_tgt.sum: {y_tgt.sum().item()}')
            
        return {'val_loss': loss}
        
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=0.0001)
        schedulers = [{
             'scheduler': ReduceLROnPlateau(optimizer,patience=10, verbose=True),
             'monitor': 'loss', # Default: val_loss
             'interval': 'step',
             'frequency': 1
          }]
        scheduler = ReduceLROnPlateau(optimizer,)
#         scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, 
#                                                         max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
        return optimizer
    
#     def optimizer_step(self, current_epoch, batch_idx, optimizer, optimizer_idx,
#                        second_order_closure=None):
#         optimizer.step()
#         if batch_idx == 5:
#             for name, param in model.named_parameters():
#                 if param.requires_grad:
#                     pass
# #                     print(name, param.grad)
#         optimizer.zero_grad()
    
#     def test(self):
#         dl = self.test_dataloader()
#         batch = next(iter(dl))
#         return batch
    
    
#     def test_dataloader(self):
#         test_ds = M5DataSet(self.x, self.y, self.hparams.src_len, self.hparams.tgt_len, self.hparams.bsz, dstype='test1')  
#         test_dl = DataLoader(test_ds, batch_size=1, shuffle=False)
#         return test_dl
        

In [24]:
# bsz = 200
# model = SalesModel(hparams)

parser = ArgumentParser()
parser = SalesModel.add_model_specifi_args(parser)
hparams = parser.parse_args('--bsz 1000 --ninp 320 --nhid 512 --nlayers 1'.split())

checkpoint_callback = ModelCheckpoint(
    filepath='models/weights.ckpt',
    verbose=True
)

model = SalesModel(hparams)

wandb_logger = WandbLogger(name='achinta',project='kaggle-m5-forecasting-accuracy')
trainer = Trainer(gpus=1,max_epochs=1,auto_lr_find=False, val_check_interval=10)
trainer.fit(model)
trainer.save_checkpoint('models/weights_v1.1.ckpt')

reading data


GPU available: True, used: True
No environment variable for node rank defined. Set as 0.
CUDA_VISIBLE_DEVICES: [0]

   | Name                                     | Type                    | Params
---------------------------------------------------------------------------------
0  | pos_encoder                              | PositionalEncoding      | 0     
1  | pos_encoder.dropout                      | Dropout                 | 0     
2  | encoder                                  | TransformerEncoder      | 740 K 
3  | encoder.layers                           | ModuleList              | 740 K 
4  | encoder.layers.0                         | TransformerEncoderLayer | 740 K 
5  | encoder.layers.0.self_attn               | MultiheadAttention      | 410 K 
6  | encoder.layers.0.self_attn.out_proj      | Linear                  | 102 K 
7  | encoder.layers.0.linear1                 | Linear                  | 164 K 
8  | encoder.layers.0.dropout                 | Dropout                 |

emb_szs - [(3049, 143), (7, 5), (3, 3), (10, 6), (3, 3), (7, 5), (7, 5), (12, 6), (6, 4), (31, 11), (5, 4), (5, 4), (3, 3), (2, 2), (2, 2), (2, 2)]
val index - 1795. test1_idx - 1823
val.length - 1




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

0 val loss: 0.26271411776542664  yhat_tgt.sum: 14353.744140625  y_tgt.sum: 51.80602883355177
val index - 1795. test1_idx - 1823
train_ds.length - 1705
val index - 1795. test1_idx - 1823
val.length - 1




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

0 train loss: 0.25655174255371094  yhat_tgt.sum: 14171.443359375  y_tgt.sum: 38.73525557011796


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.020812438800930977  yhat_tgt.sum: 4064.185302734375  y_tgt.sum: 50.293577981651374
10 train loss: 0.031039496883749962  yhat_tgt.sum: 4929.40234375  y_tgt.sum: 41.351245085190044




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.0023068662267178297  yhat_tgt.sum: 1381.865234375  y_tgt.sum: 46.40235910878113
20 train loss: 0.004795339424163103  yhat_tgt.sum: 1961.9674072265625  y_tgt.sum: 41.048492791612055


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.0007366426871158183  yhat_tgt.sum: 798.4913330078125  y_tgt.sum: 49.625163826998694
30 train loss: 0.0018610068364068866  yhat_tgt.sum: 1226.147705078125  y_tgt.sum: 30.107470511140235


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.0004237604734953493  yhat_tgt.sum: 608.7508544921875  y_tgt.sum: 49.239842726081264
40 train loss: 0.0011050391476601362  yhat_tgt.sum: 952.525390625  y_tgt.sum: 42.0694626474443


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.0003087493241764605  yhat_tgt.sum: 520.87744140625  y_tgt.sum: 44.88859764089122
50 train loss: 0.0008699744939804077  yhat_tgt.sum: 848.742431640625  y_tgt.sum: 40.715596330275226


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.00024085596669465303  yhat_tgt.sum: 467.3934326171875  y_tgt.sum: 44.92791612057667
60 train loss: 0.0007243193103931844  yhat_tgt.sum: 774.5672607421875  y_tgt.sum: 49.72346002621232


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.0002050382026936859  yhat_tgt.sum: 431.087890625  y_tgt.sum: 50.66841415465269
70 train loss: 0.0005767869297415018  yhat_tgt.sum: 700.76171875  y_tgt.sum: 44.381389252948885


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.00020185943867545575  yhat_tgt.sum: 402.5251159667969  y_tgt.sum: 60.17824377457406
80 train loss: 0.0005282160709612072  yhat_tgt.sum: 667.321044921875  y_tgt.sum: 44.43643512450852


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.000165882651344873  yhat_tgt.sum: 377.55322265625  y_tgt.sum: 53.48099606815203
90 train loss: 0.0004489048442337662  yhat_tgt.sum: 617.5247192382812  y_tgt.sum: 38.28047182175622


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.00014152024232316762  yhat_tgt.sum: 357.5767517089844  y_tgt.sum: 50.25688073394495
100 train loss: 0.0004265518509782851  yhat_tgt.sum: 603.6793212890625  y_tgt.sum: 40.37876802096986


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.00012671730655711144  yhat_tgt.sum: 338.90606689453125  y_tgt.sum: 49.7994757536042
110 train loss: 0.00037697056541219354  yhat_tgt.sum: 566.0552978515625  y_tgt.sum: 45.30799475753604


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.00013342410966288298  yhat_tgt.sum: 322.3496398925781  y_tgt.sum: 55.387942332896465
120 train loss: 0.00038054180913604796  yhat_tgt.sum: 545.1176147460938  y_tgt.sum: 49.456094364351245


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 0.00010444834333611652  yhat_tgt.sum: 305.50091552734375  y_tgt.sum: 54.18348623853211
130 train loss: 0.00032779324101284146  yhat_tgt.sum: 528.114501953125  y_tgt.sum: 36.49410222804718


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 8.962760330177844e-05  yhat_tgt.sum: 292.56817626953125  y_tgt.sum: 45.11926605504587
140 train loss: 0.0002886794682126492  yhat_tgt.sum: 500.476806640625  y_tgt.sum: 39.62254259501966


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 8.762708603171632e-05  yhat_tgt.sum: 279.4927062988281  y_tgt.sum: 50.52031454783748
150 train loss: 0.0002693379356060177  yhat_tgt.sum: 481.9508056640625  y_tgt.sum: 39.782437745740495


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 8.409666043007746e-05  yhat_tgt.sum: 267.85205078125  y_tgt.sum: 51.449541284403665
160 train loss: 0.0002574159298092127  yhat_tgt.sum: 466.95538330078125  y_tgt.sum: 45.44954128440367


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 6.998189928708598e-05  yhat_tgt.sum: 257.2681884765625  y_tgt.sum: 50.186107470511146
170 train loss: 0.00021867573377676308  yhat_tgt.sum: 440.9133605957031  y_tgt.sum: 47.30668414154653


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 6.787392339901999e-05  yhat_tgt.sum: 247.61090087890625  y_tgt.sum: 47.693315858453474
180 train loss: 0.00021629175171256065  yhat_tgt.sum: 433.80487060546875  y_tgt.sum: 39.528178243774576


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 7.245261076604947e-05  yhat_tgt.sum: 238.2282257080078  y_tgt.sum: 53.435124508519
190 train loss: 0.0002047921298071742  yhat_tgt.sum: 424.4147033691406  y_tgt.sum: 49.22804718217563


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 6.080671664676629e-05  yhat_tgt.sum: 230.5347900390625  y_tgt.sum: 48.95806028833552
200 train loss: 0.00018666622054297477  yhat_tgt.sum: 398.3976745605469  y_tgt.sum: 47.65923984272608


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 5.642247560899705e-05  yhat_tgt.sum: 222.0377197265625  y_tgt.sum: 47.80602883355177
210 train loss: 0.00017387063417118043  yhat_tgt.sum: 392.72967529296875  y_tgt.sum: 46.77588466579292


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 5.443954432848841e-05  yhat_tgt.sum: 215.08578491210938  y_tgt.sum: 51.85714285714286
220 train loss: 0.00016446106019429862  yhat_tgt.sum: 380.64495849609375  y_tgt.sum: 38.646133682830936


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.9550464609637856e-05  yhat_tgt.sum: 207.67047119140625  y_tgt.sum: 49.85452162516383
230 train loss: 0.00016546099504921585  yhat_tgt.sum: 372.8734436035156  y_tgt.sum: 46.76408912188729


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.930062641506083e-05  yhat_tgt.sum: 201.90377807617188  y_tgt.sum: 50.05373525557012
240 train loss: 0.00048049972974695265  yhat_tgt.sum: 362.03302001953125  y_tgt.sum: 64.88073394495413


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 5.288937609293498e-05  yhat_tgt.sum: 196.11634826660156  y_tgt.sum: 50.58584534731324
250 train loss: 0.00014359246415551752  yhat_tgt.sum: 348.73504638671875  y_tgt.sum: 36.83748361730014


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.4671978685073555e-05  yhat_tgt.sum: 190.36927795410156  y_tgt.sum: 53.68283093053736
260 train loss: 0.0001507000415585935  yhat_tgt.sum: 345.660888671875  y_tgt.sum: 30.078636959370904


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.197670932626352e-05  yhat_tgt.sum: 185.30841064453125  y_tgt.sum: 50.53342070773264
270 train loss: 0.00013383536133915186  yhat_tgt.sum: 335.5980224609375  y_tgt.sum: 35.92005242463958


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.656612534541637e-05  yhat_tgt.sum: 179.9285125732422  y_tgt.sum: 46.638269986893846
280 train loss: 0.00011901111429324374  yhat_tgt.sum: 325.4118957519531  y_tgt.sum: 42.60943643512451


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 5.71472046431154e-05  yhat_tgt.sum: 175.09161376953125  y_tgt.sum: 58.67365661861075
290 train loss: 0.00013105083780828863  yhat_tgt.sum: 324.3260192871094  y_tgt.sum: 33.00917431192661


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.397838008822873e-05  yhat_tgt.sum: 170.71951293945312  y_tgt.sum: 48.824377457404985
300 train loss: 0.00011967614409513772  yhat_tgt.sum: 312.38787841796875  y_tgt.sum: 41.39449541284404


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.194619785062969e-05  yhat_tgt.sum: 166.1361083984375  y_tgt.sum: 47.68020969855832
310 train loss: 0.00010508191189728677  yhat_tgt.sum: 302.64410400390625  y_tgt.sum: 43.92398427260812


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.2507217585807666e-05  yhat_tgt.sum: 162.1432647705078  y_tgt.sum: 52.26605504587156
320 train loss: 9.747068543219939e-05  yhat_tgt.sum: 291.8116455078125  y_tgt.sum: 35.80733944954129


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.5971999750472605e-05  yhat_tgt.sum: 158.4427490234375  y_tgt.sum: 53.8872870249017
330 train loss: 9.21335958992131e-05  yhat_tgt.sum: 282.128662109375  y_tgt.sum: 45.328964613368285


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.1241644794354215e-05  yhat_tgt.sum: 154.39230346679688  y_tgt.sum: 57.08125819134993
340 train loss: 9.135322761721909e-05  yhat_tgt.sum: 283.3459777832031  y_tgt.sum: 47.33289646133683


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.6814421036979184e-05  yhat_tgt.sum: 151.18490600585938  y_tgt.sum: 53.10353866317169
350 train loss: 9.05115157365799e-05  yhat_tgt.sum: 276.3083190917969  y_tgt.sum: 43.809960681520316


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 5.484638313646428e-05  yhat_tgt.sum: 147.84725952148438  y_tgt.sum: 60.48230668414154
360 train loss: 8.971226634457707e-05  yhat_tgt.sum: 276.4155578613281  y_tgt.sum: 37.699868938401046


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.763006523309741e-05  yhat_tgt.sum: 144.6981201171875  y_tgt.sum: 45.678899082568805
370 train loss: 8.359264029422775e-05  yhat_tgt.sum: 266.41851806640625  y_tgt.sum: 49.30144167758847


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.1017443689052016e-05  yhat_tgt.sum: 141.7071990966797  y_tgt.sum: 53.861074705111406
380 train loss: 8.346510003320873e-05  yhat_tgt.sum: 263.9418640136719  y_tgt.sum: 39.16382699868939


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.578275052655954e-05  yhat_tgt.sum: 139.051025390625  y_tgt.sum: 47.211009174311926
390 train loss: 8.049289317568764e-05  yhat_tgt.sum: 261.0591735839844  y_tgt.sum: 41.17562254259502


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.77045180357527e-05  yhat_tgt.sum: 135.854248046875  y_tgt.sum: 52.470511140235914
400 train loss: 7.629773608641699e-05  yhat_tgt.sum: 252.43927001953125  y_tgt.sum: 42.313237221494106


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.5867000406142324e-05  yhat_tgt.sum: 133.47235107421875  y_tgt.sum: 51.986893840104855
410 train loss: 7.799887680448592e-05  yhat_tgt.sum: 255.11119079589844  y_tgt.sum: 28.32241153342071


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.427716026431881e-05  yhat_tgt.sum: 130.7816162109375  y_tgt.sum: 47.42332896461337
420 train loss: 9.25803542486392e-05  yhat_tgt.sum: 249.43515014648438  y_tgt.sum: 41.6959370904325


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.317085429443978e-05  yhat_tgt.sum: 127.93685913085938  y_tgt.sum: 54.150720838794236
430 train loss: 0.00012673484161496162  yhat_tgt.sum: 247.01039123535156  y_tgt.sum: 28.990825688073397


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.6231229639961384e-05  yhat_tgt.sum: 126.04931640625  y_tgt.sum: 48.92398427260812
440 train loss: 6.082807885832153e-05  yhat_tgt.sum: 233.709716796875  y_tgt.sum: 40.129750982961994


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.2790014554630034e-05  yhat_tgt.sum: 123.64067840576172  y_tgt.sum: 52.94495412844037
450 train loss: 8.073150820564479e-05  yhat_tgt.sum: 231.96144104003906  y_tgt.sum: 41.61336828309305


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.2596163034904748e-05  yhat_tgt.sum: 121.51522064208984  y_tgt.sum: 48.06422018348624
460 train loss: 7.118976645870134e-05  yhat_tgt.sum: 228.82069396972656  y_tgt.sum: 41.996068152031455


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.423589441808872e-05  yhat_tgt.sum: 119.59880065917969  y_tgt.sum: 53.67234600262124
470 train loss: 5.952673382125795e-05  yhat_tgt.sum: 222.66946411132812  y_tgt.sum: 38.758846657929226


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.238115798216313e-05  yhat_tgt.sum: 117.21308135986328  y_tgt.sum: 59.48623853211009
480 train loss: 7.39455281291157e-05  yhat_tgt.sum: 221.83145141601562  y_tgt.sum: 50.97509829619922


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.900398587575182e-05  yhat_tgt.sum: 115.58586883544922  y_tgt.sum: 55.49279161205767
490 train loss: 6.731404573656619e-05  yhat_tgt.sum: 217.7392578125  y_tgt.sum: 48.71035386631716


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.459658389852848e-05  yhat_tgt.sum: 113.39459228515625  y_tgt.sum: 51.8348623853211
500 train loss: 9.130610851570964e-05  yhat_tgt.sum: 215.45510864257812  y_tgt.sum: 48.23197903014417


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.2341730073094368e-05  yhat_tgt.sum: 111.73346710205078  y_tgt.sum: 47.92005242463958
510 train loss: 6.008268246660009e-05  yhat_tgt.sum: 216.22808837890625  y_tgt.sum: 36.861074705111406


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.604643304948695e-05  yhat_tgt.sum: 109.80258178710938  y_tgt.sum: 53.018348623853214
520 train loss: 5.75059384573251e-05  yhat_tgt.sum: 205.926025390625  y_tgt.sum: 42.906946264744434


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.2987562008202076e-05  yhat_tgt.sum: 108.62513732910156  y_tgt.sum: 54.8086500655308
530 train loss: 6.204485544003546e-05  yhat_tgt.sum: 211.57000732421875  y_tgt.sum: 39.093053735255566


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.155131005565636e-05  yhat_tgt.sum: 106.45449829101562  y_tgt.sum: 51.174311926605505
540 train loss: 5.068423706688918e-05  yhat_tgt.sum: 205.07745361328125  y_tgt.sum: 37.96199213630406


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0581168428179808e-05  yhat_tgt.sum: 104.96969604492188  y_tgt.sum: 47.914809960681524
550 train loss: 4.894907397101633e-05  yhat_tgt.sum: 198.74807739257812  y_tgt.sum: 37.97640891218873


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0282443074393086e-05  yhat_tgt.sum: 103.42233276367188  y_tgt.sum: 46.74311926605505
560 train loss: 5.595581023953855e-05  yhat_tgt.sum: 197.92100524902344  y_tgt.sum: 36.80602883355177


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.335512130637653e-05  yhat_tgt.sum: 101.94189453125  y_tgt.sum: 55.5124508519004
570 train loss: 5.1466384320519865e-05  yhat_tgt.sum: 195.70657348632812  y_tgt.sum: 43.77719528178244


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.2176653146743774e-05  yhat_tgt.sum: 100.49835205078125  y_tgt.sum: 48.66448230668414
580 train loss: 4.3871412344742566e-05  yhat_tgt.sum: 190.73968505859375  y_tgt.sum: 45.87418086500655


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.7603775631869212e-05  yhat_tgt.sum: 99.04144287109375  y_tgt.sum: 48.88859764089122
590 train loss: 4.7449466364923865e-05  yhat_tgt.sum: 187.7623748779297  y_tgt.sum: 39.69462647444299


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.240235153294634e-05  yhat_tgt.sum: 97.68538665771484  y_tgt.sum: 50.31847968545216
600 train loss: 5.332756700227037e-05  yhat_tgt.sum: 186.45518493652344  y_tgt.sum: 42.937090432503275


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.296815702924505e-05  yhat_tgt.sum: 96.24080657958984  y_tgt.sum: 59.083879423328966
610 train loss: 5.454504571389407e-05  yhat_tgt.sum: 189.0104522705078  y_tgt.sum: 32.833551769331585


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0840236174990423e-05  yhat_tgt.sum: 94.74868774414062  y_tgt.sum: 47.43512450851901
620 train loss: 4.2373372707515955e-05  yhat_tgt.sum: 184.39047241210938  y_tgt.sum: 40.085190039318476


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 5.2121704356977716e-05  yhat_tgt.sum: 93.74612426757812  y_tgt.sum: 61.1821756225426
630 train loss: 0.0002848795847967267  yhat_tgt.sum: 181.2539825439453  y_tgt.sum: 54.50589777195282


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0576391762006097e-05  yhat_tgt.sum: 92.64632415771484  y_tgt.sum: 52.1559633027523
640 train loss: 4.65754019387532e-05  yhat_tgt.sum: 181.87490844726562  y_tgt.sum: 26.64089121887287


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.9144841644447297e-05  yhat_tgt.sum: 91.32467651367188  y_tgt.sum: 50.802096985583226
650 train loss: 4.483534212340601e-05  yhat_tgt.sum: 177.20986938476562  y_tgt.sum: 46.03800786369594


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.675240127951838e-05  yhat_tgt.sum: 90.1884765625  y_tgt.sum: 47.48230668414155
660 train loss: 4.6675897465320304e-05  yhat_tgt.sum: 175.88888549804688  y_tgt.sum: 46.764089121887295


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.314173227408901e-05  yhat_tgt.sum: 89.0616455078125  y_tgt.sum: 50.530799475753604
670 train loss: 5.025296559324488e-05  yhat_tgt.sum: 168.62310791015625  y_tgt.sum: 50.81913499344692


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0476325516938232e-05  yhat_tgt.sum: 87.96541595458984  y_tgt.sum: 51.684141546526874
680 train loss: 3.811152055277489e-05  yhat_tgt.sum: 171.25930786132812  y_tgt.sum: 44.86500655307995


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.7227697753696702e-05  yhat_tgt.sum: 86.82949829101562  y_tgt.sum: 46.60419397116645
690 train loss: 4.417479794938117e-05  yhat_tgt.sum: 170.33761596679688  y_tgt.sum: 36.20707732634338


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.077172939607408e-05  yhat_tgt.sum: 85.83027648925781  y_tgt.sum: 51.60419397116645
700 train loss: 3.740070678759366e-05  yhat_tgt.sum: 168.75424194335938  y_tgt.sum: 24.79030144167759


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.5477136912522838e-05  yhat_tgt.sum: 84.57182312011719  y_tgt.sum: 43.10091743119266
710 train loss: 3.5923352697864175e-05  yhat_tgt.sum: 165.60964965820312  y_tgt.sum: 37.5124508519004


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.2427337171393447e-05  yhat_tgt.sum: 83.71330261230469  y_tgt.sum: 50.06815203145479
720 train loss: 3.334970824653283e-05  yhat_tgt.sum: 162.4247283935547  y_tgt.sum: 44.33158584534732


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0596686226781458e-05  yhat_tgt.sum: 82.74102783203125  y_tgt.sum: 49.53735255570118
730 train loss: 4.953921597916633e-05  yhat_tgt.sum: 164.22210693359375  y_tgt.sum: 37.39187418086501


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.685861025180202e-05  yhat_tgt.sum: 81.78392028808594  y_tgt.sum: 47.40498034076016
740 train loss: 3.150453994749114e-05  yhat_tgt.sum: 161.30880737304688  y_tgt.sum: 41.43119266055046


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.3857328415033408e-05  yhat_tgt.sum: 80.966064453125  y_tgt.sum: 53.00655307994758
750 train loss: 4.497449845075607e-05  yhat_tgt.sum: 159.29318237304688  y_tgt.sum: 44.05897771952818


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.7303256501909345e-05  yhat_tgt.sum: 79.95790100097656  y_tgt.sum: 52.16120576671035
760 train loss: 3.8343958294717595e-05  yhat_tgt.sum: 160.82174682617188  y_tgt.sum: 33.964613368283096


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.454227978887502e-05  yhat_tgt.sum: 79.07150268554688  y_tgt.sum: 56.26867627785059
770 train loss: 3.63552535418421e-05  yhat_tgt.sum: 154.77377319335938  y_tgt.sum: 36.18348623853211


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.8387554518994875e-05  yhat_tgt.sum: 78.29164123535156  y_tgt.sum: 48.85058977719529
780 train loss: 4.604614514391869e-05  yhat_tgt.sum: 156.30926513671875  y_tgt.sum: 40.21625163826999


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.1076126358821057e-05  yhat_tgt.sum: 77.47715759277344  y_tgt.sum: 49.34076015727392
790 train loss: 3.489839946269058e-05  yhat_tgt.sum: 152.29505920410156  y_tgt.sum: 41.900393184796854


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.834388917312026e-05  yhat_tgt.sum: 76.50254821777344  y_tgt.sum: 55.1048492791612
800 train loss: 3.929856757167727e-05  yhat_tgt.sum: 152.49176025390625  y_tgt.sum: 40.93315858453474


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.3485961719416082e-05  yhat_tgt.sum: 75.65193176269531  y_tgt.sum: 51.669724770642205
810 train loss: 6.894148100400344e-05  yhat_tgt.sum: 153.02633666992188  y_tgt.sum: 51.77588466579292


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0325389414210804e-05  yhat_tgt.sum: 75.0226058959961  y_tgt.sum: 47.96985583224116
820 train loss: 3.831741923931986e-05  yhat_tgt.sum: 150.4251708984375  y_tgt.sum: 37.71690694626474


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.2230657123145647e-05  yhat_tgt.sum: 74.09312438964844  y_tgt.sum: 50.492791612057665
830 train loss: 6.409371417248622e-05  yhat_tgt.sum: 149.58004760742188  y_tgt.sum: 43.48099606815204


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.8770265998900868e-05  yhat_tgt.sum: 73.71035766601562  y_tgt.sum: 54.76015727391874
840 train loss: 4.9984366341959685e-05  yhat_tgt.sum: 144.88011169433594  y_tgt.sum: 48.010484927916124


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.4934403225197457e-05  yhat_tgt.sum: 72.7723388671875  y_tgt.sum: 43.82961992136305
850 train loss: 3.4816304832929745e-05  yhat_tgt.sum: 144.34840393066406  y_tgt.sum: 43.85452162516383


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.889539089461323e-05  yhat_tgt.sum: 72.06828308105469  y_tgt.sum: 48.17169069462648
860 train loss: 8.866718417266384e-05  yhat_tgt.sum: 142.96807861328125  y_tgt.sum: 48.44298820445609


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.294623845955357e-05  yhat_tgt.sum: 71.54640197753906  y_tgt.sum: 53.398427260812575
870 train loss: 6.345305155264214e-05  yhat_tgt.sum: 141.1991729736328  y_tgt.sum: 41.200524246395815


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.652186801948119e-05  yhat_tgt.sum: 70.59183502197266  y_tgt.sum: 46.54128440366972
880 train loss: 2.8057183953933418e-05  yhat_tgt.sum: 139.22860717773438  y_tgt.sum: 40.14547837483617


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.8450089555699378e-05  yhat_tgt.sum: 69.9764404296875  y_tgt.sum: 52.736566186107474
890 train loss: 3.8563252019230276e-05  yhat_tgt.sum: 140.73153686523438  y_tgt.sum: 44.522935779816514


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.8392619899241254e-05  yhat_tgt.sum: 69.34402465820312  y_tgt.sum: 60.27391874180866
900 train loss: 4.0234921470982954e-05  yhat_tgt.sum: 140.70631408691406  y_tgt.sum: 42.201834862385326


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.617808240756858e-05  yhat_tgt.sum: 68.72356414794922  y_tgt.sum: 54.790301441677585
910 train loss: 4.5445962314261124e-05  yhat_tgt.sum: 137.76095581054688  y_tgt.sum: 44.38401048492791


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.4482372534985188e-05  yhat_tgt.sum: 68.0915298461914  y_tgt.sum: 45.68152031454784
920 train loss: 3.205543180229142e-05  yhat_tgt.sum: 134.80799865722656  y_tgt.sum: 44.19397116644823


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.6841293472680263e-05  yhat_tgt.sum: 67.57293701171875  y_tgt.sum: 47.226736566186105
930 train loss: 2.7410133043304086e-05  yhat_tgt.sum: 136.4923858642578  y_tgt.sum: 34.086500655308


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.536163265176583e-05  yhat_tgt.sum: 66.94575500488281  y_tgt.sum: 49.05897771952817
940 train loss: 3.07851041725371e-05  yhat_tgt.sum: 135.0506591796875  y_tgt.sum: 34.1480996068152


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.3297430996317416e-05  yhat_tgt.sum: 66.3302230834961  y_tgt.sum: 46.425950196592396
950 train loss: 3.179508348694071e-05  yhat_tgt.sum: 134.33383178710938  y_tgt.sum: 26.25032765399738


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0358513211249374e-05  yhat_tgt.sum: 65.65247344970703  y_tgt.sum: 50.74049803407601
960 train loss: 4.2918738472508267e-05  yhat_tgt.sum: 129.80630493164062  y_tgt.sum: 46.35386631716907


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.311350363015663e-05  yhat_tgt.sum: 65.00611877441406  y_tgt.sum: 48.9908256880734
970 train loss: 2.197233698097989e-05  yhat_tgt.sum: 131.1497802734375  y_tgt.sum: 38.29882044560944


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.082657374558039e-05  yhat_tgt.sum: 64.53856658935547  y_tgt.sum: 45.67234600262123
980 train loss: 4.675026502809487e-05  yhat_tgt.sum: 130.25010681152344  y_tgt.sum: 42.319790301441685


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.9482895368128084e-05  yhat_tgt.sum: 63.925331115722656  y_tgt.sum: 50.314547837483616
990 train loss: 2.414302434772253e-05  yhat_tgt.sum: 127.22187042236328  y_tgt.sum: 46.73525557011796


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.012586446653586e-05  yhat_tgt.sum: 63.46833419799805  y_tgt.sum: 50.5478374836173
1000 train loss: 2.664427665877156e-05  yhat_tgt.sum: 128.73876953125  y_tgt.sum: 36.50589777195282


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 9.698520443635061e-06  yhat_tgt.sum: 62.75187683105469  y_tgt.sum: 41.929226736566186
1010 train loss: 3.3259097108384594e-05  yhat_tgt.sum: 125.44065856933594  y_tgt.sum: 46.899082568807344


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.6348732970072888e-05  yhat_tgt.sum: 62.12420654296875  y_tgt.sum: 47.93315858453473
1020 train loss: 2.8261132683837786e-05  yhat_tgt.sum: 126.14376831054688  y_tgt.sum: 43.28964613368283


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.7493846573634073e-05  yhat_tgt.sum: 61.81189727783203  y_tgt.sum: 53.33158584534731
1030 train loss: 3.563424979802221e-05  yhat_tgt.sum: 125.10594177246094  y_tgt.sum: 44.94888597640892


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.5152023479458876e-05  yhat_tgt.sum: 61.225215911865234  y_tgt.sum: 50.95019659239843
1040 train loss: 4.047342372359708e-05  yhat_tgt.sum: 123.67290496826172  y_tgt.sum: 49.94626474442988


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.1082898456370458e-05  yhat_tgt.sum: 60.780860900878906  y_tgt.sum: 53.54652686762779
1050 train loss: 2.783872332656756e-05  yhat_tgt.sum: 121.98873901367188  y_tgt.sum: 39.22804718217562


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.842823050741572e-05  yhat_tgt.sum: 60.35253143310547  y_tgt.sum: 54.23984272608126
1060 train loss: 6.074683187762275e-05  yhat_tgt.sum: 121.84790802001953  y_tgt.sum: 53.567496723460025


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.5342291337437928e-05  yhat_tgt.sum: 59.881256103515625  y_tgt.sum: 53.17562254259502
1070 train loss: 3.3309170248685405e-05  yhat_tgt.sum: 122.5614242553711  y_tgt.sum: 39.328964613368285


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.585102654644288e-05  yhat_tgt.sum: 59.404449462890625  y_tgt.sum: 49.685452162516384
1080 train loss: 2.5560368158039637e-05  yhat_tgt.sum: 119.46522521972656  y_tgt.sum: 44.031454783748366


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.1647087123710662e-05  yhat_tgt.sum: 58.96063232421875  y_tgt.sum: 49.646133682830936
1090 train loss: 2.3542808776255697e-05  yhat_tgt.sum: 118.47673034667969  y_tgt.sum: 42.25688073394495


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.3833355953684077e-05  yhat_tgt.sum: 58.519203186035156  y_tgt.sum: 50.63826998689384
1100 train loss: 4.476725007407367e-05  yhat_tgt.sum: 117.92916870117188  y_tgt.sum: 45.23591087811271


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0075254724361002e-05  yhat_tgt.sum: 58.077476501464844  y_tgt.sum: 51.57667103538663
1110 train loss: 2.2451453332905658e-05  yhat_tgt.sum: 118.18838500976562  y_tgt.sum: 25.327653997378768


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.3024000256555155e-05  yhat_tgt.sum: 57.539337158203125  y_tgt.sum: 51.180865006553084
1120 train loss: 3.470796218607575e-05  yhat_tgt.sum: 115.82015991210938  y_tgt.sum: 47.50851900393185


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.9571119739557616e-05  yhat_tgt.sum: 57.095298767089844  y_tgt.sum: 52.34862385321101
1130 train loss: 3.4976521419594064e-05  yhat_tgt.sum: 117.09136199951172  y_tgt.sum: 33.884665792922675


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.038962651975453e-05  yhat_tgt.sum: 56.695823669433594  y_tgt.sum: 52.062909567496725
1140 train loss: 2.4364751880057156e-05  yhat_tgt.sum: 116.51589965820312  y_tgt.sum: 30.804718217562257


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.1073592506581917e-05  yhat_tgt.sum: 56.15866470336914  y_tgt.sum: 53.8951507208388
1150 train loss: 3.224063402740285e-05  yhat_tgt.sum: 114.06309509277344  y_tgt.sum: 43.528178243774576


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.501251194393262e-05  yhat_tgt.sum: 55.794281005859375  y_tgt.sum: 55.131061598951504
1160 train loss: 3.422111694817431e-05  yhat_tgt.sum: 112.78518676757812  y_tgt.sum: 48.65923984272608


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.6247973437421024e-05  yhat_tgt.sum: 55.274574279785156  y_tgt.sum: 54.714285714285715
1170 train loss: 5.4267788073047996e-05  yhat_tgt.sum: 112.64845275878906  y_tgt.sum: 51.11926605504587


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.6042020433815196e-05  yhat_tgt.sum: 54.93115997314453  y_tgt.sum: 47.50982961992136
1180 train loss: 2.3210184735944495e-05  yhat_tgt.sum: 112.08656311035156  y_tgt.sum: 40.246395806028836


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.8691920558921993e-05  yhat_tgt.sum: 54.57753372192383  y_tgt.sum: 49.373525557011796
1190 train loss: 2.9844066375517286e-05  yhat_tgt.sum: 110.74999237060547  y_tgt.sum: 42.83617300131061


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.7409454560256563e-05  yhat_tgt.sum: 54.15253829956055  y_tgt.sum: 51.93971166448231
1200 train loss: 4.006639574072324e-05  yhat_tgt.sum: 110.74241638183594  y_tgt.sum: 50.17169069462648


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.6061014321167022e-05  yhat_tgt.sum: 53.73183059692383  y_tgt.sum: 50.23722149410223
1210 train loss: 3.661255323095247e-05  yhat_tgt.sum: 111.33431243896484  y_tgt.sum: 39.737876802096984


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.6464236978208646e-05  yhat_tgt.sum: 53.484642028808594  y_tgt.sum: 47.66579292267366
1220 train loss: 4.856348823523149e-05  yhat_tgt.sum: 108.07701110839844  y_tgt.sum: 53.661861074705115


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 4.623067798092961e-05  yhat_tgt.sum: 53.078712463378906  y_tgt.sum: 55.972477064220186
1230 train loss: 2.0955832951585762e-05  yhat_tgt.sum: 108.35709381103516  y_tgt.sum: 38.917431192660544


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.380158366577234e-05  yhat_tgt.sum: 52.71168899536133  y_tgt.sum: 50.403669724770644
1240 train loss: 3.073201150982641e-05  yhat_tgt.sum: 107.10025024414062  y_tgt.sum: 44.397116644823065


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.041823427134659e-05  yhat_tgt.sum: 52.44652557373047  y_tgt.sum: 50.25950196592398
1250 train loss: 1.8882921722251922e-05  yhat_tgt.sum: 104.63253784179688  y_tgt.sum: 41.05897771952817


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.8636111892410554e-05  yhat_tgt.sum: 52.09314727783203  y_tgt.sum: 55.551769331585845
1260 train loss: 2.191306703025475e-05  yhat_tgt.sum: 108.03497314453125  y_tgt.sum: 31.11533420707733


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.4399687643162906e-05  yhat_tgt.sum: 51.738590240478516  y_tgt.sum: 47.31979030144168
1270 train loss: 3.383222428965382e-05  yhat_tgt.sum: 107.5146484375  y_tgt.sum: 41.20445609436436


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.091849066549912e-05  yhat_tgt.sum: 51.34637451171875  y_tgt.sum: 53.61336828309305
1280 train loss: 4.488936974667013e-05  yhat_tgt.sum: 104.52174377441406  y_tgt.sum: 48.564875491481


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0386716641951352e-05  yhat_tgt.sum: 50.99983215332031  y_tgt.sum: 51.479685452162514
1290 train loss: 2.5579091015970334e-05  yhat_tgt.sum: 103.3935775756836  y_tgt.sum: 46.545216251638266


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.8833823055028915e-05  yhat_tgt.sum: 50.686729431152344  y_tgt.sum: 53.34862385321101
1300 train loss: 3.105403084191494e-05  yhat_tgt.sum: 103.57960510253906  y_tgt.sum: 49.424639580602886


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0749479517689906e-05  yhat_tgt.sum: 50.28396224975586  y_tgt.sum: 50.05897771952818
1310 train loss: 2.417017094558105e-05  yhat_tgt.sum: 103.39579772949219  y_tgt.sum: 38.06815203145479


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.2745452952221967e-05  yhat_tgt.sum: 49.93714141845703  y_tgt.sum: 50.71952817824378
1320 train loss: 1.955632251338102e-05  yhat_tgt.sum: 102.26380920410156  y_tgt.sum: 41.11926605504587


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.4554861991200596e-05  yhat_tgt.sum: 49.67703628540039  y_tgt.sum: 46.06422018348624
1330 train loss: 2.989853237522766e-05  yhat_tgt.sum: 102.8696517944336  y_tgt.sum: 41.809960681520316


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.9695737137226388e-05  yhat_tgt.sum: 49.34772491455078  y_tgt.sum: 51.7562254259502
1340 train loss: 4.273795639164746e-05  yhat_tgt.sum: 100.76339721679688  y_tgt.sum: 47.627785058977715


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.4286115149152465e-05  yhat_tgt.sum: 48.97467803955078  y_tgt.sum: 53.18610747051114
1350 train loss: 3.6148772778688e-05  yhat_tgt.sum: 101.46003723144531  y_tgt.sum: 33.01441677588467


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.370612830622122e-05  yhat_tgt.sum: 48.69447326660156  y_tgt.sum: 55.262123197903016
1360 train loss: 1.9101180441793986e-05  yhat_tgt.sum: 99.73929595947266  y_tgt.sum: 38.51376146788991


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.5420593374292366e-05  yhat_tgt.sum: 48.47792053222656  y_tgt.sum: 54.04062909567497
1370 train loss: 2.178259819629602e-05  yhat_tgt.sum: 99.94467163085938  y_tgt.sum: 36.917431192660544


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.1484389435499907e-05  yhat_tgt.sum: 48.17670440673828  y_tgt.sum: 53.20445609436435
1380 train loss: 2.3013140889815986e-05  yhat_tgt.sum: 98.14559936523438  y_tgt.sum: 39.137614678899084


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0621342628146522e-05  yhat_tgt.sum: 47.8314323425293  y_tgt.sum: 51.411533420707734
1390 train loss: 2.190024133597035e-05  yhat_tgt.sum: 98.29071807861328  y_tgt.sum: 42.18741808650066


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.6269248590106145e-05  yhat_tgt.sum: 47.52937316894531  y_tgt.sum: 46.70773263433814
1400 train loss: 1.702179724816233e-05  yhat_tgt.sum: 97.5525131225586  y_tgt.sum: 36.69200524246396


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.4983228538767435e-05  yhat_tgt.sum: 47.149078369140625  y_tgt.sum: 47.71428571428572
1410 train loss: 4.414020077092573e-05  yhat_tgt.sum: 98.1423568725586  y_tgt.sum: 37.093053735255566


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.4175657270243391e-05  yhat_tgt.sum: 46.88288116455078  y_tgt.sum: 45.967234600262124
1420 train loss: 2.234055682492908e-05  yhat_tgt.sum: 96.613525390625  y_tgt.sum: 38.1559633027523


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.1792137229349464e-05  yhat_tgt.sum: 46.62532043457031  y_tgt.sum: 43.24377457404981
1430 train loss: 2.5173312678816728e-05  yhat_tgt.sum: 95.38927459716797  y_tgt.sum: 40.686762778505894


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.9019655155716464e-05  yhat_tgt.sum: 46.386497497558594  y_tgt.sum: 49.840104849279165
1440 train loss: 2.9264356271596625e-05  yhat_tgt.sum: 96.64849090576172  y_tgt.sum: 48.766710353866316


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.043728636635933e-05  yhat_tgt.sum: 46.057987213134766  y_tgt.sum: 50.863695937090434
1450 train loss: 2.0290166503400542e-05  yhat_tgt.sum: 95.81689453125  y_tgt.sum: 41.439056356487555


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.563646765134763e-05  yhat_tgt.sum: 45.84514617919922  y_tgt.sum: 49.48361730013106
1460 train loss: 1.9873137716786005e-05  yhat_tgt.sum: 93.32745361328125  y_tgt.sum: 44.86762778505897


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.5334078852902167e-05  yhat_tgt.sum: 45.6156120300293  y_tgt.sum: 51.32372214941022
1470 train loss: 2.0651077647926286e-05  yhat_tgt.sum: 93.06800842285156  y_tgt.sum: 49.16775884665793


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.852390727843158e-05  yhat_tgt.sum: 45.32649612426758  y_tgt.sum: 53.68414154652687
1480 train loss: 2.5092895157285966e-05  yhat_tgt.sum: 94.7304916381836  y_tgt.sum: 37.17824377457404


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.269817170803435e-05  yhat_tgt.sum: 45.068702697753906  y_tgt.sum: 52.71690694626474
1490 train loss: 2.3209406208479777e-05  yhat_tgt.sum: 92.93971252441406  y_tgt.sum: 43.93840104849279


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.2365453332895413e-05  yhat_tgt.sum: 44.929656982421875  y_tgt.sum: 53.908256880733944
1500 train loss: 3.909545193891972e-05  yhat_tgt.sum: 92.69059753417969  y_tgt.sum: 49.4521625163827


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0749372197315097e-05  yhat_tgt.sum: 44.66019821166992  y_tgt.sum: 52.809960681520316
1510 train loss: 3.070048842346296e-05  yhat_tgt.sum: 93.81057739257812  y_tgt.sum: 36.59633027522936


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.5737359717604704e-05  yhat_tgt.sum: 44.403663635253906  y_tgt.sum: 47.234600262123195
1520 train loss: 3.292123801656999e-05  yhat_tgt.sum: 93.03173828125  y_tgt.sum: 42.884665792922675


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0533054339466617e-05  yhat_tgt.sum: 43.956504821777344  y_tgt.sum: 50.9475753604194
1530 train loss: 2.0114195649512112e-05  yhat_tgt.sum: 92.95331573486328  y_tgt.sum: 37.085190039318476


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.816471922211349e-05  yhat_tgt.sum: 43.756595611572266  y_tgt.sum: 47.82830930537352
1540 train loss: 2.98573086183751e-05  yhat_tgt.sum: 92.44548034667969  y_tgt.sum: 30.391874180865006


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 5.734360092901625e-05  yhat_tgt.sum: 43.596275329589844  y_tgt.sum: 62.469200524246396
1550 train loss: 2.8420496164471842e-05  yhat_tgt.sum: 90.01456451416016  y_tgt.sum: 45.62254259501967


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.4756460334174335e-05  yhat_tgt.sum: 43.29447937011719  y_tgt.sum: 52.58846657929227
1560 train loss: 1.971188794414047e-05  yhat_tgt.sum: 88.96356201171875  y_tgt.sum: 42.360419397116644


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.001814184244722e-05  yhat_tgt.sum: 43.09962463378906  y_tgt.sum: 52.76015727391875
1570 train loss: 5.966660683043301e-05  yhat_tgt.sum: 89.75267028808594  y_tgt.sum: 41.31061598951507


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0352439605630934e-05  yhat_tgt.sum: 42.856964111328125  y_tgt.sum: 52.04849279161206
1580 train loss: 2.6252735551679507e-05  yhat_tgt.sum: 89.04489135742188  y_tgt.sum: 35.357798165137616


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.7212771126651205e-05  yhat_tgt.sum: 42.547760009765625  y_tgt.sum: 47.37745740498035
1590 train loss: 2.2061665731598623e-05  yhat_tgt.sum: 89.33770751953125  y_tgt.sum: 38.81651376146789


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.0246317944838665e-05  yhat_tgt.sum: 42.449520111083984  y_tgt.sum: 56.575360419397114
1600 train loss: 2.794600550259929e-05  yhat_tgt.sum: 87.9279556274414  y_tgt.sum: 39.82961992136304


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.9314418750582263e-05  yhat_tgt.sum: 42.20827102661133  y_tgt.sum: 49.545216251638266
1610 train loss: 2.1314492187229916e-05  yhat_tgt.sum: 87.48231506347656  y_tgt.sum: 40.960681520314544


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.4370070352451876e-05  yhat_tgt.sum: 41.99007034301758  y_tgt.sum: 50.853211009174316
1620 train loss: 2.08414548978908e-05  yhat_tgt.sum: 87.75714111328125  y_tgt.sum: 37.8519003931848


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.3938993106421549e-05  yhat_tgt.sum: 41.68133544921875  y_tgt.sum: 44.366972477064216
1630 train loss: 2.7157238946529105e-05  yhat_tgt.sum: 87.16313171386719  y_tgt.sum: 42.25950196592399


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.766270725056529e-05  yhat_tgt.sum: 41.44809341430664  y_tgt.sum: 48.3997378768021
1640 train loss: 2.2466701921075583e-05  yhat_tgt.sum: 88.0434799194336  y_tgt.sum: 35.26605504587155


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.547154508647509e-05  yhat_tgt.sum: 41.264686584472656  y_tgt.sum: 54.67496723460026
1650 train loss: 2.3267779397428967e-05  yhat_tgt.sum: 86.37504577636719  y_tgt.sum: 42.752293577981646


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.910808168759104e-05  yhat_tgt.sum: 41.0747184753418  y_tgt.sum: 47.66710353866318
1660 train loss: 1.6263267752947286e-05  yhat_tgt.sum: 86.689697265625  y_tgt.sum: 29.93315858453473


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 3.677002678159624e-05  yhat_tgt.sum: 40.817665100097656  y_tgt.sum: 57.70773263433814
1670 train loss: 2.821330417646095e-05  yhat_tgt.sum: 86.44747161865234  y_tgt.sum: 32.84534731323723


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.422522629785817e-05  yhat_tgt.sum: 40.65386962890625  y_tgt.sum: 50.323722149410216
1680 train loss: 4.3054315028712153e-05  yhat_tgt.sum: 85.28367614746094  y_tgt.sum: 45.48885976408912


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 2.073156792903319e-05  yhat_tgt.sum: 40.449012756347656  y_tgt.sum: 50.56749672346003
1690 train loss: 3.648582423920743e-05  yhat_tgt.sum: 84.48458862304688  y_tgt.sum: 43.93577981651376


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

0 val loss: 1.5908910427242517e-05  yhat_tgt.sum: 40.242713928222656  y_tgt.sum: 48.22804718217563
1700 train loss: 1.892899308586493e-05  yhat_tgt.sum: 84.52944946289062  y_tgt.sum: 39.73656618610747



### testing

In [None]:
%%time
model = SalesModel.load_from_checkpoint('models/weights.ckpt')
x_src, x_tgt, y_src, y_tgt, item_store_mask = model.test()
x_src = x_src.squeeze(0)
x_tgt = x_tgt.squeeze(0)
y_src = y_src.squeeze(0)
y_tgt = y_tgt.squeeze(0)
print(f'x_src.shape - {x_src.shape}, x_tgt.shape - {x_tgt.shape} , y_src.shape - {y_src.shape} , y_tgt.shape - {y_tgt.shape}')

model.eval()
print('starting inference...')
yhat_tgt = model(x_src, y_src, x_tgt)
yhat_tgt.shape

In [None]:
yhat_tgt = yhat_tgt.refine_names('days','item_store','demand')
yhat_tgt_aligned = yhat_tgt.align_to('item_store','days','demand').squeeze(2).detach().numpy()
print(f'yhat.shape: ', yhat_tgt_aligned.shape)

# create preds df
preds = pd.DataFrame()
preds['id'] = sales['id']

# read scalers
with open('scalers.pkl','rb') as f:
    scalers = pickle.load(f)


pred_ids = preds['id'].tolist()
# eval df should also be submitted (days 1942 to 1969)
eval_ids = ['_'.join(o.split('_')[:5] + ['evaluation']) for o in pred_ids]
eval_df = pd.DataFrame({'id': eval_ids})

for idx in range(num_test1_days):
    preds['F' + str(idx+1)] = yhat_tgt_aligned[:,idx]
    preds['F' + str(idx+1)] = scalers['demand'].inverse_transform(preds[['F' + str(idx+1)]])
    
    eval_df['F' + str(idx+1)] = 0.0
    
out_df = pd.concat([preds,eval_df],axis=0)
print(out_df.shape)
preds.head()

In [None]:
out_df.describe()

In [None]:
out_df.to_csv('preds.csv', index=False)
!kaggle competitions submit -c m5-forecasting-accuracy -f preds.csv -m "transformers 3"

In [None]:
# !head preds.csv

## Playground

In [None]:
x_src.shape

In [None]:
item_store_mask = list(np.random.randint(0, 10,3))
item_store_mask

In [None]:
torch.randn(10).sum().item()

In [None]:
sample.head(2)

In [None]:
dir(hparams)`

In [None]:
sub = pd.read_csv(data_dir/'sample_submission.csv')

In [None]:

    
eval_df.head()

In [None]:
sub.shape

In [None]:
preds.shape

In [None]:
preds.head()