In [None]:
import os
import sys
import torch
import faiss
import argparse
import numpy as np
import pandas as pd
import torch.nn.functional as F
import nlpaug.augmenter.word as naw
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
tqdm.pandas()

from util_modeling import get_model_objects
from util_data import get_formatted_dataset, get_num_labels
from adaptive_methods import get_paraphrase_augmentations

In [None]:
datasets = get_formatted_dataset("boss_sentiment")
train_set = datasets["train"].to_pandas().drop(columns=["__index_level_0__"])
test_set = datasets["validation"].to_pandas().drop(columns=["__index_level_0__"])
display(train_set.head())
display(test_set.head())

## Create Embeddings

In [28]:
model_name = "humarin/chatgpt_paraphraser_on_T5_base"
# model_name = "princeton-nlp/sup-simcse-roberta-large"
# model_name = "humarin/chatgpt_paraphraser_on_T5_base"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer, model = get_model_objects(model_name, num_labels=3)

In [30]:
def get_embedding(text):
    with torch.no_grad():
        tokens = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
        if model_name == "humarin/chatgpt_paraphraser_on_T5_base":
            return model(**tokens, decoder_input_ids=tokens["input_ids"], output_hidden_states=True)["encoder_last_hidden_state"].mean(dim=1).squeeze().detach().cpu().numpy()
        elif model_name == "princeton-nlp/sup-simcse-roberta-large":
            return model(**tokens)["pooler_output"].detach().cpu().numpy()
        else:
            return model(**tokens, output_hidden_states=True)["hidden_states"][-1].mean(dim=1).squeeze().detach().cpu().numpy()

train_set["embedding"] = train_set["text"].progress_apply(get_embedding)
train_set.head()

100%|██████████| 29999/29999 [10:57<00:00, 45.61it/s]


Unnamed: 0,text,label,embedding
0,One of my favorites,1,"[-0.2169, 0.0548, -0.3027, 0.00297, -0.04462, ..."
1,My favorite Coarse Sea Salt brand I know about...,1,"[0.0527, -0.0285, -0.0876, 0.002548, -0.03056,..."
2,"Love the top! It fits a little tight, so can b...",1,"[-0.12036, -0.08734, -0.04828, 0.0815, -0.0334..."
3,very nice & I like it for everything I used it...,1,"[-0.09247, -0.09595, -0.1056, -0.0383, -0.1348..."
4,Awesome product!,1,"[-0.1068, -0.1333, -0.02583, 0.04718, -0.1442,..."


In [31]:
torch.Tensor(np.stack(train_set["embedding"])).mean(dim=0)

