In [2]:
import re, os
from tqdm.auto import tqdm 


def preprocess_latex_corpus(corpus):
    citation_id = 1
    citation_map = {}  # To keep track of the citation to id mapping

    processed_corpus = []

    # Define the regex patterns
    citation_pattern = re.compile(r"\\cite[t,p]?{(.*?)}")
    command_pattern = re.compile(r"\\[a-zA-Z]+")
    comment_pattern = re.compile(r"%.*?$", re.MULTILINE)
    environment_pattern = re.compile(r"\\begin{(figure|table|equation).*?\\end{\1}", re.DOTALL)

    for document in tqdm(corpus):
        # Remove comments
        document = re.sub(comment_pattern, "", document)

        # Remove non-informative environments
        # document = re.sub(environment_pattern, "", document)

        # Replace citation commands
        matches = citation_pattern.findall(document)
        for match in matches:
            citations = match.split(",")  # Handle multiple citations within one command
            for citation in citations:
                citation = citation.strip()
                
                # Assign a unique id to each citation if not already done & if not empty
                if citation not in citation_map and citation != "":
                    citation_map[citation] = f"<CITATION_{citation_id}>"
                    citation_id += 1
                
                    # Replace the citation with the special token
                    document = document.replace(citation, citation_map[citation])

        # Remove other commands
        # document = re.sub(command_pattern, "", document)
        document = document.replace('{<','<').replace('>}','>')
        if r'\begin{document}' not in document:
            continue
        document = document.split(r'\begin{document}')[1] # get the body
        processed_corpus.append(document)

    return processed_corpus, citation_map


def read_tex_files(directory):
    corpus = []
    for root, dirs, files in tqdm(os.walk(directory)):
        try:
            for file_name in files:
                if file_name.endswith('.tex'):
                    # print(file_name)
                    file_path = os.path.join(root, file_name)
                    with open(file_path, 'r') as file:
                        content = file.read()
                        corpus.append(content)
        except:
            pass
    return corpus 

corpus = read_tex_files('./sources/')
processed_corpus, citation_map = preprocess_latex_corpus(corpus)

# pint number of corpus before & after preprocessing
print(f'Number of documents before preprocessing: {len(corpus)}')
print(f'Number of documents after preprocessing: {len(processed_corpus)}')



0it [00:00, ?it/s]

  0%|          | 0/3236 [00:00<?, ?it/s]

Number of documents before preprocessing: 3236
Number of documents after preprocessing: 474


In [17]:
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast
import numpy as np
import torch
import random

import random

class MaskedCitationDataset(Dataset):
    def __init__(self, text_list, tokenizer, seq_len, mask_prob=0.15):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.mask_prob = mask_prob
        self.inputs = []
        self.targets = []

        for text in tqdm(text_list,total=len(text_list)):
            self.process_text(text)

    def process_text(self, text):
        tokenized_text = self.tokenizer.encode(text)
        for i in range(0, len(tokenized_text)-self.seq_len, self.seq_len):
            if i+self.seq_len+1 >= len(tokenized_text):
                break
            sequence = tokenized_text[i:i+self.seq_len]
            masked_sequence, target_sequence = self.mask_sequence(sequence)
            if len(target_sequence[target_sequence!=-100]) == 0:
                continue    
            self.inputs.append(torch.tensor(masked_sequence.clone().detach()))
            self.targets.append(torch.tensor(target_sequence))

    def mask_sequence(self, sequence):
        sequence = np.array(sequence)
        target_sequence = np.full(sequence.shape, -100)  # -100 is the default ignore index for CrossEntropyLoss
        citation_indices = np.where(sequence >= self.tokenizer.additional_special_tokens_ids[0])[0]
        # num_to_mask = max(1, int(len(citation_indices) * self.mask_prob))  # At least mask 1 token
        # num_to_mask = min(num_to_mask, len(citation_indices))  # Ensure we do not attempt to mask more citations than exist
        # mask_indices = np.random.choice(citation_indices, num_to_mask, replace=False)
        target_sequence[citation_indices] = sequence[citation_indices]
        sequence[citation_indices] = self.tokenizer.mask_token_id
        return torch.tensor(sequence.tolist()), torch.tensor(target_sequence.tolist())


    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

# Initialize the tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

special_tokens_dict = {'additional_special_tokens': list(citation_map.values())}
tokenizer.add_special_tokens(special_tokens_dict)


# Split the processed_corpus into training and validation sets
train_size = int(len(processed_corpus) * 0.8)
random.shuffle(processed_corpus)
train_texts = processed_corpus[:train_size]
val_texts = processed_corpus[train_size:]

# Create the datasets
seq_len = 512

train_dataset = MaskedCitationDataset(train_texts, tokenizer, seq_len=seq_len)
val_dataset = MaskedCitationDataset(val_texts, tokenizer, seq_len=seq_len)


  0%|          | 0/379 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (41012 > 512). Running this sequence through the model will result in indexing errors
  self.inputs.append(torch.tensor(masked_sequence.clone().detach()))
  self.targets.append(torch.tensor(target_sequence))


  0%|          | 0/95 [00:00<?, ?it/s]

395

In [19]:
from transformers import BertForMaskedLM, BertTokenizerFast, AdamW

import torch


