### Dataclasses

In [None]:
from dataclasses import dataclass, asdict
from typing import Callable
import json
import yaml
from yaml import Loader, Dumper

@dataclass
class HP_alg:
  max_nactions: int = -1
  critic_play: bool = False
  actor_lr: float = 1e-4 # lr at time 0
  critic_lr: float = 5e-5
  train_condition_str: str = "(lambda tr: tr[0] % 5 != 2)"
  td_gamma: float = 0.998
  batch_size: int = 512
  max_attention_length: int = 64
  use_fork_discount: bool = False # True not supported by jar-file currently
  batch_accumulation_steps: int = 1
  epochs: int = 100
  use_GAE: bool = False
  use_double_atten: bool = False
  use_FFM: bool = True
  path_: str = '../Config/HP_alg.yaml'

hp_alg = HP_alg(**yaml.load(open('../Config/HP_alg.yml').read(), Loader=Loader))
hp_alg.max_nactions = int(1e9) if hp_alg.critic_play else hp_alg.max_attention_length


@dataclass
class HP_front:
  critic_play: bool = False
  state_scheme: tuple = tuple(json.load(open('../Config/schemes.json'))['stateScheme'])
  feature_names_scheme: tuple = tuple(state_scheme[0])
  trajectory_scheme: tuple = ('hash', 'trajectory', 'name', 'statementsCount', 'probabilities')
  between_logs: int = 64
  rnn_features_count: int = 33
  gnn_features_count: int = 8
  features_dim: int = len(feature_names_scheme)
  use_cuda_if_available: bool = True
  tr_hash_idx: int = trajectory_scheme.index('hash')
  tr_path_idx: int = trajectory_scheme.index('trajectory')
  tr_name_idx: int = trajectory_scheme.index('name')
  tr_code_len_idx: int = trajectory_scheme.index('statementsCount')
  tr_prev_probs_idx = trajectory_scheme.index('probabilities')
  state_queue_idx: int = 0
  state_chosenId_idx: int = 1
  state_reward_idx: int = 2
  cuda_sync: bool = False
  json_path: str = '../Data/current_dataset.json'
  jar_command: str = '/home/st-andrey-podivilov/java16/usr/lib/jvm/bellsoft-java16-amd64/bin/java -Dorg.jooq.no-logo=true -jar ../Game_env/usvm-jvm/build/libs/usvm-jvm-new.jar ../Game_env/jar_config.txt > ../Game_env/jar_log.txt'
  actor_model_path: str = '../Game_env/actor_model.onnx'
  path_: str = '../Config/HP_front.yaml'

hp_front = HP_front(**yaml.load(open('../Config/HP_front.yml').read(), Loader=Loader))
try:
  hp_front.state_scheme = tuple(json.load(open('../Game_env/schemes.json'))['stateScheme'])
except:
  print('No scheme, using the default')
hp_front.feature_names_scheme = tuple(hp_front.state_scheme[0])
hp_front.features_dim = len(hp_front.feature_names_scheme)
hp_front.tr_hash_idx = hp_front.trajectory_scheme.index('hash')
hp_front.tr_path_idx = hp_front.trajectory_scheme.index('trajectory')
hp_front.tr_name_idx = hp_front.trajectory_scheme.index('name')
hp_front.tr_code_len_idx = hp_front.trajectory_scheme.index('statementsCount')


@dataclass
class HP_back:
  shuffleTests: bool = False
  maxAttentionLength: int = -1 # effectively it's a max number of actions to choose from during data gathering or inference (last in queue)
  samplesPath: str = "../Game_env/usvm-jvm/src/samples/java"
  gameEnvPath: str = "../Game_env"
  dataPath: str = "../Data"
  defaultAlgorithm: str = "BFS" # "ForkDepthRandom"
  postprocessing: str = "Argmax" # "Softmax", "None"
  mode: str = "Both" # "Calculation", "Aggregation"
  inputShape: tuple = (-1, hp_front.features_dim) if hp_alg.critic_play else (1, -1, hp_front.features_dim)
  dataConsumption: float = 100.0
  hardTimeLimit: float = 30000
  solverTimeLimit: float = 10000
  maxConcurrency: int = 120
  graphUpdate: str = "Once" # "TestGeneration"
  logGraphFeatures: bool = False
  gnnFeaturesCount: int = 8
  rnnStateShape: tuple = (4, 1, 512)
  useGnn: bool = True
  useRnn: bool = True
  path_: str = '../Config/HP_back.yaml'
  inputJars: dict = None

hp_back = HP_back(**yaml.load(open('../Config/HP_back.yml').read(), Loader=Loader))
hp_back.inputShape = (-1, hp_front.features_dim) if hp_alg.critic_play else (1, -1, hp_front.features_dim)
hp_back.maxAttentionLength = -1 if hp_alg.critic_play else hp_alg.max_attention_length


def make_configs():
  with open('../Config/HP_alg.yml', 'w') as outfile:
      yaml.dump(asdict(hp_alg), outfile)
  with open('../Config/HP_front.yml', 'w') as outfile:
      yaml.dump(asdict(hp_front), outfile)
  with open('../Config/HP_back.yml', 'w') as outfile:
      yaml.dump(asdict(hp_back), outfile)
  with open('../Game_env/jar_config.txt', 'w') as jar_config:
    jar_config.write(json.dumps(asdict(hp_back)))

make_configs()

### Imports, meta

In [None]:
# %%capture

# %env CUDA_DEVICE_ORDER=PCI_BUS_ID
# %env CUDA_VISIBLE_DEVICES=4

# from IPython.display import Javascript
# def resize_colab_cell():
#   display(Javascript('google.colab.output.setIframeHeight(0, true, {maxHeight: 600})'))
# get_ipython().events.register('pre_run_cell', resize_colab_cell)

import warnings
import numpy as np
from numpy import random
import copy
import inspect
import torch
from torch import nn
import torch.onnx
import json
from tqdm import tqdm, trange
from time import time
import os
import math
from operator import itemgetter
import torch.nn.functional as F
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.data import Data as GraphData
from torch_geometric.nn import GCNConv


# !pip install wandb
import wandb

# !pip install onnx==1.12
# import onnx

wandb.login()

### Args/shortcuts


In [None]:
batch_size = hp_alg.batch_size
max_nactions = hp_alg.max_nactions
device = 'cuda' if (torch.cuda.is_available() and hp_front.use_cuda_if_available) else 'cpu'
td_gamma = hp_alg.td_gamma
json_path = hp_front.json_path
use_fork_discount = hp_alg.use_fork_discount
batch_accumulation_steps = hp_alg.batch_accumulation_steps
critic_play = hp_alg.critic_play
features_dim = hp_front.features_dim
train_condition = eval(hp_alg.train_condition_str)

