In [1]:
import sys
root_path = '../'
sys.path.insert(0, root_path)

import lightning as L
from natsort import natsorted

from run import *
from utils import *

In [2]:
%load_ext autoreload
%autoreload 2

Load checkpoints

In [3]:
run_name = "5zfoags0"
root_dir = os.path.join("../../results", "DPT_3", run_name, "checkpoints")
checkpoint = natsorted(os.listdir(root_dir))[-1]
checkpoint_file = os.path.join(root_dir, checkpoint)

model = DPTSolver.load_from_checkpoint(checkpoint_file).cpu()

Load the offline datasets used for train and validation as well as online dataset with the same problems that were in validation.

In [None]:
config = model.config
config['n_problems'] = 100 # for a fast check
dl = get_dataloaders(config)

train_offline_dataset = dl["train_dataloaders"].dataset
val_offline_dataset = dl["val_dataloaders"][0].dataset
val_online_dataset = dl["val_dataloaders"][1].dataset

In [5]:
def online_inference(model):
    tester = L.Trainer(logger=False, precision=config["precision"])
    test_dataloader = dl["val_dataloaders"][1]

    # check out two strategies of online inference:
    # - where a predicted action is the argmax of a predicted distribution 
    # - where a predicted action is sampled with temperature = 1 from a predicted distribution 
    hparams = [
        {"do_sample": False, "temperature": 0.0},
        {"do_sample": True, "temperature": 1.0},
    ]
    results_list = []
    for i, hparam in enumerate(hparams):
        model.config["do_sample"] = hparam["do_sample"]
        model.config["temperature"] = hparam["temperature"]
        best_results = tester.test(model=model, dataloaders=test_dataloader)[0]
        best_results = {
            "MAE(best x, x*)": best_results["test x_mae"],
            "MAE(best y, y*)": best_results["test y_mae"]
        }
        all_results = model.save_results
        all_results = {
            "MAE(all x, x*)": all_results["x_mae"].cpu().tolist(),
            "MAE(all y, y*)": all_results["y_mae"].cpu().tolist()
        }
        results_list.append(hparam | best_results | all_results)
    return results_list

In [None]:
results_list = online_inference(model)

In [None]:
from matplotlib import colormaps as cm

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].set_title('MAE (x, x*)')
axes[1].set_title('MAE (y, y*)')

cmap = cm.get_cmap('jet')
color_list = [cmap(c) for c in np.linspace(0.1, 0.9, len(results_list))[::-1]]
for results, c in zip(results_list, color_list):
    label = f't = {results["temperature"]}' if results["do_sample"] else 'argmax'
    axes[0].plot(results["MAE(all x, x*)"], c=c)
    axes[1].plot(results["MAE(all y, y*)"], c=c, label=f'{label}')
    # print(label)
    # print(f'MAE(best x, x*) = {results["MAE(best x, x*)"]}')
    # print(f'MAE(best y, y*) = {results["MAE(best y, y*)"]}')
    # print()

for ax in axes:
    ax.set_xlim(0, None)
    ax.set_xlabel('Step')
axes[1].set_yscale('log')
axes[1].legend(loc=1)

plt.tight_layout()
plt.show()

An example of offline mode for a problem from the train dataset.

In [None]:
sample, outputs, predictions, metrics = run(model, train_offline_dataset[0])
print_sample(sample, predictions)
print_metrics(metrics)

An example of online mode for a problem from the validation dataset.

In [None]:
# one may choose a strategy
# model.config["do_sample"] = False
# model.config["temperature"] = 0.0
model.config["do_sample"] = True
model.config["temperature"] = 1.0

sample, outputs, predictions, metrics = run(model, val_online_dataset[0])
print_sample(sample, print_ta=False, print_fm=True)
print_metrics(metrics)