This is an implementation of PPO-clip for path selection for symbolic execution.
Each epoch we communicate with jar-file for data gathering and wandb for logging.

### Imports, meta

In [1]:
# %%capture

%env CUDA_DEVICE_ORDER=PCI_BUS_ID
%env CUDA_VISIBLE_DEVICES=4

from IPython.display import Javascript
def resize_colab_cell():
  display(Javascript('google.colab.output.setIframeHeight(0, true, {maxHeight: 600})'))
get_ipython().events.register('pre_run_cell', resize_colab_cell)

import warnings
import numpy as np
from numpy import random
import copy
import inspect
import torch
from torch import nn
import torch.onnx
import json
from tqdm import tqdm, trange
from time import time
import os
import math
from operator import itemgetter
import torch.nn.functional as F
from torch_geometric.loader import DataLoader as GraphDataLoader
from torch_geometric.data import Data as GraphData
from torch_geometric.nn import GCNConv


# !pip install wandb
import wandb

# !pip install onnx==1.12
# import onnx

wandb.login()

env: CUDA_DEVICE_ORDER=PCI_BUS_ID
env: CUDA_VISIBLE_DEVICES=4


[34m[1mwandb[0m: Currently logged in as: [33mandrey_podivilov[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### Args (potentially immutable), login

In [2]:
batch_size = 512
max_nactions = 64
device = 'cuda' if torch.cuda.is_available() else 'cpu'
td_gamma=0.998
json_path = '../Data/current_dataset.json'
use_fork_discount = False
batch_accumulation_steps = 1
cuda_sync = False
use_gnn = False
gnn_in_nfeatures = 7
gnn_out_nfeatures = 16
features_dim = 36 + (gnn_out_nfeatures if use_gnn else 0)

maybe_sync = torch.cuda.synchronize if cuda_sync else (lambda *args: None)
jar_command = '/home/st-andrey-podivilov/java16/usr/lib/jvm/bellsoft-java16-amd64/bin/java -Dorg.jooq.no-logo=true -jar ../Game_env/usvm-jvm/build/libs/usvm-jvm-new.jar ../Game_env/jar_config.txt > ../Game_env/jar_log.txt'
device

<IPython.core.display.Javascript object>

'cuda'

### Models, modules

In [3]:
class FFM_layer(torch.nn.Module):
    """
    wtf
    """
    def __init__(self, input_dim):
      super().__init__()
      assert input_dim%2 == 0, 'even input_dim is more convenient'
      self.fourier_matrix = torch.nn.Linear(input_dim, int(input_dim), bias=False)
      nn.init.normal_(
          self.fourier_matrix.weight,
          std=1/np.sqrt(input_dim),
      )
      self.fourier_matrix.weight.requires_grad_(False)

    def forward(self, x):
      pre = x # self.fourier_matrix(x)
      s = torch.sin(pre)
      c = torch.cos(pre)
      return torch.cat([x,s,c], dim=-1)
    
  
class Attn_model(torch.nn.Module):
  def __init__(
    self,
    d_model=512,
    n_heads=8,
    dim_feedforward=512, 
    dropout=0.0,    
    use_FFM=True,
  ):
    super(Attn_model, self).__init__()
    self.emb = nn.Sequential(
      nn.LazyLinear(512),
      nn.ReLU(),
      nn.LayerNorm(512),
      FFM_layer(512) if use_FFM else nn.Identity(),
      nn.LazyLinear(d_model),
      nn.ReLU(),
    )
    self.attn_EncoderLayer0 = nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout, batch_first=True)
    self.head = nn.Sequential(
      nn.LazyLinear(1),
    )
    self.sfmax = nn.Softmax(dim=-1)
    
  def forward(self, x, mask=None):
    x = self.emb(x)
    x = self.attn_EncoderLayer0(x, src_key_padding_mask=mask)
    x = self.head(x).squeeze(-1)
    if mask is None:
      return self.sfmax(x)
    inf_mask = mask.float().masked_fill(mask==True, -float('inf'))    
    x = self.sfmax(x + inf_mask)
    return x
  
  
class Double_attn_model(torch.nn.Module):
  def __init__(
    self,
    d_model=512,
    n_heads=8,
    dim_feedforward=512, 
    dropout=0.0,    
    use_FFM=True,
  ):
    super(Double_attn_model, self).__init__()
    self.emb = nn.Sequential(
      nn.LazyLinear(512),
      nn.ReLU(),
      nn.LayerNorm(512),
      FFM_layer(512) if use_FFM else nn.Identity(),
      nn.LazyLinear(d_model),
      nn.ReLU(),
    )
    self.attn_EncoderLayer0 = nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout, batch_first=True)
    self.attn_EncoderLayer1 = nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout, batch_first=True)
    self.head = nn.Sequential(
      nn.LazyLinear(1),
    )
    self.sfmax = nn.Softmax(dim=-1)
    
  def forward(self, x, mask=None):
    x = self.emb(x)
    x = self.attn_EncoderLayer0(x, src_key_padding_mask=mask)
    x = self.attn_EncoderLayer1(x, src_key_padding_mask=mask)
    x = self.head(x).squeeze(-1)
    if mask is None:
      return self.sfmax(x)
    inf_mask = mask.float().masked_fill(mask==True, -float('inf'))    
    x = self.sfmax(x + inf_mask)
    return x

  
class V_cell(torch.nn.Module):
  """
  Predicts Returns by traversing trajectories.
  We use its inner representations to enrich world embedding.
  """
  def __init__(
    self,
    hidden_size=64,
    dropout=0.0,
    bias=True, 
    batch_first=True,
  ):
    super(V_cell, self).__init__()
    self.hidden_size = hidden_size
    self.num_layers = 2
    self.emb = nn.Sequential(
      nn.LazyLinear(64),
      nn.LayerNorm(64),
    )
    self.lstm_cell = nn.LSTMCell(input_size=64, hidden_size=hidden_size)
    self.lstm_cell2 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
    self.head = nn.Sequential(
      nn.Dropout(p=dropout),
      nn.LazyLinear(128),
      nn.ReLU(),
      nn.LazyLinear(1),
    )
    
  def forward(self, input_batch, prev_state_batch):
    """
    input_batch: [batch_size, feature_size]
    prev_state_batch: [2*num_layers, batch_size, hidden_size]
    Outputs: [batch_size, 1], [2*num_layers, batch_size, hidden_size]
    """
    h_p = prev_state_batch[[np.arange(self.num_layers) * 2]]
    c_p = prev_state_batch[[np.arange(self.num_layers) * 2 + 1]]
    x = self.emb(input_batch)
    h1, c1 = self.lstm_cell(x, (h_p[0], c_p[0]))
    h2, c2 = self.lstm_cell2(h1, (h_p[1], c_p[1]))
    v = self.head(h2)
    return v, torch.stack((h1, c1, h2, c2), dim=0)
  

class Q_Net(torch.nn.Module):
  """
  Builds Q-network based on critic
  """
  def __init__(
    self,
    V_function,
    reward_ind,
  ):
    super(Q_Net, self).__init__()
    self.V_function = V_function
    self.reward_ind = reward_ind
    
  def forward(self, feature_batch):
    v = self.V_function(feature_batch)
    q = v + feature_batch[:, self.reward_ind][:, None].exp() - 1
    return q


class GNN_model(torch.nn.Module):
  def __init__(
    self,
    num_input_features,
    num_output_features
  ):
    super().__init__()
    self.convs = nn.ModuleList([
      GCNConv(num_input_features, 16),
      GCNConv(16, 32),
      GCNConv(32, 32),
      GCNConv(32, 16),
      GCNConv(16, num_output_features),
    ])
    self.dropout_probs = [
      None,
      None,
      0.1,
      0.3
    ]

  def forward(self, data_x, data_edge_index):
    x, edge_index = data_x, data_edge_index
    x = self.convs[0](x, edge_index)
    for conv, prob in zip(self.convs[1:], self.dropout_probs):
      x = F.relu(x)
      if prob is not None:
        x = F.dropout(x, p=prob, training=self.training)
      x = conv(x, edge_index)
    return x


def get_mlp_setup(use_FFM=False,
                  lr=1e-4,
                  wd=0.05,
                 ):
    mlp = nn.Sequential(
        nn.LazyLinear(512),
        nn.ReLU(),
        nn.Linear(512,256),
        nn.LayerNorm(256),
        FFM_layer(256) if use_FFM else nn.Identity(),
        nn.LazyLinear(512),
        nn.ReLU(),
        nn.Linear(512,512),
        nn.ReLU(),
        nn.Linear(512,1),
    ).to(device)
    mlp_opt = torch.optim.AdamW(mlp.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.999))
    return mlp, mlp_opt
  
  
def get_attn_setup(epochs, use_double=False):
  attn_model = Attn_model().to(device)
  if use_double:
    attn_model = Double_attn_model().to(device)
  opt = torch.optim.AdamW(attn_model.parameters(), lr=1e-4, weight_decay=1e-2,)
  scheduler = torch.optim.lr_scheduler.LinearLR(opt, start_factor=1, end_factor=0.1, total_iters=epochs, verbose=True)
  return attn_model, opt, scheduler


def get_gnn_setup(num_input_features, num_output_features):
  if not use_gnn:
    return None, None
  gnn_model = GNN_model(num_input_features, num_output_features).to(device)
  opt = torch.optim.AdamW(gnn_model.parameters(), lr=3e-4, weight_decay=1e-2)
  return gnn_model, opt


def get_rnn_setup():
  rnn_cell = V_cell()
  opt =  torch.optim.AdamW(rnn_cell.parameters(), lr=3e-4, weight_decay=1e-2)
  return rnn_cell, opt

<IPython.core.display.Javascript object>

### Logger

In [4]:
class Logger:
  """
  Supporting class, to be expanded.
  Intended to store logging utils and relevant data.
  """
  def __init__(
      self,
      between_logs = 64,
  ):
    self.grad_a = None
    self.grad_c = None
    self.weight_a = None
    self.weight_c = None
    self.log_gamma = torch.tensor(0.95).to(device)
    self.between_logs = between_logs
    self.timer = {}

  @torch.no_grad()
  def link_models(self,):
    self.grad_a = [p.grad.detach() for p in self.actor.parameters() if p.requires_grad]
    self.weight_a = [p.detach() for p in self.actor.parameters()]
    self.grad_c = [p.grad.detach() for p in self.critic.parameters() if p.requires_grad]
    self.weight_c = [p.detach() for p in self.critic.parameters()]

  @torch.no_grad()
  def list_norm(self, l, p=2):
    n = 0
    for t in l:
      n += t.detach().norm(p) ** p
    return n.item() ** (1/p)

  @torch.no_grad()
  def list_cos_dist(self, a, b):
    a_norm = self.list_norm(a, 2)
    b_norm = self.list_norm(b, 2)
    product = sum([torch.dot(torch.flatten(a[i]), torch.flatten(b[i])).item() for i in range(len(a))])
    return product/(a_norm*b_norm)

  @torch.no_grad()
  def on_list(self, a, b, operation):
    assert len(a) == len(b), 'lists lengths differ'
    return [operation(a[i], b[i]) for i in range(len(a))]

  @torch.no_grad()
  def running_mean(self, a, b):
    return [a[i].mul(self.log_gamma) + b[i].mul(1 - self.log_gamma) for i in range(len(a))]
  
  
logger = Logger()

<IPython.core.display.Javascript object>

### Data

In [5]:
class Trajectories:
  """
  Contains all kinds of data in a form of tensor.
  realized are raw and derived features of visited states.
  queues is a list of each states' actions features.
  Action and state embeddings are effectively the same, fyi. 
  """
  def __init__(self,
               train_condition,
               td_gamma=td_gamma,
               use_gnn=use_gnn,
               gnn_in_nfeatures=gnn_in_nfeatures,
               gnn_out_nfeatures=gnn_out_nfeatures,
              ):
    self.train_condition = train_condition
    self.td_gamma = td_gamma
    self.use_gnn = use_gnn
    self.gnn_in_nfeatures = gnn_in_nfeatures
    self.gnn_out_nfeatures = gnn_out_nfeatures
    self.j_file = None
    self.feature_names, self.feature_names2ids = None, None
    # dict of train tensors f, f_n, r, R, is_last, queue_len, chosen_acts, tr_id, graph_id +
    # + list of queue tensors + list of block ids tensors + list of graphs data + list of previous probabilities tensors
    self.realized, self.queues, self.queue_block_ids, self.graphs_data, self.prev_probs = None, None, None, None, None
    # GAE critic estimates for each state estimated by trainer.GAE()
    self.Psi = None
    self.sampled_ids_list = []
    self.sampled_queue_lengths_list = []
    
  
  def gather_n_store(self, actor_model=None, gnn_model=None, json_path=json_path):
    """
    Plays, collects trajectories into json, then transforms json to tensors
    """
    self.update_data_on_path(actor_model, gnn_model, json_path)
    self.store_from_json(json_path)
    
    
  def evaluate_val_train(self, verbose=True,):
    """
    Evaluates val and train subset
    """
    log_val = self.evaluate_data(eval_condition = (lambda tr: not self.train_condition(tr)),
                                 wandb_prefix = 'val', 
                                 verbose=verbose,)
    log_train = self.evaluate_data(eval_condition = self.train_condition,
                                   wandb_prefix = 'train',
                                   verbose=verbose,)
    return log_val, log_train
  
  
  def store_from_json(self, json_path=json_path):
    """
    Extracts trajectories from json to a torch-friendly form
    """
    time_before = time()
    self.j_file = json.load(open(json_path))
    print('json loading time: ', time()-time_before)
    self.feature_names = self.j_file['scheme'][0]
    self.feature_names2ids = {self.feature_names[i]:i for i in range(len(self.feature_names))}
    time_before = time()
    self.realized, self.queues, self.queue_block_ids, self.graphs_data, self.prev_probs = self.j2torch(self.j_file)
    print('j2torch time: ', time()-time_before)
    self.Psi = None
    
  
  def mean_code_length(self):
    train_lines = 0
    val_lines = 0
    for tr in self.j_file['paths']:
      if self.train_condition(tr):
        train_lines += tr[3]
      else:
        val_lines += tr[3]
    n_train_tr = self.get_properties()['number of traj-s']
    n_val_tr = self.get_properties()['number of validation traj-s']
    return train_lines/(n_train_tr-n_val_tr+1e-9), val_lines/(n_val_tr+1e-9)

    
  def j2torch(self, j_file):
    """
    Transforms json to data tensors.
    Queues are flipped for reasons related to batch truncation and padding. Actions numeration is changed accordingly.
    """
    features, features_next, rewards, Returns, is_last, queue_lengths, chosen_actions, queues = [], [], [], [], [], [], [], []
    graph_features, graph_edges, tr_ids, graph_ids, queue_block_ids = [], [], [], [], []
    prev_probs = []
    chosenStId_idx = self.j_file['scheme'].index('chosenStateId')
    rewards_idx = self.j_file['scheme'].index('reward')
    graphId_idx = self.j_file['scheme'].index('graphId')
    blockIds_idx = self.j_file['scheme'].index('blockIds')
    
    logger.timer = {
      'tr_prev_probs': 0,
      'graph_features':0,
      'tr_features':0,
      'tr_graph_ids':0,
      'tr_queues':0,
      'cat graph features':0,
      'graphs_data':0,
    }
    
    tr_id = 0
    for tr in self.j_file['paths']:
      
      time0=time()
      if not self.train_condition(tr):
        continue
      if self.use_gnn:
        graph_features.append(tr[4])
        graph_edges.append(tr[5])
      logger.timer['graph_features'] += time()-time0
      
      time0 = time()
      tr_prev_probs = [torch.Tensor(tr[6][i]).flip(dims=[0]).to(device) for i in range(len(tr[6]))][1:] + [torch.ones(1).to(device)]
      if len(tr_prev_probs) == 1:
        tr_prev_probs = [torch.ones(len(tr[1][i][0])) for i in range(len(tr[1]))]
      logger.timer['tr_prev_probs'] += time()-time0
      
      prev_probs += tr_prev_probs
      tr = tr[1]
      tr_rewards = [tr[i][rewards_idx] for i in range(len(tr))][1:] + [0]
      # we use features as a state emb and reward is what we get the next step
      rewards += tr_rewards
      do_discount = torch.ones(len(tr))
      if use_fork_discount:
        is_cfg_fork_idx = self.j_file['scheme'].index('is_cfg_fork')
        do_discount = [tr[i][is_cfg_fork_idx] for i in range(len(tr))]
      tr_Returns = self.tr_rewards_to_returns(tr_rewards, do_discount)
      Returns += tr_Returns
      
      tr_chosen_actions = [tr[i][chosenStId_idx] for i in range(len(tr))][1:] + [0]
      chosen_actions += tr_chosen_actions
      
      time0=time()
      tr_features = [tr[i][0][tr[i][chosenStId_idx]] for i in range(len(tr))]
      is_last += [0]*(len(tr_features)-1) + [1]
      features += tr_features
      features_next += tr_features[1:] + [[-1]*len(self.feature_names)]
      logger.timer['tr_features'] += time() - time0

      tr_ids += [tr_id for i in range(len(tr))]
      tr_id += 1

      time0 = time()
      if self.use_gnn:
        tr_graph_ids = [tr[i][graphId_idx] for i in range(len(tr))]
        graph_ids += tr_graph_ids
        
        tr_queue_block_ids = [torch.LongTensor(tr[i][blockIds_idx]).flip(dims=[0]).to(device) for i in range(len(tr))][1:] + [torch.LongTensor([-1]).to(device)]
        queue_block_ids += tr_queue_block_ids
      else:
        graph_ids += [-1] * len(tr)
      logger.timer['tr_graph_ids'] += time()-time0
      
      time0=time()
      tr_queues = [torch.Tensor(tr[i][0]).flip(dims=[0]).to(device) for i in range(len(tr))][1:] + [torch.zeros_like(torch.Tensor([tr[0][0][0]])).to(device)]
      logger.timer['tr_queues'] += time()-time0
      
      tr_queue_lengths = [len(q) for q in tr_queues]
      queues += tr_queues
      queue_lengths += tr_queue_lengths
    rewards = torch.Tensor(rewards).to(device)
    features = torch.Tensor(features).to(device)
    features_next = torch.Tensor(features_next).to(device)
    Returns = torch.Tensor(Returns).to(device)
    is_last = torch.Tensor(is_last).to(device)
    queue_lengths = torch.LongTensor(queue_lengths).to(device)
    # we flip queue and numeration of actions
    chosen_actions = queue_lengths - torch.LongTensor(chosen_actions).to(device) - 1
    tr_ids = torch.LongTensor(tr_ids).to(device)
    graph_ids = torch.LongTensor(graph_ids).to(device)
      
    time0 = time()
    if self.use_gnn:
      features = torch.cat((features, torch.zeros(len(features), gnn_out_nfeatures).to(device)), dim=1)
      features_next = torch.cat((features_next, torch.zeros(len(features_next), gnn_out_nfeatures).to(device)), dim=1)
      for i in range(len(queues)):
        queues[i] = torch.cat((queues[i], torch.zeros(len(queues[i]), gnn_out_nfeatures).to(device)), dim=1)
    logger.timer['cat graph features'] += time()-time0
      
    time0 = time()
    graphs_data = []
    for i in range(len(graph_features)):
      graphs_data.append([GraphData(x=torch.Tensor(features).to(device), edge_index=torch.LongTensor(graph_edges[i]).to(device))
                          for features in graph_features[i]])
    logger.timer['graphs_data'] += time()-time0
    print(logger.timer)
      
    realized = {
      'features': features,
      'features_next': features_next,
      'rewards': rewards,
      'Returns': Returns,
      'is_last': is_last,
      'queue_lengths': queue_lengths, 
      'chosen_actions': chosen_actions,
      'tr_ids': tr_ids,
      'graph_ids': graph_ids,
    }
    return realized, queues, queue_block_ids, graphs_data, prev_probs
  

  def get_n_train_states(self):
    return len(self.realized['features'])

  
  def get_properties(self):
    queue_lengths = np.array([q.shape[0] for q in self.queues])
    longest_queue_ids = np.argmax(queue_lengths)
    tr_lengths = np.array([len(tr[1]) for tr in self.j_file['paths']])
    prop = {
            'traj_length mean, median, max': (f'{tr_lengths.mean():.2f}', np.median(tr_lengths), tr_lengths.max()),
            'queue max length, idx': (self.queues[longest_queue_ids].shape[0], longest_queue_ids),
            'number of train states': self.get_n_train_states(),
            'number of traj-s': len(self.j_file['paths']),
            'number of validation traj-s': sum([not self.train_condition(tr) for tr in self.j_file['paths']]),
           }
    return prop
  

  def tr_rewards_to_returns(self, tr_rewards, do_discount):
    tr_R = [0]*(len(tr_rewards))
    for i in range(len(tr_rewards)-2, -1, -1):
        tr_R[i] = tr_rewards[i] + (self.td_gamma**do_discount[i]) * tr_R[i+1]
    return tr_R
  
  
  def sample_ids(self, batch_size=batch_size):
    """
    Assembles ids for several consiquent batches based on queues lengths.
    """
    if len(self.sampled_ids_list) == 0:
      sampled_ids = torch.LongTensor(random.choice(self.get_n_train_states(), size=batch_size*batch_accumulation_steps)).to(device)
      sampled_queue_lengths = self.realized['queue_lengths'][sampled_ids]      
      sampled_queue_lengths, sorting_ids = torch.sort(sampled_queue_lengths)
      sampled_ids = sampled_ids[sorting_ids]
      self.sampled_ids_list = list(torch.split(sampled_ids, batch_size))
      self.sampled_queue_lengths_list = list(torch.split(sampled_queue_lengths, batch_size))
    return self.sampled_ids_list.pop(0), self.sampled_queue_lengths_list.pop(0)
      

  def sample_batch(self, batch_size=batch_size):
    ids, queue_lengths = self.sample_ids(batch_size=batch_size)    
    sampled_realized = {k: self.realized[k][ids] for k in self.realized.keys()}  
    sampled_queues = itemgetter(*list(ids))(self.queues)
    sampled_prev_probs = itemgetter(*list(ids))(self.prev_probs)
    padded_length = torch.minimum(queue_lengths.max(), torch.tensor(max_nactions))
    queues_tensor = torch.zeros(batch_size, padded_length, len(self.queues[0][0])).to(device)
    prev_probs_tensor = torch.zeros(batch_size, padded_length).to(device)
    pad_mask = torch.ones(batch_size, padded_length).to(device)
    if self.use_gnn:
      sampled_queue_block_ids = itemgetter(*list(ids))(self.queue_block_ids)
      queue_block_ids_tensor = torch.ones(batch_size, padded_length, dtype=torch.long).to(device) * (-1)
    else:
      sampled_queue_block_ids = None
      queue_block_ids_tensor = None

    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()
    
    for i, q in enumerate(sampled_queues):
      l = min(q.shape[0], padded_length)
      queues_tensor[i, 0 : l, :] = q[: l, :]
      pad_mask[i, 0 : l] = 0
    pad_mask = pad_mask.bool()

    if self.use_gnn:
      for i, q in enumerate(sampled_queue_block_ids):
        l = min(q.shape[0], padded_length)
        queue_block_ids_tensor[i, 0 : l] = q[: l]
    for i, p in enumerate(sampled_prev_probs):
      l = min(p.shape[0], padded_length)
      prev_probs_tensor[i, 0 : l] = p[: l]

    end.record()    
    maybe_sync()
    logger.timer['sample inside loop'] += start.elapsed_time(end)/1000
    
    if not self.Psi is None:
      Psi = self.Psi[ids]
      return sampled_realized, Psi, queues_tensor, queue_block_ids_tensor, pad_mask, self.graphs_data, prev_probs_tensor
    return sampled_realized, queues_tensor, queue_block_ids_tensor, pad_mask, self.graphs_data, prev_probs_tensor
  
  
  def sample_realized_batch(self, n=batch_size):
    ids = torch.tensor(random.choice(self.get_n_train_states(), size=n)).long()
    sampled = {k: self.realized[k][ids] for k in self.realized.keys()}
    return sampled
  
  
  def update_data_on_path(self, actor_model=None, gnn_model=None, path=json_path):
    """
    Communication with jar file on a server.
    """
    if actor_model is None:
      # use default heuristics
      os.system('rm -f ../Game_env/actor_model.onnx')
      os.system('rm -f ../Game_env/gnn_model.onnx')
      algo_name = 'Heuristic'
    else:
      total_features_dim = features_dim
      if hasattr(actor_model, 'sfmax'):
        shape = [1, 1, total_features_dim]
      else:
        shape = [1, total_features_dim]
      x = torch.randn(*shape, requires_grad=True).to(device)
      torch_model = actor_model.eval()
      torch_out = torch_model(x)
      torch.onnx.export(torch_model,
                        x,
                        '../Game_env/actor_model.onnx',
                        opset_version=15,
                        export_params=True,
                        input_names = ['input'],   # the model's input names
                        output_names = ['output'],
                        dynamic_axes={'input' : {0 : 'batch_size',
                                                 1 : 'n_actions',
                                                },    # variable length axes
                                      'output' : {0 : 'batch_size',
                                                  1 : 'n_actions',
                                                },
                                     },
                        )
      algo_name = 'NN'

    if actor_model is not None and gnn_model is not None and self.use_gnn:
      x_shape = [1, self.gnn_in_nfeatures]
      edge_index_shape = [2, 1]
      x = (torch.randn(*x_shape, requires_grad=True).to(device), torch.randint(0, 1, edge_index_shape).to(device))
      torch_model = gnn_model.eval()
      torch_out = torch_model(*x)
      torch.onnx.export(torch_model,
                        x,
                        '../Game_env/gnn_model.onnx',
                        opset_version=15,
                        export_params=True,
                        input_names = ['x', 'edge_index'],   # the model's input names
                        output_names = ['output'],
                        dynamic_axes={'x' : {0 : 'nodes_number'},    # variable length axes
                                      'edge_index' : {1 : 'egdes_number'},
                                     },
                        )

    time_before = time()
    os.system(jar_command)
    print(algo_name, 'data gathering time:', time() - time_before)

  
    
  def evaluate_data(self,
           eval_condition = None,
           wandb_prefix = 'val',
           verbose=True,
           factors=torch.Tensor([1, td_gamma]),
           ):
    if eval_condition is None:
        eval_condition = (lambda tr: not self.train_condition(tr))
    rewards_idx = self.j_file['scheme'].index('reward')
    for f in factors:
      size = 0
      tr_lengths = []
      trs_R = []
      for tr in self.j_file['paths']:
        if not eval_condition(tr):
          continue
        tr=tr[1]
        size += len(tr)
        tr_lengths += [len(tr)]
        # rewards list is not shifted, has a different purpose this time
        tr_rewards = [tr[i][rewards_idx] for i in range(len(tr))]
        trs_R += [0]
        do_discount = torch.Tensor([1]*len(tr))
        if use_fork_discount:
          is_cfg_fork_idx = self.j_file['scheme'].index('is_cfg_fork')
          do_discount = [tr[i][is_cfg_fork_idx] for i in range(len(tr))]
        for i in range(len(tr_rewards)-1, -1, -1):
          trs_R[-1] = tr_rewards[i] + (f ** do_discount[i]) * trs_R[-1]
      log = {}
      log[f'{wandb_prefix} size'] = size
      log[f'{wandb_prefix}_eval/mean {f:.2f} discount '] = torch.Tensor(trs_R).mean()
      log[f'{wandb_prefix}_eval/median {f:.2f} discount '] = torch.Tensor(trs_R).median()
      log[f'{wandb_prefix} Return by trjs {f:.2f} (previous epoch)'] = wandb.Histogram(np_histogram=np.histogram(trs_R, bins=30, ))
      # log[f'{wandb_prefix} lengths hist'] = wandb.Histogram(np_histogram=np.histogram(tr_lengths, bins=30, ))
      log['code length train mean'], log['code length val mean'] = self.mean_code_length()
      if verbose:
          wandb.log(log.copy())
      log['Returns'] = trs_R
      log[f'{wandb_prefix} lengths'] = tr_lengths
    return log

<IPython.core.display.Javascript object>

### Trainer


In [6]:
class NN_Trainer:
  def __init__(
      self,
      NN_setup,
      trajectories=None,
      batch_size=batch_size,
      n_batches=1000,
      target_update_steps=15,
      td_gamma=td_gamma,
      clip_eps=2e-2,
      use_gnn=True,
      gnn_out_nfeatures=gnn_out_nfeatures,
      ):
    self.n_batches = n_batches
    self.batch_number = -1
    self.td_gamma = td_gamma
    self.gae_gamma = 0.7
    self.clip_eps = clip_eps
    self.use_gnn = use_gnn
    self.gnn_out_nfeatures = gnn_out_nfeatures
    self.actor = NN_setup['actor']
    self.actor_opt = NN_setup['actor_opt']
    self.actor_sched = NN_setup['actor_sched']
    self.prev_actor = copy.deepcopy(self.actor).eval()
    self.critic = NN_setup['critic'].train()
    self.target_critic = copy.deepcopy(self.critic).eval()
    self.critic_opt = NN_setup['critic_opt']
    self.gnn = NN_setup['gnn']
    self.gnn_opt = NN_setup['gnn_opt']
    self.gnn_zeros_tensor = torch.zeros(self.gnn_out_nfeatures).to(device)
    self.trajectories = trajectories
    self.batch_size = batch_size
    self.target_update_steps = target_update_steps
    self.log = {}

  def run_gnn(self, tr_ids, graph_ids, graphs_data, gnn_features):
    dataset = []
    tr_graph_ids = list(set(zip(tr_ids, graph_ids)))
    for tr_id, graph_id in tr_graph_ids:
      dataset.append(graphs_data[tr_id][graph_id])
    loader = GraphDataLoader(dataset, batch_size=256)
    gnn_features_list = []
    for batch_data in loader:
      results = self.gnn(batch_data.x, batch_data.edge_index)
      for i in range(batch_data.batch[-1] + 1):
        gnn_features_list.append(results[batch_data.batch == i])
    for i, (tr_id, graph_id) in enumerate(tr_graph_ids):
      gnn_features[(tr_id, graph_id)] = gnn_features_list[i]

  def add_gnn_features_actor(self, tr_ids, graph_ids, queue_block_ids_tensor, queues_tensor, gnn_features):
    gnn_actor_tensor = []
    for tr_id, graph_id, queue_block_ids in zip(tr_ids, graph_ids, queue_block_ids_tensor):
      cur_gnn_features = gnn_features[(tr_id, graph_id)]
      for block_id in queue_block_ids:
        if block_id == -1:
          gnn_actor_tensor.append(self.gnn_zeros_tensor)
        else:
          gnn_actor_tensor += [cur_gnn_features[block_id]]
    gnn_actor_tensor = torch.cat(gnn_actor_tensor)
    queues_shape = queues_tensor.shape
    mask = torch.cat((torch.zeros(queues_shape[0], queues_shape[1], queues_shape[2] - self.gnn_out_nfeatures),
                       torch.ones(queues_shape[0], queues_shape[1], self.gnn_out_nfeatures)), dim=2).bool().to(device)
    queues_tensor[mask] = gnn_actor_tensor

  def add_gnn_features_critic(self, tr_ids, graph_ids, features, features_next, gnn_features):
    gnn_critic_list = []
    for tr_id, graph_id in zip(tr_ids, graph_ids):
      cur_gnn_features = gnn_features[(tr_id, graph_id)]
      gnn_critic_list.append(torch.mean(cur_gnn_features, dim=0))
    gnn_critic_tensor = torch.cat(gnn_critic_list)
    gnn_critic_tensor_next = torch.cat(gnn_critic_list[1:] + [torch.zeros(self.gnn_out_nfeatures).to(device)])
    mask = torch.cat((torch.zeros(features.shape[0], features.shape[1] - self.gnn_out_nfeatures),
                       torch.ones(features.shape[0], self.gnn_out_nfeatures)), dim=1).bool().to(device)
    features[mask] = gnn_critic_tensor
    features_next[mask] = gnn_critic_tensor_next

  def get_each_loss(self, sampled_batch,):
    """
    Computes losses for actor, critic and exploration (loss_ent) within PPO algorithm.
    Decisions were made to avoid python loops at all costs -- 
    varying action space is not particularly batch-friendly.    
    """
    if self.trajectories.Psi is None:
      sampled_realized, queues_tensor, queue_block_ids_tensor, pad_mask, graphs_data, prev_probs_tensor = sampled_batch
    else:
      sampled_realized, Psi, queues_tensor, queue_block_ids_tensor, pad_mask, graphs_data, prev_probs_tensor = sampled_batch
    self.log['Returns mean'] = sampled_realized['Returns'].mean().item()
    self.log['rewards mean'] = sampled_realized['rewards'].mean().item()
    self.log['queue length max'] = sampled_realized['queue_lengths'].max().item()
        
    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()

    gnn_features = dict()
    if self.use_gnn:
      self.run_gnn(sampled_realized['tr_ids'].tolist(), sampled_realized['graph_ids'].tolist(), graphs_data, gnn_features)
      self.add_gnn_features_actor(sampled_realized['tr_ids'].tolist(), sampled_realized['graph_ids'].tolist(),
                                  queue_block_ids_tensor.tolist(), queues_tensor, gnn_features)
      self.add_gnn_features_critic(sampled_realized['tr_ids'].tolist(), sampled_realized['graph_ids'].tolist(),
                                   sampled_realized['features'], sampled_realized['features_next'], gnn_features)
    probs = self.actor(queues_tensor, mask=pad_mask)
    self.log['max prob mean'] = probs.max(dim=-1).values.mean().item()
    self.log['40 max prob quantile'] = torch.quantile(probs.max(dim=-1).values, 0.4).item()
    
    end.record()    
    maybe_sync()
    logger.timer['probs+add_gnn_feature'] += start.elapsed_time(end)/1000

    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()

    values = self.critic(sampled_realized['features']).squeeze(-1)
    with torch.no_grad():
      next_values = self.target_critic(sampled_realized['features_next']).squeeze(-1)
    
    end.record()    
    maybe_sync()
    logger.timer['values'] += start.elapsed_time(end)/1000
        
    t_hist = time()
    self.log['V-func mean'] = torch.mean(values.detach()).item()
    self.log['V-func stdev'] = torch.std(values.detach()).item()
    # hist = wandb.Histogram(np_histogram=np.histogram(values.detach().to('cpu'), bins=40, ))
    # self.log['V-func hist'] = hist
    logger.timer['hist'] += time() - t_hist

    # critic loss
    t_critic_loss = time()
    TD = values - ((sampled_realized['rewards'] + next_values.detach() * self.td_gamma) * (1-sampled_realized['is_last']))
    MC = (values - sampled_realized['Returns']).abs().mean()/1000
    loss_c = (TD**2).mean()/10 + MC
    logger.timer['critic loss'] += time()-t_critic_loss

    self.log['TD loss'] = (TD**2).mean().item()/10
    self.log['MC loss'] = MC.item()
    # hist = wandb.Histogram(np_histogram=np.histogram(TD.detach().to('cpu'), bins=20, ))
    # self.log['TD hist'] = hist
        
    # entropy loss
    t_entropy_loss = time()
    
    entropies = - probs * torch.log(torch.max(torch.tensor(1e-40), probs))
    entropy_regularizer = torch.log(torch.minimum(sampled_realized['queue_lengths']+1, torch.tensor(max_nactions+1))).to(device)
    entropy_by_state_reg = torch.sum(entropies, dim=-1) / entropy_regularizer
    loss_ent = -entropy_by_state_reg.mean()
        
    logger.timer['entropy loss'] += time()-t_entropy_loss
    
    # actor loss
    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()
    
    probs_chosen = probs[get_ass_dope_ids(sampled_realized['chosen_actions'])].squeeze(-1)
    prev_probs_chosen = prev_probs_tensor[get_ass_dope_ids(sampled_realized['chosen_actions'])].squeeze(-1)
    ratios = (probs_chosen / (prev_probs_chosen.detach()+1e-4)).to(device)
    
    clipped = torch.clip(ratios, min=1-self.clip_eps, max=1+self.clip_eps)
    Adv = - (TD * (1-sampled_realized['is_last'])).detach() # to not affect loss_a logs
    self.log['Adv mean'] = Adv.sum() / (Adv.numel() - sampled_realized['is_last'].sum())
    if self.trajectories.Psi is None:
      loss_a = - torch.min(ratios*Adv, clipped*Adv).mean()
    else:
      loss_a = - torch.min(ratios*Psi, clipped*Psi).mean()
    self.log['clipped to all'] = (clipped*Adv < ratios*Adv).long().sum().item() / ratios.numel()
    self.log['ratios mean'] = ratios.mean().item()
    end.record()    
    maybe_sync()
    logger.timer['actor loss'] += start.elapsed_time(end)/1000
    return loss_a, loss_c, loss_ent
  
  
  def critic_loss(self,
               sampled_realized,
    ):
    """
    Temporal difference loss
    """
    if self.use_gnn:
      gnn_features = dict()
      graphs_data = self.trajectories.graphs_data
      self.run_gnn(sampled_realized['tr_ids'].tolist(), sampled_realized['graph_ids'].tolist(), graphs_data, gnn_features)
      self.add_gnn_features_critic(sampled_realized['tr_ids'].tolist(), sampled_realized['graph_ids'].tolist(),
                                   sampled_realized['features'], sampled_realized['features_next'], gnn_features)
    values = self.critic(sampled_realized['features']).squeeze(-1)
    with torch.no_grad():
      next_values = self.target_critic(sampled_realized['features_next']).squeeze(-1)
    self.log['V-func mean'] = torch.mean(values.detach()).item()
    TD = values - ((sampled_realized['rewards'] + next_values * self.td_gamma) * (1-sampled_realized['is_last']))
    MC = (values - sampled_realized['Returns']).abs().mean()/1000
    critic_loss = (TD**2).mean()/10 + MC
    self.log['TD loss'] = (TD**2).mean().item()/10
    self.log['MC loss'] = MC.item()
    return critic_loss
  
    
  def learn_v(self, n_batches=None):
    """
    Approximates V-function of previous policy by iterating over collected data.
    """
    if n_batches is None:
      n_batches = self.n_batches
    for i in trange(n_batches):
      self.batch_number = i
      if self.batch_number % self.target_update_steps == 0:
        self.target_critic = copy.deepcopy(self.critic).eval()
      sampled_realized = self.trajectories.sample_realized_batch(n=2048)
      loss = self.critic_loss(sampled_realized)
      self.critic_opt.zero_grad()
      loss.backward()
      torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 30)
      self.critic_opt.step()
      if self.batch_number % (logger.between_logs+1) == 0:
        wandb.log(self.log)
        self.log = {}


  def learn_new_policy(self, prev_actor=None):
    """
    Implements one learning cycle over collected dataset.
    """
    if prev_actor is None:
      self.prev_actor = copy.deepcopy(self.actor).eval()
    else:
      self.prev_actor = prev_actor
      
    logger.timer = {'probs+add_gnn_feature': 0,
                    'values': 0,
                    'hist': 0,
                    'critic loss': 0,
                    'entropy loss': 0,
                    'actor loss': 0,
                    'total loss': 0,
                    'optimizers step': 0,
                    'loss.backward': 0,
                    'sample batch': 0,
                    'sample ids': 0,
                    'sample inside loop': 0,
                    }
    for i in trange(self.n_batches):
      self.batch_number = i
      if self.batch_number % self.target_update_steps == 0:
        self.target_critic = copy.deepcopy(self.critic).eval()
      
      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()
      
      sampled_batch = self.trajectories.sample_batch(self.batch_size)
      
      end.record()    
      maybe_sync()
      logger.timer['sample batch'] += start.elapsed_time(end)/1000
      
      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()      
      
      loss_a, loss_c, loss_ent = self.get_each_loss(sampled_batch)
      loss = (loss_a + loss_c + loss_ent/100) / batch_accumulation_steps

      end.record()    
      maybe_sync()
      logger.timer['total loss'] += start.elapsed_time(end)/1000
      
      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()   
      
      loss.backward()
      
      end.record()    
      maybe_sync()
      logger.timer['loss.backward'] += start.elapsed_time(end)/1000
      
      if self.batch_number % batch_accumulation_steps == 0:
        t_optimizers_step = time()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 30)
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 5)
        self.actor_opt.step()
        self.critic_opt.step()
        if use_gnn:
          self.gnn_opt.step()
          self.log.update({
            'grad gnn': logger.list_norm([p.grad for p in self.gnn.parameters() if p.requires_grad], 2),
            'weight gnn': logger.list_norm([p for p in self.gnn.parameters() if p.requires_grad], 2),
        })
        self.log.update({
            'grad actor': logger.list_norm([p.grad for p in self.actor.parameters() if p.requires_grad], 2),
            'grad critic': logger.list_norm([p.grad for p in self.critic.parameters() if p.requires_grad], 2),
        })
        self.critic_opt.zero_grad()
        self.actor_opt.zero_grad()
        if use_gnn:
          self.gnn_opt.zero_grad()
        logger.timer['optimizers step'] += time()-t_optimizers_step
        

      if self.batch_number % (logger.between_logs+1) == 0:
        self.log.update({
            'loss actor': loss_a.item(),
            'loss critic': loss_c.item(),
            '-entropy by log(n)': loss_ent.item(),
            'weight actor': logger.list_norm([p for p in self.actor.parameters() if p.requires_grad], 2),
            'weight critic': logger.list_norm([p for p in self.critic.parameters() if p.requires_grad], 2),
        })
        wandb.log(self.log)
        self.log = {}
    self.actor_sched.step()
        
        
  @torch.no_grad()      
  def compute_GAE(self, critic_to_use=None):
    """
    Computes GAE with passed/current critic.
    GAEs could be used instead of Advantages in PPO actor loss
    """
    if critic_to_use is None:
      critic_to_use = copy.deepcopy(self.critic).eval()
    features = self.trajectories.realized['features']
    features_next = self.trajectories.realized['features_next']
    rewards = self.trajectories.realized['rewards']
    is_last = self.trajectories.realized['is_last']
    Adv1 = []
    ids_list = list(torch.split(torch.LongTensor(np.arange(len(is_last))), 4096))
    for batch_ids in ids_list:
      values = critic_to_use(features[batch_ids]).squeeze(-1)
      values_next = critic_to_use(features_next[batch_ids]).squeeze(-1)
      TD = values - ((rewards[batch_ids] + values_next * self.td_gamma) * (1-is_last[batch_ids]))
      batch_Adv = - TD * (1-is_last[batch_ids])
      Adv1 += [batch_Adv]
    Adv1 = torch.cat(Adv1, dim=0)
    Psi = torch.zeros_like(Adv1)
    
    for i in range(len(Adv1)-1, -1, -1):
      if is_last[i]:
        continue
      Psi[i] = Psi[i+1] * self.gae_gamma * self.td_gamma + Adv1[i]
    self.trajectories.Psi = Psi      
    wandb.log({'Psi mean': Psi.mean().item(),
               'Psi 95 perc': torch.quantile(Psi, 0.95).item(),
               'Adv mean': Adv1.mean().item(),
              })

