In [1]:
import os
import gym
import torch
import numpy as np
import pandas as pd
from torch.distributions import Categorical

In [2]:
%cd ..
import src.envs
from src.utils import load_text, apply_labels
from src.models.seq2labels import PretrainedEncoder, Seq2Labels
%cd notebooks

/home/rajk/Machine_Learning/DRL-GEC
/home/rajk/Machine_Learning/DRL-GEC/notebooks


In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
@torch.no_grad()
def greedy_action(policy, state, all_labels, verbose=True):
    [logits] = policy([state])
    top_logits, i = logits.topk(3)
    top_logits = top_logits.cpu().numpy()
    i = i.cpu().numpy()
    dist = Categorical(logits=logits)
    top_probs = dist.probs[torch.arange(len(state)).unsqueeze(1), i]
    entropy = dist.entropy().cpu().numpy()
    if verbose:
        for a, e, label_logit_prob in zip(state, entropy, zip(all_labels[i], top_logits, top_probs)):
            print(f"Entropy: {e:4f} | Label: {a:15}  |", " -- ".join(f"{lab} [{prob:3.2f}, {log:5.2f}]" for (lab, log, prob) in zip(*label_logit_prob)))
        print()
    action = logits.argmax(axis=-1)
    return action.cpu().numpy()

In [8]:
def load_model(model_path, output_size):
    model_name = "roberta-base"
    encoder = PretrainedEncoder(model_name, local_files_only=True).to(device)
    policy = Seq2Labels(encoder_model=encoder, num_labels=output_size).to(device)
    policy.load_state_dict(torch.load(model_path))
    policy.eval()
    return policy

# Load Labels

In [6]:
env = gym.make("wi_locness_gec_lev_dist-v1", new_step_api=True, correct_examples_percent=[0.0])

Original number of data in wi+locness: 24734
Number of data without correct sentences: 15413


# Load model

