### Imports

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!pwd

/content


In [3]:
PATH_TO_FOLDER = '/content/drive/MyDrive/DL_project/EXPERIMENTS'

In [4]:
%cd {PATH_TO_FOLDER}

/content/drive/MyDrive/DL_project/EXPERIMENTS


In [6]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [18]:
from models import NAC, Complex_NAC, MLP

In [8]:
import math
import random
import numpy as np
from tqdm.notebook import tqdm
from copy import deepcopy

import torch
import torch.nn as nn
import torch.nn.functional as F

# Ensure deterministic behavior
torch.backends.cudnn.deterministic = True
random.seed(hash("setting random seeds") % 2**32 - 1)
np.random.seed(hash("improves reproducibility") % 2**32 - 1)
torch.manual_seed(hash("by removing stochasticity") % 2**32 - 1)
torch.cuda.manual_seed_all(hash("so runs are repeatable") % 2**32 - 1)

# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
%%capture
!pip install wandb --upgrade

In [16]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

### Parameters

In [20]:
NAMES = dict(
    project = 'Experiment_1',
    tags = 'try_1'
)

MODEL_PARAMETERS = dict(
    in_dim=2,
    out_dim=1,
    NUM_LAYERS = 1,
    HIDDEN_DIM = 2,
)

ARITHMETIC_PARAMETERS = dict(
    num_train=500, 
    num_test=50,
    dim=100, 
    num_sum=5,
    RANGE = [5, 10]
)

TRAIN_PARAMETERS = dict(
                        LEARNING_RATE = 1e-2,
                        # NUM_ITERS = int(1e5),
                        NUM_ITERS = int(50000),
                        activation='relu6'
                        )

ARITHMETIC_FUNCTIONS = {
    'add': lambda x, y: x + y,
    'sub': lambda x, y: x - y,
    'mul': lambda x, y: x * y,
    'div': lambda x, y: x / y,
    # 'squared': lambda x, y: torch.pow(x, 2),
    # 'root': lambda x, y: torch.sqrt(x)
    }

MODELS = {
        'NAC': NAC, 
        'Complex_NAC': NAC
        }

In [None]:
# def get_id():
#     id_dct = {}
#     for i, key in enumerate(ARITHMETIC_FUNCTIONS.keys()):
#         id_dct[key] = i
#     return id_dct

# FUNC_id = get_id()

### Config for Wandb

In [21]:
def create_config(model, function):
    config = {'model': model,
              'function': function}
              
    models = MODEL_PARAMETERS.copy()
    funcs = ARITHMETIC_PARAMETERS.copy()
    params = TRAIN_PARAMETERS.copy()
    names = NAMES.copy()

    for key, value in models.items():
        config[key] = value
    for key, value in funcs.items():
        config[key] = value
    for key, value in params.items():
        config[key] = value
    for key, value in names.items():
        config[key] = value

    return config

### Example of config

In [23]:
config = create_config('NAC', 'add')

In [24]:
config

{'HIDDEN_DIM': 2,
 'LEARNING_RATE': 0.01,
 'NUM_ITERS': 50000,
 'NUM_LAYERS': 1,
 'RANGE': [5, 10],
 'activation': 'relu6',
 'dim': 100,
 'function': 'add',
 'in_dim': 2,
 'model': 'NAC',
 'num_sum': 5,
 'num_test': 50,
 'num_train': 500,
 'out_dim': 1,
 'project': 'Experiment_1',
 'tags': 'try_1'}

### Iteration on models and arithmetic functions

In [25]:
mdls = MODELS.keys()
fncts = ARITHMETIC_FUNCTIONS.keys()

