## output structure

```
dict: {<data_id> : {
                    "spell_correction_txt": type->str,
                    "embedding": type-> ndarray(768,),
                    "phrase_chunking_sentence": type -> lst(str)
                    "phrase_chunking_tag": type-> lst(str),
                    'verb_phrase_arr': lst(str)
                    'noun_phrase_arr': lst(str)
                    }
      }
```

## Sent emb model

In [1]:
# Config
SEN_EMB_DIM = 768
# import packages 
import numpy as np
import torch
import re
import pickle 
import os
from tqdm import tqdm_notebook as tqdm
from scipy.spatial.distance import cosine
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer

# try another way
model = SentenceTransformer('sentence-transformers/gtr-t5-xl')


## Spell correction

In [2]:
import pkg_resources
from symspellpy import SymSpell
from symspellpy import Verbosity
from nltk.stem import PorterStemmer 
import string
from nltk.stem import WordNetLemmatizer
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('words')

# Set up sym_spell
sym_spell = SymSpell(max_dictionary_edit_distance=2, prefix_length=7)
dictionary_path = pkg_resources.resource_filename(
    "symspellpy", "frequency_dictionary_en_82_765.txt"
)
bigram_path = pkg_resources.resource_filename(
    "symspellpy", "frequency_bigramdictionary_en_243_342.txt"
)
# term_index is the column of the term and count_index is the
# column of the term frequency
sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)
sym_spell.load_bigram_dictionary(bigram_path, term_index=0, count_index=2)

ps = PorterStemmer()
wnl = WordNetLemmatizer()
english_vocab = set(w.lower() for w in nltk.corpus.words.words())
def correct(w):
    if all([not c.isalpha() for c in w]): return w
    
    word = w
    
    # Fix for strings with punctuations in them
    # Otherwise, it will try to spellcheck with the punctuation included which messes up results
    word_no_punc = w.translate(str.maketrans(string.punctuation, ' '*len(string.punctuation)))
    split = word_no_punc.split()
    l = len(split)
    if l > 1:
        corrected = []
        i = 0
        while i < l:
            if i == l-1:
                corrected.append(correct(split[i]))
                break
        
            # if the words combined are english, then consider it as one word
            tmp = split[i] + split[i+1]
            if wnl.lemmatize(word=tmp) in english_vocab or ps.stem(tmp) in english_vocab:
                corrected.append(tmp)
                i += 1
            # otherwise, spellcheck them separately
            else:
                corrected.append(correct(split[i]))
            i += 1
        return corrected
    else:
        word_no_punc = word_no_punc.strip()
    
    o = sym_spell.lookup(word_no_punc,
        Verbosity.CLOSEST,
        max_edit_distance = 2,
        transfer_casing = True)
    
    if not o: return w

    word = o[0].term
    if w[0].isupper():
        word = word[0].upper() + ''.join(word[1: ])

    start_punc = end_punc = ''
    
    # Get start punctuation
    for c in w:
        if c in string.punctuation:
            start_punc += c
        else:
            break

    # Get end punctuation
    for c in reversed(w):
        if c in string.punctuation:
            end_punc = c + end_punc
        else:
            break
            
    return start_punc + word + end_punc

def spellcheck_keep_punctuation(sentence):
    # Replace apostrophe/short words 
    # specific
    sentence = re.sub(r"won\'t", "will not", sentence)
    sentence = re.sub(r"can\'t", "can not", sentence)

    # general
    sentence = re.sub(r"n\'t", " not", sentence)
    sentence = re.sub(r"\'re", " are", sentence)
    sentence = re.sub(r"\'s", " is", sentence)
    sentence = re.sub(r"\'d", " would", sentence)
    sentence = re.sub(r"\'ll", " will", sentence)
    sentence = re.sub(r"\'t", " not", sentence)
    sentence = re.sub(r"\'ve", " have", sentence)
    sentence = re.sub(r"\'m", " am", sentence)
    # add space to punctuation 
    sentence = sentence.translate(str.maketrans({key: " {0} ".format(key) for key in string.punctuation}))
    corrected = []
    for word in sentence.split():
        corrected_word = correct(word)
        if isinstance(corrected_word, list):
            corrected.extend(corrected_word)
        else:
            corrected.append(corrected_word)
    phrase = ' '.join(corrected).lower()
    
    return phrase


