In [9]:
# Import package and read the data
import pandas as pd
data =pd.read_csv("path/to/amazon-multilingual-counterfactual-dataset/data/EN_train.tsv", sep="\t")
data.head()

Unnamed: 0,sentence,is_counterfactual
0,In person it looks as though it would have cos...,1
1,Product received as described in the shipping ...,0
2,The handkerchiefs were just what I wanted - so...,0
3,I adore Nessa and give her credit because the...,0
4,He was very pleased when he saw how neat the c...,0


In [10]:
print(data[data["is_counterfactual"]==1])

                                               sentence  is_counterfactual
0     In person it looks as though it would have cos...                  1
11              2) I wish the pillow was much fluffier.                  1
13    I thought it should have used more of esp or v...                  1
15    IF you're buying these CD sleeves PRIMARILY FO...                  1
16    They are much bigger than I thought they would...                  1
...                                                 ...                ...
4004  The padding is a little bulky, but if you want...                  1
4005  With I could have found the model without the ...                  1
4006  I wish they had one additional option of a mut...                  1
4008  Wish i would have returned it while I still co...                  1
4017  I must have looked at the reviews and the plot...                  1

[765 rows x 2 columns]


In [None]:
# load the bge tokenizer to tokenize the statements
# max token length is 512 for this model
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-base-en-v1.5")
inputs = tokenizer(list(data['sentence']), return_tensors="pt", max_length=512, padding=True, truncation=True)

In [18]:
# the pretrained bge model from hugging face
from transformers import AutoModel
model = AutoModel.from_pretrained("BAAI/bge-base-en-v1.5")

In [19]:
import torch
from torch.utils.data import DataLoader, TensorDataset

ds = TensorDataset(inputs['input_ids'], inputs['token_type_ids'], inputs['attention_mask'])
data_loader = DataLoader(ds, batch_size=128, shuffle=False)

model.eval()
output_embeddings = []

with torch.no_grad():
    for batch in data_loader:
        batch_input_ids, batch_token_type_ids, batch_attention_mask = batch
        outputs = model(input_ids=batch_input_ids, token_type_ids=batch_token_type_ids, attention_mask=batch_attention_mask)
        last_hidden_state = outputs.last_hidden_state
        # Extract the CLS token embedding
        sentence_embeddings = last_hidden_state[:, 0, :]
        output_embeddings.append(sentence_embeddings)

# Concatenate all into one array
all_embeddings = torch.cat(output_embeddings, dim=0)
print(f"All embeddings dimensions: {all_embeddings.shape}")

All embeddings dimensions: torch.Size([4018, 768])


In [24]:
last_hidden_state.shape

torch.Size([50, 137, 768])

In [25]:
# using chromadb as vector database to store the embeddings and efficiently querying similar items
import chromadb

chroma_client = chromadb.Client()
collection = chroma_client.create_collection(name="counterfactual_collection", metadata={"hnsw:space": "cosine"})

for i, (documents, embeddings, label) in enumerate(zip(list(data['sentence']), all_embeddings.tolist(), list(data['is_counterfactual']))):
    collection.upsert(ids=[str(i)], documents=documents, embeddings=embeddings, metadatas=[{"is_counterfactual": label}])


In [50]:

query = "Just who do you think you are talking to right now?"
query_inputs = tokenizer(query, return_tensors="pt", max_length=512, padding=True, truncation=True)
query_output = model(input_ids=query_inputs["input_ids"], token_type_ids=query_inputs["token_type_ids"], attention_mask=query_inputs["attention_mask"])
# Extract the CLS token embedding
query_embeddings = query_output["last_hidden_state"][:, 0, :]
query_embeddings_list = query_embeddings.squeeze().tolist()
results = collection.query([query_embeddings_list], n_results=10)
#results = collection.query(query_embeddings, n_results=10)
counterfactual_prob = sum([r["is_counterfactual"] for r in results["metadatas"][0]]) / len(results["metadatas"][0])
print(f"counterfactual probability is {counterfactual_prob * 100}%")

counterfactual probability is 10.0%


In [45]:

import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

test_data = pd.read_csv("path/to/amazon-multilingual-counterfactual-dataset/data/EN_test.tsv", sep="\t")
cutoff_threshold = 0.5

def get_counterfactual_prob(sentence):
    query_inputs = tokenizer(sentence, return_tensors="pt", max_length=512, padding=True, truncation=True)
    query_output = model(input_ids=query_inputs["input_ids"], token_type_ids=query_inputs["token_type_ids"], attention_mask=query_inputs["attention_mask"])
    # Extract the CLS token embedding
    query_embeddings = query_output["last_hidden_state"][:, 0, :]
    query_embeddings_list = query_embeddings.squeeze().tolist()
    results = collection.query([query_embeddings_list], n_results=10)
    counterfactual_prob = sum([r["is_counterfactual"] for r in results["metadatas"][0]]) / len(results["metadatas"][0])
    return counterfactual_prob

predictions = []
for index, row in test_data.iterrows():
    counterfactual_prob = get_counterfactual_prob(row['sentence'])
    prediction = 1 if counterfactual_prob > cutoff_threshold else 0
    predictions.append(prediction)

accuracy = accuracy_score(test_data['is_counterfactual'], predictions)
f1 = f1_score(test_data['is_counterfactual'], predictions)
precision = precision_score(test_data['is_counterfactual'], predictions)
recall = recall_score(test_data['is_counterfactual'], predictions)

print(f"Accuracy: {accuracy:.2f}")
print(f"F1 Score: {f1:.2f}")
print(f"Precision: {precision:.2f}")
print(f"Recall: {recall:.2f}")


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


Accuracy: 0.87
F1 Score: 0.62
Precision: 0.74
Recall: 0.53
