### Imports

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
!pwd

/content


In [3]:
PATH_TO_FOLDER = '/content/drive/MyDrive/DL_project/EXPERIMENTS'

In [4]:
%cd {PATH_TO_FOLDER}

/content/drive/MyDrive/DL_project/EXPERIMENTS


In [5]:
%load_ext autoreload
%autoreload 2

In [6]:
from models import MLP, NAC, NALU

In [7]:
import math
import random
import numpy as np
from tqdm.notebook import tqdm
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [8]:
%%capture
!pip install wandb --upgrade

In [10]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

### INTERPOLATION Parameters

In [11]:
NAMES = dict(
    project = 'INTERPOLATION',
    tags = 'preparation'
)

MODEL_PARAMETERS = dict(
    in_dim=2,
    out_dim=1,
    NUM_LAYERS = 2,
    HIDDEN_DIM = 2,
)

ARITHMETIC_PARAMETERS = dict(
    num_train=500, 
    num_test=50,
    dim=100, 
    num_sum=5,
    RANGE = [5, 10]
)

TRAIN_PARAMETERS = dict(
                        LEARNING_RATE = 1e-2,
                        # NUM_ITERS = int(1e5),
                        NUM_ITERS = int(50000),
                        activation='relu6'
                        )

ARITHMETIC_FUNCTIONS = {
    'add': lambda x, y: x + y,
    'sub': lambda x, y: x - y,
    'mul': lambda x, y: x * y,
    'div': lambda x, y: x / y,
    'squared': lambda x, y: torch.pow(x, 2),
    'root': lambda x, y: torch.sqrt(x)
    }

MODELS = {
        'MLP': MLP, 
        'NAC': NAC, 
        'NALU': NALU
        }

In [None]:
# def get_id():
#     id_dct = {}
#     for i, key in enumerate(ARITHMETIC_FUNCTIONS.keys()):
#         id_dct[key] = i
#     return id_dct

# FUNC_id = get_id()

### Config for Wandb

In [12]:
def create_config(model, function):
    config = {'model': model,
              'function': function}
              
    models = MODEL_PARAMETERS.copy()
    funcs = ARITHMETIC_PARAMETERS.copy()
    params = TRAIN_PARAMETERS.copy()
    names = NAMES.copy()

    for key, value in models.items():
        config[key] = value
    for key, value in funcs.items():
        config[key] = value
    for key, value in params.items():
        config[key] = value
    for key, value in names.items():
        config[key] = value

    return config

### Example of config

In [13]:
config = create_config('MLP', 'add')

In [15]:
config

{'HIDDEN_DIM': 2,
 'LEARNING_RATE': 0.01,
 'NUM_ITERS': 50000,
 'NUM_LAYERS': 2,
 'RANGE': [5, 10],
 'activation': 'relu6',
 'dim': 100,
 'function': 'add',
 'in_dim': 2,
 'model': 'MLP',
 'num_sum': 5,
 'num_test': 50,
 'num_train': 500,
 'out_dim': 1,
 'project': 'INTERPOLATION',
 'tags': 'preparation'}

### Iteration on models and arithmetic functions

In [17]:
mdls = MODELS.keys()
fncts = ARITHMETIC_FUNCTIONS.keys()

for function in fncts:
    for arch in mdls:
        config = create_config(arch, function)
        # Build, train and analyze the model with the pipeline
        model = model_pipeline(config)

