In [1]:
import torch
import json
from pathlib import Path
from task.taskManager import get_model

def load_checkpoint(checkpoint_path):
    checkpoint_path = Path(checkpoint_path)
    assert checkpoint_path.exists()
    checkpoint = torch.load(checkpoint_path)
    task = get_model(checkpoint["arguments"])
    task.net.load_state_dict(checkpoint["model"])
    loss_fn = lambda model, x, y: task.__class__.loss_fn(model, x, y, task.criterion) # we hide the criterion
    forward_fn = task.__class__.forward_fn
    return task.net, forward_fn, loss_fn, task.dataloader_fn, checkpoint["history"]

def parse_checkpoint_names(path):
    file_name = str(path).split("/")[-1]
    info = file_name.split("-")
    task = info[0]
    model = info[1]
    modifier = info[2]
    batch = int(info[4])
    return task, model, modifier, batch

def most_trained(task, model, modifier):
    #
    # Returns the longest trained checkpoint of a task-model combination
    #
    # Example:
    #     cp_path_addTask_ChronoLSTM = most_trained("addTask", "ChronoLSTM")
    #     model, forward_fn, loss_fn, dataloader_fn, history  = load_checkpoint(cp_path_addTask_ChronoLSTM)
    #
    checkpoint_paths = Path("./saves/").glob("*.pth")
    goods = []
    for path in checkpoint_paths:
        cp_task, cp_model, cp_modifier, cp_batch = parse_checkpoint_names(path)
        if cp_task == task and cp_model == model and cp_modifier == modifier:
            goods.append({"path": path, "batch": cp_batch})
    return max(goods, key=lambda x: x["batch"])["path"]

def view_addTask(predictions, y):
    # torch.nn.functional.mse_loss averages the results. we want the individual results here
    losses = (predictions - y)**2  
    predictions = predictions.contiguous().view(-1).cpu().data.numpy()
    y = y.view(-1).cpu().data.numpy()
    losses = losses.view(-1).cpu().data.numpy()
    print("\nPreds:    Y:        Loss:")
    for p, y_, l in zip(predictions, y, losses):
        print(f"{p:0.5f},  {y_:0.5f},  {l:0.7f}")

In [4]:
checkpoint_path = most_trained("addTask", "Rnn", "leaky")
print(checkpoint_path)
model, forward_fn, loss_fn, dataloader_fn, history = load_checkpoint(checkpoint_path)

saves/addTask-Rnn-leaky-batch-29000-seed-0.pth


In [5]:
for batch_num, x, y in dataloader_fn():
    if batch_num > 50:
        break
        
    predictions = forward_fn(model, x)
    losses = loss_fn(model, x, y)
    view_addTask(predictions, y)



Preds:    Y:        Loss:
0.82145,  0.75425,  0.0045162
0.88464,  0.87539,  0.0029096
1.64446,  1.61078,  0.6230509
0.58415,  0.58731,  0.0548212
1.48655,  1.46811,  0.4181754
0.60644,  0.54561,  0.0760863
1.21396,  1.17394,  0.1242533
1.15450,  1.11482,  0.0860659
1.00418,  0.96724,  0.0212546
1.44639,  1.47252,  0.4239004
0.85904,  0.83196,  0.0001106
0.96277,  0.96352,  0.0201839
1.32574,  1.32223,  0.2507847
1.51083,  1.41269,  0.3495631
1.05986,  1.07758,  0.0656029
1.32354,  1.30649,  0.2352697
0.54029,  0.55100,  0.0731421
0.55502,  0.54335,  0.0773371
1.65106,  1.70800,  0.7859767
0.80976,  0.83839,  0.0002869
0.57593,  0.56074,  0.0679701
1.05418,  1.02595,  0.0418195
1.63598,  1.60178,  0.6089185
1.02042,  1.05907,  0.0564631
0.95056,  0.95610,  0.0181312
1.37702,  1.38713,  0.3199916
0.96131,  1.03662,  0.0462985
1.43181,  1.49523,  0.4539812
1.46292,  1.54267,  0.5201600
0.80502,  0.81428,  0.0000514
0.59293,  0.57512,  0.0606800
0.94172,  0.90526,  0.0070248

Preds:    Y:


