# Setup

In [17]:
single_match=False

In [18]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [19]:
from collections import namedtuple

import utils.envsetup
from data.pmc_iterable import PMCIterable
from data.topic_iterable import TopicIterable
import pickle
import torch
from torch.utils.data import DataLoader
from utils.doc_abstract_to_sentence_list_transform import DocAbstractToSentenceListTransform

from transformers import AutoTokenizer

if single_match:
    from utils.ex_aspire_consent import AspireConSent, prepare_abstracts
else:
    from utils.ex_aspire_consent_multimatch import AspireConSent, AllPairMaskedWasserstein
    from utils.ex_aspire_consent_multimatch import prepare_abstracts

import time
import tqdm
import torch.nn.functional as F

# topics = [x['ID'] for x in TopicIterable(format='aspire', topic_file_name='data/data_old_format/topics.pkl')]
# topics.sort()

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

In [20]:
if single_match:
    huggingface_model_name = 'allenai/aspire-contextualsentence-singlem-biomed' # single match
else:
    huggingface_model_name = 'allenai/aspire-contextualsentence-multim-biomed'  # multi match
    ot_distance = AllPairMaskedWasserstein({}, device)
aspire_tokenizer = AutoTokenizer.from_pretrained(huggingface_model_name, cache_dir="/cs/labs/tomhope/taltatal/cache")
aspire_mv_model = AspireConSent(huggingface_model_name, device).to(device)
# Empty dict of hyper params will force class to use defaults.


In [21]:
def apply_model(docs, tokenizer, model, device=torch.device("cpu")):
    bert_batch, abs_lens, sent_token_idxs = prepare_abstracts(batch_abs=docs,
                                                              pt_lm_tokenizer=tokenizer)
    # move batch to device, bert_batch is a dict
    for k, v in bert_batch.items():
        bert_batch[k] = v.to(device) if type(v) == torch.Tensor else v
    # abs_lens is a list
    abs_lens = torch.tensor(abs_lens, dtype=torch.long, device=device)

    clsreps, contextual_sent_reps = model.forward(bert_batch=bert_batch,
                                                  abs_lens=abs_lens,
                                                  sent_tok_idxs=sent_token_idxs)
    return abs_lens, contextual_sent_reps

In [22]:
def pad_to_same_size_along_axis(tensor_list, axis=0):
    max_size = max([x.shape[axis] for x in tensor_list])
    padded_list = []
    for tensor in tensor_list:
        if tensor.shape[axis] < max_size:
            pad_size = max_size - tensor.shape[axis]
            padded_list.append(F.pad(input=tensor, pad=(0, 0, 0, pad_size, 0, 0), mode='constant', value=0))
        else:
            padded_list.append(tensor)
    return padded_list

In [23]:
pmc_ids = pickle.load(open('data/data_old_format/pmc_ids.pkl', 'rb'))

topics_dataloader = DataLoader(TopicIterable(format='aspire', topic_file_name='data/data_old_format/topics.pkl',
                                             transform=DocAbstractToSentenceListTransform()), batch_size=1,
                               collate_fn=lambda x: x)

topics_dataloader = sorted([x for x in topics_dataloader], key=lambda x: int(x[0]['ID']))

# Aspire Pre-Trained

In [24]:
topic_embeddings = []
topic_abs_lens = []
for topic_batch in tqdm.tqdm(topics_dataloader):
    # move topic_batch to device, but avoid non tensors. its a list of dicts
    topic_batch = [{k: v.to(device) if torch.is_tensor(v) else v for k, v in x.items()} for x in topic_batch]

    abs_lens, topic_embedding = apply_model(topic_batch, aspire_tokenizer, aspire_mv_model)
    
    topic_embeddings.append(topic_embedding)
    topic_abs_lens.append(abs_lens)
topic_embeddings = pad_to_same_size_along_axis(topic_embeddings, axis=1)
topic_embeddings_tensor = torch.cat(topic_embeddings, dim=0)

100%|██████████| 30/30 [00:01<00:00, 28.33it/s]


