In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset

def make_dataloader(n, batch_size):
    X = np.random.random((n, 2)) * 2 - 1
    Y = (X[:, 0] > X[:, 1] + 0.25).astype(int) + 1

    X = torch.tensor(X, dtype=torch.float)
    Y = torch.tensor(Y, dtype=torch.long)

    div1 = int(n*0.8)
    div2 = int(n*0.9)
    Xs = [X[:div1], X[div1:div2], X[div2:]]
    Ys = [Y[:div1], Y[div1:div2], Y[div2:]]

    dataset = TensorDataset(Xs[0], Ys[0])
    data_loader = DataLoader(dataset, batch_size=batch_size)
    return data_loader

In [3]:
import torch.nn as nn
from metal.mmtl.utils.metrics import acc_f1, pearson_spearman
from metal.mmtl.task import Task
from metal.mmtl.scorer import Scorer

BATCHSIZE = 32

foo_input = nn.Linear(2, 10)
bar_input = foo_input #nn.Linear(100, 7)

foo_head = nn.Linear(10, 2)
bar_head = nn.Linear(10, 2)

foo_data = make_dataloader(5000, batch_size=BATCHSIZE)
bar_data = make_dataloader(2000, batch_size=BATCHSIZE)


# custom_metrics = {
#     pearson_spearman: ["pearson_corr", "spearman_corr", "pearson_spearman"]
# }
# scorer = Scorer(["accuracy"], custom_metric_funcs=custom_metrics)

foo = Task("foo_task", {"train": foo_data, "valid": foo_data, "test": foo_data}, foo_input, foo_head)
bar = Task("bar_task", {"train": bar_data, "valid": bar_data, "test": bar_data}, bar_input, bar_head)
# baz = Task("baz_task", "baz_head", [make_dataloader(100), None, None])
tasks = [foo, bar]

In [18]:
from metal.end_model import EndModel
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.trainer import MultitaskTrainer

model = MetalModel(tasks, device=-1, verbose=False)
trainer = MultitaskTrainer()
trainer.train_model(
    model, 
    tasks, 
    n_epochs=3, 
    lr=0.1, 
    progress_bar=True,
    log_unit="epochs",
    log_every=0.1,
    score_every=0.1,
#     optimizer="sgd",
    lr_scheduler="linear",
    warmup_steps = 100,
    min_lr = 0.0,
#     patience=10,
#     task_metrics=["foo_task/valid/acc_f1"],
    trainer_metrics=["lr"],
    test_split="test",
    grad_clip=0.01,
    writer="tensorboard",
    checkpoint=True,
    checkpoint_best=True,
    checkpoint_metric="foo_task/valid/accuracy",
    checkpoint_metric_mode="max",
)

Beginning train loop.
Expecting a total of approximately 5600 examples and 175 batches per epoch from 2 tasks.


HBox(children=(IntProgress(value=0, max=175), HTML(value='')))

