In [1]:
import os
import gym
import torch
import numpy as np
import pandas as pd
from torch.distributions import Categorical

In [2]:
%cd ..
import src.envs
from src.utils import load_text, apply_labels
from src.models.seq2labels import PretrainedEncoder, Seq2Labels
%cd notebooks

/home/rajk/Machine_Learning/DRL-GEC
/home/rajk/Machine_Learning/DRL-GEC/notebooks


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
@torch.no_grad()
def greedy_action(policy, state, all_labels, verbose=True):
    [logits] = policy([state])
    top_logits, i = logits.topk(3)
    top_logits = top_logits.cpu().numpy()
    i = i.cpu().numpy()
    dist = Categorical(logits=logits)
    top_probs = dist.probs[torch.arange(len(state)).unsqueeze(1), i]
    entropy = dist.entropy().cpu().numpy()
    if verbose:
        for a, e, label_logit_prob in zip(state, entropy, zip(all_labels[i], top_logits, top_probs)):
            print(f"Entropy: {e:4f} | Label: {a:15}  |", " -- ".join(f"{lab} [{prob:3.2f}, {log:5.2f}]" for (lab, log, prob) in zip(*label_logit_prob)))
        print()
    action = logits.argmax(axis=-1)
    return action.cpu().numpy()

In [5]:
def load_model(model_path, output_size):
    model_name = "roberta-base"
    encoder = PretrainedEncoder(model_name, local_files_only=True).to(device)
    policy = Seq2Labels(encoder_model=encoder, num_labels=output_size).to(device)
    policy.load_state_dict(torch.load(model_path))
    policy.eval()
    return policy

# Load Labels

In [6]:
env = gym.make("wi_locness_gec_lev_dist-v1", new_step_api=True, correct_examples_percent=[0.0])

Original number of data in wi+locness: 24734
Number of data without correct sentences: 15413


# Load model

In [7]:
rl_model_path = os.path.abspath("pg_logs_new/finetune_rl_19_11_2022_13:25/model-last.pt")
sl_model_path = os.path.abspath("sl_logs/finetune_wi+locness_02:11:2022_23:06/model-best.pt")
rl_model = load_model(rl_model_path, output_size=len(env.labels))
sl_model = load_model(sl_model_path, output_size=len(env.labels))

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.dense.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.den

# Test Model

# SL model