In [26]:
rankings = {x[0]['ID']: list() for x in topics_dataloader}
for article_batch in tqdm.tqdm(DataLoader(PMCIterable(labeled_ids_or_filename=pmc_ids, format='aspire',
                                                      transform=DocAbstractToSentenceListTransform()), batch_size=1,
                                          collate_fn=lambda x: x)):
    with torch.no_grad():
        if len(article_batch[0]['ABSTRACT']) == 0:
            continue
        article_abs_lens, article_embedding = apply_model(article_batch, aspire_tokenizer, aspire_mv_model)
        query_embeds = article_embedding
        
        if single_match:
            for topic_abs_len, topic_embedding, i in zip(topic_abs_lens, topic_embeddings, range(len(topic_abs_lens))):
                distance_matrix = torch.squeeze(torch.cdist(topic_embedding, query_embeds, p=2.0), 0)
                argmax = torch.argmin(distance_matrix)
                indices = torch.stack([argmax // distance_matrix.shape[1], argmax % distance_matrix.shape[1]], -1)
                rankings[topics_dataloader[i][0]['ID']].append((article_batch[0]['ID'], distance_matrix[indices[0], indices[1]].to('cpu').numpy()))
            
        else:
            for topic_abs_len, topic_embedding, i in zip(topic_abs_lens, topic_embeddings, range(len(topic_abs_lens))):
                cand_embeds = topic_embedding
                rep_len_tup = namedtuple('RepLen', ['embed', 'abs_lens'])
                qt = rep_len_tup(embed=query_embeds.permute(0, 2, 1), abs_lens=[article_abs_lens[0]])
                ct = rep_len_tup(embed=cand_embeds.permute(0, 2, 1), abs_lens=[topic_abs_len[0]])
                wd, intermediate_items = ot_distance.compute_distance(query=qt, cand=ct, return_pair_sims=True)
                rankings[topics_dataloader[i][0]['ID']].append((article_batch[0]['ID'], -1*(wd.to('cpu').detach().numpy().item())))



 75%|███████▌  | 28349/37707 [6:51:08<2:15:42,  1.15it/s]


In [28]:
import pickle

for k in tqdm.tqdm(rankings.keys()):
    # rankings[k] = [(_id, rank.to('cpu').numpy().item()) for _id, rank in rankings[k]]
    rankings[k].sort(key=lambda x: x[1])
    pickle.dump(rankings[k], open('data/data_new_format/aspire_only/{}.pkl'.format(k), 'wb'))

    # transport_plan = intermediate_items[3].data.numpy()[0, :article_abs_lens[0], :topic_abs_lens[0]]
    # print(transport_plan.shape)
    # # Print the sentences and plot the optimal transport plan for the pair of abstracts.
    # print('\n'.join([f'{i}: {s}' for i, s in enumerate(topics_dataloader[0][0]['ABSTRACT'])]))
    # print('')
    # print('\n'.join([f'{i}: {s}' for i, s in enumerate(article_batch[0]['ABSTRACT'])]))
    # h = sns.heatmap(transport_plan, linewidths=.7, cmap='Blues')
    # h.set(xlabel='Candidate', ylabel='Query')
    # h.tick_params(labelsize=5)
    # plt.show()

100%|██████████| 30/30 [00:01<00:00, 26.48it/s]


In [None]:
rankings

# ReRanker and Aspire Pre-Trained

In [8]:
import os
import pickle

folder_path = "data/data_old_format/reranker_out/"

reranker_rankings_dict = {}

for file_name in os.listdir(folder_path):
    if file_name.endswith(".pkl"):
        file_path = os.path.join(folder_path, file_name)
        base_name = os.path.splitext(file_name)[0]
        with open(file_path, "rb") as file:
            data = pickle.load(file)
            reranker_rankings_dict[base_name] = data
            

In [9]:
article_ids_per_topic = {topic_id:[article_id for article_id,rank in rankings] for topic_id,rankings in reranker_rankings_dict.items()}

In [10]:
topic_embeddings = []
topic_abs_lens = []
for topic_batch in tqdm.tqdm(topics_dataloader):
    # move topic_batch to device, but avoid non tensors. its a list of dicts
    topic_batch = [{k: v.to(device) if torch.is_tensor(v) else v for k, v in x.items()} for x in topic_batch]

    abs_lens, topic_embedding = apply_model(topic_batch, aspire_tokenizer, aspire_mv_model)
    
    topic_embeddings.append(topic_embedding)
    topic_abs_lens.append(abs_lens)
topic_embeddings = pad_to_same_size_along_axis(topic_embeddings, axis=1)
topic_embeddings_tensor = torch.cat(topic_embeddings, dim=0)

100%|██████████| 30/30 [00:01<00:00, 19.81it/s]


In [12]:
rankings = {x[0]['ID']: list() for x in topics_dataloader}
for i in range(len(topic_embeddings)):
    for article_batch in tqdm.tqdm(DataLoader(PMCIterable(labeled_ids_or_filename=article_ids_per_topic[topics_dataloader[i][0]['ID']], format='aspire',
                                                          transform=DocAbstractToSentenceListTransform()), batch_size=1,
                                              collate_fn=lambda x: x)):
        with torch.no_grad():
            if len(article_batch[0]['ABSTRACT']) == 0:
                continue
            article_abs_lens, article_embedding = apply_model(article_batch, aspire_tokenizer, aspire_mv_model)
            query_embeds = article_embedding

            if single_match:
                distance_matrix = torch.squeeze(torch.cdist(topic_embeddings[i], query_embeds, p=2.0), 0)
                argmax = torch.argmin(distance_matrix)
                indices = torch.stack([argmax // distance_matrix.shape[1], argmax % distance_matrix.shape[1]], -1)
                rankings[topics_dataloader[i][0]['ID']].append((article_batch[0]['ID'], distance_matrix[indices[0], indices[1]].to('cpu').numpy()))

            else:
                cand_embeds = topic_embeddings[i]
                rep_len_tup = namedtuple('RepLen', ['embed', 'abs_lens'])
                qt = rep_len_tup(embed=query_embeds.permute(0, 2, 1), abs_lens=[article_abs_lens[0]])
                ct = rep_len_tup(embed=cand_embeds.permute(0, 2, 1), abs_lens=[topic_abs_lens[i][0]])
                wd, intermediate_items = ot_distance.compute_distance(query=qt, cand=ct, return_pair_sims=True)
                rankings[topics_dataloader[i][0]['ID']].append((article_batch[0]['ID'], -1*(wd.to('cpu').detach().numpy().item())))



100%|██████████| 1000/1000 [00:46<00:00, 21.41it/s]
100%|██████████| 1000/1000 [00:51<00:00, 19.55it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.23it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.73it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.54it/s]
100%|██████████| 1000/1000 [00:47<00:00, 20.93it/s]
100%|██████████| 1000/1000 [00:47<00:00, 20.95it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.45it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.05it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.60it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.03it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.64it/s]
100%|██████████| 1000/1000 [00:51<00:00, 19.33it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.18it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.06it/s]
100%|██████████| 1000/1000 [00:48<00:00, 20.56it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.07it/s]
100%|██████████| 1000/1000 [00:49<00:00, 20.05it/s]
100%|██████████| 1000/1000 [00:50<00:00, 19.74it/s]
100%|███████

In [15]:
import pickle
import os

folder_path = "data/data_new_format/aspire_with_reranker/"

if not os.path.exists(folder_path):
    os.makedirs(folder_path)

for k in tqdm.tqdm(rankings.keys()):
    # rankings[k] = [(_id, rank.to('cpu').numpy().item()) for _id, rank in rankings[k]]
    rankings[k].sort(key=lambda x: x[1])
    pickle.dump(rankings[k], open(os.path.join(folder_path, '{}.pkl'.format(k)), 'wb'))

    # transport_plan = intermediate_items[3].data.numpy()[0, :article_abs_lens[0], :topic_abs_lens[0]]
    # print(transport_plan.shape)
    # # Print the sentences and plot the optimal transport plan for the pair of abstracts.
    # print('\n'.join([f'{i}: {s}' for i, s in enumerate(topics_dataloader[0][0]['ABSTRACT'])]))
    # print('')
    # print('\n'.join([f'{i}: {s}' for i, s in enumerate(article_batch[0]['ABSTRACT'])]))
    # h = sns.heatmap(transport_plan, linewidths=.7, cmap='Blues')
    # h.set(xlabel='Candidate', ylabel='Query')
    # h.tick_params(labelsize=5)
    # plt.show()

100%|██████████| 30/30 [00:00<00:00, 101.35it/s]


In [None]:
rankings['16']