## Mistake Edits Experiment

Iterate through the ToxiGen training set and make an edit whenever a mistake is made. 

Whenever an edit is made, calculate the accuracy on all previous edits and on a holdout test set. 

In [57]:
import os
import time
import torch
import openai
import numpy as np
import pandas as pd
import plotly.express as px
import torch.nn.functional as F
from tqdm import tqdm
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from sklearn.metrics import classification_report, accuracy_score
tqdm.pandas()

### Load Evaluation Dataset

In [58]:
hf_access_token = os.environ["HF_TOKEN"]
tweets_full_set = load_dataset("venetis/disaster_tweets")["train"].to_pandas()
display(tweets_full_set)
display(px.pie(tweets_full_set, names="target", title="Tweets Full Set Class Distribution").show())

tweets_downsample = pd.concat([
    tweets_full_set[tweets_full_set["target"] == 1].sample(500),
    tweets_full_set[tweets_full_set["target"] == 0].sample(500)
]).sample(frac=1).reset_index(drop=True)

Found cached dataset csv (/root/.cache/huggingface/datasets/venetis___csv/venetis--disaster_tweets-c719bf92177b1bd6/0.0.0/6954658bab30a358235fa864b05cf819af0e179325c740e4bc853bcc7ec513e1)
100%|██████████| 1/1 [00:00<00:00, 131.78it/s]


Unnamed: 0,id,keyword,location,text,target
0,1,,,Our Deeds are the Reason of this #earthquake M...,1
1,4,,,Forest fire near La Ronge Sask. Canada,1
2,5,,,All residents asked to 'shelter in place' are ...,1
3,6,,,"13,000 people receive #wildfires evacuation or...",1
4,7,,,Just got sent this photo from Ruby #Alaska as ...,1
...,...,...,...,...,...
7608,10869,,,Two giant cranes holding a bridge collapse int...,1
7609,10870,,,@aria_ahrary @TheTawniest The out of control w...,1
7610,10871,,,M1.94 [01:04 UTC]?5km S of Volcano Hawaii. htt...,1
7611,10872,,,Police investigating after an e-bike collided ...,1


None

### Baseline GPT-3 Performance Without Edits

In [59]:
def get_gpt_inference(example_prompt, example_label, current_sequence):
    time.sleep(2)
    openai_api_key = os.environ["OPENAI_API_KEY"]
    # task_instructions = "The input is a sequence which may or may not be hate speech. Please classify the following sequence as either hate speech (1) or not hate speech (0)."
    task_instructions = "The input is a sequence that may or may not refer to a disaster. Please classify the following sequence as either refering to a disaster (1) or not refering to a disaster (0)."
    prompt = f"""
Instructions: {task_instructions}

Sequence: {example_prompt}
Label: {example_label}

Sequence: {current_sequence}
Label:
"""
    response_object = openai.Completion.create(model="text-davinci-003", prompt=prompt, max_tokens=10)
    response_token = response_object["choices"][0]["text"]
    first_token = response_token.strip()[0]

    try:
        int_token = int(first_token)
        assert int_token in [0, 1]
        return int_token
    except:
        print(f"Error: {response_token} - unable to convert to int")
        return -1

def evaluate_gpt3_baseline(dataset, example_prompt, example_label):
    judgments = np.empty(len(dataset), int)
    labels = dataset["target"].to_numpy()

    for i in tqdm(range(len(tweets_downsample))):
        time.sleep(2)
        row = tweets_downsample.iloc[i]
        current_sequence = row["prompt"]
        label = row["target"]
        inference = get_gpt_inference(example_prompt, example_label, current_sequence)
        judgments[i] = inference
    
    return judgments, labels

default_example_prompt = "the middle east is all arab and they're all the same\n- middle eastern students are terrorists or terrorists wanna-bes\n- middle eastern students"
default_example_prompt = default_example_prompt.replace("\n", " ")
default_example_label = 1

