In [2]:
%load_ext autoreload

In [3]:
%autoreload
import sys
import os
os.environ['TRANSFORMERS_CACHE'] = '/mnt/swordfish-pool2/milad/hf-cache'
os.environ['HF_DATASETS_CACHE'] = '/mnt/swordfish-pool2/milad/hf-cache'
sys.path.insert(0, '../')

In [4]:
%autoreload

from src.utilities.mluar_utils import *
from src.datasets import utils
from src.arguments import create_argument_parser
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, Dataset
import numpy as np
from einops import rearrange, reduce, repeat
import torch
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib import pyplot as plt
import math
import pandas as pd
from scipy.stats import zscore
import pickle as pkl
import tabulate
import torch.nn.functional as F
from torch.utils.data import DataLoader
torch.multiprocessing.set_sharing_strategy('file_system')

In [6]:
params = create_argument_parser()
params.sanity = None
params.episode_length=16
params.model_type='roberta'
params.text_key = 'syms'
params.time_key='hours'
params.suffix=''
params.token_max_length=32
params.use_random_windows=False
params.mask_bpe_percentage=0
params.pin_memory=False
params.num_workers=10

In [10]:
#MULTI_LUAR_PATH =  "/mnt/swordfish-pool2/milad/Multi-LUAR/data/multi-luar-reddit-model/"
MULTI_LUAR_PATH =  "/local/nlp/milad/code/Multi-LUAR/data/multi-luar-all-model/"
#LUAR_PATH       =  "/mnt/swordfish-pool2/milad/Multi-LUAR/data/reproduced-luar/"
LUAR_PATH       =  "/mnt/swordfish-pool2/milad/rrivera1849/"

In [11]:
# Load models
multiluar_model = AutoModel.from_pretrained(MULTI_LUAR_PATH, trust_remote_code=True)
luar_model = AutoModel.from_pretrained(LUAR_PATH, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD")

In [12]:
# The following two functions are taken from src/models/lightning_trainer.py
def validation_collate_fn(batch):
    """Some validation datasets have authors with less than < 16 episodes. 
       When batching, make sure that we don't run into stacking problems. 
    """

    data, author = zip(*batch)

    author = torch.stack(author)

    # Minimum number of posts for an author history in batch
    min_posts = min([d[0].shape[1] for d in data])
    # If min_posts < episode length, need to subsample
    if min_posts < 16:
        data = [torch.stack([f[:, :min_posts, :] for f in feature])
                for feature in zip(*data)]
    # Otherwise, stack data as is
    else:
        data = [torch.stack([f for f in feature])
                for feature in zip(*data)]

    return data, author

def test_dataloader(params):
    """Returns the validation DataLoader.
    """
    # to counteract different episode sizes during validation / testing
    batch_size = 1 if params.dataset_name in ["raw_amazon", "pan_paragraph", "hrs"] else params.batch_size
    
    queries, targets = utils.get_val_or_test_dataset(params, 'test', only_queries=False)

    q_data_loader = DataLoader(
        queries,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=params.pin_memory,
        num_workers=params.num_workers,
        collate_fn=validation_collate_fn
    )

    t_data_loader = DataLoader(
        targets,
        batch_size=batch_size,
        shuffle=False,
        pin_memory=params.pin_memory,
        num_workers=params.num_workers,
        collate_fn=validation_collate_fn
    )

    return q_data_loader, t_data_loader, queries, targets

In [13]:
def get_embeddings(test_data_loader):
    luar_embeddings = []
    mluar_embeddings = []
    author_labels = []
    with torch.no_grad():
        for i, batch in tqdm(enumerate(test_data_loader)):
            data, author = batch[0], batch[1]
            luar_embedding = luar_model(data[0].squeeze(0), data[1].squeeze(0), document_batch_size=16)
            luar_embeddings.append(luar_embedding)
        
            mluar_embedding = multiluar_model(data[0].squeeze(0), data[1].squeeze(0), document_batch_size=16)
            mluar_embedding = rearrange(mluar_embedding, 'l b d -> b l d')
            mluar_embeddings.append(mluar_embedding)
    
            author_labels.append(author)

        luar_embeddings = torch.stack(luar_embeddings)
        mluar_embeddings = torch.stack(mluar_embeddings).squeeze(1)

    return luar_embeddings, mluar_embeddings, author_labels

def evaluate_embeddings(q_embed, t_embed, q_mluar_embed, t_mluar_embed,  q_labels, t_labels, domain = 'pan'):
    results = []
    
    luar_sims = compute_similarities(q_embed, t_embed)
    labels_matrix = np.array([[int(x == y) for y in t_labels] for x in q_labels])
    luar_eer, luar_mrr   = eer(luar_sims, labels_matrix), compute_mrr(luar_sims, q_labels, t_labels)
    
    results.append(['LUAR', domain, luar_eer, luar_mrr])
    results.append(['+++', '+++', '+++', '+++'])
    
    muti_luar_layers_sims = np.stack([compute_similarities(q_mluar_embed, t_mluar_embed, layer=i) for i in range(7)])
    
    muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims, 0)
    mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated,  q_labels, t_labels)    
    results.append(['MLUAR', domain, mluar_eer, mluar_mrr])
    results.append(['--', '--', '--', '--'])
    
    # Ablation study
    for selector in [[0], [1], [2], [3], [0,1], [0,1,2],[3,4,5], [5,6],[6]]:
        muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims[selector, :, :], 0)
        mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated, q_labels, t_labels)
        results.append(['MLUAR-{}'.format(selector), domain, mluar_eer, mluar_mrr])

    return results

