In [1]:
%load_ext autoreload
%autoreload 2

# Train DPT

If there is no `results` directory, run

In [None]:
# !bash dpt_run.sh

In [2]:
from solvers.dpt.train import DPTSolver

  from tqdm.autonotebook import tqdm


In [3]:
model = DPTSolver('config.yaml')

If there is, run

In [4]:
from functools import partial
import torch
from problems import Net
from utils import *


def transition_function(state, action, problem):
    point = int2bin(action, d=problem.d, n=problem.n)
    state = torch.tensor(problem.target(point))
    return state

def _eval_function(model, problem):
    all_actions = get_xaxis(d=problem.d, n=problem.n)
    all_states = problem.target(all_actions)
    target_state = torch.tensor([all_states.min()])
    query_state = torch.tensor([all_states.max()])

    trajectory = model.test(query_state, partial(transition_function, problem=problem))

    best_found_state = trajectory[3].min()
    accuracy = (query_state - best_found_state) / (query_state - target_state)
    return accuracy.item()

def eval_function(model, problems, keys):
    return {
        key: _eval_function(model, problem) 
        for problem, key in zip(problems, keys)
    }

train_problem = Net(d=4, n=2, seed=1)
test_problem = Net(d=4, n=2, seed=0)

eval_function = partial(
    eval_function, 
    problems=(train_problem, test_problem), 
    keys=('accuracy (train problem)', 'accuracy (test problem)')
)
model.train(eval_function)

Loading training histories...
Num histories: 6000