[0.10 epo]: TRAIN:[model/loss=0.047] VALID:[foo_task/accuracy=0.999, bar_task/accuracy=0.997, model/lr=0.016]
Saving model at iteration 0.10285714285714286 with best (max) score 0.999
[0.21 epo]: TRAIN:[model/loss=0.027] VALID:[foo_task/accuracy=0.996, bar_task/accuracy=0.993, model/lr=0.034]
[0.31 epo]: TRAIN:[model/loss=0.020] VALID:[foo_task/accuracy=0.984, bar_task/accuracy=0.988, model/lr=0.052]
[0.41 epo]: TRAIN:[model/loss=0.051] VALID:[foo_task/accuracy=0.985, bar_task/accuracy=0.981, model/lr=0.070]
[0.51 epo]: TRAIN:[model/loss=0.154] VALID:[foo_task/accuracy=0.981, bar_task/accuracy=0.996, model/lr=0.088]
[0.62 epo]: TRAIN:[model/loss=0.032] VALID:[foo_task/accuracy=0.990, bar_task/accuracy=0.989, model/lr=0.099]
[0.72 epo]: TRAIN:[model/loss=0.033] VALID:[foo_task/accuracy=0.991, bar_task/accuracy=0.989, model/lr=0.094]
[0.82 epo]: TRAIN:[model/loss=0.035] VALID:[foo_task/accuracy=0.975, bar_task/accuracy=0.979, model/lr=0.090]
[0.93 epo]: TRAIN:[model/loss=0.104] VALID:[fo

HBox(children=(IntProgress(value=0, max=175), HTML(value='')))

[1.03 epo]: TRAIN:[model/loss=0.130] VALID:[foo_task/accuracy=0.978, bar_task/accuracy=0.980, model/lr=0.082]
[1.13 epo]: TRAIN:[model/loss=0.025] VALID:[foo_task/accuracy=0.988, bar_task/accuracy=0.982, model/lr=0.077]
[1.23 epo]: TRAIN:[model/loss=0.031] VALID:[foo_task/accuracy=0.997, bar_task/accuracy=0.989, model/lr=0.073]
[1.34 epo]: TRAIN:[model/loss=0.021] VALID:[foo_task/accuracy=0.993, bar_task/accuracy=0.989, model/lr=0.069]
[1.44 epo]: TRAIN:[model/loss=0.039] VALID:[foo_task/accuracy=0.985, bar_task/accuracy=0.986, model/lr=0.065]
[1.54 epo]: TRAIN:[model/loss=0.105] VALID:[foo_task/accuracy=0.986, bar_task/accuracy=0.986, model/lr=0.060]
[1.65 epo]: TRAIN:[model/loss=0.037] VALID:[foo_task/accuracy=0.994, bar_task/accuracy=0.996, model/lr=0.056]
[1.75 epo]: TRAIN:[model/loss=0.010] VALID:[foo_task/accuracy=0.996, bar_task/accuracy=0.996, model/lr=0.052]
[1.85 epo]: TRAIN:[model/loss=0.014] VALID:[foo_task/accuracy=0.981, bar_task/accuracy=0.983, model/lr=0.048]
[1.95 epo]

HBox(children=(IntProgress(value=0, max=175), HTML(value='')))

[2.06 epo]: TRAIN:[model/loss=0.014] VALID:[foo_task/accuracy=0.990, bar_task/accuracy=0.993, model/lr=0.039]
[2.16 epo]: TRAIN:[model/loss=0.004] VALID:[foo_task/accuracy=0.991, bar_task/accuracy=0.989, model/lr=0.035]
[2.26 epo]: TRAIN:[model/loss=0.038] VALID:[foo_task/accuracy=0.989, bar_task/accuracy=0.991, model/lr=0.031]
[2.37 epo]: TRAIN:[model/loss=0.031] VALID:[foo_task/accuracy=0.995, bar_task/accuracy=0.990, model/lr=0.027]
[2.47 epo]: TRAIN:[model/loss=0.015] VALID:[foo_task/accuracy=0.995, bar_task/accuracy=0.996, model/lr=0.022]
[2.57 epo]: TRAIN:[model/loss=0.019] VALID:[foo_task/accuracy=0.998, bar_task/accuracy=0.996, model/lr=0.018]
[2.67 epo]: TRAIN:[model/loss=0.005] VALID:[foo_task/accuracy=0.999, bar_task/accuracy=0.998, model/lr=0.014]
[2.78 epo]: TRAIN:[model/loss=0.011] VALID:[foo_task/accuracy=0.996, bar_task/accuracy=0.994, model/lr=0.010]
[2.88 epo]: TRAIN:[model/loss=0.007] VALID:[foo_task/accuracy=0.999, bar_task/accuracy=0.998, model/lr=0.005]
[2.98 epo]

In [16]:
trainer.optimizer.param_groups[0]["lr"]

1.1801719665527344e-08

In [17]:
# for batch in foo.data_loaders["train"]:
#     X, Y = batch
#     print(model(X, ['foo_task']))
#     print(model.calculate_loss(X, Y, ['foo_task']))    
#     print(model.calculate_output(X, ['foo_task']))    
#     break

In [7]:
model.predict_probs(foo, "train")

array([[8.86418000e-02, 9.11358178e-01],
       [1.00000000e+00, 8.56352843e-19],
       [8.66827876e-18, 1.00000000e+00],
       ...,
       [9.96014357e-01, 3.98571556e-03],
       [3.08819072e-05, 9.99969125e-01],
       [1.00000000e+00, 3.44794659e-09]])

In [8]:
model.predict(foo, "train")

array([2, 1, 2, ..., 1, 2, 1])

In [9]:
model.score(foo, "valid", "accuracy")

0.99775

In [10]:
a = {"foo": 1, "bar": 2}
list(a.keys())

['foo', 'bar']

In [11]:
# import os
# filepath = os.path.join(os.environ["METALHOME"], "my_model.pkl")
# model.save(filepath)

In [12]:
a = []
a[0]

IndexError: list index out of range