### Evaluating on HRS

In [14]:
params.sanity = None
params.dataset_name = 'hrs'
params.q_author_clm_name='authorIDs'
params.c_author_clm_name='authorSetIDs'
params.token_max_length=16
params.text_key = 'fullText'

In [92]:
params.dataset_path = '/mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS1_english_long/TA2/HRS1_english_long_sample-0_crossGenre/'
params.queries_file_name = 'data/HRS1_english_long_sample-0_crossGenre_TA2_input_queries.jsonl'
params.gt_path = 'groundtruth/HRS1_english_long_sample-0_crossGenre_TA2'
params.candidates_file_name = 'data/HRS1_english_long_sample-0_crossGenre_TA2_input_candidates.jsonl'

hrs_q_loader, hrs_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)

map d3175716-caaf-654e-8fa1-e75a4af557ec --> 550d7750-f0fa-57dc-95d7-ef61b6fc096b
map d6b91159-3354-f7af-aa90-7b458f173bfb --> a0fecc59-2fef-572e-a698-3c67a3ebc212
map e4c148e3-adef-8ba6-7bc1-99d815cfebb0 --> 61224409-9987-53ea-bb5e-4f52e719fbe3
map 0884230b-c3a1-5d7d-36cb-4eac079587fc --> 828fe9e6-a97d-58ea-bb02-e28fb580cf54
map 12739ffb-fd89-e8ca-ab05-b341d33b1e34 --> 2bd046b3-7c09-52c4-968c-b19750e295a1
map 855ffe07-828a-d746-c165-aa61ee4ca5bc --> 61224409-9987-53ea-bb5e-4f52e719fbe3
map 2aaaee8a-72bb-8100-5e21-10cc7bc3a96a --> 828fe9e6-a97d-58ea-bb02-e28fb580cf54
map 4d2d8527-e028-2686-e44d-900eeaa77d06 --> 828fe9e6-a97d-58ea-bb02-e28fb580cf54
map 6afdb97a-d5df-e768-7cf9-9f6ddbb5d6f3 --> ccf55b4e-5751-5279-9088-c742db36d3a1
map a213116e-1e14-4762-3ab6-14e761b372b4 --> 2e4f14ac-f53c-5eba-b8f6-6ca245948d49
map dd4b6dc5-b888-eb29-6653-26b499708d0b --> 966baafa-ee5b-526c-9aed-7ed331eb0a52
map a070dc13-5e22-b546-1b9e-d351f7e2eacb --> 8149c230-ee1e-5628-98a6-9510a6bbae6d
map 7c678fb5-146