[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33manabatsheva[0m ([33manabatsheva_sk[0m). Use [1m`wandb login --relogin`[0m to force relogin


Training: 0it [00:00, ?it/s]

In [19]:
# !python3 train_dpt.py

# Test DPT

Get a model and run it on a `Net` problem.

In [None]:
from GreyBoxDPTOptimizer_.test_dpt import load_model, test
from problems import Net


model = load_model("../GreyBoxDPTOptimizerData/checkpoints/model_last.pt")
model.eval()
problem = Net(d=4, n=2, seed=1)
result = test(model, problem, 2)

Run it on a proble, of the class `Net`

In [None]:
import numpy as np
from torch.nn import functional as F
from utils import int2bin
from problems import Net
from utils import get_xaxis


p = Net(d=4, n=2, seed=1)
all_targets = p.target(get_xaxis(d=4, n=2))

temp = 0.5
do_samples = True

query_states = torch.tensor([all_targets.max()])
context_states = torch.Tensor(1, 0)
context_next_states = torch.Tensor(1, 0)
context_actions = torch.Tensor(1, 0)
context_rewards = torch.Tensor(1, 0)

for _ in range(10):
    predicted_actions = model(
        query_states=query_states.to(dtype=torch.float, device=DEVICE),
        context_states=context_states.to(dtype=torch.float, device=DEVICE),
        context_next_states=context_next_states.to(dtype=torch.float, device=DEVICE),
        context_actions=context_actions.to(dtype=torch.long, device=DEVICE),
        context_rewards=context_rewards.to(dtype=torch.float, device=DEVICE),
    )
    temp = 1.0 if temp <= 0 else temp
    probs = F.softmax(predicted_actions / temp, dim=-1)
    if do_samples:
        predicted_action = torch.multinomial(probs, num_samples=1).squeeze(1).cpu()[0]
    else:
        predicted_action = torch.argmax(probs, dim=-1).cpu()[0]

    # predicted_action = torch.argmax(F.softmax(predicted_actions, dim=1)).cpu()
    # делать сэмплирование из распределения, а не брать моду
    # добавить температуру
    point = int2bin(predicted_action, d=4, n=2)
    target = torch.tensor(p.target(point))
    print(f'step {_} | current target: {query_states.item():>8.6} -> suggested point: {point} -> new target: {target.item():.6}')

    context_states = torch.cat([context_states, query_states.unsqueeze(0)], dim=1)
    context_next_states = torch.cat([context_next_states, target.unsqueeze(0)], dim=1)
    context_actions = torch.cat([context_actions, torch.tensor([predicted_action]).unsqueeze(0)], dim=1)
    context_rewards = torch.cat([context_rewards, (target - query_states).unsqueeze(0)], dim=1)
    query_states = target

print()
print(f'found minimal value: {target.item():.6}')
print(f'ground truth: {all_targets.min().item():.6}')
print()
print(f'all possible targets in an order:\n{np.sort(all_targets)}')

# Extra

other extra things

## Offline Test

In [None]:
# # dataset = MarkovianDataset("trajectories/", seq_len=200)
# # dataloader = DataLoader(dataset=dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)

# for batch in dataloader:
#     (
#         query_flag,
#         query_states,
#         flags,
#         states,
#         actions,
#         next_flags,
#         next_states,
#         rewards,
#         target_actions,
#     ) = [b.to(dtype=torch.float, device=DEVICE) for b in batch]
#     break

# actions = actions.to(torch.long)

# predicted_actions = model(
#     query_states=query_states,
#     context_states=states,
#     context_next_states=next_states,
#     context_actions=actions,
#     context_rewards=rewards,
# )

# predicted_action = torch.argmax(F.softmax(predicted_actions, dim=1))
# target_actions, predicted_action

## Train step

In [None]:
for i in range(2):
    print(f"Problem {i}...")
    !bash ./run.sh $i

In [None]:
from solvers.dpt.src.utils.data import results2trajectories

results2trajectories('results', 'solvers/dpt/trajectories')

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F

from solvers.dpt.src.utils.data import MarkovianDataset
from GreyBoxDPTOptimizer_.solvers.dpt.model import DPT_K2D

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


dataset = MarkovianDataset('solvers/dpt/trajectories', seq_len=50)
dataloader = DataLoader(
    dataset=dataset,
    batch_size=4,
    pin_memory=True,
    shuffle=False,
    num_workers=0,
)

In [None]:
model = DPT_K2D(
    num_states=1,
    num_actions=16,
    hidden_dim=512,
    seq_len=50,
    num_layers=4,
    num_heads=4,
    attention_dropout=0.5,
    residual_dropout=0.1,
    embedding_dropout=0.3,
    normalize_qk=False,
    pre_norm=True,
    rnn_weights_path=None,
    state_rnn_embedding=16,
    rnn_dropout=0.0,
).to(DEVICE)

In [None]:
for batch in dataloader:
    (
        query_states,
        states,
        actions,
        next_states,
        rewards,
        target_actions,
    ) = [b.to(DEVICE) for b in batch]
    break

In [None]:
query_states = query_states.to(torch.float)
states = states.to(torch.float)
actions = actions.to(torch.long)
next_states = next_states.to(torch.float)
rewards = rewards.to(torch.float32)

In [None]:
target_actions_onehot = (
    F.one_hot(target_actions.squeeze(-1), num_classes=16)
    .unsqueeze(1)
    .repeat(1, 50, 1)
    .float()
)

In [None]:
predicted_actions = model(
    query_states=query_states,
    context_states=states,
    context_next_states=next_states,
    context_actions=actions,
    context_rewards=rewards,
)
predicted_actions = predicted_actions[:, 1:, :]

In [None]:
loss = F.cross_entropy(
    input=predicted_actions.flatten(0, 1),
    target=target_actions_onehot.flatten(0, 1),
    label_smoothing=0.0,
)

In [None]:
loss

## things

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

logits = torch.tensor(np.hstack([np.arange(5), np.arange(4)[::-1]])).to(torch.float)
d = len(logits)
bins = np.arange(d+1)
N = 1000

temp = 1
probs = torch.nn.functional.softmax(logits / temp, dim=-1)
next_tokens = torch.multinomial(probs, num_samples=N, replacement=True)#.squeeze(1)
distr = np.histogram(next_tokens, bins=bins)[0] / N
plt.plot(bins[:-1]+0.5, distr, '-o', label=f't={temp}')
# plt.hist(next_tokens, bins=bins, density=True, rwidth=0.9)

probs = torch.nn.functional.softmax(logits, dim=-1)
next_tokens = torch.argmax(logits, dim=-1)
distr = np.zeros(d)
distr[next_tokens] = 1
plt.plot(bins[:-1]+0.5, distr, '-o', label='argmax')
# plt.hist([next_tokens], bins=bins, density=False, rwidth=0.9)

# plt.xticks(bins[:-1]+0.5, bins[:-1])
plt.legend()
plt.show()