# Setup

In [193]:
single_match=True

In [194]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

In [195]:
from transformers import AutoTokenizer

if single_match:
    from utils.ex_aspire_consent import AspireConSent, prepare_abstracts
else:
    from utils.ex_aspire_consent_multimatch import AspireConSent, AllPairMaskedWasserstein
    from utils.ex_aspire_consent_multimatch import prepare_abstracts

import time
import tqdm
import torch.nn.functional as F

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# device = torch.device('cpu')

# Dataset Setup

In [None]:
from collections import namedtuple

import utils.envsetup
from data.pmc_iterable import PMCIterable
from data.topic_iterable import TopicIterable
import pickle
import torch
from torch.utils.data import DataLoader
from utils.doc_abstract_to_sentence_list_transform import DocAbstractToSentenceListTransform

In [19]:
pmc_ids = pickle.load(open('data/data_old_format/pmc_ids.pkl', 'rb'))

topics_trainloader = DataLoader(TopicIterable(format='aspire', train=True, topic_file_name='data/data_old_format/topics.pkl',
                                             transform=DocAbstractToSentenceListTransform()), batch_size=1,
                               collate_fn=lambda x: x)

topics_trainloader = sorted([x for x in topics_dataloader], key=lambda x: int(x[0]['ID']))

topics_testloader = DataLoader(TopicIterable(format='aspire', test=True, topic_file_name='data/data_old_format/topics.pkl',
                                             transform=DocAbstractToSentenceListTransform()), batch_size=1,
                               collate_fn=lambda x: x)

topics_testloader = sorted([x for x in topics_testloader], key=lambda x: int(x[0]['ID']))

In [60]:
import pickle
with open("data/data_old_format/labels_dict.pkl", 'rb') as file:
    labelled_article_ids_pickle = pickle.load(file)

labelled_article_ids = {}
for key in labelled_article_ids_pickle.keys():
    positive = [k for k,v in labelled_article_ids_pickle[key].items() if v in ['1','2']]
    negative = [k for k,v in labelled_article_ids_pickle[key].items() if v == '0']
    labelled_article_ids[key] = {"POSITIVE": positive, "NEGATIVE": negative}

In [261]:
import time
start = time.time()
train_triplets = []
negs_per_pos = 5
for topic in topics_trainloader:
    topic = topic[0]
    topic_id = topic['ID']
    articles_positive_loader = DataLoader(PMCIterable(labeled_ids_or_filename=labelled_article_ids[topic_id]['POSITIVE'], format='aspire',
                                                      transform=DocAbstractToSentenceListTransform()), batch_size=1, collate_fn=lambda x: x)
    articles_negative_loader = DataLoader(PMCIterable(labeled_ids_or_filename=labelled_article_ids[topic_id]['NEGATIVE'], format='aspire',
                                                      transform=DocAbstractToSentenceListTransform()), batch_size=1, collate_fn=lambda x: x)
    iterator = iter(articles_negative_loader)
    for pos in articles_positive_loader:
        pos = pos[0]
        for _ in range(negs_per_pos):
            try:
                neg = next(iterator)[0]
                train_triplets.append(
                    (topic, pos, neg)
                )
            except StopIteration as e:
                break
print("Dataset Loadtime: {}".format(time.time() - start))
print(len(train_triplets))

Dataset Loadtime: 61.04307842254639
13199


In [242]:
transform_to_sentence_list = DocAbstractToSentenceListTransform()
import time
start = time.time()
test_triplets = []
negs_per_pos = 5
for topic in topics_testloader:
    topic = topic[0]
    topic_id = topic['ID']
    articles_positive_loader = DataLoader(PMCIterable(labeled_ids_or_filename=labelled_article_ids[topic_id]['POSITIVE'], format='aspire',
                                                      transform=DocAbstractToSentenceListTransform()), batch_size=1, collate_fn=lambda x: x)
    articles_negative_loader = DataLoader(PMCIterable(labeled_ids_or_filename=labelled_article_ids[topic_id]['NEGATIVE'], format='aspire',
                                                      transform=DocAbstractToSentenceListTransform()), batch_size=1, collate_fn=lambda x: x)
    iterator = iter(articles_negative_loader)
    for pos in articles_positive_loader:
        pos = pos[0]
        for _ in range(negs_per_pos):
            try:
                neg = next(iterator)[0]
                test_triplets.append(
                    (topic, pos, neg)
                )
            except StopIteration as e:
                break
print("Dataset Loadtime: {}".format(time.time() - start))
print(len(test_triplets))

Dataset Loadtime: 19.802971839904785
2779


In [275]:
import pickle
with open("data/data_new_format/triplets/train_triplets.pkl", "wb") as file:
    pickle.dump(train_triplets, file)
    
import pickle
with open("data/data_new_format/triplets/test_triplets.pkl", "wb") as file:
    pickle.dump(test_triplets, file)

## Triplets to top sentence Triplet

Performed In Parallel (Long Computation) In "data/data_new_format/triplets/triplets_to_sentence_only_triplets.ipynb"

In [255]:
# from transformers import AutoModel, AutoTokenizer

# aspire_sent = AutoModel.from_pretrained('allenai/aspire-sentence-embedder')
# aspire_tok = AutoTokenizer.from_pretrained('allenai/aspire-sentence-embedder')

# def apply_sent_bert(sents):
#     inputs = aspire_tok(sents, padding=True, truncation=True, return_tensors="pt", max_length=512)

#     result = aspire_sent(**inputs)

#     clsrep = result.last_hidden_state[:,0,:]
    
#     return clsrep

