This is an implementation of PPO-clip for path selection for symbolic execution.
Each epoch we communicate with jar-file for data gathering and wandb for logging.

### Imports, meta

In [1]:
# %%capture

%env CUDA_DEVICE_ORDER=PCI_BUS_ID

from IPython.display import Javascript
def resize_colab_cell():
  display(Javascript('google.colab.output.setIframeHeight(0, true, {maxHeight: 600})'))
get_ipython().events.register('pre_run_cell', resize_colab_cell)

import numpy as np
from numpy import random
import copy
import inspect
import torch
from torch import nn
import torch.onnx
import json
from tqdm import tqdm, trange
from time import time
import os
import sklearn
from sklearn import tree
import math
from operator import itemgetter 


# !pip install wandb
import wandb

# !pip install onnx==1.12
# import onnx

wandb.login()

env: CUDA_DEVICE_ORDER=PCI_BUS_ID


[34m[1mwandb[0m: Currently logged in as: [33mandrey_podivilov[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

### Args (potentially immutable), login

In [2]:
batch_size = 256
max_nactions = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
td_gamma=0.99
json_path = '../Data/current_dataset.json'
use_fork_discount = False
batch_accumulation_steps = 1
cuda_sync = False

maybe_sync = torch.cuda.synchronize if cuda_sync else (lambda *args: None)
jar_command = '/home/st-andrey-podivilov/java16/usr/lib/jvm/bellsoft-java16-amd64/bin/java -Dorg.jooq.no-logo=true -jar ../Game_env/usvm-jvm/build/libs/usvm-jvm-new.jar ../Game_env/jar_config.txt > ../Game_env/jar_log.txt'
device

<IPython.core.display.Javascript object>

'cuda'

### Models, modules

In [3]:
class FFM_layer(torch.nn.Module):
    """
    wtf
    """
    def __init__(self, input_dim):
      super().__init__()
      assert input_dim%2 == 0, 'even input_dim is more convenient'
      self.fourier_matrix = torch.nn.Linear(input_dim, int(input_dim), bias=False)
      nn.init.normal_(
          self.fourier_matrix.weight,
          std=1/np.sqrt(input_dim),
      )
      self.fourier_matrix.weight.requires_grad_(False)

    def forward(self, x):
      pre = x # self.fourier_matrix(x)
      s = torch.sin(pre)
      c = torch.cos(pre)
      return torch.cat([x,s,c], dim=-1)
    
  
class Attn_model(torch.nn.Module):
  def __init__(
    self,
    d_model=256,
    n_heads=8,
    dim_feedforward=256, 
    dropout=0.0,    
    use_FFM=True,
  ):
    super(Attn_model, self).__init__()
    self.emb = nn.Sequential(
      nn.LazyLinear(512),
      nn.ReLU(),
      nn.LayerNorm(512),
      FFM_layer(512) if use_FFM else nn.Identity(),
      nn.LazyLinear(d_model),
      nn.ReLU(),
    )
    self.attn_EncoderLayer0 = nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout, batch_first=True)
    # self.attn_EncoderLayer1 = nn.TransformerEncoderLayer(d_model, n_heads, dim_feedforward, dropout, batch_first=True)
    self.head = nn.Sequential(
      nn.LazyLinear(1),
    )
    self.sfmax = nn.Softmax(dim=-1)
    
  def forward(self, x, mask=None):
    x = self.emb(x)
    x = self.attn_EncoderLayer0(x, src_key_padding_mask=mask)
    # x = self.attn_EncoderLayer1(x, src_key_padding_mask=mask)
    x = self.head(x).squeeze(-1)
    if mask is None:
      return self.sfmax(x)
    inf_mask = mask.float().masked_fill(mask==True, -float('inf'))    
    x = self.sfmax(x + inf_mask)
    return x
    
    
def get_mlp_setup(use_FFM=False,
                    lr = 3e-4,
                    wd=0.1,
                    ):
    mlp = nn.Sequential(
        nn.LazyLinear(512),
        nn.ReLU(),
        nn.Linear(512,256),
        nn.LayerNorm(256),
        FFM_layer(256) if use_FFM else nn.Identity(),
        nn.LazyLinear(512),
        nn.ReLU(),
        nn.Linear(512,512),
        nn.ReLU(),
        nn.Linear(512,1),
    ).to(device)
    mlp_opt = torch.optim.AdamW(mlp.parameters(), lr=lr, weight_decay=wd, betas=(0.9, 0.999))
    return mlp, mlp_opt
  
  
def get_attn_setup():
  attn_model = Attn_model().to(device)
  opt = torch.optim.AdamW(attn_model.parameters(), lr=3e-4, weight_decay=3e-2,)
  return attn_model, opt

<IPython.core.display.Javascript object>

### Logger

In [4]:
class Logger:
  """
  Supporting class, to be expanded.
  Intended to store logging utils and relevant data.
  """
  def __init__(
      self,
      between_logs = 32,
  ):
    self.grad_a = None
    self.grad_c = None
    self.weight_a = None
    self.weight_c = None
    self.log_gamma = torch.tensor(0.95).to(device)
    self.between_logs = between_logs
    self.timer = {}

  @torch.no_grad()
  def link_models(self,):
    self.grad_a = [p.grad.detach() for p in self.actor.parameters() if p.requires_grad]
    self.weight_a = [p.detach() for p in self.actor.parameters()]
    self.grad_c = [p.grad.detach() for p in self.critic.parameters() if p.requires_grad]
    self.weight_c = [p.detach() for p in self.critic.parameters()]

  @torch.no_grad()
  def list_norm(self, l, p=2):
    n = 0
    for t in l:
      n += t.detach().norm(p) ** p
    return n.item() ** (1/p)

  @torch.no_grad()
  def list_cos_dist(self, a, b):
    a_norm = self.list_norm(a, 2)
    b_norm = self.list_norm(b, 2)
    product = sum([torch.dot(torch.flatten(a[i]), torch.flatten(b[i])).item() for i in range(len(a))])
    return product/(a_norm*b_norm)

  @torch.no_grad()
  def on_list(self, a, b, operation):
    assert len(a) == len(b), 'lists lengths differ'
    return [operation(a[i], b[i]) for i in range(len(a))]

  @torch.no_grad()
  def running_mean(self, a, b):
    return [a[i].mul(self.log_gamma) + b[i].mul(1 - self.log_gamma) for i in range(len(a))]
  
  
logger = Logger()

<IPython.core.display.Javascript object>

### Data

In [5]:
class Trajectories:
  """
  Contains all kinds of data in a form of tensor.
  realized_tensors are raw and derived features of visited states.
  queues is a list of each states' actions features.
  Action and state embeddings are effectively the same, fyi. 
  """
  def __init__(self,
               td_gamma=td_gamma,
               train_condition = (lambda tr: tr[0]%5!=0),
              ):
    self.train_condition = train_condition
    self.td_gamma = td_gamma
    self.j_file = None
    self.feature_names, self.feature_names2ids = None, None
    self.realized_tensors, self.queues = None, None # list of 7 train tensors f, f_n, r, R, is_last, queue_length, chosen_actions + list of queue tensors
    self.sampled_ids_list = []
    self.sampled_queue_lengths_list = []
    
  
  def gather_n_store(self, model=None, json_path=json_path):
    """
    Plays, collects trajectories into json, then transforms json to tensors
    """
    self.update_data_on_path(model, json_path)
    self.store_from_json(json_path)
    
    
  def evaluate_val_train(self, verbose=True,):
    """
    Evaluates val and train subset
    """
    log_val = self.evaluate_data(eval_condition = (lambda tr: not self.train_condition(tr)),
                                 wandb_prefix = 'val', 
                                 verbose=verbose,)
    log_train = self.evaluate_data(eval_condition = self.train_condition,
                                   wandb_prefix = 'train',
                                   verbose=verbose,)
    return log_val, log_train
  
  
  def store_from_json(self, json_path=json_path):
    """
    Extracts trajectories from json in a torch-friendly form
    """
    self.j_file = json.load(open(json_path))
    self.feature_names = self.j_file['scheme'][0]
    self.feature_names2ids = {self.feature_names[i]:i for i in range(len(self.feature_names))}
    self.realized_tensors, self.queues = self.j2torch(self.j_file)
    
  
  def avg_code_length(self):
    train_lines = 0
    val_lines = 0
    for tr in self.j_file['paths']:
      if self.train_condition(tr):
        train_lines += tr[3]
      else:
        val_lines += tr[3]
    n_train_tr = self.get_properties()['number of traj-s']
    n_val_tr = self.get_properties()['number of validation traj-s']
    return train_lines/(n_train_tr-n_val_tr), val_lines/n_val_tr 

    
  def j2torch(self, j_file):
    """
    Transforms json to data tensors.
    Queues are flipped for reasons related to batch truncation and padding. Actions numeration is changed accordingly.
    """
    features, features_next, rewards, Returns, is_last, queue_lengths, chosen_actions, queues = [], [], [], [], [], [], [], []
    chosenStId_idx = self.j_file['scheme'].index('chosenStateId')
    rewards_idx = self.j_file['scheme'].index('reward')

    for tr in self.j_file['paths']:
      if not self.train_condition(tr):
        continue
      tr = tr[1]
      tr_rewards = [tr[i][rewards_idx] for i in range(len(tr))][1:] + [0]
      # we use features as a state emb and reward is what we get the next step
      rewards += tr_rewards
      do_discount = torch.ones(len(tr))
      if use_fork_discount:
        is_cfg_fork_idx = self.j_file['scheme'].index('is_cfg_fork')
        do_discount = [tr[i][is_cfg_fork_idx] for i in range(len(tr))]
      tr_Returns = self.tr_rewards_to_returns(tr_rewards, do_discount)
      Returns += tr_Returns
      
      tr_chosen_actions = [tr[i][chosenStId_idx] for i in range(len(tr))][1:] + [0]
      chosen_actions += tr_chosen_actions
      
      tr_features = [tr[i][0][tr[i][chosenStId_idx]] for i in range(len(tr))]
      is_last += [0]*(len(tr_features)-1) + [1]
      features += tr_features
      features_next += tr_features[1:] + [[-1]*len(self.feature_names)]

      tr_queues = [torch.Tensor(tr[i][0]).flip(dims=[0]) for i in range(len(tr))][1:] + [torch.zeros_like(torch.Tensor([tr[0][0][0]]))]
      tr_queue_lengths = [len(q) for q in tr_queues]
      queues += tr_queues
      queue_lengths += tr_queue_lengths
    rewards = torch.Tensor(rewards).to(device)
    features = torch.Tensor(features).to(device)
    features_next = torch.Tensor(features_next).to(device)
    Returns = torch.Tensor(Returns).to(device)
    is_last = torch.Tensor(is_last).to(device)
    queue_lengths = torch.LongTensor(queue_lengths).to(device)
    # we flip queue and numeration of actions
    chosen_actions = queue_lengths - torch.LongTensor(chosen_actions).to(device) - 1 
    return [features, features_next, rewards, Returns, is_last, queue_lengths, chosen_actions], queues
  

  def n_train_states(self):
    return len(self.realized_tensors[-1])

  
  def get_properties(self):
    queue_lengths = np.array([q.shape[0] for q in self.queues])
    longest_queue_ids = np.argmax(queue_lengths)
    tr_lengths = np.array([len(tr[1]) for tr in self.j_file['paths']])
    prop = {
            'traj_length mean, median, max': (f'{tr_lengths.mean():.2f}', np.median(tr_lengths), tr_lengths.max()),
            'queue max length, idx': (self.queues[longest_queue_ids].shape[0], longest_queue_ids),
            'number of train states': len(self.realized_tensors[-1]),
            'number of traj-s': len(self.j_file['paths']),
            'number of validation traj-s': sum([not self.train_condition(tr) for tr in self.j_file['paths']]),
           }
    return prop
  

  def tr_rewards_to_returns(self, tr_rewards, do_discount):
    tr_R = [0]*(len(tr_rewards))
    for i in range(len(tr_rewards)-2, -1, -1):
        tr_R[i] = tr_rewards[i] + (self.td_gamma**do_discount[i]) * tr_R[i+1]
    return tr_R
  
  
  def sample_ids(self, batch_size=batch_size):
    """
    Assembles ids for several consiquent batches based on queues lengths.
    """
    if len(self.sampled_ids_list) == 0:
      sampled_ids = torch.LongTensor(random.choice(self.n_train_states(), size=batch_size*batch_accumulation_steps)).to(device)
      sampled_queue_lengths = self.realized_tensors[5][sampled_ids]
      sampled_queue_lengths, sorting_ids = torch.sort(sampled_queue_lengths)
      sampled_ids = sampled_ids[sorting_ids]
      self.sampled_ids_list = list(torch.split(sampled_ids, batch_size))
      self.sampled_queue_lengths_list = list(torch.split(sampled_queue_lengths, batch_size))
    return self.sampled_ids_list.pop(0), self.sampled_queue_lengths_list.pop(0)
      

  def sample_batch(self, batch_size=batch_size):
    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()
  
    ids, queue_lengths = self.sample_ids(batch_size=batch_size)
    
    end.record()    
    maybe_sync()
    logger.timer['sample ids'] += start.elapsed_time(end)/1000
    
    sampled_realized = [t[ids] for t in self.realized_tensors]
    
    sampled_queues = itemgetter(*list(ids))(self.queues)
    padded_length = torch.minimum(queue_lengths.max(), torch.tensor(max_nactions))
    queues_tensor = torch.zeros(batch_size, padded_length, len(self.feature_names)).to(device)
    pad_mask = torch.ones(batch_size, padded_length).to(device)
    
    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()
    
    for i, q in enumerate(sampled_queues):
      l = min(q.shape[0], padded_length)
      queues_tensor[i, 0 : l, :] = q[: l, :]
      pad_mask[i, 0 : l] = 0
    pad_mask = pad_mask.bool()
    
    end.record()    
    maybe_sync()
    logger.timer['sample inside loop'] += start.elapsed_time(end)/1000

    return *sampled_realized, queues_tensor, pad_mask
  

  def update_data_on_path(self, model=None, path=json_path):
    """
    Communication with jar file on a server.
    """
    if model is None:
      # use default heuristics
      time_before = time()
      os.system('rm -f ../Game_env/model.onnx')
      os.system(jar_command)
      print('BFS data gathering time:', time() - time_before)
    else:
      x = torch.randn(1, 1, len(self.feature_names), requires_grad=True).to(device)
      torch_model = model.eval()
      torch_out = torch_model(x)
      torch.onnx.export(torch_model,
                        x,
                        '../Game_env/model.onnx',
                        opset_version=15,
                        export_params=True,
                        input_names = ['input'],   # the model's input names
                        output_names = ['output'],
                        dynamic_axes={'input' : {0 : 'batch_size',
                                                 1 : 'n_actions',
                                                },    # variable length axes
                                      'output' : {0 : 'batch_size',
                                                  1 : 'n_actions',
                                                },
                                      },
                        )
      os.system(jar_command)
  
    
  def evaluate_data(self,
           eval_condition = None,
           wandb_prefix = 'val',
           verbose=True,
           factors=torch.Tensor([1, 0.99]),
           ):
    if eval_condition is None:
        eval_condition = (lambda tr: not self.train_condition(tr))
    rewards_idx = self.j_file['scheme'].index('reward')
    for f in factors:
      size = 0
      tr_lengths = []
      trs_R = []
      for tr in self.j_file['paths']:
        if not eval_condition(tr):
          continue
        tr=tr[1]
        size += len(tr)
        tr_lengths += [len(tr)]
        # rewards list is not shifted, has a different purpose this time
        tr_rewards = [tr[i][rewards_idx] for i in range(len(tr))]
        trs_R += [0]
        do_discount = torch.Tensor([1]*len(tr))
        if use_fork_discount:
          is_cfg_fork_idx = self.j_file['scheme'].index('is_cfg_fork')
          do_discount = [tr[i][is_cfg_fork_idx] for i in range(len(tr))]
        for i in range(len(tr_rewards)-1, -1, -1):
          trs_R[-1] = tr_rewards[i] + (f ** do_discount[i]) * trs_R[-1]
      log = {}
      log[f'{wandb_prefix} size'] = size
      log[f'{wandb_prefix}_eval/mean {f:.2f} discount '] = torch.Tensor(trs_R).mean()
      log[f'{wandb_prefix}_eval/median {f:.2f} discount '] = torch.Tensor(trs_R).median()
      log[f'{wandb_prefix} Return by trjs {f:.2f} (previous epoch)'] = wandb.Histogram(np_histogram=np.histogram(trs_R, bins=30, ))
      # log[f'{wandb_prefix} lengths hist'] = wandb.Histogram(np_histogram=np.histogram(tr_lengths, bins=30, ))
      log['code length train mean'], log['code length val mean'] = self.avg_code_length()
      if verbose:
          wandb.log(log.copy())
      log['Returns'] = trs_R
      log[f'{wandb_prefix} lengths'] = tr_lengths
    return log

<IPython.core.display.Javascript object>

### Trainer


In [6]:
class NN_Trainer:
  def __init__(
      self,
      NN_setup,
      trajectories=None,
      batch_size=batch_size,
      n_batches=1000,
      target_update_steps = 15,
      td_gamma=td_gamma,
      clip_eps = 3e-1
      ):
    self.n_batches = n_batches
    self.batch_number = -1
    self.td_gamma = td_gamma
    self.clip_eps = clip_eps
    self.actor = NN_setup['actor'].train()
    self.actor_opt = NN_setup['actor_opt']
    self.prev_actor = copy.deepcopy(self.actor).eval()
    self.critic = NN_setup['critic'].train()
    self.target_critic = copy.deepcopy(self.critic).eval()
    self.critic_opt = NN_setup['critic_opt']
    self.trajectories = trajectories
    self.batch_size = batch_size
    self.target_update_steps = target_update_steps
    self.log = {}

  def get_each_loss(self,
               features,
               features_next,
               rewards,
               Returns,
               is_last,
               queue_lengths,
               chosen_actions,
               queues_tensor, 
               pad_mask,
    ):
    """
    Computes losses for actor, critic and exploration (loss_ent) within PPO algorithm.
    Decisions were made to avoid python loops at all costs -- 
    varying action space is not particularly batch-friendly.    
    """
    self.log['Returns mean'] = Returns.mean().item()
    # self.log['Return std'] = Returns.std().item()
    self.log['rewards mean'] = rewards.mean().item()
    self.log['queue length max'] = queue_lengths.max().item()
        
        
    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()
    
    probs = self.actor(queues_tensor, mask=pad_mask)
    with torch.no_grad():
      prev_probs = self.prev_actor(queues_tensor, mask=pad_mask) # can do once in epoch
    self.log['max prob mean'] = probs.max(dim=-1).values.mean().item()
    self.log['40 max prob quantile'] = torch.quantile(probs.max(dim=-1).values, 0.4).item()
    
    end.record()    
    maybe_sync()
    logger.timer['probs'] += start.elapsed_time(end)/1000

    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()
    
    values = self.critic(features).squeeze(-1)
    with torch.no_grad():
      next_values = self.target_critic(features_next).squeeze(-1)
    
    end.record()    
    maybe_sync()
    logger.timer['values'] += start.elapsed_time(end)/1000
        
    t_hist = time()
    self.log['V-func mean'] = torch.mean(values.detach()).item()
    self.log['V-func stdev'] = torch.std(values.detach()).item()
    # hist = wandb.Histogram(np_histogram=np.histogram(values.detach().to('cpu'), bins=40, ))
    # self.log['V-func hist'] = hist
    logger.timer['hist'] += time() - t_hist

    # critic loss
    t_critic_loss = time()
    TD = values - (rewards + next_values * self.td_gamma * (1-is_last))
    MC = (values - Returns).abs().mean()/10
    loss_c = (TD**2).mean().sqrt() # + MC
    logger.timer['critic loss'] += time()-t_critic_loss

    self.log['TD loss'] = (TD**2).mean().sqrt().item()
    self.log['MC loss'] = MC.item()
    # hist = wandb.Histogram(np_histogram=np.histogram(TD.detach().to('cpu'), bins=20, ))
    # self.log['TD hist'] = hist
        
    # entropy loss
    t_entropy_loss = time()
    
    entropies = - probs * torch.log(torch.max(torch.tensor(1e-40), probs))
    entropy_by_state_reg = torch.sum(entropies, dim=-1) / torch.log(torch.minimum(queue_lengths+1, torch.tensor(max_nactions+1))).to(device)
    
    loss_ent = -entropy_by_state_reg.mean()
        
    logger.timer['entropy loss'] += time()-t_entropy_loss
    
    # actor loss
    start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
    start.record()
    
    probs_chosen = probs[get_ass_dope_ids(chosen_actions)].squeeze(-1)
    prev_probs_chosen = prev_probs[get_ass_dope_ids(chosen_actions)].squeeze(-1)
        
    ratios = (probs_chosen / (prev_probs_chosen.detach()+1e-9)).to(device)
    clipped = torch.clip(ratios, min=1-self.clip_eps, max=1+self.clip_eps)
    Adv = - (TD * (1-is_last)).detach()
    loss_a = - torch.min(ratios*Adv, clipped*Adv).mean()
        
    self.log['clipped to all'] = (clipped*Adv < ratios*Adv).long().sum().item() / ratios.numel()
    self.log['ratios mean'] = ratios.mean().item()
    
    end.record()    
    maybe_sync()
    logger.timer['actor loss'] += start.elapsed_time(end)/1000
    
    return loss_a, loss_c, loss_ent



  def learn_new_policy(self, ):
    """
    Implements one learning cycle over collected dataset.
    """
    self.prev_actor = copy.deepcopy(self.actor).eval()
    logger.timer = {'probs': 0,
                    'values': 0,
                    'hist': 0,
                    'critic loss': 0,
                    'entropy loss': 0,
                    'actor loss': 0,
                    'total loss': 0,
                    'optimizers step': 0,
                    'loss.backward': 0,
                    'sample batch': 0,
                    'sample ids': 0,
                    'sample inside loop': 0,
                    }
    for i in trange(self.n_batches):
      self.batch_number = i
      if self.batch_number % self.target_update_steps == 0:
        self.target_critic = copy.deepcopy(self.critic).eval()
      
      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()
      
      sampled_batch = self.trajectories.sample_batch(self.batch_size)
      
      end.record()    
      maybe_sync()
      logger.timer['sample batch'] += start.elapsed_time(end)/1000
      
      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()      
      
      loss_a, loss_c, loss_ent = self.get_each_loss(*sampled_batch)
      loss = (loss_a + loss_c + loss_ent/50) / batch_accumulation_steps

      end.record()    
      maybe_sync()
      logger.timer['total loss'] += start.elapsed_time(end)/1000
      
      start, end = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
      start.record()   
      
      loss.backward()
      
      end.record()    
      maybe_sync()
      logger.timer['loss.backward'] += start.elapsed_time(end)/1000
      
      if self.batch_number % batch_accumulation_steps == 0:
        t_optimizers_step = time()
        torch.nn.utils.clip_grad_norm_(self.critic.parameters(), 30)
        torch.nn.utils.clip_grad_norm_(self.actor.parameters(), 30)
        self.actor_opt.step()
        self.critic_opt.step()
        
        self.log.update({
            'grad actor L2': logger.list_norm([p.grad for p in self.actor.parameters() if p.requires_grad], 2),
            'grad critic L2': logger.list_norm([p.grad for p in self.critic.parameters() if p.requires_grad], 2),
        })
        self.critic_opt.zero_grad()
        self.actor_opt.zero_grad()
        logger.timer['optimizers step'] += time()-t_optimizers_step
        

      if self.batch_number % (logger.between_logs+1) == 0:
        self.log.update({
            'loss actor': loss_a.item(),
            'loss critic': loss_c.item(),
            '-entropy by log(n)': loss_ent.item(),
            'weight actor L2': logger.list_norm([p for p in self.actor.parameters() if p.requires_grad], 2),
            'weight critic L2': logger.list_norm([p for p in self.critic.parameters() if p.requires_grad], 2),
        })
        wandb.log(self.log)
        self.log = {}

<IPython.core.display.Javascript object>

### Utils

In [7]:
def get_ass_dope_ids(l):
  l = l.long()
  if len(l.shape) == 1:
    l = l[:, None]
  ind = torch.LongTensor(np.indices(l.shape))
  ind[-1] = l
  return tuple(ind)

def consumption_percent(epoch):
  return 100 # 40 if (epoch<30) else 100

def get_clip_eps(epoch):
  return 0.5 if (epoch<5) else 0.2

<IPython.core.display.Javascript object>

### Procedure

In [None]:
epochs = 100
train_condition, message = (lambda tr: tr[0]%5!=1), 'h%5=0'

for _ in [0]:
    run = wandb.init(
          project="delete",
          name=f'consumption, val {message}',
          config={
              'algorithm': 'PPO-clip',
              'models': 'mlp + attn',
          }
    )

    # first we evaluate BFS heuristic (not to be confused with naive BFS)
    
    trajectories = Trajectories(train_condition=train_condition)

    
    with open('../Game_env/jar_config.txt', 'w') as jar_config:
      jar_config.write(json.dumps({'postprocessing': 'None', 
                                   'dataConsumption': consumption_percent(100),
                                   'maxAttentionLength': max_nactions,
                                   'inputShape': [-1, -1, -1]}))
    
    trajectories.gather_n_store()

    trajectories.evaluate_val_train()

    # then collect new json data file using randomly initialized policy neural network
    actor, actor_opt = get_attn_setup()
    critic, critic_opt = get_mlp_setup(use_FFM=True, wd=0.01)
    
    time0 = time()
    with open('../Game_env/jar_config.txt', 'w') as jar_config:
      jar_config.write(json.dumps({'postprocessing': 'None', 
                                   'dataConsumption': consumption_percent(0),
                                   'maxAttentionLength': max_nactions,
                                   'inputShape': [1, -1, len(trajectories.feature_names)]}))
    trajectories.gather_n_store(model=actor)
    print('Attn data gathering+storing: ', time()-time0)

    for epoch in range(epochs):  
        with open('../Game_env/jar_config.txt', 'w') as jar_config:
          jar_config.write(json.dumps({'postprocessing': 'None', 
                                       'dataConsumption': consumption_percent(epoch),
                                       'maxAttentionLength': max_nactions,
                                       'inputShape': [1, -1, len(trajectories.feature_names)]}))
        print(trajectories.get_properties())
        
        trainer = NN_Trainer(NN_setup={'actor': actor, 'actor_opt': actor_opt,
                                       'critic': critic, 'critic_opt': critic_opt,},
                             trajectories=trajectories,
                             n_batches=int(1000 * consumption_percent(epoch)/100),
                             clip_eps = get_clip_eps(epoch),
                             )
        trajectories.evaluate_val_train()
        trainer.learn_new_policy()

        wandb.log({'epoch': epoch,
                })

        time_before = time()
        trajectories.gather_n_store(model=actor)
        print('Data gathering time: ', time()-time_before)

    trajectories.evaluate_val_train()

    checkpoint = {
    'actor': actor,
    'critic':critic,
    }
    torch.save(checkpoint, os.path.join(wandb.run.dir, f'actor, critic'))
    wandb.finish()

<IPython.core.display.Javascript object>

BFS data gathering time: 167.7209391593933




Attn data gathering+storing:  154.76415181159973
{'traj_length mean, median, max': ('58.95', 11.0, 1501), 'queue max length, idx': (243, 31714), 'number of train states': 35956, 'number of traj-s': 735, 'number of validation traj-s': 136}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:54<00:00, 18.34it/s]


In [None]:
exit()

### Side Utils and commented

In [None]:
# def R_dif_max():
#     TrsBFS = Trajectories(path='../Data/BFS_dataset.json', eval_condition=(lambda x:False))
#     TrsNN = Trajectories(eval_condition=(lambda x:False))

#     Returns_BFS = TrsBFS.evaluate_data(factors=[1], eval_condition=(lambda x: True), verbose=False)['Returns']
#     Returns_NN = TrsNN.evaluate_data(factors=[1], eval_condition=(lambda x: True), verbose=False)['Returns']

#     Returns_dif = torch.Tensor(Returns_BFS) - torch.Tensor(Returns_NN)
#     max_idx = torch.argmax(Returns_dif)
#     max_dif = Returns_dif[max_idx]
#     assert TrsBFS.j_file['paths'][max_idx][2]==TrsNN.j_file['paths'][max_idx][2], 'wtf'
#     return TrsNN.j_file['paths'][max_idx][2], TrsBFS.j_file['paths'][max_idx][2], max_idx, max_dif, Returns_BFS[max_idx], Returns_NN[max_idx], len(Returns_dif)

# name_max, *a = R_dif_max()
# R_dif_max()

In [None]:
#@title Fit a tree


# to check features' strength
# r_tree = tree.DecisionTreeRegressor(max_depth=1000, )

# Features, _, _, R, _ =  Trajectories(json_path).trs_tensors
# r_tree.fit(Features.to('cpu'), R.to('cpu'))
# R_prediction = r_tree.predict(Features.to('cpu'))
# print(f'leaves: {r_tree.get_n_leaves()}, number of states: {Trajectories(json_path).n_sarsa_pairs}, depth: {totalr_tree.get_depth()}')
# torch.mean((R.to('cpu') - torch.Tensor(R_prediction))**2)