tr_hash_idx = hp_front.tr_hash_idx
tr_path_idx = hp_front.tr_path_idx
tr_name_idx = hp_front.tr_name_idx
tr_prev_probs_idx = hp_front.tr_prev_probs_idx
state_queue_idx = hp_front.state_queue_idx
state_chosenId_idx = hp_front.state_chosenId_idx
state_reward_idx = hp_front.state_reward_idx

maybe_sync = torch.cuda.synchronize if hp_front.cuda_sync else (lambda *args: None)
jar_command = hp_front.jar_command
device

### Utils

In [None]:
def get_ass_dope_ids(l):
  l = l.long()
  if len(l.shape) == 1:
    l = l[:, None]
  ind = torch.LongTensor(np.indices(l.shape))
  ind[-1] = l
  return tuple(ind)

def get_clip_eps(epoch):
  if epoch < 5:
    c = 0.3
  elif epoch < 20:
    c = 0.1
  else:
    c = 0.05
  return c

### Models, modules

In [None]:
class FFM_layer(torch.nn.Module):
    def __init__(self, input_dim):
      super().__init__()
      # assert input_dim%2 == 0, 'even input_dim is more convenient'
      # self.fourier_matrix = torch.nn.Linear(input_dim, int(input_dim), bias=False)
      # nn.init.normal_(
      #     self.fourier_matrix.weight,
      #     std=1/np.sqrt(input_dim),
      # )
      # self.fourier_matrix.weight.requires_grad_(False)

    def forward(self, x):
      pre = x # self.fourier_matrix(x)
      s = torch.sin(pre)
      c = torch.cos(pre)
      return torch.cat([x,s,c], dim=-1)


class Attn_model(torch.nn.Module):
  '''
  PPO actor model.
  '''
  def __init__(
    self,
    d_model=512,
    n_heads=8,
    dim_feedforward=512,
    dropout=0.0,
  ):
    super(Attn_model, self).__init__()
    self.emb = nn.Sequential(
      nn.LazyLinear(512),
      nn.ReLU(),
      nn.LayerNorm(512),
      FFM_layer(512) if hp_alg.use_FFM else nn.Identity(),
      nn.LazyLinear(d_model),
      nn.ReLU(),
    )
    self.attn_EncoderLayer0 = nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout, batch_first=True)
    if hp_alg.use_double_atten:
      self.attn_EncoderLayer1 = nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout, batch_first=True)
    self.head = nn.Sequential(
      nn.LazyLinear(1),
    )
    self.sfmax = nn.Softmax(dim=-1)

  def forward(self, x, mask=None):
    x = self.emb(x)
    x = self.attn_EncoderLayer0(x, src_key_padding_mask=mask)
    if hp_alg.use_double_atten:
      x = self.attn_EncoderLayer1(x, src_key_padding_mask=mask)
    x = self.head(x).squeeze(-1)
    if mask is None:
      return self.sfmax(x)
    inf_mask = mask.float().masked_fill(mask==True, -float('inf'))
    x = self.sfmax(x + inf_mask)
    return x


class Q_Net(torch.nn.Module):
  """
  Builds Q-network based on critic.
  Currently is neither needed nor properly working (#reward_ind feature
  is not a function of true reward. Might be solved, is not).
  """
  def __init__(
    self,
    V_function,
  ):
    super(Q_Net, self).__init__()
    self.V_function = V_function
    self.reward_ind = hp_front.feature_names_scheme.index('logReward')

  def forward(self, feature_batch):
    v = self.V_function(feature_batch)
    q = v + feature_batch[:, self.reward_ind][:, None].exp() - 1
    return q


class GNN_model(torch.nn.Module):
  """
  GNN model used during jar-file calls.
  Was trained once as a part of critic network.
  """
  def __init__(
    self,
    num_input_features,
    num_output_features
  ):
    super().__init__()
    self.convs = nn.ModuleList([
      GCNConv(num_input_features, 16),
      GCNConv(16, 32),
      GCNConv(32, 32),
      GCNConv(32, 16),
      GCNConv(16, num_output_features),
    ])
    self.dropout_probs = [
      None,
      None,
      0.1,
      0.3,
    ]

  def forward(self, data_x, data_edge_index):
    x, edge_index = data_x, data_edge_index
    x = self.convs[0](x, edge_index)
    for conv, prob in zip(self.convs[1:], self.dropout_probs):
      x = F.relu(x)
      if prob is not None:
        x = F.dropout(x, p=prob, training=self.training)
      x = conv(x, edge_index)
    return x