<IPython.core.display.Javascript object>

In [7]:
class RNN_Trainer():
  """
  Learns state history representation via a proxy task of V-function approximation.
  """
  def __init__(self, v_cell, optimizer, trajectories, loss_horizon=10, rnn_batch_size=32):
    self.rnn_cell = v_cell
    self.optimizer = optimizer
    self.loss_horizon = loss_horizon
    self.retain_graph = True
    self.trjs = trajectories
    self.log = {}
    
  def sample_ids_pairs(self, features_by_trjs, rnn_batch_size, n_steps):
    """
    Returns [n_steps] list with [rnn_batch_size, 2] tensors
    of pairs (traj idx, position in traj for rnn to start from)
    """
    n_trjs = len(features_by_trjs)
    residual_lengths = torch.tensor([len(f)-self.loss_horizon for f in features_by_trjs])
    sampling_weights = torch.log(residual_lengths) / torch.log(residual_lengths).sum()
    trjs_ids = torch.tensor(np.random.choice(np.arange(n_trjs), size=(n_steps, rnn_batch_size), p=sampling_weights))
    position_samples = torch.rand(n_steps, rnn_batch_size)
    residual_lengths_forall = torch.take(residual_lengths, trjs_ids)
    start_ids = (residual_lengths_forall*position_samples).long()
    stacked = torch.stack([trjs_ids, start_ids], dim=-1).to(device)
    return list(stacked)
  
  def train(self, n_steps=1000, rnn_batch_size=32):
    is_last_ids = torch.nonzero(self.trjs.realized['is_last']).squeeze()
    all_train_trjs_lengths = is_last_ids - torch.cat([torch.tensor([-1]).to(device), is_last_ids[:-1]])
    not_long_enough_ids = (train_trjs_lengths < self.loss_horizon + 3).nonzero().squeeze()
    features_by_trjs = list(torch.split(self.trjs.realized['features'], list(all_train_trjs_lengths)))
    Returns_by_trjs = list(torch.split(self.trjs.realized['Returns'], list(all_train_trjs_lengths)))
    for i in list(not_long_enough_ids).reverse():
      # we train ony on long enough trajectories
      popped = features_by_trjs.pop(i)
      Returns_by_trjs.pop(i)
      assert len(popped) < self.loss_horizon
    sampled_ids = sample_ids_pairs(features_by_trjs, rnn_batch_size, n_steps)
    for step in range(n_steps):
      batch_ids = sampled_ids.pop(0)
      inputs = torch.zeros(rnn_batch_size, self.loss_horizon, features_dim).to(device)
      targets = torch.zeros(rnn_batch_size, self.loss_horizon, 1).to(device)
      for b in range(rnn_batch_size):
        inputs[b] = features_by_trjs[batch_ids[b,0]][batch_ids[b,1]: batch_ids[b,1] + self.loss_horizon]
        targets[b] = Returns_by_trjs[batch_ids[b,0]][batch_ids[b,1]: batch_ids[b,1] + self.loss_horizon][:, None]
      states = [torch.zeros(2*self.rnn_cell.num_layers, rnn_batch_size, self.rnn_cell.hidden_size).to(device)]
      Vs = []
      for rnn_step in range(self.loss_horizon):
        v, new_state = self.rnn_cell(inputs[rnn_step], states[-1])
        states += [new_state]
        Vs += [v]
      abs_difs = (torch.cat(Vs) - targets).abs()
      assert abs_difs.shape == (rnn_batch_size, self.loss_horizon, 1)
      loss_weights = torch.arange(self.loss_horizon)**1.4 / (torch.arange(self.loss_horizon)**1.4).sum()
      loss = (abs_difs * loss_weights[:, None].to(device)).mean()
      self.log['rnn MC abs loss'] = loss.item()
      self.log['rnn V mean'] = torch.cat(Vs).mean().item()
      self.optimizer.zero_grad()
      loss.backward()
      self.optimizer.step()
      
  @torch.no_grad()
  def collect_rnn_features(self):
    """
    Collects rnn hidden and V for every state in train dataset 
    """
    features = self.trjs.realized['features']
    is_last = self.trjs.realized['is_last']
    hid_list = []
    v_list = []
    init_state = torch.zeros(2*self.rnn_cell.num_layers, 1, self.rnn_cell.hidden_size).to(device)
    state = init_state
    for i, f in enumerate(features):
      v, state = self.rnn_cell(f[None, :], state)
      hid = state[2].squeeze(0)
      hid_list += [hid.detach()]
      v_list += [v.detach()]
      if is_last[i]:
        state = init_state
    self.traj.realized['rnn_features'] = torch.stack(hid_list, dim=0)
    self.traj.realized['rnn_v'] = torch.stack(v_list, dim=0)
    
  @torch.no_grad()
  def concat_rnn_features(self):
    """
    Concatenates state history embedding to embedding of that state and to embeddings of corresponding actions 
    """
    realized = self.trjs.realized
    rnn_related = torch.cat([realized['rnn_features'], realized['rnn_v']], dim=-1)
    realized['features'] = torch.cat([realized['features'], rnn_related], dim=-1)
    queues = self.trjs.queues
    time_before = time()
    for i in range(len(queues)):
      queues[i] = torch.cat([queues[i], rnn_related[i].repeat(len(queues[i]), 1)], dim=-1)
    print('Concating rnn features to queues: ', time()-time_before)

