In [1]:
import os
import torch
import numpy as np
import pandas as pd
import altair as alt
from IPython.display import display

In [2]:
%cd ..
from src.sampler import EditMaskGenerator
from src.models.seq2labels import PretrainedEncoder, Seq2Labels
from src.planning import search_best_actions, get_lev_dist_of_next_tokens
from src.utils import load_text, get_lev_dist, apply_labels, apply_labels_at
%cd notebooks

/home/rajk/Machine_Learning/DRL-GEC
/home/rajk/Machine_Learning/DRL-GEC/notebooks


In [3]:
alt.data_transformers.disable_max_rows()

DataTransformerRegistry.enable('default')

In [4]:
def check(policy, tokens, references, mask_generator, max_attempts=10, explore=False, verbose=False):
    ref_i = np.random.choice(len(references))
    current_ref = references[ref_i]
    print(f"Reference: {current_ref}")
    print(f"Source: {tokens}")
    attempt = 0
    num_labels = len(mask_generator.labels)
    while (tokens != current_ref):
        actions = search_best_actions(policy, tokens, current_ref, mask_generator, explore=explore, verbose=verbose)
        labels = mask_generator.actions_to_labels(actions)
        new_tokens = apply_labels(tokens, labels)
        tokens = new_tokens
        attempt += 1
        print()
        print(f"Labels: {labels}")
        print(f"New Tokens: {new_tokens}")
        if (attempt >= max_attempts):
            break
    if tokens in references:
        print("OK!")
    else:
        print("Not OK!")
        
        
def visualize_candidates(policy, mask_generator, state, reference, tok_i, eps=0.0):
    print(f"# Checking the labels of the token '{state[tok_i]}' at index '{tok_i}'.")

    edit_mask = mask_generator.get_edit_mask(state, [reference])
    tok_edit = edit_mask[tok_i]
    print(f"# Edit type: {tok_edit}")

    # tok_labels = mask_generator.edit_to_labels(tok_edit)
    # tok_labels = np.append(tok_labels, "$KEEP")
    # tok_actions = mask_generator.labels_to_actions(list(tok_labels))
    with torch.no_grad():
        [logits] = policy([state])
    tok_logits = logits[tok_i]
    tok_probs = tok_logits.softmax(0).cpu().numpy()

    orig_lev_dist = get_lev_dist(state, reference)
    distances = np.array([get_lev_dist(apply_labels_at(state, [label], [tok_i]), reference) for label in mask_generator.labels])
    delta_levs = orig_lev_dist-distances
    candidate_labels = mask_generator.labels[delta_levs > 0]
    candidate_labels = np.append(candidate_labels, "$KEEP")
    print(f"# Candidate labels: {candidate_labels}")
    if len(candidate_labels) == 1:
        return
    df = pd.DataFrame([mask_generator.labels, delta_levs], index=["Labels", "Delta_Lev"]).T
    df["Label_Index"] = df.index
    df["Probs"] = 0

    if np.random.uniform() < eps:
        print("# Exploring")
        tok_label = np.random.choice(candidate_labels)
    else:
        print("# Exploiting")
        candidate_actions = mask_generator.labels_to_actions(list(candidate_labels))
        tok_label = candidate_labels[tok_logits[candidate_actions].argmax().item()]
    df.loc[:, "Probs"] = tok_probs

    print(f"# Selected label: {tok_label}")
    print()
    mask = df.Labels.apply(lambda x: x in candidate_labels)
    display(df[mask])
    display(alt.Chart(df).mark_circle(size=60).encode(
        x = alt.X("Label_Index", sort=None),
        y = "Delta_Lev:Q",
        color = "Delta_Lev:N",
        tooltip=["Label_Index", "Labels", "Probs"],
    ))

In [5]:
class RandomPolicy:
    def __init__(self, num_labels):
        self.num_labels = num_labels
    
    def __call__(self, tokens):
        assert len(tokens) == 1
        tokens = tokens[0]
        num_tokens = len(tokens)
        logits = np.random.randint(-100, 100, size=(num_tokens, self.num_labels))/10.0
        logits = torch.from_numpy(logits)
        return [logits]

# Load Labels

In [6]:
label_vocab = load_text("../data/vocabs/labels.txt")
label_vocab = np.char.array(label_vocab)
mask_generator = EditMaskGenerator(label_vocab)
print(f"Number of labels: {len(label_vocab)}")

Number of labels: 5001


