In [1]:
%load_ext autoreload

In [2]:
%autoreload
import sys
sys.path.insert(0, '../')

In [3]:
from src.utilities.mluar_utils import *

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, Dataset
import numpy as np
from einops import rearrange, reduce, repeat
import torch
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib import pyplot as plt
import math
import pandas as pd
from scipy.stats import zscore
import pickle as pkl
import tabulate
import torch.nn.functional as F

In [5]:
from src.utilities.mluar_utils import *

In [6]:
MULTI_LUAR_PATH =  "/mnt/swordfish-pool2/milad/multi-luar-reddit-model/"
LUAR_PATH =  "/mnt/swordfish-pool2/nikhil/LUAR/pretrained_weights/LUAR-MUD/"

In [7]:
# Load models
multiluar_model = AutoModel.from_pretrained(MULTI_LUAR_PATH, trust_remote_code=True)
luar_model = AutoModel.from_pretrained(LUAR_PATH, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD")

In [8]:
data_path = '/mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2'

In [9]:
data_embeddings = {}
max_seq_length = 512
domain_path_temp = data_path + '/HRS1_english_long_sample-0_perGenre-{}/{}/HRS1_english_long_sample-0_perGenre-{}_TA2'

for domain in ['HRS1.1', 'HRS1.2', 'HRS1.3', 'HRS1.4']:
    # Load data
    domain_data_path = domain_path_temp.format(domain, 'data', domain) + '_input'
    domain_groundtruth_path = domain_path_temp.format(domain, 'groundtruth', domain) 
    hiatus_data, _, _ = load_aa_data(domain_data_path, domain_groundtruth_path)

    # keep authors with only more than one text
    authors_with_multiple_texts = [x[0] for x in hiatus_data.authorID.value_counts().to_dict().items() if x[1] > 1]
    hiatus_data = hiatus_data[hiatus_data.authorID.isin(authors_with_multiple_texts)]

    # Embed data using m-luar and luar
    hiatus_data_texts = hiatus_data.fullText.tolist()
    hiatus_mluar_data_embeddings,_ = get_luar_embeddings(hiatus_data_texts, multiluar_model, tokenizer, max_length=max_seq_length, batch_size=1, is_multi_luar=True)
    hiatus_luar_data_embeddings,_  = get_luar_embeddings(hiatus_data_texts, luar_model, tokenizer, max_length=max_seq_length, batch_size=1)
    hiatus_luar_data_embeddings    = [e.unsqueeze(0) for e in hiatus_luar_data_embeddings]
    
    data_embeddings[domain] = {'LUAR': hiatus_luar_data_embeddings, 'MLUAR': hiatus_mluar_data_embeddings}

Loading:  /mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_perGenre-HRS1.1/data/HRS1_english_long_sample-0_perGenre-HRS1.1_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_perGenre-HRS1.2/data/HRS1_english_long_sample-0_perGenre-HRS1.2_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_perGenre-HRS1.3/data/HRS1_english_long_sample-0_perGenre-HRS1.3_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_perGenre-HRS1.4/data/HRS1_english_long_sample-0_perGenre-HRS1.4_TA2_input


In [13]:
pkl.dump(data_embeddings, open('./hiatus_data_embedded.pkl', 'wb'))

In [14]:
results = []
for domain in ['HRS1.1', 'HRS1.2', 'HRS1.3', 'HRS1.4']:

    # Load data
    domain_data_path = domain_path_temp.format(domain, 'data', domain) + '_input'
    domain_groundtruth_path = domain_path_temp.format(domain, 'groundtruth', domain) 
    hiatus_data, _, _ = load_aa_data(domain_data_path, domain_groundtruth_path)
    # keep authors with only more than one text
    authors_with_multiple_texts = [x[0] for x in hiatus_data.authorID.value_counts().to_dict().items() if x[1] > 1]
    hiatus_data = hiatus_data[hiatus_data.authorID.isin(authors_with_multiple_texts)]
    
    labels = hiatus_data.authorID.tolist()

    #Load embeddings from the saved dictionary
    hiatus_mluar_data_embeddings = data_embeddings[domain]['MLUAR']
    hiatus_luar_data_embeddings  = data_embeddings[domain]['LUAR']

    luar_sims = compute_similarities(hiatus_luar_data_embeddings, hiatus_luar_data_embeddings)
    
    labels_matrix = np.array([[int(x == y) for y in labels] for x in labels])
    luar_eer, luar_mrr   = eer(luar_sims, labels_matrix), compute_mrr(luar_sims, labels)

    results.append(['+++', '+++', '+++', '+++'])
    results.append(['LUAR', domain, luar_eer, luar_mrr])

    muti_luar_layers_sims = np.stack([compute_similarities(hiatus_mluar_data_embeddings, hiatus_mluar_data_embeddings, layer=i) for i in range(7)])
    
    muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims, 0)
    mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated, labels)
    results.append(['MLUAR', domain, mluar_eer, mluar_mrr])
    results.append(['--', '--', '--', '--'])
    
    # Ablation study
    for layers in [[0], [1,2], [3,4], [5,6]]:
        selector = [i for i in range(muti_luar_layers_sims.shape[0]) if i in layers]
        #print(selector)
        muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims[selector, :, :], 0)
        mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated, labels)
        results.append(['MLUAR/{}'.format(layers), domain, mluar_eer, mluar_mrr])
        #results.append(['--', '--', '--', '--'])

Loading:  /mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_perGenre-HRS1.1/data/HRS1_english_long_sample-0_perGenre-HRS1.1_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_perGenre-HRS1.2/data/HRS1_english_long_sample-0_perGenre-HRS1.2_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_perGenre-HRS1.3/data/HRS1_english_long_sample-0_perGenre-HRS1.3_TA2_input
Loading:  /mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_perGenre-HRS1.4/data/HRS1_english_long_sample-0_perGenre-HRS1.4_TA2_input


In [15]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

              Domain    EER    MRR
------------  --------  -----  -----
+++           +++       +++    +++
LUAR          HRS1.1    0.098  0.192
MLUAR         HRS1.1    0.118  0.174
--            --        --     --
MLUAR/[0]     HRS1.1    0.123  0.139
MLUAR/[1, 2]  HRS1.1    0.121  0.164
MLUAR/[3, 4]  HRS1.1    0.118  0.169
MLUAR/[5, 6]  HRS1.1    0.106  0.169
+++           +++       +++    +++
LUAR          HRS1.2    0.1    0.187
MLUAR         HRS1.2    0.12   0.17
--            --        --     --
MLUAR/[0]     HRS1.2    0.125  0.137
MLUAR/[1, 2]  HRS1.2    0.123  0.162
MLUAR/[3, 4]  HRS1.2    0.119  0.166
MLUAR/[5, 6]  HRS1.2    0.108  0.166
+++           +++       +++    +++
LUAR          HRS1.3    0.1    0.185
MLUAR         HRS1.3    0.12   0.166
--            --        --     --
MLUAR/[0]     HRS1.3    0.125  0.133
MLUAR/[1, 2]  HRS1.3    0.121  0.159
MLUAR/[3, 4]  HRS1.3    0.119  0.162
MLUAR/[5, 6]  HRS1.3    0.109  0.162
+++           +++       +++    +++
LUAR          HRS1.4 