for function in fncts:
    for arch in mdls:
        config = create_config(arch, function)
        # Build, train and analyze the model with the pipeline
        model = model_pipeline(config)

	1/50000: loss: 3937.5056152 - mea: 62.5801811
	1001/50000: loss: 1.1238458 - mea: 1.0577044
	2001/50000: loss: 0.0087257 - mea: 0.0932048
	3001/50000: loss: 0.0000705 - mea: 0.0083774
	4001/50000: loss: 0.0000006 - mea: 0.0007539
	5001/50000: loss: 0.0000000 - mea: 0.0000803
	6001/50000: loss: 0.0000000 - mea: 0.0000247
	7001/50000: loss: 0.0000000 - mea: 0.0000197
	8001/50000: loss: 0.0000000 - mea: 0.0000113
	9001/50000: loss: 0.0000000 - mea: 0.0000113
	10001/50000: loss: 0.0000000 - mea: 0.0000113
	11001/50000: loss: 0.0000000 - mea: 0.0000074
	12001/50000: loss: 0.0000000 - mea: 0.0000074
	13001/50000: loss: 0.0000000 - mea: 0.0000074
	14001/50000: loss: 0.0000000 - mea: 0.0000074
	15001/50000: loss: 0.0000000 - mea: 0.0000074
	16001/50000: loss: 0.0000000 - mea: 0.0000074
	17001/50000: loss: 0.0000000 - mea: 0.0000074
	18001/50000: loss: 0.0000000 - mea: 0.0000074
	19001/50000: loss: 0.0000000 - mea: 0.0000074
	20001/50000: loss: 0.0000000 - mea: 0.0000074
	21001/50000: loss: 0.

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.0
mean_absolute_error,0.0
_runtime,23.0
_timestamp,1623439033.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 5159.7626953 - mea: 71.6897354
	1001/50000: loss: 1.6701837 - mea: 1.2899815
	2001/50000: loss: 0.0130567 - mea: 0.1140550
	3001/50000: loss: 0.0001070 - mea: 0.0103245
	4001/50000: loss: 0.0000009 - mea: 0.0009364
	5001/50000: loss: 0.0000000 - mea: 0.0000923
	6001/50000: loss: 0.0000000 - mea: 0.0000267
	7001/50000: loss: 0.0000000 - mea: 0.0000188
	8001/50000: loss: 0.0000000 - mea: 0.0000112
	9001/50000: loss: 0.0000000 - mea: 0.0000112
	10001/50000: loss: 0.0000000 - mea: 0.0000112
	11001/50000: loss: 0.0000000 - mea: 0.0000074
	12001/50000: loss: 0.0000000 - mea: 0.0000074
	13001/50000: loss: 0.0000000 - mea: 0.0000074
	14001/50000: loss: 0.0000000 - mea: 0.0000074
	15001/50000: loss: 0.0000000 - mea: 0.0000074
	16001/50000: loss: 0.0000000 - mea: 0.0000074
	17001/50000: loss: 0.0000000 - mea: 0.0000074
	18001/50000: loss: 0.0000000 - mea: 0.0000074
	19001/50000: loss: 0.0000000 - mea: 0.0000074
	20001/50000: loss: 0.0000000 - mea: 0.0000074
	21001/50000: loss: 0.

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.0
mean_absolute_error,0.0
_runtime,24.0
_timestamp,1623439063.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 337.8469849 - mea: 17.7205048
	1001/50000: loss: 0.1812958 - mea: 0.3366335
	2001/50000: loss: 0.0011750 - mea: 0.0271012
	3001/50000: loss: 0.0000095 - mea: 0.0024385
	4001/50000: loss: 0.0000001 - mea: 0.0002239
	5001/50000: loss: 0.0000000 - mea: 0.0000336
	6001/50000: loss: 0.0000000 - mea: 0.0000178
	7001/50000: loss: 0.0000000 - mea: 0.0000133
	8001/50000: loss: 0.0000000 - mea: 0.0000114
	9001/50000: loss: 0.0000000 - mea: 0.0000097
	10001/50000: loss: 0.0000000 - mea: 0.0000087
	11001/50000: loss: 0.0000000 - mea: 0.0000085
	12001/50000: loss: 0.0000000 - mea: 0.0000075
	13001/50000: loss: 0.0000000 - mea: 0.0000075
	14001/50000: loss: 0.0000000 - mea: 0.0000070
	15001/50000: loss: 0.0000000 - mea: 0.0000065
	16001/50000: loss: 0.0000000 - mea: 0.0000061
	17001/50000: loss: 0.0000000 - mea: 0.0000062
	18001/50000: loss: 0.0000000 - mea: 0.0000053
	19001/50000: loss: 0.0000000 - mea: 0.0000056
	20001/50000: loss: 0.0000000 - mea: 0.0000050
	21001/50000: loss: 0.0

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.0
mean_absolute_error,0.0
_runtime,23.0
_timestamp,1623439094.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 280.9273071 - mea: 15.9568434
	1001/50000: loss: 0.1685604 - mea: 0.3309749
	2001/50000: loss: 0.0010934 - mea: 0.0266575
	3001/50000: loss: 0.0000089 - mea: 0.0024015
	4001/50000: loss: 0.0000001 - mea: 0.0002209
	5001/50000: loss: 0.0000000 - mea: 0.0000338
	6001/50000: loss: 0.0000000 - mea: 0.0000182
	7001/50000: loss: 0.0000000 - mea: 0.0000135
	8001/50000: loss: 0.0000000 - mea: 0.0000114
	9001/50000: loss: 0.0000000 - mea: 0.0000099
	10001/50000: loss: 0.0000000 - mea: 0.0000087
	11001/50000: loss: 0.0000000 - mea: 0.0000083
	12001/50000: loss: 0.0000000 - mea: 0.0000074
	13001/50000: loss: 0.0000000 - mea: 0.0000074
	14001/50000: loss: 0.0000000 - mea: 0.0000065
	15001/50000: loss: 0.0000000 - mea: 0.0000065
	16001/50000: loss: 0.0000000 - mea: 0.0000062
	17001/50000: loss: 0.0000000 - mea: 0.0000060
	18001/50000: loss: 0.0000000 - mea: 0.0000053
	19001/50000: loss: 0.0000000 - mea: 0.0000053
	20001/50000: loss: 0.0000000 - mea: 0.0000053
	21001/50000: loss: 0.0

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.0
mean_absolute_error,0.0
_runtime,23.0
_timestamp,1623439125.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 2011982.1250000 - mea: 1408.8945312
	1001/50000: loss: 1808265.6250000 - mea: 1335.1582031
	2001/50000: loss: 1807244.3750000 - mea: 1334.7780762
	3001/50000: loss: 1807228.8750000 - mea: 1334.7724609
	4001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	5001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	6001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	7001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	8001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	9001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	10001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	11001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	12001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	13001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	14001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	15001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	16001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	17001/50000: loss: 1807228.7500000 - mea: 1334.7724609
	1800

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),1807228.75
mean_absolute_error,1334.77246
_runtime,24.0
_timestamp,1623439156.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇██
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 1891392.5000000 - mea: 1364.7497559
	1001/50000: loss: 1736666.6250000 - mea: 1307.2900391
	2001/50000: loss: 1736439.1250000 - mea: 1307.2034912
	3001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	4001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	5001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	6001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	7001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	8001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	9001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	10001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	11001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	12001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	13001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	14001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	15001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	16001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	17001/50000: loss: 1736436.1250000 - mea: 1307.2025146
	1800

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),1736436.125
mean_absolute_error,1307.20251
_runtime,23.0
_timestamp,1623439186.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 242.0960388 - mea: 15.5305729
	1001/50000: loss: 0.0152736 - mea: 0.1029180
	2001/50000: loss: 0.0191439 - mea: 0.1218680
	3001/50000: loss: 0.0130938 - mea: 0.0978282
	4001/50000: loss: 0.0097501 - mea: 0.0828895
	5001/50000: loss: 0.0078614 - mea: 0.0735343
	6001/50000: loss: 0.0067358 - mea: 0.0675039
	7001/50000: loss: 0.0060212 - mea: 0.0633824
	8001/50000: loss: 0.0055417 - mea: 0.0605038
	9001/50000: loss: 0.0052042 - mea: 0.0584386
	10001/50000: loss: 0.0049569 - mea: 0.0568928
	11001/50000: loss: 0.0047694 - mea: 0.0557211
	12001/50000: loss: 0.0046241 - mea: 0.0547918
	13001/50000: loss: 0.0045084 - mea: 0.0540403
	14001/50000: loss: 0.0044146 - mea: 0.0534681
	15001/50000: loss: 0.0043375 - mea: 0.0529991
	16001/50000: loss: 0.0042733 - mea: 0.0526041
	17001/50000: loss: 0.0042191 - mea: 0.0522654
	18001/50000: loss: 0.0041734 - mea: 0.0519855
	19001/50000: loss: 0.0041340 - mea: 0.0517402
	20001/50000: loss: 0.0041004 - mea: 0.0515284
	21001/50000: loss: 0.0

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.0039
mean_absolute_error,0.05019
_runtime,24.0
_timestamp,1623439217.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_timestamp,▁▁▁▂▂▂▂▂▃▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


	1/50000: loss: 151.4924622 - mea: 11.9921055
	1001/50000: loss: 0.0384302 - mea: 0.1817961
	2001/50000: loss: 0.0189010 - mea: 0.1209432
	3001/50000: loss: 0.0123164 - mea: 0.0943413
	4001/50000: loss: 0.0092009 - mea: 0.0796465
	5001/50000: loss: 0.0075385 - mea: 0.0708008
	6001/50000: loss: 0.0065522 - mea: 0.0654130
	7001/50000: loss: 0.0059187 - mea: 0.0618472
	8001/50000: loss: 0.0054866 - mea: 0.0593323
	9001/50000: loss: 0.0051773 - mea: 0.0575228
	10001/50000: loss: 0.0049478 - mea: 0.0562022
	11001/50000: loss: 0.0047719 - mea: 0.0551690
	12001/50000: loss: 0.0046335 - mea: 0.0543634
	13001/50000: loss: 0.0045222 - mea: 0.0537049
	14001/50000: loss: 0.0044317 - mea: 0.0531589
	15001/50000: loss: 0.0043567 - mea: 0.0527129
	16001/50000: loss: 0.0042938 - mea: 0.0523470
	17001/50000: loss: 0.0042405 - mea: 0.0520467
	18001/50000: loss: 0.0041951 - mea: 0.0517868
	19001/50000: loss: 0.0041564 - mea: 0.0515605
	20001/50000: loss: 0.0041230 - mea: 0.0513643
	21001/50000: loss: 0.0

VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,49000.0
loss_(mse),0.00392
mean_absolute_error,0.05017
_runtime,24.0
_timestamp,1623439248.0
_step,50.0


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss_(mse),█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
mean_absolute_error,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
_runtime,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_timestamp,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███


### Wandb pipeline

In [14]:
"""
pipeline for training model and tracking weights, gradients and metrics via Wandb
"""

def model_pipeline(hyperparameters):
    with wandb.init(project=f"{hyperparameters['project']}",
                    group=f"{hyperparameters['function']}", 
                    job_type=f"{hyperparameters['model']}",
                    name=f"{hyperparameters['function']}_{hyperparameters['model']}",
                    tags=f"{hyperparameters['tags']}",
                    config=hyperparameters):
        config = wandb.config

        random_model, model, data, criterion, optimizer = make(config)
        # print(model)

        # data
        train_data, test_data = data
 
        # random_model
        random_mea = []
        for i in range(100):
            abs = test(random_model, test_data)
            random_mea.append(abs.mean().item())
        random_result = np.mean(random_mea)

        # model
        train(model, 
            train_data,
            criterion, 
            optimizer,
            config.NUM_ITERS)
        
        abs = test(model, test_data)
        result = abs.mean().item()
        normalized_result = 100.0 * result/random_result

        print('model_type:', config.model)
        print('function:', config.function)
        print('test_result:', result)
        print('normalized_test_result:', normalized_result)

        ###wand plots
        # id = FUNC_id[config.function]

        table_res = wandb.Table(data=[[config.function, result]], 
                                columns=["FUNCTION", "Mean_of_MAE"])
        
        table_norm_res = wandb.Table(data=[[config.function, normalized_result]], 
                                     columns=["FUNCTION", "Scaled_Mean_of_MAE"])

        wandb.log({
        'Results': wandb.plot.bar(table_res, 
                                "FUNCTION",
                                "Mean_of_MAE",
                                title="Results"),
                   
        'Normalized_Results': wandb.plot.bar(table_norm_res, 
                                            "FUNCTION",
                                            "Scaled_Mean_of_MAE", 
                                            title="Normalized_Results")
        })

    return model