In [94]:
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(hrs_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(hrs_t_loader)

q_author_labels = [hrs_q_ds.int2AuthorId[x.item()] for x in q_author_labels]
t_author_labels = [hrs_t_ds.int2AuthorId[x.item()] for x in t_author_labels]

In [None]:
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

=======

In [146]:
params.dataset_path = '/mnt/swordfish-pool2/milad/hiatus-data/HRS_evaluation_samples/HRS2_english_medium/TA2/HRS2_english_medium_sample-0_crossGenre/'
params.queries_file_name = 'data/HRS2_english_medium_sample-0_crossGenre_TA2_input_queries.jsonl'
params.gt_path = 'groundtruth/HRS2_english_medium_sample-0_crossGenre_TA2'
params.candidates_file_name = 'data/HRS2_english_medium_sample-0_crossGenre_TA2_input_candidates.jsonl'

hrs_q_loader, hrs_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)

map 4f94add8-5b2f-5cb1-8ad2-008772603e94 --> 1c05f9d5-16fa-50e6-85f6-a81da55766fc
map e74bdae6-b163-a908-0e98-d8930d066717 --> 1c05f9d5-16fa-50e6-85f6-a81da55766fc
map 2a0f608b-d5e8-8f6e-7bd9-a4380567644e --> 3353a8da-edac-5584-bcac-c8607954d040
map 1757e209-3119-b197-3763-22f5d66be1b2 --> 1c05f9d5-16fa-50e6-85f6-a81da55766fc
map 76f68fbb-4c39-6292-3156-1b55fe5beb55 --> 3353a8da-edac-5584-bcac-c8607954d040
map 123d54e9-4246-c538-1a24-993f24179cdd --> dfa7d505-10c2-54a7-bfd5-5e0b32e8711e
map ece9374f-5cfd-358e-fa5c-2c6a1af63cf4 --> 1c05f9d5-16fa-50e6-85f6-a81da55766fc
map 01075158-d1da-c3af-a71e-681385ba46ca --> 3353a8da-edac-5584-bcac-c8607954d040
map c5ee9c03-2aa0-f69d-523b-04d36902e6a9 --> 8568f101-9281-5aad-9a09-ddb08535e583
map ad61e2d0-1eff-2161-7ad2-f73c46f62946 --> 1f4e6a77-857b-5774-a2d1-50b1bc12af60
map fd8ee89f-ce89-e5bc-e3d2-0090bc9f9b53 --> 8568f101-9281-5aad-9a09-ddb08535e583
map 2dc9a9e3-dfe0-1c9d-b7fe-6dc4db06868c --> 1f4e6a77-857b-5774-a2d1-50b1bc12af60
map 8fe09888-ffd

In [147]:
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(hrs_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(hrs_t_loader)

q_author_labels = [hrs_q_ds.int2AuthorId[x.item()] for x in q_author_labels]
t_author_labels = [hrs_t_ds.int2AuthorId[x.item()] for x in t_author_labels]

21it [00:06,  3.39it/s]
57it [00:10,  5.35it/s]


In [150]:
#Note that these results are not on the full haystack (In line 102 in the hrs_dataset.py, I only kept the candidate authors that are matched with the query authors)
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             hrs       0.404  0.32
+++              +++       +++    +++
MLUAR            hrs       0.407  0.199
--               --        --     --
MLUAR-[0]        hrs       0.455  0.108
MLUAR-[1]        hrs       0.404  0.245
MLUAR-[2]        hrs       0.404  0.241
MLUAR-[3]        hrs       0.437  0.267
MLUAR-[0, 1]     hrs       0.421  0.144
MLUAR-[0, 1, 2]  hrs       0.421  0.153
MLUAR-[3, 4, 5]  hrs       0.439  0.233
MLUAR-[5, 6]     hrs       0.419  0.231
MLUAR-[6]        hrs       0.405  0.245


=======

In [95]:
params.dataset_path = '/mnt/swordfish-pool2/milad/hiatus-data/V2/english_TA2_p1_and_p2_dev_20240207/'
params.queries_file_name = 'hrs_release_08-14-23_crossGenre-combined_TA2_P1_and_P2_input_queries.jsonl'
params.gt_path = 'hrs_release_08-14-23_crossGenre-combined_TA2_P1_and_P2'
params.candidates_file_name = 'hrs_release_08-14-23_crossGenre-combined_TA2_P1_and_P2_input_candidates.jsonl'

hrs_q_loader, hrs_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)

