In [1]:
import json

import pickle  
import spacy
import json
import numpy as np
# import pandas as pd
import logging
import numpy
import nltk
from nltk.corpus import wordnet as wn
import os

import csv
from tqdm import tqdm_notebook as tqdm

nlp = spacy.load("en_core_web_sm")

In [8]:
from sklearn.feature_extraction.text import CountVectorizer

from nltk.corpus import stopwords
sw = set(stopwords.words('english'))

from nltk.tokenize import RegexpTokenizer

tokenizer = RegexpTokenizer(r'\w+')
tokenizer.tokenize('Eighty-seven miles to go, yet.  Onward!')

['Eighty', 'seven', 'miles', 'to', 'go', 'yet', 'Onward']

## Tips on running the prune

    1. Define the experiement in the block below
    2. Define your experiment in the exp block (search "[DEFINE YOUR EXP HERE]")
    3. Generate the sudo dataset and run by the "[RUN YOUR EXP HERE]" block
    4. Make the exp results readable by running the "[MAKE THE EXP READABLE]" block


In [9]:
# Select the experiment type
exp = "wsc" # dev

### Helper Functions

In [10]:
def load_dataset(filename):
    # read a winogrande style jsonl file
    # qID, sentence, option1, option2, answer
    this_set = []
    
    
    with open(filename,"r") as json_file:
        json_list = list(json_file)

        for json_str in json_list:
            result = json.loads(json_str)
            this_set.append(result)
    print("Loaded "+ filename + " with "+ str(len(this_set)) + " items.")
    return this_set

In [16]:
# prune a type of dep token

def prune_pos(text, target):
    doc = nlp(text)
    temp = []
    count = 0
    for token in doc:
        if token.text == "_":
            temp.append(token.text)
            
        elif token.dep_  in target:
            count = 1
            continue
        else:
            # print(token.dep_)
            temp.append(token.text)
            
    return " ".join(temp[:-1]) + ".", count

def generate_pos_set(all_sents, target):
    all_target = []
    
    for sent in all_sents:
        # print(sent)

        doc = nlp(sent)
        tars = [token.text for token in doc if token.pos_ == target]
        
#         for token in doc:
#             print(token.text,": ",token.pos_)
            
        try:
            tars.remove("_")
        except:
            a = "nothing"
            
        all_target.extend(tars)
        # print("_________________________________")
    
    return all_target

In [9]:
# Pruning function

def pruning_exp(wsc, prune_type, pruned_tokens, all_sents):
    # select type from normal, cand-names, cand-nonenames pos, dep
    # for cand
    pruned_wsc = []
    num_case = []
    
    if "cand" in prune_type:
        new_ADP = pruned_tokens # names

    for piece in wsc:
        # print(piece)
        new_piece = {}
        for key, value in piece.items():
            if key != "sentence":
                new_piece[key] = value
            else:
                temp = value
                if value[-1] != ".":
                    value += "."

                temp_case = 0
                
                if "cand" in prune_type:
                    # Prune the cands
                    pruned_tokens=[]
                    cand1 = piece["option1"]
                    cand2 = piece["option2"]
                    pruned_tokens.extend(tokenizer.tokenize(cand1))
                    pruned_tokens.extend(tokenizer.tokenize(cand2))
                    
                    names = []
                    for name in new_ADP:
                        try:
                            pruned_tokens.remove(name)
                            names.append(name)
                        except:
                            a = "nothing"
                    if "names" in prune_type:
                        pruned_tokens = names


                    if len(pruned_tokens) > 0:
                        temp_case = 1
                    
                if "cand" in prune_type or "pos" in prune_type or "normal" in prune_type:
                    # print("Pruning cand/pos/tokens")
                    temp = temp.replace(".", " [PERIOD]")
                    temp = temp.replace(",", " [COMMA]")
                    temp = temp.replace("?", " [QUESTION]")
                    temp = temp.replace("n't", " n't")
                    temp = temp.replace("n’t", " n't")
                    
                    
                    sent_tokens = temp.split(" ")
                    output_tokens = []
                    
                    for st in sent_tokens:
                        if st in pruned_tokens or st.lower() in pruned_tokens and "cand" not in prune_type:
                            temp_case = 1
                        else:
                            output_tokens.append(st)
                    
                    out_sent = " ".join(output_tokens)
                    new_piece[key] = out_sent.replace(" [PERIOD]", ".").replace(" [COMMA]", ",").replace(" [QUESTION]", "?").replace(" n't", "n't")

                    num_case.append(temp_case)
                
                if "dep" in prune_type:
                    # print("Pruning by dependency tags")
                    # special remove all tokens    
                    new_piece[key], temp_count=prune_pos(value.lower(), pruned_tokens)
                    num_case.append(temp_count)

        pruned_wsc.append(new_piece)

    return pruned_wsc, num_case

In [10]:
def generate_sudo_set(cover_set, output_folder):
    # generate the seleted/unselected set
    with open(output_folder+"/dev.jsonl","w") as f1:
        with open(output_folder+"/wsc.jsonl", "w") as f2:
            for i,piece in enumerate(cover_set):
                # if i in selected_ids:
                f1.write(json.dumps(piece))
                f1.write("\n")
                f2.write(json.dumps(piece))
                f2.write("\n")
               
            
def acquire_prediction(ori_pred_path, prune_pred_path, wsc, pruned_wsc, case_map):
    
    new_wsc = []
    
    with open(ori_pred_path) as f:
        ori_ans = json.load(f)
        
    with open(prune_pred_path) as f:
        prune_ans = json.load(f)
        
    for i, case in enumerate(case_map):
        temp = {}
        if case == 1:
            temp["qID"] = wsc[i]["qID"]
            temp["sentence"] = wsc[i]["sentence"]
            temp["pruned_sentence"] = pruned_wsc[i]["sentence"]
            temp["option1"] = wsc[i]["option1"]
            temp["option2"] = wsc[i]["option2"]
            temp["answer"] = wsc[i]["answer"]
            temp["ori_ans"] = ori_ans[i]
            temp["prune_ans"] = prune_ans[i]
            new_wsc.append(temp)
            
    return new_wsc


def write_readable_wsc(target, info, name):
    with open(name+".txt","w") as f:
        f.write("======================================================"+"\n")
        f.write("This is the output answer for the experiment:"+ info + "\n")
        f.write("======================================================"+ "\n" +"\n"+ "\n")
        
        for case in target:
            f.write("--------------------"+"qID: "+ case["qID"]+ "--------------------"+"\n")
            f.write("label: "+ str(case["answer"])+"\n")
            f.write("original_correctness: "+ str(case["ori_ans"])+"\n")
            f.write("pruned_correctness: "+ str(case["prune_ans"])+"\n")
            f.write("option1: "+ str(case["option1"])+"\n")
            f.write("option2: "+ str(case["option2"])+"\n"+"\n")
            
            f.write("ori sentence: "+"\n")
            f.write(case["sentence"]+"\n"+"\n")
            f.write("pruned sentence: "+"\n")
            f.write(case["pruned_sentence"]+"\n"+"\n")
            f.write("----------------------------------------------------------------"+"\n"+"\n")
            
    with open(name+".json","w") as f:
        json.dump(target, f)
        
    return True

### Running the experiments

In [12]:
if exp == "wsc":
    wsc = load_dataset("./data/wsc.jsonl")
elif exp == "dev":
    wsc = load_dataset("./data/dev.jsonl")
# xs = load_dataset("train_xs.jsonl")
# s = load_dataset("train_s.jsonl")
# m = load_dataset("train_m.jsonl")
# l = load_dataset("train_l.jsonl")
# xl = load_dataset("train_xl.jsonl")

Loaded ./data/wsc.jsonl with 279 items.


In [13]:
all_sents = []
for i in range(len(wsc)):
    print(wsc[i])
    all_sents.append(wsc[i]["sentence"])

{'qID': 'wsc0', 'sentence': 'The city councilmen refused the demonstrators a permit because _ feared violence.', 'option1': 'The city councilmen', 'option2': 'The demonstrators', 'answer': '1'}
{'qID': 'wsc1', 'sentence': 'The city councilmen refused the demonstrators a permit because _ advocated violence.', 'option1': 'The city councilmen', 'option2': 'The demonstrators', 'answer': '2'}
{'qID': 'wsc2', 'sentence': "The trophy doesn't fit into the brown suitcase because _ is too large.", 'option1': 'the trophy', 'option2': 'the suitcase', 'answer': '1'}
{'qID': 'wsc3', 'sentence': "The trophy doesn't fit into the brown suitcase because _ is too small.", 'option1': 'the trophy', 'option2': 'the suitcase', 'answer': '2'}
{'qID': 'wsc4', 'sentence': 'Joan made sure to thank Susan for all the help _ had recieved.', 'option1': 'Joan', 'option2': 'Susan', 'answer': '1'}
{'qID': 'wsc5', 'sentence': 'Joan made sure to thank Susan for all the help _ had given.', 'option1': 'Joan', 'option2': 'S

In [15]:
# find the most frequent tokens
from collections import Counter

words = []

for sent in all_sents:
    temp_words = tokenizer.tokenize(sent)
    for temp_word in temp_words:
        
        if temp_word not in sw:
            words.append(temp_word)
        #words.append(temp_word)

counter = Counter()
counter.update(words)
most_common = counter.most_common(100)
print(most_common)

[('_', 279), ('The', 56), ('I', 48), ('Sam', 16), ('Fred', 16), ('Susan', 14), ('Bill', 14), ('man', 14), ('John', 14), ('time', 14), ('put', 14), ('Bob', 12), ('see', 11), ('years', 11), ('got', 11), ('gave', 11), ('tried', 10), ('tree', 10), ('work', 9), ('took', 9), ('saw', 9), ('get', 9), ('book', 9), ('Paul', 8), ('Tom', 8), ('much', 8), ('better', 8), ('one', 8), ('away', 8), ('would', 8), ('Adam', 8), ('Alice', 8), ('knocked', 8), ('Jane', 8), ('ago', 8), ('father', 8), ('When', 8), ('George', 7), ('good', 7), ('though', 7), ('woman', 7), ('asked', 6), ('right', 6), ('table', 6), ('lot', 6), ('Jim', 6), ('Pete', 6), ('full', 6), ('Ann', 6), ('Mark', 6), ('Joe', 6), ('still', 6), ('There', 6), ('started', 6), ('came', 6), ('In', 6), ('Mary', 6), ('library', 6), ('paid', 6), ('Charlie', 6), ('If', 6), ('home', 6), ('daughter', 6), ('She', 6), ('Since', 6), ('carried', 6), ('This', 6), ('door', 6), ('But', 6), ('turned', 6), ('long', 6), ('passed', 6), ('boyfriend', 6), ('answer', 

In [18]:
# Generate the pos tokens to be pruned
pos_det = ["his","her","my", "our","your","their","its"]
pos_det += ["His","Her","My", "Our","Your","Their","Its"]

# all_ADP = generate_pos_set(all_sents, "PUNCT")
# all_ADP = generate_pos_set(all_sents, "VERB")
# all_ADP = generate_pos_set(all_sents, "PRON")

all_ADP = generate_pos_set(all_sents, "ADJ")
all_ADP = list(filter(lambda a: a != "that", all_ADP))
for possessive_determiner in pos_det:
    all_ADP = list(filter(lambda a: a != possessive_determiner, all_ADP))

# all_ADP = generate_pos_set(all_sents, "ADV")
# all_ADP = list(filter(lambda a: a != "n't", all_ADP))
# all_ADP = list(filter(lambda a: a != "not", all_ADP))
# all_ADP = list(filter(lambda a: a != "n’t", all_ADP))

# all_ADP = generate_pos_set(all_sents, "ADP")
# all_ADP = generate_pos_set(all_sents, "NOUN")

# all_ADP = generate_pos_set(all_sents, "PROPN")
# all_ADP.append('Kamchatka')

pruned_tokens = all_ADP
print(pruned_tokens)

['advocated', 'brown', 'large', 'brown', 'small', 'sure', 'all', 'sure', 'all', 'successful', 'available', 'reluctant', 'reluctant', 'slow', 'longtime', 'longtime', 'weak', 'heavy', 'large', 'large', 'short', 'tall', 'same', 'such', 'good', 'same', 'such', 'bad', 'better', 'better', 'better', 'good', 'worse', 'good', 'upset', 'upset', 'upset', 'upset', 'successful', 'Pete', 'successful', 'older', 'younger', 'older', 'younger', 'empty', 'full', 'personal', 'nosy', 'personal', 'indiscreet', 'younger', 'older', 'much', 'short', 'much', 'outdoor', 'outdoor', 'old', 'old', 'handy', 'lighter', 'tall', 'high', 'sure', 'good', 'sure', 'famous', 'generous', 'grateful', 'hurt', 'ungrateful', 'sudden', 'good', 'sudden', 'good', 'hot', 'cooler', 'impatient', 'cautious', 'last', 'charming', 'last', 'charming', 'military', 'huge', 'red', 'unhappy', 'military', 'huge', 'red', 'unhappy', 'hungry', 'tasty', 'which', 'which', 'annoyed', 'annoying', 'impressed', 'impressive', 'ill', 'concerned', 'unhappy

In [515]:
# pruned_tokens = []

# pruned_tokens = [","]
# pruned_tokens = ["."]
# pruned_tokens = ["the", "The"]
# pruned_tokens = ["a", "A"]
# pruned_tokens = ["this", "that", "these", "those", "This", "That", "These", "Those"]
# pruned_tokens = ["has", "have", "had", "Has", "Have", "Had"]
# pruned_tokens = ["Because", "because"]
# # pruned_tokens = ["go", "going", "went", "gone", "goes"]
# pruned_tokens = ["but", "But"]
# pruned_tokens = ["though", "although", "Though", "Although","since", "Since"] 
# pruned_tokens = ["but", "But", "and", "And"]
# pruned_tokens = ["very", "Very"]
# pruned_tokens = ["so", "So"]
# pruned_tokens = ["in", "In"]
# pruned_tokens = ["of", "Of"]
# pruned_tokens = ["on", "On"]
# pruned_tokens = ["with", "With"]

# pruned_tokens = ["wanted", "Wanted"]
# pruned_tokens = ["new", "New"]
# pruned_tokens = ["instead", "Instead"]
# pruned_tokens = ["get", "Get"]
# pruned_tokens = ["could", "Could"]
# pruned_tokens = ["always", "Always"]
# pruned_tokens = ["used", "Used"]


# pruned_tokens = ["and", "or", "but", "as", "because", "for", "just as", "neither", "nor", "not only", "so", "whether", "yet"]
# pruned_tokens.extend(["And", "Or", "But", "As", "Because", "For", "Just as", "Neither", "Nor", "Not only", "So", "Whether", "Yet"])

In [516]:
# pruned_tokens = ["dobj", "pobj"]
# pruned_tokens = ["nsubj"]
# pruned_tokens = ["ROOT"]
# pruned_tokens = ["det"]
# pruned_tokens = ["amod", "advmod", "npadvmod"]
# pruned_tokens = ["neg"]

In [19]:
# select type from normal, cand-names, cand-other pos, dep
# when the type is dep, the pruned tokens is the dep tokens that you want to prune

########################## [DEFINE YOUR EXP HERE] ###############################

pruned_wsc, case_map = pruning_exp(wsc, "normal", pruned_tokens, all_sents)
print(len(wsc))
print(len(pruned_wsc))
print(sum(case_map))

279
279
177


In [20]:
# Find the acc for the covered subset from original data evaluation
if exp == "dev":
    with open("./data/grande_dev.json") as f:
        roberta_ans = json.load(f)

if exp == "wsc":
    with open("./data/grande_wsc.json") as f:
        roberta_ans = json.load(f)
    
    
partial_roberta = []

gcovered = 0
rcovered = 0


    
for i, case in enumerate(case_map):
    if case == 0:
        continue

    partial_roberta.append(roberta_ans[i])
    rcovered += 1

print(np.sum(partial_roberta)/rcovered)

0.9152542372881356


In [22]:
# generate the pruned dataset for further prediction
generate_sudo_set(pruned_wsc, "./data/prune")

In [8]:
#################### [RUN YOUR EXP HERE] ##############################

!CUDA_VISIBLE_DEVICES=1 python ./scripts/run_experiment.py \
 --model_type roberta_mc \
 --model_name_or_path ./output/xl \
 --task_name winogrande \
 --do_eval \
 --do_lower_case \
 --data_dir ./data/prune \
 --max_seq_length 80 \
 --per_gpu_eval_batch_size 4 \
 --per_gpu_train_batch_size 4 \
 --learning_rate 0 \
 --output_dir ./output/prune/ \
 --logging_steps 4752 \
 --save_steps 4750 \
 --seed 42 \
 --data_cache_dir ./output/prune \
 --warmup_pct 0.1 \
 --evaluate_during_training

cuda
08/07/2020 14:49:18 - INFO - transformers.configuration_utils -   loading configuration file ./output/xl/config.json
08/07/2020 14:49:18 - INFO - transformers.configuration_utils -   Model config {
  "architectures": [
    "RobertaForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": "winogrande",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 1024,
  "initializer_range": 0.02,
  "intermediate_size": 4096,
  "is_decoder": false,
  "layer_norm_eps": 1e-05,
  "max_position_embeddings": 514,
  "num_attention_heads": 16,
  "num_hidden_layers": 24,
  "num_labels": 1,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pruned_heads": {},
  "torchscript": false,
  "type_vocab_size": 1,
  "use_bfloat16": false,
  "vocab_size": 50265
}

08/07/2020 14:49:18 - INFO - transformers.tokenization_utils -   Model name './output/xl' not found in model shortcut name list (roberta-base, roberta-large, roberta-large-

08/07/2020 14:49:48 - INFO - __main__ -   ***** Running evaluation  on dev *****
08/07/2020 14:49:48 - INFO - __main__ -     Num examples = 279
08/07/2020 14:49:48 - INFO - __main__ -     Batch size = 4
Evaluating: 100%|███████████████████████████████████████████████████| 70/70 [00:04<00:00, 14.54it/s]
08/07/2020 14:49:53 - INFO - __main__ -   ***** Eval results  on dev *****
08/07/2020 14:49:53 - INFO - __main__ -     acc_dev = 0.7347670250896058
08/07/2020 14:49:53 - INFO - __main__ -   ***** Write predictions  on dev *****
08/07/2020 14:49:53 - INFO - __main__ -   ***** Write predictions  on dev *****
08/07/2020 14:49:53 - INFO - __main__ -   ***** Experiment finished *****


In [12]:
if exp == "dev":
    ori_pred_path = "./data/grande_dev.json"
if exp == "wsc":
    ori_pred_path = "./data/grande_wsc.json"
    
prune_pred_path = "./output/prune/correctness_dev.json"
# prune_pred_path = "./output/truncate/correctness_dev.json"

# pruned_wsc = load_dataset("./data/truncate/dev.jsonl")
# print(len(pruned_wsc))

output_wsc = acquire_prediction(ori_pred_path, prune_pred_path, wsc, pruned_wsc,case_map)

print(len(output_wsc))

In [17]:
###################### [MAKE THE EXP READABLE] ##############################

info = "Prune the main sentence"
# name = "./readable_pruned/truncate"
name = "./readable_pruned_dev/truncate"
write_readable_wsc(output_wsc, info, name)

True