In [14]:
%load_ext autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
%autoreload
import sys
import os
os.environ['TRANSFORMERS_CACHE'] = '/mnt/swordfish-pool2/milad/hf-cache'
os.environ['HF_DATASETS_CACHE'] = '/mnt/swordfish-pool2/milad/hf-cache'
sys.path.insert(0, '../')

In [18]:
%autoreload

from src.utilities.mluar_utils import *
from src.datasets import utils
from src.arguments import create_argument_parser
from tqdm import tqdm

In [19]:
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, Dataset
import numpy as np
from einops import rearrange, reduce, repeat
import torch
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib import pyplot as plt
import math
import pandas as pd
from scipy.stats import zscore
import pickle as pkl
import tabulate
import torch.nn.functional as F
from torch.utils.data import DataLoader
torch.multiprocessing.set_sharing_strategy('file_system')

In [20]:
params = create_argument_parser()
params.sanity = None
params.episode_length=16
params.model_type='roberta'
params.text_key = 'syms'
params.time_key='hours'
params.suffix=''
params.token_max_length=32
params.use_random_windows=False
params.mask_bpe_percentage=0
params.pin_memory=False
params.num_workers=4

In [21]:
#MULTI_LUAR_PATH =  "/mnt/swordfish-pool2/milad/Multi-LUAR/data/multi-luar-reddit-model/"
MULTI_LUAR_PATH =  "/mnt/swordfish-pool2/milad/Multi-LUAR/data/multi-luar-all-model/"
LUAR_PATH       =  "/mnt/swordfish-pool2/milad/Multi-LUAR/data/reproduced-luar/"

In [29]:
# Load models
multiluar_model = AutoModel.from_pretrained(MULTI_LUAR_PATH, trust_remote_code=True)
luar_model = AutoModel.from_pretrained(LUAR_PATH, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD")

In [47]:
# The following two functions are taken from src/models/lightning_trainer.py
def validation_collate_fn(batch):
    """Some validation datasets have authors with less than < 16 episodes. 
       When batching, make sure that we don't run into stacking problems. 
    """

    data, author = zip(*batch)

    author = torch.stack(author)

    # Minimum number of posts for an author history in batch
    min_posts = min([d[0].shape[1] for d in data])
    # If min_posts < episode length, need to subsample
    if min_posts < 16:
        data = [torch.stack([f[:, :min_posts, :] for f in feature])
                for feature in zip(*data)]
    # Otherwise, stack data as is
    else:
        data = [torch.stack([f for f in feature])
                for feature in zip(*data)]

    return data, author

def test_dataloader(params):
    """Returns the validation DataLoader.
    """
    # to counteract different episode sizes during validation / testing
    batch_size = 1 if params.dataset_name in ["raw_amazon", "pan_paragraph", "hrs"] else params.batch_size
    
    queries, targets = utils.get_val_or_test_dataset(params, 'test', only_queries=False)

    q_data_loader = DataLoader(
        queries,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=params.pin_memory,
        num_workers=params.num_workers,
        collate_fn=validation_collate_fn
    )

    t_data_loader = DataLoader(
        targets,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=params.pin_memory,
        num_workers=params.num_workers,
        collate_fn=validation_collate_fn
    )

    return q_data_loader, t_data_loader, queries, targets

In [48]:
def get_embeddings(test_data_loader):
    luar_embeddings = []
    mluar_embeddings = []
    author_labels = []
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_data_loader)):
            data, author = batch[0], batch[1]
            
            luar_embedding = luar_model(data[0].squeeze(0), data[1].squeeze(0))
            luar_embeddings.append(luar_embedding)
        
            mluar_embedding = multiluar_model(data[0].squeeze(0), data[1].squeeze(0))
            mluar_embedding = rearrange(mluar_embedding, 'l b d -> b l d')
            mluar_embeddings.append(mluar_embedding)
    
            author_labels.append(author)
        luar_embeddings = torch.stack(luar_embeddings)
        mluar_embeddings = torch.stack(mluar_embeddings).squeeze(1)

    return luar_embeddings, mluar_embeddings, author_labels

def evaluate_embeddings(q_embed, t_embed, q_mluar_embed, t_mluar_embed,  q_labels, t_labels, domain = 'pan'):
    results = []
    
    luar_sims = compute_similarities(q_embed, t_embed)
    labels_matrix = np.array([[int(x == y) for y in t_labels] for x in q_labels])
    luar_eer, luar_mrr   = eer(luar_sims, labels_matrix), compute_mrr(luar_sims, q_labels, t_labels)
    
    results.append(['LUAR', domain, luar_eer, luar_mrr])
    results.append(['+++', '+++', '+++', '+++'])
    
    muti_luar_layers_sims = np.stack([compute_similarities(q_mluar_embed, t_mluar_embed, layer=i) for i in range(7)])
    
    muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims, 0)
    mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated,  q_labels, t_labels)
    results.append(['MLUAR', domain, mluar_eer, mluar_mrr])
    results.append(['--', '--', '--', '--'])
    
    # Ablation study
    for selector in [[0,1], [0,1,2],[3,4,5], [5,6],[6]]:
        muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims[selector, :, :], 0)
        mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated, q_labels, t_labels)
        results.append(['MLUAR-{}'.format(selector), domain, mluar_eer, mluar_mrr])

    return results

