In [1]:
%load_ext autoreload

In [119]:
%autoreload
import sys
sys.path.insert(0, '../')

In [120]:
from src.utilities.mluar_utils import *

In [115]:
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, Dataset
import numpy as np
from einops import rearrange, reduce, repeat
import torch
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib import pyplot as plt
import math
import pandas as pd
from scipy.stats import zscore
import pickle as pkl
import tabulate
import torch.nn.functional as F

In [116]:
from src.utilities.mluar_utils import *

In [6]:
MULTI_LUAR_PATH =  "/mnt/swordfish-pool2/milad/multi-luar-reddit-model/"
LUAR_PATH =  "/mnt/swordfish-pool2/nikhil/LUAR/pretrained_weights/LUAR-MUD/"

In [7]:
# Load models
multiluar_model = AutoModel.from_pretrained(MULTI_LUAR_PATH, trust_remote_code=True)
luar_model = AutoModel.from_pretrained(LUAR_PATH, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD")

In [8]:
data_path = '/mnt/swordfish-pool2/milad/hiatus-data/phase_2'

In [121]:
# sents = ["""nots\n\nWatched as part of The Tara Reid Mission\n\nThe Film\nThere's a moment in Knots where <PERSON> character tells his friend that his prolific cheating on his girlfriend is okay because a) Her work takes her away a lot, and b) Sex with other women further solidifies his feeling that his girlfriend is 'the one' and that the sex with her after he's cheated is brilliant. It's this kind of mean spiritedness that defines the film, from the characters to the situations they create for themselves.\n\nSupposedly a comedy from the male perspective, very few of the characters ring true. The men themselves are hapless, bumbling idiots, so sure of their masculinity that they are unable to accept their (many) flaws, whilst the women are either screaming harpies, dumb conquests or manipulative ice maidens.\n\nIt's not all bad though. There are a few genuinely amusing moments, and <PERSON>, who is the one genuinely nice guy, does well with what little the script gives him.\n\nIf this is what modern relationships are like, count me out.\n\nHow's <PERSON>?\nNot appearing until around the 42 minute mark, <PERSON>'s <PERSON> was the only other character I didn't destest. The scenes with her and <PERSON>'s <PERSON> are sweet, and it was nice seeing their relationship blossom, although more time is spent with the other, more hateful characters. <PERSON>'s well documented of screen partying seems to have taken it's toll though, as she sounds like she's smoked a pack of cigarettes every time she opens her mouth.""", """nots\n\nWatched as part of The Tara Reid Mission\n\nThe Film\nThere's a moment in Knots where <PERSON> character tells his friend that his prolific cheating on his girlfriend is okay because a) Her work takes her away a lot, and b) Sex with other women further solidifies his feeling that his girlfriend is 'the one' and that the sex with her after he's cheated is brilliant. It's this kind of mean spiritedness that defines the film, from the characters to the situations they create for themselves.\n\nSupposedly a comedy from the male perspective, very few of the characters ring true. The men themselves are hapless, bumbling idiots, so sure of their masculinity that they are unable to accept their (many) flaws, whilst the women are either screaming harpies, dumb conquests or manipulative ice maidens.\n\nIt's not all bad though. There are a few genuinely amusing moments, and <PERSON>, who is the one genuinely nice guy, does well with what little the script gives him.\n\nIf this is what modern relationships are like, count me out.\n\nHow's <PERSON>?\nNot appearing until around the 42 minute mark, <PERSON>'s <PERSON> was the only other character I didn't destest. The scenes with her and <PERSON>'s <PERSON> are sweet, and it was nice seeing their relationship blossom, although more time is spent with the other, more hateful characters. <PERSON>'s well documented of screen partying seems to have taken it's toll though, as she sounds like she's smoked"""]
# embed, tokenized_txt = get_luar_embeddings(sents, luar_model, tokenizer, max_length=512, batch_size=2)
# tokenized_txt['input_ids'].shape

In [125]:
data_embeddings = {}
max_seq_length = 736
domain_path_temp = data_path + '/mode_perGenre-{}/TA2/hrs_06-27-24_english_perGenre-{}/{}/hrs_06-27-24_english_perGenre-{}_TA2'

for domain in ['HRS2.1', 'HRS2.2', 'HRS2.3', 'HRS2.4']:
    # Load data
    domain_data_path = domain_path_temp.format(domain, domain, 'data', domain) + '_input'
    domain_groundtruth_path = domain_path_temp.format(domain, domain, 'groundtruth', domain) 
    hiatus_data, _, _ = load_aa_data(domain_data_path, domain_groundtruth_path)

    # keep authors with only more than one text
    authors_with_multiple_texts = [x[0] for x in hiatus_data.authorID.value_counts().to_dict().items() if x[1] > 1]
    hiatus_data = hiatus_data[hiatus_data.authorID.isin(authors_with_multiple_texts)]

    # Embed data using m-luar and luar
    hiatus_data_texts = hiatus_data.fullText.tolist()
    hiatus_mluar_data_embeddings,_ = get_luar_embeddings(hiatus_data_texts, multiluar_model, tokenizer, max_length=max_seq_length, batch_size=1, is_multi_luar=True)
    hiatus_luar_data_embeddings,_  = get_luar_embeddings(hiatus_data_texts, luar_model, tokenizer, max_length=max_seq_length, batch_size=1)
    hiatus_luar_data_embeddings    = [e.unsqueeze(0) for e in hiatus_luar_data_embeddings]
    
    data_embeddings[domain] = {'LUAR': hiatus_luar_data_embeddings, 'MLUAR': hiatus_mluar_data_embeddings}
    break

Loading:  /mnt/swordfish-pool2/milad/hiatus-data/phase_2/mode_perGenre-HRS2.1/TA2/hrs_06-27-24_english_perGenre-HRS2.1/data/hrs_06-27-24_english_perGenre-HRS2.1_TA2_input


In [126]:
data_embeddings.keys()

dict_keys(['HRS2.1'])

In [14]:
pkl.dump(data_embeddings, open('./hiatus_data_embedded.pkl', 'wb'))

In [135]:
results = []
for domain in ['HRS2.1', 'HRS2.2', 'HRS2.3', 'HRS2.4']:

    # Load data
    domain_data_path = domain_path_temp.format(domain, domain, 'data', domain) + '_input'
    domain_groundtruth_path = domain_path_temp.format(domain, domain, 'groundtruth', domain) 
    hiatus_data, _, _ = load_aa_data(domain_data_path, domain_groundtruth_path)
    # keep authors with only more than one text
    authors_with_multiple_texts = [x[0] for x in hiatus_data.authorID.value_counts().to_dict().items() if x[1] > 1]
    hiatus_data = hiatus_data[hiatus_data.authorID.isin(authors_with_multiple_texts)]
    
    labels = hiatus_data.authorID.tolist()

    #Load embeddings from the saved dictionary
    hiatus_mluar_data_embeddings = data_embeddings[domain]['MLUAR']
    hiatus_luar_data_embeddings  = data_embeddings[domain]['LUAR']

    luar_sims = compute_similarities(hiatus_luar_data_embeddings, hiatus_luar_data_embeddings)
    labels_matrix = np.array([[int(x == y) for y in labels] for x in labels])
    luar_eer, luar_mrr   = eer(luar_sims, labels_matrix), compute_mrr(luar_sims, labels)

    results.append(['LUAR', domain, luar_eer, luar_mrr])
    results.append(['+++', '+++', '+++', '+++'])

    muti_luar_layers_sims = np.stack([compute_similarities(hiatus_mluar_data_embeddings, hiatus_mluar_data_embeddings, layer=i) for i in range(7)])
    
    muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims, 0)
    mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated, labels)
    results.append(['MLUAR', domain, mluar_eer, mluar_mrr])
    results.append(['--', '--', '--', '--'])

    # Ablation study
    for layer in range(7):
        selector = [i for i in range(muti_luar_layers_sims.shape[0]) if i != layer]
        print(selector)
        muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims[selector, :, :], 0)
        mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated, labels)
        results.append(['MLUAR/{}'.format(layer), domain, mluar_eer, mluar_mrr])
        results.append(['--', '--', '--', '--'])
        
    break