In [None]:
# test_triplets_sentence_only = []
# for triplet in tqdm.tqdm(test_triplets):
#     # generate embeddings
#     triplet = train_triplets[0]
#     topic, pos, neg = triplet[0]['ABSTRACT'], triplet[1]['ABSTRACT'], triplet[2]['ABSTRACT']
#     co_citation_context = triplet[0]["CO-CITATION-CONTEXT"]
#     topic_embed, pos_embed, neg_embed, context_embed = apply_sent_bert(topic), apply_sent_bert(pos), apply_sent_bert(pos), apply_sent_bert(co_citation_context)
    
#     # create sentence-only triplet
#     triplet_sentence_only = []
#     for sentences, embedding in zip([topic, pos, neg], [topic_embed, pos_embed, neg_embed]):
#         distance_topic_context = torch.squeeze(torch.cdist(embedding, context_embed, p=2.0), 0)
#         argmax = torch.argmin(distance_topic_context)
#         indices = torch.stack([argmax // distance_topic_context.shape[1], argmax % distance_topic_context.shape[1]], -1)
#         triplet_sentence_only.append(sentences[indices[0]])
#     test_triplets_sentence_only.append(tuple(triplet_sentence_only))

In [277]:
import pickle
if os.path.exists("data/data_new_format/triplets/train_triplets_sentence_only.pkl"):
    with open("data/data_new_format/triplets/train_triplets_sentence_only.pkl", "rb") as file:
        train_triplets_sentence_only = pickle.load(file)
else:
    print("train triplets not loaded")

In [277]:
if os.path.exists("data/data_new_format/triplets/test_triplets_sentence_only.pkl"):
    with open("data/data_new_format/triplets/test_triplets_sentence_only.pkl", "rb") as file:
        test_triplets_sentence_only = pickle.load(file)
else:
    print("test triplets not loaded")

train triplets not loaded
test triplets not loaded


In [129]:
from torch.utils.data import Dataset

class TripletsDataset(Dataset):
    """Triplets Landmarks dataset."""

    def __init__(self, triplets, transform=None):
        self.triplets = triplets
        self.transform = transform

    def __len__(self):
        return len(self.triplets)

    def __getitem__(self, idx):
        
        sample = self.triplets[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample
    
train_triplets_dataset = TripletsDataset(train_triplets)
test_triplets_dataset = TripletsDataset(test_triplets)

# Model Setup

In [4]:
if single_match:
    huggingface_model_name = 'allenai/aspire-contextualsentence-singlem-biomed' # single match
else:
    huggingface_model_name = 'allenai/aspire-contextualsentence-multim-biomed'  # multi match
    ot_distance = AllPairMaskedWasserstein({}, device)
aspire_tokenizer = AutoTokenizer.from_pretrained(huggingface_model_name, cache_dir="/cs/labs/tomhope/taltatal/cache")
aspire_mv_model = AspireConSent(huggingface_model_name, device).to(device)
# Empty dict of hyper params will force class to use defaults.


In [5]:
def apply_model(docs, tokenizer, model, device=torch.device("cpu")):
    bert_batch, abs_lens, sent_token_idxs = prepare_abstracts(batch_abs=docs,
                                                              pt_lm_tokenizer=tokenizer)
    # move batch to device, bert_batch is a dict
    for k, v in bert_batch.items():
        bert_batch[k] = v.to(device) if type(v) == torch.Tensor else v
    # abs_lens is a list
    abs_lens = torch.tensor(abs_lens, dtype=torch.long, device=device)

    clsreps, contextual_sent_reps = model.forward(bert_batch=bert_batch,
                                                  abs_lens=abs_lens,
                                                  sent_tok_idxs=sent_token_idxs)
    return abs_lens, contextual_sent_reps

In [6]:
def pad_to_same_size_along_axis(tensor_list, axis=0):
    max_size = max([x.shape[axis] for x in tensor_list])
    padded_list = []
    for tensor in tensor_list:
        if tensor.shape[axis] < max_size:
            pad_size = max_size - tensor.shape[axis]
            padded_list.append(F.pad(input=tensor, pad=(0, 0, 0, pad_size, 0, 0), mode='constant', value=0))
        else:
            padded_list.append(tensor)
    return padded_list

In [None]:
from torch.optim import AdamW
Loss = torch.nn.TripletMarginLoss(margin=1)

# Set hyperparameters
batch_size = 1
learning_rate = 2e-5
num_epochs = 1

train_triplets_loader = DataLoader(train_triplets_dataset, batch_size=1, shuffle=True)
test_triplets_loader = DataLoader(test_triplets_dataset, batch_size=1, shuffle=False)

optimizer = AdamW(aspire_mv_model.parameters(), lr=learning_rate)
# Fine-Tuning
aspire_mv_model.train()

In [172]:
# for epoch in range(num_epochs):
#     total_loss = 0
#     for batch in tqdm.tqdm(train_triplets_loader):
#         break
#         input_ids = batch["input_ids"].to(device)
#         attention_mask = batch["attention_mask"].to(device)
#         labels = batch["labels"].to(device, dtype=torch.float32)

#         model.zero_grad()

#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         outputs_probabilities = torch.sigmoid(outputs.logits.view(-1))
#         loss = Loss(outputs_probabilities, target=labels)
#         total_loss += loss.item()

#         loss.backward()
#         optimizer.step()

#     avg_loss = total_loss / len(train_triplets_loader)

#     print(f"Fine Tuning Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss}")


  0%|          | 0/13199 [00:00<?, ?it/s]

Fine Tuning Epoch 1/1, Loss: 0.0



