In [19]:
import torch
import json
from pathlib import Path
from task.taskManager import get_model

def load_checkpoint(checkpoint_path):
    checkpoint_path = Path(checkpoint_path)
    assert checkpoint_path.exists()
    checkpoint = torch.load(checkpoint_path)
    task = get_model(checkpoint["arguments"])
    task.net.load_state_dict(checkpoint["model"])
    loss_fn = lambda model, x, y: task.__class__.loss_fn(model, x, y, task.criterion) # we hide the criterion
    forward_fn = task.__class__.forward_fn
    return task.net, forward_fn, loss_fn, task.dataloader_fn, checkpoint["history"]

def parse_checkpoint_names(path):
    file_name = str(path).split("/")[-1]
    info = file_name.split("-")
    task = info[0]
    model = info[1]
    modifier = info[2]
    batch = int(info[4])
    return task, model, modifier, batch

def most_trained(task, model, modifier):
    #
    # Returns the longest trained checkpoint of a task-model combination
    #
    # Example:
    #     cp_path_addTask_ChronoLSTM = most_trained("addTask", "ChronoLSTM")
    #     model, forward_fn, loss_fn, dataloader_fn, history  = load_checkpoint(cp_path_addTask_ChronoLSTM)
    #
    checkpoint_paths = Path("./saves/").glob("*.pth")
    goods = []
    for path in checkpoint_paths:
        cp_task, cp_model, cp_modifier, cp_batch = parse_checkpoint_names(path)
        if cp_task == task and cp_model == model and cp_modifier == modifier:
            goods.append({"path": path, "batch": cp_batch})
    return max(goods, key=lambda x: x["batch"])["path"]

def view_addTask(predictions, y):
    # torch.nn.functional.mse_loss averages the results. we want the individual results here
    losses = (predictions - y)**2  
    predictions = predictions.contiguous().view(-1).cpu().data.numpy()
    y = y.view(-1).cpu().data.numpy()
    losses = losses.view(-1).cpu().data.numpy()
    print("\nPreds:    Y:        Loss:")
    for p, y_, l in zip(predictions, y, losses):
        print(f"{p:0.5f},  {y_:0.5f},  {l:0.7f}")

In [21]:
checkpoint_path = most_trained("addTask", "Rnn", "leaky")
print(checkpoint_path)
model, forward_fn, loss_fn, dataloader_fn, history = load_checkpoint(checkpoint_path)

saves/addTask-Rnn-leaky-batch-16000-seed-0.pth


In [22]:
  
for batch_num, x, y in dataloader_fn():
    if batch_num > 50:
        break
        
    predictions = forward_fn(model, x)
    losses = loss_fn(model, x, y)
    view_addTask(predictions, y)



Preds:    Y:        Loss:
0.96845,  0.97816,  0.0000942
1.26663,  1.29780,  0.1084676
0.97529,  0.99510,  0.0007102
1.41611,  1.19031,  0.0492218
1.55946,  1.65872,  0.4764619
1.28640,  1.24449,  0.0761981
0.59609,  0.57377,  0.1557745
1.43465,  1.31639,  0.1210613
0.94360,  1.05254,  0.0070707
1.48476,  1.41241,  0.1971021
1.68150,  1.79906,  0.6899123
0.84960,  0.68982,  0.0776355
1.50885,  1.76791,  0.6391381
1.39924,  1.40217,  0.1881067
1.06266,  1.03580,  0.0045358
1.49937,  1.52993,  0.3152578
1.32842,  1.33427,  0.1338217
1.44289,  1.64048,  0.4516208
0.70703,  1.03533,  0.0044721
1.24058,  1.26613,  0.0886133
1.52811,  1.42967,  0.2127196
1.56800,  1.63512,  0.4444387
0.25705,  0.32448,  0.4147035
1.48278,  1.64924,  0.4634758
0.64666,  0.73442,  0.0547706
1.24213,  1.23669,  0.0719535
0.74861,  0.59278,  0.1411277
0.45965,  0.51277,  0.2076449
1.16715,  1.26576,  0.0883941
1.20141,  1.19234,  0.0501249
0.15257,  0.21349,  0.5699626
0.99444,  1.03808,  0.0048481

Preds:    Y:


