In [None]:
%load_ext autoreload
%autoreload 2

# Train DPT

In [None]:
!bash dpt_run.sh
# the same thing !python3 solvers/dpt/train_dpt.py

# Test DPT

In [None]:
import torch
from solvers.dpt.src.model_dpt import DPT_K2D

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


model = DPT_K2D(
    num_states=1,
    num_actions=16,
    hidden_dim=512,
    seq_len=50,
    num_layers=4,
    num_heads=4,
    attention_dropout=0.5,
    residual_dropout=0.1,
    embedding_dropout=0.3,
    normalize_qk=False,
    pre_norm=True,
    rnn_weights_path=None,
    state_rnn_embedding=1,
    rnn_dropout=0.0,
).to(DEVICE)

checkpoint_path = "solvers/dpt/checkpoints/model_last.pt"
checkpoint = torch.load(checkpoint_path, weights_only=True)
model.load_state_dict(checkpoint)

model.eval()

In [None]:
import numpy as np
from torch.nn import functional as F
from utils import int2bin
from problems import Net
from utils import get_xaxis


p = Net(d=4, n=2, seed=100)
all_targets = p.target(get_xaxis(d=4, n=2))

query_states = torch.tensor([all_targets.max()])
context_states = torch.Tensor(1, 0)
context_actions = torch.Tensor(1, 0)
context_rewards = torch.Tensor(1, 0)

for _ in range(10):
    predicted_actions = model(
        query_states=query_states.to(dtype=torch.float, device=DEVICE),
        context_states=context_states.to(dtype=torch.float, device=DEVICE),
        context_next_states=None,
        context_actions=context_actions.to(dtype=torch.long, device=DEVICE),
        context_rewards=context_rewards.to(dtype=torch.float, device=DEVICE),
    )    
    predicted_action = torch.argmax(F.softmax(predicted_actions, dim=1)).cpu()
    point = int2bin(predicted_action, d=4, n=2)
    target = torch.tensor(p.target(point))
    print(f'step {_} | current target: {query_states.item():>8.6} -> suggested point: {point} -> new target: {target.item():.6}')

    context_states = torch.cat([context_states, target.unsqueeze(0)], dim=1)
    context_actions = torch.cat([context_actions, torch.tensor([predicted_action]).unsqueeze(0)], dim=1)
    context_rewards = torch.cat([context_rewards, (target - query_states).unsqueeze(0)], dim=1)
    query_states = target

print()
print(f'found minimal value: {target.item():.6}')
print(f'ground truth: {all_targets.min().item():.6}')
print()
print(f'all possible targets in an order:\n{np.sort(all_targets)}')

# Extra

## Offline Test

In [None]:
# # dataset = MarkovianDataset("trajectories/", seq_len=200)
# # dataloader = DataLoader(dataset=dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)

# for batch in dataloader:
#     (
#         query_flag,
#         query_states,
#         flags,
#         states,
#         actions,
#         next_flags,
#         next_states,
#         rewards,
#         target_actions,
#     ) = [b.to(dtype=torch.float, device=DEVICE) for b in batch]
#     break

# actions = actions.to(torch.long)

# predicted_actions = model(
#     query_states=query_states,
#     context_states=states,
#     context_next_states=next_states,
#     context_actions=actions,
#     context_rewards=rewards,
# )

# predicted_action = torch.argmax(F.softmax(predicted_actions, dim=1))
# target_actions, predicted_action

## Train step

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F

from solvers.dpt.src.utils.data import MarkovianDataset
from solvers.dpt.src.model_dpt import DPT_K2D

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


dataset = MarkovianDataset('trajectories', seq_len=200)
dataloader = DataLoader(
    dataset=dataset,
    batch_size=128,
    pin_memory=True,
    shuffle=False,
    num_workers=0,
)

In [None]:
model = DPT_K2D(
    num_states=1,
    num_actions=10,
    hidden_dim=512,
    seq_len=200,
    num_layers=4,
    num_heads=4,
    attention_dropout=0.5,
    residual_dropout=0.1,
    embedding_dropout=0.3,
    normalize_qk=False,
    pre_norm=True,
    rnn_weights_path=None,
    state_rnn_embedding=16,
    rnn_dropout=0.0,
).to(DEVICE)

In [None]:
for batch in dataloader:
    (
        query_flag,
        query_states,
        flags,
        states,
        actions,
        next_flags,
        next_states,
        rewards,
        target_actions,
    ) = [b.to(DEVICE) for b in batch]
    break

In [None]:
query_states = query_states.to(torch.float)
states = states.to(torch.float)
actions = actions.to(torch.long)
next_states = next_states.to(torch.float)
rewards = rewards.to(torch.float32)

target_actions = target_actions.squeeze(-1)
target_actions = (
    F.one_hot(target_actions, num_classes=10)
    .unsqueeze(1)
    .repeat(1, 200, 1)
    .float()
)

In [None]:
predicted_actions = model(
    query_states=query_states,
    context_states=states,
    context_next_states=next_states,
    context_actions=actions,
    context_rewards=rewards,
)
predicted_actions = predicted_actions[:, 1:, :]

In [None]:
loss = F.cross_entropy(
    input=predicted_actions.flatten(0, 1),
    target=target_actions.flatten(0, 1),
    label_smoothing=0.0,
)