In [1]:
%load_ext autoreload

In [2]:
%autoreload
import sys
sys.path.insert(0, '../')

In [3]:
from src.utilities.mluar_utils import *
from src.datasets import utils
from src.arguments import create_argument_parser

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, Dataset
import numpy as np
from einops import rearrange, reduce, repeat
import torch
from sklearn.metrics.pairwise import cosine_similarity
from matplotlib import pyplot as plt
import math
import pandas as pd
from scipy.stats import zscore
import pickle as pkl
import tabulate
import torch.nn.functional as F
from torch.utils.data import DataLoader


In [5]:
params = create_argument_parser()
params.sanity = None
params.episode_length=16
params.model_type='roberta'
params.text_key = 'syms'
params.time_key='hours'
params.suffix=''
params.token_max_length=32
params.mask_bpe_percentage=0
params.pin_memory=False
params.num_workers=5

In [6]:
MULTI_LUAR_PATH =  "/mnt/swordfish-pool2/milad/multi-luar-reddit-model/"
LUAR_PATH =  "/mnt/swordfish-pool2/nikhil/LUAR/pretrained_weights/LUAR-MUD/"

In [7]:
# Load models
multiluar_model = AutoModel.from_pretrained(MULTI_LUAR_PATH, trust_remote_code=True)
luar_model = AutoModel.from_pretrained(LUAR_PATH, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained("rrivera1849/LUAR-MUD")

In [117]:
def validation_collate_fn(batch):
    """Some validation datasets have authors with less than < 16 episodes. 
       When batching, make sure that we don't run into stacking problems. 
    """

    data, author = zip(*batch)

    author = torch.stack(author)

    # Minimum number of posts for an author history in batch
    min_posts = min([d[0].shape[1] for d in data])
    # If min_posts < episode length, need to subsample
    if min_posts < 16:
        data = [torch.stack([f[:, :min_posts, :] for f in feature])
                for feature in zip(*data)]
    # Otherwise, stack data as is
    else:
        data = [torch.stack([f for f in feature])
                for feature in zip(*data)]

    return data, author

def test_dataloader(params):
    """Returns the validation DataLoader.
    """
    # to counteract different episode sizes during validation / testing
    batch_size = 1 if params.dataset_name in ["raw_amazon", "pan_paragraph"] else params.batch_size
    
    queries = utils.get_val_or_test_dataset(params, 'test', only_queries=True)

    data_loaders = [
        DataLoader(
            queries,
            batch_size=batch_size,
            shuffle=False,
            pin_memory=params.pin_memory,
            num_workers=params.num_workers,
            #collate_fn=validation_collate_fn
            )
        ]

    return data_loaders

In [122]:
def get_embeddings(test_data_loader, num_samples=100000):
    luar_embeddings = []
    mluar_embeddings = []
    author_labels = []
    with torch.no_grad():
        for i, batch in enumerate(test_data_loader[0]):
            data, author = batch[0], batch[1]
            
            luar_embedding = luar_model(data[0].squeeze(0), data[1].squeeze(0))
            luar_embeddings.append(luar_embedding)
        
            mluar_embedding = multiluar_model(data[0].squeeze(0), data[1].squeeze(0))
            mluar_embedding = rearrange(mluar_embedding, 'l b d -> b l d')
            mluar_embeddings.append(mluar_embedding)
    
            author_labels.append(author)
            if i > num_samples:
                break
        luar_embeddings = torch.stack(luar_embeddings)
        mluar_embeddings = torch.stack(mluar_embeddings).squeeze(1)

    return luar_embeddings, mluar_embeddings, author_labels

In [None]:
params.dataset_name = 'pan_paragraph'
pan_paragraph_loader = test_dataloader(params)
luar_embeddings, mluar_embeddings, author_labels = get_embeddings(pan_paragraph_loader)

Loading pan_paragraph dataset test query file: /mnt/swordfish-pool2/nikhil/pan_paragraph/queries_raw.jsonl


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Av

In [121]:
len(author_labels)

12

In [97]:
results = []
domain = 'pan'

luar_sims = compute_similarities(luar_embeddings, luar_embeddings)
labels_matrix = np.array([[int(x == y) for y in author_labels] for x in author_labels])
luar_eer, luar_mrr   = eer(luar_sims, labels_matrix), compute_mrr(luar_sims, author_labels)

results.append(['LUAR', domain, luar_eer, luar_mrr])
results.append(['+++', '+++', '+++', '+++'])

muti_luar_layers_sims = np.stack([compute_similarities(mluar_embeddings, mluar_embeddings, layer=i) for i in range(7)])

muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims, 0)
mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated, author_labels)
results.append(['MLUAR', domain, mluar_eer, mluar_mrr])
results.append(['--', '--', '--', '--'])

# Ablation study
for layer in range(7):
    selector = [i for i in range(muti_luar_layers_sims.shape[0]) if i != layer]
    print(selector)
    muti_luar_layers_sims_ablated = np.mean(muti_luar_layers_sims[selector, :, :], 0)
    mluar_eer, mluar_mrr = eer(muti_luar_layers_sims_ablated, labels_matrix), compute_mrr(muti_luar_layers_sims_ablated, author_labels)
    results.append(['MLUAR/{}'.format(layer), domain, mluar_eer, mluar_mrr])
    results.append(['--', '--', '--', '--'])

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


[1, 2, 3, 4, 5, 6]
[0, 2, 3, 4, 5, 6]
[0, 1, 3, 4, 5, 6]
[0, 1, 2, 4, 5, 6]
[0, 1, 2, 3, 5, 6]
[0, 1, 2, 3, 4, 6]
[0, 1, 2, 3, 4, 5]


  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)
  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [99]:
print(tabulate.tabulate(results, headers=['Domain', 'EER', 'MRR']))

         Domain    EER    MRR
-------  --------  -----  -----
LUAR     pan       0.0    nan
+++      +++       +++    +++
MLUAR    pan       0.0    nan
--       --        --     --
MLUAR/0  pan       0.0    nan
--       --        --     --
MLUAR/1  pan       0.0    nan
--       --        --     --
MLUAR/2  pan       0.0    nan
--       --        --     --
MLUAR/3  pan       0.0    nan
--       --        --     --
MLUAR/4  pan       0.0    nan
--       --        --     --
MLUAR/5  pan       0.0    nan
--       --        --     --
MLUAR/6  pan       0.0    nan
--       --        --     --