### Evaluating on HRS

In [54]:
params.sanity = None
params.dataset_name = 'hrs'
params.q_author_clm_name='authorIDs'
params.c_author_clm_name='authorSetIDs'
params.token_max_length=32
params.text_key = 'fullText'

In [55]:
params.dataset_path = '/mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_crossGenre/'
params.queries_file_name = 'data/HRS1_english_long_sample-0_crossGenre_TA2_input_queries.jsonl'
params.gt_path = 'groundtruth/HRS1_english_long_sample-0_crossGenre_TA2'
params.candidates_file_name = 'data/HRS1_english_long_sample-0_crossGenre_TA2_input_candidates.jsonl'

hrs_q_loader, hrs_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)

map d3175716-caaf-654e-8fa1-e75a4af557ec --> 550d7750-f0fa-57dc-95d7-ef61b6fc096b
map d6b91159-3354-f7af-aa90-7b458f173bfb --> a0fecc59-2fef-572e-a698-3c67a3ebc212
map e4c148e3-adef-8ba6-7bc1-99d815cfebb0 --> 61224409-9987-53ea-bb5e-4f52e719fbe3
map 0884230b-c3a1-5d7d-36cb-4eac079587fc --> 828fe9e6-a97d-58ea-bb02-e28fb580cf54
map 12739ffb-fd89-e8ca-ab05-b341d33b1e34 --> 2bd046b3-7c09-52c4-968c-b19750e295a1
map 855ffe07-828a-d746-c165-aa61ee4ca5bc --> 61224409-9987-53ea-bb5e-4f52e719fbe3
map 2aaaee8a-72bb-8100-5e21-10cc7bc3a96a --> 828fe9e6-a97d-58ea-bb02-e28fb580cf54
map 4d2d8527-e028-2686-e44d-900eeaa77d06 --> 828fe9e6-a97d-58ea-bb02-e28fb580cf54
map 6afdb97a-d5df-e768-7cf9-9f6ddbb5d6f3 --> ccf55b4e-5751-5279-9088-c742db36d3a1
map a213116e-1e14-4762-3ab6-14e761b372b4 --> 2e4f14ac-f53c-5eba-b8f6-6ca245948d49
map dd4b6dc5-b888-eb29-6653-26b499708d0b --> 966baafa-ee5b-526c-9aed-7ed331eb0a52
map a070dc13-5e22-b546-1b9e-d351f7e2eacb --> 8149c230-ee1e-5628-98a6-9510a6bbae6d
map 7c678fb5-146

In [56]:
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(hrs_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(hrs_t_loader)

42it [00:12,  3.35it/s]
64it [00:15,  4.10it/s]


In [57]:
q_author_labels = [hrs_q_ds.int2AuthorId[x.item()] for x in q_author_labels]
t_author_labels = [hrs_t_ds.int2AuthorId[x.item()] for x in t_author_labels]

In [58]:
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')

In [60]:
#Note that these results are not on the full haystack (In line 102 in the hrs_dataset.py, I only kept the candidate authors that are matched with the query authors)
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             hrs       0.375  0.135
+++              +++       +++    +++
MLUAR            hrs       0.453  0.112
--               --        --     --
MLUAR-[0, 1]     hrs       0.471  0.101
MLUAR-[0, 1, 2]  hrs       0.466  0.109
MLUAR-[3, 4, 5]  hrs       0.422  0.13
MLUAR-[5, 6]     hrs       0.438  0.126
MLUAR-[6]        hrs       0.436  0.128


In [61]:
params.dataset_path = '/mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS2_english_medium/TA2/HRS2_english_medium_sample-0_crossGenre/'
params.queries_file_name = 'data/HRS2_english_medium_sample-0_crossGenre_TA2_input_queries.jsonl'
params.gt_path = 'groundtruth/HRS2_english_medium_sample-0_crossGenre_TA2'
params.candidates_file_name = 'data/HRS2_english_medium_sample-0_crossGenre_TA2_input_candidates.jsonl'

hrs_q_loader, hrs_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)