In [7]:
model_path = None # os.path.abspath("pg_logs/finetune_rl_07_11_2022_13:51/model-last.pt")
if model_path:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    encoder = PretrainedEncoder("roberta-base").to(device)
    policy = Seq2Labels(encoder_model=encoder, num_labels=len(label_vocab)).to(device)
    policy.load_state_dict(torch.load(model_path))
    policy.eval()
else:
    policy = RandomPolicy(len(label_vocab))

# Timing 

In [8]:
%%time
tokens = "$START Well , that 's all for now I hope to hear about you soon .".split()
reference = "$START Well , that 's all for now . I hope to hear from you soon .".split()
actions = search_best_actions(policy, tokens, reference, mask_generator, verbose=False)
print(actions)

[ 0  0  0  0  0  0  0 22  0  0  0  0 66  0  0  0]
CPU times: user 19.8 ms, sys: 0 ns, total: 19.8 ms
Wall time: 19.3 ms


# Test Algo

## Keep label

In [9]:
state = "$START This is fine .".split()
ref_list = [
    "$START This is fine .".split(),
]
check(policy, state, ref_list, mask_generator)

Reference: ['$START', 'This', 'is', 'fine', '.']
Source: ['$START', 'This', 'is', 'fine', '.']
OK!


## Replace label

In [10]:
state = "$START This is fine ?".split()
ref_list = [
    "$START This is fine .".split(),
]
check(policy, state, ref_list, mask_generator)

Reference: ['$START', 'This', 'is', 'fine', '.']
Source: ['$START', 'This', 'is', 'fine', '?']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'This', 'is', 'fine', '?']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$REPLACE_.']
New Tokens: ['$START', 'This', 'is', 'fine', '.']
OK!


## Append label

In [11]:
state = "$START He lives in big house .".split()
ref_list = [
    "$START He lives in a big house .".split(),
]
check(policy, state, ref_list, mask_generator)