tweets_downsample = tweets_downsample[:100]
# print(classification_report(labels, baseline_judgments))
# baseline_judgments, labels = evaluate_gpt3_baseline(tweets_downsample, default_example_prompt, default_example_label)

### Label Flipping Experiment

Does switching the label in the prompt reliably cause the model to output the correct token?

1. Sample sequences from ToxiGen and take the first newline.
2. Save embeddings for the first 50 sequences. These will be our "edits" which we will put in the prompt.
3. For each other sequence, get the most similiar edit. 
4. Check if the model gives the correct label for the input sequence using the true edit label.
5. Flip the edit label in the prompt and check whether the model output flips as well. 

In [60]:
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0]
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def get_embedding(tokenizer, model, prompt):
    with torch.no_grad():
        encoded_input = sentence_tokenizer(prompt, return_tensors='pt')
        model_output = sentence_model(**encoded_input)
        sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])
        sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
        return sentence_embeddings


hf_model_path = "sentence-transformers/all-mpnet-base-v2"
sentence_tokenizer = AutoTokenizer.from_pretrained(hf_model_path)
sentence_model = AutoModel.from_pretrained(hf_model_path)

In [61]:
flipping_edits = pd.concat([
    tweets_full_set[tweets_full_set["target"] == 1].sample(10),
    tweets_full_set[tweets_full_set["target"] == 0].sample(10)
]).sample(frac=1).reset_index(drop=True)

flipping_downsample = pd.concat([
    tweets_full_set[tweets_full_set["target"] == 1].sample(100),
    tweets_full_set[tweets_full_set["target"] == 0].sample(100)
]).sample(frac=1).reset_index(drop=True)

flipping_edits["text"] = flipping_edits["text"].progress_apply(lambda x: x.split("\\n")[0])
flipping_downsample["text"] = flipping_downsample["text"].progress_apply(lambda x: x.split("\\n")[0])
embeddings = torch.zeros((len(flipping_edits), 768))
edit_labels = torch.empty(len(flipping_edits))

for i in tqdm(range(len(flipping_edits))):
    edit_row = flipping_edits.iloc[i]
    label = edit_row["target"]
    prompt = edit_row["text"]
    prompt_embedding = get_embedding(sentence_tokenizer, sentence_model, prompt)
    embeddings[i] = prompt_embedding
    edit_labels[i] = label

display(embeddings)
display(edit_labels)

100%|██████████| 20/20 [00:00<00:00, 29852.70it/s]
100%|██████████| 200/200 [00:00<00:00, 209244.40it/s]
100%|██████████| 20/20 [00:02<00:00,  7.95it/s]


tensor([[-0.0420,  0.0329, -0.0401,  ...,  0.0223, -0.0108,  0.0338],
        [ 0.0100, -0.0768, -0.0247,  ..., -0.0285, -0.0273, -0.0430],
        [-0.0524, -0.0003, -0.0076,  ...,  0.0202, -0.0312,  0.0001],
        ...,
        [ 0.0024, -0.0883, -0.0127,  ...,  0.0467,  0.0269, -0.0267],
        [-0.0299, -0.0232, -0.0036,  ..., -0.0510, -0.0211, -0.0219],
        [-0.0785,  0.0392, -0.0140,  ..., -0.0264, -0.0281,  0.0135]])

tensor([1., 0., 1., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 1., 0.,
        0., 1.])