Preds:    Y:        Loss:
0.78659,  0.78156,  0.0000253
0.84015,  0.79232,  0.0000329
0.37749,  0.33948,  0.1999003
1.26999,  1.29837,  0.2619242
1.26661,  1.28870,  0.2521227
1.65895,  1.67173,  0.7834767
0.86536,  0.91505,  0.0165033
1.36306,  1.38471,  0.3577576
0.76522,  0.81695,  0.0009219
0.97593,  0.96973,  0.0335432
0.17246,  0.23474,  0.3045379
1.27745,  1.32109,  0.2856934
1.18050,  1.24147,  0.2069184
0.69801,  0.69765,  0.0079091
0.85452,  0.87551,  0.0079076
0.43183,  0.46298,  0.1047178
0.42138,  0.38391,  0.1621458
0.85219,  0.82622,  0.0015708
0.31121,  0.31744,  0.2200981
1.26655,  1.27661,  0.2401236
1.76923,  1.76958,  0.9662716
0.74124,  0.73653,  0.0025052
0.95222,  1.01643,  0.0528270
1.61628,  1.60331,  0.6670424
0.86901,  0.88979,  0.0106514
1.22443,  1.30278,  0.2664527
0.25029,  0.25598,  0.2815404
0.51744,  0.56393,  0.0495744
1.77089,  1.73952,  0.9080886
1.59515,  1.60842,  0.6754127
0.86778,  0.89851,  0.0125275
1.55329,  1.61183,  0.6810253

Preds:    Y:


Preds:    Y:        Loss:
0.46864,  0.49261,  0.0005746
0.61556,  0.66460,  0.0384012
1.43422,  1.45348,  0.9699122
1.32780,  1.37243,  0.8168398
0.87464,  0.83250,  0.1323927
1.45383,  1.42667,  0.9178149
0.94803,  0.96058,  0.2420010
1.38626,  1.35623,  0.7878184
0.45404,  0.47584,  0.0000518
1.02122,  1.04165,  0.3283457
0.96222,  0.98913,  0.2709142
0.53164,  0.52482,  0.0031565
0.14709,  0.20804,  0.0679103
1.35951,  1.42834,  0.9210165
0.89902,  0.90143,  0.1873076
0.47125,  0.50863,  0.0015989
0.34397,  0.41748,  0.0026178
1.08312,  1.11189,  0.4137698
1.48878,  1.52463,  1.1151189
0.46872,  0.52872,  0.0036091
1.12834,  1.14482,  0.4572172
0.47825,  0.47683,  0.0000671
0.87195,  0.96764,  0.2490021
1.20091,  1.28575,  0.6676622
1.32277,  1.31631,  0.7185476
0.88080,  0.92026,  0.2039572
1.53017,  1.51500,  1.0948640
1.22331,  1.24568,  0.6037956
0.60764,  0.62165,  0.0234123
0.78740,  0.76845,  0.0898837
1.07365,  1.03393,  0.3195484
1.16823,  1.20780,  0.5463643

Preds:    Y:


Preds:    Y:        Loss:
1.46073,  1.52447,  0.0040634
1.05943,  1.05709,  0.1629223
1.47501,  1.44938,  0.0001288
0.82444,  0.78784,  0.4527755
1.65277,  1.71059,  0.0624327
1.76426,  1.77768,  0.1004591
1.58432,  1.65092,  0.0361740
0.24697,  0.27895,  1.3965963
0.94291,  0.95635,  0.2543932
1.28513,  1.30520,  0.0241875
1.19499,  1.20422,  0.0657974
0.73626,  0.73257,  0.5302072
1.49362,  1.42603,  0.0012037
0.76594,  0.79011,  0.4497304
1.17132,  1.21307,  0.0613353
0.63928,  0.64946,  0.6581544
0.49377,  0.49889,  0.9251381
0.97428,  0.93563,  0.2757290
0.71445,  0.70048,  0.5779682
0.44308,  0.41794,  1.0874091
1.12647,  1.17762,  0.0801500
1.52924,  1.56871,  0.0116608
1.26111,  1.21444,  0.0606549
1.13200,  1.17314,  0.0827051
1.26165,  1.27979,  0.0327386
1.09711,  1.09864,  0.1311085
0.94895,  1.00770,  0.2052363
1.35934,  1.33746,  0.0151936
1.00721,  1.00274,  0.2097481
0.25551,  0.34305,  1.2492101
0.79261,  0.83677,  0.3893277
1.87692,  1.88644,  0.1812321