<IPython.core.display.Javascript object>

### Utils

In [8]:
def get_ass_dope_ids(l):
  l = l.long()
  if len(l.shape) == 1:
    l = l[:, None]
  ind = torch.LongTensor(np.indices(l.shape))
  ind[-1] = l
  return tuple(ind)

def consumption_percent(epoch):
  return 50

def get_clip_eps(epoch):
  if epoch < 5:
    c = 0.5
  elif epoch < 25:
    c = 0.2
  else:
    c = 0.05
  return c

<IPython.core.display.Javascript object>

### Procedure

In [9]:
# Pretrain

epochs = 100
train_condition, message = (lambda tr: tr[0]%5 != 2), '%5!=2'
checkpoint_path = '../Checkpoints/actor(Double), critic'
checkpoint_actor = torch.load(checkpoint_path)['actor']
checkpoint_critic = torch.load(checkpoint_path)['critic']

actor = copy.deepcopy(checkpoint_actor).to(device).train()
actor_opt = torch.optim.AdamW(actor.parameters(), lr=2e-5, weight_decay=2e-3,)
actor_sched = torch.optim.lr_scheduler.LinearLR(actor_opt, start_factor=1, end_factor=0.1, total_iters=epochs, verbose=True)
critic = copy.deepcopy(checkpoint_critic).to(device).train()
critic_opt = torch.optim.AdamW(critic.parameters(), lr=2e-5, weight_decay=0.05)
gnn, gnn_opt = get_gnn_setup(gnn_in_nfeatures, gnn_out_nfeatures)