class V_cell(torch.nn.Module):
  """
  We use V_cell inner representations to enrich a world embedding.
  Was trained once on a proxy seq2seq task (predicting Returns along a trajectory).
  """
  def __init__(
    self,
    inner_hidden_size = 512,
    top_hidden_size=32,
    bias=True,
    batch_first=True,
  ):
    super(V_cell, self).__init__()
    self.inner_hidden_size = inner_hidden_size
    self.top_hidden_size = top_hidden_size
    self.num_layers = 2
    self.emb = nn.Sequential(
      nn.LazyLinear(128),
      nn.LayerNorm(128),
    )
    self.lstm_cell = nn.LSTMCell(input_size=128, hidden_size=inner_hidden_size)
    self.dropout = nn.Dropout(p=0.2)
    self.lstm_cell2 = nn.LSTMCell(input_size=inner_hidden_size, hidden_size=top_hidden_size)
    self.head = nn.Sequential(
      nn.Dropout(p=0.1),
      nn.LazyLinear(128),
      nn.ReLU(),
      nn.LazyLinear(1),
    )

  def forward(self, input_batch, prev_state_batch):
    """
    input_batch: [batch_size, feature_size]
    prev_state_batch: [2*num_layers, batch_size, inner_hidden_size]
    Outputs: [batch_size, 1], [2*num_layers, batch_size, inner_hidden_size], [batch_size, top_hidden_size+1]

    Weird stuff happens to top hidden h2 so that state_batch can be kept nice and simple.
    """
    h_p = prev_state_batch[list(np.arange(self.num_layers) * 2)]
    c_p = prev_state_batch[list(np.arange(self.num_layers) * 2 + 1)]
    x = self.emb(input_batch)
    h1, c1 = self.lstm_cell(x, (h_p[0], c_p[0]))
    h1 = self.dropout(h1)
    h2, c2 = self.lstm_cell2(h1, (h_p[1][:, :self.top_hidden_size], c_p[1][:, :self.top_hidden_size]))
    v = self.head(h2)
    h2_rep = h2.repeat(1, self.inner_hidden_size//self.top_hidden_size)
    c2_rep = c2.repeat(1, self.inner_hidden_size//self.top_hidden_size)
    return v, torch.stack((h1, c1, h2_rep, c2_rep), dim=0), torch.cat([h2, v], dim=-1).detach()


def get_mlp_setup():
  '''
  Returns model and optimizer for Mr. critic of PPO.
  '''
    mlp = nn.Sequential(
        nn.LazyLinear(512),
        nn.ReLU(),
        nn.Linear(512,256),
        nn.LayerNorm(256),
        FFM_layer(256) if hp_alg.use_FFM else nn.Identity(),
        nn.LazyLinear(512),
        nn.ReLU(),
        nn.Linear(512,512),
        nn.ReLU(),
        nn.Linear(512,1),
    ).to(device)
    mlp_opt = torch.optim.AdamW(mlp.parameters(), lr=hp_alg.critic_lr, weight_decay=0.1, betas=(0.9, 0.999))
    return mlp, mlp_opt


def get_attn_setup(sched_total_iters=hp_alg.epochs):
  '''
  Returns model, optimizer and scheduler for Mr. actor of PPO.
  '''
  attn_model = Attn_model().to(device)
  opt = torch.optim.AdamW(attn_model.parameters(), lr=hp_alg.actor_lr, weight_decay=1e-2,)
  scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1, end_factor=0.1, total_iters=sched_total_iters, verbose=False)
  return attn_model, opt, scheduler

### Logger

In [None]:
class Logger:
  """
  Supporting class, to be expanded.
  Intended to store logging utils and relevant data.
  """
  def __init__(
      self,
  ):
    self.grad_a = None
    self.grad_c = None
    self.weight_a = None
    self.weight_c = None
    self.log_gamma = torch.tensor(0.95).to(device)
    self.between_logs = hp_front.between_logs
    self.timer = {}

  @torch.no_grad()
  def link_models(self,):
    self.grad_a = [p.grad.detach() for p in self.actor.parameters() if p.requires_grad]
    self.weight_a = [p.detach() for p in self.actor.parameters()]
    self.grad_c = [p.grad.detach() for p in self.critic.parameters() if p.requires_grad]
    self.weight_c = [p.detach() for p in self.critic.parameters()]

  @torch.no_grad()
  def list_norm(self, l, p=2):
    n = 0
    for t in l:
      n += t.detach().norm(p) ** p
    return n.item() ** (1/p)

  @torch.no_grad()
  def list_cos_dist(self, a, b):
    a_norm = self.list_norm(a, 2)
    b_norm = self.list_norm(b, 2)
    product = sum([torch.dot(torch.flatten(a[i]), torch.flatten(b[i])).item() for i in range(len(a))])
    return product/(a_norm*b_norm)

  @torch.no_grad()
  def on_list(self, a, b, operation):
    assert len(a) == len(b), 'lists lengths differ'
    return [operation(a[i], b[i]) for i in range(len(a))]

  @torch.no_grad()
  def running_mean(self, a, b):
    return [a[i].mul(self.log_gamma) + b[i].mul(1 - self.log_gamma) for i in range(len(a))]


logger = Logger()

### Data

In [None]:
class Trajectories:
  """
  Contains all kinds of data in a form of tensors.
  realized are raw and derived features of visited states.
  queues is a (3-dim) list of actions features by state.
  Action and state embeddings only differ in RNN part, fyi.
  """
  def __init__(self):
    self.train_condition = eval(hp_alg.train_condition_str)
    self.td_gamma = hp_alg.td_gamma
    self.j_file = None
    self.feature_names, self.feature_names2ids = None, None
    # dict of train tensors f, f_n, r, R, is_last, queue_len, chosen_acts, tr_id +
    # + list of queue tensors + list of previous probabilities tensors
    self.realized, self.queues, self.prev_probs = None, None, None
    # Psi is GAE critic estimates computed by trainer.GAE()
    self.Psi = None
    self.sampled_ids_list = []
    self.sampled_queue_lengths_list = []


  def gather_n_store(self, actor_model=None, json_path=hp_front.json_path):
    """
    Plays, collects trajectories into json, then transforms json to tensors
    """
    self.update_data_on_path(actor_model, json_path)
    self.store_from_json(json_path)


  def evaluate_val_train(self, verbose=True,):
    """
    Evaluates val and train subset
    """
    log_val = self.evaluate_data(eval_condition = (lambda tr: not self.train_condition(tr)),
                                 wandb_prefix = 'val',
                                 verbose=verbose,)
    log_train = self.evaluate_data(eval_condition = self.train_condition,
                                   wandb_prefix = 'train',
                                   verbose=verbose,)
    return log_val, log_train


  def store_from_json(self, json_path=hp_front.json_path):
    """
    Extracts trajectories from json to a torch-friendly form
    """
    time_before = time()
    self.j_file = json.load(open(json_path))
    print('json loading time: ', time()-time_before)
    self.feature_names = self.j_file['stateScheme'][0]
    self.feature_names2ids = {self.feature_names[i]:i for i in range(len(self.feature_names))}
    time_before = time()
    self.realized, self.queues, self.prev_probs = self.j2torch(self.j_file)
    print('j2torch time: ', time()-time_before)
    self.Psi = None


  def mean_code_length(self):
    train_lines = 0
    val_lines = 0
    for tr in self.j_file['paths']:
      if self.train_condition(tr):
        train_lines += tr[hp_front.tr_code_len_idx]
      else:
        val_lines += tr[hp_front.tr_code_len_idx]
    n_tr = self.get_properties()['number of traj-s']
    n_val_tr = self.get_properties()['number of validation traj-s']
    return train_lines/(n_tr-n_val_tr+1e-9), val_lines/(n_val_tr+1e-9)


  def j2torch(self, j_file):
    """
    Transforms json to data tensors.
    Queues are flipped for reasons related to batch truncation and padding. Actions numeration is changed accordingly.
    """
    features, features_next, rewards, Returns, is_last, queue_lengths, chosen_actions, queues = [], [], [], [], [], [], [], []
    if not critic_play:
      prev_probs = []
    logger.timer = {
      'tr_prev_probs': 0,
      'tr_features':0,
      'tr_queues':0,
    }
    for j, tr in enumerate(self.j_file['paths']):
      if not self.train_condition(tr):
        continue
      if not critic_play:
        tr_prev_probs = [torch.Tensor(dstrn).flip(dims=[0]).to(device) for dstrn in tr[hp_front.tr_prev_probs_idx]][1:] + [torch.tensor([1.0]).to(device)]
        prev_probs += tr_prev_probs
      tr = tr[hp_front.tr_path_idx]
      if hp_alg.critic_play: # use features as an action emb, reward is what we get for that action
        tr_rewards = [tr[i][hp_front.state_reward_idx] for i in range(len(tr))]
      else: # use features as a state emb, reward is what we get the next step
        tr_rewards = [tr[i][hp_front.state_reward_idx] for i in range(len(tr))][1:] + [0]
      rewards += tr_rewards
      do_discount = torch.ones(len(tr))
      if use_fork_discount:
        is_cfg_fork_idx = self.j_file['stateScheme'].index('is_cfg_fork')
        do_discount = [tr[i][is_cfg_fork_idx] for i in range(len(tr))]
      tr_Returns = self.tr_rewards_to_returns(tr_rewards, do_discount)
      Returns += tr_Returns

      tr_chosen_actions = [tr[i][hp_front.state_chosenId_idx] for i in range(len(tr))][1:] + [0]
      chosen_actions += tr_chosen_actions

      time0=time()
      tr_features = [tr[i][hp_front.state_queue_idx][tr[i][hp_front.state_chosenId_idx]] for i in range(len(tr))]
      is_last += [0]*(len(tr_features)-1) + [1]
      features += tr_features
      features_next += tr_features[1:] + [[-1]*hp_front.features_dim]
      logger.timer['tr_features'] += time() - time0

      time0=time()
      tr_queues = [torch.Tensor(tr[i][hp_front.state_queue_idx]).flip(dims=[0]).to(device) for i in range(len(tr))][1:]
      tr_queues += [torch.zeros_like(torch.Tensor([tr[0][hp_front.state_queue_idx][0]])).to(device)]
      logger.timer['tr_queues'] += time()-time0

      tr_queue_lengths = [len(q) for q in tr_queues]
      queues += tr_queues
      queue_lengths += tr_queue_lengths
    rewards = torch.Tensor(rewards).to(device)
    features = torch.Tensor(features).to(device)
    features_next = torch.Tensor(features_next).to(device)
    Returns = torch.Tensor(Returns).to(device)
    is_last = torch.Tensor(is_last).to(device)
    queue_lengths = torch.LongTensor(queue_lengths).to(device)
    chosen_actions = queue_lengths - torch.LongTensor(chosen_actions).to(device) - 1 # we flip queue and numeration of actions
    # print(logger.timer)
    realized = {
      'features': features,
      'features_next': features_next,
      'rewards': rewards,
      'Returns': Returns,
      'is_last': is_last,
      'queue_lengths': queue_lengths,
      'chosen_actions': chosen_actions,
    }
    if critic_play:
      prev_probs = None
    return realized, queues, prev_probs


  def get_n_train_states(self):
    return len(self.realized['features'])


  def get_properties(self):
    queue_lengths = np.array([q.shape[0] for q in self.queues])
    longest_queue_ids = np.argmax(queue_lengths)
    tr_lengths = np.array([len(tr[hp_front.tr_path_idx]) for tr in self.j_file['paths']])
    prop = {
            'traj_length mean, median, max': (f'{tr_lengths.mean():.2f}', np.median(tr_lengths), tr_lengths.max()),
            'queue max length, idx': (self.queues[longest_queue_ids].shape[0], longest_queue_ids),
            'number of train states': self.get_n_train_states(),
            'number of traj-s': len(self.j_file['paths']),
            'number of validation traj-s': sum([not self.train_condition(tr) for tr in self.j_file['paths']]),
           }
    return prop


  def tr_rewards_to_returns(self, tr_rewards, do_discount=None):
    if do_discount is None:
      do_discount = [1]*(len(tr_rewards))
    tr_R = [0]*(len(tr_rewards))
    for i in range(len(tr_rewards)-2, -1, -1):
        tr_R[i] = tr_rewards[i] + (self.td_gamma**do_discount[i]) * tr_R[i+1]
    return tr_R


  def sample_ids(self, batch_size=batch_size):
    """
    Assembles ids for several consiquent batches based on queues lengths.
    """
    if len(self.sampled_ids_list) == 0:
      sampled_ids = torch.LongTensor(random.choice(self.get_n_train_states(), size=hp_alg.batch_size*hp_alg.batch_accumulation_steps)).to(device)
      sampled_queue_lengths = self.realized['queue_lengths'][sampled_ids]
      sampled_queue_lengths, sorting_ids = torch.sort(sampled_queue_lengths)
      sampled_ids = sampled_ids[sorting_ids]
      self.sampled_ids_list = list(torch.split(sampled_ids, batch_size))
      self.sampled_queue_lengths_list = list(torch.split(sampled_queue_lengths, batch_size))
    return self.sampled_ids_list.pop(0), self.sampled_queue_lengths_list.pop(0)


  def sample_batch(self, batch_size=batch_size):
    ids, queue_lengths = self.sample_ids(batch_size=batch_size)
    sampled_realized = {k: self.realized[k][ids] for k in self.realized.keys()}
    sampled_queues = itemgetter(*list(ids))(self.queues)
    sampled_prev_probs = itemgetter(*list(ids))(self.prev_probs)
    padded_length = torch.minimum(queue_lengths.max(), torch.tensor(hp_alg.max_nactions))
    queues_tensor = torch.zeros(hp_alg.batch_size, padded_length, hp_front.features_dim).to(device)
    prev_probs_tensor = torch.zeros(hp_alg.batch_size, padded_length).to(device)
    pad_mask = torch.ones(hp_alg.batch_size, padded_length).to(device)

    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()

    for i, q in enumerate(sampled_queues):
      l = min(q.shape[0], padded_length)
      queues_tensor[i, 0 : l, :] = q[: l, :]
      pad_mask[i, 0 : l] = 0
    pad_mask = pad_mask.bool()

    for i, p in enumerate(sampled_prev_probs):
      l = min(p.shape[0], padded_length)
      prev_probs_tensor[i, 0 : l] = p[: l]

    end.record()
    maybe_sync()
    logger.timer['sample inside loop'] += start.elapsed_time(end)/1000
    if not self.Psi is None:
      Psi = self.Psi[ids]
    else:
      Psi = None
    return sampled_realized, Psi, queues_tensor, pad_mask, prev_probs_tensor


  def sample_realized_batch(self, n=hp_alg.batch_size):
    '''
    A lightened version of sample_batch, can be used in critic_play (Policy Iteration) mode.
    '''
    ids = torch.tensor(random.choice(self.get_n_train_states(), size=n)).long()
    sampled = {k: self.realized[k][ids] for k in self.realized.keys()}
    return sampled


  def update_data_on_path(self, actor_model=None, path=hp_front.json_path):
    """
    Communication with jar file on a server.
    """
    if actor_model is None:
      player_name = 'Heuristic' # using default heuristics
      os.system('rm -f ' + hp_front.actor_model_path)
    else:
      player_name = 'NN'
      total_features_dim = hp_front.features_dim
      if hasattr(actor_model, 'sfmax'):
        shape = [1, 1, total_features_dim]
      else:
        shape = [1, total_features_dim]
      x = torch.randn(*shape, requires_grad=True).to(device)
      torch_model = actor_model.eval()
      torch_out = torch_model(x)
      torch.onnx.export(torch_model,
                        x,
                        '../Game_env/actor_model.onnx',
                        opset_version=15,
                        export_params=True,
                        input_names = ['input'],   # the model's input names
                        output_names = ['output'],
                        dynamic_axes={'input' : {0 : 'batch_size',
                                                 1 : 'n_actions',
                                                },    # variable length axes
                                      'output' : {0 : 'batch_size',
                                                  1 : 'n_actions',
                                                },
                                     },
                        )
    time_before = time()
    os.system(hp_front.jar_command)
    print(player_name, ' data gathering time: ', time() - time_before)


  def evaluate_data(self,
           eval_condition = (lambda tr: not self.train_condition(tr)),
           wandb_prefix = 'val',
           verbose=True,
           factors=torch.Tensor([1, hp_alg.td_gamma]),
           ):
    for f in factors:
      size = 0
      tr_lengths = []
      trs_R = []
      for tr in self.j_file['paths']:
        if not eval_condition(tr):
          continue
        tr=tr[1]
        size += len(tr)
        tr_lengths += [len(tr)]
        tr_rewards = [tr[i][state_reward_idx] for i in range(len(tr))] # rewards list is not shifted, has a different purpose this time
        trs_R += [0]
        do_discount = torch.Tensor([1]*len(tr))
        if use_fork_discount:
          is_cfg_fork_idx = self.j_file['stateScheme'].index('is_cfg_fork')
          do_discount = [tr[i][is_cfg_fork_idx] for i in range(len(tr))]
        for i in range(len(tr_rewards)-1, -1, -1):
          trs_R[-1] = tr_rewards[i] + (f ** do_discount[i]) * trs_R[-1]
      log = {}
      log[f'{wandb_prefix} size'] = size
      log[f'{wandb_prefix}_eval/mean {f:.3f} discount '] = torch.Tensor(trs_R).mean()
      log[f'{wandb_prefix}_eval/median {f:.3f} discount '] = torch.Tensor(trs_R).median()
      log[f'{wandb_prefix} Return by trjs {f:.3f} (previous epoch)'] = wandb.Histogram(np_histogram=np.histogram(trs_R, bins=30, ))
      log['code length train mean'], log['code length val mean'] = self.mean_code_length()
      if verbose:
          wandb.log(log.copy())
      log['Returns'] = trs_R
      log[f'{wandb_prefix} lengths'] = tr_lengths
    return log

### Trainer


In [None]:
class NN_Trainer:
  def __init__(
      self,
      NN_setup,
      trajectories,
      train_progress,
      clip_eps = None,
      batch_size=hp_alg.batch_size,
      n_batches=1000,
      target_update_steps=15,
      td_gamma=hp_alg.td_gamma,
      ):
    self.n_batches = n_batches
    self.batch_number = -1
    self.td_gamma = td_gamma
    self.gae_gamma = 0.7
    self.clip_eps = clip_eps
    self.actor = NN_setup['actor']
    self.actor_opt = NN_setup['actor_opt']
    self.actor_sched = NN_setup['actor_sched']
    self.prev_actor = copy.deepcopy(self.actor).eval()
    self.critic = NN_setup['critic'].train()
    self.target_critic = copy.deepcopy(self.critic).eval()
    self.critic_opt = NN_setup['critic_opt']
    self.trajectories = trajectories
    self.batch_size = batch_size
    self.target_update_steps = target_update_steps
    self.train_progress = train_progress
    self.log = {}

  def get_each_loss(self, sampled_batch,):
    """
    Computes losses for actor, critic and exploration (loss_ent) within PPO algorithm.
    Decisions were made to avoid python loops --
    varying action space is not particularly batch-friendly.
    """
    sampled_realized, Psi, queues_tensor, pad_mask, prev_probs_tensor = sampled_batch
    self.log['Returns mean'] = sampled_realized['Returns'].mean().item()
    self.log['rewards mean'] = sampled_realized['rewards'].mean().item()
    self.log['queue length max'] = sampled_realized['queue_lengths'].max().item()
    probs = self.actor(queues_tensor, mask=pad_mask)
    self.log['max prob mean'] = probs.max(dim=-1).values.mean().item()
    self.log['40 max prob quantile'] = torch.quantile(probs.max(dim=-1).values, 0.4).item()

    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()

    values = self.critic(sampled_realized['features']).squeeze(-1)
    with torch.no_grad():
      next_values = self.target_critic(sampled_realized['features_next']).squeeze(-1)

    end.record()
    maybe_sync()
    logger.timer['values'] += start.elapsed_time(end)/1000

    self.log['V-func mean'] = torch.mean(values.detach()).item()
    self.log['V-func stdev'] = torch.std(values.detach()).item()
    # hist = wandb.Histogram(np_histogram=np.histogram(values.detach().to('cpu'), bins=40, ))
    # self.log['V-func hist'] = hist

    # critic loss
    TD = values - ((sampled_realized['rewards'] + next_values.detach() * self.td_gamma) * (1-sampled_realized['is_last']))
    MC = (values - sampled_realized['Returns']).abs().mean()/100
    loss_c = (TD**2).mean() + MC

    self.log['TD loss'] = (TD**2).mean().item()
    self.log['MC loss'] = MC.item()
    # hist = wandb.Histogram(np_histogram=np.histogram(TD.detach().to('cpu'), bins=20, ))
    # self.log['TD hist'] = hist

    # entropy loss
    t_entropy_loss = time()

    entropies = - probs * torch.log(torch.max(torch.tensor(1e-40), probs))
    entropy_regularizer = torch.log(torch.minimum(sampled_realized['queue_lengths']+1, torch.tensor(hp_alg.max_nactions+1))).to(device)
    entropy_by_state_reg = torch.sum(entropies, dim=-1) / entropy_regularizer
    loss_ent = -entropy_by_state_reg.mean()

    logger.timer['entropy loss'] += time()-t_entropy_loss

    # actor loss
    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()

    probs_chosen = probs[get_ass_dope_ids(sampled_realized['chosen_actions'])].squeeze(-1)
    prev_probs_chosen = prev_probs_tensor[get_ass_dope_ids(sampled_realized['chosen_actions'])].squeeze(-1)
    ratios = (probs_chosen / (prev_probs_chosen.detach()+1e-4)).to(device)

    clipped = torch.clip(ratios, min=1-self.clip_eps, max=1+self.clip_eps)
    Adv = - (TD * (1-sampled_realized['is_last'])).detach() # to not affect loss_a logs
    self.log['Adv mean'] = Adv.sum() / (Adv.numel() - sampled_realized['is_last'].sum())
    if self.trajectories.Psi is None:
      loss_a = - torch.min(ratios*Adv, clipped*Adv).mean()
    else:
      loss_a = - torch.min(ratios*Psi, clipped*Psi).mean()
    self.log['clipped to all'] = (clipped*Adv < ratios*Adv).long().sum().item() / ratios.numel()
    self.log['ratios mean'] = ratios.mean().item()
    end.record()
    maybe_sync()
    logger.timer['actor loss'] += start.elapsed_time(end)/1000
    return loss_a, loss_c, loss_ent


  def critic_loss(self,
                  sampled_realized,
    ):
    """
    Temporal difference loss for PPO
    """
    values = self.critic(sampled_realized['features']).squeeze(-1)
    with torch.no_grad():
      next_values = self.target_critic(sampled_realized['features_next']).squeeze(-1)
    self.log['V-func mean'] = torch.mean(values.detach()).item()
    TD = values - ((sampled_realized['rewards'] + next_values * self.td_gamma) * (1-sampled_realized['is_last']))
    MC = (values - sampled_realized['Returns']).abs().mean()/100
    critic_loss = (TD**2).mean() + MC
    self.log['TD loss'] = (TD**2).mean().item()
    self.log['MC loss'] = MC.item()
    return critic_loss


  def q_loss(self,
                  sampled_realized,
    ):
    """
    Temporal difference loss for PI
    """
    q = self.critic(sampled_realized['features']).squeeze(-1)
    with torch.no_grad():
      next_q = self.target_critic(sampled_realized['features_next']).squeeze(-1)
    self.log['Q-func mean'] = torch.mean(q.detach()).item()
    TD = q - (sampled_realized['rewards'] + next_q * self.td_gamma * (1-sampled_realized['is_last']))
    MC = (q - sampled_realized['Returns']).abs().mean()/100
    q_loss = (TD**2).mean() + MC
    self.log['TD loss'] = (TD**2).mean().item()
    self.log['MC loss'] = MC.item()
    return q_loss


  def learn_q(self, n_batches=None):
    """
    Approximates Q-function of previous policy by iterating over collected data.
    critic_play mode required (check realized['rewards'] lag in PPO setup).
    """
    if n_batches is None:
      n_batches = self.n_batches
    for i in trange(n_batches):
      self.batch_number = i
      if self.batch_number % self.target_update_steps == 0:
        self.target_critic = copy.deepcopy(self.critic).eval()
      sampled_realized = self.trajectories.sample_realized_batch(n=2048)
      loss = self.q_loss(sampled_realized)
      self.critic_opt.zero_grad()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 30)
      self.critic_opt.step()
      if self.batch_number % (logger.between_logs+1) == 0:
        self.log['weight critic'] = logger.list_norm([p for p in self.critic.parameters() if p.requires_grad], 2)
        self.log['grad critic'] = logger.list_norm([p.grad for p in self.critic.parameters() if p.requires_grad], 2)
        wandb.log(self.log)


  def learn_new_policy(self, prev_actor=None):
    """
    Implements one learning cycle over collected dataset.
    """
    if prev_actor is None:
      self.prev_actor = copy.deepcopy(self.actor).eval()
    else:
      self.prev_actor = prev_actor
    logger.timer = {'values': 0,
                    'entropy loss': 0,
                    'actor loss': 0,
                    'total loss': 0,
                    'optimizers step': 0,
                    'loss.backward': 0,
                    'sample batch': 0,
                    'sample ids': 0,
                    'sample inside loop': 0,
                    }
    for i in trange(self.n_batches):
      self.batch_number = i
      if self.batch_number % self.target_update_steps == 0:
        self.target_critic = copy.deepcopy(self.critic).eval()

      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()

      sampled_batch = self.trajectories.sample_batch(self.batch_size)

      end.record()
      maybe_sync()
      logger.timer['sample batch'] += start.elapsed_time(end)/1000

      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()

      loss_a, loss_c, loss_ent = self.get_each_loss(sampled_batch)
      loss = (loss_a + loss_c + loss_ent/(10 + self.train_progress*20)) / batch_accumulation_steps

      end.record()
      maybe_sync()
      logger.timer['total loss'] += start.elapsed_time(end)/1000

      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()

      loss.backward()

      end.record()
      maybe_sync()
      logger.timer['loss.backward'] += start.elapsed_time(end)/1000

      if self.batch_number % batch_accumulation_steps == 0:
        t_optimizers_step = time()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 30)
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 5)
        self.actor_opt.step()
        self.critic_opt.step()
        self.log.update({
            'grad actor': logger.list_norm([p.grad for p in self.actor.parameters() if p.requires_grad], 2),
            'grad critic': logger.list_norm([p.grad for p in self.critic.parameters() if p.requires_grad], 2),
        })
        self.critic_opt.zero_grad()
        self.actor_opt.zero_grad()
        logger.timer['optimizers step'] += time()-t_optimizers_step
      if self.batch_number % (logger.between_logs+1) == 0:
        self.log.update({
            'loss actor': loss_a.item(),
            'loss entropy': (loss_ent/(5 + self.train_progress*10)).item(),
            'loss critic': loss_c.item(),
            '-entropy by log(n)': loss_ent.item(),
            'weight actor': logger.list_norm([p for p in self.actor.parameters() if p.requires_grad], 2),
            'weight critic': logger.list_norm([p for p in self.critic.parameters() if p.requires_grad], 2),
        })
        wandb.log(self.log)
        self.log = {}
    self.actor_sched.step()


  @torch.no_grad()
  def compute_GAE(self, critic_to_use=None):
    """
    Computes GAE with passed/current critic.
    GAEs could be used instead of Advantages in PPO actor loss
    """
    if critic_to_use is None:
      critic_to_use = copy.deepcopy(self.critic).eval()
    features = self.trajectories.realized['features']
    features_next = self.trajectories.realized['features_next']
    rewards = self.trajectories.realized['rewards']
    is_last = self.trajectories.realized['is_last']
    Adv1 = []
    ids_list = list(torch.split(torch.LongTensor(np.arange(len(is_last))), 4096))
    for batch_ids in ids_list:
      values = critic_to_use(features[batch_ids]).squeeze(-1)
      values_next = critic_to_use(features_next[batch_ids]).squeeze(-1)
      TD = values - ((rewards[batch_ids] + values_next * self.td_gamma) * (1-is_last[batch_ids]))
      batch_Adv = - TD * (1-is_last[batch_ids])
      Adv1 += [batch_Adv]
    Adv1 = torch.cat(Adv1, dim=0)
    Psi = torch.zeros_like(Adv1)

    for i in range(len(Adv1)-1, -1, -1):
      if is_last[i]:
        continue
      Psi[i] = Psi[i+1] * self.gae_gamma * self.td_gamma + Adv1[i]
    self.trajectories.Psi = Psi
    wandb.log({'Psi mean': Psi.mean().item(),
               'Psi 95 perc': torch.quantile(Psi, 0.95).item(),
               'Adv mean': Adv1.mean().item(),
              })

