In [33]:
import torch
import json
from pathlib import Path
from task.taskManager import get_model

def load_checkpoint(checkpoint_path):
    checkpoint_path = Path(checkpoint_path)
    assert checkpoint_path.exists()
    checkpoint = torch.load(checkpoint_path)
    task = get_model(checkpoint["arguments"])
    task.net.load_state_dict(checkpoint["model"])
    loss_fn = lambda model, x, y: task.__class__.loss_fn(model, x, y, task.criterion) # we hide the criterion
    forward_fn = task.__class__.forward_fn
    return task.net, forward_fn, loss_fn, task.dataloader_fn, checkpoint["history"]

def parse_checkpoint_names(path):
    file_name = str(path).split("/")[-1]
    info = file_name.split("-")
    task = info[0]
    model = info[1]
    batch = info[3]
    return task, model, batch

def most_trained(task, model):
    #
    # Returns the longest trained checkpoint of a task-model combination
    #
    # Example:
    #     cp_path_addTask_ChronoLSTM = most_trained("addTask", "ChronoLSTM")
    #     model, forward_fn, loss_fn, dataloader_fn, history  = load_checkpoint(cp_path_addTask_ChronoLSTM)
    #
    checkpoint_paths = Path("./saves/").glob("*.pth")
    goods = []
    for path in checkpoint_paths:
        cp_task, cp_model, cp_batch = parse_checkpoint_names(path)
        goods.append({"path": path, "batch": cp_batch})
    return max(goods, key=lambda x: x["batch"])["path"]
        

In [40]:
checkpoint_path = most_trained("addTask", "ChronoLSTM")
print(checkpoint_path)
model, forward_fn, loss_fn, dataloader_fn, history = load_checkpoint(checkpoint_path)

saves/addTask-ChronoLSTM-batch-60000-seed-0.pth


In [41]:
for batch_num, x, y in dataloader_fn():
    predictions = forward_fn(model, x)
    losses = loss_fn(model, x, y)
    #print(f"x: {x}")
    print()
    print(f"Y:            {y.view(-1).cpu().data.numpy()}")
    print(f"Predictions:  {predictions.view(-1).cpu().data.numpy()}")


Y:            [0.81168556 0.547483  ]
Predictions:  [0.8640177 0.5620277]

Y:            [0.9946759  0.81528246]
Predictions:  [0.9399312 0.8722621]

Y:            [1.7030679 1.3512139]
Predictions:  [1.6815732 1.2965504]

Y:            [0.9306775 1.444541 ]
Predictions:  [0.9997676 1.3880951]

Y:            [0.28777772 0.25945204]
Predictions:  [0.3833443  0.31367028]

Y:            [1.6004289 1.3542571]
Predictions:  [1.5451508 1.3008628]

Y:            [1.3774774 1.6768824]
Predictions:  [1.3924093 1.520457 ]

Y:            [0.16015033 1.4540488 ]
Predictions:  [0.27397615 1.369669  ]

Y:            [0.73928267 0.9578585 ]
Predictions:  [0.7489887 0.9045748]

Y:            [0.45971486 0.2208832 ]
Predictions:  [0.5204885 0.2919354]

Y:            [0.7204699  0.60071003]
Predictions:  [0.729144 0.658039]

Y:            [0.9542944 1.1296391]
Predictions:  [0.9991817 1.098474 ]

Y:            [1.4836065 0.8208914]
Predictions:  [1.3996408  0.77440447]