tensor([-5.2756e-02, -6.0333e-02, -6.7755e-02,  5.4552e-02, -2.7676e-02,
         1.7895e-02, -7.9711e-02, -6.7177e-02, -5.8001e-02, -2.0899e-03,
        -1.5487e-02, -3.2947e-02, -7.6349e-03,  4.9093e-02,  2.3862e-02,
         6.8389e-03, -3.0690e-02,  3.9634e-02,  9.5420e-02,  3.2718e-02,
        -3.1661e-02, -1.8483e-02,  6.1046e-02,  1.2202e-03, -3.5695e-02,
        -9.8095e-03,  2.2426e-02,  7.1941e-03,  2.2082e-03, -2.2697e-02,
        -3.1966e-03,  1.0527e-02,  8.3322e-02, -7.2907e-03,  1.3515e-02,
        -6.8795e-03, -1.4656e-01, -4.2748e-03,  2.0008e-03,  6.8266e-02,
        -1.2039e-02,  6.1029e-02,  1.2464e-01, -2.6741e-02, -3.3780e-01,
        -8.9305e-03, -3.5918e-02,  7.9114e-03, -3.7333e-02,  4.5892e-03,
         1.2873e-02,  3.0469e-02, -4.8068e-02,  2.6156e-02, -7.9206e-02,
         7.2732e-03,  2.9686e-01, -7.6005e-03, -1.5971e-04,  8.4089e-03,
         6.4910e-02, -1.9394e-02,  3.9386e-03, -5.8784e-03,  9.5261e-03,
        -7.7078e-02, -3.6686e-02, -2.6133e-02,  9.1

In [32]:
train_set_embeddings = torch.Tensor(np.stack(train_set["embedding"])).squeeze(1)
display(train_set_embeddings.shape)
torch.save(train_set_embeddings, f"notebooks/dynasent_analysis/amazon_train_embeddings_{model_name.replace('/', '-')}.pt")

torch.Size([29999, 768])

In [None]:
test_set_centroid = test_set_embeddings.mean(dim=0)
display(test_set_centroid.shape)
torch.save(test_set_centroid, f"notebooks/dynasent_analysis/amazon_validation_centroid_{model_name.replace('/', '-')}.pt")

In [None]:
labels = train_set["label"].unique()
vector_stores = {}
centroids = {}
centroid_examples = {}
k = 10
d = 1024

for label in labels:
    label_instances = train_set[train_set["label"] == label]
    label_embeddings = np.stack(label_instances["embedding"].to_numpy()).astype(np.float32)
    
    faiss.normalize_L2(label_embeddings)
    vector_stores[label] = faiss.IndexFlatIP(d)
    vector_stores[label].add(label_embeddings)
    centroids[label] = label_embeddings.mean(axis=0)
    
    cosine_sims, centroid_example_indices = vector_stores[label].search(centroids[label].reshape(1, -1), k)
    centroid_examples[label] = []
    for index in centroid_example_indices[0]:
        centroid_examples[label].append(label_instances.iloc[index]["text"])

centroid_examples

In [None]:
new_train_set_records = []
for inde, row in tqdm(train_set.iterrows(), total=len(train_set)):
    current_label = row["label"]
    current_text = row["text"]
    for example in centroid_examples[current_label]:
        new_train_set_records.append({"text": current_text, "label": example, "class": current_label})

rewrite_train_set = pd.DataFrame(new_train_set_records).sample(frac=1).reset_index(drop=True)
rewrite_train_set.to_csv("datasets/corrupted/boss_sentiment_train.csv", index=False)
display(rewrite_train_set)

new_test_set_records = []
for inde, row in tqdm(test_set.iterrows(), total=len(test_set)):
    current_label = row["label"]
    current_text = row["text"]
    for example in centroid_examples[current_label]:
        new_test_set_records.append({"text": current_text, "label": example, "class": current_label})

rewrite_test_set = pd.DataFrame(new_test_set_records).sample(frac=1).reset_index(drop=True)
rewrite_test_set.to_csv("datasets/corrupted/boss_sentiment_test.csv", index=False)
display(rewrite_test_set)

## Dataset with Augmentations

In [None]:
paraphrase_tokenizer, paraphrase_model = get_model_objects("humarin/chatgpt_paraphraser_on_T5_base", num_labels=-1)

In [None]:
example_text = "I use this every day I would recommend this for anyone who has special needs with thinning hair, it has made a huge difference in my daily life."
get_paraphrase_augmentations(example_text,
                             paraphrase_tokenizer,
                             paraphrase_model,
                             paraphrase_model.device,
                             num_return_sequences=4,
                             temperature=0.3)

In [None]:
new_train_set_records = []
for _, row in tqdm(train_set.iterrows(), total=len(train_set)):
    current_label = row["label"]
    current_text = row["text"]
    augmentations = get_paraphrase_augmentations(current_text,
                             paraphrase_tokenizer,
                             paraphrase_model,
                             paraphrase_model.device,
                             num_return_sequences=4,
                             temperature=0.3)
    
    for example in centroid_examples[current_label]:
        for text_input in [current_text] + augmentations:
            new_train_set_records.append({"text": text_input, "label": example, "class": current_label})

rewrite_train_set = pd.DataFrame(new_train_set_records).sample(frac=1).reset_index(drop=True)
rewrite_train_set.to_csv("datasets/corruped/boss_sentiment_augmented_train.csv", index=False)
display(rewrite_train_set)

In [None]:
new_test_set_records = []
for _, row in tqdm(test_set.iterrows(), total=len(test_set)):
    current_label = row["label"]
    current_text = row["text"]
    augmentations = get_paraphrase_augmentations(current_text,
                             paraphrase_tokenizer,
                             paraphrase_model,
                             paraphrase_model.device,
                             num_return_sequences=4,
                             temperature=0.3)
    
    for example in centroid_examples[current_label]:
        for text_input in [current_text] + augmentations:
            new_test_set_records.append({"text": text_input, "label": example, "class": current_label})

rewrite_test_set = pd.DataFrame(new_test_set_records).sample(frac=1).reset_index(drop=True)
rewrite_test_set.to_csv("datasets/corruped/boss_sentiment_augmented_test.csv", index=False)
display(rewrite_test_set)