trajectories = Trajectories(train_condition=train_condition,
                                use_gnn=use_gnn,
                                gnn_out_nfeatures=gnn_out_nfeatures,
                                gnn_in_nfeatures=gnn_in_nfeatures)    
run = wandb.init(
      project="PS",
      name=f'tune double no_gnn',
      config={
          'algorithm': 'PPO-clip',
          'models': 'mlp + attn',
      }
)

with open('../Game_env/jar_config.txt', 'w') as jar_config:
  jar_config.write(json.dumps({'postprocessing': 'None', 
                               'dataConsumption': 100,
                               'maxAttentionLength': max_nactions,
                               'inputShape': [1, -1, features_dim],
                               'defaultAlgorithm': 'BFS',
                               "maxConcurrency": 64,
                               'useGnn': use_gnn}))  
trajectories.gather_n_store(actor_model=checkpoint_actor, gnn_model=gnn)
trajectories.evaluate_val_train()
print(trajectories.get_properties())

# actor, actor_opt, actor_sched = get_attn_setup(epochs=epochs, use_double=True)
# critic, critic_opt = get_mlp_setup(use_FFM=True)

for epoch in range(epochs):  
    with open('../Game_env/jar_config.txt', 'w') as jar_config:
      jar_config.write(json.dumps({'postprocessing': 'None', 
                                   'dataConsumption': consumption_percent(epoch),
                                   'maxAttentionLength': max_nactions,
                                   'inputShape': [1, -1, features_dim],
                                   'defaultAlgorithm': 'BFS',
                                   "maxConcurrency": 64,
                                   'useGnn': use_gnn}))  

    trainer = NN_Trainer(NN_setup={'gnn': gnn, 'gnn_opt': gnn_opt,
                                   'actor': actor, 'actor_opt': actor_opt, 'actor_sched': actor_sched,
                                   'critic': critic, 'critic_opt': critic_opt,},
                         trajectories=trajectories,
                         n_batches=250,
                         clip_eps = get_clip_eps(epoch),
                         use_gnn=use_gnn,
                         gnn_out_nfeatures=gnn_out_nfeatures,
                         )
    trainer.learn_new_policy()
    wandb.log({'epoch': epoch})
    trajectories.gather_n_store(actor_model=actor, gnn_model=gnn)
    trajectories.evaluate_val_train()
    print(trajectories.get_properties())

