In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
import torch
from torch.utils.data import DataLoader, TensorDataset

def make_dataloader(n, batch_size):
    X = np.random.random((n, 2)) * 2 - 1
    Y = (X[:, 0] > X[:, 1] + 0.25).astype(int) + 1

    X = torch.tensor(X, dtype=torch.float)
    Y = torch.tensor(Y, dtype=torch.long)

    div1 = int(n*0.8)
    div2 = int(n*0.9)
    Xs = [X[:div1], X[div1:div2], X[div2:]]
    Ys = [Y[:div1], Y[div1:div2], Y[div2:]]

    dataset = TensorDataset(Xs[0], Ys[0])
    data_loader = DataLoader(dataset, batch_size=batch_size)
    return data_loader

In [3]:
import torch.nn as nn
from metal.mmtl.glue.glue_metrics import acc_f1, pearson_spearman
from metal.mmtl.task import ClassificationTask
from metal.mmtl.scorer import Scorer

BATCHSIZE = 8
torch.manual_seed(1234)

foo_input = nn.Linear(2, 10)
bar_input = foo_input #nn.Linear(100, 7)

foo_head = nn.Linear(10, 2)
bar_head = nn.Linear(10, 2)

foo_data = make_dataloader(6000, batch_size=BATCHSIZE)
bar_data = make_dataloader(2000, batch_size=BATCHSIZE)


# custom_metrics = {
#     pearson_spearman: ["pearson_corr", "spearman_corr", "pearson_spearman"]
# }
# scorer = Scorer(["accuracy"], custom_metric_funcs=custom_metrics)

foo = ClassificationTask("foo_task", {"train": foo_data, "valid": foo_data, "test": foo_data}, foo_input, foo_head)
bar = ClassificationTask("bar_task", {"train": bar_data, "valid": bar_data, "test": bar_data}, bar_input, bar_head)
# baz = Task("baz_task", "baz_head", [make_dataloader(100), None, None])
tasks = [foo, bar]

<torch._C.Generator at 0x1132fdf30>

In [18]:
%%time
from metal.end_model import EndModel
from metal.mmtl.metal_model import MetalModel
from metal.mmtl.trainer import MultitaskTrainer

model = MetalModel(tasks, device=-1, seed=123, verbose=False)
trainer = MultitaskTrainer(seed=123)
trainer.train_model(
    model, 
    tasks, 
    n_epochs=2, 
    lr=0.01, 
    progress_bar=False,
    log_unit="epochs",
    log_every=0.5,
    score_every=0.5,
    lr_scheduler="linear",
    warmup_steps = 0.5,
    warmup_unit = "epochs",
    min_lr = 0.0,
#     task_scheduler="proportional",
#     patience=10,
#     task_metrics=None,  # ["model_valid_loss"]
#     max_valid_examples=1000,
    trainer_metrics=[], # "glue_partial",
    test_split="test",
    grad_clip=0.01,
    writer="tensorboard",
    checkpoint=True,
    checkpoint_best=True,
    checkpoint_metric="foo_task/valid/accuracy",
    checkpoint_metric_mode="max",
)

Beginning train loop.
Expecting a total of approximately 5600 examples and 700 batches per epoch from 2 tasks.
[1.00 epo]: TRAIN:[foo_task/loss=8.14e-03, bar_task/loss=9.53e-03, model/loss=8.54e-03, model/lr=2.86e-05] VALID:[foo_task/accuracy=9.94e-01, bar_task/accuracy=9.99e-01]
Saving model at iteration 1.0014285714285713 with best (max) score 0.994
Restoring best model from iteration 1.00 with score 0.994
Cleaning checkpoints
Writing log to /Users/bradenjh/repos/metal/logs/2019_02_21/18_42_30/18_42_30.json
Finished training
{'bar_task/test/accuracy': 0.999, 'foo_task/test/accuracy': 0.994}
CPU times: user 608 ms, sys: 23.4 ms, total: 632 ms
Wall time: 656 ms


In [20]:
a = 5
type(a)('4')

4

In [5]:
# for batch in foo.data_loaders["train"]:
#     X, Y = batch
#     print(model(X, ['foo_task']))
#     print(model.calculate_loss(X, Y, ['foo_task']))    
#     print(model.calculate_output(X, ['foo_task']))    
#     break

In [6]:
model.predict_probs(foo, "train")

array([[9.94975328e-01, 5.02471300e-03],
       [3.47893092e-07, 9.99999642e-01],
       [1.00000000e+00, 4.07478211e-08],
       ...,
       [1.03129009e-24, 1.00000000e+00],
       [5.31781472e-15, 1.00000000e+00],
       [2.15263814e-28, 1.00000000e+00]])

In [7]:
model.predict(foo, "train")

array([1, 2, 1, ..., 2, 2, 2])

In [8]:
model.score(foo, "valid", "accuracy")

0.99825

In [9]:
# tasks = ['A', 'B', 'C']
# batch_counts = [5, 3, 10]
# threshold = max(batch_counts)
# for i in reversed(range(threshold)):
#     for task, count in zip(tasks, batch_counts):
#         if count > i:
#             print(task)

In [10]:
# import os
# filepath = os.path.join(os.environ["METALHOME"], "my_model.pkl")
# model.save(filepath)