map e29e5eba-43e0-6e7e-d28e-6ded98b0735b --> 27faf46e-44e6-5ade-a7f0-9ea7394d4072
map 05d52e3b-1e28-5148-1fdf-1ab8c520c2f2 --> 38465e8c-f8c5-58bf-b290-73e6249939ff
map adfc6c98-3c69-53e1-eba4-7be37703ac87 --> 37ae2263-e50b-5aee-908d-cbe1b5dda6ec
map 6dd80c92-5a53-25f0-939b-8edf8c122c21 --> 24bb54e1-a0ac-5b4c-88a6-b3b49ea2242d
map 19aef804-b2ac-6671-e273-50caac32c336 --> fa58196e-88fc-5a3d-a3cc-37d2ebc43201
map 98976f54-12c2-f4f2-6392-afef6b1088e4 --> e52cac2e-d5b5-5bbe-a20c-17ad70f028f4
map 82c87b59-62ff-3b0a-b587-0ab4b62a7479 --> 0ee01d46-f3a9-5b8f-9062-b054fdbbbc44
map fb122917-6d0f-2f9c-f211-6ac6efb40021 --> 354aaddf-17bc-5256-bd64-9ad6c6b3608d
map 618a3618-3b8f-1d8c-cbd7-4fa898fda4f8 --> 2ded40a0-2072-5819-96fb-a6d755ffb83d
map b5b852b6-edf7-a071-e0dc-80c8b028524d --> ef9f5dad-14b4-5599-bdde-00dd82069278
map d5a9dd3e-2f0f-f30c-96ce-4ce3d1d700b5 --> 71f49eaf-fb11-5531-9abe-4682907cdfd9
map ef4cf8f0-0181-c0ec-909b-db9d24f28b59 --> 7e73d5d4-1990-5c1c-8526-7ba7bdd8bf5b
map 6a52882a-03d