checkpoint = {
  'actor': actor,
  'critic':critic,
}
torch.save(checkpoint, os.path.join(wandb.run.dir, f'actor, critic'))
wandb.finish()

<IPython.core.display.Javascript object>

Adjusting learning rate of group 0 to 2.0000e-05.


NN data gathering time: 139.4162735939026
json loading time:  6.077685594558716
{'tr_prev_probs': 3.1756234169006348, 'graph_features': 0.0012946128845214844, 'tr_features': 0.02842998504638672, 'tr_graph_ids': 0.002512693405151367, 'tr_queues': 4.312460660934448, 'cat graph features': 5.0067901611328125e-06, 'graphs_data': 2.47955322265625e-05}
j2torch time:  10.298267841339111
{'traj_length mean, median, max': ('169.10', 31.0, 1501), 'queue max length, idx': (48, 78237), 'number of train states': 103120, 'number of traj-s': 788, 'number of validation traj-s': 170}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:40<00:00,  6.21it/s]


Adjusting learning rate of group 0 to 1.9820e-05.
NN data gathering time: 96.91763687133789
json loading time:  3.2417681217193604
{'tr_prev_probs': 1.5628864765167236, 'graph_features': 0.0006070137023925781, 'tr_features': 0.013328075408935547, 'tr_graph_ids': 0.00083160400390625, 'tr_queues': 2.2015445232391357, 'cat graph features': 5.9604644775390625e-06, 'graphs_data': 4.5299530029296875e-06}
j2torch time:  5.411961793899536
{'traj_length mean, median, max': ('169.28', 28.5, 1501), 'queue max length, idx': (42, 20090), 'number of train states': 50887, 'number of traj-s': 394, 'number of validation traj-s': 85}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:37<00:00,  6.59it/s]


Adjusting learning rate of group 0 to 1.9640e-05.
NN data gathering time: 98.65009045600891
json loading time:  4.268360614776611
{'tr_prev_probs': 2.3114044666290283, 'graph_features': 0.0009388923645019531, 'tr_features': 0.022204160690307617, 'tr_graph_ids': 0.0013439655303955078, 'tr_queues': 2.029927968978882, 'cat graph features': 7.62939453125e-06, 'graphs_data': 1.049041748046875e-05}
j2torch time:  6.068872451782227
{'traj_length mean, median, max': ('187.25', 31.0, 1501), 'queue max length, idx': (65, 39579), 'number of train states': 53440, 'number of traj-s': 387, 'number of validation traj-s': 92}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:43<00:00,  5.72it/s]


Adjusting learning rate of group 0 to 1.9460e-05.
NN data gathering time: 91.03038382530212
json loading time:  4.940057754516602
{'tr_prev_probs': 1.9303224086761475, 'graph_features': 0.0011477470397949219, 'tr_features': 0.01940131187438965, 'tr_graph_ids': 0.0014498233795166016, 'tr_queues': 2.23753023147583, 'cat graph features': 8.344650268554688e-06, 'graphs_data': 1.1920928955078125e-05}
j2torch time:  6.115119695663452
{'traj_length mean, median, max': ('216.58', 33.0, 1501), 'queue max length, idx': (51, 17236), 'number of train states': 58940, 'number of traj-s': 382, 'number of validation traj-s': 87}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:40<00:00,  6.22it/s]


Adjusting learning rate of group 0 to 1.9280e-05.
NN data gathering time: 94.44353818893433
json loading time:  5.506159782409668
{'tr_prev_probs': 1.6824486255645752, 'graph_features': 0.0008022785186767578, 'tr_features': 0.016710996627807617, 'tr_graph_ids': 0.0011818408966064453, 'tr_queues': 1.9804906845092773, 'cat graph features': 1.0013580322265625e-05, 'graphs_data': 1.049041748046875e-05}
j2torch time:  5.400115013122559
{'traj_length mean, median, max': ('186.63', 33.0, 1501), 'queue max length, idx': (72, 39454), 'number of train states': 53173, 'number of traj-s': 385, 'number of validation traj-s': 83}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:40<00:00,  6.18it/s]


Adjusting learning rate of group 0 to 1.9100e-05.
NN data gathering time: 96.78603339195251
json loading time:  4.197582721710205
{'tr_prev_probs': 1.9023914337158203, 'graph_features': 0.0011103153228759766, 'tr_features': 0.0191497802734375, 'tr_graph_ids': 0.001455545425415039, 'tr_queues': 2.181835174560547, 'cat graph features': 5.9604644775390625e-06, 'graphs_data': 5.7220458984375e-06}
j2torch time:  6.036898374557495
{'traj_length mean, median, max': ('207.24', 29.0, 1473), 'queue max length, idx': (44, 12773), 'number of train states': 58787, 'number of traj-s': 395, 'number of validation traj-s': 91}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:38<00:00,  6.48it/s]


