In [3]:
#library for data processing
import pandas as pd
from bs4 import BeautifulSoup
import string
import re
import itertools
import io
import json
import os
import ast
import time
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import transformers
import accelerate
from transformers import GPTNeoForCausalLM, GPT2Tokenizer, AdamW, get_scheduler, GPTNeoModel, GPT2LMHeadModel
from sentence_transformers import SentenceTransformer, CrossEncoder

import openai

#from torchrl.data import PrioritizedReplayBuffer, ReplayBuffer

import datasets
from datasets import load_dataset

import pickle
import gym

base_dir="<YOUR PATH>"
accelerator=accelerate.Accelerator()
device=accelerator.device

In [2]:
hotpot_train= load_dataset("hotpot_qa", name="fullwiki", split="train")
hotpot_val = load_dataset("hotpot_qa", name="distractor", split="validation")
hotpot_test= load_dataset("hotpot_qa", name="fullwiki", split="test")

#hotpot_train.save_to_disk(os.path.join(base_dir, 'data/hotpot_train_raw'))
#hotpot_val.save_to_disk(os.path.join(base_dir, 'data/hotpot_val_raw'))
#hotpot_test.save_to_disk(os.path.join(base_dir, 'data/hotpot_test_raw'))

Downloading builder script:   0%|          | 0.00/6.42k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/5.93k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/9.21k [00:00<?, ?B/s]

Downloading and preparing dataset hotpot_qa/fullwiki to /Users/hyungmoonko/.cache/huggingface/datasets/hotpot_qa/fullwiki/1.0.0/133b9501f892e5193babbad937bee3b4899deb4691ef4d791e6ac0111c875bb5...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/566M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/47.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/46.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7405 [00:00<?, ? examples/s]

Dataset hotpot_qa downloaded and prepared to /Users/hyungmoonko/.cache/huggingface/datasets/hotpot_qa/fullwiki/1.0.0/133b9501f892e5193babbad937bee3b4899deb4691ef4d791e6ac0111c875bb5. Subsequent calls will reuse this data.
Downloading and preparing dataset hotpot_qa/distractor to /Users/hyungmoonko/.cache/huggingface/datasets/hotpot_qa/distractor/1.0.0/133b9501f892e5193babbad937bee3b4899deb4691ef4d791e6ac0111c875bb5...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/46.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/90447 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/7405 [00:00<?, ? examples/s]

Dataset hotpot_qa downloaded and prepared to /Users/hyungmoonko/.cache/huggingface/datasets/hotpot_qa/distractor/1.0.0/133b9501f892e5193babbad937bee3b4899deb4691ef4d791e6ac0111c875bb5. Subsequent calls will reuse this data.


Found cached dataset hotpot_qa (/Users/hyungmoonko/.cache/huggingface/datasets/hotpot_qa/fullwiki/1.0.0/133b9501f892e5193babbad937bee3b4899deb4691ef4d791e6ac0111c875bb5)


In [4]:
sbert_name="multi-qa-MiniLM-L6-dot-v1"
n_layers=2

In [None]:
#environment for whole hotpot data ex) for hotpot train

#for real data: use format of {question, context} => 
#context={titles=[title1, title2,...], sentences=[[doc1_s1, doc1_s2,..], [doc2_s1, doc2_s2...], ..]}

class HotPotQAEnv():
  def __init__(self, hotpot_dataset, train):
    self.train=train
    self.hotpot_dataset=hotpot_dataset
    #tokenizer for early termination due to context window limit.
    self.tokenizer = transformers.AutoTokenizer.from_pretrained("sentence-transformers/{:s}".format(sbert_name))

    #single data environment
    #state variables
    self.question=None
    #passage has form of {'title': title, 'passage': passage}
    self.labels=None
    self.unselected_passages=None
    self.selected_passages=None
    self.accuracy=0
  
    #batch environment
    self.horizon=0
    self.batch_traj_step=0
    self.batch_questions=None
    self.batch_labels=None
    self.batch_unselected_passages=None
    self.batch_selected_passages=None

    #values for adaptive reward scheme
    self.batch_positive_rewards=None
    self.batch_negative_rewards=None

    self.batch_accuracies=None

    #normalizing(scaling) rewards
    self.norm=5

  def process_context(self, hotpot_data):
    #context holds candidate passage data.
    context=hotpot_data['context']

    passages=[]
    titles=context['title']
    sentences_list=context['sentences']
    #decompose document into every sentence => concept of taking many small steps
    passage_to_idx={}
    count=0
    for idx, title in enumerate(titles):
      #many sentences for a single document => multiple documents for one question.
      for sent_id, sentence in enumerate(sentences_list[idx]):
        passage={
            'index': count,
            'title': title,
            'passage': sentence
        }
        passages.append(passage)
        passage_to_idx[(title, sent_id)]=count
        count+=1

    #gold labels also present for eval and test sets
    labels=[]
    supp_titles=hotpot_data['supporting_facts']['title']
    supp_sent_ids=hotpot_data['supporting_facts']['sent_id']
    for supp_idx, supp_title in enumerate(supp_titles):
      supp_sent_id=supp_sent_ids[supp_idx]
      #there are some faulty data
      if (supp_title, supp_sent_id) in passage_to_idx.keys():
        label=passage_to_idx[(supp_title, supp_sent_id)]
        labels.append(label)
    return passages, labels
    
  def get_state_prompt(self, question, selected_passages):
    #for computing state value
    prompt="Question: "+question
    state_prompt=prompt+'\n\nSelected Passages:\n'
    #state prompt combines question and selected passages
    if len(selected_passages)==0:
      state_prompt+="Not Selected\n\n"
    for selected_passage in selected_passages:
      title=selected_passage['title']
      passage=selected_passage['passage']
      state_prompt+=("Document Title: "+title+"\nPassage: "+passage+'\n\n')
    return state_prompt
  
  def get_action_prompts(self, question, unselected_passages):
    prompt="Question: "+question
    #has one prompt for each unselected passage
    action_prompts=[]
    for unselected_passage in unselected_passages:
      action_prompt=prompt+"\n\nSelected Passage:\n"
      title=unselected_passage['title']
      passage=unselected_passage['passage']
      action_prompt+=("Document Title: "+title+"\nPassage: "+passage)
      action_prompts.append(action_prompt)
    return action_prompts
  
  def get_actor_prompt(self):
    state_prompt=self.get_state_prompt(self.question, self.unselected_passages)
    action_prompts=self.get_action_prompts(self.question, self.unselected_passages)
    return state_prompt, action_prompts
  
  def reset(self):
    #get random hotpot data
    random_idx=np.random.randint(low=0, high=len(self.hotpot_dataset))
    hotpot_data=self.hotpot_dataset[random_idx]
    self.question=hotpot_data['question']
    self.unselected_passages, self.labels=self.process_context(hotpot_data)
    self.selected_passages=[]
    self.accuracy=0
    state=self.get_state()
    return state
  
  def get_state(self):
    q=self.question
    sp=self.selected_passages[:]
    usp=self.unselected_passages[:]
    state={
        'question': q,
        'selected_passages': sp,
        'unselected_passages': usp
    }
    return state
  
  #action idx refers to relative index within unselected passages. 0 (stop selecting) will always remain in index 0.
  def step(self, action_idx):
    #if you run out of unselected passages: recieve reward of automatically answer at this step
    if len(self.unselected_passages)==0:
      action_idx=0
    
    #answering action
    if action_idx==0:
      #termination: agent considers that the seleted passages are sufficient
      termin_signal=True
      reward, accuracy=self.give_final_reward(self.selected_passages, self.labels)
      self.accuracy=accuracy
    #choosing another passage
    else:
      #keep selecting action => # of available actions at every state: # of unselected passages+1
      termin_signal=False
      reward=-0.1
    #move selected action to selected passages
    selected_passage=self.unselected_passages.pop(action_idx)
    self.selected_passages.append(selected_passage)
    state_f=self.get_state()
    return reward, state_f, termin_signal
  
  def get_accuracy(self, selected_passages, labels):
    #give final reward based on how many labels they got correct. + should select no more
    len_sp=len(selected_passages)
    len_labels=len(labels)
    count=0
    #count correctly selected passages
    for sp in selected_passages:
      if sp['index'] in labels:
        count+=1
    accuracy=count/len_sp
    return count, accuracy
  
  def get_batch_actor_prompts(self, mask):
    batch_state_prompts=[]
    batch_action_prompts_list=[]
    for idx, question in enumerate(self.batch_questions):
      masked=mask[idx]
      #get only for the environments s.t. mask==True (episode not terminated)
      if masked==False:
        state_prompt=self.get_state_prompt(question, self.batch_selected_passages[idx])
        action_prompts=self.get_action_prompts(question, self.batch_unselected_passages[idx])
        batch_state_prompts.append(state_prompt)
        batch_action_prompts_list.append(action_prompts)
    return batch_state_prompts, batch_action_prompts_list
  
  def get_batch_states(self):
    #return list of states
    batch_states=[]
    for idx, question in enumerate(self.batch_questions):
      state={
          'question': question,
          'unselected_passages': self.batch_unselected_passages[idx][:],
          'selected_passages': self.batch_selected_passages[idx][:]
      }
      batch_states.append(state)
    return batch_states
  
  def batch_reset(self, batch_indices, horizon):
    #horizon: max trajectory length
    self.horizon=horizon
    self.batch_traj_step=0
    batch_size=len(batch_indices)
    batch_hotpot_data=self.hotpot_dataset[batch_indices]
    self.batch_questions=batch_hotpot_data['question']
    contexts=batch_hotpot_data['context']
    spfs=batch_hotpot_data['supporting_facts']
    #get passages data
    batch_usp=[]
    batch_sp=[]
    batch_labels=[]
    batch_pos_ars=[]
    batch_neg_ars=[]
    #adaptive reward scheme
    for idx, context in enumerate(contexts):
      spf=spfs[idx]
      #process context as in single env. step
      passages, labels=self.process_context({'context': context, 'supporting_facts': spf})
      batch_sp.append([])
      batch_usp.append(passages)
      batch_labels.append(labels)
      len_labels=len(labels)
      len_psgs=len(passages)
      pos_reward=1
      neg_reward=-0.2
      #optimal_return=optimal_return_ftn(len_psgs, len_labels)
      #pos_reward=(1/self.norm)*optimal_return/(len_labels+1)
      #neg_reward=-(1/self.norm)*len_labels*(len_psgs+1)*optimal_return/(len_psgs*(len_labels+1)*(len_psgs-len_labels))
      batch_pos_ars.append(pos_reward)
      batch_neg_ars.append(neg_reward)
    self.batch_positive_rewards=batch_pos_ars
    self.batch_negative_rewards=batch_neg_ars
    #print(batch_ars)
    self.batch_unselected_passages=batch_usp
    self.batch_selected_passages=batch_sp
    self.batch_labels=batch_labels
    #return batch states
    batch_states=self.get_batch_states()

    #reset accuracy
    self.batch_accuracies=[0 for _ in range(batch_size)]
    return batch_states
  
  def batch_step(self, batch_size, action_indices, mask):
    assert self.batch_questions!=None, "Run batch_reset() first to use step"
    self.batch_traj_step+=1 #increment trajectory step

    batch_rewards=[]
    batch_termin_signals=[]
    batch_accs=[]
    unpack_idx=0
    for idx, unselected_passages in enumerate(self.batch_unselected_passages):
      #take a environment step if episode is running
      if mask[idx]==False:
        selected_passages=self.batch_selected_passages[idx]
        #relative index w.r.t currently unselected passages
        action_idx=action_indices[unpack_idx]
        unpack_idx+=1
        #get action passage
        action_psg=unselected_passages.pop(action_idx)
        selected_passages.append(action_psg)
        #updating state
        self.batch_unselected_passages[idx]=unselected_passages
        self.batch_selected_passages[idx]=selected_passages

        #check for accuracy and correct count
        correct_count, accuracy=self.get_accuracy(selected_passages, self.batch_labels[idx])
        n_labels=len(self.batch_labels[idx])
        #check exceeding context window
        state_prompt=self.get_state_prompt(self.batch_questions[idx][:], self.batch_selected_passages[idx][:])
        n_tokens=self.tokenizer(state_prompt, return_tensors='pt').input_ids.size(1)
        #terminate if all correct psg.s are found or horizon is reached.
        if correct_count==n_labels:
          self.batch_accuracies[idx]=accuracy
          #termination reward => should be positive only if everything is correct
          termin_signal=True
          reward=self.batch_positive_rewards[idx]
        elif self.batch_traj_step==self.horizon or n_tokens>512:
          self.batch_accuracies[idx]=accuracy
          #termination reward => should be positive only if everything is correct
          termin_signal=True
          if accuracy==1:
            reward=self.batch_positive_rewards[idx]
          else:
            reward=self.batch_negative_rewards[idx]
        else:
          termin_signal=False
          if action_psg['index'] in self.batch_labels[idx]:
            reward=self.batch_positive_rewards[idx] #large adaptive reward for correct answer
          else:
            reward=self.batch_negative_rewards[idx] #negative reward for wrong answer.
      #for terminated episode: give reward of 0 and keep termin signal as True => defined as a padding episode step.
      else:
        #state remains the same for masked trajectories
        reward=0
        termin_signal=True
      batch_rewards.append(reward)
      batch_termin_signals.append(termin_signal)
    batch_state_fs=self.get_batch_states()
    return batch_rewards, batch_state_fs, batch_termin_signals

In [None]:
class StateValueModule(nn.Module):
  def __init__(self):
    super(StateValueModule, self).__init__()
    #state value module takes question + selected passages as input
    #state value module estmiates the future return not given the available actions.
    #quality of available actions (passages) is the role of retriever.
    self.tokenizer = transformers.AutoTokenizer.from_pretrained("sentence-transformers/{:s}".format(sbert_name))
    self.model = transformers.AutoModel.from_pretrained("sentence-transformers/{:s}".format(sbert_name)).to(device)
    layers=self.model.encoder.layer
    new_layers=layers[:n_layers]
    self.model.encoder.layer=new_layers.to(device)
    self.head1=nn.Linear(384,128).to(device)
    self.head2=nn.Linear(128,1).to(device)

    self.layernorm=nn.LayerNorm(128).to(device)
    self.element_init()
  
  def element_init(self):
    nn.init.kaiming_normal_(self.head1.weight)
    nn.init.zeros_(self.head1.bias)
    nn.init.kaiming_normal_(self.head2.weight)
    nn.init.zeros_(self.head2.bias)
    return
  
  def forward(self, batch_state_prompts):
    #module designed for returning batch of state values
    relu=nn.ReLU()
    dropout=nn.Dropout(p=0.2).to(device)
    mlp=nn.Sequential(self.head1, self.layernorm, relu, dropout, self.head2)

    #forward pass
    tokenized= self.tokenizer(batch_state_prompts, padding=True, truncation=True, return_tensors='pt')
    input_ids, attention_mask, token_type_ids=tokenized.input_ids.to(device), tokenized.attention_mask.to(device), tokenized.token_type_ids.to(device)
    model_output=self.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, return_dict=True)
    embeddings=model_output.pooler_output
    norms = torch.norm(embeddings, dim=1, keepdim=True)
    normalized_embeddings=embeddings/norms
    state_values=mlp(normalized_embeddings).squeeze(dim=1)
    return state_values

In [None]:
class RLQA_Actor(nn.Module):
  def __init__(self):
    super(RLQA_Actor, self).__init__()
    self.tokenizer = transformers.AutoTokenizer.from_pretrained("sentence-transformers/{:s}".format(sbert_name))
    self.state_embedder = transformers.AutoModel.from_pretrained("sentence-transformers/{:s}".format(sbert_name)).to(device)
    layers=self.state_embedder.encoder.layer
    new_layers=layers[:n_layers]
    self.state_embedder.encoder.layer=new_layers.to(device)
    #available actions=unselected passages
    self.action_embedder=transformers.AutoModel.from_pretrained("sentence-transformers/{:s}".format(sbert_name)).to(device)
    layers=self.action_embedder.encoder.layer
    new_layers=layers[:n_layers]
    self.action_embedder.encoder.layer=new_layers.to(device)

    #alternate: change to a single embedder+ couldn't find pretrained model smaller than multi-qa-MiniLM-L6-dot-v1
    #self.embedder=transformers.AutoModel.from_pretrained("sentence-transformers/multi-qa-MiniLM-L6-dot-v1").to(device)

  #unit operations
  def get_single_action_prob(self, state_prompt, action_prompts, action_idx, temperature=1):
    categorical=self.single_forward(state_prompt, action_prompts, temperature)
    action_prob=categorical.probs[action_idx]
    return action_prob
  
  def single_forward(self, state_prompt, action_prompts, temperature=1):
    state_embedding=self.state_embedder.encode(state_prompt, convert_to_tensor=True, normalize_embeddings=True)
    action_embeddings=self.action_embedder.encode(action_prompts, convert_to_tensor=True, normalize_embeddings=True)

    logits=state_embedding.matmul(action_embeddings.permute(1,0))
    if temperature==0:
      argmax_idx=torch.argmax(logits, dim=-1)
      new_probs=torch.zeros(logits).to(device)
      new_probs[argmax_idx]=0
      categorical=torch.distributions.categorical.Categorical(probs=new_probs)
    else:
      categorical=torch.distributions.categorical.Categorical(logits=logits)
    return categorical

  def get_action_probs(self, batch_state_prompts, batch_action_prompts_list, action_index_list, temperature=1):
    #compute categoricals based on trajectory
    categoricals=self.forward(batch_state_prompts, batch_action_prompts_list, temperature)
    action_probs=[] #list of tensors
    for idx, categorical in enumerate(categoricals):
      action_prob=categorical.probs[action_index_list[idx]]
      action_probs.append(action_prob.unsqueeze(dim=0))
    #return as a single tensor.
    action_probs_tensor=torch.cat(action_probs, dim=0)
    return categoricals, action_probs_tensor
    
  #batch operations
  def forward(self, batch_state_prompts, batch_action_prompts_list, temperature=1):
    #input: multiple states => each state has multiple available actions (numbers aren't equal)
    #concatenate all batch_action_prompts
    batch_action_prompts_whole=[]
    batch_action_intervals=[]
    start_index=0
    for baps in batch_action_prompts_list:
      batch_action_prompts_whole.extend(baps)
      start=start_index
      end=start_index+len(baps)
      batch_action_intervals.append([start, end])
      start_index+=len(baps)
    
    #compute state embeddings
    states_tokenized=self.tokenizer(batch_state_prompts, padding=True, truncation=True, return_tensors='pt')
    states_input_ids, states_attention_mask, states_token_type_ids=states_tokenized.input_ids.to(device), states_tokenized.attention_mask.to(device), states_tokenized.token_type_ids.to(device)
    state_embedder_output=self.state_embedder(input_ids=states_input_ids, attention_mask=states_attention_mask, token_type_ids=states_token_type_ids, return_dict=True)
    state_embeddings=state_embedder_output.pooler_output
    state_norms=torch.norm(state_embeddings, dim=1, keepdim=True)
    normalized_state_embeddings=state_embeddings/state_norms

    #compute action embeddings
    actions_tokenized=self.tokenizer(batch_action_prompts_whole, padding=True, truncation=True, return_tensors='pt')
    actions_input_ids, actions_attention_mask, actions_token_type_ids=actions_tokenized.input_ids.to(device), actions_tokenized.attention_mask.to(device), actions_tokenized.token_type_ids.to(device)
    action_embedder_output=self.action_embedder(input_ids=actions_input_ids, attention_mask=actions_attention_mask, token_type_ids=actions_token_type_ids, return_dict=True)
    action_embeddings_whole=action_embedder_output.pooler_output
    action_norms=torch.norm(action_embeddings_whole, dim=1, keepdim=True)
    normalized_action_embeddings_whole=action_embeddings_whole/action_norms
    #torch.split doesn't retain gradients.
    #action_embeddings_list=torch.split(action_embeddings_whole, batch_action_lengths, dim=0)

    #generate categorical objects
    categoricals=[]
    for idx, state_embedding in enumerate(normalized_state_embeddings):
      batch_action_interval=batch_action_intervals[idx]
      action_embeddings=normalized_action_embeddings_whole[batch_action_interval[0]:batch_action_interval[1]]
      logits=state_embeddings[idx].matmul(action_embeddings.permute(1,0))
      #temperauter=0 => assigns prob 0 everywhere except argmax index
      if temperature==0:
        argmax_idx=torch.argmax(logits, dim=-1)
        new_probs=torch.zeros(logits).to(device)
        new_probs[argmax_idx]=1
        categorical=torch.distributions.categorical.Categorical(probs=new_probs)
      else:
        categorical=torch.distributions.categorical.Categorical(logits=logits)
      categoricals.append(categorical)
    return categoricals

In [None]:
class RLQA_Agent(nn.Module):
  def __init__(self):
    super(RLQA_Agent, self).__init__()
    self.actor=RLQA_Actor()
    self.critic=StateValueModule() #used for GAE estimation.
  
  def load_actor(self, actor_name):
    self.actor.load_state_dict(torch.load(os.path.join(base_dir, 'models/{:s}.pt'.format(actor_name))))
    return
  
  def load_state_value_module(self, V_name):
    self.V.load_state_dict(torch.load(os.path.join(base_dir, 'models/{:s}.pt'.format(V_name))))
    return
  
  def single_forward(self, state_prompt, action_prompts, temperature=1):
    if temperature==0: #greedy sampling
      categorical=self.actor.single_forward(state_prompt, action_prompts, 1) #use 1 and argmax
    else:
      categorical=self.actor.single_forward(state_prompt, action_prompts, temperature)
    
    #take indices
    if temperature==0:
      #greedy sampling
      action_idx=torch.argmax(categorical.logits, dim=-1)
      action_prob=1
    else:
      action_idx=categorical.sample()
      action_prob=categorical.probs[action_idx]
    return action_idx, action_prob

  #get action_indices & action_probs
  def forward(self, batch_state_prompts, batch_action_prompts_list, temperature=1, no_grad=False):
    action_indices=[]
    action_probs=[]
    #default: sampling w/o temperature
    if temperature==0: #greedy sampling
      if no_grad:
        with torch.no_grad():
          categoricals=self.actor(batch_state_prompts, batch_action_prompts_list, 1) #use 1 and argmax
    else:
      if no_grad:
        with torch.no_grad():
          categoricals=self.actor(batch_state_prompts, batch_action_prompts_list, temperature)
    
    #get action indices through sampling and get corres. probabilities
    for categorical in categoricals:
      action_idx=categorical.sample()
      action_prob=categorical.probs[action_idx]
      action_indices.append(action_idx)
      action_probs.append(action_prob)
    #action_probs_tensor=torch.cat([ap.unsqueeze(dim=0) for ap in action_probs], dim=0).to(device)
    return categoricals, action_indices, action_probs

In [None]:
#used for on-policy & off-policy training
class BatchEpisodeSteps():
  def __init__(self, states, action_indices, action_probs, rewards, state_fs, termin_signals, categoricals, masks):
    #state => but "state prompt" only holds question and selected passages: used for state value
    self.states=states
    self.state_prompts, self.action_prompts_list=self.get_actor_prompts(self.states, masks)

    #action related
    self.action_indices=self.unpack_list(action_indices, masks) #relative to unselected passages
    self.action_probs=self.unpack_list(action_probs, masks)

    self.rewards=rewards

    self.state_fs=state_fs
    #termin_signals = new masks
    self.state_f_prompts, self.action_f_prompts_list=self.get_actor_prompts(self.state_fs, termin_signals)

    self.termin_signals=termin_signals

    #used for kl-divergence computation.
    self.categoricals=self.unpack_list(categoricals, masks)

    self.masks=masks
    
  def __getitem__(self, idx):
    ep_step={
        'state': self.states[idx],
        'state_prompt': self.state_prompts[idx],
        'action_prompts': self.action_prompts_list[idx],
        'action_index': self.action_indices[idx],
        'action_prob': self.action_probs[idx],
        'reward': self.rewards[idx],
        'state_f': self.state_fs[idx],
        'state_f_prompt': self.state_f_prompts[idx],
        'action_f_prompts': self.action_f_prompts_list[idx],
        'termin_signal': self.termin_signals[idx],
        'categorical': self.categoricals[idx],
        'mask': self.masks[idx]
    }
    return ep_step
  
  #unpack to be same with mask length.
  def unpack_list(self, props, mask):
    unpacked_props=[None for _ in range(len(mask))]
    props_idx=0
    for m_idx, masked in enumerate(mask):
      if not masked:
        unpacked_props[m_idx]=props[props_idx]
        props_idx+=1
    return unpacked_props
  
  def unpack_tensor(self, tensor, mask):
    unpacked_tensor=torch.zeros(len(mask)).to(device)
    tensor_idx=0
    for m_idx, masked in enumerate(mask):
      if not masked:
        unpacked_tensor[m_idx]=tensor[tensor_idx]
        tensor_idx+=1
    return unpacked_tensor
  
  def get_state_prompt(self, question, selected_passages):
    #for computing state value
    prompt="Question: "+question
    state_prompt=prompt+'\n\nSelected Passages:\n'
    #state prompt combines question and selected passages
    if len(selected_passages)==0:
      state_prompt+="Not Selected\n\n"
    for selected_passage in selected_passages:
      title=selected_passage['title']
      passage=selected_passage['passage']
      state_prompt+=("Document Title: "+title+"\nPassage: "+passage+'\n\n')
    return state_prompt
  
  def get_action_prompts(self, question, unselected_passages):
    prompt="Question: "+question
    #has one prompt for each unselected passage
    action_prompts=[]
    for unselected_passage in unselected_passages:
      action_prompt=prompt+"\n\nSelected Passage:\n"
      title=unselected_passage['title']
      passage=unselected_passage['passage']
      action_prompt+=("Document Title: "+title+"\nPassage: "+passage)
      action_prompts.append(action_prompt)
    return action_prompts
    
  def get_actor_prompts(self, states, mask):
    state_prompts=[]
    action_prompts_list=[]
    for idx, state in enumerate(states):
      state_prompt=self.get_state_prompt(state['question'], state['selected_passages'])
      action_prompts=self.get_action_prompts(state['question'], state['unselected_passages'])
      state_prompts.append(state_prompt)
      action_prompts_list.append(action_prompts)
    return state_prompts, action_prompts_list

In [None]:
class Trajectory():
  def __init__(self):
    self.states=[]
    self.state_prompts=[]
    self.action_prompts_list=[]
    self.action_indices=[]
    self.action_probs=[]
    self.rewards=[]
    self.state_fs=[]
    self.state_f_prompts=[]
    self.action_f_prompts_list=[]
    self.termin_signals=[]
    self.categoricals=[]
    self.masks=[]

    self.finished=False

    self.len_traj=0
  
  def get_discounted_cumsum_matrix(self, N, discount_rate):
    dcm=torch.eye(N).to(device)
    for offset in range(1,N):
      max_row=N-offset
      for row in range(max_row):
        col=row+offset
        dcm[row, col]=discount_rate**offset
    return dcm
  
  def compute_rtgs(self):
    dcm=torch.triu(torch.ones(self.len_traj, self.len_traj)).to(device)
    rtgs=dcm.matmul(self.rewards)
    self.rtgs=rtgs
    return
  
  def compute_gaes(self, V_s, V_fs, lambd):
    ts=torch.BoolTensor(self.termin_signals).float().to(device)
    tdes=self.rewards+(1-ts)*V_fs-V_s
    dcm=self.get_discounted_cumsum_matrix(self.len_traj, lambd) #undiscounted finite horizon setting
    gaes=dcm.matmul(tdes)
    return gaes

  def finish_trajectory(self):
    self.rewards=torch.FloatTensor(self.rewards).to(device)
    self.action_probs=torch.FloatTensor(self.action_probs).to(device)
    return

  def add_ep_step(self, ep_step):
    self.states.append(ep_step['state'])
    self.state_prompts.append(ep_step['state_prompt'])
    self.action_prompts_list.append(ep_step['action_prompts'])
    self.action_indices.append(ep_step['action_index'])
    self.action_probs.append(ep_step['action_prob'].item())
    self.rewards.append(ep_step['reward'])
    self.state_fs.append(ep_step['state_f'])
    self.state_f_prompts.append(ep_step['state_f_prompt'])
    self.action_f_prompts_list.append(ep_step['action_f_prompts'])
    self.termin_signals.append(ep_step['termin_signal'])
    self.categoricals.append(ep_step['categorical'])
    self.masks.append(ep_step['mask'])
    if ep_step['termin_signal']==True:
      self.finish_trajectory()
      self.len_traj=len(self.states)
      #print(self.len_traj)
      self.compute_rtgs()
      self.finished=True
    return

In [None]:
#used in training phase only
class FlatTrajectory():
  def __init__(self, batch_ep_steps_list):
    #to store flattened episode steps
    self.states=[]
    self.state_prompts=[]
    self.action_prompts_list=[]
    self.action_indices=[]
    self.action_probs=[]
    self.rewards=[]
    self.state_fs=[]
    self.state_f_prompts=[]
    self.action_f_prompts_list=[]
    self.termin_signals=[]
    self.masks=[]
    self.categoricals=[]

    self.len_trajs=[]
    self.n_steps=0

    #combine batch episode steps
    self.batch_ep_steps_list=batch_ep_steps_list
    self.trajectories=self.get_trajectories(batch_ep_steps_list)
    self.get_flat_trajectory()

  def get_discounted_cumsum_matrix(self, N, discount_rate):
    dcm=torch.eye(N).to(device)
    for offset in range(1,N):
      max_row=N-offset
      for row in range(max_row):
        col=row+offset
        dcm[row, col]=discount_rate**offset
    return dcm

  def compute_gaes(self, V_s, V_fs, lambd):
    #discount cumsum matrix for gae computation: block diagonal
    gae_dcm_list=[self.get_discounted_cumsum_matrix(len_traj, lambd) for len_traj in self.len_trajs]
    gae_dcm=torch.block_diag(*gae_dcm_list)
    termin_signal=torch.BoolTensor(self.termin_signals).float().to(device)
    tdes=self.rewards+(1-termin_signal)*V_fs-V_s
    gaes=torch.matmul(gae_dcm, tdes)
    return gaes
  
  def compute_rtgs(self, dcm_list):
    rtg_dcm=torch.block_diag(*dcm_list).to(device)
    self.rtgs=torch.matmul(rtg_dcm, self.rewards)
    return

  def get_flat_trajectory(self):
    dcm_list=[]
    for trajectory in self.trajectories:
      self.states.extend(trajectory.states)
      self.state_prompts.extend(trajectory.state_prompts)
      self.action_prompts_list.extend(trajectory.action_prompts_list)
      self.action_indices.extend(trajectory.action_indices)
      self.action_probs.append(trajectory.action_probs)
      self.rewards.append(trajectory.rewards)
      self.state_fs.extend(trajectory.state_fs)
      self.state_f_prompts.extend(trajectory.state_f_prompts)
      self.action_f_prompts_list.extend(trajectory.action_f_prompts_list)
      self.termin_signals.extend(trajectory.termin_signals)
      self.categoricals.extend(trajectory.categoricals)
      self.masks.extend(trajectory.masks)
      self.len_trajs.append(trajectory.len_traj)
      dcm_list.append(torch.triu(torch.ones(trajectory.len_traj, trajectory.len_traj)).to(device))
      #discount cumsum matrix for gae computation: block diagonal
    self.rewards=torch.cat(self.rewards, dim=0).to(device)
    self.action_probs=torch.cat(self.action_probs, dim=0).to(device)
    self.compute_rtgs(dcm_list)
    return
  
  def get_trajectories(self, batch_ep_steps_list):
    batch_size=len(batch_ep_steps_list[0].states)
    trajectories=[Trajectory() for _ in range(batch_size)]
    for batch_ep_steps in batch_ep_steps_list:
      for batch_idx in range(batch_size):
        ep_step=batch_ep_steps[batch_idx]
        #do not add dummy steps created during batch traj. generation
        if ep_step['mask']==False:
          trajectories[batch_idx].add_ep_step(ep_step)
    return trajectories

In [None]:
class IndexDataset(torch.utils.data.Dataset):
  def __init__(self, N):
    super(IndexDataset, self).__init__()
    self.N=N
    self.data=list(np.arange(0,self.N,1))
  
  def __len__(self):
    return self.N
  
  def __getitem__(self, idx):
    return self.data[idx]

In [None]:
#using batch trajectories => memory error due to large-sized gradient tensors => compute with single trajectory.
class PPO_RLQA_Trainer():
  #targeted for hotpot data.
  def __init__(self, hotpot_data, horizon):
    self.agent=RLQA_Agent()
    self.old_agent=RLQA_Agent()

    #save data
    self.train_data=hotpot_data['train']
    self.val_data=hotpot_data['validation']
    self.test_data=hotpot_data['test']

    #create environments
    self.train_env=HotPotQAEnv(self.train_data, train=True)
    self.val_env=HotPotQAEnv(self.val_data, train=False)
    self.test_env=HotPotQAEnv(self.test_data, train=False)

    self.horizon=horizon

    self.train_logs={
        #joing loss
        'loss_history': [],
        #trajectory length during training
        'len_traj_history': [],
        #kl divergence
        'kl_div_history': [],
        'log10_kl_div_history': [],
        #no. of iters per batch data
        'avg_iter_history': [],
        #time duration of operations
        'traj_gen_time_history': [],
        'iter_time_history': [],
        #return and accuracy data
        'train_return_history': [],
        'train_acc_history': [],
        'val_return_history': [],
        'val_acc_history': [],
        'best_val_return': 0,
        'best_val_acc': 0,
        'best_val_acc_epoch': 0,
        #test data
        'test_return': 0,
        'test_acc': 0,
    }

    self.best_agent=RLQA_Agent()
  
  def plot_train_logs(self):
    plt.figure(figsize=(25,25))
    #plot policy loss
    lh=self.train_logs['loss_history']
    plt.subplot(3,3,1)
    plt.plot(np.arange(1, len(lh)+1, 1), lh)
    plt.title("Joint Loss History")
    plt.xlabel("Update Steps")
    plt.ylabel("Joint Loss")
    #plot value loss

    #plot log-scaled kl divergence
    plt.subplot(3,3,2)
    klh=self.train_logs['log10_kl_div_history']
    plt.plot(np.arange(1, len(klh)+1, 1), klh)
    plt.title("Log-Scale KL Divergence History")
    plt.xlabel("Steps")
    plt.ylabel("Log10 KL Divergence")

    #plot return history
    plt.subplot(3,3,3)
    trh=self.train_logs['train_return_history']
    vrh=self.train_logs['val_return_history']
    plt.plot(np.arange(1, len(trh)+1, 1), trh, 'b-', label='train avg return')
    plt.plot(np.arange(1, len(vrh)+1, 1), vrh, "r-", label='validation avg return')
    plt.xlabel("Epochs")
    plt.ylabel("Average Return")
    plt.title("Average Return")
    plt.legend()

    #plot accuracy history
    plt.subplot(3,3,4)
    tah=self.train_logs['train_acc_history']
    vah=self.train_logs['val_acc_history']
    plt.plot(np.arange(1, len(tah)+1, 1), tah, 'b-', label='train avg acc')
    plt.plot(np.arange(1, len(vah)+1, 1), vah, "r-", label='validation avg acc')
    plt.plot(self.train_logs['best_val_acc_epoch'], self.train_logs['best_val_acc'], "go")
    plt.xlabel("Epochs")
    plt.ylabel("Average Accuracy")
    plt.title("Average Return")
    plt.legend()

    #plot durations of operations
    plt.subplot(3,3,5)
    tgth=self.train_logs['traj_gen_time_history']
    ith=self.train_logs['iter_time_history']
    plt.plot(np.arange(1, len(tgth)+1, 1), tgth, 'b-', label='traj. gen. time')
    plt.plot(np.arange(1, len(ith)+1, 1), ith, 'r-', label='iter time')
    plt.xlabel("Batch Steps")
    plt.ylabel("Durations")
    plt.title("Operation Durations")
    plt.legend()

    #plot average # of iterations over time
    plt.subplot(3,3,6)
    nih=self.train_logs['avg_iter_history']
    plt.plot(np.arange(1, len(nih)+1, 1), nih)
    plt.title("Loss Fitting Iters History")
    plt.xlabel("Epochs")
    plt.ylabel("No. of iters")

    #plot average of trajectory length
    plt.subplot(3,3,7)
    lth=self.train_logs['len_traj_history']
    plt.plot(np.arange(1, len(lth)+1, 1), lth)
    plt.title("Avg. Traj. Length History(per batch)")
    plt.xlabel("Batch Steps")
    plt.ylabel("Avg. Traj. Length")

    plt.show()
  
  def load_model(self, name):
    self.agent.load_state_dict(torch.load(os.path.join(base_dir, 'models/hotpot_ppo/{:s}.pt'.format(name))))
    self.best_agent.load_state_dict(torch.load(os.path.join(base_dir, 'models/hotpot_ppo/{:s}.pt'.format(name))))
    return

  def save_best_model(self, num_epochs, batch_size, iters, lr, reg, lambd, eps, target_kl_div):
    setting_string=str(num_epochs)+"_"+str(batch_size)+"_"+str(iters)+"_"+str(lr)+"_"+str(reg)+"_"+str(lambd)+"_"+str(eps)+"_"+str(target_kl_div)
    model_dir=os.path.join(base_dir, 'models/hotpot_ppo/{:s}.pt'.format(setting_string))
    torch.save(self.best_agent.state_dict(), model_dir)
    return

  def test(self):
    print("Testing")
    batch_size=4
    #run 10 episodes with best model, get avg. return
    self.best_agent.eval()
    index_dataset=IndexDataset(len(self.test_data))
    index_loader=torch.utils.data.DataLoader(index_dataset, batch_size=batch_size, shuffle=True)
    avg_return=0
    avg_test_acc=0
    test_pbar=tqdm(desc="testing", total=len(self.test_data))
    for b_idx, batch_indices in enumerate(index_loader):
      _,_,batch_avg_return, batch_val_acc=self.generate_batch_trajectories(batch_indices, self.best_agent, self.test_env, temperature=0)
      avg_return=avg_return+(batch_avg_return-avg_return)/(b_idx+1)
      avg_test_acc=avg_test_acc+(batch_val_acc-avg_test_acc)/(b_idx+1)
      test_pbar.update(batch_size)
    test_pbar.close()
    print("Average Test Return: {:.3f}, Average Test Accuracy: {:.3f}".format(avg_return, avg_test_acc))
    #log the results
    self.train_logs['test_return']=avg_return
    self.train_logs['test_acc']=avg_test_acc
    return
  
  def validate(self):
    #run 10 episodes with current model, get avg. return
    batch_size=4
    self.agent.eval()
    index_dataset=IndexDataset(len(self.val_data))
    index_loader=torch.utils.data.DataLoader(index_dataset, batch_size=batch_size, shuffle=True)
    avg_return=0
    avg_val_acc=0
    val_pbar=tqdm(desc="validating", total=len(self.val_data))
    for b_idx, batch_indices in enumerate(index_loader):
      _,_,batch_avg_return, batch_val_acc=self.generate_batch_trajectories(batch_indices, self.agent, self.val_env, temperature=0)
      avg_return=avg_return+(batch_avg_return-avg_return)/(b_idx+1)
      avg_val_acc=avg_val_acc+(batch_val_acc-avg_val_acc)/(b_idx+1)
      val_pbar.update(batch_size)
    val_pbar.close()
    print("Average Val. Return: {:.3f}, Average Val. Accuracy: {:.3f}".format(avg_return, avg_val_acc))
    #log the results
    self.train_logs['val_return_history'].append(avg_return)
    self.train_logs['val_acc_history'].append(avg_val_acc)
    if avg_val_acc>self.train_logs['best_val_acc']:
      self.train_logs['best_val_acc']=avg_val_acc
      self.train_logs['best_val_return']=avg_return
      self.train_logs['best_val_acc_epoch']=len(self.train_logs['val_acc_history'])
      self.best_agent=self.agent
    return
  
  def get_trajectories(self, batch_ep_steps_list):
    batch_size=len(batch_ep_steps_list[0].states)
    trajectories=[Trajectory() for _ in range(batch_size)]
    for batch_ep_steps in batch_ep_steps_list:
      for batch_idx in range(batch_size):
        ep_step=batch_ep_steps[batch_idx]
        #do not add dummy steps created during batch traj. generation
        if ep_step['mask']==False:
          trajectories[batch_idx].add_ep_step(ep_step)
    return trajectories
  
  def compute_kl_divergence(self, old_categoricals, new_categoricals):
    #used for early stopping of fitting iterations.
    #compute kl divergnece between current agent and old agent based on batch trajectories.
    avg_kl_div=0
    for idx, old_categorical in enumerate(old_categoricals):
      new_categorical=new_categoricals[idx]
      kl_div=torch.distributions.kl.kl_divergence(old_categorical, new_categorical).item()
      avg_kl_div=avg_kl_div+(kl_div-avg_kl_div)/(idx+1)
    return avg_kl_div
  
  #ftn.s for using separate trajectories
  def generate_batch_trajectories(self, batch_indices, agent, env, temperature=1):
    batch_size=len(batch_indices)
    batch_states=env.batch_reset(batch_indices, self.horizon)
    batch_termin_signals=[False for _ in range(batch_size)]
    batch_traj_accs=[0 for _ in range(batch_size)]
    batch_ep_steps_list=[]
    #until all trajectories terminate
    while sum(batch_termin_signals)!=batch_size:
      #do not generate action for terminated tasks
      mask=batch_termin_signals
      #generate prompts only for 
      batch_state_prompts, batch_action_prompts_list=env.get_batch_actor_prompts(mask)

      #categoricals used for KL-divergence computation
      categoricals, action_indices, action_probs=agent(batch_state_prompts, batch_action_prompts_list, no_grad=True, temperature=temperature)
    
      batch_rewards, batch_state_fs, batch_termin_signals=env.batch_step(batch_size, action_indices, mask)
      #print(action_indices, batch_rewards, batch_termin_signals)
      batch_ep_steps=BatchEpisodeSteps(batch_states, action_indices, action_probs, batch_rewards, batch_state_fs, batch_termin_signals, categoricals, mask)
      batch_ep_steps_list.append(batch_ep_steps)
      #update state
      batch_states=batch_state_fs
    #get trajectories => using flat trajectory causes OUT OF MEMORY w/ Colab
    trajectories=self.get_trajectories(batch_ep_steps_list)
    #compute total # of steps inside batch trajectory
    n_steps=sum([traj.len_traj for traj in trajectories])
    #compute avg_return
    batch_rewards=[traj.rewards for traj in trajectories]
    reward_sums=[torch.sum(br).item() for br in batch_rewards]
    avg_return=np.mean(reward_sums)
    #compute accuracy
    avg_acc=np.mean(env.batch_accuracies)
    return trajectories, n_steps, avg_return, avg_acc

  def get_joint_loss(self, batch_trajectories, lambd, eps):
    #train actor and critic simultaneously.
    #compute V_s, and V_fs 
    avg_joint_loss=torch.FloatTensor([0]).to(device)
    avg_kl_div=0
    for t_idx, trajectory in enumerate(batch_trajectories):
      V_s=self.agent.critic(trajectory.state_prompts)
      V_fs=self.agent.critic(trajectory.state_f_prompts)

      #compute GAEs
      gaes=trajectory.compute_gaes(V_s, V_fs, lambd)
      #get action probs w.r.t current policy
      old_action_probs=trajectory.action_probs
      old_categoricals=trajectory.categoricals
      new_categoricals, new_action_probs=self.agent.actor.get_action_probs(trajectory.state_prompts, 
                                                        trajectory.action_prompts_list, trajectory.action_indices)
      #compute kl div
      kl_div=self.compute_kl_divergence(old_categoricals, new_categoricals)
      avg_kl_div=avg_kl_div+(kl_div-avg_kl_div)/(t_idx+1)

      #get policy loss
      prob_ratios=torch.div(new_action_probs, old_action_probs)
      first_term=torch.mul(prob_ratios, gaes)
      second_term=torch.mul(torch.clamp(prob_ratios, 1-eps, 1+eps), gaes)
      policy_loss=-torch.mean(torch.minimum(first_term, second_term))

      #get value loss
      mse_loss=nn.MSELoss()
      value_loss=0.5*mse_loss(trajectory.rtgs, V_s)
      joint_loss=policy_loss+value_loss
      avg_joint_loss=avg_joint_loss+(joint_loss-avg_joint_loss)/(t_idx+1)

    #log scale of KL divergence
    if avg_kl_div>0:
      log10_avg_kl_div=np.log(avg_kl_div)
    else:
      log10_avg_kl_div=np.log(-avg_kl_div)
      print(avg_kl_div)
    return avg_kl_div, log10_avg_kl_div, avg_joint_loss
    
  def train(self, num_epochs, batch_size, iters, lr, reg, lambd, eps, target_kl_div):
    total_steps=num_epochs*len(self.train_data)
    optimizer=optim.AdamW(self.agent.parameters(), lr=lr, weight_decay=reg)
    #scheduler=transformers.get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(0.2*total_steps), num_training_steps=total_steps)
    scheduler=transformers.get_constant_schedule_with_warmup(optimizer, num_warmup_steps=int(0.2*total_steps))
    #run PPO
    pbar=tqdm(desc="PPO training", total=num_epochs*len(self.train_data))
    for epoch in range(1, num_epochs+1):
      avg_iter=0
      #single epoch
      index_dataset=IndexDataset(len(self.train_data))
      index_loader=torch.utils.data.DataLoader(index_dataset, batch_size=batch_size, shuffle=True)

      #measuring average train trajectories' return and accuracy
      avg_train_acc=0
      avg_train_return=0
      avg_len_traj=0
      for b_idx, batch_indices in enumerate(index_loader):
        traj_gen_start=time.time()
        batch_trajectories, n_steps, batch_train_return, batch_train_acc=self.generate_batch_trajectories(batch_indices, self.old_agent, self.train_env)
        #log the average length of trajectory on each batch
        self.train_logs['len_traj_history'].append(n_steps/batch_size)
        traj_gen_end=time.time()
        traj_gen_time=traj_gen_end-traj_gen_start
        #log duartion of trajectory generation
        self.train_logs['traj_gen_time_history'].append(traj_gen_time)
        #print("Traj Gen Duration for epoch={:d}, batch idx={:d}: {:.5f}".format(epoch, b_idx, traj_gen_time))

        #policy gradient + critic loss joint fitting
        total_iter_time=0
        for iter in range(iters):
          iter_start=time.time()
          kl_div, log10_kl_div, joint_loss=self.get_joint_loss(batch_trajectories, lambd, eps)
          #print("Log10 KL div: {:s},  Loss: {:.3f}".format(str(log10_kl_div), joint_loss.item()))
          self.train_logs['loss_history'].append(joint_loss.item())
          self.train_logs['kl_div_history'].append(kl_div)
          self.train_logs['log10_kl_div_history'].append(log10_kl_div)

          #during fitting iteration: should employ KL-divergence based early stopping: preventing agent's policy moving too far away from old policy
          if kl_div>target_kl_div:
            #record average iter
            avg_iter=avg_iter+(iter-avg_iter)/(b_idx+1)
            break
          
          optimizer.zero_grad()
          joint_loss.backward()
          optimizer.step()

          #ending loss-fitting iteration
          iter_end=time.time()
          iter_time=iter_end-iter_start
          total_iter_time+=iter_time
          #print("Single Iter Duration for epoch={:d}, batch idx={:d}, iter idx={:d}: {:.5f}".format(epoch, b_idx, iter, iter_time))
          #print("-------")

          #record avg. iter if all iters are used w/o violating target KL Divergence
          if iter==iters-1:
            avg_iter=avg_iter+(iter-avg_iter)/(b_idx+1)
        
        #log duration of loss fitting iterations.
        self.train_logs['iter_time_history'].append(total_iter_time)

        #replace old agent's weight with current agent's weight
        self.old_agent.load_state_dict(self.agent.state_dict())
        #incremental update on train average metrics.
        avg_train_acc=avg_train_acc+(batch_train_acc-avg_train_acc)/(b_idx+1)
        avg_train_return=avg_train_return+(batch_train_return-avg_train_return)/(b_idx+1)
        pbar.update(batch_size)
        #step scheduler
        scheduler.step()
      
      #end of epoch
      #validate, add average train acc, and avg. no. of loss fitting iters
      self.train_logs['train_acc_history'].append(avg_train_acc)
      self.train_logs['train_return_history'].append(avg_train_return)
      self.train_logs['avg_iter_history'].append(avg_iter)
      self.validate()
      
    #final testing
    self.test()
    self.save_best_model(num_epochs, batch_size, iters, lr, reg, lambd, eps, target_kl_div)
    return self.train_logs

In [None]:
class RLQA_tuner():
  def __init__(self, hotpot_data, horizon, start_model_name=None):
    self.hotpot_data=hotpot_data
    self.horizon=horizon
    self.start_model_name=start_model_name

    self.tuning_logs={
        'logs': {},
        'best_setting': [],
        'best_test_return': 0,
        'best_test_acc': 0,
    }

    self.current_trainer=None
    
    self.best_agent=RLQA_Agent()
  
  def save_results(self, model_name, log_name):
    #save model
    model_dir=os.path.join(base_dir, 'models/hotpot_ppo/{:s}.pt'.format(model_name))
    torch.save(self.best_agent.state_dict(), model_dir)
    #save tuning logs
    with open(os.path.join(base_dir, 'experiment logs/hotpot_ppo/{:s}.pkl'.format(log_name)), 'wb') as file:
      pickle.dump(self.tuning_logs, file)
    return
  
  def get_settings(self, num_epochs_list, batch_size_list, iters_list, lr_list, reg_list, lambd_list, eps_list, target_kl_div_list):
    settings=[]
    for num_epochs in num_epochs_list:
      for batch_size in batch_size_list:
        for iters in iters_list:
          for lr in lr_list:
            for reg in reg_list:
              for lambd in lambd_list:
                for eps in eps_list:
                  for target_kl_div in target_kl_div_list:
                    setting=(num_epochs, batch_size, iters, lr, reg, lambd, eps, target_kl_div)
                    settings.append(setting)
    return settings
  
  def show_setting(self, setting):
    print("num_epochs={:d}, batch_size={:d}, iters={:d}, lr={:s}, reg={:s}, lambd={:.2f}, eps={:.2f}, target_kl_div={:s}".format(setting[0], setting[1], setting[2], str(setting[3]), str(setting[4]), setting[5], setting[6], str(setting[7])))
    return
  
  def tune(self, num_epochs_list, batch_size_list, iters_list, lr_list, reg_list, lambd_list, eps_list, target_kl_div_list):
    settings=self.get_settings(num_epochs_list, batch_size_list, iters_list, lr_list, reg_list, lambd_list, eps_list, target_kl_div_list)
    outer_pbar=tqdm(desc="PPO RLQA Tuning", total=len(settings))
    for setting in settings:
      self.show_setting(setting)
      self.current_trainer=PPO_RLQA_Trainer(self.hotpot_data, self.horizon)
      if self.start_model_name!=None:
        self.current_trainer.load_model(self.start_model_name)
      train_logs=self.current_trainer.train(*setting)
      self.current_trainer.plot_train_logs()
      #add to logs
      self.tuning_logs['logs'][setting]=train_logs

      #update best model
      if train_logs['test_acc']>self.tuning_logs['best_test_acc']:
        self.tuning_logs['best_test_return']=train_logs['test_return']
        self.tuning_logs['best_test_acc']=train_logs['test_acc']
        self.tuning_logs['best_setting']=setting
        self.best_agent=self.current_trainer.best_agent

      outer_pbar.update(1)
    outer_pbar.close()
    return self.tuning_logs

In [None]:
num_epochs_list=[5]
batch_size_list=[4] 
iters_list=[5] 
lr_list=[5e-5] 
reg_list=[0]
lambd_list=[0.97]
eps_list=[0.2]
target_kl_div_list=[0.03] #generally 0.003 ~ 0.03
tuner=RLQA_tuner(processed_hotpot_data, horizon=60) 
#set gamma to 1 by default (undiscounted finite horizon setting)
tuning_logs=tuner.tune(num_epochs_list, batch_size_list, iters_list, lr_list, reg_list, lambd_list, eps_list, target_kl_div_list)