In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
root_path = '../'
sys.path.insert(0, root_path)

In [None]:
import os
import numpy as np
import lightning as L
from functools import partial
from natsort import natsorted
from torch.utils.data import DataLoader

os.environ['CUDA_VISIBLE_DEVICES'] = "0"

from problems import Problem
from scripts.create_problem import load_problem_set
from train_dpt import DPTSolver, custom_collate_fn, OnlineDataset
from utils import *

In [None]:
def get_checkpoint(run_name):
    root_dir = os.path.join("../results", "DPT_3", run_name, "checkpoints")
    checkpoint = natsorted(os.listdir(root_dir))[-1]
    checkpoint_file = os.path.join(root_dir, checkpoint)
    return checkpoint_file

In [None]:
def run_model(model, read_dir, problem, name, suffix='test', budget=100):
    problem_path = os.path.join(read_dir, problem, suffix)
    problems = load_problem_set(problem_path)
    dataset = OnlineDataset(problems)
    collate_fn = partial(custom_collate_fn, problem_class=Problem)
    dataloader = DataLoader(
        dataset=dataset,
        batch_size=1000,
        num_workers=1,
        pin_memory=True,
        shuffle=False,
        collate_fn=collate_fn
    )
    tester = L.Trainer(logger=False, precision=model.config["precision"])
    logs = {}
    for warmup, do_sample in ((0, False), (0, True), (50, False), (50, True)):
        model.config["online_steps"] = int(budget - warmup)
        model.config["do_sample"] = do_sample
        model.config["warmup"] = warmup
        
        with torch.inference_mode():
            tester.test(model=model, dataloaders=dataloader)

        warmup_mode = "warmup" if warmup > 0 else "no warmup"
        sample_mode = "sample" if do_sample else "argmax"
        logs[f"{name} ({warmup_mode}) ({sample_mode})"] = {
            "m_list": np.arange(budget) + 1,
            "y_list (mean)": model.trajectory.cpu().numpy()
        }
    return logs

In [None]:
device = "cuda:0"
run = "4q309u12"
checkpoint_file = get_checkpoint(run)
model = DPTSolver.load_from_checkpoint(checkpoint_file)
model = model.to(device)

In [None]:
read_dir = "data/test"
problem = "Normal(25, 1)"
budget = 2 * model.config["model_params"]["seq_len"]
logs = run_model(model, read_dir, problem, "AD", budget=budget)

In [None]:
read_dir = "results/test"
problem_list = (problem,) # natsorted(os.listdir(read_dir))

meta_results = defaultdict(dict)
# for problem in problem_list:
for solver in ('RandomSearch', 'PSO', 'PROTES'):
    meta_results[problem][solver] = get_meta_results(problem, solver, read_dir, suffix='test', budget=400)

In [None]:
meta_results[problem] |= logs
show_meta_results(meta_results)

In [None]:
# warmup = 0
# model.config["temperature"] = lambda x: math.sqrt(x)
# model.config["temperature"] = lambda x: 5 - 4 * x
# model.config["temperature"] = lambda x: 1 / math.sqrt(1 + x)