Y:            [1.1635845 1.43922


Y:            [1.748302  0.9598317]
Predictions:  [1.6137335 1.0153627]

Y:            [0.69827867 0.40094534]
Predictions:  [0.65787476 0.4382695 ]

Y:            [1.0775936 1.4819546]
Predictions:  [1.0592914 1.4988277]

Y:            [0.6820108  0.86086875]
Predictions:  [0.7083596 0.8724025]

Y:            [0.8453035 1.1343548]
Predictions:  [0.8626999 1.1422243]

Y:            [0.94631803 1.8479863 ]
Predictions:  [0.9631504 1.6945374]

Y:            [0.94866717 1.2718705 ]
Predictions:  [0.9206171 1.1864891]

Y:            [0.4939325  0.85943204]
Predictions:  [0.51364946 0.8736749 ]

Y:            [1.0957292 1.5620329]
Predictions:  [1.047328  1.5941188]

Y:            [0.6677184  0.17191866]
Predictions:  [0.6642431  0.23139338]

Y:            [0.9049735 1.386186 ]
Predictions:  [0.89462477 1.3210027 ]

Y:            [0.90534574 1.8196634 ]
Predictions:  [0.91591185 1.7946401 ]

Y:            [0.56050116 1.3200877 ]
Predictions:  [0.4998479 1.2744536]

Y:            [0.6291891


Y:            [0.6581805  0.60404456]
Predictions:  [0.7999711 0.6726724]

Y:            [0.4655352  0.46627942]
Predictions:  [0.6238389  0.43017018]

Y:            [0.77877855 0.74200636]
Predictions:  [0.8351013  0.78572994]

Y:            [1.5886368 1.0025833]
Predictions:  [1.4477172 1.0106404]

Y:            [0.6852228 1.1401433]
Predictions:  [0.7294406 1.1523993]

Y:            [0.27575627 0.82344   ]
Predictions:  [0.39662403 0.94184375]

Y:            [0.93967843 0.228791  ]
Predictions:  [0.91378886 0.2601518 ]

Y:            [0.7701424 0.8768632]
Predictions:  [0.84731203 0.88780385]

Y:            [0.4999497 0.8240995]
Predictions:  [0.5549638  0.86905795]

Y:            [0.947218   0.95005894]
Predictions:  [0.9761067 0.9157557]

Y:            [0.9879942 0.672598 ]
Predictions:  [0.94980067 0.70027995]

Y:            [1.2188654  0.97089624]
Predictions:  [1.158824 0.977863]

Y:            [1.1132668 0.8242056]
Predictions:  [1.1206331 0.8384102]

Y:            [1.9294561

Predictions:  [1.4015157  0.53807294]

Y:            [1.0823047  0.90040576]
Predictions:  [1.0859964  0.85421497]

Y:            [0.34335476 0.9337505 ]
Predictions:  [0.3577638  0.90090096]

Y:            [1.8427539 0.8858634]
Predictions:  [1.781182  0.9004579]

Y:            [1.5638359  0.70963985]
Predictions:  [1.4670072  0.71037483]

Y:            [0.9796376  0.56652445]
Predictions:  [0.97229815 0.6345161 ]

Y:            [0.71187854 0.5869602 ]
Predictions:  [0.75563264 0.6452913 ]

Y:            [0.63180524 1.3153816 ]
Predictions:  [0.6325698 1.2068553]

Y:            [0.7908208  0.31646562]
Predictions:  [0.83078843 0.36737218]

Y:            [0.15697807 0.9433456 ]
Predictions:  [0.29115224 0.9963805 ]

Y:            [1.422156 0.82736 ]
Predictions:  [1.4816346  0.82747436]

Y:            [1.0837444 0.7973498]
Predictions:  [1.125202   0.83038116]

Y:            [1.1224663  0.88514215]
Predictions:  [1.1132257 0.9235254]

Y:            [1.3853412 1.4708914]
Predictions:  [


Y:            [1.5874462 1.4403784]
Predictions:  [1.5448711 1.4185545]

Y:            [1.2738975  0.18335475]
Predictions:  [1.268671   0.25494874]

Y:            [0.81462437 1.3072243 ]
Predictions:  [0.78565615 1.2679837 ]

Y:            [1.1620119 1.096568 ]
Predictions:  [1.1524515 1.0946078]

Y:            [1.5967724  0.34422827]
Predictions:  [1.5890398  0.38197213]

Y:            [1.2471172 1.2434382]
Predictions:  [1.1536887 1.1969266]

Y:            [1.6547294 1.6321143]
Predictions:  [1.5496135 1.5451767]

Y:            [1.0684456  0.29429293]
Predictions:  [1.059481   0.36136302]

Y:            [0.37337083 1.5306991 ]
Predictions:  [0.4578168 1.4688029]

Y:            [1.7848141  0.48749214]
Predictions:  [1.6978464  0.49483863]

Y:            [0.5672045 1.4778779]
Predictions:  [0.6562002 1.4280455]

Y:            [0.23865142 1.8296065 ]
Predictions:  [0.25998312 1.7030325 ]

Y:            [0.63634795 0.5995668 ]
Predictions:  [0.7099659  0.61873287]

Y:            [0.606


Y:            [0.7201657 0.8505909]
Predictions:  [0.7469468 0.9049127]

Y:            [1.5159011 0.3136484]
Predictions:  [1.3993299  0.31754416]

Y:            [1.3616974  0.88606465]
Predictions:  [1.3465176  0.91133803]

Y:            [0.35673815 1.4941859 ]
Predictions:  [0.34809184 1.4035759 ]

Y:            [1.1021312 1.1524689]
Predictions:  [1.0403419 1.1588221]

Y:            [0.92846286 1.2889256 ]
Predictions:  [0.9621494 1.2611129]

Y:            [0.7562633 0.5478822]
Predictions:  [0.7914532  0.62092686]

Y:            [1.2419835  0.99280554]
Predictions:  [1.1490698 1.0767637]

Y:            [1.256169  0.8566402]
Predictions:  [1.2501551  0.84587735]

Y:            [1.046132   0.83444315]
Predictions:  [1.1132698 0.7927755]

Y:            [0.3042618 1.0994422]
Predictions:  [0.3869032 1.0725757]

Y:            [1.3283628  0.17425127]
Predictions:  [1.272867   0.32604098]

Y:            [0.3000543 1.1743248]
Predictions:  [0.4192038 1.1186266]

Y:            [1.3249025  


Y:            [1.113823 1.386246]
Predictions:  [1.0871124 1.2575786]

Y:            [0.5741361 0.2660618]
Predictions:  [0.6337961  0.26536652]

Y:            [0.8644691  0.32493845]
Predictions:  [0.90746355 0.4329701 ]

Y:            [0.8825757 1.226238 ]
Predictions:  [0.85170734 1.1608622 ]

Y:            [1.4343774  0.36060467]
Predictions:  [1.3806529 0.4372936]

Y:            [0.14188033 0.3888631 ]
Predictions:  [0.21706384 0.49724182]

Y:            [0.82583183 0.9250618 ]
Predictions:  [0.85071045 0.98252547]

Y:            [0.13572767 1.0688971 ]
Predictions:  [0.19320923 1.0725319 ]

Y:            [1.7418747 1.1207662]
Predictions:  [1.7227705 1.1097677]

Y:            [1.2781286  0.77344054]
Predictions:  [1.1938701 0.7997574]

Y:            [1.1170399 1.0649097]
Predictions:  [1.0351343 1.0470302]

Y:            [0.9137328  0.91391003]
Predictions:  [0.83302546 0.97603655]

Y:            [0.8097659 1.6021038]
Predictions:  [0.86969453 1.5189757 ]

Y:            [0.49272


Y:            [1.679765  0.8628533]
Predictions:  [1.5515189 0.8597209]

Y:            [0.94619304 1.0900524 ]
Predictions:  [0.98735505 1.0087833 ]

Y:            [0.9496303 0.9924447]
Predictions:  [0.9570146  0.97104484]

Y:            [1.0672787 0.8647638]
Predictions:  [1.0440631 0.8893974]

Y:            [1.1330379 1.2628263]
Predictions:  [1.1542492 1.2569821]

Y:            [1.0801108 1.2142978]
Predictions:  [1.1071761 1.2485864]

Y:            [1.4907653  0.90897125]
Predictions:  [1.3928196  0.96816486]

Y:            [1.2355367 1.5443   ]
Predictions:  [1.1813109 1.4963484]

Y:            [1.4675394 1.0940331]
Predictions:  [1.4296718 1.0924027]

Y:            [1.4049181 0.3704085]
Predictions:  [1.3202295  0.43593866]

Y:            [1.2563621  0.98272806]
Predictions:  [1.2136877 0.9800257]

Y:            [1.2820652 0.3825112]
Predictions:  [1.307492  0.4154822]

Y:            [0.73002815 1.401606  ]
Predictions:  [0.75894517 1.4222558 ]

Y:            [0.44347355 1.7053

Y:            [1.6050141  0.15903817]
Predictions:  [1.4875965  0.29588488]

Y:            [0.9028913 0.7432719]
Predictions:  [0.9152362 0.8879133]

Y:            [1.2401756 0.6235872]
Predictions:  [1.3312855 0.7460055]

Y:            [0.6584592  0.87057287]
Predictions:  [0.71825796 0.88110757]

Y:            [0.99461746 1.1225435 ]
Predictions:  [1.0050278 1.1035104]

Y:            [1.0240729  0.53500026]
Predictions:  [0.97797143 0.53861666]

Y:            [1.1415583 0.5404066]
Predictions:  [1.1717932 0.5614062]

Y:            [1.4795754  0.85986865]
Predictions:  [1.3040876  0.90493655]

Y:            [0.6145637  0.78094304]
Predictions:  [0.55185825 0.77244294]

Y:            [0.81927216 1.0218409 ]
Predictions:  [0.8316924 1.0410948]

Y:            [1.0735204 1.0918229]
Predictions:  [1.0833112 1.0390238]

Y:            [0.91493946 1.8949314 ]
Predictions:  [1.020517  1.7227967]

Y:            [0.29172084 1.525324  ]
Predictions:  [0.2957381 1.4137654]

Y:            [0.744510