In [35]:
import sys
root_path = '../'
sys.path.insert(0, root_path)

import lightning as L
from natsort import natsorted

from run import *
from utils import *

In [None]:
%load_ext autoreload
%autoreload 2

In [37]:
run_name = "strategy21_exp3"
root_dir = os.path.join("../..", "results", "DPT_2", run_name, "checkpoints")
# checkpoint_file = get_best_checkpoint(root_dir, 'epoch')
checkpoint_file = os.path.join(root_dir, "epoch=999.ckpt")

model = DPTSolver.load_from_checkpoint(checkpoint_file)

In [None]:
config = model.config
config['n_problems'] = 100
dl = get_dataloaders(config)

train_offline_dataset = dl["train_dataloaders"].dataset
val_offline_dataset = dl["val_dataloaders"][0].dataset
val_online_dataset = dl["val_dataloaders"][1].dataset

In [None]:
tester = L.Trainer(
    logger=False,
    precision=config["precision"]
)
test_dataloader = dl["val_dataloaders"][1]

hparams = [
    {"do_sample": False, "temperature": 0.0},
    # {"do_sample": True, "temperature": 1.0},
]
results_list = []
for hparam in hparams:
    model.config["do_sample"] = hparam["do_sample"]
    model.config["temperature"] = hparam["temperature"]
    tester.test(model=model, dataloaders=test_dataloader)#, verbose=False)
    results = {key: val.cpu().tolist() for key, val in model.save_results.items()}
    results_list.append({"label": f't={hparam["temperature"]:.1f}'} | results)

# save_path = "../../results/online_inference_last_epoch_t=1"
# with open(f"{save_path}.json", "w") as f:
#     json.dump(results, f)

In [None]:
from matplotlib import colormaps as cm

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

axes[0].set_title('MAE (x, x*)')
axes[1].set_title('MAE (y, y*)')

cmap = cm.get_cmap('jet')
color_list = [cmap(c) for c in np.linspace(0.1, 0.9, len(results_list))[::-1]]
for results, c in zip(results_list, color_list):
    label = results["label"]
    axes[0].plot(results["x_mae"], c=c, label=label)
    axes[1].plot(results["y_mae"], c=c, label=label)

for ax in axes:
    ax.set_xlim(0, None)
    ax.set_xlabel('Step')
    ax.legend(loc=1)

axes[1].set_yscale('log')

plt.tight_layout()
plt.show()

In [None]:
sample, outputs, predictions, metrics = run(model, train_offline_dataset[0])
print_sample(sample, predictions)
print_metrics(metrics)

In [None]:
model.config["do_sample"] = False
model.config["temperature"] = 0.0
sample, outputs, predictions, metrics = run(model, val_online_dataset[0])
print_sample(sample, print_ta=False, print_fm=True)
print_metrics(metrics)

In [62]:
# # plt.plot(outputs[0].detach().numpy())
# # plt.plot(outputs[1].detach().numpy())
# # plt.plot(outputs[2].detach().numpy())
# from torch.nn.functional import softmax

# p = outputs[-1].detach()
# for t in (1.0, 2.0, 3.0, 5.0, 10.0):
#     plt.plot(softmax(p / t, -1).numpy(), label=t)
# plt.legend()
# plt.show()

### Online Inference per Epoch

In [None]:
tester = L.Trainer(
    logger=False,
    precision=config["precision"]
)
test_dataloader = dl["val_dataloaders"][1]
checkpoints = natsorted(os.listdir(root_dir))
results_accumulated = defaultdict(list)

for checkpoint in checkpoints:
    checkpoint_file = os.path.join(root_dir, checkpoint)
    model = DPTSolver.load_from_checkpoint(checkpoint_file)
    results = tester.test(model=model, dataloaders=test_dataloader)#, verbose=False)
    for key, val in results[0].items():
        results_accumulated[key].append(val)
results_accumulated = dict(results_accumulated)

In [None]:
save_path = "../../results/online_inference_t=1"

with open(f"{save_path}.json", "w") as f: 
    json.dump(results_accumulated, f)

fig, axes = plt.subplots(1, len(results_accumulated), figsize=(12, 4))

for ax, (key, vals) in zip(axes, results_accumulated.items()):
    ax.set_title(key)
    ax.plot(np.arange(1, len(vals)+1) * 50, vals)
    ax.set_xlim(0, None)
    ax.set_xlabel('Epoch')

# plt.show()
plt.tight_layout()
plt.savefig(f"{save_path}.png")

In [None]:
from matplotlib import colormaps as cm

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

rpath_list = (
    "../../results/online_inference_argmax",
    "../../results/online_inference_t=1"
)
cmap = cm.get_cmap('jet')
color_list = [cmap(c) for c in np.linspace(0.1, 0.9, len(rpath_list))[::-1]]

axes[0].set_title('MAE (x, x*)')
axes[1].set_title('MAE (y, y*)')

for read_path, c in zip(rpath_list, color_list):
    with open(f"{read_path}.json") as f:
        results = json.load(f)

    for ax, (key, vals) in zip(axes, results.items()):
        # ax.set_title(key)
        epoch_list = np.arange(1, len(vals)+1) * 50
        ax.plot(epoch_list, vals, c=c, label=read_path.split('_')[-1])
        ax.set_xlim(0, None)
        ax.set_xlabel('Epoch')
        ax.legend(loc=1)
        
axes[1].set_yscale('log')

plt.tight_layout()
plt.savefig(f"{save_path}.png")