Adjusting learning rate of group 0 to 1.8920e-05.
NN data gathering time: 96.71084570884705
json loading time:  4.300457954406738
{'tr_prev_probs': 1.6048972606658936, 'graph_features': 0.0007472038269042969, 'tr_features': 0.01572442054748535, 'tr_graph_ids': 0.0011725425720214844, 'tr_queues': 1.8683140277862549, 'cat graph features': 1.1205673217773438e-05, 'graphs_data': 1.0251998901367188e-05}
j2torch time:  5.108102560043335
{'traj_length mean, median, max': ('179.42', 26.0, 1501), 'queue max length, idx': (84, 38576), 'number of train states': 50677, 'number of traj-s': 388, 'number of validation traj-s': 84}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:44<00:00,  5.67it/s]


Adjusting learning rate of group 0 to 1.8740e-05.
NN data gathering time: 97.47627472877502
json loading time:  3.760350465774536
{'tr_prev_probs': 5.11362886428833, 'graph_features': 0.00106048583984375, 'tr_features': 0.01769423484802246, 'tr_graph_ids': 0.0013127326965332031, 'tr_queues': 7.372008323669434, 'cat graph features': 5.0067901611328125e-06, 'graphs_data': 4.291534423828125e-06}
j2torch time:  14.293853521347046
{'traj_length mean, median, max': ('175.12', 26.0, 1501), 'queue max length, idx': (48, 43514), 'number of train states': 57636, 'number of traj-s': 394, 'number of validation traj-s': 80}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [01:18<00:00,  3.20it/s]


Adjusting learning rate of group 0 to 1.8560e-05.
NN data gathering time: 87.77162981033325
json loading time:  3.9571685791015625
{'tr_prev_probs': 9.034733772277832, 'graph_features': 0.0011374950408935547, 'tr_features': 0.01820087432861328, 'tr_graph_ids': 0.001394510269165039, 'tr_queues': 9.997941732406616, 'cat graph features': 6.198883056640625e-06, 'graphs_data': 1.049041748046875e-05}
j2torch time:  20.998766899108887
{'traj_length mean, median, max': ('199.04', 32.0, 1344), 'queue max length, idx': (40, 46117), 'number of train states': 60459, 'number of traj-s': 399, 'number of validation traj-s': 82}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [01:16<00:00,  3.27it/s]


Adjusting learning rate of group 0 to 1.8380e-05.
NN data gathering time: 103.66037964820862
json loading time:  3.764714479446411
{'tr_prev_probs': 1.6602635383605957, 'graph_features': 0.0006389617919921875, 'tr_features': 0.015609264373779297, 'tr_graph_ids': 0.0011925697326660156, 'tr_queues': 1.8785905838012695, 'cat graph features': 5.9604644775390625e-06, 'graphs_data': 5.4836273193359375e-06}
j2torch time:  5.202850103378296
{'traj_length mean, median, max': ('182.33', 36.0, 1501), 'queue max length, idx': (49, 2959), 'number of train states': 54049, 'number of traj-s': 392, 'number of validation traj-s': 88}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:39<00:00,  6.29it/s]


Adjusting learning rate of group 0 to 1.8200e-05.
NN data gathering time: 101.9282820224762
json loading time:  4.455860137939453
{'tr_prev_probs': 1.5757029056549072, 'graph_features': 0.0007762908935546875, 'tr_features': 0.015395641326904297, 'tr_graph_ids': 0.0011866092681884766, 'tr_queues': 1.7718017101287842, 'cat graph features': 4.291534423828125e-06, 'graphs_data': 3.814697265625e-06}
j2torch time:  5.014542818069458
{'traj_length mean, median, max': ('191.82', 30.0, 1501), 'queue max length, idx': (52, 14997), 'number of train states': 50010, 'number of traj-s': 374, 'number of validation traj-s': 85}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:39<00:00,  6.30it/s]


Adjusting learning rate of group 0 to 1.8020e-05.
NN data gathering time: 112.91390657424927
json loading time:  3.806763172149658
{'tr_prev_probs': 1.602522850036621, 'graph_features': 0.0008635520935058594, 'tr_features': 0.01526498794555664, 'tr_graph_ids': 0.001107931137084961, 'tr_queues': 1.8454291820526123, 'cat graph features': 6.198883056640625e-06, 'graphs_data': 8.58306884765625e-06}
j2torch time:  5.048805475234985
{'traj_length mean, median, max': ('180.13', 31.0, 1501), 'queue max length, idx': (64, 43163), 'number of train states': 52045, 'number of traj-s': 384, 'number of validation traj-s': 88}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:39<00:00,  6.30it/s]


Adjusting learning rate of group 0 to 1.7840e-05.
NN data gathering time: 90.04598498344421
json loading time:  3.9676709175109863
{'tr_prev_probs': 1.800954818725586, 'graph_features': 0.0007832050323486328, 'tr_features': 0.01719522476196289, 'tr_graph_ids': 0.0012803077697753906, 'tr_queues': 2.0939414501190186, 'cat graph features': 5.0067901611328125e-06, 'graphs_data': 5.0067901611328125e-06}
j2torch time:  6.180996894836426
{'traj_length mean, median, max': ('188.29', 27.5, 1501), 'queue max length, idx': (59, 39305), 'number of train states': 58360, 'number of traj-s': 394, 'number of validation traj-s': 84}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:38<00:00,  6.50it/s]


Adjusting learning rate of group 0 to 1.7660e-05.
NN data gathering time: 99.60273265838623
json loading time:  5.195897817611694
{'tr_prev_probs': 1.831662654876709, 'graph_features': 0.0007834434509277344, 'tr_features': 0.017466306686401367, 'tr_graph_ids': 0.0012753009796142578, 'tr_queues': 2.7845561504364014, 'cat graph features': 2.4318695068359375e-05, 'graphs_data': 9.775161743164062e-06}
j2torch time:  6.475531101226807
{'traj_length mean, median, max': ('207.09', 28.0, 1501), 'queue max length, idx': (67, 19995), 'number of train states': 58674, 'number of traj-s': 387, 'number of validation traj-s': 92}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:41<00:00,  6.05it/s]


Adjusting learning rate of group 0 to 1.7480e-05.
NN data gathering time: 99.77912473678589
json loading time:  2.657055139541626
{'tr_prev_probs': 1.3577995300292969, 'graph_features': 0.0007417201995849609, 'tr_features': 0.0122833251953125, 'tr_graph_ids': 0.000949859619140625, 'tr_queues': 1.5210909843444824, 'cat graph features': 5.9604644775390625e-06, 'graphs_data': 5.245208740234375e-06}
j2torch time:  4.2996344566345215
{'traj_length mean, median, max': ('146.83', 32.0, 1420), 'queue max length, idx': (39, 8885), 'number of train states': 44104, 'number of traj-s': 386, 'number of validation traj-s': 92}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:37<00:00,  6.75it/s]


Adjusting learning rate of group 0 to 1.7300e-05.
NN data gathering time: 91.68265724182129
json loading time:  4.002163648605347
{'tr_prev_probs': 1.7837085723876953, 'graph_features': 0.0009870529174804688, 'tr_features': 0.017519235610961914, 'tr_graph_ids': 0.0011873245239257812, 'tr_queues': 2.1166000366210938, 'cat graph features': 5.245208740234375e-06, 'graphs_data': 9.298324584960938e-06}
j2torch time:  5.672453165054321
{'traj_length mean, median, max': ('173.46', 29.0, 1501), 'queue max length, idx': (62, 45438), 'number of train states': 56728, 'number of traj-s': 377, 'number of validation traj-s': 80}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:41<00:00,  6.07it/s]


Adjusting learning rate of group 0 to 1.7120e-05.
NN data gathering time: 85.66029906272888
json loading time:  3.216754913330078
{'tr_prev_probs': 1.831850290298462, 'graph_features': 0.0009720325469970703, 'tr_features': 0.016774415969848633, 'tr_graph_ids': 0.0012693405151367188, 'tr_queues': 2.378387451171875, 'cat graph features': 2.7418136596679688e-05, 'graphs_data': 9.5367431640625e-06}
j2torch time:  6.066020965576172
{'traj_length mean, median, max': ('182.07', 31.5, 1472), 'queue max length, idx': (33, 20050), 'number of train states': 58025, 'number of traj-s': 406, 'number of validation traj-s': 94}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:36<00:00,  6.85it/s]


Adjusting learning rate of group 0 to 1.6940e-05.
NN data gathering time: 98.01767683029175
json loading time:  3.3071131706237793
{'tr_prev_probs': 2.063119649887085, 'graph_features': 0.0007266998291015625, 'tr_features': 0.015488386154174805, 'tr_graph_ids': 0.0011744499206542969, 'tr_queues': 1.8206398487091064, 'cat graph features': 5.7220458984375e-06, 'graphs_data': 5.245208740234375e-06}
j2torch time:  5.509599447250366
{'traj_length mean, median, max': ('173.43', 33.0, 1501), 'queue max length, idx': (76, 37454), 'number of train states': 50409, 'number of traj-s': 382, 'number of validation traj-s': 85}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:42<00:00,  5.82it/s]


Adjusting learning rate of group 0 to 1.6760e-05.
NN data gathering time: 90.06292486190796
json loading time:  3.2592175006866455
{'tr_prev_probs': 1.6515681743621826, 'graph_features': 0.0005984306335449219, 'tr_features': 0.015193462371826172, 'tr_graph_ids': 0.0011911392211914062, 'tr_queues': 1.8352909088134766, 'cat graph features': 5.9604644775390625e-06, 'graphs_data': 4.76837158203125e-06}
j2torch time:  5.128601789474487
{'traj_length mean, median, max': ('181.06', 28.0, 1501), 'queue max length, idx': (31, 3718), 'number of train states': 54154, 'number of traj-s': 390, 'number of validation traj-s': 88}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:37<00:00,  6.59it/s]


Adjusting learning rate of group 0 to 1.6580e-05.
NN data gathering time: 99.36723351478577
json loading time:  3.5874202251434326
{'tr_prev_probs': 1.9889895915985107, 'graph_features': 0.0010306835174560547, 'tr_features': 0.017228126525878906, 'tr_graph_ids': 0.0012621879577636719, 'tr_queues': 2.3572640419006348, 'cat graph features': 5.7220458984375e-06, 'graphs_data': 8.821487426757812e-06}
j2torch time:  6.266549825668335
{'traj_length mean, median, max': ('167.77', 26.0, 1501), 'queue max length, idx': (64, 34580), 'number of train states': 50492, 'number of traj-s': 390, 'number of validation traj-s': 86}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:41<00:00,  5.98it/s]


Adjusting learning rate of group 0 to 1.6400e-05.
NN data gathering time: 97.91637325286865
json loading time:  4.064199209213257
{'tr_prev_probs': 2.4734601974487305, 'graph_features': 0.0007157325744628906, 'tr_features': 0.01978135108947754, 'tr_graph_ids': 0.0014514923095703125, 'tr_queues': 2.560220718383789, 'cat graph features': 5.7220458984375e-06, 'graphs_data': 3.5762786865234375e-06}
j2torch time:  7.111647844314575
{'traj_length mean, median, max': ('185.88', 31.0, 1387), 'queue max length, idx': (40, 47181), 'number of train states': 59621, 'number of traj-s': 391, 'number of validation traj-s': 85}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:51<00:00,  4.85it/s]