In [None]:
class RNN_Trainer():
  """
  Learns state history representation via a proxy task of V-function approximation.
  """
  def __init__(self, v_cell, optimizer, trajectories, loss_horizon=10, rnn_batch_size=32):
    self.rnn_cell = v_cell
    self.optimizer = optimizer
    self.loss_horizon = loss_horizon
    self.retain_graph = True
    self.trjs = trajectories
    self.log = {}

  def sample_ids_pairs(self, features_by_trjs, rnn_batch_size, n_steps):
    """
    Returns [n_steps] list with [rnn_batch_size, 2] tensors
    of pairs (traj idx, position in traj for rnn to start from)
    """
    n_trjs = len(features_by_trjs)
    residual_lengths = torch.tensor([len(f)-self.loss_horizon for f in features_by_trjs])
    sampling_weights = np.log(residual_lengths) / np.log(residual_lengths).sum()
    trjs_ids = torch.tensor(np.random.choice(np.arange(n_trjs), size=(n_steps, rnn_batch_size), p=sampling_weights))
    position_samples = torch.rand(n_steps, rnn_batch_size)
    residual_lengths_forall = torch.take(residual_lengths, trjs_ids)
    start_ids = (residual_lengths_forall*position_samples).long()
    stacked = torch.stack([trjs_ids, start_ids], dim=-1).to(device)
    return list(stacked)

  def train(self, n_steps=1000, rnn_batch_size=32, train_condition=eval(hp_alg.train_condition_str)):
    is_last_ids = torch.nonzero(self.trjs.realized['is_last']).squeeze()
    all_train_trjs_lengths = is_last_ids - torch.cat([torch.tensor([-1]).to(device), is_last_ids[:-1]])
    not_long_enough_ids = (all_train_trjs_lengths < self.loss_horizon + 3).nonzero().squeeze()
    donotlook_ids = torch.tensor([not train_condition(tr) for tr in self.trjs.j_file['paths']]).nonzero().squeeze()
    features_by_trjs = list(torch.split(self.trjs.realized['features'], list(all_train_trjs_lengths)))
    Returns_by_trjs = list(torch.split(self.trjs.realized['Returns'], list(all_train_trjs_lengths)))
    for i in sorted(list(set(donotlook_ids.tolist()+not_long_enough_ids.tolist())), reverse=True):
      # we train ony on long enough trajectories
      popped = features_by_trjs.pop(i)
      Returns_by_trjs.pop(i)
    print(len(not_long_enough_ids), len(Returns_by_trjs))
    sampled_ids = self.sample_ids_pairs(features_by_trjs, rnn_batch_size, n_steps)
    for step in trange(n_steps):
      batch_ids = sampled_ids.pop(0)
      inputs = torch.zeros(rnn_batch_size, self.loss_horizon, features_dim).to(device)
      targets = torch.zeros(rnn_batch_size, self.loss_horizon, 1).to(device)
      for b in range(rnn_batch_size):
        inputs[b] = features_by_trjs[batch_ids[b,0]][batch_ids[b,1]: batch_ids[b,1] + self.loss_horizon]
        targets[b] = Returns_by_trjs[batch_ids[b,0]][batch_ids[b,1]: batch_ids[b,1] + self.loss_horizon][:, None]
      states = [torch.zeros(2*self.rnn_cell.num_layers, rnn_batch_size, self.rnn_cell.inner_hidden_size).to(device)]
      Vs = []
      for rnn_step in range(self.loss_horizon):
        v, new_state, h2 = self.rnn_cell(inputs[:, rnn_step, :], states[-1])
        states += [new_state]
        Vs += [v]
      abs_difs = (torch.stack(Vs, dim=1) - targets).abs()
      assert abs_difs.shape == (rnn_batch_size, self.loss_horizon, 1)
      loss_weights = torch.arange(self.loss_horizon)**1.4 / (torch.arange(self.loss_horizon)**1.4).sum()
      loss = (abs_difs * loss_weights[:, None].to(device)).mean()
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      if step % logger.between_logs == 0:
        self.log['rnn MC abs loss'] = loss.item()
        self.log['rnn V mean'] = torch.cat(Vs).mean().item()
        self.log['RNN weight'] = logger.list_norm([p for p in self.rnn_cell.parameters() if p.requires_grad], 2)
        self.log['RNN grad'] = logger.list_norm([p.grad for p in self.rnn_cell.parameters() if p.requires_grad], 2)
        hist = wandb.Histogram(np_histogram=np.histogram(torch.cat(Vs).flatten().detach().to('cpu'), bins=40, ))
        self.log['V-func hist'] = hist
        hist = wandb.Histogram(np_histogram=np.histogram(abs_difs[:, -1,:].flatten().detach().to('cpu'), bins=40, ))
        self.log['abs difs -1'] = hist
        hist = wandb.Histogram(np_histogram=np.histogram(abs_difs[:, 0,:].flatten().detach().to('cpu'), bins=40, ))
        self.log['abs difs 0'] = hist
        wandb.log(self.log)

  @torch.no_grad()
  def collect_rnn_features(self):
    """
    Happened to be not relevant in the end.
    Collects rnn hidden and V for every state in train dataset
    """
    features = self.trjs.realized['features']
    is_last = self.trjs.realized['is_last']
    hid_list = []
    v_list = []
    init_state = torch.zeros(2*self.rnn_cell.num_layers, 1, self.rnn_cell.inner_hidden_size).to(device)
    state = init_state
    for i, f in enumerate(features):
      v, state = self.rnn_cell(f[None, :], state)
      hid = state[2].squeeze(0)
      hid_list += [hid.detach()]
      v_list += [v.detach()]
      if is_last[i]:
        state = init_state
    self.traj.realized['rnn_features'] = torch.stack(hid_list, dim=0)
    self.traj.realized['rnn_v'] = torch.stack(v_list, dim=0)

  @torch.no_grad()
  def concat_rnn_features(self):
    """
    Happened to be not relevant in the end.
    Concatenates state history embedding to embedding of that state and to embeddings of corresponding actions
    """
    realized = self.trjs.realized
    rnn_related = torch.cat([realized['rnn_features'], realized['rnn_v']], dim=-1)
    realized['features'] = torch.cat([realized['features'], rnn_related], dim=-1)
    queues = self.trjs.queues
    time_before = time()
    for i in range(len(queues)):
      queues[i] = torch.cat([queues[i], rnn_related[i].repeat(len(queues[i]), 1)], dim=-1)
    print('Concating rnn features to queues: ', time()-time_before)