# Initialize the tokenizer and model
model = BertForMaskedLM.from_pretrained('bert-base-uncased',)

# resise vocab size
model.resize_token_embeddings(len(tokenizer))

# Create the data loaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-4)

# Training loop
criterion = torch.nn.CrossEntropyLoss(ignore_index=-100)

for epoch in range(100):  # Number of epochs
    model.train()
    bar = tqdm(enumerate(train_loader),total=len(train_loader),)
    running_loss = []
    for i, (inputs, targets) in bar:
        inputs = inputs.to(device)
        targets = targets.to(device)
        optimizer.zero_grad()
        outputs = model(inputs,)
        logits = outputs.logits.view(-1, outputs.logits.size(-1))
        targets_flattened = targets.view(-1)
        loss = criterion(logits,targets_flattened)
        loss.backward()
        optimizer.step()

        running_loss.append(loss.item())
        if (i+1) % 20 == 0:
            # print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")
            avg_loss = sum(running_loss) / len(running_loss)
            bar.set_postfix(iteration=i, loss=f"{avg_loss:.5f}")
            running_loss = []
        if i > 100:
            break
            
    # Validation loop
    model.eval()
    with torch.no_grad():
        total_loss = 0
        total_acc = 0
        total_citations = 0
        bar = tqdm(val_loader, total=len(val_loader),)
        running_loss = []
        for i, (inputs, targets) in enumerate(bar):
            if i > 100:
                break
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            logits = outputs.logits.view(-1, outputs.logits.size(-1))
            targets_flattened = targets.view(-1)

            loss = criterion(logits, targets_flattened)
            total_loss += loss.item()

            if (i+1) % 20 == 0:
                avg_loss = total_loss / (i+1)
                bar.set_postfix(iteration=i, loss=f"{avg_loss:.5f}")
                running_loss = []

            # Compute the citation accuracy
            citation_predictions = torch.argmax(logits, dim=-1)  # Shape: (batch_size, seq_len)
            correct_predictions = (citation_predictions == targets_flattened).sum().item()
            total_citations += targets.numel()
            total_acc += correct_predictions
            # print('corr pred = ', correct_predictions, ' cites = ', citation_targets.numel(), ' total acc = ', total_acc, ' total cites = ', total_citations)

        print(f"Validation Loss: {total_loss / len(val_loader)}")
        print(f"Citation Accuracy: {total_acc / total_citations if total_citations > 0 else 0}")



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 5.8685675048828125
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.067061805725098
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.024843314034598
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.1696972601754325
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.302638953072684
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.784555151803153
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.456542854309082
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.655230772835868
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.717359913417271
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.982607830592564
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 6.659653783525739
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.227774423871722
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.234947547912598
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.103667308262416
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.546057444981166
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.291487334115165
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.802298981802804
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.350942328316825
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.7918954195295065
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

  0%|          | 0/175 [00:00<?, ?it/s]

Validation Loss: 7.653470371791295
Citation Accuracy: 0.0


  0%|          | 0/1026 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [78]:
targets_flattened[targets_flattened>0]

tensor([], device='cuda:0', dtype=torch.int64)

In [51]:
import torch.nn.functional as F

# Define the sample text
sample_text = "Batch normalization improves training efficiency of neural networks <CITATION_1>"

# # Preprocess the text by replacing the <cite_1> token with a <mask> token
# mask_token = tokenizer.mask_token
# masked_text = sample_text.replace("<cite_1>", mask_token)

# Tokenize the text and convert it to a tensor
inputs = tokenizer(sample_text, return_tensors="pt")

# Send the tensor to the device
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}

# Forward pass through the model
with torch.no_grad():
    outputs = model(**inputs)

# Get the logits and compute the softmax to get probabilities
logits = outputs.logits
probs = F.softmax(logits, dim=-1)

# Get the position of the <mask> token
mask_position = torch.where(inputs["input_ids"] == tokenizer.encode('<CITATION_1>')[0])[1]

# Get the probabilities of the tokens at the <mask> position
mask_probs = probs[0, mask_position, :]

# Get the top k probabilities and their indices
top_k_probs, top_k_indices = torch.topk(mask_probs, k=5)  # Get the top 5 predictions

# Print the top k citations and their probabilities
for i in range(top_k_probs.shape[-1]):
    try:
        citation_id = top_k_indices[0, i].item()
        prob = top_k_probs[0, i].item()
        citation_key = tokenizer.decode(citation_id)
        print(citation_key)
        # citation_val = citation_map[citation_key]
        citation_val = [k for k,v in citation_map.items() if v == citation_key][0]
        print(f"Citation {citation_val}, Probability: {prob}")
        # if citation_id in tokenizer.get_added_vocab():
        #     citation_key = tokenizer.get_added_vocab()[citation_id]
        #     print(f"Citation: {citation_map[citation_key]}, Probability: {prob}")
    except:
        pass


<CITATION_5825>
Citation for, Probability: 0.029667695984244347
<CITATION_2002>
Citation generalization, Probability: 0.013682132586836815
<CITATION_9241>
Citation BERT, Probability: 0.013249002397060394
<CITATION_2280>
Citation transformer, Probability: 0.01189996674656868
<CITATION_1563>
Citation supervised, Probability: 0.011230929754674435


In [50]:
citation_map['infi']

KeyError: 'infinitewly wide neural networks'