Adjusting learning rate of group 0 to 1.6220e-05.
NN data gathering time: 89.54605317115784
json loading time:  4.3928704261779785
{'tr_prev_probs': 2.2126007080078125, 'graph_features': 0.0008566379547119141, 'tr_features': 0.016707420349121094, 'tr_graph_ids': 0.001169443130493164, 'tr_queues': 3.1738641262054443, 'cat graph features': 5.9604644775390625e-06, 'graphs_data': 1.0728836059570312e-05}
j2torch time:  7.278268337249756
{'traj_length mean, median, max': ('171.71', 24.0, 1501), 'queue max length, idx': (58, 38418), 'number of train states': 47560, 'number of traj-s': 378, 'number of validation traj-s': 86}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:55<00:00,  4.51it/s]


Adjusting learning rate of group 0 to 1.6040e-05.
NN data gathering time: 98.24895095825195
json loading time:  3.883312940597534
{'tr_prev_probs': 1.7720491886138916, 'graph_features': 0.0008058547973632812, 'tr_features': 0.017233848571777344, 'tr_graph_ids': 0.0012660026550292969, 'tr_queues': 2.5211069583892822, 'cat graph features': 2.4557113647460938e-05, 'graphs_data': 7.152557373046875e-06}
j2torch time:  6.113448858261108
{'traj_length mean, median, max': ('184.88', 29.0, 1501), 'queue max length, idx': (53, 6364), 'number of train states': 56473, 'number of traj-s': 373, 'number of validation traj-s': 78}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:39<00:00,  6.29it/s]


Adjusting learning rate of group 0 to 1.5860e-05.
NN data gathering time: 104.43959712982178
json loading time:  4.5773162841796875
{'tr_prev_probs': 1.6445131301879883, 'graph_features': 0.0006575584411621094, 'tr_features': 0.01661229133605957, 'tr_graph_ids': 0.0012598037719726562, 'tr_queues': 2.5658953189849854, 'cat graph features': 6.198883056640625e-06, 'graphs_data': 4.291534423828125e-06}
j2torch time:  5.858936071395874
{'traj_length mean, median, max': ('183.33', 24.0, 1501), 'queue max length, idx': (74, 38635), 'number of train states': 52297, 'number of traj-s': 373, 'number of validation traj-s': 84}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:43<00:00,  5.78it/s]


Adjusting learning rate of group 0 to 1.5680e-05.
NN data gathering time: 91.84328484535217
json loading time:  2.819647789001465
{'tr_prev_probs': 1.6012725830078125, 'graph_features': 0.0007703304290771484, 'tr_features': 0.015758991241455078, 'tr_graph_ids': 0.0012736320495605469, 'tr_queues': 2.233349561691284, 'cat graph features': 6.9141387939453125e-06, 'graphs_data': 5.4836273193359375e-06}
j2torch time:  5.496665954589844
{'traj_length mean, median, max': ('167.41', 28.0, 1401), 'queue max length, idx': (49, 41206), 'number of train states': 51312, 'number of traj-s': 381, 'number of validation traj-s': 78}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:40<00:00,  6.16it/s]


Adjusting learning rate of group 0 to 1.5500e-05.
NN data gathering time: 93.66808795928955
json loading time:  3.025022029876709
{'tr_prev_probs': 1.536731481552124, 'graph_features': 0.0009644031524658203, 'tr_features': 0.015108823776245117, 'tr_graph_ids': 0.0011973381042480469, 'tr_queues': 1.7327001094818115, 'cat graph features': 6.198883056640625e-06, 'graphs_data': 9.059906005859375e-06}
j2torch time:  4.877683877944946
{'traj_length mean, median, max': ('159.18', 34.0, 1501), 'queue max length, idx': (38, 19580), 'number of train states': 47535, 'number of traj-s': 389, 'number of validation traj-s': 91}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:39<00:00,  6.40it/s]


Adjusting learning rate of group 0 to 1.5320e-05.
NN data gathering time: 89.05420637130737
json loading time:  5.146982431411743
{'tr_prev_probs': 1.8979394435882568, 'graph_features': 0.0007734298706054688, 'tr_features': 0.018247365951538086, 'tr_graph_ids': 0.0013849735260009766, 'tr_queues': 2.2173848152160645, 'cat graph features': 6.198883056640625e-06, 'graphs_data': 5.245208740234375e-06}
j2torch time:  6.014786005020142
{'traj_length mean, median, max': ('198.54', 29.0, 1501), 'queue max length, idx': (67, 42981), 'number of train states': 59561, 'number of traj-s': 381, 'number of validation traj-s': 90}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:40<00:00,  6.14it/s]


Adjusting learning rate of group 0 to 1.5140e-05.
NN data gathering time: 98.63908648490906
json loading time:  3.9301228523254395
{'tr_prev_probs': 1.5191614627838135, 'graph_features': 0.0006275177001953125, 'tr_features': 0.01449894905090332, 'tr_graph_ids': 0.001089334487915039, 'tr_queues': 2.221466064453125, 'cat graph features': 4.76837158203125e-06, 'graphs_data': 4.76837158203125e-06}
j2torch time:  5.269061803817749
{'traj_length mean, median, max': ('164.46', 28.5, 1501), 'queue max length, idx': (50, 8293), 'number of train states': 49203, 'number of traj-s': 398, 'number of validation traj-s': 87}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:38<00:00,  6.43it/s]


Adjusting learning rate of group 0 to 1.4960e-05.
NN data gathering time: 100.04251098632812
json loading time:  4.108374118804932
{'tr_prev_probs': 1.4839041233062744, 'graph_features': 0.0006630420684814453, 'tr_features': 0.01447153091430664, 'tr_graph_ids': 0.0011529922485351562, 'tr_queues': 1.7469539642333984, 'cat graph features': 5.0067901611328125e-06, 'graphs_data': 8.344650268554688e-06}
j2torch time:  4.735868692398071
{'traj_length mean, median, max': ('144.15', 26.0, 1501), 'queue max length, idx': (71, 32767), 'number of train states': 47281, 'number of traj-s': 384, 'number of validation traj-s': 78}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:41<00:00,  5.97it/s]


Adjusting learning rate of group 0 to 1.4780e-05.
NN data gathering time: 95.26145362854004
json loading time:  3.6900806427001953
{'tr_prev_probs': 1.5481030941009521, 'graph_features': 0.0006935596466064453, 'tr_features': 0.014743566513061523, 'tr_graph_ids': 0.0012111663818359375, 'tr_queues': 1.780893087387085, 'cat graph features': 5.7220458984375e-06, 'graphs_data': 5.0067901611328125e-06}
j2torch time:  4.916708707809448
{'traj_length mean, median, max': ('179.53', 28.0, 1501), 'queue max length, idx': (51, 1625), 'number of train states': 49612, 'number of traj-s': 386, 'number of validation traj-s': 82}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:36<00:00,  6.81it/s]


Adjusting learning rate of group 0 to 1.4600e-05.
NN data gathering time: 86.17546224594116
json loading time:  3.407054901123047
{'tr_prev_probs': 1.7525634765625, 'graph_features': 0.0007398128509521484, 'tr_features': 0.01697850227355957, 'tr_graph_ids': 0.0013592243194580078, 'tr_queues': 2.459259033203125, 'cat graph features': 1.2874603271484375e-05, 'graphs_data': 8.106231689453125e-06}
j2torch time:  5.972339153289795
{'traj_length mean, median, max': ('169.34', 29.5, 1501), 'queue max length, idx': (42, 24722), 'number of train states': 56308, 'number of traj-s': 400, 'number of validation traj-s': 80}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:35<00:00,  7.01it/s]


Adjusting learning rate of group 0 to 1.4420e-05.
NN data gathering time: 84.78720426559448
json loading time:  3.7100272178649902
{'tr_prev_probs': 1.7542977333068848, 'graph_features': 0.0010235309600830078, 'tr_features': 0.016382455825805664, 'tr_graph_ids': 0.0012412071228027344, 'tr_queues': 2.02129864692688, 'cat graph features': 8.344650268554688e-06, 'graphs_data': 1.1682510375976562e-05}
j2torch time:  5.549696445465088
{'traj_length mean, median, max': ('175.68', 29.5, 1501), 'queue max length, idx': (65, 40052), 'number of train states': 55205, 'number of traj-s': 394, 'number of validation traj-s': 82}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:40<00:00,  6.17it/s]


Adjusting learning rate of group 0 to 1.4240e-05.
NN data gathering time: 90.91972875595093
json loading time:  3.9110844135284424
{'tr_prev_probs': 1.6831626892089844, 'graph_features': 0.0006420612335205078, 'tr_features': 0.015718698501586914, 'tr_graph_ids': 0.0012555122375488281, 'tr_queues': 1.9239435195922852, 'cat graph features': 6.198883056640625e-06, 'graphs_data': 5.245208740234375e-06}
j2torch time:  5.867390871047974
{'traj_length mean, median, max': ('181.31', 27.0, 1348), 'queue max length, idx': (44, 44196), 'number of train states': 53895, 'number of traj-s': 393, 'number of validation traj-s': 88}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:37<00:00,  6.66it/s]


Adjusting learning rate of group 0 to 1.4060e-05.
NN data gathering time: 96.38421869277954
json loading time:  3.4093616008758545
{'tr_prev_probs': 1.5733697414398193, 'graph_features': 0.0006222724914550781, 'tr_features': 0.014539957046508789, 'tr_graph_ids': 0.0011663436889648438, 'tr_queues': 1.7721717357635498, 'cat graph features': 5.245208740234375e-06, 'graphs_data': 5.245208740234375e-06}
j2torch time:  5.018129825592041
{'traj_length mean, median, max': ('169.60', 25.0, 1501), 'queue max length, idx': (46, 33960), 'number of train states': 51678, 'number of traj-s': 382, 'number of validation traj-s': 82}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:36<00:00,  6.80it/s]


Adjusting learning rate of group 0 to 1.3880e-05.
NN data gathering time: 77.2356345653534
json loading time:  2.505060911178589
{'tr_prev_probs': 1.3271539211273193, 'graph_features': 0.0007164478302001953, 'tr_features': 0.013369083404541016, 'tr_graph_ids': 0.0011088848114013672, 'tr_queues': 1.4753475189208984, 'cat graph features': 4.76837158203125e-06, 'graphs_data': 5.7220458984375e-06}
j2torch time:  4.244013786315918
{'traj_length mean, median, max': ('128.15', 27.0, 1501), 'queue max length, idx': (27, 38808), 'number of train states': 41411, 'number of traj-s': 404, 'number of validation traj-s': 79}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:35<00:00,  7.05it/s]


Adjusting learning rate of group 0 to 1.3700e-05.
NN data gathering time: 91.68827652931213
json loading time:  3.325035810470581
{'tr_prev_probs': 2.2456579208374023, 'graph_features': 0.0008215904235839844, 'tr_features': 0.01660323143005371, 'tr_graph_ids': 0.0012972354888916016, 'tr_queues': 2.00858211517334, 'cat graph features': 5.9604644775390625e-06, 'graphs_data': 8.344650268554688e-06}
j2torch time:  6.039960861206055
{'traj_length mean, median, max': ('184.56', 30.0, 1459), 'queue max length, idx': (53, 36361), 'number of train states': 52811, 'number of traj-s': 383, 'number of validation traj-s': 89}


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 250/250 [00:38<00:00,  6.51it/s]


Adjusting learning rate of group 0 to 1.3520e-05.
NN data gathering time: 61.40294647216797



