In [34]:
import sys
root_path = '../'
sys.path.insert(0, root_path)

from functools import partial
from torch.utils.data import DataLoader

import problems as p
# from src_2d.data import OfflineDataset, OnlineDataset, custom_collate_fn
# from src_2d.train import DPTSolver
from src_1d.data import OfflineDataset, OnlineDataset, custom_collate_fn
from src_1d.train import DPTSolver
from run import load_config
from utils import *

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
config = load_config('../config.yaml')
print(f'problem: {config["problem"]}({config["problem_params"]})')
print(f'seq_len: {config["model_params"]["seq_len"]}')

problem_class = getattr(p, config["problem"])
problems = [problem_class(**config["problem_params"], seed=seed) for seed in range(10)]

In [37]:
model = DPTSolver(config)

### Offline Dataset

In [None]:
batch_size = 2

dataset = OfflineDataset(problems, seq_len=config["model_params"]["seq_len"])
collate_fn = partial(custom_collate_fn, problem_class=problem_class)
dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn)
batch = next(iter(dataloader))
batch

In [None]:
print('query_state', batch["query_state"].shape, batch["query_state"].dtype)
print('states', batch["states"].shape, batch["states"].dtype)
print('actions', batch["actions"].shape, batch["actions"].dtype)
print('next_states', batch["next_states"].shape, batch["next_states"].dtype)
print('rewards', batch["rewards"].shape, batch["rewards"].dtype)
print('target_action', batch["target_action"].shape, batch["target_action"].dtype)
print('problem', batch["problem"], type(batch["problem"]))

In [None]:
outputs = model._offline_step(batch)
outputs

In [None]:
print('outputs', outputs["outputs"].shape, outputs["outputs"].dtype)
print('predictions', outputs["predictions"].shape, outputs["predictions"].dtype)
print('targets', outputs["targets"].shape, outputs["targets"].dtype)

In [42]:
# outputs["predictions"], outputs["targets"]

In [None]:
model.training_step(batch, 0)

In [None]:
model.eval()
model.validation_step(batch, 0)

In [None]:
print_sample(dataset[0])

In [None]:
sample, outputs, predictions, metrics = run(model, dataset[0])
print_sample(sample, predictions)
print_metrics(metrics)

### Online Dataset

In [None]:
dataset = OnlineDataset(problems)
collate_fn = partial(custom_collate_fn, problem_class=problem_class)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
batch = next(iter(dataloader))
batch

In [None]:
print('query_state', batch["query_state"].shape, batch["query_state"].dtype)
print('target_state', batch["target_state"].shape, batch["target_state"].dtype)
print('problem', batch["problem"], type(batch["problem"]))

In [None]:
outputs = model._online_step(batch)
outputs

In [None]:
print('outputs', outputs["outputs"].shape, outputs["outputs"].dtype)
print('predictions', outputs["predictions"].shape, outputs["predictions"].dtype)
print('targets', outputs["targets"].shape, outputs["targets"].dtype)

In [None]:
model.eval()
model.test_step(batch, 0)

In [None]:
sample, outputs, predictions, metrics = run(model, dataset[0])
print_sample(sample, print_ta=False, print_fm=True)
print_metrics(metrics)