In [1]:
%load_ext autoreload
%autoreload 2
from data.dataloader import get_tokenizer
from model.model import GPT
from data.utils import Trie
import pickle
import numpy as np 
import torch

from data.dataloader import SearchData
from torch.utils.data import DataLoader
from model.env import Env

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
tokenizer = get_tokenizer()
actor = GPT(tokenizer=tokenizer, vocab_size=len(tokenizer), block_size=100).cuda()
reader = GPT(tokenizer=tokenizer, vocab_size=len(tokenizer), block_size=100).cuda()
optimizer = torch.optim.Adam(actor.parameters(), lr=1e-4)


In [3]:
with open('data/trie.pkl', 'rb') as f:
    trie = pickle.load(f)
with open('data/dataset.pkl', 'rb') as f:
    dataset = pickle.load(f)
    _, database = dataset

In [4]:
dataset = SearchData('data/train.pkl', tokenizer)
loader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=0)

In [5]:
env = Env(reader, database, trie, tokenizer, max_len=100).cuda()

In [14]:
from tqdm import tqdm
def train_policy_gradient(
        act_model,
        act_opt,
        train_loader,
        env,
        gamma=0.99,
        temperature=1.0,
        top_k=None,
        top_p=None,
        writer=None):

    pbar = tqdm(train_loader)
    for i, (data, indices, gt_y) in enumerate(train_loader):
        data, indices, gt_y = data.cuda(), indices.cuda(), gt_y.cuda()
        obs, indices, current_nodes = env.reset(data, indices, gt_y)
        B = len(obs)
        done = torch.zeros(len(data), dtype=torch.bool, device=obs.device, requires_grad=False)
        episode_reward = torch.zeros(len(data), device=data.device, requires_grad=False)
        log_probs = []
        rewards = []
        while not done.all():
            action_logits, action_probs, log_prob, action = act_model.predict(obs, indices, current_nodes, env.trie,
                                                                              temperature, top_k, top_p)
            obs = obs.detach().clone()
            indices = indices.detach().clone()
            with torch.no_grad():
                obs, indices, current_nodes, reward, done, _ = env.step(obs, indices, current_nodes, action, done)
            episode_reward += reward
            log_probs.append(log_prob)
            rewards.append(reward.detach())
        returns = []
        G = 0
        for r in reversed(rewards):
            G = r + gamma * G
            returns.insert(0, G)
        returns = torch.stack(returns, -1) # (B, T)
        log_probs = torch.stack(log_probs, -1)
        advantages = returns - returns.mean(-1, keepdim=True)
        policy_loss = -(log_probs * advantages.detach()).mean()
        act_opt.zero_grad()
        loss = policy_loss
        loss.backward()
        act_opt.step()
        if writer is not None:
            writer.add_scalar('Loss/policy_loss', policy_loss.item(), i)
            writer.add_scalar('Reward/episode_reward', episode_reward, i)

        pbar.update()
        pbar.set_description(f"Episode {i+1}: reward={episode_reward.mean().item()}, policy_loss={policy_loss.item()}")

In [15]:
train_policy_gradient(actor, optimizer, loader, env)

Episode 69: reward=-2.9510374069213867, policy_loss=0.03739551827311516:   3%|▎         | 69/2500 [01:08<39:27,  1.03it/s] 