map 4f94add8-5b2f-5cb1-8ad2-008772603e94 --> 1c05f9d5-16fa-50e6-85f6-a81da55766fc
map e74bdae6-b163-a908-0e98-d8930d066717 --> 1c05f9d5-16fa-50e6-85f6-a81da55766fc
map 2a0f608b-d5e8-8f6e-7bd9-a4380567644e --> 3353a8da-edac-5584-bcac-c8607954d040
map 1757e209-3119-b197-3763-22f5d66be1b2 --> 1c05f9d5-16fa-50e6-85f6-a81da55766fc
map 76f68fbb-4c39-6292-3156-1b55fe5beb55 --> 3353a8da-edac-5584-bcac-c8607954d040
map 123d54e9-4246-c538-1a24-993f24179cdd --> dfa7d505-10c2-54a7-bfd5-5e0b32e8711e
map ece9374f-5cfd-358e-fa5c-2c6a1af63cf4 --> 1c05f9d5-16fa-50e6-85f6-a81da55766fc
map 01075158-d1da-c3af-a71e-681385ba46ca --> 3353a8da-edac-5584-bcac-c8607954d040
map c5ee9c03-2aa0-f69d-523b-04d36902e6a9 --> 8568f101-9281-5aad-9a09-ddb08535e583
map ad61e2d0-1eff-2161-7ad2-f73c46f62946 --> 1f4e6a77-857b-5774-a2d1-50b1bc12af60
map fd8ee89f-ce89-e5bc-e3d2-0090bc9f9b53 --> 8568f101-9281-5aad-9a09-ddb08535e583
map 2dc9a9e3-dfe0-1c9d-b7fe-6dc4db06868c --> 1f4e6a77-857b-5774-a2d1-50b1bc12af60
map 8fe09888-ffd

In [62]:
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(hrs_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(hrs_t_loader)

q_author_labels = [hrs_q_ds.int2AuthorId[x.item()] for x in q_author_labels]
t_author_labels = [hrs_t_ds.int2AuthorId[x.item()] for x in t_author_labels]

21it [00:04,  4.36it/s]
57it [00:09,  5.86it/s]


In [64]:
#Note that these results are not on the full haystack (In line 102 in the hrs_dataset.py, I only kept the candidate authors that are matched with the query authors)
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             hrs       0.386  0.289
+++              +++       +++    +++
MLUAR            hrs       0.407  0.199
--               --        --     --
MLUAR-[0, 1]     hrs       0.421  0.144
MLUAR-[0, 1, 2]  hrs       0.421  0.153
MLUAR-[3, 4, 5]  hrs       0.439  0.233
MLUAR-[5, 6]     hrs       0.419  0.231
MLUAR-[6]        hrs       0.405  0.245


### Evaluating on Fanfiction:

In [53]:
params = create_argument_parser()
#params.sanity = 100
params.sanity = None
params.episode_length=16
params.model_type='roberta'
params.text_key = 'syms'
params.time_key='hours'
params.suffix=''
params.token_max_length=32
params.use_random_windows=False
params.mask_bpe_percentage=0
params.pin_memory=False
params.num_workers=4

In [50]:
params.dataset_name = 'pan_paragraph'
pan_paragraph_q_loader, pan_paragraph_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(pan_paragraph_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(pan_paragraph_t_loader)

q_author_labels = [x.item() for x in q_author_labels]
t_author_labels = [x.item() for x in t_author_labels]

Loading pan_paragraph dataset test query file: /mnt/swordfish-pool2/nikhil/pan_paragraph/train_raw.jsonl
Loading pan_paragraph dataset test targets file: /mnt/swordfish-pool2/nikhil/pan_paragraph/train_raw.jsonl


10it [00:18,  1.80s/it]
10it [00:18,  1.85s/it]


In [51]:
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'pan')

In [99]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain        EER    MRR
---------------  ------------  -----  -----
LUAR             pan           0.169  0.323
+++              +++           +++    +++
MLUAR            pan           0.105  0.463
--               --            --     --
MLUAR-[0, 1]     cross-domain  0.146  0.397
MLUAR-[0, 1, 2]  cross-domain  0.135  0.412
MLUAR-[3, 4, 5]  cross-domain  0.097  0.472
MLUAR-[5, 6]     cross-domain  0.094  0.476
MLUAR-[6]        cross-domain  0.094  0.474


### Evaluating on Amazon:

In [18]:
params.dataset_name = 'raw_amazon'
amazon_paragraph_q_loader, amazon_paragraph_t_loader = test_dataloader(params)
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(amazon_paragraph_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(amazon_paragraph_t_loader)

q_author_labels = [x.item() for x in q_author_labels]
t_author_labels = [x.item() for x in t_author_labels]

Loading raw_amazon dataset test query file: /mnt/swordfish-pool2/nikhil/raw_amazon/validation_queries.jsonl
Loading raw_amazon dataset test targets file: /mnt/swordfish-pool2/nikhil/raw_amazon/validation_targets.jsonl


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [19]:
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'amazon')

In [20]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

       Domain    EER    MRR
-----  --------  -----  -----
LUAR   amazon    0.062  0.534
+++    +++       +++    +++
MLUAR  amazon    0.02   0.784
--     --        --     --
