In [1]:
import chromadb
from chromadb.config import Settings
from chromadb.utils import embedding_functions

sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="all-distilroberta-v1")

print(f'Default embedding function: {sentence_transformer_ef}')
client = chromadb.Client(Settings(persist_directory="./spam-db")) # Does persistence work??


# collection = client.create_collection(name="spam-dataset")
collection = client.create_collection(
        name="spam-dataset",
        metadata={"hnsw:space": "cosine"} # l2 is the default
    )

  from .autonotebook import tqdm as notebook_tqdm
Using embedded DuckDB without persistence: data will be transient
No embedding_function provided, using default embedding function: SentenceTransformerEmbeddingFunction


Default embedding function: <chromadb.utils.embedding_functions.SentenceTransformerEmbeddingFunction object at 0x7ff803fe4580>


In [2]:
import os
import sys
import numpy as np

# Add the parent directory of the project to the sys.path
project_dir = os.path.abspath(os.path.join(os.path.dirname('.'), ".."))

sys.path.append(project_dir)


from utils.read_gold_dataset import collect_all_files, read_a_file, get_gt_from_file_name

root_dir = '../spam-dataset/op_spam_v1.4'
file_path_list = collect_all_files(root_dir, 'truthful')
file_path_list += collect_all_files(root_dir, 'deceptive')
total_samples = len(file_path_list)
print(f'Total samples: {total_samples}')



Total samples: 1600


In [3]:
# test/train split
# Import necessary libraries
from sklearn.model_selection import train_test_split


# Splitting the dataset into training and testing sets
train_files, test_files = train_test_split(file_path_list, test_size=0.20, random_state=42, shuffle=True)

# Printing the results
print(f'Train file count: {len(train_files)}, Training Data: {train_files}')
print(f'Test file count: {len(test_files)} Testing Data: {test_files}')

Train file count: 1280, Training Data: ['../spam-dataset/op_spam_v1.4/negative_polarity/deceptive_from_MTurk/fold4/d_homewood_11.txt', '../spam-dataset/op_spam_v1.4/negative_polarity/truthful_from_Web/fold5/t_palmer_13.txt', '../spam-dataset/op_spam_v1.4/negative_polarity/truthful_from_Web/fold2/t_ambassador_7.txt', '../spam-dataset/op_spam_v1.4/negative_polarity/truthful_from_Web/fold2/t_affinia_16.txt', '../spam-dataset/op_spam_v1.4/negative_polarity/truthful_from_Web/fold3/t_fairmont_2.txt', '../spam-dataset/op_spam_v1.4/positive_polarity/deceptive_from_MTurk/fold4/d_swissotel_10.txt', '../spam-dataset/op_spam_v1.4/negative_polarity/truthful_from_Web/fold3/t_hyatt_17.txt', '../spam-dataset/op_spam_v1.4/negative_polarity/deceptive_from_MTurk/fold4/d_homewood_1.txt', '../spam-dataset/op_spam_v1.4/negative_polarity/truthful_from_Web/fold2/t_hardrock_19.txt', '../spam-dataset/op_spam_v1.4/negative_polarity/truthful_from_Web/fold4/t_knickerbocker_19.txt', '../spam-dataset/op_spam_v1.4/ne

In [4]:
# do this only once
base = 100
for i, file_path in enumerate(train_files):
    gt_sentiment, gt_veracity, text = read_a_file(file_path)
    ids = ['id' + str(i + base) ]
    metadatas = [{"source": file_path}]
    # print(ids)
    # print(metadatas)
    collection.add(
        documents = [text],
        metadatas = metadatas,
        ids = ids
    )
print(f' count: {collection.count()}')

 count: 1280


In [5]:
# collection.get()

In [6]:
import pprint

def get_nearest_neighbor(test_file_name):
    _, gt_veracity, text = read_a_file(test_file_name)
    # pprint.pprint(text)
    print(f'senti: {gt_sentiment}, vericity: {gt_veracity}')
    results = collection.query(
        query_texts= text,
        n_results=1
    )
    # pprint.pprint(results)
    file_nearest_neighbor = results['metadatas'][0][0]['source'] 
    # print(file_nearest_neighbor)
    # extract ground_truth from spam file name
    _, gt_veracity_of_nearest_neighbor = get_gt_from_file_name(file_nearest_neighbor)
    print(f'GT of post:{gt_veracity}, GT of nearest neighbor: {gt_veracity_of_nearest_neighbor}')

test_file_name = '../spam-dataset/op_spam_v1.4/negative_polarity/truthful_from_Web/fold2/t_affinia_11.txt'
get_nearest_neighbor(test_file_name)

senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: truthful


In [7]:
for test_file in test_files:
    get_nearest_neighbor(test_file)

senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: truthful
senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: deceptive
senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: deceptive
senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: truthful
senti: positive, vericity: deceptive
GT of post:deceptive, GT of nearest neighbor: deceptive
senti: positive, vericity: deceptive
GT of post:deceptive, GT of nearest neighbor: truthful
senti: positive, vericity: deceptive
GT of post:deceptive, GT of nearest neighbor: deceptive
senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: deceptive
senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: deceptive
senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: truthful
senti: positive, vericity: truthful
GT of post:truthful, GT of nearest neighbor: truthfu