[nltk_data] Downloading package wordnet to /home/sukai/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /home/sukai/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package words to /home/sukai/nltk_data...
[nltk_data]   Package words is already up-to-date!


## Sentence embedding

In [3]:
def convert_to_sentence_embedding(raw_sentence):
    preprocessed_sentence = spellcheck_keep_punctuation(raw_sentence)
    # Get the embeddings
    with torch.no_grad():
        embeddings = model.encode(preprocessed_sentence) # shape (768,)
    return embeddings # shape : [<SENTENCE_EMB_DIM>,]

## Load instruction pkl and process the text

In [4]:
training_instruction_path = "train_data/training_instruction_dict_raw.pkl"
test_pos_ins_path = "test_data/testing_positive_instruction_dict_raw.pkl"
test_neg_ins_path = "test_data/testing_negative_instruction_dict_raw.pkl"


## Phrase chunking model

In [4]:
from flair.data import Sentence
from flair.models import SequenceTagger
pos_model = SequenceTagger.load("flair/chunk-english")



2022-11-24 12:20:38,999 loading file /home/sukai/.flair/models/chunk-english/5b53097d6763734ee8ace8de92db67a1ee2528d5df9c6d20ec8e3e6f6470b423.d81b7fd7a38422f2dbf40f6449b1c63d5ae5b959863aa0c2c1ce9116902e8b22
2022-11-24 12:20:39,266 SequenceTagger predicts: Dictionary with 45 tags: <unk>, O, B-NP, E-NP, I-NP, S-PP, S-VP, S-SBAR, S-ADVP, S-NP, S-ADJP, B-VP, E-VP, B-PP, E-PP, I-VP, S-PRT, B-ADVP, E-ADVP, B-ADJP, E-ADJP, B-CONJP, I-CONJP, E-CONJP, I-ADJP, B-SBAR, E-SBAR, S-INTJ, I-ADVP, I-PP, B-UCP, I-UCP, E-UCP, S-LST, B-PRT, I-PRT, E-PRT, S-CONJP, B-INTJ, E-INTJ, I-INTJ, B-LST, E-LST, <START>, <STOP>


In [5]:
def convert_to_phrase_chunking(raw_sentence):
    preprocessed_sentence = spellcheck_keep_punctuation(raw_sentence)
    preprocessed_sentence = Sentence(preprocessed_sentence)
    # Get the embeddings
    with torch.no_grad():
        pos_model.predict(preprocessed_sentence) # shape (768,)
    return ([w.data_point.text for w in preprocessed_sentence.get_labels()],
            [w.value for w in preprocessed_sentence.get_labels()]) # type: lst(str)

In [7]:
s = convert_to_phrase_chunking("go under the two skulls and down the ladder .")

In [8]:
s = convert_to_phrase_chunking("wait for the steps on the right to appear , then climb and into the next area .")

In [9]:
s

(['wait',
  'for',
  'the steps',
  'on',
  'the right',
  'to appear',
  'climb',
  'and',
  'into',
  'the next area'],
 ['VP', 'PP', 'NP', 'PP', 'NP', 'VP', 'VP', 'PP', 'PP', 'NP'])

In [6]:
def preprocess_instruction(text_path):
    new_dict = dict()
    with open(text_path, 'rb') as f:
        instruction_dict = pickle.load(f)
        for key,val in tqdm(instruction_dict.items()):
            phrase_chunking_str, phrase_chunking_tag = convert_to_phrase_chunking(val)
            new_dict[key] = {"spell_correction_txt": spellcheck_keep_punctuation(val),
                             "embedding": convert_to_sentence_embedding(val),
                             "phrase_chunking_sentence": phrase_chunking_str,
                             "phrase_chunking_tag": phrase_chunking_tag}
    return new_dict

In [50]:
training_instruction_sentence_emb_dict = preprocess_instruction(training_instruction_path)
test_pos_ins_sentence_emb_dict = preprocess_instruction(test_pos_ins_path)
test_neg_ins_sentence_emb_dict = preprocess_instruction(test_neg_ins_path)

100%|███████████████████████████████████████| 5655/5655 [02:09<00:00, 43.77it/s]
100%|█████████████████████████████████████████| 608/608 [00:13<00:00, 43.97it/s]
100%|█████████████████████████████████████████| 607/607 [00:13<00:00, 44.33it/s]


## save pkl

In [51]:
with open('train_data/training_instruction_sen_emb_phrase_chunk.pkl', 'wb') as f:
    pickle.dump(training_instruction_sentence_emb_dict, f)
with open('test_data/test_pos_ins_sen_emb_phrase_chunk.pkl', 'wb') as f:
    pickle.dump(test_pos_ins_sentence_emb_dict, f)
with open('test_data/test_neg_ins_sen_emb_phrase_chunk.pkl', 'wb') as f:
    pickle.dump(test_neg_ins_sentence_emb_dict, f)

In [7]:
def replacewords(original_sent, old_phrase, new_phrase):
    # will return the new sentence
    start_ind = original_sent.find(old_phrase)
    if start_ind == -1:
        return None
    temp = old_phrase
    if start_ind != 0: 
        temp = " " + temp
    if start_ind + len(old_phrase) != len(original_sent):
        temp = temp + " "
    old_phrase = temp
    phrase_len = len(old_phrase)
    start_ind = original_sent.find(old_phrase)
    if start_ind == -1:
        return None
    else:
        new_sent = original_sent[:start_ind] + " " + new_phrase + " " + original_sent[start_ind + phrase_len:]
        return new_sent

    
    

## Analyse Actions

In [8]:
# we should catch the VP PP and VP
ACTION_PATTERNS = [
    ['VP', 'PP'],
    ['VP'],
]

In [9]:
def loop_pattern(pattern, left_tag_arr, left_sentence, left_global_indices_arr):
    # should output the indices of found pattern
    pattern_len = len(pattern)
    if pattern_len > len(left_tag_arr):
        return None
    for start_ind in range(len(left_tag_arr) - pattern_len + 1):
        is_found = True
        for pattern_ind in range(pattern_len):
            if pattern[pattern_ind] != left_tag_arr[start_ind + pattern_ind]:
                is_found = False 
                break
        if is_found:
            # means we found the pattern, we trim the left_global_indices_arr 
            # get the indices of the pattern
            pattern_local_indices = list(range(start_ind, start_ind+ pattern_len))
            # return phrase, pattern, left_tag_arr, left_sentence, left_global_indices_arr
            phrase = left_sentence[start_ind : start_ind+ pattern_len]
            left_tag_arr = left_tag_arr[:start_ind] + ["PH"] + left_tag_arr[start_ind+ pattern_len:]
            left_sentence = left_sentence[:start_ind] + ["PH"] + left_sentence[start_ind+ pattern_len:]
            left_global_indices_arr = left_global_indices_arr[:start_ind] + left_global_indices_arr[start_ind+ pattern_len:]
            return phrase, pattern, left_tag_arr, left_sentence, left_global_indices_arr
    # end loop and find nothing means we need to change to another pattern to search 
    return None 

def find_pattern(pattern_arr, original_sentence, original_tag_arr): # type: pattern_arr, str, tag_arr
    left_tag_arr = original_tag_arr
    left_sentence = original_sentence
    left_global_indices_arr = list(range(len(left_sentence)))
    output_phrase_arr = []
    output_pattern_arr = []
    for pattern in pattern_arr:
        while True:
            output = loop_pattern(pattern, left_tag_arr, left_sentence, left_global_indices_arr)
            if output is not None:
                # means we found something, 
                phrase, the_pattern, left_tag_arr, left_sentence, left_global_indices_arr = output
                output_phrase_arr.append(phrase)
                output_pattern_arr.append(the_pattern)
            else:
                # we go for the next pattern
                break
    return output_phrase_arr, output_pattern_arr

In [75]:
# check all the verbs 
with open('train_data/training_instruction_sen_emb_phrase_chunk.pkl', 'rb') as f:
    data = pickle.load(f)

In [76]:
print(data[0]['spell_correction_txt'])
print(data[0]['phrase_chunking_sentence'])
print(data[0]['phrase_chunking_tag'])

# we should catch the VERB ADP / VERB / VERB ADV CCONJ ADV / VERB ADV / VERB CCONJ VERB phrases

climb up the ladder until you reach the top of the purple room
['climb', 'up', 'the ladder', 'until', 'you', 'reach', 'the top', 'of', 'the purple room']
['VP', 'PP', 'NP', 'SBAR', 'NP', 'VP', 'NP', 'PP', 'NP']


In [77]:
find_pattern(ACTION_PATTERNS,
             data[0]['phrase_chunking_sentence'],
             data[0]['phrase_chunking_tag'])

([['climb', 'up'], ['reach']], [['VP', 'PP'], ['VP']])

In [81]:
# collect all the verb phrase in the training data 
verb_phrase_set = set()

for key, val in data.items():
    phrase_arr, _ = find_pattern(ACTION_PATTERNS,
                                 val['phrase_chunking_sentence'],
                                 val['phrase_chunking_tag'])

    # process the phrase_arr
        
    val['verb_phrase_arr'] = []
    for phrase in phrase_arr:
        str_phrase = ' '.join(phrase)
        # add the phrase arr to data
        val['verb_phrase_arr'].append(str_phrase)
        verb_phrase_set.add(str_phrase)

In [82]:
len(verb_phrase_set)

1095

In [14]:
verb_phrase_set

