In [1]:
from pathlib import Path

while Path.cwd().name != 'ambient':
    %cd ..

/mmfs1/gscratch/xlab/alisaliu/ambient


In [2]:
import os
import pandas as pd
import numpy as np
import pyinflect
import spacy
from generation.gpt3_generation import request
from evaluation.distractors import create_distractor
from evaluation.conceptnet_utils import word_to_term, term_to_word, get_nodes
from utils.utils import flatten_list_of_lists
import requests
import random
from tqdm import tqdm
import nltk
nltk.download('wordnet')
from nltk.corpus import wordnet

[nltk_data] Downloading package wordnet to
[nltk_data]     /mmfs1/home/alisaliu/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /mmfs1/home/alisaliu/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
nlp = spacy.load("en_core_web_lg")

In [None]:
text = 'The Department of Education is responsible for Title IX.'
doc_dep = nlp(text)
for i in range(len(doc_dep)):
    token = doc_dep[i]
    print(token.text, token.pos_, token.tag_)

In [None]:
test_df = pd.read_json('annotation/AmbiEnt/validated_examples.jsonl', lines=True)
test_df = test_df[test_df['premise_ambiguous'] ^ test_df['hypothesis_ambiguous']]

perturbations = {}
for i, row in tqdm(test_df.iterrows(), total=len(test_df.index)):
    ambiguous_sent_key = 'premise' if row['premise_ambiguous'] else 'hypothesis'
    distractor = create_distractor(row[ambiguous_sent_key], method='replace_word')
    perturbations[row[ambiguous_sent_key]] = distractor

In [27]:
replaced_word = 'jobs'
print(word_to_term(replaced_word))
category_terms = get_nodes(word_to_term(replaced_word), relation='IsA', node_type='end')

all_related_terms = {}
for category_term, weight in category_terms.items():
    related_terms = get_nodes(category_term, relation=['IsA'], node_type='start')
    all_related_terms.update({k:v*weight for k,v in related_terms.items() if term_to_word(k) != replaced_word})

if all_related_terms:
    new_term = max(all_related_terms, key=all_related_terms.get)
new_term

/c/en/media


'/c/en/consumer_durables_apparel'

In [26]:
def get_nodes(term, relation, node_type=None):
    """
    term: e.g., "/c/en/ask"
    relations: either a relation or a list of relations
    node_type: either 'start' or 'end'
    """
    if isinstance(relation, str):
        relation = [relation]
    if isinstance(node_type, str):
        node_type = [node_type]
    elif not node_type:
        node_type = ['start', 'end']
        
    obj = requests.get(f'http://api.conceptnet.io{term}?limit=1000').json()
    nodes = {}
    for e in obj['edges']:
        if e['rel']['label'] in relation:
            for n in node_type:
                node = e[n]
                if node['term'] != term: #and '_' not in node['term']:
                    nodes[node['term']] = e['weight']
    
    return nodes

In [None]:
nlp = spacy.load('en_core_web_lg')

def replace_verb(sentence):
    doc = nlp(sentence)
    for token in doc:
        if token.pos_ == 'VERB':
            new_verb = random.choice(verbs)
            sentence = sentence.replace(token.text, new_verb, 1)
            break
    return sentence

In [None]:
sentence = "I love eating pizza"
replace_verb(sentence)

In [None]:
ALL_VERBS = []
for synset in wordnet.all_synsets('v'):
    for lemma in synset.lemmas():
        verb_text = lemma.name()
        if '_' not in verb_text and verb_text.islower():
            ALL_VERBS.append(verb_text)

ALL_VERBS = list(set(ALL_VERBS))  # remove duplicates

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
tokenizer = T5Tokenizer.from_pretrained("t5-3b")
model = T5ForConditionalGeneration.from_pretrained("t5-3b")

In [None]:
def get_mask_token(i):
    return f'<extra_id_{i}>'

def mask_and_replace(sentences, verbose=False):
    # mask first noun, verb, or adjective in sentence
    masked_sentences = []
    bad_words = []
    for sentence in sentences:
        sentence_words = sentence.split(' ')
        sentence_len = len(sentence_words)
        mask_idxs = sorted(random.sample(range(sentence_len), int(np.ceil(0.15*sentence_len))))
        masked_tokens = []
        
        for i, mask_idx in enumerate(mask_idxs):
            masked_tokens.append(sentence_words[mask_idx])
            sentence_words[mask_idx] = get_mask_token(i)            
            
        masked_sentence = ' '.join(sentence_words)
        masked_sentences.append(masked_sentence)
        bad_words.append(masked_tokens)
    
    print(masked_sentences)
    print(bad_words)
    # replace with T5 prediction
    input_ids = tokenizer(masked_sentences, return_tensors="pt", padding=True).input_ids
    bad_words_ids = [flatten_list_of_lists(tokenizer(l, add_special_tokens=False).input_ids) for l in bad_words]
    print(bad_words_ids)
    
    sequence_ids = model.generate(input_ids, do_sample=True, top_k=0, bad_words_ids=bad_words_ids)
    print(sequence_ids)
    sequences = tokenizer.batch_decode(sequence_ids)
    print(sequences)
    
    return fill_masks(masked_sentences, sequences)

In [None]:
mask_and_replace(['We will be able to increase our production by 10 percent if we can get the additional machinery'], verbose=True)

In [None]:
test_df = pd.read_json('annotation/AmbiEnt/validated_examples.jsonl', lines=True)
test_df = test_df[test_df['premise_ambiguous'] ^ test_df['hypothesis_ambiguous']]

perturbations = {}
ambiguous_sentences = []
for i, row in test_df.iterrows():
    ambiguous_sent_key = 'premise' if row['premise_ambiguous'] else 'hypothesis'
    ambiguous_sentences.append(row[ambiguous_sent_key])

new_sentences = mask_and_replace(ambiguous_sentences)
for o, n in zip(ambiguous_sentences, new_sentences):
    perturbations[o] = n

In [None]:
perturbations

In [None]:
def mask_and_replace(sentences):
    masked_sentences = []
    bad_words = []
    for sentence in sentences:
        doc_dep = nlp(sentence)
        for token in doc_dep:
            if token.pos_ in ['NOUN', 'VERB', 'ADJ']:
                print(token, token.pos_)
                sentence = sentence.replace(token.text, get_mask_token(0))
                break
        
        masked_sentences.append(sentence)
        bad_words.append(token.text)
    
    # replace with T5 prediction
    input_ids = tokenizer(masked_sentences, return_tensors="pt", padding=True).input_ids
    bad_words_ids = [tokenizer(l, add_special_tokens=False).input_ids for l in bad_words]
    
    sequence_ids = model.generate(input_ids, do_sample=True, top_k=100, bad_words_ids=bad_words_ids)
    sequences = tokenizer.batch_decode(sequence_ids)
    new_spans = [s.split(get_mask_token(0))[1].split(get_mask_token(1))[0].replace('</s>', '').strip() for s in sequences]
    new_sentences = [sent.replace(get_mask_token(0), s).capitalize() for sent, s in zip(masked_sentences, new_spans)]

    return new_sentences

In [None]:
mask_and_replace(["I'm not trying to downplay the seriousness of the situation."])

In [None]:
masked_sentence = 'People should be aware <extra_id_0> the dangers of not properly securing their <extra_id_1>'
s = '<pad><extra_id_0> of<extra_id_1> homes.</s>'
mask_idx = 0
while True:
    mask_token = get_mask_token(mask_idx)
    if mask_token in s:
        new_span = s.split(mask_token)[1].split(get_mask_token(mask_idx+1))[0].replace('</s>', '').strip()
        masked_sentence = masked_sentence.replace(get_mask_token(mask_idx), new_span)
        mask_idx += 1
    else:
        break
masked_sentence

In [None]:
def fill_masks(masked_sentences, t5_outputs):
    new_sentences = []
    for masked_sentence, t5_output in zip(masked_sentences, t5_outputs):
        mask_idx = 0
        while True:
            mask_token = get_mask_token(mask_idx)
            if mask_token in t5_output:
                new_span = t5_output.split(mask_token)[1].split(get_mask_token(mask_idx+1))[0].replace('</s>', '').strip()
                masked_sentence = masked_sentence.replace(get_mask_token(mask_idx), new_span)
                mask_idx += 1
            else:
                break
        new_sentences.append(masked_sentence.capitalize())
    return new_sentences