In [7]:
import torch
import json
from pathlib import Path
from task.taskManager import get_model

def load_checkpoint(checkpoint_path):
    checkpoint_path = Path(checkpoint_path)
    assert checkpoint_path.exists()
    checkpoint = torch.load(checkpoint_path)
    task = get_model(checkpoint["arguments"])
    task.net.load_state_dict(checkpoint["model"])
    loss_fn = lambda model, x, y: task.__class__.loss_fn(model, x, y, task.criterion) # we hide the criterion
    forward_fn = task.__class__.forward_fn
    return task.net, forward_fn, loss_fn, task.dataloader_fn, checkpoint["history"]

def parse_checkpoint_names(path):
    file_name = str(path).split("/")[-1]
    info = file_name.split("-")
    task = info[0]
    model = info[1]
    batch = info[3]
    return task, model, batch

def most_trained(task, model):
    #
    # Returns the longest trained checkpoint of a task-model combination
    #
    # Example:
    #     cp_path_addTask_ChronoLSTM = most_trained("addTask", "ChronoLSTM")
    #     model, forward_fn, loss_fn, dataloader_fn, history  = load_checkpoint(cp_path_addTask_ChronoLSTM)
    #
    checkpoint_paths = Path("./saves/").glob("*.pth")
    goods = []
    for path in checkpoint_paths:
        cp_task, cp_model, cp_batch = parse_checkpoint_names(path)
        goods.append({"path": path, "batch": cp_batch})
    return max(goods, key=lambda x: x["batch"])["path"]
        

In [8]:
checkpoint_path = most_trained("addTask", "ChronoLSTM")
print(checkpoint_path)
model, forward_fn, loss_fn, dataloader_fn, history = load_checkpoint(checkpoint_path)

saves/addTask-ChronoLSTM-batch-90000-seed-0.pth


In [10]:
for batch_num, x, y in dataloader_fn():
    if batch_num > 50:
        break
    predictions = forward_fn(model, x)
    losses = loss_fn(model, x, y)
    #print(f"x: {x}")
    print()
    print(f"Y:            {y.view(-1).cpu().data.numpy()}")
    print(f"Predictions:  {predictions.view(-1).cpu().data.numpy()}")


Y:            [1.3474481 0.9346001]
Predictions:  [1.3544083  0.92259246]

Y:            [0.6171282 0.6735193]
Predictions:  [0.62046313 0.66267836]

Y:            [0.5354442 0.6383799]
Predictions:  [0.59999704 0.5743224 ]

Y:            [1.0039248 1.4017282]
Predictions:  [1.023082  1.4211943]

Y:            [0.8889884  0.96116906]
Predictions:  [0.94343686 0.95160097]

Y:            [0.37711495 1.0181358 ]
Predictions:  [0.31571335 1.0841148 ]

Y:            [0.5413036 0.9625013]
Predictions:  [0.55871165 0.9713546 ]

Y:            [1.3180226 1.1504401]
Predictions:  [1.3316301 1.2197387]

Y:            [1.2272192 1.3758796]
Predictions:  [1.2984531 1.3937209]

Y:            [0.73102695 1.4589404 ]
Predictions:  [0.7446311 1.4822464]

Y:            [1.5101938 1.1105547]
Predictions:  [1.568196  1.1495852]

Y:            [1.1070709 1.0577906]
Predictions:  [1.1171157 1.016793 ]

Y:            [0.91498524 1.2982314 ]
Predictions:  [0.9292326 1.2975752]

Y:            [1.06321    0.88