In [2]:
from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import random
from sentence_transformers import util
import helper_functions as hp
from torch.optim.lr_scheduler import StepLR

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import random
class CustomDatasetTriplet(Dataset):
    def __init__(self, dataset):#store the dataset in memory for faster acess
        self.edit_sentences = [val[4] for val in dataset]
        self.paraphrase_sentences = [val[5] for val in dataset]
        self.neigbhourhood_sentences = [val[6] for val in dataset]
        self.edit_vectors = [val[0] for val in dataset]
        self.paraphrase_vectors = [val[1] for val in dataset]
        self.neighbourhood_vectors = [val[2] for val in dataset]
    def __len__(self):
        return len(self.edit_vectors)

    def __getitem__(self, index):
        emb1 = torch.tensor(self.edit_vectors[index], dtype=torch.float32)
        emb2 = torch.tensor(self.paraphrase_vectors[index], dtype=torch.float32)
        emb3 = torch.tensor(self.neighbourhood_vectors[index], dtype=torch.float32)
        sent1 = self.edit_sentences[index]
        sent2 = self.paraphrase_sentences[index]
        sent3 = self.neigbhourhood_sentences[index]
        
        return emb1, emb2, emb3, sent1, sent2, sent3


class CustomDataset(Dataset):#do not duplicate data approach
    def __init__(self, dataset, device):
        self.dataset = np.array(dataset, dtype=object)
        self.device = device
    
    def __len__(self):
        return len(self.dataset)

    def total_indexes(self):
        return np.unique(self.dataset[:, 3])

    def get_row_indexes(self, target_sample_index):
        return np.where(self.dataset[:, 3] == target_sample_index)[0]

    def get_samples_at_data_index(self, target_sample_index):
        row_indexes = np.where(self.dataset[:, 3] == target_sample_index)[0]
        embs_edit, embs_paraphrase, embs_neighbour, row_indexes, sents_edit, sents_paraphrase, sents_neigbhour = [], [], [], [], [], [], []
        
        for index in row_indexes:
            embs_edit.append(torch.tensor(self.dataset[index][0], dtype=torch.float32).to(self.device))
            embs_paraphrase.append(torch.tensor(self.dataset[index][1], dtype=torch.float32).to(self.device))
            embs_neighbour.append(torch.tensor(self.dataset[index][2], dtype=torch.float32).to(self.device))
            row_indexes.append(self.dataset[index][3])
            sents_edit.append(self.dataset[index][4])
            sents_paraphrase.append(self.dataset[index][5])
            sents_neigbhour.append(self.dataset[index][6])
        
        return embs_edit, embs_paraphrase, embs_neighbour, row_indexes, sents_edit, sents_paraphrase, sents_neigbhour

    def __getitem__(self, index):
        emb_edit = torch.tensor(self.dataset[index][0], dtype=torch.float32).to(self.device)
        emb_paraphrase = torch.tensor(self.dataset[index][1], dtype=torch.float32).to(self.device)
        emb_neighbour = torch.tensor(self.dataset[index][2], dtype=torch.float32).to(self.device)
        row_index = self.dataset[index][3]
        sent_edit = self.dataset[index][4]
        sent_paraphrase = self.dataset[index][5]
        sent_neigbhour = self.dataset[index][6]
        #print(self.dataset[index][2])
        
        return emb_edit, emb_paraphrase, emb_neighbour, row_index, sent_edit, sent_paraphrase, sent_neigbhour


def get_data_loader(dataset_paired, batch_size=2, shuffle=True, device="cpu"):
    dataset_pt = CustomDataset(dataset_paired, device)
    data_loader = DataLoader(dataset_pt, batch_size=batch_size, shuffle=shuffle)
    return data_loader