Preds:    Y:        Loss:
1.53937,  1.58292,  0.0018966
0.42863,  0.51801,  1.0431751
1.01241,  0.77600,  0.5827311
1.73983,  1.76311,  0.0500582
1.27417,  1.39058,  0.0221392
0.72850,  0.94631,  0.3517219
1.55207,  1.58934,  0.0024972
1.01806,  1.19953,  0.1154915
1.52641,  1.71811,  0.0319492
1.17455,  1.09778,  0.1950049
1.50568,  1.61344,  0.0054856
0.61588,  0.76421,  0.6008786
0.92028,  0.84513,  0.4819761
0.41935,  0.33888,  1.4411806
0.87691,  0.88201,  0.4321187
0.77789,  0.68282,  0.7336856
1.01856,  0.82172,  0.5150203
0.59270,  0.68043,  0.7377731
1.59491,  1.71494,  0.0308229
1.40352,  0.94120,  0.3578121
1.15444,  1.13704,  0.1618722
1.36930,  1.35288,  0.0347808
1.01567,  1.06189,  0.2279879
0.99916,  1.25843,  0.0789271
1.23250,  1.39392,  0.0211558
0.86393,  0.74372,  0.6330670
1.15570,  1.23147,  0.0948006
1.30596,  1.35978,  0.0322518
0.82949,  0.86507,  0.4546831
0.83107,  0.74971,  0.6235661
0.91578,  0.85353,  0.4703778
1.53799,  1.74589,  0.0426513

Preds:    Y:


Preds:    Y:        Loss:
1.26206,  1.33071,  0.0047124
1.14737,  1.12635,  0.0184170
0.68701,  0.75267,  0.2594827
0.75524,  0.84686,  0.1723943
1.18628,  1.10832,  0.0236354
1.06464,  1.14302,  0.0141701
0.76244,  0.56650,  0.4838004
0.95835,  0.92896,  0.1109542
1.58848,  1.66000,  0.1583537
0.60853,  0.55103,  0.5055587
1.12703,  1.23454,  0.0007572
0.11653,  0.18424,  1.1616883
1.17796,  0.94863,  0.0982386
1.01305,  0.95646,  0.0933938
1.14262,  1.14707,  0.0132234
0.50268,  0.86649,  0.1564736
1.35276,  1.28872,  0.0007105
0.39252,  0.21370,  1.0990542
1.14099,  0.95088,  0.0968324
1.46242,  1.34745,  0.0072904
1.07701,  1.33740,  0.0056756
1.08792,  1.00321,  0.0670036
1.65958,  1.77159,  0.2596188
0.27306,  0.25151,  1.0212166
0.95867,  0.93896,  0.1043951
1.34801,  1.47400,  0.0449160
1.03829,  1.31113,  0.0024082
0.20901,  0.23555,  1.0537238
0.69265,  0.46758,  0.6312056
1.45969,  1.48387,  0.0491986
1.25837,  1.26105,  0.0000010
1.19218,  1.29844,  0.0013237

Preds:    Y:


Preds:    Y:        Loss:
1.09772,  1.15200,  0.0029463
0.80989,  0.85037,  0.0611811
0.90733,  1.03324,  0.0041584
0.67734,  0.65781,  0.1935239
1.07070,  1.23876,  0.0198902
0.47837,  0.58208,  0.2658860
1.45191,  1.50282,  0.1641016
1.12643,  1.01070,  0.0075724
0.73031,  0.71650,  0.1453308
1.70363,  1.67741,  0.3360426
1.10663,  0.86807,  0.0527417
1.20920,  1.24683,  0.0222318
1.47415,  1.79320,  0.4836841
1.43651,  1.51716,  0.1759266
1.43240,  1.40336,  0.0934163
1.46225,  1.62432,  0.2773059
0.84663,  1.11054,  0.0001642
0.86806,  0.85937,  0.0568135
1.27807,  1.41047,  0.0978131
1.33634,  1.30170,  0.0416053
1.15539,  1.18639,  0.0078613
0.89689,  1.19613,  0.0096841
1.07810,  1.12426,  0.0007042
0.87775,  0.92887,  0.0285110
1.12881,  1.05186,  0.0021033
1.34861,  1.54524,  0.2002703
0.80518,  0.67000,  0.1829451
1.19244,  1.40355,  0.0935274
0.29435,  0.07490,  1.0461602
1.13686,  0.98490,  0.0127279
1.37591,  1.26802,  0.0290026
1.20645,  1.20356,  0.0112009

Preds:    Y: