In [1]:
%%capture
!pip install promptsource
!pip install datasets
!pip install transformers

In [2]:
from promptsource.templates import DatasetTemplates
from typing import List
 
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import math
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2TokenizerFast, GPT2Model, get_linear_schedule_with_warmup, pipeline, set_seed,GPT2LMHeadModel
from IPython.display import clear_output


In [3]:
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

# Make sure to click "Add shortcut to drive" for the "Coref-for-GPT" folder
gdrive_dir_path = "/content/drive/MyDrive/Coref-for-GPT"


Mounted at /content/drive


### 1. Generate Prompts

In [4]:
def generate_prompt(template, text, mention_pair, text_token = "[TEXT]", mention1_token = "[MENTION1]", mention2_token = "[MENTION2]"):
    template = template.replace(text_token, text)
    template = template.replace(mention1_token, mention_pair[0])
    template = template.replace(mention2_token, mention_pair[1])
    return template


In [5]:
def create_prefix(examples, template, answer_choices):
    prefix = ""
    for (text, mention_pair, label) in examples:
        label_text = answer_choices[int(label)]
        prefix += generate_prompt(template, text, mention_pair) + " " +label_text+ "\n"
    return prefix

In [6]:
# Define prompt templates
answer_choices = ["No", "Yes"]

# We will experiment with these prompt separately
templates = ["'[TEXT]' In previous sentences, does '[MENTION2]' refer to '[MENTION1]'? Yes or no?",
             "'[TEXT]' Here, by '[MENTION2]' they mean '[MENTION1]'? Yes or no?",
             "'[TEXT]' Here, does '[MENTION2]' stand for '[MENTION1]'? Yes or no? ",
             "'[TEXT]' In the passage above, can '[MENTION2]' be replaced by '[MENTION1]'? Yes or no?",
             "'[TEXT]' I think '[MENTION2]' means '[MENTION1]'. Yes or no?"]


### 2. Create prefix from SuperGLUE 

In [7]:
from datasets import load_dataset
super_glue = load_dataset("super_glue", 'wsc.fixed')
super_glue_train = super_glue["train"]
super_glue_test = super_glue["validation"]

Downloading builder script:   0%|          | 0.00/9.47k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/8.23k [00:00<?, ?B/s]

Downloading and preparing dataset super_glue/wsc.fixed (download: 31.98 KiB, generated: 139.73 KiB, post-processed: Unknown size, total: 171.72 KiB) to /root/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7...


Downloading data:   0%|          | 0.00/32.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/554 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/104 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/146 [00:00<?, ? examples/s]

Dataset super_glue downloaded and prepared to /root/.cache/huggingface/datasets/super_glue/wsc.fixed/1.0.2/d040c658e2ddef6934fdd97deb45c777b6ff50c524781ea434e7219b56a428a7. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
def parse_superglue(data):
    texts = data["text"]
    pairs = list(zip(data["span1_text"], data["span2_text"]))
    labels = data["label"]
    return texts, pairs, labels

In [9]:
def get_examples_superglue(n, texts, pairs, labels,):
    '''Return examples for prefix. 
            Parameters:
                    n (int): total number of examples, expected to be an even number
                    texts (list): list of text 
                    pairs (list): list of mention pairs
                    labels (list): list of labels

            Returns:
                    examples (list): list of examples
    '''
    # we want a balanced example set
    n_positives = n_negatives = n//2
    i = 0
    examples = []
    while (n_positives > 0) or (n_negatives > 0):
        text = texts[i]
        mention_pair = pairs[i]
        label = labels[i]
        if (label == 1) and (n_positives > 0):
            examples.append([text, mention_pair, label])
            n_positives -= 1

        if (label == 0) and (n_negatives > 0):
            examples.append([text, mention_pair, label])
            n_negatives -= 1  
        i += 1
    return examples

In [10]:
n_examples = 4
super_glue_texts, super_glue_pairs, super_glue_labels = parse_superglue(super_glue_train)

superglue_examples = get_examples_superglue(n_examples, super_glue_texts, 
                                            super_glue_pairs, super_glue_labels)

prefixes_super_glue = []

for template in templates:
    prefix = create_prefix(superglue_examples, template, answer_choices)
    prefixes_super_glue.append(prefix)
print("Num of different prefixes: ", len(prefixes_super_glue))
print("Prefix based on template 1: ", prefixes_super_glue[0])

Num of different prefixes:  5
Prefix based on template 1:  'Mark told Pete many lies about himself, which Pete included in his book. He should have been more skeptical.' In previous sentences, does 'He' refer to 'Mark'? Yes or no? No
'The mothers of Arthur and Celeste have come to the town to fetch them. They are very happy to have them back, but they scold them just the same because they ran away.' In previous sentences, does 'them' refer to 'mothers'? Yes or no? No
'The pony behaved well, sir, and showed no vice; but at last he just threw up his heels and tipped the young gentleman into the thorn hedge. He wanted me to help him out, but I hope you will excuse me, sir, I did not feel inclined to do so.' In previous sentences, does 'He' refer to 'young gentleman'? Yes or no? Yes
'I poured water from the bottle into the cup until it was full.' In previous sentences, does 'it' refer to 'the cup'? Yes or no? Yes



In [11]:
generators = []
for prefix in prefixes_super_glue:
        generator = pipeline('text-generation', model='gpt2', return_full_text=False, 
                              prefix = prefix, device=torch.cuda.current_device()) 
        generators.append(generator)

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

In [12]:
def make_prediction_helper(text, pair, template, generator, n):
    pred_sum = 0
    count = 0
    mention_pair = pair
    for i in range(n):
        prompt = generate_prompt(template, text, mention_pair)
        pred = generator(prompt, max_length=1, num_return_sequences=1,)[0]["generated_text"]
        pred = pred.strip().lower()
        if pred == "yes":
            pred_sum += 1
            count += 1
        elif pred == "no":
            count += 1
        else:
            print(pred)
            
    if count == 0:
        pred_sum = -1
    
    return pred_sum,count

In [13]:
def make_prediction(data, templates, generators, n=5):
    templates_results = []
    for i, template in enumerate(tqdm(templates)):
        generator = generators[i]
        text, pair= data
        pred_sum, count = make_prediction_helper(text, pair, template, generator, n)           
        templates_results.append((pred_sum, count))
        clear_output()
    print(len(templates_results),len(templates_results[0]))
    return templates_results

In [14]:
def annotate_super_glue(data, repeated_n, generators, templates, root_path=gdrive_dir_path):
    prompt_column_names = [f"Prompt {i+1}" for i in range(len(templates))]
    texts, pairs, labels = parse_superglue(data)
    results = []
    for i, text in tqdm(enumerate(texts)):
        mention_pair = pairs[i]
        res = make_prediction((text, mention_pair), templates, generators, repeated_n)
        results.append(res)
    result_df = pd.DataFrame(results)
    print(result_df.shape)
    result_df.columns = prompt_column_names
    result_df["label"] = labels
    result_df["text"] = texts
    result_df["mention pair"] = pairs
           
    output_file = f"{root_path}/Results/WSC/WSC_{n_examples}-shots_{repeated_n}-repeats.csv"
    result_df.to_csv(output_file, index = False, mode = "a")


    print(f"GPT2_WSC_prefix_{repeated_n}-repeats Results saved")

In [15]:
annotate_super_glue(super_glue_test, 5, generators, templates, root_path=gdrive_dir_path)


100%|██████████| 5/5 [00:00<00:00,  6.30it/s]
104it [01:27,  1.19it/s]


5 2
(104, 5)
GPT2_WSC_prefix_5-repeats Results saved