### Procedures

In [None]:
if hp_alg.critic_play:
  run = wandb.init(
          project="PS",
          name=f'PI',
          config={}
  )
  trajectories = Trajectories()

  actor, actor_opt, actor_sched = get_attn_setup() # not to be used
  critic, critic_opt = get_mlp_setup()
  q_net = Q_Net(V_function=critic)

  trajectories.gather_n_store() # default heuristic play (ForkDeothRandom) to initialize
  trajectories.evaluate_val_train()
  print(trajectories.get_properties())

  for epoch in range(hp_alg.epochs):
      trainer = NN_Trainer(NN_setup={'actor': actor, 'actor_opt': actor_opt, 'actor_sched': actor_sched,
                                     'critic': critic, 'critic_opt': critic_opt,},
                           trajectories=trajectories,
                           train_progress=epoch/hp_alg.epochs,
                           )
      trainer.learn_v(n_batches=2000 if epoch==0 else 400)
      wandb.log({'epoch': epoch})
      trajectories.gather_n_store(actor_model=q_net)
      trajectories.evaluate_val_train()
      print(trajectories.get_properties())

  checkpoint = {
    'critic':critic,
  }
  torch.save(checkpoint, os.path.join(wandb.run.dir, f'critic model'))
  wandb.finish()

In [None]:
if not hp_alg.critic_play:
  actor, actor_opt, actor_sched = get_attn_setup()
  critic, critic_opt = get_mlp_setup()
  trajectories = Trajectories()
  run = wandb.init(
        project='PS',
        name=f'PPO',
        config={}
  )
  try:
    checkpoint_path = '../Checkpoints/actor_critic27_08'
    init_actor = torch.load(checkpoint_path)['actor']
  except:
    print(no checkpoint model, using random init)
  else:
    init_actor = actor
  trajectories.gather_n_store(actor_model=init_actor) # init data gathering has to be performed by some policy
  trajectories.evaluate_val_train()
  print(trajectories.get_properties())

  for epoch in range(hp_alg.epochs):
      trainer = NN_Trainer(NN_setup={'actor': actor, 'actor_opt': actor_opt, 'actor_sched': actor_sched,
                                     'critic': critic, 'critic_opt': critic_opt,},
                           clip_eps = get_clip_eps(epoch),
                           trajectories=trajectories,
                           train_progress=epoch/hp_alg.epochs,
                           n_batches=1000 if epoch==0 else 300,
                           )
      trainer.learn_v(2000 if epoch==0 else 50)
      trainer.learn_new_policy()
      wandb.log({'epoch': epoch})
      trajectories.gather_n_store(actor_model=actor)
      trajectories.evaluate_val_train()
      print(trajectories.get_properties())
  checkpoint = {
    'actor': actor,
  }
  torch.save(checkpoint, os.path.join(wandb.run.dir, f'actor model'))
  wandb.finish()

