In [1]:
%load_ext autoreload

In [40]:
%autoreload
import sys
sys.path.insert(0, '../')

In [41]:
from src.utilities.mluar_utils import *

In [42]:
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, Dataset
import numpy as np
from einops import rearrange, reduce, repeat
import torch
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib import pyplot as plt
import math
import pandas as pd
from scipy.stats import zscore
import pickle as pkl
import tabulate

In [43]:
from src.utilities.mluar_utils import *

In [6]:
MULTI_LUAR_PATH =  "/mnt/swordfish-pool2/milad/multi-luar-reddit-model/"
LUAR_PATH =  "/mnt/swordfish-pool2/nikhil/LUAR/pretrained_weights/LUAR-MUD/"

In [7]:
# Load models
multiluar_model = AutoModel.from_pretrained(MULTI_LUAR_PATH, trust_remote_code=True)
luar_model = AutoModel.from_pretrained(LUAR_PATH, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD")

In [8]:
data_path = '/mnt/swordfish-pool2/milad/hiatus-data/phase_2'

In [44]:
data_embeddings = {}
for domain in ['HRS2.1', 'HRS2.2', 'HRS2.3', 'HRS2.4']:
    # Load data
    domain_data_path = domain_path_temp.format(domain, domain, 'data', domain) + '_input'
    domain_groundtruth_path = domain_path_temp.format(domain, domain, 'groundtruth', domain) 
    hiatus_data, _, _ = load_aa_data(domain_data_path, domain_groundtruth_path)

    # keep authors with only more than one text
    authors_with_multiple_texts = [x[0] for x in hiatus_data.authorID.value_counts().to_dict().items() if x[1] > 1]
    hiatus_data = hiatus_data[hiatus_data.authorID.isin(authors_with_multiple_texts)]

    # Embed data using m-luar and luar
    hiatus_data_texts = hiatus_data.fullText.tolist()
    labels = hiatus_data.authorID.tolist()
    hiatus_mluar_data_embeddings = get_luar_embeddings(hiatus_data_texts, multiluar_model, tokenizer, max_length=max_seq_length, batch_size=1, is_multi_luar=True)
    hiatus_luar_data_embeddings  = get_luar_embeddings(hiatus_data_texts, luar_model, tokenizer, max_length=max_seq_length, batch_size=1)
    hiatus_luar_data_embeddings  = [e.unsqueeze(0) for e in hiatus_luar_data_embeddings]
    
    data_embeddings[domain] = {'LUAR': hiatus_luar_data_embeddings, 'MLUAR': hiatus_mluar_data_embeddings}

Loading:  /mnt/swordfish-pool2/milad/hiatus-data/phase_2/mode_perGenre-HRS2.1/TA2/hrs_06-27-24_english_perGenre-HRS2.1/data/hrs_06-27-24_english_perGenre-HRS2.1_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/phase_2/mode_perGenre-HRS2.2/TA2/hrs_06-27-24_english_perGenre-HRS2.2/data/hrs_06-27-24_english_perGenre-HRS2.2_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/phase_2/mode_perGenre-HRS2.3/TA2/hrs_06-27-24_english_perGenre-HRS2.3/data/hrs_06-27-24_english_perGenre-HRS2.3_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/phase_2/mode_perGenre-HRS2.4/TA2/hrs_06-27-24_english_perGenre-HRS2.4/data/hrs_06-27-24_english_perGenre-HRS2.4_TA2_input


In [45]:
results = []
max_seq_length = 736
domain_path_temp = data_path + '/mode_perGenre-{}/TA2/hrs_06-27-24_english_perGenre-{}/{}/hrs_06-27-24_english_perGenre-{}_TA2'

for domain in ['HRS2.1', 'HRS2.2', 'HRS2.3', 'HRS2.4']:
    hiatus_mluar_data_embeddings = data_embeddings[domain]['MLUAR']
    hiatus_luar_data_embeddings  = data_embeddings[domain]['LUAR']
    
    # Compute cosine-sim matrix and evalaute
    muti_luar_layers_sims = np.stack([compute_similarities(hiatus_mluar_data_embeddings, hiatus_mluar_data_embeddings, layer=i) for i in range(7)])
    muti_luar_layers_sims = np.mean(muti_luar_layers_sims, 0)

    luar_sims = compute_similarities(hiatus_luar_data_embeddings, hiatus_luar_data_embeddings)
    labels_matrix = np.array([[int(x == y) for x in labels] for y in labels])

    luar_eer, luar_mrr   = eer(luar_sims, labels_matrix), compute_mrr(luar_sims, labels)
    mluar_eer, mluar_mrr = eer(muti_luar_layers_sims, labels_matrix), compute_mrr(muti_luar_layers_sims, labels)

    results.append(['LUAR', domain, luar_eer, luar_mrr])
    results.append(['MLUAR', domain, mluar_eer, mluar_mrr])
    results.append(['--', '--', '--', '--'])

ValueError: Found input variables with inconsistent numbers of samples: [13213225, 12673600]

In [None]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))