[34m[1mwandb[0m: Currently logged in as: [33mgalmitr[0m (use `wandb login --relogin` to force relogin)


	1/50000: loss: 3748.6340332 - mea: 61.1039314
	1001/50000: loss: 0.7828440 - mea: 0.8812203
	2001/50000: loss: 0.5584386 - mea: 0.7453416
	3001/50000: loss: 0.5582510 - mea: 0.7457025
	4001/50000: loss: 0.5568953 - mea: 0.7451655
	5001/50000: loss: 0.5557318 - mea: 0.7446712
	6001/50000: loss: 0.5547053 - mea: 0.7442012
	7001/50000: loss: 0.5537956 - mea: 0.7437565
	8001/50000: loss: 0.5529677 - mea: 0.7433252
	9001/50000: loss: 0.5522817 - mea: 0.7429571
	10001/50000: loss: 0.5516057 - mea: 0.7425703
	11001/50000: loss: 0.5510343 - mea: 0.7422344
	12001/50000: loss: 0.5506465 - mea: 0.7420072
	13001/50000: loss: 0.5501085 - mea: 0.7416670
	14001/50000: loss: 0.5497105 - mea: 0.7414125
	15001/50000: loss: 0.5493234 - mea: 0.7411587
	16001/50000: loss: 0.5489950 - mea: 0.7409395
	17001/50000: loss: 0.5485813 - mea: 0.7406591
	18001/50000: loss: 0.5482658 - mea: 0.7404423
	19001/50000: loss: 0.5478176 - mea: 0.7401340
	20001/50000: loss: 0.5476298 - mea: 0.7400004
	21001/50000: loss: 0.

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.54138
mean_absolute_error,0.73564
_runtime,30.0
_timestamp,1623439782.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 7029.0747070 - mea: 83.6875458
	1001/50000: loss: 0.0033662 - mea: 0.0468918
	2001/50000: loss: 0.0192659 - mea: 0.1385377
	3001/50000: loss: 0.0172079 - mea: 0.1308930
	4001/50000: loss: 0.0145552 - mea: 0.1203790
	5001/50000: loss: 0.0122746 - mea: 0.1105449
	6001/50000: loss: 0.0103421 - mea: 0.1014696
	7001/50000: loss: 0.0087115 - mea: 0.0931267
	8001/50000: loss: 0.0073053 - mea: 0.0852791
	9001/50000: loss: 0.0061372 - mea: 0.0781640
	10001/50000: loss: 0.0051513 - mea: 0.0716105
	11001/50000: loss: 0.0043264 - mea: 0.0656273
	12001/50000: loss: 0.0036246 - mea: 0.0600686
	13001/50000: loss: 0.0030399 - mea: 0.0550108
	14001/50000: loss: 0.0025570 - mea: 0.0504529
	15001/50000: loss: 0.0021463 - mea: 0.0462238
	16001/50000: loss: 0.0017937 - mea: 0.0422569
	17001/50000: loss: 0.0015088 - mea: 0.0387558
	18001/50000: loss: 0.0012754 - mea: 0.0356329
	19001/50000: loss: 0.0010640 - mea: 0.0325461
	20001/50000: loss: 0.0008965 - mea: 0.0298743
	21001/50000: loss: 0.

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.0
mean_absolute_error,0.0022
_runtime,31.0
_timestamp,1623439820.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇██
_timestamp,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇██
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 6646.0751953 - mea: 81.3788605
	1001/50000: loss: 4.0522738 - mea: 1.7092155
	2001/50000: loss: 3.2490234 - mea: 1.5538594
	3001/50000: loss: 2.9144080 - mea: 1.4764096
	4001/50000: loss: 2.7402048 - mea: 1.4300343
	5001/50000: loss: 2.6239080 - mea: 1.3969358
	6001/50000: loss: 2.5313683 - mea: 1.3701613
	7001/50000: loss: 2.4506972 - mea: 1.3464513
	8001/50000: loss: 2.3807337 - mea: 1.3256232
	9001/50000: loss: 2.3169265 - mea: 1.3063997
	10001/50000: loss: 2.2611101 - mea: 1.2894981
	11001/50000: loss: 2.2094994 - mea: 1.2736760
	12001/50000: loss: 2.1640484 - mea: 1.2596492
	13001/50000: loss: 2.1223688 - mea: 1.2466543
	14001/50000: loss: 2.0868485 - mea: 1.2355886
	15001/50000: loss: 2.0543647 - mea: 1.2253973
	16001/50000: loss: 2.0255828 - mea: 1.2163492
	17001/50000: loss: 2.0002713 - mea: 1.2083992
	18001/50000: loss: 1.9777086 - mea: 1.2013171
	19001/50000: loss: 1.9572546 - mea: 1.1948935
	20001/50000: loss: 1.9391894 - mea: 1.1892518
	21001/50000: loss: 1.

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),1.77642
mean_absolute_error,1.14351
_runtime,68.0
_timestamp,1623439894.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 18.4325676 - mea: 3.3782134
	1001/50000: loss: 18.1327953 - mea: 3.3373244
	2001/50000: loss: 18.1327953 - mea: 3.3373239
	3001/50000: loss: 18.1327953 - mea: 3.3373239
	4001/50000: loss: 18.1327972 - mea: 3.3373239
	5001/50000: loss: 18.1327953 - mea: 3.3373246
	6001/50000: loss: 18.1327953 - mea: 3.3373239
	7001/50000: loss: 18.1327953 - mea: 3.3373239
	8001/50000: loss: 18.1327953 - mea: 3.3373241
	9001/50000: loss: 18.1327953 - mea: 3.3373258
	10001/50000: loss: 18.1327953 - mea: 3.3373239
	11001/50000: loss: 18.1327953 - mea: 3.3373239
	12001/50000: loss: 18.1327953 - mea: 3.3373241
	13001/50000: loss: 18.1327953 - mea: 3.3373287
	14001/50000: loss: 18.1327953 - mea: 3.3373239
	15001/50000: loss: 18.1327953 - mea: 3.3373239
	16001/50000: loss: 18.1327953 - mea: 3.3373241
	17001/50000: loss: 18.1327953 - mea: 3.3373356
	18001/50000: loss: 18.1327953 - mea: 3.3373239
	19001/50000: loss: 18.1327953 - mea: 3.3373239
	20001/50000: loss: 18.1327953 - mea: 3.3373241
	2100

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),18.1328
mean_absolute_error,3.33734
_runtime,30.0
_timestamp,1623439929.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
_timestamp,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 25.3872166 - mea: 4.1018472
	1001/50000: loss: 0.0486156 - mea: 0.2201203
	2001/50000: loss: 0.0467388 - mea: 0.2158303
	3001/50000: loss: 0.0434900 - mea: 0.2081955
	4001/50000: loss: 0.0403792 - mea: 0.2006122
	5001/50000: loss: 0.0373849 - mea: 0.1930315
	6001/50000: loss: 0.0345214 - mea: 0.1854921
	7001/50000: loss: 0.0317472 - mea: 0.1778826
	8001/50000: loss: 0.0291351 - mea: 0.1704076
	9001/50000: loss: 0.0266530 - mea: 0.1629862
	10001/50000: loss: 0.0243147 - mea: 0.1556714
	11001/50000: loss: 0.0221169 - mea: 0.1484678
	12001/50000: loss: 0.0200971 - mea: 0.1415242
	13001/50000: loss: 0.0181970 - mea: 0.1346653
	14001/50000: loss: 0.0164687 - mea: 0.1281077
	15001/50000: loss: 0.0148859 - mea: 0.1217930
	16001/50000: loss: 0.0134380 - mea: 0.1157146
	17001/50000: loss: 0.0121231 - mea: 0.1099033
	18001/50000: loss: 0.0109287 - mea: 0.1043444
	19001/50000: loss: 0.0098426 - mea: 0.0990191
	20001/50000: loss: 0.0088560 - mea: 0.0939199
	21001/50000: loss: 0.007

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.00077
mean_absolute_error,0.0268
_runtime,31.0
_timestamp,1623439966.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 22.4865017 - mea: 3.7758424
	1001/50000: loss: 0.2037984 - mea: 0.3668483
	2001/50000: loss: 0.0665022 - mea: 0.2574056
	3001/50000: loss: 0.0590744 - mea: 0.2426206
	4001/50000: loss: 0.0519254 - mea: 0.2274517
	5001/50000: loss: 0.0451130 - mea: 0.2119971
	6001/50000: loss: 0.0387936 - mea: 0.1965779
	7001/50000: loss: 0.0330766 - mea: 0.1815040
	8001/50000: loss: 0.0279663 - mea: 0.1668834
	9001/50000: loss: 0.0235231 - mea: 0.1530434
	10001/50000: loss: 0.0197337 - mea: 0.1401680
	11001/50000: loss: 0.0030795 - mea: 0.0553668
	12001/50000: loss: 0.0122014 - mea: 0.1101981
	13001/50000: loss: 0.0205045 - mea: 0.1429334
	14001/50000: loss: 0.0064773 - mea: 0.0803305
	15001/50000: loss: 0.0011850 - mea: 0.0342368
	16001/50000: loss: 0.0101581 - mea: 0.1005121
	17001/50000: loss: 0.0000000 - mea: 0.0000391
	18001/50000: loss: 0.0042506 - mea: 0.0650726
	19001/50000: loss: 0.0004193 - mea: 0.0204363
	20001/50000: loss: 0.0043593 - mea: 0.0659057
	21001/50000: loss: 0.004

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.00016
mean_absolute_error,0.01253
_runtime,62.0
_timestamp,1623440035.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 1739532.7500000 - mea: 1310.4144287
	1001/50000: loss: 5681.9067383 - mea: 60.9020309
	2001/50000: loss: 5306.3750000 - mea: 59.0105209
	3001/50000: loss: 4885.9179688 - mea: 56.6304131
	4001/50000: loss: 4385.9853516 - mea: 53.6638374
	5001/50000: loss: 3761.4389648 - mea: 49.7087173
	6001/50000: loss: 3006.5122070 - mea: 44.4476166
	7001/50000: loss: 2203.7224121 - mea: 38.0566406
	8001/50000: loss: 1490.6488037 - mea: 31.3604755
	9001/50000: loss: 957.5733032 - mea: 25.2971859
	10001/50000: loss: 609.9528198 - mea: 20.4424419
	11001/50000: loss: 404.7455750 - mea: 16.9190769
	12001/50000: loss: 291.8048096 - mea: 14.5704603
	13001/50000: loss: 232.5429993 - mea: 13.0552673
	14001/50000: loss: 202.3949127 - mea: 12.1663904
	15001/50000: loss: 187.3672333 - mea: 11.6691628
	16001/50000: loss: 179.9253082 - mea: 11.4099178
	17001/50000: loss: 176.2155304 - mea: 11.2775402
	18001/50000: loss: 174.3334808 - mea: 11.2091894
	19001/50000: loss: 173.3127899 - mea: 11.1657324

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),167.24385
mean_absolute_error,10.90048
_runtime,30.0
_timestamp,1623440071.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 2009686.1250000 - mea: 1408.3637695
	1001/50000: loss: 1620627.1250000 - mea: 1263.7286377
	2001/50000: loss: 1619169.2500000 - mea: 1263.1557617
	3001/50000: loss: 1619147.5000000 - mea: 1263.1470947
	4001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	5001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	6001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	7001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	8001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	9001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	10001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	11001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	12001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	13001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	14001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	15001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	16001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	17001/50000: loss: 1619147.1250000 - mea: 1263.1470947
	1800

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),1619147.125
mean_absolute_error,1263.14709
_runtime,31.0
_timestamp,1623440109.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_timestamp,▁▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 2224618.0000000 - mea: 1482.8612061
	1001/50000: loss: 1294.9248047 - mea: 31.7436008
	2001/50000: loss: 1092.4796143 - mea: 29.1405411
	3001/50000: loss: 889.2701416 - mea: 26.1761169
	4001/50000: loss: 681.5133667 - mea: 23.0156956
	5001/50000: loss: 519.5241699 - mea: 20.9236870
	6001/50000: loss: 441.0010071 - mea: 19.9423676
	7001/50000: loss: 410.9433899 - mea: 19.4155045
	8001/50000: loss: 392.3776550 - mea: 18.9868240
	9001/50000: loss: 376.3796997 - mea: 18.5965900
	10001/50000: loss: 361.0455627 - mea: 18.2130718
	11001/50000: loss: 346.5305786 - mea: 17.8421993
	12001/50000: loss: 332.3478394 - mea: 17.4712925
	13001/50000: loss: 318.8890686 - mea: 17.1126614
	14001/50000: loss: 306.1312561 - mea: 16.7667389
	15001/50000: loss: 293.5495911 - mea: 16.4181786
	16001/50000: loss: 281.4149475 - mea: 16.0755157
	17001/50000: loss: 269.5126953 - mea: 15.7321749
	18001/50000: loss: 258.1751099 - mea: 15.3991833
	19001/50000: loss: 247.1514435 - mea: 15.0687199
	2000

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),38.92661
mean_absolute_error,6.04214
_runtime,64.0
_timestamp,1623440179.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 0.4815880 - mea: 0.6818374
	1001/50000: loss: 0.0166857 - mea: 0.1013899
	2001/50000: loss: 0.0166857 - mea: 0.1013927
	3001/50000: loss: 0.0166857 - mea: 0.1013927
	4001/50000: loss: 0.0166857 - mea: 0.1013900
	5001/50000: loss: 0.0166857 - mea: 0.1013927
	6001/50000: loss: 0.0166857 - mea: 0.1013927
	7001/50000: loss: 0.0166857 - mea: 0.1013900
	8001/50000: loss: 0.0166857 - mea: 0.1013927
	9001/50000: loss: 0.0166857 - mea: 0.1013927
	10001/50000: loss: 0.0166857 - mea: 0.1013900
	11001/50000: loss: 0.0166857 - mea: 0.1013927
	12001/50000: loss: 0.0166857 - mea: 0.1013927
	13001/50000: loss: 0.0166857 - mea: 0.1013900
	14001/50000: loss: 0.0166857 - mea: 0.1013927
	15001/50000: loss: 0.0166857 - mea: 0.1013927
	16001/50000: loss: 0.0166857 - mea: 0.1013900
	17001/50000: loss: 0.0166857 - mea: 0.1013927
	18001/50000: loss: 0.0166857 - mea: 0.1013927
	19001/50000: loss: 0.0166857 - mea: 0.1013900
	20001/50000: loss: 0.0166857 - mea: 0.1013927
	21001/50000: loss: 0.0166

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.01669
mean_absolute_error,0.10139
_runtime,30.0
_timestamp,1623440215.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 116.5605087 - mea: 10.7648544
	1001/50000: loss: 0.0441992 - mea: 0.2003285
	2001/50000: loss: 0.0234847 - mea: 0.1448817
	3001/50000: loss: 0.0181860 - mea: 0.1248981
	4001/50000: loss: 0.0138387 - mea: 0.1058239
	5001/50000: loss: 0.0105251 - mea: 0.0891970
	6001/50000: loss: 0.0081443 - mea: 0.0759732
	7001/50000: loss: 0.0065321 - mea: 0.0664932
	8001/50000: loss: 0.0054632 - mea: 0.0596887
	9001/50000: loss: 0.0046537 - mea: 0.0545594
	10001/50000: loss: 0.0044086 - mea: 0.0529687
	11001/50000: loss: 0.0042256 - mea: 0.0518463
	12001/50000: loss: 0.0040773 - mea: 0.0508977
	13001/50000: loss: 0.0039542 - mea: 0.0501221
	14001/50000: loss: 0.0038517 - mea: 0.0494798
	15001/50000: loss: 0.0037663 - mea: 0.0489724
	16001/50000: loss: 0.0036954 - mea: 0.0485553
	17001/50000: loss: 0.0036366 - mea: 0.0482089
	18001/50000: loss: 0.0035893 - mea: 0.0479267
	19001/50000: loss: 0.0035581 - mea: 0.0477404
	20001/50000: loss: 0.0035318 - mea: 0.0475775
	21001/50000: loss: 0.0

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.00337
mean_absolute_error,0.04661
_runtime,32.0
_timestamp,1623440254.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇████
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 21.7698345 - mea: 4.4583735
	1001/50000: loss: 0.0025014 - mea: 0.0415725
	2001/50000: loss: 0.0018424 - mea: 0.0360645
	3001/50000: loss: 0.0006455 - mea: 0.0185444
	4001/50000: loss: 0.0046913 - mea: 0.0610040
	5001/50000: loss: 0.0010763 - mea: 0.0203322
	6001/50000: loss: 0.0009032 - mea: 0.0220966
	7001/50000: loss: 0.0019310 - mea: 0.0289235
	8001/50000: loss: 0.0014798 - mea: 0.0225789
	9001/50000: loss: 0.0003172 - mea: 0.0152857
	10001/50000: loss: 0.0002784 - mea: 0.0140826
	11001/50000: loss: 0.0003831 - mea: 0.0156260
	12001/50000: loss: 0.0004731 - mea: 0.0178948
	13001/50000: loss: 0.0005842 - mea: 0.0199040
	14001/50000: loss: 0.0005244 - mea: 0.0206745
	15001/50000: loss: 0.0000804 - mea: 0.0060178
	16001/50000: loss: 0.0000946 - mea: 0.0059113
	17001/50000: loss: 0.0001771 - mea: 0.0109954
	18001/50000: loss: 0.0001805 - mea: 0.0102396
	19001/50000: loss: 0.0000977 - mea: 0.0058099
	20001/50000: loss: 0.0001353 - mea: 0.0064603
	21001/50000: loss: 0.000

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.00014
mean_absolute_error,0.00986
_runtime,62.0
_timestamp,1623440322.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 2081435.3750000 - mea: 1423.7891846
	1001/50000: loss: 22386.9199219 - mea: 118.2460022
	2001/50000: loss: 6556.7255859 - mea: 64.8680954
	3001/50000: loss: 5947.7641602 - mea: 61.7278824
	4001/50000: loss: 5318.5834961 - mea: 58.3041916
	5001/50000: loss: 4651.9619141 - mea: 54.4431343
	6001/50000: loss: 3945.6015625 - mea: 50.0210304
	7001/50000: loss: 3218.1328125 - mea: 45.0023499
	8001/50000: loss: 2510.8916016 - mea: 39.4955025
	9001/50000: loss: 1876.0031738 - mea: 33.8525047
	10001/50000: loss: 1354.6755371 - mea: 28.4483032
	11001/50000: loss: 962.0405273 - mea: 23.5317822
	12001/50000: loss: 687.7727661 - mea: 19.3781872
	13001/50000: loss: 507.7424011 - mea: 16.1450176
	14001/50000: loss: 395.2399597 - mea: 13.7691078
	15001/50000: loss: 327.4454651 - mea: 12.1171036
	16001/50000: loss: 287.7634277 - mea: 10.9525909
	17001/50000: loss: 264.9884949 - mea: 10.1677227
	18001/50000: loss: 252.1228027 - mea: 9.6239567
	19001/50000: loss: 244.8997650 - mea: 9.27017

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),234.22433
mean_absolute_error,8.64783
_runtime,31.0
_timestamp,1623440360.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_timestamp,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 2123429.7500000 - mea: 1434.9736328
	1001/50000: loss: 1692603.6250000 - mea: 1277.2885742
	2001/50000: loss: 1691285.2500000 - mea: 1276.7758789
	3001/50000: loss: 1691265.6250000 - mea: 1276.7683105
	4001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	5001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	6001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	7001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	8001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	9001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	10001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	11001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	12001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	13001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	14001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	15001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	16001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	17001/50000: loss: 1691265.3750000 - mea: 1276.7680664
	1800

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),1691265.375
mean_absolute_error,1276.76807
_runtime,31.0
_timestamp,1623440397.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 2005882.8750000 - mea: 1398.8598633
	1001/50000: loss: 25877.1015625 - mea: 125.8082886
	2001/50000: loss: 24512.4062500 - mea: 122.3937988
	3001/50000: loss: 24488.5898438 - mea: 122.3496704
	4001/50000: loss: 24488.1601562 - mea: 122.3478470
	5001/50000: loss: 24488.1523438 - mea: 122.3479385
	6001/50000: loss: 24488.1621094 - mea: 122.3480682
	7001/50000: loss: 24488.1582031 - mea: 122.3477631
	8001/50000: loss: 24488.1660156 - mea: 122.3479919
	9001/50000: loss: 24488.1601562 - mea: 122.3479309
	10001/50000: loss: 24488.1601562 - mea: 122.3479309
	11001/50000: loss: 24488.1601562 - mea: 122.3482971
	12001/50000: loss: 24488.1523438 - mea: 122.3479385
	13001/50000: loss: 24488.1621094 - mea: 122.3480682
	14001/50000: loss: 24488.1601562 - mea: 122.3477173
	15001/50000: loss: 24488.1621094 - mea: 122.3478775
	16001/50000: loss: 24488.1601562 - mea: 122.3480148
	17001/50000: loss: 24488.1562500 - mea: 122.3477631
	18001/50000: loss: 24488.1621094 - mea: 122.3479538
	19

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),24488.1582
mean_absolute_error,122.34784
_runtime,63.0
_timestamp,1623440467.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 284.5762329 - mea: 16.8499889
	1001/50000: loss: 0.0631955 - mea: 0.2083121
	2001/50000: loss: 0.0631937 - mea: 0.2082855
	3001/50000: loss: 0.0632031 - mea: 0.2083526
	4001/50000: loss: 0.0631937 - mea: 0.2082855
	5001/50000: loss: 0.0631937 - mea: 0.2082855
	6001/50000: loss: 0.0631937 - mea: 0.2082855
	7001/50000: loss: 0.0631937 - mea: 0.2082855
	8001/50000: loss: 0.0631937 - mea: 0.2082855
	9001/50000: loss: 0.0631937 - mea: 0.2082855
	10001/50000: loss: 0.0631937 - mea: 0.2082855
	11001/50000: loss: 0.0631937 - mea: 0.2082855
	12001/50000: loss: 0.0631937 - mea: 0.2082855
	13001/50000: loss: 0.0631937 - mea: 0.2082855
	14001/50000: loss: 0.0631937 - mea: 0.2082855
	15001/50000: loss: 0.0631937 - mea: 0.2082855
	16001/50000: loss: 0.0631937 - mea: 0.2082855
	17001/50000: loss: 0.0631937 - mea: 0.2082855
	18001/50000: loss: 0.0631937 - mea: 0.2082855
	19001/50000: loss: 0.0631937 - mea: 0.2082855
	20001/50000: loss: 0.0631937 - mea: 0.2082855
	21001/50000: loss: 0.0

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.06319
mean_absolute_error,0.20829
_runtime,31.0
_timestamp,1623440507.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_timestamp,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇▇▇██
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 6.7213259 - mea: 2.5588174
	1001/50000: loss: 0.0450699 - mea: 0.1688710
	2001/50000: loss: 0.0417498 - mea: 0.1611167
	3001/50000: loss: 0.0419718 - mea: 0.1614885
	4001/50000: loss: 0.0410249 - mea: 0.1593877
	5001/50000: loss: 0.0392769 - mea: 0.1554545
	6001/50000: loss: 0.0383809 - mea: 0.1534177
	7001/50000: loss: 0.0376787 - mea: 0.1517839
	8001/50000: loss: 0.0370564 - mea: 0.1503112
	9001/50000: loss: 0.0364848 - mea: 0.1489597
	10001/50000: loss: 0.0359609 - mea: 0.1477043
	11001/50000: loss: 0.0354803 - mea: 0.1465679
	12001/50000: loss: 0.0350399 - mea: 0.1455474
	13001/50000: loss: 0.0346383 - mea: 0.1446489
	14001/50000: loss: 0.0342682 - mea: 0.1438292
	15001/50000: loss: 0.0339304 - mea: 0.1430951
	16001/50000: loss: 0.0336235 - mea: 0.1424363
	17001/50000: loss: 0.0333425 - mea: 0.1418156
	18001/50000: loss: 0.0330875 - mea: 0.1412512
	19001/50000: loss: 0.0328583 - mea: 0.1407427
	20001/50000: loss: 0.0326493 - mea: 0.1402836
	21001/50000: loss: 0.0324

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.03108
mean_absolute_error,0.13727
_runtime,32.0
_timestamp,1623440547.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 8.3164883 - mea: 2.8119428
	1001/50000: loss: 0.0154093 - mea: 0.1205957
	2001/50000: loss: 0.0128875 - mea: 0.1127877
	3001/50000: loss: 0.0110573 - mea: 0.1047929
	4001/50000: loss: 0.0094072 - mea: 0.0966364
	5001/50000: loss: 0.0078510 - mea: 0.0882059
	6001/50000: loss: 0.0063535 - mea: 0.0791773
	7001/50000: loss: 0.0048988 - mea: 0.0689847
	8001/50000: loss: 0.0038040 - mea: 0.0603444
	9001/50000: loss: 0.0033137 - mea: 0.0558943
	10001/50000: loss: 0.0031482 - mea: 0.0544466
	11001/50000: loss: 0.0029410 - mea: 0.0526236
	12001/50000: loss: 0.0027793 - mea: 0.0511843
	13001/50000: loss: 0.0026313 - mea: 0.0498328
	14001/50000: loss: 0.0024966 - mea: 0.0485720
	15001/50000: loss: 0.0023721 - mea: 0.0473766
	16001/50000: loss: 0.0022549 - mea: 0.0462199
	17001/50000: loss: 0.0021429 - mea: 0.0450836
	18001/50000: loss: 0.0020391 - mea: 0.0440032
	19001/50000: loss: 0.0019404 - mea: 0.0429485
	20001/50000: loss: 0.0018477 - mea: 0.0419324
	21001/50000: loss: 0.0017

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.00083
mean_absolute_error,0.02852
_runtime,63.0
_timestamp,1623440617.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇████
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