Reference: ['$START', 'He', 'lives', 'in', 'a', 'big', 'house', '.']
Source: ['$START', 'He', 'lives', 'in', 'big', 'house', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'He', 'lives', 'in', 'big', 'house', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'He', 'lives', 'in', 'big', 'house', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$APPEND_a' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'He', 'lives', 'in', 'a', 'big', 'house', '.']
OK!


## Delete label

In [19]:
state = "$START It is a very nice place .".split()
ref_list = [
    "$START It is a nice place .".split(),
]
check(policy, state, ref_list, mask_generator)

Reference: ['$START', 'It', 'is', 'a', 'nice', 'place', '.']
Source: ['$START', 'It', 'is', 'a', 'very', 'nice', 'place', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'It', 'is', 'a', 'very', 'nice', 'place', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'It', 'is', 'a', 'very', 'nice', 'place', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$DELETE' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'It', 'is', 'a', 'nice', 'place', '.']
OK!


## Merge hyphen label

In [13]:
state = "$START Tiger is cold blooded animal .".split()
ref_list = [
    # "$START Tiger is a cold - blooded animal .".split(),
    "$START Tigers are cold-blooded animals .".split(),
]
check(policy, state, ref_list, mask_generator)

Reference: ['$START', 'Tigers', 'are', 'cold-blooded', 'animals', '.']
Source: ['$START', 'Tiger', 'is', 'cold', 'blooded', 'animal', '.']

Labels: ['$KEEP' '$TRANSFORM_AGREEMENT_PLURAL' '$MERGE_HYPHEN' '$APPEND_are'
 '$KEEP' '$REPLACE_animals' '$KEEP']
New Tokens: ['$START', 'Tigers', 'is-cold', 'are', 'blooded', 'animals', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'Tigers', 'is-cold', 'are', 'blooded', 'animals', '.']

Labels: ['$KEEP' '$KEEP' '$DELETE' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'Tigers', 'are', 'blooded', 'animals', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'Tigers', 'are', 'blooded', 'animals', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'Tigers', 'are', 'blooded', 'animals', '.']

Labels: ['$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP' '$KEEP']
New Tokens: ['$START', 'Tigers', 'are', 'blooded', 'animals', '.']

Labels: 

## Split hyphen label

In [16]:
state = "$START Tiger is cold-blooded animal .".split()
ref_list = [
    # "$START Tiger is a cold blooded animal .".split(),
    "$START Tigers are cold blooded animals .".split(),
]
check(policy, state, ref_list, mask_generator, explore=True, verbose=False)

Reference: ['$START', 'Tigers', 'are', 'cold', 'blooded', 'animals', '.']
Source: ['$START', 'Tiger', 'is', 'cold-blooded', 'animal', '.']

Labels: ['$KEEP' '$KEEP' '$REPLACE_are' '$TRANSFORM_SPLIT_HYPHEN'
 '$TRANSFORM_AGREEMENT_PLURAL' '$KEEP']
New Tokens: ['$START', 'Tiger', 'are', 'cold', 'blooded', 'animals', '.']

Labels: ['$KEEP' '$TRANSFORM_AGREEMENT_PLURAL' '$KEEP' '$KEEP' '$KEEP' '$KEEP'
 '$KEEP']
New Tokens: ['$START', 'Tigers', 'are', 'cold', 'blooded', 'animals', '.']
OK!


# Visualize candidate labels

In [24]:
state = "$START Tigers is cold blooded animals .".split()
reference = "$START Tigers are cold-blooded animals .".split()
tok_i = 6
eps = 0.0

for tok_i in reversed(range(len(state))):
    visualize_candidates(policy, mask_generator, state, reference, tok_i=tok_i, eps=eps)
    print()

# Checking the labels of the token '.' at index '6'.
# Edit type: equal
# Candidate labels: ['$KEEP']

# Checking the labels of the token 'animals' at index '5'.
# Edit type: equal
# Candidate labels: ['$KEEP']

# Checking the labels of the token 'blooded' at index '4'.
# Edit type: mixed
# Candidate labels: ['$DELETE' '$REPLACE_are' '$APPEND_are' '$KEEP']
# Exploiting
# Selected label: $KEEP



Unnamed: 0,Labels,Delta_Lev,Label_Index,Probs
0,$KEEP,0,0,0.0006904251
1,$DELETE,1,1,0.0001141265
30,$REPLACE_are,2,30,2.572975e-09
91,$APPEND_are,1,91,3.446322e-06



# Checking the labels of the token 'cold' at index '3'.
# Edit type: mixed
# Candidate labels: ['$DELETE' '$REPLACE_are' '$MERGE_SPACE' '$APPEND_are' '$MERGE_HYPHEN'
 '$KEEP']
# Exploiting
# Selected label: $APPEND_are



Unnamed: 0,Labels,Delta_Lev,Label_Index,Probs
0,$KEEP,0,0,2.206966e-07
1,$DELETE,1,1,0.0006578872
30,$REPLACE_are,2,30,7.403545e-11
71,$MERGE_SPACE,1,71,1.104479e-10
91,$APPEND_are,1,91,0.002667864
2224,$MERGE_HYPHEN,3,2224,8.996078e-09



# Checking the labels of the token 'is' at index '2'.
# Edit type: mixed
# Candidate labels: ['$DELETE' '$REPLACE_are' '$MERGE_SPACE' '$APPEND_are' '$MERGE_HYPHEN'
 '$KEEP']
# Exploiting
# Selected label: $APPEND_are



Unnamed: 0,Labels,Delta_Lev,Label_Index,Probs
0,$KEEP,0,0,3.313005e-08
1,$DELETE,1,1,8.148677e-08
30,$REPLACE_are,2,30,3.651983e-07
71,$MERGE_SPACE,1,71,1.209371e-05
91,$APPEND_are,1,91,2.203617e-05
2224,$MERGE_HYPHEN,1,2224,4.483664e-09



# Checking the labels of the token 'Tigers' at index '1'.
# Edit type: equal
# Candidate labels: ['$APPEND_are' '$KEEP']
# Exploiting
# Selected label: $APPEND_are



Unnamed: 0,Labels,Delta_Lev,Label_Index,Probs
0,$KEEP,0,0,1.449747e-07
91,$APPEND_are,1,91,1.180831e-05



# Checking the labels of the token '$START' at index '0'.
# Edit type: equal
# Candidate labels: ['$KEEP']



In [17]:
import re

In [18]:
REPLACE_DICT = {
    '``': '"',
    "''": '"',
}

RE_REPLACE_DICT = {
    '``': '"',
    "''": '"',
    r"(\S)(-)": r"\1 -",
    r"(-)(\S)": r"- \2",
}

In [19]:
text = "This is a sample text with '' and ``"
for k, v in REPLACE_DICT.items():
    text = text.replace(k, v)
text

'This is a sample text with " and "'

In [20]:
text = "This is a sample-text with '' and ``"
for k, v in RE_REPLACE_DICT.items():
    text = re.sub(k, v, text)
text

'This is a sample - text with " and "'