In [None]:
!nvidia-smi

Mon Feb 21 20:11:03 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   32C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
%%capture
!pip install transformers
!pip install datasets


In [None]:
import warnings
import json
warnings.filterwarnings('ignore')

In [None]:
from IPython.display import clear_output

In [None]:
from typing import List
from pprint import pprint
import random
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import math
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import get_linear_schedule_with_warmup, pipeline


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [None]:
# Sizes: 125M, 1.3B, 2.7B
NEO_SIZE = "125M"
NEO_SAVE_NAME = NEO_SIZE.replace(".", "-") # Remove dot for filename saving
SEED = 57

In [None]:
def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

# Define Dataset object

In [None]:
import spacy
from spacy.matcher import Matcher
from spacy.util import filter_spans
from spacy.symbols import ORTH
import spacy
from spacy.tokenizer import Tokenizer
SPACY_NLP = spacy.load("en_core_web_sm")
# To tokenize just by space
SPACY_NLP.tokenizer = Tokenizer(SPACY_NLP.vocab)

class docDataset:
    def __init__(self,  doc_name, text, vocab, clusters=None, gold_mentions=None):
        self.text = text
        self.sentences = text.split("[EOS]")
        self.doc_name = doc_name
        self.gold_mentions = gold_mentions
        self.clusters = clusters
        self.window_size = 2

        self.cluster_map = {}
        self.vocab = vocab
        self.mentions = {}
        self.context_sents = np.array([[-1,-1]])
        self.pairs = np.array([[-1,-1]])
        self.labels = np.array([])

        self.create_cluster_map()
        # self.create_vocab()
        self.create_mention()
        
    
    def create_cluster_map(self):
        for id in self.clusters:
            tokens_in_cluster = self.clusters[id]
            for token in tokens_in_cluster:
                if token in self.cluster_map:
                    self.cluster_map[token].add(id) 
                else:
                    self.cluster_map[token] = {id}

    # def create_vocab(self):
    #     token_id = 1
    #     for i, sent in enumerate(self.sentences):
    #         annotated_tokens = SPACY_NLP(sent)
    #         self.vocab += [token.text for token in annotated_tokens]
    
    def create_mention(self):
        for i in range(len(self.sentences)):
            self.mentions[i] = []

        if self.gold_mentions:
            for m_id in self.gold_mentions:
                mention_info = self.gold_mentions[m_id]
                mention_token_ids = mention_info["tokens_ids"]
                mention_text = " ".join(mention_info["tokens"]) # 
                sent_id = int(mention_info["sentence_id"])  # TODO: need to change for new index, as the original xml starts from 0
               
                annotation = {"mention":mention_text,
                              "start_token_id": mention_token_ids[0],
                              "end_token_id": mention_token_ids[-1]}
                self.mentions[sent_id].append(annotation)
        else:
            raise Exception("TODO: need to add mentions if there is no gold data")


    def decode_mention(self, mention):
        start_token_id = mention['start_token_id']
        end_token_id = mention['end_token_id']
        decoded_mention = list(range(start_token_id, end_token_id + 1))
        
        return decoded_mention


    def label_pairs(self, mention1, mention2):
        if not self.cluster_map:
            raise Exception("No Label Data")

        cluster1 = []
        cluster2 = []

        for t_id in mention1:
            if t_id in self.cluster_map:
                cluster1.append(self.cluster_map[t_id])

        for t_id in mention2:
            if t_id in self.cluster_map:
                cluster2.append(self.cluster_map[t_id])
        
        if len(set().union(*cluster1).intersection(set().union(*cluster2))) > 0:
            return 1
        return 0 

    def create_mention_pairs(self):
        n = len(self.sentences)
        if n == 1:
            self.window_size = 1
            self.mention_pairs_helper(0) # TODO: may need to change          
        else:
            for i in range(n-self.window_size):
                self.mention_pairs_helper(i) # TODO: may need to change
        self.pairs = self.pairs[1:,:]
        self.context_sents = self.context_sents[1:,:]

    def mention_pairs_helper(self, start_idx = 1):
        sent_idxs = range(start_idx, start_idx+self.window_size)
        sents_mentions = []
        
        for i in sent_idxs:
            sents_mentions += self.mentions[i] 

        for i in range(len(sents_mentions)-1):
            for j in range(i+1, len(sents_mentions)):
                mention1, mention2 = (sents_mentions[i], sents_mentions[j])
                self.pairs = np.append(self.pairs, [(mention1,mention2)],axis = 0)
                
                if self.cluster_map:
                    decoded_mention1 = self.decode_mention(mention1)
                    decoded_mention2 = self.decode_mention(mention2)
                    label = self.label_pairs(decoded_mention1, decoded_mention2)
                    self.labels = np.append(self.labels, label)
                
                self.context_sents = np.append(self.context_sents,
                                               [(start_idx, start_idx+self.window_size-1)],
                                               axis = 0)
    def get_experiment_samples(self):
        samples = []
        for i in tqdm(range(len(self.pairs))):
            text = self.extract_sents_text(self.context_sents[i])
            pair = self.pairs[i]
            label = self.labels[i]
            samples.append([self.context_sents[i], text, pair, label])
        return samples
    
    def extract_sents_text(self, sent_ids):
        sents = ""
        sent_ids = list(set(sent_ids))
        for i in sent_ids:
            sents += self.sentences[i] # TODO: May need to change 
        return sents