In [4]:
def create_dataset_tripletloss(dataset,mode=0,label_reversal=False):
    """
    The dataset is created in a set format for both test and train sets.
    The first value will always be the edit vector, the second a paraphrase vector and the thirt a neighbour vector.
    This is followed by the row index, the text prompts for edit, praraphrase and neighbours

    While using the test set the test paraphrase will be in all the rows but need to evaluated only once and can be ignored there after.

    """

        
    if(label_reversal==True):
        paraphrase=0
        neightbour=1
    else:
        paraphrase=1
        neightbour=0

    dataset_paired_train=[]
    dataset_paired_test=[]
   
    for row_index,row in enumerate(dataset):#iterate over the dataset
        index_control_neighbourhood=len(row["vectors_neighborhood_prompts_high_sim"])-1#number of entries in the neighbourhood 
        # print(index_control_neighbourhood)
        num_elements_to_select = min(3, len(row["openai_usable_paraphrases_embeddings"]))#add 3 max open ai paraphrases
        #with openai paraphrases set to 3 and 1 paraphrase from the dataset there are 4 elements, total neighbourhood elements are 5
        #I have made the code such that there is sampling for paraphrase for 5th element based on random sampling, othere wise you can just use 4 elements from neighbourhood
        sampled_indices, sampled_elements = zip(*random.sample(list(enumerate(row["openai_usable_paraphrases_embeddings"])), num_elements_to_select))# sample and get indexes
        for index, vector_openai in zip(sampled_indices, sampled_elements):
            dataset_paired_train.append([row["vector_edited_prompt"],vector_openai,row["vectors_neighborhood_prompts_high_sim"][abs(index_control_neighbourhood)],row_index,
                                        row["edited_prompt"][0],row["openai_usable_paraphrases"][index],row["neighborhood_prompts_high_sim"][abs(index_control_neighbourhood)]])
            index_control_neighbourhood=index_control_neighbourhood-1

        dataset_paired_train.append([row["vector_edited_prompt"],row["vector_edited_prompt_paraphrases_processed"],row["vectors_neighborhood_prompts_high_sim"][abs(index_control_neighbourhood)],row_index,
                                        row["edited_prompt"][0],row["openai_usable_paraphrases"][index],row["neighborhood_prompts_high_sim"][abs(index_control_neighbourhood)]])
        index_control_neighbourhood=index_control_neighbourhood-1
        #at this point the index is zero with one neigbour not being used. you can add it if you want to.

            
        #test set
        for index,vector in enumerate(row["vectors_neighborhood_prompts_low_sim"]):
            dataset_paired_test.append([row["vector_edited_prompt"],row["vector_edited_prompt_paraphrases_processed_testing"],vector,row_index,
                                        row["edited_prompt"][0],row["edited_prompt_paraphrases_processed_testing"],row["neighborhood_prompts_low_sim"][index]])
      
            

    return  dataset_paired_train,dataset_paired_test

In [5]:
#testing the dataloaders
import json,linecache
def read_dataset_reduced(file_path_read_dataset: str,data_size):
    dataset=[]
    values_list = list(range(1, data_size+1))
    for index,number in enumerate(values_list):

        try:
            data_entry = json.loads(linecache.getline(file_path_read_dataset, number).strip())
            dataset.append(data_entry)
        except Exception as e:
            print(index)
            print(e)
    return dataset
file_path_dataset="counterfact_test_2_lama_merged.jsonl"
num_samples=4999
dataset=read_dataset_reduced(file_path_dataset,data_size=num_samples) 


In [6]:
dataset_paired_train,dataset_paired_test=create_dataset_tripletloss(dataset)
input_dim = len(dataset_paired_test[0][0][0])  
print(f"output vector length: {input_dim}")

output vector length: 4096


In [7]:
batch_size = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset_paired_train, dataset_paired_test = create_dataset_tripletloss(dataset, mode=0, label_reversal=False)
train_loader = get_data_loader(dataset_paired_train, batch_size=batch_size, device=device)
test_loader = get_data_loader(dataset_paired_test, batch_size=batch_size, device=device)

In [8]:
import torch
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn.functional as F
import numpy as np