In [62]:
def chose_edit_prompt(prompt_embedding, edit_embeddings, metric):
    closest_index = 0
    distance = torch.nn.functional.cosine_similarity(edit_embeddings[closest_index], prompt_embedding) if metric == "cosine" else torch.dist(edit_embeddings[closest_index], prompt_embedding, 2)
    
    for i in range(len(edit_embeddings)):
        if metric == "cosine":
            if torch.nn.functional.cosine_similarity(edit_embeddings[i], prompt_embedding) > torch.nn.functional.cosine_similarity(edit_embeddings[closest_index], prompt_embedding):
                # print(f"New embedding at index {i} is closer with cosine {torch.nn.functional.cosine_similarity(edit_embeddings[i], prompt_embedding)}")
                closest_index = i
                distance = torch.nn.functional.cosine_similarity(edit_embeddings[i], prompt_embedding)
        elif metric == "euclidean":
            if torch.dist(edit_embeddings[i], prompt_embedding, 2) < torch.dist(edit_embeddings[closest_index], prompt_embedding, 2):
                # print(f"New embedding at index {i} is closer with distance {torch.dist(edit_embeddings[i], prompt_embedding, 2)}")
                closest_index = i
                distance = torch.dist(edit_embeddings[i], prompt_embedding, 2)
        else:
            raise ValueError(f"Invalid metric {metric}")
    
    return closest_index, distance.item()


true_labels = flipping_downsample["target"].to_numpy()
lm_judgments = np.empty(len(flipping_downsample), int)

# The count of times that having the exact input in the prompt lead to the edit label being returned. 
count_successful_direct_edit = 0 

# Count of times that the edit label was returned and the label in the prompt was not flipped.
count_correct_original_label = 0 

# The count of times that the edit label was returned the label in the prompt was flipped.
count_correct_label_flips = 0 

logs = []
distance_metric = "cosine"

for i in tqdm(range(len(flipping_downsample))):
    # time.sleep(1)
    edit_log = {}

    # Get the embedding for the current sequence
    row = flipping_downsample.iloc[i]
    original_label = row["target"]
    true_labels[i] = original_label
    current_sequence = row["text"]
    prompt_embedding = get_embedding(sentence_tokenizer, sentence_model, current_sequence)
    
    # Calculate the closest edit
    edit_example_index, distance = chose_edit_prompt(prompt_embedding, embeddings, distance_metric)
    edit_prompt = flipping_edits.iloc[edit_example_index]["text"]
    edit_label = flipping_edits.iloc[edit_example_index]["target"]
    
    edit_log["text"] = current_sequence
    edit_log["original_prompt_label"] = original_label
    edit_log["edit_prompt"] = edit_prompt
    edit_log["edit_label"] = edit_label
    edit_log["edit_distance"] = distance

    # Get the GPT-3 inference for the original edit label
    original_judgment = get_gpt_inference(edit_prompt, edit_label, current_sequence)
    edit_likely_changed_output = original_judgment == edit_label and original_label != edit_label
    edit_log["succesful_original_judgment"] = edit_likely_changed_output
    if edit_likely_changed_output:
        count_correct_original_label += 1

    # Get the GPT-3 inference for the original edit label
    flipped_edit_label = 0 if edit_label == 1 else 1
    flipped_judgment = get_gpt_inference(edit_prompt, flipped_edit_label, current_sequence)
    succesful_flipped_judgment = flipped_judgment == flipped_edit_label and edit_likely_changed_output
    edit_log["succesful_flipped_judgment"] = succesful_flipped_judgment
    if succesful_flipped_judgment:
        count_correct_label_flips += 1

    # Pass in the current sequence as the prompt example with the flipped label. See if doing so cuses the model to correctly output the edit label.
    flipped_original_label = 0 if original_label == 1 else 1
    flipped_judgment = get_gpt_inference(current_sequence, flipped_original_label, current_sequence)
    edit_log["succesful_edit"] = flipped_judgment == flipped_original_label
    if flipped_judgment == flipped_original_label:
        count_successful_direct_edit += 1
    
    logs.append(edit_log)

print(f"Count of succesful direct edits - {count_successful_direct_edit}")
print(f"Count where model agreed with edit label where the true label differs - {count_correct_original_label}")
print(f"Count where flipping the edit label caused the mode output to flip - {count_correct_label_flips}")
display(logs)

100%|██████████| 200/200 [24:22<00:00,  7.31s/it]

Count of succesful direct edits - 75
Count where model agreed with edit label where the true label differs - 16
Count where flipping the edit label caused the mode output to flip - 0





