In [1]:
import os
import gym
import math
import torch
import numpy as np
from typing import List
from tqdm.auto import tqdm
from datetime import datetime
from collections import defaultdict
from torch.distributions import Categorical
from torch.utils.tensorboard import SummaryWriter
from nltk.translate.gleu_score import sentence_gleu

In [2]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [3]:
%cd ..
import src.envs
from src.search import search_best_actions
from src.utils import remove_ansi, iterative_prediction, stack_padding, load_json, write_json, discount_cumsum, freeze_params, scale, is_gce_instance, load_model
%cd notebooks

/home/rajk/Machine_Learning/DRL-GEC
/home/rajk/Machine_Learning/DRL-GEC/notebooks


In [4]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [5]:
@torch.no_grad()
def select_action(policy, state, reference, mask_generator, explore=False):
    [logits] = policy([state])
    dist = Categorical(logits=logits)
    action_np = search_best_actions(policy, state, reference, mask_generator, explore=explore)
    action = torch.from_numpy(action_np).to(logits.device)
    log_pi = dist.log_prob(action)
    return action_np, log_pi, dist.entropy()

def get_log_pi(policy, states, actions):
    logits = policy(states)
    dist = Categorical(logits=logits)
    actions = stack_padding(actions, dtype="float32")
    actions = torch.from_numpy(actions).to(logits.device)
    # Obtain mask of the non-padded tokens
    batch_size, seq_len, label_size = logits.shape
    mask = torch.zeros((batch_size, seq_len), dtype=torch.bool, device=logits.device)
    for i, state in enumerate(states):
        mask[i, :len(state)] = True
    log_pis = dist.log_prob(actions)
    log_pis = log_pis*mask                     # Set padded tokens' log_pis to zero
    return log_pis

In [6]:
def get_evaluator(data_path, label_vocab, num_iterations=10):
    json_data = load_json(data_path)
    src_ref = ((data_dict["text"], data_dict["references"]) for data_dict in json_data) 
    sources, references = zip(*src_ref)
    print(f"Number of evaluation examples: {len(sources)}")
    del json_data
    
    def eval_func(policy):
        policy.eval()
        predictions = iterative_prediction(policy, label_vocab, sources, num_iter=num_iterations, insert_start=True, verbose=False)
        score = np.mean([sentence_gleu(refs, pred) for refs, pred in zip(references, predictions)])
        policy.train()
        return score
    
    return eval_func

In [7]:
def train(pbar, optim, grad_scaler, policy, buffer, batch_size=32, focal_alpha=0.25, focal_gamma=2.0):
    num_items = len(buffer["return"])
    accumulation_size = math.ceil(num_items/batch_size)
    # Set up the progress bar
    pbar.reset()
    pbar.total = num_items
    buffer["return"] = torch.tensor(buffer["return"], device=device)
    losses = []
    optim.zero_grad()
    for i in range(0, num_items, batch_size):
        state_batch = buffer["state"][i:i+batch_size]
        action_batch = buffer["action"][i:i+batch_size]
        return_batch = buffer["return"][i:i+batch_size]
        log_pis = get_log_pi(policy, state_batch, action_batch)
        focal_coef = focal_alpha*(1-log_pis.exp()).pow(focal_gamma)
        seq_log_pis = (focal_coef*log_pis).sum(-1)                             # Sum log_probs over tokens
        pi_loss = -(seq_log_pis*return_batch).mean()
        grad_scaler.scale(pi_loss/accumulation_size).backward()
        losses.append(pi_loss)
        pbar.update(len(state_batch))
    grad_scaler.step(optim)
    grad_scaler.update()
    pbar.refresh()
    return torch.stack(losses).mean().item()

# Define parameters

In [8]:
# Algorithm parameters
eps = 0.5
min_eps = 0.5
eps_decay = 0.999996
gamma = 0.9
focal_alpha = 0.25
focal_gamma = 2.0
# Training parameters
cold_lr = 1e-3
warm_lr = 1e-5
lr = cold_lr
batch_size = 32
update_interval = 500
num_unfreeze_layers = 0            # Unfreeze only the last 98 layers
dropout = 0.1
weight_decay = 0.0
episodes = 200_000
cold_episodes = 0 # 20_000
# Evaluation parameters
eval_max_iter = 10
evaluate_interval = 10_000
record_output_interval = 10
dev_data_path = r"../data/processed/wi+locness/dev.json"
model_path = "sl_logs/pretrain_synthetic_31:10:2022_19:25/model-best.pt"
train_type = "pretrain" if model_path is None else "finetune"
current_datetime = datetime.now().strftime("%d_%m_%Y_%H:%M")
log_dir = os.path.join("pg_logs_final", f"{train_type}_rl_{current_datetime}")
env_kwargs = {
    "id": "gec_lev_dist-v1",
    "datasets": ["wi+locness"],                  # Datasets to load
    "correct_examples_percent": [1.0],           # Percentage of correct sentences to load
    "reward_config": {
        "scale": 1.0,
        "correct": 1.0,
        "fn_penalty": 0.0,
        "out_of_range_penalty": -2.0,
    },
}
meta_data = {
    "base_model": model_path,
    "env_config": env_kwargs,
    "description": """
    Finetune the FL (2M) pretrained model. 
    Dataset has no unsolvable examples.
    Dropout enabled.
    positive reward for success; episodes ends when all keep predicted for correct sentence
    Don't shuffle batches.
    Explore-Exploit determined per episode.
    Label = UKNONWN if len(candidate_labels) == 0
    Label = candidate_labels[0] if len(candidate_labels) == 1
    Check all append for insert edit + [keep]
    Check all labels for candidate labels for replace edits + [keep]
    Expoit = sample from candidate_probs
    Explore = Sample from candidate_lev_probs
    No return rescaling
    Early termination for false negative too.
    eval_max_iter set from 5 to 10.
    """
}