In [8]:
data_dict = dict(
    text = "he said in other words that the more fluoride may create damage in human body , specifically the bone .",
    references = [
        "He said in other words that the more fluoride may create damage in the human body , specifically the bone .",
        "He said , in other words , that more fluoride may create damage in the human body , specifically the bone .",
        "He said , in other words , that more fluoride may create damage to the human body , specifically the bones .",
        "In other words , he said that more fluoride may damage the human body , specifically the bones ."
    ],
)
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(sl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")


Entropy: 0.533590 | Label: $START           | $KEEP [0.92, 10.03] -- $APPEND_But [0.02,  6.12] -- $APPEND_And [0.01,  5.13]
Entropy: 0.337851 | Label: he               | $TRANSFORM_CASE_CAPITAL [0.93, 11.58] -- $KEEP [0.06,  8.82] -- $REPLACE_He [0.00,  6.32]
Entropy: 0.871657 | Label: said             | $APPEND_, [0.64, 10.05] -- $KEEP [0.32,  9.38] -- $REPLACE_, [0.01,  5.69]
Entropy: 0.779609 | Label: in               | $KEEP [0.81,  9.01] -- $TRANSFORM_CASE_CAPITAL [0.10,  6.89] -- $DELETE [0.07,  6.58]
Entropy: 0.595660 | Label: other            | $KEEP [0.87,  9.40] -- $DELETE [0.10,  7.21] -- $MERGE_SPACE [0.00,  4.15]
Entropy: 0.896537 | Label: words            | $KEEP [0.64,  9.58] -- $APPEND_, [0.30,  8.84] -- $DELETE [0.05,  6.94]
Entropy: 0.789471 | Label: that             | $KEEP [0.77,  9.21] -- $DELETE [0.17,  7.70] -- $REPLACE_, [0.03,  5.89]
Entropy: 0.998475 | Label: the              | $DELETE [0.51,  8.31] -- $KEEP [0.45,  8.19] -- $UNKNOWN [0.01,  4.09]
Entropy: 0.8

  logger.deprecation(
  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")
If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  logger.warn(


# RL model

In [9]:
data_dict = dict(
    text = "he said in other words that the more fluoride may create damage in human body , specifically the bone .",
    references = [
        "He said in other words that the more fluoride may create damage in the human body , specifically the bone .",
        "He said , in other words , that more fluoride may create damage in the human body , specifically the bone .",
        "He said , in other words , that more fluoride may create damage to the human body , specifically the bones .",
        "In other words , he said that more fluoride may damage the human body , specifically the bones ."
    ]
)
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(rl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

Entropy: 0.774663 | Label: $START           | $KEEP [0.85, 11.90] -- $APPEND_But [0.07,  9.39] -- $APPEND_And [0.02,  8.01]
Entropy: 1.023935 | Label: he               | $TRANSFORM_CASE_CAPITAL [0.76, 11.36] -- $REPLACE_He [0.10,  9.33] -- $KEEP [0.07,  8.92]
Entropy: 1.032397 | Label: said             | $KEEP [0.49,  9.64] -- $APPEND_, [0.45,  9.56] -- $DELETE [0.03,  6.80]
Entropy: 0.962300 | Label: in               | $KEEP [0.78,  9.20] -- $DELETE [0.12,  7.29] -- $REPLACE_in [0.03,  5.85]
Entropy: 0.680887 | Label: other            | $KEEP [0.84,  9.84] -- $DELETE [0.12,  7.86] -- $REPLACE_other [0.01,  5.24]
Entropy: 0.877305 | Label: words            | $KEEP [0.63, 10.14] -- $APPEND_, [0.31,  9.43] -- $DELETE [0.04,  7.38]
Entropy: 0.684049 | Label: that             | $KEEP [0.82, 10.51] -- $DELETE [0.13,  8.66] -- $APPEND_, [0.03,  7.05]
Entropy: 1.173237 | Label: the              | $DELETE [0.65,  9.50] -- $KEEP [0.28,  8.67] -- $APPEND_more [0.00,  3.96]
Entropy: 1.249199 | La

# RL model

In [10]:
data_dict = dict(
    text = "$START Unfortunately , there is still a long way to do in terms environment concerns , but some of this solutions suggested by the mayor help us stopping the pollution .",
    references = [
        "Unfortunately , there is still a long way to go in terms of environmental concerns , but some of these solutions suggested by the mayor help us stop the pollution ."
    ]
)
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(rl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

Entropy: 0.142187 | Label: $START           | $KEEP [0.98, 13.14] -- $APPEND_But [0.01,  8.24] -- $APPEND_However [0.00,  7.45]
Entropy: 0.284960 | Label: Unfortunately    | $KEEP [0.96, 10.57] -- $DELETE [0.02,  6.58] -- $APPEND_, [0.00,  5.15]
Entropy: 0.316656 | Label: ,                | $KEEP [0.93, 10.65] -- $DELETE [0.05,  7.77] -- $APPEND_, [0.01,  5.68]
Entropy: 0.262932 | Label: there            | $KEEP [0.96, 10.90] -- $DELETE [0.02,  7.25] -- $APPEND_there [0.00,  5.00]
Entropy: 0.547691 | Label: is               | $KEEP [0.89, 10.12] -- $DELETE [0.07,  7.58] -- $REPLACE_are [0.01,  5.33]
Entropy: 0.452718 | Label: still            | $KEEP [0.92, 11.01] -- $DELETE [0.04,  7.91] -- $APPEND_much [0.02,  6.99]
Entropy: 0.424061 | Label: a                | $KEEP [0.92, 11.81] -- $DELETE [0.06,  9.04] -- $APPEND_lot [0.00,  6.57]
Entropy: 0.395132 | Label: long             | $KEEP [0.94, 10.71] -- $DELETE [0.03,  7.14] -- $APPEND_a [0.01,  5.68]
Entropy: 0.262892 | Label: way    

In [11]:
data_dict = dict(
    text = "$START I think someone should get exercise by starting play some favourite sport instead of watching To or playing game .",
    references = [
        "$START I think people should get exercise by starting to play some favourite sport instead of watching To or playing games ."
    ]
)
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(rl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

Entropy: 0.149045 | Label: $START           | $KEEP [0.98, 13.62] -- $APPEND_But [0.01,  9.30] -- $APPEND_And [0.00,  7.32]
Entropy: 0.292497 | Label: I                | $KEEP [0.96, 11.82] -- $APPEND_also [0.02,  7.78] -- $APPEND_would [0.01,  6.93]
Entropy: 0.351736 | Label: think            | $KEEP [0.94, 10.61] -- $APPEND_that [0.04,  7.34] -- $TRANSFORM_VERB_VB_VBN [0.01,  5.68]
Entropy: 1.183376 | Label: someone          | $KEEP [0.78,  9.55] -- $REPLACE_people [0.07,  7.08] -- $DELETE [0.05,  6.76]
Entropy: 0.582261 | Label: should           | $KEEP [0.92, 10.12] -- $DELETE [0.02,  6.19] -- $REPLACE_would [0.01,  5.41]
Entropy: 1.012070 | Label: get              | $KEEP [0.84, 10.13] -- $APPEND_to [0.03,  6.62] -- $REPLACE_get [0.02,  6.58]
Entropy: 0.853198 | Label: exercise         | $KEEP [0.89,  9.43] -- $DELETE [0.02,  5.61] -- $TRANSFORM_VERB_VB_VBN [0.01,  4.90]
Entropy: 0.627109 | Label: by               | $KEEP [0.89,  9.75] -- $DELETE [0.07,  7.15] -- $REPLACE_by [0.02

In [12]:
data_dict = {
    "text": "I 'm planning to improve my English and I have already registered for a course .",
    "references": [
        "I am planning to improve my English and I have already registered on a course ."
    ]
}
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(rl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

Entropy: 0.188871 | Label: $START           | $KEEP [0.97, 13.55] -- $APPEND_But [0.01,  8.93] -- $APPEND_I [0.00,  8.05]
Entropy: 0.424570 | Label: I                | $KEEP [0.93, 10.97] -- $DELETE [0.04,  7.90] -- $APPEND_am [0.01,  6.33]
Entropy: 1.358291 | Label: 'm               | $REPLACE_am [0.47,  9.61] -- $DELETE [0.31,  9.19] -- $KEEP [0.13,  8.30]
Entropy: 0.899895 | Label: planning         | $KEEP [0.85,  9.10] -- $DELETE [0.05,  6.30] -- $TRANSFORM_VERB_VBG_VB [0.03,  5.89]
Entropy: 0.374026 | Label: to               | $KEEP [0.94, 10.36] -- $DELETE [0.03,  6.99] -- $REPLACE_on [0.00,  4.76]
Entropy: 0.937385 | Label: improve          | $KEEP [0.82,  9.23] -- $TRANSFORM_VERB_VB_VBG [0.09,  7.06] -- $DELETE [0.04,  6.17]
Entropy: 0.262185 | Label: my               | $KEEP [0.96, 10.66] -- $DELETE [0.02,  6.98] -- $REPLACE_my [0.00,  4.79]
Entropy: 0.633877 | Label: English          | $KEEP [0.82, 10.48] -- $APPEND_, [0.15,  8.75] -- $DELETE [0.02,  6.75]
Entropy: 0.283158 |