In [96]:
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(hrs_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(hrs_t_loader)

q_author_labels = [hrs_q_ds.int2AuthorId[x.item()] for x in q_author_labels]
t_author_labels = [hrs_t_ds.int2AuthorId[x.item()] for x in t_author_labels]

150it [01:02,  2.40it/s]
2133it [16:24,  2.17it/s]


In [97]:
print(len(q_author_labels), '-->' , len(t_author_labels))

150 --> 2133


In [98]:
#Note that these results are not on the full haystack (In line 102 in the hrs_dataset.py, I only kept the candidate authors that are matched with the query authors)
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             hrs       0.268  0.122
+++              +++       +++    +++
MLUAR            hrs       0.309  0.075
--               --        --     --
MLUAR-[0]        hrs       0.337  0.042
MLUAR-[1]        hrs       0.32   0.058
MLUAR-[2]        hrs       0.309  0.075
MLUAR-[3]        hrs       0.308  0.086
MLUAR-[0, 1]     hrs       0.337  0.053
MLUAR-[0, 1, 2]  hrs       0.325  0.06
MLUAR-[3, 4, 5]  hrs       0.307  0.088
MLUAR-[5, 6]     hrs       0.309  0.085
MLUAR-[6]        hrs       0.308  0.085


In [74]:
#Note that these results are not on the full haystack (In line 102 in the hrs_dataset.py, I only kept the candidate authors that are matched with the query authors)
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             hrs       0.218  0.234
+++              +++       +++    +++
MLUAR            hrs       0.284  0.075
--               --        --     --
MLUAR-[0]        hrs       0.367  0.044
MLUAR-[1]        hrs       0.302  0.063
MLUAR-[2]        hrs       0.29   0.066
MLUAR-[3]        hrs       0.278  0.076
MLUAR-[0, 1]     hrs       0.336  0.057
MLUAR-[0, 1, 2]  hrs       0.314  0.064
MLUAR-[3, 4, 5]  hrs       0.268  0.08
MLUAR-[5, 6]     hrs       0.266  0.076
MLUAR-[6]        hrs       0.262  0.074


In [22]:
#Note that these results are not on the full haystack (In line 102 in the hrs_dataset.py, I only kept the candidate authors that are matched with the query authors)
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             hrs       0.231  0.187
+++              +++       +++    +++
MLUAR            hrs       0.279  0.125
--               --        --     --
MLUAR-[0]        hrs       0.331  0.057
MLUAR-[1]        hrs       0.291  0.09
MLUAR-[2]        hrs       0.273  0.114
MLUAR-[3]        hrs       0.272  0.121
MLUAR-[0, 1]     hrs       0.308  0.07
MLUAR-[0, 1, 2]  hrs       0.302  0.094
MLUAR-[3, 4, 5]  hrs       0.278  0.131
MLUAR-[5, 6]     hrs       0.266  0.128
MLUAR-[6]        hrs       0.266  0.127


#### Evaluating on Biber Dataset

In [20]:
params.dataset_path = '/mnt/swordfish-pool2/milad/hiatus-data/biber_data/qc_samples/TA2/biber_test_qc'
params.queries_file_name = 'data/biber_test_qc_TA2_input_queries.jsonl'
params.candidates_file_name = 'data/biber_test_qc_TA2_input_candidates.jsonl'
params.gt_path = 'groundtruth/biber_test_qc_TA2'

hrs_q_loader, hrs_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)

map 7a9a6130-0147-c2d7-3eb6-ec7644394e54 --> 1956622.0
map dcb9e331-cdfd-4c0a-b4bd-0f3fe8f7937f --> 1570566.0
map 780e51f0-8264-98d5-dba7-cd9cd34633fc --> 1956622.0
map cd4fc952-8a39-4607-3767-94d84d77079d --> 4001102.0
map 6621ca89-459e-1688-3cb5-9589d7f51b57 --> 9ac8c53d-0c65-af52-910e-87c32ddd6d40
map fc03e608-b0d0-6390-37dc-f811ca961a03 --> 1956622.0
map 5df0d301-ac00-17f3-5770-7f1a11e4259c --> 4118185.0
map 4f8878e1-d430-ae7e-7966-1b39abc4ccd8 --> 976482_Stackoverflow
map 41740ba9-ef91-ad10-d70f-0313beca13b8 --> 976482_Stackoverflow
map 3fd71027-a4f6-c013-4758-3bba126fd630 --> 22743_gaming.stackexchange
map cca64c5d-f2e8-0c4e-a65d-c54897056ef6 --> 4373948_Stackoverflow
map 5d899ec3-ac76-d60f-b042-6928a1a948bb --> 4001102.0
map 26447bcf-983e-ac1c-a938-bcac0737961e --> 3899443.0
map b95bccd7-4fa0-e5fb-5e71-bf7476c7bee7 --> 3899443.0
map 39c21b09-4809-1039-b9fa-727ab8e92792 --> 9ac8c53d-0c65-af52-910e-87c32ddd6d40
map 133c42c4-fbcf-4538-c7ef-a5abaa97e194 --> 1570566.0
map 8bd813db-d8