def calculate_distances(anchor, positive, negative):
    cos_pos = float(cosine_similarity(anchor.cpu().numpy().reshape(1, -1), 
                                positive.cpu().numpy().reshape(1, -1))[0][0])
    cos_neg = float(cosine_similarity(anchor.cpu().numpy().reshape(1, -1), 
                                negative.cpu().numpy().reshape(1, -1))[0][0])
    
    dist_pos = float(F.pairwise_distance(anchor.unsqueeze(0), positive.unsqueeze(0)).item())
    dist_neg = float(F.pairwise_distance(anchor.unsqueeze(0), negative.unsqueeze(0)).item())
    
    return cos_pos, cos_neg, dist_pos, dist_neg

In [9]:
def compute_threshold(emb_edit, emb_para):
    
    dist = torch.dist(emb_edit, emb_para).item()

    return dist

In [10]:
import torch.nn.functional as F
from collections import defaultdict


threshold_map = defaultdict(list)

for edit_vector, paraphrase_vector, neighbor_vector, row_index, _, _, _ in train_loader:

    for i, idx in enumerate(row_index):
        idx = int(idx.item())  
        threshold = compute_threshold(edit_vector[i], paraphrase_vector[i])

        threshold_map[idx].append(threshold)


final_threshold_map = {str(k): max(v) for k, v in threshold_map.items()}


with open("Lexical_threashold.json", "w") as f:
    json.dump(final_threshold_map, f, indent=4)

print(f"Thresholds saved for {len(final_threshold_map)} edit vectors")

Thresholds saved for 4999 edit vectors


In [None]:
with open("Lexical_Threashold.json", "r") as f:
    threshold_map = json.load(f)

threshold_map = {str(k): v for k, v in threshold_map.items()}

def predict_label(edit_vector_test, paraphrase_vector_test, neighbor_vector_test, threshold, margin=0.1):
    anchor = edit_vector_test
    positive = paraphrase_vector_test
    negative = neighbor_vector_test

    dist_para = torch.dist(anchor, positive).item()
    dist_neigh = torch.dist(anchor, negative).item()


    if dist_para < threshold:
        if dist_neigh > threshold:
            return 1, dist_para, dist_neigh, True, True
        else:
            return 1, dist_para, dist_neigh, True, False
    else:
        if dist_neigh > threshold:
            return 0, dist_para, dist_neigh, False, True
        else:
            return 0, dist_para, dist_neigh, False, False

correct = 0
total = 0
Generation_success = 0
locality_success = 0
incorrect_predictions = []


with torch.no_grad():
    for edit_vector, paraphrase_vector, neighbor_vector, row_index, edit_sentence, paraphrase_sentence, neighbor_sentence in test_loader:

        if isinstance(row_index, torch.Tensor):
            row_index = [int(idx.item()) for idx in row_index]

        for i, index in enumerate(row_index):
            threshold = threshold_map.get(str(index), 1.0)

            predicted, dist_para, dist_neigh, gen_success, loc_success = predict_label(
                edit_vector[i], 
                paraphrase_vector[i], 
                neighbor_vector[i], 
                threshold,
            )

            if gen_success:
                Generation_success += 1
            
            if loc_success:
                locality_success += 1

            if predicted == 1:
                correct += 1
            else:
                incorrect_predictions.append({
                    "edit_sentence": edit_sentence[i],
                    "paraphrase_sentence": paraphrase_sentence[i],
                    "neighbor_sentence": neighbor_sentence[i],
                    "distance_paraphrase": dist_para,
                    "distance_neighbor": dist_neigh,
                    "threshold": threshold
                })

            total += 1

with open("Lexical bias outcome.json", "w") as f:
    json.dump(incorrect_predictions, f, indent=4)

print(f"loca_success: {locality_success}")
print(f"gen_success: {Generation_success}")

Generation_rate = Generation_success / total
locality_rate = locality_success / total

print(f"Total: {total}")
print(f"Correct: {correct}")
print(f"Generalization Rate: {Generation_rate:.2%}")
print(f"Locality Rate: {locality_rate:.2%}")
print(f"Incorrect Predictions Saved to 'Lexical bias outcome.json'")


loca_success: 23955
gen_success: 20820
Total: 24995
Correct: 20820
Generalization Rate: 83.30%
Locality Rate: 95.84%
Incorrect Predictions Saved to 'Lexical bias outcome.json'


: 