In [13]:
import torch
import json
from pathlib import Path
from task.taskManager import get_model

def load_checkpoint(checkpoint_path):
    checkpoint_path = Path(checkpoint_path)
    assert checkpoint_path.exists()
    checkpoint = torch.load(checkpoint_path)
    task = get_model(checkpoint["arguments"])
    task.net.load_state_dict(checkpoint["model"])
    loss_fn = lambda model, x, y: task.__class__.loss_fn(model, x, y, task.criterion) # we hide the criterion
    forward_fn = task.__class__.forward_fn
    return task.net, forward_fn, loss_fn, task.dataloader_fn, checkpoint["history"]

def parse_checkpoint_names(path):
    file_name = str(path).split("/")[-1]
    info = file_name.split("-")
    task = info[0]
    model = info[1]
    batch = int(info[3])
    return task, model, batch

def most_trained(task, model):
    #
    # Returns the longest trained checkpoint of a task-model combination
    #
    # Example:
    #     cp_path_addTask_ChronoLSTM = most_trained("addTask", "ChronoLSTM")
    #     model, forward_fn, loss_fn, dataloader_fn, history  = load_checkpoint(cp_path_addTask_ChronoLSTM)
    #
    checkpoint_paths = Path("./saves/").glob("*.pth")
    goods = []
    for path in checkpoint_paths:
        cp_task, cp_model, cp_batch = parse_checkpoint_names(path)
        goods.append({"path": path, "batch": cp_batch})
    return max(goods, key=lambda x: x["batch"])["path"]

def view_addTask(predictions, y):
    # torch.nn.functional.mse_loss averages the results. we want the individual results here
    losses = (predictions - y)**2  
    predictions = predictions.contiguous().view(-1).cpu().data.numpy()
    y = y.view(-1).cpu().data.numpy()
    losses = losses.view(-1).cpu().data.numpy()
    print("\nPreds:    Y:        Loss:")
    for p, y_, l in zip(predictions, y, losses):
        print(f"{p:0.5f},  {y_:0.5f},  {l:0.7f}")

In [14]:
checkpoint_path = most_trained("addTask", "Rnn")
print(checkpoint_path)
model, forward_fn, loss_fn, dataloader_fn, history = load_checkpoint(checkpoint_path)

saves/addTask-Rnn-batch-88000-seed-0.pth


In [15]:
  
for batch_num, x, y in dataloader_fn():
    if batch_num > 50:
        break
        
    predictions = forward_fn(model, x)
    losses = loss_fn(model, x, y)
    view_addTask(predictions, y)



Preds:    Y:        Loss:
0.70235,  0.70628,  0.0000155
1.11477,  1.10806,  0.1646023
0.51889,  0.48237,  0.0483885
1.10922,  1.07944,  0.1421994
0.09781,  0.06979,  0.4001224
1.70905,  1.71826,  1.0320848
0.49287,  0.49644,  0.0423985
0.40201,  0.35986,  0.1172931
1.46382,  1.44979,  0.5586781
1.00966,  1.01133,  0.0954689
0.51942,  0.48768,  0.0460825
0.59250,  0.59909,  0.0106618
0.27569,  0.25915,  0.1964266
1.15306,  1.13480,  0.1870184
1.28151,  1.27803,  0.3314172
1.47207,  1.45921,  0.5728393
1.48524,  1.47321,  0.5942385
1.44888,  1.44893,  0.5573820
1.45615,  1.44483,  0.5512847
1.72959,  1.71498,  1.0254315
1.54891,  1.50530,  0.6447408
0.46954,  0.45883,  0.0593004
1.75877,  1.80187,  1.2089609
0.43543,  0.44571,  0.0658597
1.07838,  1.05600,  0.1250730
0.81990,  0.79688,  0.0089369
1.72513,  1.74116,  1.0791310
0.60801,  0.61209,  0.0081462
0.51576,  0.51096,  0.0366274
1.58882,  1.58317,  0.7758433
1.30803,  1.33161,  0.3959709
1.52397,  1.52765,  0.6811307

Preds:    Y:


Preds:    Y:        Loss:
1.35627,  1.35327,  0.0000090
0.67567,  0.65464,  0.4922829
1.50079,  1.48329,  0.0161337
1.26730,  1.28501,  0.0050776
1.20614,  1.22664,  0.0168025
1.21298,  1.22357,  0.0176089
0.67193,  0.65130,  0.4969800
0.77580,  0.76512,  0.3494508
1.61447,  1.59731,  0.0581023
0.70783,  0.69657,  0.4351994
0.46548,  0.48046,  0.7670369
0.88753,  0.89802,  0.2099958
0.87401,  0.86821,  0.2382054
1.44852,  1.49439,  0.0190769
1.01704,  0.99212,  0.1326017
0.53134,  0.52346,  0.6935691
0.10285,  0.12250,  1.5221914
1.07365,  1.04936,  0.0941941
1.27433,  1.25901,  0.0094599
0.47691,  0.46090,  0.8016850
0.74008,  0.72062,  0.4040501
1.14476,  1.14172,  0.0460305
1.07199,  1.07269,  0.0804177
1.35822,  1.34656,  0.0000943
1.37582,  1.36220,  0.0000352
1.84425,  1.87854,  0.2727719
1.11736,  1.10343,  0.0639286
1.15627,  1.13152,  0.0505103
1.57163,  1.56646,  0.0441825
1.70946,  1.70909,  0.1244854
0.96744,  0.95584,  0.1603452
1.11632,  1.11858,  0.0564957

Preds:    Y:


Preds:    Y:        Loss:
1.39889,  1.40591,  0.0000492
1.01520,  1.01766,  0.1453340
0.57120,  0.57981,  0.6708909
0.32790,  0.34910,  1.1020683
0.95256,  0.91512,  0.2340333
0.16212,  0.18013,  1.4853722
1.09177,  1.05997,  0.1148640
1.11189,  1.13682,  0.0686799
0.63182,  0.64623,  0.5664965
0.76622,  0.76238,  0.4051396
0.70015,  0.67064,  0.5303513
0.31450,  0.31754,  1.1693273
1.19648,  1.20471,  0.0377065
1.44859,  1.46610,  0.0045166
0.79073,  0.77181,  0.3932290
1.01072,  0.99176,  0.1657523
0.90854,  0.90496,  0.2439664
0.43702,  0.43674,  0.9257239
1.26815,  1.28765,  0.0123748
1.32697,  1.31419,  0.0071747
1.29894,  1.27847,  0.0145005
1.41711,  1.42063,  0.0004726
1.16229,  1.12736,  0.0737305
1.08105,  1.05531,  0.1180447
0.45259,  0.44288,  0.9139576
1.46990,  1.46047,  0.0037922
0.78938,  0.80098,  0.3574922
1.37106,  1.35878,  0.0016091
1.15451,  1.19661,  0.0409189
0.76461,  0.79827,  0.3607479
0.64923,  0.62125,  0.6047260
1.09216,  1.10509,  0.0863161

Preds:    Y:


Preds:    Y:        Loss:
0.69884,  0.71726,  0.0003391
0.88032,  0.87177,  0.0299054
1.24070,  1.21362,  0.2649974
1.47057,  1.48566,  0.6190807
0.47582,  0.47357,  0.0507476
0.92525,  0.94703,  0.0615962
0.42940,  0.42489,  0.0750526
1.24218,  1.21882,  0.2703715
1.01197,  0.98820,  0.0837273
1.81570,  1.77555,  1.1593072
0.97836,  0.99198,  0.0859308
0.97954,  0.99455,  0.0874446
1.28671,  1.27977,  0.3374772
0.92373,  0.89243,  0.0374771
1.04206,  1.05600,  0.1275599
1.39690,  1.39361,  0.4827053
1.22820,  1.25846,  0.3131683
1.46945,  1.43858,  0.5472171
1.77966,  1.77241,  1.1525497
1.31664,  1.33551,  0.4053423
1.02910,  1.02294,  0.1050380
1.33701,  1.37677,  0.4595849
0.98582,  1.00196,  0.0918779
0.66291,  0.67872,  0.0004051
0.23750,  0.24044,  0.2101323
1.59671,  1.60614,  0.8231890
1.09424,  1.07759,  0.1434485
1.30507,  1.30692,  0.3697563
0.82285,  0.82884,  0.0168986
0.61193,  0.62487,  0.0054724
0.73211,  0.71875,  0.0003964
0.95508,  0.93984,  0.0580816

Preds:    Y: