In [1]:
import os
import gym
import math
import torch
import numpy as np
from typing import List
from tqdm.auto import tqdm
from datetime import datetime
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
from nltk.translate.gleu_score import sentence_gleu

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
%cd ..
import src.envs
from src.planning import search_best_actions
from src.models.seq2labels import PretrainedEncoder, Seq2Labels
from src.utils import remove_ansi, iterative_prediction, load_json, write_json, discount_cumsum, freeze_params, scale
%cd notebooks

/home/rajk/Machine_Learning/DRL-GEC
/home/rajk/Machine_Learning/DRL-GEC/notebooks


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
@torch.no_grad()
def select_action(policy, state, reference, mask_generator, explore=False):
    [logits] = policy([state])
    dist = Categorical(logits=logits)
    action_np = search_best_actions(policy, state, mask_generator, reference, explore=explore)
    action = torch.from_numpy(action_np).to(logits.device)
    log_pi = dist.log_prob(action)
    return action_np, log_pi, dist.entropy()

def get_log_pi(policy, state, action):
    [logits] = policy([state])
    dist = Categorical(logits=logits)
    action = torch.from_numpy(action).to(logits.device)
    log_pi = dist.log_prob(action)
    return log_pi

In [6]:
def get_evaluator(data_path, label_vocab, num_iterations=10):
    json_data = load_json(data_path)
    src_ref = ((data_dict["text"], data_dict["references"]) for data_dict in json_data) 
    sources, references = zip(*src_ref)
    print(f"Number of evaluation examples: {len(sources)}")
    del json_data
    
    def eval_func(policy):
        policy.eval()
        predictions = iterative_prediction(policy, label_vocab, sources, num_iter=num_iterations, insert_start=True, verbose=False)
        score = np.mean([sentence_gleu(refs, pred) for refs, pred in zip(references, predictions)])
        policy.train()
        return score
    
    return eval_func

In [7]:
def train(pbar, optim, grad_scaler, policy, state_action_batch, return_batch, batch_size=32, focal_alpha=0.25, focal_gamma=2.0, gamma=0.99):
    num_items = len(return_batch)
    accumulation_size = math.ceil(num_items/batch_size)
    # Set up the progress bar
    pbar.reset()
    pbar.total = num_items
    # return_batch = scale(return_batch)
    return_batch = torch.tensor(return_batch, device=device)
    losses = []
    optim.zero_grad()
    for i in range(0, num_items, batch_size):
        focal_log_pis = []
        for (state, action) in state_action_batch[i:i+batch_size]:                                      # Obtain log_pi(state, action) for the batch
            log_pi = get_log_pi(policy, state, action)
            focal_coef = focal_alpha*(1-log_pi.exp()).pow(focal_gamma)
            focal_log_pi = (focal_coef*log_pi).sum()                      # Sum log_probs over tokens               
            focal_log_pis.append(focal_log_pi)
        focal_log_pis = torch.stack(focal_log_pis)
        pi_loss = -(focal_log_pis*return_batch[i:i+batch_size]).mean()
        grad_scaler.scale(pi_loss/accumulation_size).backward()
        losses.append(pi_loss.item())
        pbar.update(len(return_batch[i:i+batch_size]))
    grad_scaler.step(optim)
    grad_scaler.update()
    pbar.refresh()
    return np.mean(losses)

# Define parameters

In [8]:
# Algorithm parameters
eps = 0.9
min_eps = 0.01
eps_decay = 1 # 0.999995
gamma = 0.99
focal_alpha = 1.0
focal_gamma = 2.0
# Training parameters
cold_lr = 1e-3
warm_lr = 1e-5
lr = cold_lr
batch_size = 32
update_interval = 1000
num_unfreeze_layers = 0            # Unfreeze only thhe last 98 layers
dropout = 0.1
weight_decay = 0.0
episodes = 100_000
cold_episodes = 0 # 100_000
# Evaluation paramters
eval_max_iter = 5
evaluate_interval = 5000
record_output_interval = 10
dev_data_path = r"../data/processed/wi+locness/dev_filtered.json"
model_path = "sl_logs/pretrain_synthetic_23:10:2022_18:26/model-best.pt"
train_type = "pretrain" if model_path is None else "finetune"
current_datetime = datetime.now().strftime("%d_%m_%Y_%H:%M")
log_dir = os.path.join("pg_logs", f"{train_type}_rl_{current_datetime}")
env_kwargs = {
    "id": "gec_lev_dist-v0",
    "datasets": ["wi+locness"],                  # Datasets to load
    "correct_examples_percent": [1.0],           # Percentage of correct sentences to load
    "repeat": 1,                                 # Number of repetation of each sentence in ``
    "repeat_interval": 1000,                     # Interval to use for repetition
    "consecutive": False,                        # If the repetition should have consecutive or random distribution
}
meta_data = {
    "base_model": model_path,
    "env_config": env_kwargs,
    "description": """
    Finetune the FL pretrained model. 
    Don't Detach the FL coefficient: alpha*(1-p)^gamma
    Disable FL Coef.
    Dropout enabled.
    With correct sentences.
    Reward func = lev_dist(current) - lev_dist(prev); negative reward for false negatives; 
    positive reward for success; episodes ends when all keep predicted for correct sentence
    Don't shuffle batches-
    Repeat sentence 1 times in interval 1000 episodes.
    New planning: explore = random sample from candidate. exploit = weighted sample from candidate.
    Explore-Exploit determined per episode.
    No rescaling and Negative rewards. +1 for success.
    """
}

# Load environment

In [9]:
env = gym.make(new_step_api=True, **env_kwargs)

Original number of data in wi+locness: 24932
Number of data without correct sentences: 24932


# Load the models

In [10]:
model_name = "roberta-base"
tokenizer_config = {"use_fast": True}
transformer_config = {
    "output_attentions": False,
    # "hidden_dropout_prob": 0.0,
    # "attention_probs_dropout_prob": 0.0,
}
encoder = PretrainedEncoder(model_name, tokenizer_config, transformer_config).to(device)
policy = Seq2Labels(encoder_model=encoder, num_labels=env.action_space.n, dropout=dropout).to(device)
if model_path:
    policy.load_state_dict(torch.load(model_path))
optim = torch.optim.Adam(policy.parameters(), lr=lr, weight_decay=weight_decay)
grad_scaler = torch.cuda.amp.GradScaler()
writer = SummaryWriter(log_dir=log_dir)
evaluator = get_evaluator(dev_data_path, env.labels, eval_max_iter)
write_json(os.path.join(log_dir, "meta.json"), meta_data)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'lm_head.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Number of evaluation examples: 3343


# Train the model

In [11]:
policy.train()
# Freeze encoder weights
freeze_params(policy.encoder, requires_grad=False, optim=optim, lr=lr)
# Log hyperparameters
writer.add_scalar("hyperparameters/gamma", gamma, 0)
writer.add_scalar("hyperparameters/dropout", dropout, 0)
writer.add_scalar("hyperparameters/batch_size", batch_size, 0)
writer.add_scalar("hyperparameters/focal_alpha", focal_alpha, 0)
writer.add_scalar("hyperparameters/focal_gamma", focal_gamma, 0)
writer.add_scalar("hyperparameters/cold_episodes", cold_episodes, 0)
writer.add_scalar("hyperparameters/update_interval", update_interval, 0)
# Variables for training progress bars
policy_pbar = None

return_batch = []
state_action_batch = []
mask_generator = env.mask_generator
max_eval_score = 0
eval_score = evaluator(policy)
writer.add_scalar("rl/eval_score", eval_score, 1)
for episode in tqdm(range(1, episodes+1), desc="Training Episodes", total=episodes):
    if episode-1 == cold_episodes:                # When cold epoch ends, update learning rate and unfreeze certain encoder layers 
        lr = warm_lr
        freeze_params(policy.encoder, requires_grad=True, num_layers=num_unfreeze_layers, optim=optim, lr=lr)
    rewards = []
    log_pis = []
    entropies = []
    token_lens = []
    done = False
    init_state = state = env.reset()
    reference = env.current_reference
    explore = np.random.uniform() < eps
    with torch.cuda.amp.autocast():
        while not done:
            action, log_pi, entropy = select_action(policy, state, reference, mask_generator, explore=explore)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            # Save timestep data
            rewards.append(reward)
            log_pis.append(log_pi)
            entropies.append(entropy)
            token_lens.append(len(next_state))
            state_action_batch.append((state, action))
            state = next_state
    # Compute returns
    returns = discount_cumsum(rewards, discount=gamma)
    return_batch.extend(returns)
    eps = max(eps*eps_decay, min_eps)
    # Train the model
    if (episode % update_interval) == 0:
        if policy_pbar is None:
            policy_pbar = tqdm(desc="Updating Policy")
        loss = train(policy_pbar, optim, grad_scaler, policy, state_action_batch, return_batch, batch_size=batch_size, focal_alpha=focal_alpha, focal_gamma=focal_gamma, gamma=gamma)
        writer.add_scalar("rl/mean_loss", loss, episode)
        return_batch = []
        state_action_batch = []
        torch.cuda.empty_cache()
    # Log the episode output to the tensorboard
    if (episode % record_output_interval) == 0:
        render_output = "  \n".join(remove_ansi(out) for out in env.render())
        writer.add_text("rl/output", render_output, episode)
    # Evaluate the model
    if (episode % evaluate_interval) == 0:
        eval_score = evaluator(policy)
        if eval_score >= max_eval_score:
            torch.save(policy.state_dict(), os.path.join(log_dir, "model-best.pt"))
            max_eval_score = eval_score
        writer.add_scalar("rl/eval_score", eval_score, episode)
    # Log scalar episode results
    rewards = np.array(rewards)
    writer.add_scalar("rl/lr", lr, episode)
    writer.add_scalar("rl/eps", eps, episode)
    writer.add_scalar("rl/explore", explore, episode)
    writer.add_scalar("rl/episode_length", len(rewards), episode)
    writer.add_scalar("rl/episode_reward_last", rewards[-1], episode)
    writer.add_scalar("rl/episode_reward_total", sum(rewards), episode)
    # writer.add_scalar("rl/episode_reward_delta", rewards[-1]-rewards[0], episode)
    writer.add_scalar("rl/token_length_delta_ratio", (len(state)-len(init_state))/len(init_state), episode)
    # Log histogram episode results
    writer.add_histogram("rl/episode_reward", rewards, episode)
    writer.add_histogram("rl/episode_returns", returns, episode)
    writer.add_histogram("rl/episode_log_pi", torch.cat(log_pis), episode)
    writer.add_histogram("rl/episode_entropy", torch.cat(entropies), episode)
    writer.add_histogram("rl/episode_token_length", np.array(token_lens), episode)
    writer.add_histogram("rl/episode_returns_normalized", np.array(scale(returns)), episode)

Number of frozen parameters: 199/199


Training Episodes:   0%|          | 0/100000 [00:00<?, ?it/s]

Number of frozen parameters: 0/199


  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")
  logger.deprecation(
  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")
If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  logger.warn(


Updating Policy: 0it [00:00, ?it/s]

# Save model

In [12]:
torch.save(policy.state_dict(), os.path.join(log_dir, "model-last.pt"))

policy.load_state_dict(torch.load(os.path.join(log_dir, "model-best.pt")))

# Test model

In [13]:
@torch.no_grad()
def greedy_action(policy, state):
    [logits] = policy([state])
    v, i = logits.topk(5)
    v = v.cpu().numpy()
    i = i.cpu().numpy()
    dist = Categorical(logits=logits)
    probs = dist.probs
    entropy = dist.entropy().cpu().numpy()
    for a, e, lp in zip(state, entropy, zip(env.labels[i], v)):
        print(f"{e:4f}, {a:15}", " -- ".join(f"{l} [{p:5.2f}]" for (l, p) in zip(*lp)))
    print()
    action = logits.argmax(axis=-1)
    # action = Categorical(logits=logits).sample()
    return action.cpu().numpy()

In [14]:
_ = policy.eval()          # Set poliy to eval mode to disable dropouts

In [15]:
env = gym.make("wi_locness_gec_lev_dist-v0", new_step_api=True, correct_examples_percent=[0.0], repeat=1, repeat_interval=5000, consecutive=False, min_num_refs=[1])

Original number of data in wi+locness: 24932
Number of data without correct sentences: 15586


In [16]:
state = env.reset()
print("# References")
for ref in env.reference_tokens_list:
    print(ref)
print()
done = False
while not done:
    action = greedy_action(policy, state)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

# References
['$START', 'In', 'my', 'opinion', ',', 'I', 'am', 'the', 'perfect', 'candidate', 'for', 'this', 'vacancy', '.']

0.004923, $START          $KEEP [14.93] -- $APPEND_But [ 6.21] -- $APPEND_, [ 5.60] -- $APPEND_And [ 4.12] -- $APPEND_The [ 3.98]
0.208553, In              $KEEP [11.09] -- $DELETE [ 6.85] -- $APPEND_, [ 6.34] -- $REPLACE_On [ 4.52] -- $TRANSFORM_CASE_CAPITAL [ 3.95]
0.119008, my              $KEEP [10.88] -- $DELETE [ 5.85] -- $APPEND_, [ 5.50] -- $APPEND_own [ 4.45] -- $APPEND_. [ 3.20]
0.305871, opinion         $APPEND_, [13.61] -- $KEEP [11.22] -- $REPLACE_, [ 6.68] -- $APPEND_. [ 5.68] -- $APPEND_- [ 5.59]
0.050036, I               $KEEP [11.96] -- $APPEND_, [ 5.83] -- $DELETE [ 5.61] -- $APPEND_am [ 4.05] -- $APPEND_I [ 4.04]
0.095194, am              $KEEP [11.08] -- $REPLACE_was [ 5.40] -- $APPEND_, [ 4.53] -- $DELETE [ 4.36] -- $REPLACE_'m [ 3.80]
0.048151, the             $KEEP [13.07] -- $REPLACE_a [ 7.87] -- $APPEND_a [ 5.27] -- $APPEND_the [ 5.10] -