In [9]:
rl_model_path = os.path.abspath("pg_logs_new/finetune_rl_18_11_2022_15:00/model-last.pt")
sl_model_path = os.path.abspath("sl_logs/finetune_wi+locness_02:11:2022_23:06/model-best.pt")
rl_model = load_model(rl_model_path, output_size=len(env.labels))
sl_model = load_model(sl_model_path, output_size=len(env.labels))

Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'roberta.pooler.dense.weight', 'lm_head.dense.weight', 'lm_head.decoder.weight']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: ['lm_head.layer_norm.weight', 'lm_head.bias', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias', 'lm_head.dense.bias', 'roberta.poole

# Test Model

# SL model

In [10]:
data_dict = dict(
    text = "he said in other words that the more fluoride may create damage in human body , specifically the bone .",
    references = [
        "He said in other words that the more fluoride may create damage in the human body , specifically the bone .",
        "He said , in other words , that more fluoride may create damage in the human body , specifically the bone .",
        "He said , in other words , that more fluoride may create damage to the human body , specifically the bones .",
        "In other words , he said that more fluoride may damage the human body , specifically the bones ."
    ],
)
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(sl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")


Entropy: 0.533590 | Label: $START           | $KEEP [0.92, 10.03] -- $APPEND_But [0.02,  6.12] -- $APPEND_And [0.01,  5.13]
Entropy: 0.337851 | Label: he               | $TRANSFORM_CASE_CAPITAL [0.93, 11.58] -- $KEEP [0.06,  8.82] -- $REPLACE_He [0.00,  6.32]
Entropy: 0.871657 | Label: said             | $APPEND_, [0.64, 10.05] -- $KEEP [0.32,  9.38] -- $REPLACE_, [0.01,  5.69]
Entropy: 0.779609 | Label: in               | $KEEP [0.81,  9.01] -- $TRANSFORM_CASE_CAPITAL [0.10,  6.89] -- $DELETE [0.07,  6.58]
Entropy: 0.595660 | Label: other            | $KEEP [0.87,  9.40] -- $DELETE [0.10,  7.21] -- $MERGE_SPACE [0.00,  4.15]
Entropy: 0.896537 | Label: words            | $KEEP [0.64,  9.58] -- $APPEND_, [0.30,  8.84] -- $DELETE [0.05,  6.94]
Entropy: 0.789471 | Label: that             | $KEEP [0.77,  9.21] -- $DELETE [0.17,  7.70] -- $REPLACE_, [0.03,  5.89]
Entropy: 0.998475 | Label: the              | $DELETE [0.51,  8.31] -- $KEEP [0.45,  8.19] -- $UNKNOWN [0.01,  4.09]
Entropy: 0.8

  logger.deprecation(
  logger.warn(f"{pre} should be an int or np.int64, actual type: {type(obs)}")
  logger.warn(f"{pre} is not within the observation space.")
If you want to render in human mode, initialize the environment in this way: gym.make('EnvName', render_mode='human') and don't call the render method.
See here for more information: https://www.gymlibrary.ml/content/api/[0m
  deprecation(
  logger.warn(


# RL model

In [11]:
data_dict = dict(
    text = "he said in other words that the more fluoride may create damage in human body , specifically the bone .",
    references = [
        "He said in other words that the more fluoride may create damage in the human body , specifically the bone .",
        "He said , in other words , that more fluoride may create damage in the human body , specifically the bone .",
        "He said , in other words , that more fluoride may create damage to the human body , specifically the bones .",
        "In other words , he said that more fluoride may damage the human body , specifically the bones ."
    ]
)
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(rl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

Entropy: 0.839501 | Label: $START           | $KEEP [0.84, 11.68] -- $APPEND_But [0.07,  9.23] -- $APPEND_And [0.02,  7.76]
Entropy: 1.406336 | Label: he               | $TRANSFORM_CASE_CAPITAL [0.57, 11.05] -- $REPLACE_He [0.25, 10.24] -- $KEEP [0.07,  8.88]
Entropy: 1.145620 | Label: said             | $KEEP [0.52,  9.45] -- $APPEND_, [0.40,  9.18] -- $DELETE [0.03,  6.54]
Entropy: 1.294882 | Label: in               | $KEEP [0.71,  8.97] -- $DELETE [0.11,  7.14] -- $TRANSFORM_CASE_CAPITAL [0.05,  6.37]
Entropy: 0.762814 | Label: other            | $KEEP [0.84,  9.56] -- $DELETE [0.09,  7.34] -- $APPEND_other [0.01,  5.40]
Entropy: 0.878034 | Label: words            | $KEEP [0.67,  9.87] -- $APPEND_, [0.27,  8.96] -- $DELETE [0.04,  7.02]
Entropy: 0.643609 | Label: that             | $KEEP [0.84, 10.36] -- $DELETE [0.12,  8.39] -- $APPEND_, [0.02,  6.70]
Entropy: 1.400433 | Label: the              | $DELETE [0.62,  9.26] -- $KEEP [0.29,  8.50] -- $APPEND_use [0.01,  4.54]
Entropy: 1.2

# RL model

In [12]:
data_dict = dict(
    text = "$START Unfortunately , there is still a long way to do in terms environment concerns , but some of this solutions suggested by the mayor help us stopping the pollution .",
    references = [
        "Unfortunately , there is still a long way to go in terms of environmental concerns , but some of these solutions suggested by the mayor help us stop the pollution ."
    ]
)
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(rl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)

Entropy: 0.153830 | Label: $START           | $KEEP [0.98, 13.09] -- $APPEND_But [0.01,  8.42] -- $APPEND_However [0.00,  7.50]
Entropy: 0.274802 | Label: Unfortunately    | $KEEP [0.97, 10.59] -- $DELETE [0.01,  6.20] -- $APPEND_, [0.00,  5.13]
Entropy: 0.301851 | Label: ,                | $KEEP [0.94, 10.57] -- $DELETE [0.05,  7.53] -- $APPEND_, [0.01,  5.59]
Entropy: 0.301055 | Label: there            | $KEEP [0.95, 10.65] -- $DELETE [0.02,  6.80] -- $REPLACE_there [0.01,  5.44]
Entropy: 0.524787 | Label: is               | $KEEP [0.91, 10.02] -- $DELETE [0.05,  7.07] -- $REPLACE_are [0.01,  5.04]
Entropy: 0.487239 | Label: still            | $KEEP [0.91, 10.68] -- $DELETE [0.04,  7.63] -- $APPEND_much [0.02,  6.84]
Entropy: 0.428117 | Label: a                | $KEEP [0.92, 11.41] -- $DELETE [0.06,  8.66] -- $APPEND_lot [0.00,  6.15]
Entropy: 0.402745 | Label: long             | $KEEP [0.94, 10.43] -- $DELETE [0.02,  6.78] -- $APPEND_of [0.01,  5.41]
Entropy: 0.299234 | Label: way  

In [None]:
data_dict = dict(
    text = "$START I think someone should get exercise by starting play some favourite sport instead of watching To or playing game .",
    references = [
        "$START I think people should get exercise by starting to play some favourite sport instead of watching To or playing games ."
    ]
)
state = env.reset(data_dict=data_dict)
done = False
while not done:
    action = greedy_action(rl_model, state, env.labels, verbose=True)
    next_state, reward, terminated, truncated, info = env.step(action)
    done = terminated or truncated
    state = next_state
    outputs = env.render()
    for o in outputs:
        print(o)