In [21]:
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(hrs_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(hrs_t_loader)

q_author_labels = [hrs_q_ds.int2AuthorId[x.item()] for x in q_author_labels]
t_author_labels = [hrs_t_ds.int2AuthorId[x.item()] for x in t_author_labels]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISMTOKENIZERS_PARALLELISM=(true | false)
=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
huggingface/tokenizers: The current process just

In [22]:
print(len(q_author_labels), '-->' , len(t_author_labels))

163 --> 11529


In [23]:
#Note that these results are not on the full haystack (In line 102 in the hrs_dataset.py, I only kept the candidate authors that are matched with the query authors)
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             hrs       0.184  0.34
+++              +++       +++    +++
MLUAR            hrs       0.203  0.312
--               --        --     --
MLUAR-[0]        hrs       0.26   0.232
MLUAR-[1]        hrs       0.195  0.272
MLUAR-[2]        hrs       0.201  0.306
MLUAR-[3]        hrs       0.219  0.326
MLUAR-[0, 1]     hrs       0.236  0.26
MLUAR-[0, 1, 2]  hrs       0.22   0.279
MLUAR-[3, 4, 5]  hrs       0.206  0.322
MLUAR-[5, 6]     hrs       0.208  0.309
MLUAR-[6]        hrs       0.206  0.309


In [15]:
params.dataset_path = '/mnt/swordfish-pool2/milad/hiatus-data/biber_data/qc_samples_gte250/TA2/biber_test_qc'
params.queries_file_name = 'data/biber_test_qc_TA2_input_queries.jsonl'
params.candidates_file_name = 'data/biber_test_qc_TA2_input_candidates.jsonl'
params.gt_path = 'groundtruth/biber_test_qc_TA2'

hrs_q_loader, hrs_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)

map a2b922b8-1584-6e15-5ab9-d825653ff83e --> d4742d41-baaf-3ccb-a5f4-b83df49866b9
map 8fb72018-44f4-b2b6-0e46-2841799c433f --> d4742d41-baaf-3ccb-a5f4-b83df49866b9
map 8b4780d0-e74a-bf9c-cc39-a467bab6ad71 --> 3567_scifi.stackexchange
map f709396c-1645-0eac-c74b-af317fed2911 --> 3567_scifi.stackexchange
map ad7a38aa-d132-9fd2-e31d-186022022135 --> 3567_scifi.stackexchange
map 0a77722a-9121-5085-aa1c-e5dbd025eb23 --> 167_islam.stackexchange
map 7b0f150b-8e18-dd84-530b-acb0a8f9af23 --> 167_islam.stackexchange
map 712d0a12-f310-3484-3f6b-4f329911340d --> d4742d41-baaf-3ccb-a5f4-b83df49866b9
map 847af5b4-7d39-6ed4-849f-edb0d07ec114 --> 4039232
map dc2de0e2-7960-7b91-7e8e-7f28839ffa5c --> faa77ead-7597-4fc9-a505-ce467cef5871
map 5b178957-22f4-3881-0660-a7a5acc353f8 --> 4039232
map e8f17a23-53ec-36a8-9662-0708b3076ef7 --> 19246_politics.stackexchange
map 809a099b-07a5-10a7-02dd-634688c4611f --> 19246_politics.stackexchange
map 38681f97-9ead-bf38-5462-24c3a6fe0ecd --> c76c3c50-b7df-2c78-2b37-7

In [16]:
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(hrs_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(hrs_t_loader)

q_author_labels = [hrs_q_ds.int2AuthorId[x.item()] for x in q_author_labels]
t_author_labels = [hrs_t_ds.int2AuthorId[x.item()] for x in t_author_labels]

print(len(q_author_labels), '-->' , len(t_author_labels))

#Note that these results are not on the full haystack (In line 102 in the hrs_dataset.py, I only kept the candidate authors that are matched with the query authors)
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'hrs')
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

116 --> 14271
                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             hrs       0.115  0.407
+++              +++       +++    +++
MLUAR            hrs       0.141  0.411
--               --        --     --
MLUAR-[0]        hrs       0.174  0.353
MLUAR-[1]        hrs       0.141  0.391
MLUAR-[2]        hrs       0.136  0.399
MLUAR-[3]        hrs       0.132  0.408
MLUAR-[0, 1]     hrs       0.153  0.383
MLUAR-[0, 1, 2]  hrs       0.153  0.386
MLUAR-[3, 4, 5]  hrs       0.136  0.391
MLUAR-[5, 6]     hrs       0.136  0.373
MLUAR-[6]        hrs       0.14   0.368


### Evaluating on Fanfiction:

In [30]:
params = create_argument_parser()
params.sanity = None
#params.sanity = None
params.episode_length=16
params.model_type='roberta'
params.text_key = 'syms'
params.time_key='hours'
params.suffix=''
params.token_max_length=32
params.use_random_windows=False
params.mask_bpe_percentage=0
params.pin_memory=False
params.num_workers=10

In [31]:
params.dataset_name = 'pan_paragraph'
pan_paragraph_q_loader, pan_paragraph_t_loader, hrs_q_ds, hrs_t_ds = test_dataloader(params)
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(pan_paragraph_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(pan_paragraph_t_loader)

q_author_labels = [x.item() for x in q_author_labels]
t_author_labels = [x.item() for x in t_author_labels]

Loading pan_paragraph dataset test query file: /mnt/swordfish-pool2/nikhil/pan_paragraph/queries_raw.jsonl
Loading pan_paragraph dataset test targets file: /mnt/swordfish-pool2/nikhil/pan_paragraph/targets_raw.jsonl


8844it [21:57,  6.71it/s]
8844it [20:08,  7.32it/s]


In [None]:
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'pan')

In [37]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain    EER    MRR
---------------  --------  -----  -----
LUAR             pan       0.22   0.12
+++              +++       +++    +++
MLUAR            pan       0.129  0.305
--               --        --     --
MLUAR-[0]        pan       0.202  0.166
MLUAR-[1]        pan       0.165  0.238
MLUAR-[2]        pan       0.146  0.27
MLUAR-[3]        pan       0.133  0.29
MLUAR-[0, 1]     pan       0.172  0.228
MLUAR-[0, 1, 2]  pan       0.157  0.257
MLUAR-[3, 4, 5]  pan       0.132  0.302
MLUAR-[5, 6]     pan       0.134  0.297
MLUAR-[6]        pan       0.136  0.29


In [99]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

                 Domain        EER    MRR
---------------  ------------  -----  -----
LUAR             pan           0.169  0.323
+++              +++           +++    +++
MLUAR            pan           0.105  0.463
--               --            --     --
MLUAR-[0, 1]     cross-domain  0.146  0.397
MLUAR-[0, 1, 2]  cross-domain  0.135  0.412
MLUAR-[3, 4, 5]  cross-domain  0.097  0.472
MLUAR-[5, 6]     cross-domain  0.094  0.476
MLUAR-[6]        cross-domain  0.094  0.474


### Evaluating on Amazon:

In [18]:
params.dataset_name = 'raw_amazon'
amazon_paragraph_q_loader, amazon_paragraph_t_loader = test_dataloader(params)
q_luar_embeddings, q_mluar_embeddings, q_author_labels = get_embeddings(amazon_paragraph_q_loader)
t_luar_embeddings, t_mluar_embeddings, t_author_labels = get_embeddings(amazon_paragraph_t_loader)

q_author_labels = [x.item() for x in q_author_labels]
t_author_labels = [x.item() for x in t_author_labels]

Loading raw_amazon dataset test query file: /mnt/swordfish-pool2/nikhil/raw_amazon/validation_queries.jsonl
Loading raw_amazon dataset test targets file: /mnt/swordfish-pool2/nikhil/raw_amazon/validation_targets.jsonl


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [19]:
results = evaluate_embeddings(q_luar_embeddings, t_luar_embeddings, q_mluar_embeddings, t_mluar_embeddings, q_author_labels, t_author_labels, domain = 'amazon')

In [20]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

       Domain    EER    MRR
-----  --------  -----  -----
LUAR   amazon    0.062  0.534
+++    +++       +++    +++
MLUAR  amazon    0.02   0.784
--     --        --     --