# Load environment

In [9]:
env = gym.make(new_step_api=True, **env_kwargs)

  logger.warn(


Original number of data in wi+locness: 24734
Number of data without correct sentences: 24734


# Load the models

In [10]:
model_name = "roberta-base"
tokenizer_config = {"use_fast": True}
transformer_config = {"output_attentions": False}
policy = load_model(
    model_name = model_name, 
    model_path = model_path, 
    num_labels = env.action_space.n,
    tokenizer_config = tokenizer_config, 
    transformer_config = transformer_config, 
    local_files_only = True,
).to(device)
optim = torch.optim.Adam(policy.parameters(), lr=lr, weight_decay=weight_decay)
grad_scaler = torch.cuda.amp.GradScaler()
writer = SummaryWriter(log_dir=log_dir)
evaluator = get_evaluator(dev_data_path, env.labels, eval_max_iter)
write_json(os.path.join(log_dir, "meta.json"), meta_data)

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.dense.bias', 'roberta.pooler.dense.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Number of evaluation examples: 3590


# Train the model

In [11]:
policy.train()
# Freeze encoder weights
freeze_params(policy.encoder, requires_grad=False, optim=optim, lr=lr)
# Log hyperparameters
writer.add_scalar("hyperparameters/gamma", gamma, 0)
writer.add_scalar("hyperparameters/dropout", dropout, 0)
writer.add_scalar("hyperparameters/batch_size", batch_size, 0)
writer.add_scalar("hyperparameters/focal_alpha", focal_alpha, 0)
writer.add_scalar("hyperparameters/focal_gamma", focal_gamma, 0)
writer.add_scalar("hyperparameters/cold_episodes", cold_episodes, 0)
writer.add_scalar("hyperparameters/update_interval", update_interval, 0)
# Variables for training progress bars
policy_pbar = None

mask_generator = env.mask_generator
max_eval_score = 0
eval_score = evaluator(policy)
writer.add_scalar("rl/eval_score", eval_score, 1)
buffer_dict = defaultdict(list)
with torch.cuda.amp.autocast():
    for episode in tqdm(range(1, episodes+1), desc="Training Episodes", total=episodes):
        if ((cold_episodes == 0) and (episode == 1)) or (episode == cold_episodes):                # When cold epoch ends, update learning rate and unfreeze certain encoder layers 
            lr = warm_lr
            freeze_params(policy.encoder, requires_grad=True, num_layers=num_unfreeze_layers, optim=optim, lr=lr)
        rewards = []
        log_pis = []
        entropies = []
        token_lens = []
        done = False
        init_state = state = env.reset()
        reference = env.current_reference
        explore = np.random.uniform() < eps
        while not done:
            action, log_pi, entropy = select_action(policy, state, reference, mask_generator, explore=explore)
            next_state, reward, terminated, truncated, info = env.step(action)
            done = terminated or truncated
            # Save timestep data
            rewards.append(reward)
            log_pis.append(log_pi)
            entropies.append(entropy)
            token_lens.append(len(next_state))
            buffer_dict["state"].append(state)
            buffer_dict["action"].append(action)
            state = next_state
    # Compute returns
    returns = discount_cumsum(rewards, discount=gamma)
    buffer_dict["return"].extend(returns)
    # eps = max(eps*eps_decay, min_eps)
    # Train the model
    if (episode % update_interval) == 0:
        if policy_pbar is None:
            policy_pbar = tqdm(desc="Updating Policy")
        loss = train(policy_pbar, optim, grad_scaler, policy, buffer_dict, batch_size=batch_size, focal_alpha=focal_alpha, focal_gamma=focal_gamma)
        writer.add_scalar("rl/mean_loss", loss, episode)
        buffer_dict = defaultdict(list)
        torch.cuda.empty_cache()
    # Log the episode output to the tensorboard
    if (episode % record_output_interval) == 0:
        render_output = "  \n".join(remove_ansi(out) for out in env.render())
        writer.add_text("rl/output", render_output, episode)
    # Evaluate the model
    if (episode % evaluate_interval) == 0:
        eval_score = evaluator(policy)
        if eval_score >= max_eval_score:
            torch.save(policy.state_dict(), os.path.join(log_dir, "model-best.pt"))
            max_eval_score = eval_score
        writer.add_scalar("rl/eval_score", eval_score, episode)
    # Log scalar episode results
    rewards = np.array(rewards)
    writer.add_scalar("rl/lr", lr, episode)
    writer.add_scalar("rl/eps", eps, episode)
    writer.add_scalar("rl/explore", explore, episode)
    writer.add_scalar("rl/episode_length", len(rewards), episode)
    writer.add_scalar("rl/episode_reward_last", rewards[-1], episode)
    writer.add_scalar("rl/episode_reward_total", sum(rewards), episode)
    writer.add_scalar("rl/token_length_delta_ratio", (len(state)-len(init_state))/len(init_state), episode)
    # Log histogram episode results
    writer.add_histogram("rl/episode_reward", rewards, episode)
    writer.add_histogram("rl/episode_returns", returns, episode)
    writer.add_histogram("rl/episode_log_pi", torch.cat(log_pis), episode)
    writer.add_histogram("rl/episode_entropy", torch.cat(entropies), episode)
    writer.add_histogram("rl/episode_token_length", np.array(token_lens), episode)
    writer.add_histogram("rl/episode_returns_normalized", np.array(scale(returns)), episode)

Number of frozen parameters: 197/197


Training Episodes:   0%|          | 0/200000 [00:00<?, ?it/s]

Number of frozen parameters: 0/197


  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")
  logger.deprecation(
  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")
If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  logger.warn(


Updating Policy: 0it [00:00, ?it/s]

# Save model

In [12]:
torch.save(policy.state_dict(), os.path.join(log_dir, "model-last.pt"))

policy.load_state_dict(torch.load(os.path.join(log_dir, "model-best.pt")))

# Test model

In [13]:
@torch.no_grad()
def greedy_action(policy, state):
    [logits] = policy([state])
    v, i = logits.topk(5)
    v = v.cpu().numpy()
    i = i.cpu().numpy()
    dist = Categorical(logits=logits)
    probs = dist.probs
    entropy = dist.entropy().cpu().numpy()
    for a, e, lp in zip(state, entropy, zip(env.labels[i], v)):
        print(f"{e:4f}, {a:15}", " -- ".join(f"{l} [{p:5.2f}]" for (l, p) in zip(*lp)))
    print()
    action = logits.argmax(axis=-1)
    # action = Categorical(logits=logits).sample()
    return action.cpu().numpy()

In [14]:
_ = policy.eval()          # Set poliy to eval mode to disable dropouts

In [15]:
env = gym.make("gec_lev_dist-v2", datasets=["wi+locness"], new_step_api=True, correct_examples_percent=[0.0], repeat=1)

Original number of data in wi+locness: 24734
Number of data without correct sentences: 15413


In [18]:
state = env.reset()
print("# References")
for ref in env.reference_tokens_list:
    print(ref)
print()
done = False
while not done:
    action = greedy_action(policy, state)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

# References
['$START', 'He', 'has', 'been', 'one', 'of', 'the', 'players', 'of', 'basketball', 'from', 'when', 'he', 'was', 'an', 'elementary', 'student', 'until', 'now', 'he', "'s", 'in', 'college', '.']

0.534045, $START          $KEEP [10.02] -- $APPEND_But [ 5.73] -- $APPEND_The [ 4.94] -- $DELETE [ 4.76] -- $APPEND_He [ 4.60]
0.902089, He              $KEEP [ 9.53] -- $APPEND_has [ 7.14] -- $DELETE [ 6.24] -- $REPLACE_He [ 5.91] -- $APPEND_had [ 5.63]
2.138970, is              $APPEND_been [ 8.63] -- $KEEP [ 8.02] -- $REPLACE_been [ 7.83] -- $REPLACE_has [ 7.83] -- $REPLACE_was [ 7.60]
0.878942, one             $KEEP [ 9.14] -- $DELETE [ 6.65] -- $APPEND_been [ 4.64] -- $REPLACE_been [ 4.13] -- $REPLACE_one [ 3.12]
0.516260, of              $KEEP [ 9.62] -- $DELETE [ 7.10] -- $APPEND_the [ 4.19] -- $APPEND_of [ 3.22] -- $APPEND_been [ 2.70]
1.224080, the             $KEEP [ 9.80] -- $DELETE [ 6.63] -- $APPEND_good [ 5.83] -- $APPEND_first [ 5.20] -- $APPEND_last [ 4.80]
1.166492,

# Close Google Compute Instance

In [17]:
if is_gce_instance():
    !gcloud compute instances stop drl-gec --zone us-west1-b