### Wandb pipeline

In [16]:
"""
pipeline for training model and tracking weights, gradients and metrics via Wandb
"""

def model_pipeline(hyperparameters):
    with wandb.init(project=f"{hyperparameters['project']}",
                    group=f"{hyperparameters['function']}", 
                    job_type=f"{hyperparameters['model']}",
                    name=f"{hyperparameters['function']}_{hyperparameters['model']}",
                    tags=f"{hyperparameters['tags']}",
                    config=hyperparameters):
        config = wandb.config

        random_model, model, data, criterion, optimizer = make(config)
        # print(model)

        # data
        train_data, test_data = data
 
        # random_model
        random_mea = []
        for i in range(100):
            abs = test(random_model, test_data)
            random_mea.append(abs.mean().item())
        random_result = np.mean(random_mea)

        # model
        train(model, 
            train_data,
            criterion, 
            optimizer,
            config.NUM_ITERS)
        
        abs = test(model, test_data)
        result = abs.mean().item()
        normalized_result = 100.0 * result/random_result

        print('model_type:', config.model)
        print('function:', config.function)
        print('test_result:', result)
        print('normalized_test_result:', normalized_result)

        ###wand plots
        # id = FUNC_id[config.function]

        table_res = wandb.Table(data=[[config.function, result]], 
                                columns=["FUNCTION", "Mean_of_MAE"])
        
        table_norm_res = wandb.Table(data=[[config.function, normalized_result]], 
                                     columns=["FUNCTION", "Scaled_Mean_of_MAE"])

        wandb.log({
        'Results': wandb.plot.bar(table_res, 
                                "FUNCTION",
                                "Mean_of_MAE",
                                title="Results"),
                   
        'Normalized_Results': wandb.plot.bar(table_norm_res, 
                                            "FUNCTION",
                                            "Scaled_Mean_of_MAE", 
                                            title="Normalized_Results")
        })

    return model