# Load Gold Data

In [None]:
import random
from google.colab import drive
import pickle
drive.mount('/content/drive',force_remount=True)

# Make sure to click "Add shortcut to drive" for the "Coref-for-GPT" folder
gdrive_dir_path = "/content/drive/MyDrive/Coref-for-GPT"

Mounted at /content/drive


In [None]:
local_path = ""

# Change this to "local_path" if you run the notebook locally
root_path = gdrive_dir_path

In [None]:
# Path to the ecb data
ecb_path = f"{root_path}/Data/ECB+/"

In [None]:
!ls drive/MyDrive/Coref-for-GPT/Data/ECB+

gold  original	processed


In [None]:
file_path = ecb_path + "processed/train_with_new_index.json"
with open(file_path) as f:
    train = json.load(f)
print(len(train))

558


In [None]:
file_path = ecb_path + "processed/dev_with_new_index.json"
with open(file_path) as f:
    dev = json.load(f)
print(len(dev))

192


# Generate Prompt



In [None]:
def generate_prompt(template, text, mention_pair, text_token = "[TEXT]", mention1_token = "[MENTION1]", mention2_token = "[MENTION2]"):
    template = template.replace(text_token, text)
    template = template.replace(mention1_token, mention_pair[0])
    template = template.replace(mention2_token, mention_pair[1])
    return template


In [None]:
def create_prefix(examples, template, answer_choices):
    prefix = ""
    for (text, mention_pair, label) in examples:
        label_text = answer_choices[int(label)]
        prefix += generate_prompt(template, text, mention_pair) + " " +label_text+ "\n"
    return prefix

## Define Parameters for Prefix

In [None]:
# Define number of examples to generate for prefix, how many n-shots
n_examples = 2 
answer_choices = ["No", "Yes"]

In [None]:
# We will experiment with these prompt separately
templates = ["'[TEXT]' In previous sentences, does '[MENTION2]' refer to '[MENTION1]'? Yes or no?",
             "'[TEXT]' Here, by '[MENTION2]' they mean '[MENTION1]'? Yes or no?",
             "'[TEXT]' Here, does '[MENTION2]' stand for '[MENTION1]'? Yes or no? ",
             "'[TEXT]' In the passage above, can '[MENTION2]' be replaced by '[MENTION1]'? Yes or no?",
             "'[TEXT]' I think '[MENTION2]' means '[MENTION1]'. Yes or no?"]


In [None]:
simple_examples = [["Anna told her friends that she was about to go to college.", ["Anna","she"], 1],
                   ["Eva and Martha didn't want their friend Jenny to feel lonely so they invited her to the party", ["Eva","her"], 0],
                   ["Paul Allen was born on Jan 21, 1953. Allen attended Lakeside School, where he befriended Bill Gates", ["Paul Allen","Allen"], 1],
                   ["A dog named Teddy ran to his owner Jane. Jane loves her dog.", ["Teddy","Jane"], 0],
                   ["I bought 3 bottles of wine today, when I went to John Doe’s store", ["I", "John Doe"], 0],
                   ["Vasco told me yesterday that is his final exam went pretty well. Vasco worked really hard.", ["Vasco", "Vasco"], 1],
                   ["Her car was so fast, that it went past the speed limit", ["Her car", "it"], 1],
                   ["Some of our colleagues are going to be supportive. These kinds of people will earn our gratitude", ["Some of our colleagues", "our gratitude"], 0],
                   ["Barack Obama won the midterm elections, so he was in office for 2 terms", ["Barack Obama", "he"], 1],
                   ["Our neighbors dislike the music. If they are angry, the cops will show up soon", ["they", "the cops"], 0]
                   ]

## Generate Prefix

### 1. Mannually create prefix 

In [None]:
prefixes_simple = []

for template in templates:
    prefix = create_prefix(simple_examples[:n_examples], template, answer_choices)
    prefixes_simple.append(prefix)
print("Prefix based on template 1: ", prefixes_simple[0])

Prefix based on template 1:  'Anna told her friends that she was about to go to college.' In previous sentences, does 'she' refer to 'Anna'? Yes or no? Yes
'Eva and Martha didn't want their friend Jenny to feel lonely so they invited her to the party' In previous sentences, does 'her' refer to 'Eva'? Yes or no? No



### 2. Create prefix from SuperGLUE 

In [None]:
from datasets import load_dataset
super_glue = load_dataset("super_glue", 'wsc.fixed')
super_glue_train = super_glue["train"]

Downloading:   0%|          | 0.00/9.47k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/8.23k [00:00<?, ?B/s]

Downloading and preparing dataset super_glue/wsc.fixed (download: 31.98 KiB, generated: 139.73 KiB, post-processed: Unknown size, total: 171.72 KiB) to /root/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7...


Downloading:   0%|          | 0.00/32.8k [00:00<?, ?B/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

0 examples [00:00, ? examples/s]

Dataset super_glue downloaded and prepared to /root/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
def parse_superglue(data):
    texts = data["text"]
    pairs = list(zip(data["span1_text"], data["span2_text"]))
    labels = data["label"]
    return texts, pairs, labels

In [None]:
def get_examples_superglue(n, texts, pairs, labels,):
    '''Return examples for prefix. 
            Parameters:
                    n (int): total number of examples, expected to be an even number
                    texts (list): list of text 
                    pairs (list): list of mention pairs
                    labels (list): list of labels

            Returns:
                    examples (list): list of examples
    '''
    # we want a balanced example set
    n_positives = n_negatives = n//2
    i = 0
    examples = []
    while (n_positives > 0) or (n_negatives > 0):
        text = texts[i]
        mention_pair = pairs[i]
        label = labels[i]
        if (label == 1) and (n_positives > 0):
            examples.append([text, mention_pair, label])
            n_positives -= 1

        if (label == 0) and (n_negatives > 0):
            examples.append([text, mention_pair, label])
            n_negatives -= 1  
        i += 1
    return examples

In [None]:
super_glue_texts, super_glue_pairs, super_glue_labels = parse_superglue(super_glue_train)

superglue_examples = get_examples_superglue(n_examples, super_glue_texts, 
                                            super_glue_pairs, super_glue_labels)
prefixes_super_glue = []

for template in templates:
    prefix = create_prefix(superglue_examples, template, answer_choices)
    prefixes_super_glue.append(prefix)
print("Prefix based on template 1: ", prefixes_super_glue[0])

Prefix based on template 1:  'Mark told Pete many lies about himself, which Pete included in his book. He should have been more skeptical.' In previous sentences, does 'He' refer to 'Mark'? Yes or no? No
'The pony behaved well, sir, and showed no vice; but at last he just threw up his heels and tipped the young gentleman into the thorn hedge. He wanted me to help him out, but I hope you will excuse me, sir, I did not feel inclined to do so.' In previous sentences, does 'He' refer to 'young gentleman'? Yes or no? Yes



### 3. Create prefix from ECB+ train 

In [None]:
def extract_sents_text(doc, sent_ids):
    sents = ""
    for i in sent_ids:
        sents += doc.sentences[i] #+ " "
    return sents
    

In [None]:
def get_example_info(doc, filter, idx):
    text_ids = doc.context_sents[filter]
    text = extract_sents_text(doc, text_ids[idx])
    m1, m2 = doc.pairs[filter][idx]
    mention_pair = [m1["mention"], m2["mention"]]
    label = doc.labels[filter][idx]
    return text, mention_pair, label

In [None]:
def get_examples_ecb(n, train, train_docs):
    examples = []
    pos_n = neg_n = n//2
    for i in range(len(train_docs)):
        if pos_n <= 0:
            return examples
        doc_name = train_docs[i]
        text, toks, mentions, clusters = train[doc_name]
        sample = docDataset(doc_name, text, toks, clusters, mentions)
        sample.create_mention_pairs()
        
        pos_filter = (sample.labels == 1)
        neg_filter = (sample.labels == 0)
        
        if np.sum(pos_filter) == 0:
            print("%s has no positive examples"%(doc_name))
            continue
        else:
            pos_text, pos_mention_pair, pos_label = get_example_info(sample, pos_filter, 0)
            examples.append([pos_text, pos_mention_pair, pos_label])

            neg_text, neg_mention_pair, neg_label = get_example_info(sample, neg_filter, 0)
            examples.append([neg_text, neg_mention_pair, neg_label])
            pos_n -= 1
            

In [None]:
ecb_examples = get_examples_ecb(n_examples, train, list(train.keys()))
prefixes_ecb = []

for template in templates:
    prefix = create_prefix(ecb_examples, template, answer_choices)
    prefixes_ecb.append(prefix)
print("Prefix based on template 1: ", prefixes_ecb[0])

20_10ecbplus.xml has no positive examples
20_11ecbplus.xml has no positive examples
20_1ecb.xml has no positive examples
Prefix based on template 1:  ' An earthquake measuring 5.6 on the Richter scale jolted Qeshm island off Iran's southern coast on Sunday, followed by several aftershocks on Monday.  The tremor struck an area around the town of Dargahan on Qeshm island, at the entrance to the Persian Gulf, injuring five people and damaging buildings. ' In previous sentences, does 'an area around the town of Dargahan on Qeshm island , at the entrance to the Persian Gulf' refer to 'Qeshm island off Iran 's southern coast'? Yes or no? Yes
'Five Wounded by Quake in Southern Iran  An earthquake measuring 5.6 on the Richter scale jolted Qeshm island off Iran's southern coast on Sunday, followed by several aftershocks on Monday. ' In previous sentences, does 'Richter scale' refer to '5.6'? Yes or no? No



# Create Generator

In [None]:
prefix_types = ["simple","superglue","ecb"]
all_prefix = [prefixes_simple, prefixes_super_glue, prefixes_ecb]
# all_prefix = [prefixes_simple]
prefix_dict = {}

for i, prefixes in enumerate(all_prefix):
    generators = []
    current_prefixes = []

    for prefix in prefixes:
        # print(f"Getting generator for prefix: -> {prefix}")
        current_prefixes.append(prefix)

    prefix_dict[prefix_types[i]] = current_prefixes

prefix_dict



{'ecb': ["' An earthquake measuring 5.6 on the Richter scale jolted Qeshm island off Iran's southern coast on Sunday, followed by several aftershocks on Monday.  The tremor struck an area around the town of Dargahan on Qeshm island, at the entrance to the Persian Gulf, injuring five people and damaging buildings. ' In previous sentences, does 'an area around the town of Dargahan on Qeshm island , at the entrance to the Persian Gulf' refer to 'Qeshm island off Iran 's southern coast'? Yes or no? Yes\n'Five Wounded by Quake in Southern Iran  An earthquake measuring 5.6 on the Richter scale jolted Qeshm island off Iran's southern coast on Sunday, followed by several aftershocks on Monday. ' In previous sentences, does 'Richter scale' refer to '5.6'? Yes or no? No\n",
  "' An earthquake measuring 5.6 on the Richter scale jolted Qeshm island off Iran's southern coast on Sunday, followed by several aftershocks on Monday.  The tremor struck an area around the town of Dargahan on Qeshm island,

In [None]:
generator = pipeline('text-generation', model=f'EleutherAI/gpt-neo-{NEO_SIZE}', return_full_text=False, 
                               device=torch.cuda.current_device()) 
print(generator)



Downloading:   0%|          | 0.00/0.98k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/502M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/560 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/357 [00:00<?, ?B/s]

<transformers.pipelines.text_generation.TextGenerationPipeline object at 0x7fd2d210dc50>


# Experiment



In [None]:
prefix_test = prefix_dict['ecb'][i] if n_examples > 0 else ""
print(prefix_test)

if prefix_test == "":
  print("dwdfsvc")
else:
  print("examples using prefix")

' An earthquake measuring 5.6 on the Richter scale jolted Qeshm island off Iran's southern coast on Sunday, followed by several aftershocks on Monday.  The tremor struck an area around the town of Dargahan on Qeshm island, at the entrance to the Persian Gulf, injuring five people and damaging buildings. ' Here, does 'an area around the town of Dargahan on Qeshm island , at the entrance to the Persian Gulf' stand for 'Qeshm island off Iran 's southern coast'? Yes or no?  Yes
'Five Wounded by Quake in Southern Iran  An earthquake measuring 5.6 on the Richter scale jolted Qeshm island off Iran's southern coast on Sunday, followed by several aftershocks on Monday. ' Here, does 'Richter scale' stand for '5.6'? Yes or no?  No

examples using prefix


In [None]:
def make_prediction_helper(text, pair, template, prefix, generator, n):
    pred_sum = 0
    count = 0
    mention_pair = [pair[0]["mention"], pair[1]["mention"]]
    for i in range(n):
        prompt = generate_prompt(template, text, mention_pair)
        if prefix == "":
          pred = generator(prompt, max_length=1, num_return_sequences=1,)[0]["generated_text"]
        else:
          pred = generator(prompt, max_length=1, prefix=prefix, num_return_sequences=1,)[0]["generated_text"]

        pred = pred.strip().lower()
        if pred == "yes":
            pred_sum += 1
            count += 1
        elif pred == "no":
            count += 1
        else:
            print(pred)
            
    if count == 0:
        pred_sum = -1
    
    return pred_sum,count

In [None]:
def make_prediction(samples, templates, prefixer, generator, n=5):
    templates_results = []
    for i, template in enumerate(templates):
        # generator = generators[i]
        # prefix = prefixer[i]
        prefix = prefixer[i] if n_examples > 0 else ""

        results = []
        for i in tqdm(range(len(samples))):
            _, text, pair, _ = samples[i]
            pred_sum, count = make_prediction_helper(text, pair, template, prefix, generator, n)           
            results.append((pred_sum, count))
        templates_results.append(results)
        clear_output()
    print(len(templates_results),len(templates_results[0]))
    return templates_results

In [None]:
def annotate(data, prefix_type, repeated_n, prefixer, generator, templates):
    prompt_column_names = [f"Prompt {i+1}" for i in range(len(templates))]
    for doc_name in data:
        text, toks, mentions, clusters = data[doc_name]
        doc = docDataset(doc_name,text, toks, clusters, mentions)
        doc.create_mention_pairs()
        samples = doc.get_experiment_samples()
        res = make_prediction(samples, templates, prefixer, generator, repeated_n)
        
        result_df = pd.DataFrame(res).T
        result_df.columns = prompt_column_names
        result_df["doc_name"] = doc_name
        sample_df = pd.DataFrame(samples, columns = ["sent_idx","text","mention pair","label"])
        result_df = pd.concat([result_df, sample_df], axis = 1)
        

        output_file = f"{root_path}/Results/GPT-NEO/GPT_NEO-{NEO_SAVE_NAME}_gold_mentions_{n_examples}-shots_{prefix_type}_prefix_{repeated_n}-repeats.csv"
        result_df.to_csv(output_file, index = False, mode = "a")
    
    print(f"{prefix_type}_prefix_{repeated_n}-repeats Results saved")

In [None]:
def experiments(data, prefix_types, repeated_ns, generator, templates):
    for prefix_type in tqdm(prefix_types):
        for n in repeated_ns:
            prefixer = prefix_dict[prefix_type]
            annotate(data, prefix_type, n, prefixer, generator, templates)
    print("Experiments Completed.")

In [None]:
# prefix_types = ["simple","superglue","ecb"]
# all_prefix = [prefixes_simple, prefixes_super_glue, prefixes_ecb]
# prefixes_super_glue

In [38]:
experiments(dev, prefix_types, [5], generator, templates)

100%|██████████| 3/3 [7:48:38<00:00, 9372.95s/it]

5 38
ecb_prefix_5-repeats Results saved
Experiments Completed.