In [None]:
# exit()

### Data filter

In [None]:
def data_filter(clear_blacklist=False, verbose=True):
  '''
  To raise training efficiency one might filter out useless trajectories.
  '''
  trajectories = Trajectories()
  trajectories.gather_n_store()
  print(trajectories.get_properties())

  if clear_blacklist:
    open('../Game_env/blacklist.txt', 'w').close()

  paths = trajectories.j_file['paths']

  lengths = [len(paths[i][1]) for i in range(len(paths))]
  lengths_t = torch.Tensor(lengths)

  mean_queue_len = [torch.Tensor([len(q[0]) for q in paths[i][1]]).mean() for i in range(len(paths))]
  mean_queue_len_t = torch.Tensor(mean_queue_len)

  ids = torch.nonzero((mean_queue_len_t<=1.1)).long().squeeze()
  if verbose:
    print('lengths: ', [lengths_t.quantile(10*i/100) for i in range(10)])
    print('mean queue len: ', [mean_queue_len_t.quantile(10*i/100) for i in range(11)])
    print('sum length where mean_queue_len_t<=1.1: ', lengths_t[torch.nonzero(mean_queue_len_t<=1.1).squeeze()].sum())
    print('traj with short queues: ', len(ids))
  bad_paths = [paths[i] for i in ids]

  with open('../Game_env/blacklist.txt', 'a') as f:
    for i in range(len(ids)):
      s = bad_paths[i][2]
      f.write(s + '\n')

  log = trajectories.evaluate_data(eval_condition = (lambda tr: True),
                                       verbose=False,
                                       factors=torch.Tensor([1]),
                                      )
  Returns = torch.Tensor(log['Returns'])
  ids = torch.nonzero((Returns==0).long()).long().squeeze()
  if verbose:
    print('len Returns, len lengths: ', len(Returns), len(lengths))
    print('sum length where Returns==0: ', lengths_t[torch.nonzero(Returns==0).squeeze()].sum())
    print('traj with zero Return: ', len(ids))
  bad_paths = [paths[i] for i in ids]

  with open('../Game_env/blacklist.txt', 'a') as f:
    for i in range(len(ids)):
      s = bad_paths[i][2]
      f.write(s + '\n')

### The end