Loading:  /mnt/swordfish-pool2/milad/hiatus-data/phase_2/mode_perGenre-HRS2.1/TA2/hrs_06-27-24_english_perGenre-HRS2.1/data/hrs_06-27-24_english_perGenre-HRS2.1_TA2_input
[1, 2, 3, 4, 5, 6]
[0, 2, 3, 4, 5, 6]
[0, 1, 3, 4, 5, 6]
[0, 1, 2, 4, 5, 6]
[0, 1, 2, 3, 5, 6]
[0, 1, 2, 3, 4, 6]
[0, 1, 2, 3, 4, 5]


In [137]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

         Domain    EER    MRR
-------  --------  -----  -----
LUAR     HRS2.1    0.158  0.199
+++      +++       +++    +++
MLUAR    HRS2.1    0.27   0.159
--       --        --     --
MLUAR/0  HRS2.1    0.237  0.166
--       --        --     --
MLUAR/1  HRS2.1    0.268  0.16
--       --        --     --
MLUAR/2  HRS2.1    0.269  0.16
--       --        --     --
MLUAR/3  HRS2.1    0.267  0.16
--       --        --     --
MLUAR/4  HRS2.1    0.264  0.16
--       --        --     --
MLUAR/5  HRS2.1    0.271  0.158
--       --        --     --
MLUAR/6  HRS2.1    0.323  0.142
--       --        --     --