[{'text': '#Flood in Bago Myanmar #We arrived Bago',
  'original_prompt_label': 1,
  'edit_prompt': '#RoddyPiperAutos Fears over missing migrants in Med: Rescuers search for survivors after a boat carrying as ma...  http://t.co/97B8AVgEWU',
  'edit_label': 1,
  'edit_distance': 0.4316819906234741,
  'succesful_original_judgment': False,
  'succesful_flipped_judgment': False,
  'succesful_edit': False},
 {'text': "@xDescry I was wrong to call it trusty actually.. considering it spontaneously collapsed on me that's not very trusty.",
  'original_prompt_label': 1,
  'edit_prompt': "it's don't panic",
  'edit_label': 0,
  'edit_distance': 0.20025381445884705,
  'succesful_original_judgment': False,
  'succesful_flipped_judgment': False,
  'succesful_edit': True},
 {'text': 'School Bus Hijacker Given Parole After 39 Years http://t.co/HmRt98OydJ',
  'original_prompt_label': 1,
  'edit_prompt': 'Cross-border terrorism: Pakistan caught red-handed again http://t.co/uDj50J3MV4',
  'edit_label': 

In [63]:
print(f"Count of succesful direct edits - {count_successful_direct_edit}")
print(f"Count where model agreed with edit label where the true label differs - {count_correct_original_label}")
print(f"Count where flipping the edit label caused the mode output to flip - {count_correct_label_flips}")

Count of succesful direct edits - 75
Count where model agreed with edit label where the true label differs - 16
Count where flipping the edit label caused the mode output to flip - 0


In [64]:
edit_logs = pd.DataFrame(logs)
print("\n================================ Successful Direct Edits ==================================================\n")
print(edit_logs[edit_logs["succesful_edit"] == True]["text"].values)
print(edit_logs[edit_logs["succesful_edit"] == True]["original_prompt_label"].value_counts())
print("\n================================ Unsuccessful Direct Edits ==================================================\n")
print(edit_logs[edit_logs["succesful_edit"] == False]["text"].values)
print(edit_logs[edit_logs["succesful_edit"] == False]["original_prompt_label"].value_counts())



["@xDescry I was wrong to call it trusty actually.. considering it spontaneously collapsed on me that's not very trusty."
 'School Bus Hijacker Given Parole After 39 Years http://t.co/HmRt98OydJ'
 'San Bernardino I10 W Eo / Redlands Blvd **Trfc Collision-No Inj** http://t.co/FT9KIGmIgh'
 'just got engulfed in a car-induced tidal wave on my run... I thought this only happened in the movies ????'
 '#OilandGas Exploration Takes Seismic Shift in #Gabon to #Somalia http://t.co/oHHolJ9vEV via @business'
 'Motorcyclist bicyclist injured in Denver collision on Broadway http://t.co/UpPwxDA4yd'
 'Iranian warship points weapon at American helicopter... http://t.co/cgFZk8Ha1R'
 "kesabaran membuahkan hasil indah pada saat tepat! life isn't about waiting for the storm to pass it's about learning to dance in the rain."
 '@booksbyRoger TY for the follow Go To http://t.co/l9MB2j5pXg BRUTALLY ABUSED+DESOLATE&amp;LOST + HER LOVELY MUM DIES..Is it Murder?'
 '@xeni my bet is mother nature might have plan

### Edit Performance

In [65]:
def evaluate_edit_approach(dataset, encoder_model, encoder_tokenizer):
    edit_prompts = []
    edit_embeddings = []
    edit_step_accuracies = []
    judgments = np.empty(len(dataset), int)
    labels = dataset["prompt_label"].to_numpy()

    for i in tqdm(range(len(dataset))):
        time.sleep(1)
        row = tweets_downsample.iloc[i]
        current_sequence = row["prompt"]
        label = row["prompt_label"]
        inference = get_gpt_inference(example_prompt, example_label, current_sequence)
        judgments[i] = inference
    
    return judgments, labels
    