def make(config):
    # Make the model
    model = MODELS[config.model](config.NUM_LAYERS, 
                                config.in_dim, 
                                config.HIDDEN_DIM, 
                                config.out_dim).model

    random_model = MLP(config.NUM_LAYERS, 
                        config.in_dim, 
                        config.HIDDEN_DIM, 
                        config.out_dim).model

    # Make the data
    data = generate_data(config.function,
                         config.num_train, 
                         config.num_test, 
                         config.dim, 
                         config.num_sum, 
                         config.RANGE)
    
    # Make the loss and optimizer
    criterion = F.mse_loss
    optimizer = torch.optim.RMSprop(model.parameters(), lr=config.LEARNING_RATE)

    # return random_model, model, criterion, optimizer
    return random_model, model, data, criterion, optimizer


def generate_data(function, num_train, num_test, dim, num_sum, support):

    fn = ARITHMETIC_FUNCTIONS[function]
    data = torch.FloatTensor(dim).uniform_(*support).unsqueeze_(1)
    X, y = [], []
    for i in range(num_train + num_test):
        idx_a = random.sample(range(dim), num_sum)
        idx_b = random.sample([x for x in range(dim) if x not in idx_a], num_sum)
        a, b = data[idx_a].sum(), data[idx_b].sum()
        X.append([a, b])
        y.append(fn(a, b))
    X = torch.FloatTensor(X)
    y = torch.FloatTensor(y).unsqueeze_(1)
    indices = list(range(num_train + num_test))
    np.random.shuffle(indices)
    X_train, y_train = X[indices[num_test:]], y[indices[num_test:]]
    X_test, y_test = X[indices[:num_test]], y[indices[:num_test]]

    return ((X_train, y_train), (X_test, y_test))


def train(model, train_data, criterion, optimizer, num_iters):
    wandb.watch(model, criterion, log="all", log_freq=1000)

    data, target = train_data
    # model.train()
    for i in range(num_iters):
        out = model(data)
        loss = criterion(out, target)
        mea = torch.mean(torch.abs(target - out))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if i % 1000 == 0:
            train_log(loss, mea, i, num_iters)

def train_log(loss, mea, epoch, num_iters):
    wandb.log({"epoch": epoch, "loss_(mse)": loss.item(), 'mean_absolute_error': mea.item()})
    print("\t{}/{}: loss: {:.7f} - mea: {:.7f}".format(epoch+1, num_iters, loss.item(), mea.item()))


def test(model, test_data):
    data, target = test_data
    with torch.no_grad():
        out = model(data)
        return torch.abs(target - out)