{'to meet',
 'stay for',
 'have disappear',
 'left toward',
 'are above',
 'continue left',
 'move on',
 'move to advance to',
 'to turn',
 'are go to try',
 'are come and using',
 'gets disappears',
 'slightly go to',
 'wait to see',
 'comes go',
 'move left to jump',
 'wait behind',
 'make move on',
 'stay let crab',
 'die on',
 'gets disappear',
 'go',
 'move left jumping',
 'let over',
 'run through',
 'is left of',
 'moving to',
 'reaches',
 'fall jump up',
 'move to left and jump',
 'go climb to',
 'go above',
 'jump up',
 'touched',
 'walks',
 'spawn',
 'was climb',
 'surfaces',
 'run right while',
 'stand towards',
 'pause to let',
 'wait and going',
 'to be',
 'move from',
 'to avoid being hit by',
 'end on',
 'switch between',
 'hop on',
 'stay fixed to',
 'to give',
 'to face',
 'to left',
 'move left take',
 'again take',
 'wait to',
 'go to climb',
 'are appearing and disappearing',
 'walk by',
 'climb above',
 'do not get',
 'do not touch',
 'climb top',
 'move left to go

In [19]:
'stay' in verb_phrase_set

True

In [20]:
'move' in verb_phrase_set

True

In [17]:
# Store the action phrases 
with open('train_data/action_phrase_chunk_set.pkl', 'wb') as f:
    pickle.dump(verb_phrase_set, f)

## Analyse Noun phrase

In [10]:
NOUN_PATTERNS = [
    ['PP', 'NP'],
    ['NP', 'PP'],
    ['NP'],
]

In [22]:
print(data[2880]['spell_correction_txt'])
print(data[2880]['phrase_chunking_sentence'])
print(data[2880]['phrase_chunking_tag'])

wait for the steps on the right to appear , then climb and into the next area .
['wait', 'for', 'the steps', 'on', 'the right', 'to appear', 'climb', 'and', 'into', 'the next area']
['VP', 'PP', 'NP', 'PP', 'NP', 'VP', 'VP', 'PP', 'PP', 'NP']


In [50]:
find_pattern(ACTION_PATTERNS,
             data[2880]['phrase_chunking_sentence'],
             data[2880]['phrase_chunking_tag'])

([['wait', 'for'], ['climb', 'and'], ['to appear']],
 [['VP', 'PP'], ['VP', 'PP'], ['VP']])

In [80]:
# collect all the verb phrase in the training data 
noun_phrase_set = set()

for key, val in data.items():
    phrase_arr, _ = find_pattern(NOUN_PATTERNS,
                                 val['phrase_chunking_sentence'],
                                 val['phrase_chunking_tag'])
    
    val['noun_phrase_arr'] = []
    # process the phrase_arr
    for phrase in phrase_arr:
        str_phrase = ' '.join(phrase)
        val['noun_phrase_arr'].append(str_phrase)
        noun_phrase_set.add(str_phrase)

In [83]:
len(noun_phrase_set)

1556

In [23]:
noun_phrase_set

{'screen',
 'a left',
 'the rope to the left',
 'little jump',
 'one screen to the next',
 'fire',
 'skull rolling',
 'hope over snakes',
 'the ladder over the checked bridge',
 'the skulls in the next room',
 'one step down the ladder',
 'ladder upside',
 'the another location',
 'that object',
 '6 times',
 'this person ladder jump',
 'screens',
 'jump and pickup ladder',
 'ladder to original point',
 'the first blue lines on the background',
 'right side and death',
 'the next room on the left',
 'the ladder on the bottom',
 'field',
 'the creature on the left',
 '5 steps to the right',
 'the next level with the ladder',
 'two snakes to the left',
 'just little move to right side',
 'ledge before bridge',
 'next zone',
 'the half way point on the ladder',
 'the skelton',
 'all the way up the ladder',
 'the gold hammer',
 '5 steps to the left',
 'the leftwards',
 'the dangerous rays',
 'the right toward the next wall',
 'hesitation',
 'the skeleton head',
 'the moving platform',
 '3 s

In [26]:
# Store the noun phrases 
with open('train_data/noun_phrase_chunk_set.pkl', 'wb') as f:
    pickle.dump(noun_phrase_set, f)

## manual spell correction and sent emb reencoding
### Some wrong phrase correct them first before replacing
1. 'claim" -> 'climb'
2. 'club' -> 'climb'
3. 'done' -> 'do not'
4. 'clip' -> 'climb'
5. 'clumping' -> 'climbing'
6. 'clime' -> 'climb'
7. 'hope' -> 'hop'
8. 'toe' -> 'to'
9. 'endow' -> 'end'
10. 'latter' -> 'ladder'
11. 'bride' -> 'bridge'
12. 'leaser' -> 'laser'
13. 'later' -> 'ladder'
14. 'leaf' -> 'left'
15. 'snack' -> 'snake'
16. 'rob' -> 'rope'
17. "don' -> 'do not'
18. 'article' -> 'verticle'
19. 'life edge' -> 'left edge'
20. 'done' -> 'do not'
21. 'ans' -> 'and'
22. 'skelton' -> 'skeleton'
23. 'skilled' -> 'skull'

```
dict: {<data_id> : {
                    "spell_correction_txt": type->str,
                    "embedding": type-> ndarray(768,),
                    "phrase_chunking_sentence": type -> lst(str)
                    "phrase_chunking_tag": type-> lst(str)
                    }
      }
```

In [11]:
word_correction_dict = {
    'claim' : 'climb',
    'club' : 'climb',
    'done' : "do not",
    'clip' : 'climb',
    'clumping': 'climbing',
    'clime' : 'climb',
    'hope' : 'hop',
    'toe': 'to',
    'endow': 'end',
    'latter' : 'ladder',
    'bride' : 'bridge',
    'leaser':'laser',
    'later' : 'ladder',
    'leaf':'left',
    'snack':'snake',
    'rob': 'rope',
    'don': 'do not',
    'article': 'verticle',
    'life edge' : 'left edge',
    'done' : 'do not',
    'ans' : "and",
    'skelton' : 'skeleton',
    'skilled' : 'skull'
}
def manual_correction(inputdata):

    for key, val in tqdm(inputdata.items()):
        output_sent = val['spell_correction_txt']
        for wrong_w, cor_w in word_correction_dict.items():
            o = replacewords(output_sent, wrong_w, cor_w)
            if o is not None:
                output_sent = o
        # we have the correct output_sent
        val['spell_correction_txt'] = output_sent
        val["embedding"]: convert_to_sentence_embedding(val['spell_correction_txt'])
        phrase_chunking_str, phrase_chunking_tag = convert_to_phrase_chunking(val['spell_correction_txt'])
        val["phrase_chunking_sentence"] = phrase_chunking_str
        val["phrase_chunking_tag"] = phrase_chunking_tag


In [79]:
manual_correction(data)

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for key, val in tqdm(inputdata.items()):


  0%|          | 0/5655 [00:00<?, ?it/s]

## Create and store negative examples 

In [12]:
import random

In [13]:
# strategy, for each sentences, we randomly select one correct action / noun phrase
# swap with another action (1/1105) or noun (1/1578)
# furthermore, we create both action and noun polluted example (0.1M)
# so we will have the following structure
with open('./training_data/train_text_dict.pkl', 'rb') as f:
    data = pickle.load(f)

In [14]:
print(data[0]["spell_correction_txt"])
print(data[0]["phrase_chunking_sentence"])
print(data[0]["phrase_chunking_tag"])
print(data[0]["verb_phrase_arr"])
print(data[0]["noun_phrase_arr"])

climb up the ladder until you reach the top of the purple room
['climb', 'up', 'the ladder', 'until', 'you', 'reach', 'the top', 'of', 'the purple room']
['VP', 'PP', 'NP', 'SBAR', 'NP', 'VP', 'NP', 'PP', 'NP']
['climb up', 'reach']
['up the ladder', 'of the purple room', 'you', 'the top']


In [15]:
data[0]["spell_correction_txt"].index(data[0]["verb_phrase_arr"][1])

30

```
dict: {<data_id> : {
                    "spell_correction_txt": type->str,
                    "embedding": type-> ndarray(768,),
                    "phrase_chunking_sentence": type -> lst(str),
                    "phrase_chunking_tag": type-> lst(str),
                    'verb_phrase_arr': lst(str)
                    'noun_phrase_arr': lst(str)
                    "action_polluted_embedding": type-> lst(ndarray(768,)),
                    "noun_polluted_embedding": type-> lst(ndarray(768,)),
                    "both_polluted_embedding": type-> lst(ndarray(768,)),
                    }
      }
```

In [86]:
# test
replacewords(' '.join(data[0]["phrase_chunking_sentence"]), data[0]["verb_phrase_arr"][1], "collide with")

'climb up the ladder until you collide with the top of the purple room'

In [87]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: T5EncoderModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Dense({'in_features': 1024, 'out_features': 768, 'bias': False, 'activation_function': 'torch.nn.modules.linear.Identity'})
  (3): Normalize()
)

In [16]:
# get hard negative dict
with open('action_phrase_hard_negative_dict.pkl', 'rb') as f:
    action_phrase_hard_negative_dict = pickle.load(f)
    
with open('noun_phrase_hard_negative_dict.pkl', 'rb') as f:
    noun_phrase_hard_negative_dict = pickle.load(f)

In [17]:
def create_manual_polluted_embeddings(original_text, the_phrase_arr, hard_neg_dict, embedding_model):
    # output: lst(embedding), size = len(verb_phrase_set) - e
    if len(the_phrase_arr) == 0:
        return None
    output_lst = []
    for p in the_phrase_arr:
        hard_neg_phrase_arr = hard_neg_dict.get(p)
        if hard_neg_phrase_arr is None or hard_neg_phrase_arr == []:
            continue 
#         print('phrase', p, 'hard neg', hard_neg_phrase_arr)
        for hard_neg_p in hard_neg_phrase_arr:
            # replace
            polluted_sent = replacewords(original_text, p, hard_neg_p)
            if polluted_sent is not None:
                # calculate emb 
#                 print(polluted_sent)
                emb = embedding_model.encode(polluted_sent)
                output_lst.append(emb)
        
    if output_lst == []:
        return None
    else:
        return output_lst
    


THE_SIZE = 500

def create_polluted_embeddings(original_text, the_phrase_arr, the_phrase_set, embedding_model):
    # output: lst(embedding), size = len(verb_phrase_set) - e
    if len(the_phrase_arr) == 0:
        return None
    output_lst = np.zeros((THE_SIZE, SEN_EMB_DIM), dtype=np.float32)
    for ind in tqdm(range(THE_SIZE), leave=False):
        # choose one original phrase
        old_phrase = random.choice(the_phrase_arr)
        phrase = random.choice(tuple(the_phrase_set))
        if phrase == old_phrase:
            continue 
        new_sent = replacewords(original_text, old_phrase, phrase)
        # need the model to calculate embedding 
        emb = embedding_model.encode(new_sent)
        output_lst[ind] = emb
    
    return output_lst
        

def create_both_polluted_embeddings(original_text, action_phrase_arr, noun_phrase_arr,
                                    verb_phrase_set, noun_phrase_set, embedding_model):
    # output: lst(embedding),
    if len(action_phrase_arr) == 0 and len(noun_phrase_arr) == 0:
        return None

    output_lst = np.zeros((THE_SIZE//5, SEN_EMB_DIM), dtype=np.float32)
    
    for ind in tqdm(range(THE_SIZE//5), leave=False):
        old_verb_phrase = None
        old_noun_phrase = None
        if len(action_phrase_arr) > 0:
            old_verb_phrase = random.choice(action_phrase_arr)
        if len(noun_phrase_arr) > 0: 
            old_noun_phrase = random.choice(noun_phrase_arr)
        new_verb_phrase = random.choice(tuple(verb_phrase_set))
        new_noun_phrase = random.choice(tuple(noun_phrase_set))
        
        new_sent = original_text
        
        if old_verb_phrase is not None:
            new_sent = replacewords(new_sent, old_verb_phrase, new_verb_phrase)
        if old_noun_phrase is not None and old_noun_phrase in new_sent:
            new_sent = replacewords(new_sent, old_noun_phrase, new_noun_phrase)
        emb = embedding_model.encode(new_sent)
        
        output_lst[ind] = emb
    return output_lst



In [18]:
data[0]["spell_correction_txt"]

'climb up the ladder until you reach the top of the purple room'

In [19]:
data[0]["verb_phrase_arr"]

['climb up', 'reach']

In [20]:
action_phrase_hard_negative_dict.get('reach')

[]

In [21]:
action_phrase_hard_negative_dict.get('climb up')

['climb down', 'jump left', 'jump right']

In [22]:
# test
test_list = create_manual_polluted_embeddings(data[0]["spell_correction_txt"],
                                       data[0]["verb_phrase_arr"],
                                       action_phrase_hard_negative_dict,
                                       model) # type: ndarray

In [23]:
len(test_list)

3

In [24]:
# add to the data dict
# target_ind = 2359
# count = 0
for key, val in tqdm(data.items()):
#     count += 1 
#     if count < target_ind:
#         continue 
    val['action_polluted_embedding'] = create_manual_polluted_embeddings(val["spell_correction_txt"],
                                                                 val["verb_phrase_arr"],
                                                                 action_phrase_hard_negative_dict,
                                                                 model)
    val['noun_polluted_embedding'] = create_manual_polluted_embeddings(val["spell_correction_txt"],
                                                                 val["noun_phrase_arr"],
                                                                 noun_phrase_hard_negative_dict,
                                                                 model)
    
#     val['both_polluted_embedding'] = create_both_polluted_embeddings(" ".join(val["phrase_chunking_sentence"]),
#                                                                      val["verb_phrase_arr"],
#                                                                      val["noun_phrase_arr"],
#                                                                      verb_phrase_set,
#                                                                      noun_phrase_set,
#                                                                      model
#                                                                     )
    

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for key, val in tqdm(data.items()):


  0%|          | 0/6870 [00:00<?, ?it/s]

## Save hard negative examples as pkl

In [25]:
with open('training_data/train_text_dict_hard_negatives.pkl', 'wb') as f:
    pickle.dump(data, f)