KeyboardInterrupt



In [None]:
# # No pretrain

# epochs = 150
# train_condition, message = (lambda tr: tr[0]%5 != 2), '%5!=2'

# for _ in [0]:
#     trajectories = Trajectories(train_condition=train_condition,
#                                 use_gnn=use_gnn,
#                                 gnn_out_nfeatures=gnn_out_nfeatures,
#                                 gnn_in_nfeatures=gnn_in_nfeatures)
#     run = wandb.init(
#           project="PS",
#           name=f'double_att, no_gnn, val {message}',
#           config={
#               'algorithm': 'PPO-clip',
#               'models': 'mlp + attn',
#           }
#     )
#     with open('../Game_env/jar_config.txt', 'w') as jar_config:
#       jar_config.write(json.dumps({'postprocessing': 'None', 
#                                    'dataConsumption': 100,
#                                    'maxAttentionLength': max_nactions,
#                                    'inputShape': [1, -1, features_dim],
#                                    'defaultAlgorithm': 'BFS',
#                                    "maxConcurrency": 120,
#                                    'useGnn': use_gnn}))    
#     trajectories.gather_n_store()
#     trajectories.evaluate_val_train()
#     print(trajectories.get_properties())

#     gnn, gnn_opt = get_gnn_setup(gnn_in_nfeatures, gnn_out_nfeatures)
#     actor, actor_opt, actor_sched = get_attn_setup(epochs=epochs, use_double=False)
#     critic, critic_opt = get_mlp_setup(use_FFM=True)

#     with open('../Game_env/jar_config.txt', 'w') as jar_config:
#       jar_config.write(json.dumps({'postprocessing': 'None', 
#                                    'dataConsumption': consumption_percent(0), # consumption_percent(0),
#                                    'maxAttentionLength': max_nactions,
#                                    'inputShape': [1, -1, features_dim],
#                                    "maxConcurrency": 120,
#                                    'useGnn': use_gnn}))
#     trajectories.gather_n_store(actor_model=actor, gnn_model=gnn)
#     trajectories.evaluate_val_train()
#     print(trajectories.get_properties())

#     for epoch in range(epochs):  
#         with open('../Game_env/jar_config.txt', 'w') as jar_config:
#           jar_config.write(json.dumps({'postprocessing': 'None', 
#                                        'dataConsumption': consumption_percent(epoch),
#                                        'maxAttentionLength': max_nactions,
#                                        "maxConcurrency": 120,
#                                        'inputShape': [1, -1, len(trajectories.feature_names) + (gnn_out_nfeatures if use_gnn else 0)],
#                                        'useGnn': use_gnn}))

#         trainer = NN_Trainer(NN_setup={'gnn': gnn, 'gnn_opt': gnn_opt,
#                                        'actor': actor, 'actor_opt': actor_opt, 'actor_sched': actor_sched,
#                                        'critic': critic, 'critic_opt': critic_opt,},
#                              trajectories=trajectories,
#                              n_batches=800,
#                              clip_eps = get_clip_eps(epoch),
#                              use_gnn=use_gnn,
#                              gnn_out_nfeatures=gnn_out_nfeatures,
#                              )
#         trainer.learn_new_policy()
#         wandb.log({'epoch': epoch,})
#         trajectories.gather_n_store(actor_model=actor, gnn_model=gnn)
#         trajectories.evaluate_val_train()
#         print(trajectories.get_properties())

#     checkpoint = {
#       # 'gnn': gnn,
#       'actor': actor,
#       'critic':critic,
#     }
#     torch.save(checkpoint, os.path.join(wandb.run.dir, f'actor, critic'))
#     wandb.finish()

In [None]:
# # Learn V

# train_condition, message = (lambda tr: 'com.' in tr[2]), 'com.'
# epochs=1

# trajectories = Trajectories(train_condition=train_condition,
#                             use_gnn=use_gnn,
#                             gnn_out_nfeatures=gnn_out_nfeatures,
#                             gnn_in_nfeatures=gnn_in_nfeatures)    
# # run = wandb.init(
# #       project="delete 3",
# #       name=f'Critic play, val {message}',
# #       config={
# #           'algorithm': 'PPO-clip',
# #           'models': 'mlp + attn',
# #       }
# # )

# actor, actor_opt, actor_sched = get_attn_setup(epochs=epochs, use_double=False)
# critic, critic_opt = get_mlp_setup(use_FFM=True)
# gnn, gnn_opt = get_gnn_setup(gnn_in_nfeatures, gnn_out_nfeatures)

# with open('../Game_env/jar_config.txt', 'w') as jar_config:
#   jar_config.write(json.dumps({'postprocessing': 'None', 
#                                'dataConsumption': consumption_percent(100),
#                                'maxAttentionLength': max_nactions,
#                                'inputShape': [1, -1, features_dim],
#                                'useGnn': use_gnn}))

# # trajectories.gather_n_store(actor_model=actor, gnn_model=gnn)
# trajectories.store_from_json()
# trajectories.evaluate_val_train()
# print(trajectories.get_properties())

# trainer = NN_Trainer(NN_setup={'gnn': gnn, 'gnn_opt': gnn_opt,
#                                    'actor': actor, 'actor_opt': actor_opt, 'actor_sched': actor_sched,
#                                    'critic': critic, 'critic_opt': critic_opt,},
#                          trajectories=trajectories,
#                          n_batches=int(2500 * consumption_percent(100)/100),
#                          clip_eps = get_clip_eps(100),
#                          use_gnn=use_gnn,
#                          gnn_out_nfeatures=gnn_out_nfeatures,
#                          )
# trainer.learn_v(5000)

In [None]:
# # Critic Play

# epochs = 50
# train_condition, message = (lambda tr: 'com.' in tr[2]), 'com.'

# for _ in [0]:
#     trajectories = Trajectories(train_condition=train_condition,
#                                 use_gnn=use_gnn,
#                                 gnn_out_nfeatures=gnn_out_nfeatures,
#                                 gnn_in_nfeatures=gnn_in_nfeatures)    
#     run = wandb.init(
#           project="delete 3",
#           name=f'Critic play, val {message}',
#           config={
#               'algorithm': 'PPO-clip',
#               'models': 'mlp + attn',
#           }
#     )
    
#     with open('../Game_env/jar_config.txt', 'w') as jar_config:
#       jar_config.write(json.dumps({'postprocessing': 'None', 
#                                    'dataConsumption': consumption_percent(100),
#                                    'maxAttentionLength': -1,
#                                    'inputShape': [-1, features_dim],
#                                    'defaultAlgorithm': 'ForkDepthRandom',
#                                    'useGnn': use_gnn}))
#     trajectories.gather_n_store()
#     # trajectories.store_from_json()
#     trajectories.evaluate_val_train()
#     print(trajectories.get_properties())

#     actor, actor_opt, actor_sched = get_attn_setup(epochs=epochs, use_double=False)
#     critic, critic_opt = get_mlp_setup(use_FFM=True)
#     q_net = Q_Net(V_function=critic, reward_ind=trajectories.j_file['scheme'][0].index('logReward'))
#     gnn, gnn_opt = get_gnn_setup(gnn_in_nfeatures, gnn_out_nfeatures)
    
#     with open('../Game_env/jar_config.txt', 'w') as jar_config:
#       jar_config.write(json.dumps({'postprocessing': 'None', 
#                                    'dataConsumption': consumption_percent(100),
#                                    'maxAttentionLength': max_nactions,
#                                    'inputShape': [1, -1, features_dim],
#                                    'useGnn': use_gnn}))
    
#     trajectories.gather_n_store(actor_model=actor, gnn_model=gnn)
#     trajectories.evaluate_val_train()
#     print(trajectories.get_properties())

#     for epoch in range(epochs):  
#         with open('../Game_env/jar_config.txt', 'w') as jar_config:
#           jar_config.write(json.dumps({'postprocessing': 'Argmax', 
#                                        'dataConsumption': consumption_percent(epoch),
#                                        'maxAttentionLength': max_nactions,
#                                        'inputShape': [-1, features_dim],
#                                        'useGnn': use_gnn}))
        
#         trainer = NN_Trainer(NN_setup={'gnn': gnn, 'gnn_opt': gnn_opt,
#                                        'actor': actor, 'actor_opt': actor_opt, 'actor_sched': actor_sched,
#                                        'critic': critic, 'critic_opt': critic_opt,},
#                              trajectories=trajectories,
#                              n_batches=int(2500 * consumption_percent(epoch)/100),
#                              clip_eps = get_clip_eps(epoch),
#                              use_gnn=use_gnn,
#                              gnn_out_nfeatures=gnn_out_nfeatures,
#                              )
#         trainer.learn_new_policy()
#         wandb.log({'epoch': epoch})
#         trajectories.gather_n_store(actor_model=q_net, gnn_model=gnn)
#         trajectories.evaluate_val_train()
#         print(trajectories.get_properties())

#     checkpoint = {
#       'gnn': gnn,
#       'actor': actor,
#       'critic':critic,
#     }
#     # torch.save(checkpoint, os.path.join(wandb.run.dir, f'actor, critic'))
#     wandb.finish()

In [10]:
exit()

<IPython.core.display.Javascript object>

<IPython.core.display.HTML object>
VBox(children=(Label(value='0.014 MB of 0.014 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max=1.0)))
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>
<IPython.core.display.HTML object>


### Data filter

In [None]:
# train_condition, message = (lambda tr: True), 'wtf'
# with open('../Game_env/jar_config.txt', 'w') as jar_config:
#       jar_config.write(json.dumps({'postprocessing': 'None', 
#                                    'dataConsumption': 100,
#                                    'maxAttentionLength': max_nactions,
#                                    'inputShape': [1, -1, features_dim],
#                                    'useGnn': use_gnn}))
# trajectories = Trajectories(train_condition=train_condition,
#                                 use_gnn=use_gnn,
#                                 gnn_out_nfeatures=gnn_out_nfeatures,
#                                 gnn_in_nfeatures=gnn_in_nfeatures)
# trajectories.store_from_json()

In [None]:
# # open('../Game_env/blacklist.txt', 'w').close()

# print(trajectories.get_properties())
# paths = trajectories.j_file['paths']

# import matplotlib.pyplot as plt
# lengths = [len(paths[i][1]) for i in range(len(paths))]
# lengths_t = torch.Tensor(lengths)
# print([lengths_t.quantile(10*i/100) for i in range(10)])
# # plt.hist([l for l in lengths if (l>0 and l<1502)], bins=500)
# # plt.show()

# max_queue_len = [max([len(q[0]) for q in paths[i][1]]) for i in range(len(paths))]
# max_queue_len_t = torch.Tensor(max_queue_len)
# print([max_queue_len_t.quantile(10*i/100) for i in range(11)])
# mean_queue_len = [torch.Tensor([len(q[0]) for q in paths[i][1]]).mean() for i in range(len(paths))]
# mean_queue_len_t = torch.Tensor(mean_queue_len)
# print([mean_queue_len_t.quantile(10*i/100) for i in range(11)])

# print(lengths_t[torch.nonzero(mean_queue_len_t<=1.1).squeeze()].sum())

# ids = torch.nonzero((mean_queue_len_t<=1.1)).long().squeeze()
# print(len(ids))
# bad_paths = [paths[i] for i in ids]
# [bad_paths[i][2] for i in range(len(bad_paths))][:10]

# with open('../Game_env/blacklist.txt', 'a') as f:
#   for i in range(len(ids)):
#     s = bad_paths[i][2]
#     f.write(s + '\n')

### The end