Ideal abstraction: 
--The user specifies some "frequency" (in seconds, examples, batches, or epochs) at which certain metrics are calculated (can use different units, freqs). 
--Metrics can be built-in or custom (user provides a func for calculating them, given a model and DataLoader) 
--Everything else (printing, logging to TB, checkpointing, etc.) operates over this metrics dict
--For checkpointing, the user can specify which metric to use and whether to min or max it.

full name of metrics is split/metric (e.g., valid/accuracy)
split is assumed based on which logger is being used, assumed to be valid for checkpointer
    but you can explicitly indicate otherwise if you'd like!

all metrics (standard and custom) are dumped to metrics_dict
all metrics values are pulled from metrics_dict
report all custom metrics, only report standard metrics that are explicitly requested

Train loss gets reported continuously with tqdm progress bar?

In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
from synthetic.generate import singletask_synthetic

n = 10000
m = 10
k = 2
D, L, X, Y, _ = singletask_synthetic(n, m, k)
    
from metal.utils import split_data
Xs, Ys, Ls, Ds = split_data(X, Y, L, D, splits=[0.8, 0.1, 0.1], stratify_by=Y, seed=123)    

In [3]:
from metal.label_model.baselines import MajorityLabelVoter

mv = MajorityLabelVoter(seed=123)

Y_train_ps = mv.predict_proba(Ls[0])
# scores = mv.score((Ls[1], Ys[1]), metric=['precision', 'recall', 'f1'])

In [4]:
def extract_lr(model, dataloader):
    return {'lr': model.optimizer.param_groups[0]['lr']}

In [18]:
from metal.end_model import EndModel

end_model = EndModel([1000,10,2])
end_model.train_model(
    (Xs[0], Y_train_ps), 
    valid_data=(Xs[1], Ys[1]), 
    l2=0.01, 
    batch_size=16, 
    n_epochs=5, 
    progress_bar=False,
    writer=None,
    checkpoint=True,
#     checkpoint_metric="train/loss",
#     checkpoint_metric_mode="min",
#     checkpoint_runway=3,
    log_unit='epochs',
    log_train_metrics=['train/loss'],
    log_train_metrics_func=extract_lr,
    log_valid_metrics=['accuracy','precision'],
    log_train_every=1,
    log_valid_every=1,
)


Network architecture:
Sequential(
  (0): IdentityModule()
  (1): Sequential(
    (0): Linear(in_features=1000, out_features=10, bias=True)
    (1): ReLU()
  )
  (2): Linear(in_features=10, out_features=2, bias=True)
)

No checkpoints will be saved in the first checkpoint_runway=3 iterations.
[1 epo]: TRAIN: [loss=0.165, lr=0.010] VALID: [accuracy=0.989, precision=0.982]
[2 epo]: TRAIN: [loss=0.150, lr=0.010] VALID: [accuracy=0.992, precision=1.000]
[3 epo]: TRAIN: [loss=0.146, lr=0.010] VALID: [accuracy=0.993, precision=0.996]
Saving model at iteration 3 with best score 0.146
Saving model at iteration 3 with best score 0.083
Saving model at iteration 3 with best score 0.082
[4 epo]: TRAIN: [loss=0.147, lr=0.010] VALID: [accuracy=0.990, precision=0.988]
[5 epo]: TRAIN: [loss=0.145, lr=0.010] VALID: [accuracy=0.996, precision=1.000]
Restoring best model from iteration 3 with score 0.082
Finished Training
Accuracy: 0.992
        y=1    y=2   
 l=1    488     3    
 l=2     5     504   


In [6]:
split_metric = "valid/foo/bar"
split_metric.split("/", 1)


['valid', 'foo/bar']