def make(config):
    # Make the model
    model = MODELS[config.model](config.NUM_LAYERS, 
                                config.in_dim, 
                                config.HIDDEN_DIM, 
                                config.out_dim).model

    random_model = MLP(config.NUM_LAYERS, 
                        config.in_dim, 
                        config.HIDDEN_DIM, 
                        config.out_dim).model

    # Make the data
    data = generate_data(config.function,
                         config.num_train, 
                         config.num_test, 
                         config.dim, 
                         config.num_sum, 
                         config.RANGE)
    
    # Make the loss and optimizer
    criterion = F.mse_loss
    optimizer = torch.optim.RMSprop(model.parameters(), lr=config.LEARNING_RATE)

    # return random_model, model, criterion, optimizer
    return random_model, model, data, criterion, optimizer


def generate_data(function, num_train, num_test, dim, num_sum, support):

    fn = ARITHMETIC_FUNCTIONS[function]
    data = torch.FloatTensor(dim).uniform_(*support).unsqueeze_(1)
    X, y = [], []
    for i in range(num_train + num_test):
        idx_a = random.sample(range(dim), num_sum)
        idx_b = random.sample([x for x in range(dim) if x not in idx_a], num_sum)
        a, b = data[idx_a].sum(), data[idx_b].sum()
        X.append([a, b])
        y.append(fn(a, b))
    X = torch.FloatTensor(X)
    y = torch.FloatTensor(y).unsqueeze_(1)
    indices = list(range(num_train + num_test))
    np.random.shuffle(indices)
    X_train, y_train = X[indices[num_test:]], y[indices[num_test:]]
    X_test, y_test = X[indices[:num_test]], y[indices[:num_test]]

    return ((X_train, y_train), (X_test, y_test))


def train(model, train_data, criterion, optimizer, num_iters):
    wandb.watch(model, criterion, log="all", log_freq=1000)

    data, target = train_data
    # model.train()
    for i in range(num_iters):
        out = model(data)
        loss = criterion(out, target)
        mea = torch.mean(torch.abs(target - out))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            train_log(loss, mea, i, num_iters)

def train_log(loss, mea, epoch, num_iters):
    wandb.log({"epoch": epoch, "loss_(mse)": loss.item(), 'mean_absolute_error': mea.item()})
    print("\t{}/{}: loss: {:.7f} - mea: {:.7f}".format(epoch+1, num_iters, loss.item(), mea.item()))


def test(model, test_data):
    data, target = test_data
    with torch.no_grad():
        out = model(data)
        return torch.abs(target - out)