In [7]:
import torch
import json
from pathlib import Path
from task.taskManager import get_model

def load_checkpoint(checkpoint_path):
    checkpoint_path = Path(checkpoint_path)
    assert checkpoint_path.exists()
    checkpoint = torch.load(checkpoint_path)
    task = get_model(checkpoint["arguments"])
    task.net.load_state_dict(checkpoint["model"])
    loss_fn = lambda model, x, y: task.__class__.loss_fn(model, x, y, task.criterion) # we hide the criterion
    forward_fn = task.__class__.forward_fn
    return task.net, forward_fn, loss_fn, task.dataloader_fn, checkpoint["history"]

def parse_checkpoint_names(path):
    file_name = str(path).split("/")[-1]
    info = file_name.split("-")
    task = info[0]
    model = info[1]
    batch = int(info[3])
    return task, model, batch

def most_trained(task, model):
    #
    # Returns the longest trained checkpoint of a task-model combination
    #
    # Example:
    #     cp_path_addTask_ChronoLSTM = most_trained("addTask", "ChronoLSTM")
    #     model, forward_fn, loss_fn, dataloader_fn, history  = load_checkpoint(cp_path_addTask_ChronoLSTM)
    #
    checkpoint_paths = Path("./saves/").glob("*.pth")
    goods = []
    for path in checkpoint_paths:
        cp_task, cp_model, cp_batch = parse_checkpoint_names(path)
        goods.append({"path": path, "batch": cp_batch})
    return max(goods, key=lambda x: x["batch"])["path"]
        

In [8]:
checkpoint_path = most_trained("addTask", "ChronoLSTM")
print(checkpoint_path)
model, forward_fn, loss_fn, dataloader_fn, history = load_checkpoint(checkpoint_path)

saves/addTask-ChronoLSTM-batch-620000-seed-0.pth


In [9]:
for batch_num, x, y in dataloader_fn():
    if batch_num > 50:
        break
    predictions = forward_fn(model, x)
    losses = loss_fn(model, x, y)
    #print(f"x: {x}")
    print()
    print(f"Y:            {y.view(-1).cpu().data.numpy()}")
    print(f"Predictions:  {predictions.view(-1).cpu().data.numpy()}")


Y:            [1.1375896 1.4436742]
Predictions:  [1.1440258 1.4745741]

Y:            [1.4710426 1.5079137]
Predictions:  [1.4980526 1.5451844]

Y:            [1.0469986  0.17319919]
Predictions:  [1.0404396  0.15989706]

Y:            [1.4556488 1.0180446]
Predictions:  [1.449225  1.0389453]

Y:            [0.3183548  0.28209436]
Predictions:  [0.3035043  0.26866493]

Y:            [0.69686544 1.4724506 ]
Predictions:  [0.708089  1.5062692]

Y:            [1.838729   0.62837696]
Predictions:  [1.8346357  0.64419544]

Y:            [0.8539101 0.789074 ]
Predictions:  [0.88825804 0.7964368 ]

Y:            [1.4645908 1.6812541]
Predictions:  [1.4562454 1.6738045]

Y:            [0.88366294 0.71673423]
Predictions:  [0.91474396 0.7160205 ]

Y:            [1.389568   0.86979824]
Predictions:  [1.3881428  0.88351667]

Y:            [0.36550736 1.4884863 ]
Predictions:  [0.3529688 1.5172195]

Y:            [1.3669109 1.177854 ]
Predictions:  [1.3328431 1.2034781]

Y:            [0.8481369