In [1]:
import re
from tqdm.auto import tqdm 
import os


def preprocess_latex_corpus(corpus):
    citation_id = 1
    citation_map = {}  # To keep track of the citation to id mapping

    processed_corpus = []

    # Define the regex patterns
    citation_pattern = re.compile(r"\\cite[t,p]?{(.*?)}")
    command_pattern = re.compile(r"\\[a-zA-Z]+")
    comment_pattern = re.compile(r"%.*?$", re.MULTILINE)
    environment_pattern = re.compile(r"\\begin{(figure|table|equation).*?\\end{\1}", re.DOTALL)

    for document in tqdm(corpus):
        # Remove comments
        document = re.sub(comment_pattern, "", document)

        # Remove non-informative environments
        # document = re.sub(environment_pattern, "", document)

        # Replace citation commands
        matches = citation_pattern.findall(document)
        for match in matches:
            citations = match.split(",")  # Handle multiple citations within one command
            for citation in citations:
                citation = citation.strip()
                
                # Assign a unique id to each citation if not already done & if not empty
                if citation not in citation_map and citation != "":
                    citation_map[citation] = f"<CITATION_{citation_id}>"
                    citation_id += 1
                
                    # Replace the citation with the special token
                    document = document.replace(citation, citation_map[citation])

        # Remove other commands
        # document = re.sub(command_pattern, "", document)
        document = document.replace('{<','<').replace('>}','>')
        if r'\begin{document}' not in document:
            continue
        document = document.split(r'\begin{document}')[1] # get the body
        processed_corpus.append(document)

    return processed_corpus, citation_map


def read_tex_files(directory):
    corpus = []
    for root, dirs, files in tqdm(os.walk(directory)):
        try:
            for file_name in files:
                if file_name.endswith('.tex'):
                    # print(file_name)
                    file_path = os.path.join(root, file_name)
                    with open(file_path, 'r') as file:
                        content = file.read()
                        corpus.append(content)
        except:
            pass
    return corpus 

corpus = read_tex_files('./sources/')
processed_corpus, citation_map = preprocess_latex_corpus(corpus)

# pint number of corpus before & after preprocessing
print(f'Number of documents before preprocessing: {len(corpus)}')
print(f'Number of documents after preprocessing: {len(processed_corpus)}')



0it [00:00, ?it/s]

  0%|          | 0/3236 [00:00<?, ?it/s]

Number of documents before preprocessing: 3236
Number of documents after preprocessing: 474


In [6]:
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer
import torch


class CitationDataset(Dataset):
    def __init__(self, text_list, tokenizer, seq_len=1024):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.inputs = []
        self.targets = []

        for text in tqdm(text_list):
            self.process_text(text)

    def process_text(self, text):
        tokenized_text = self.tokenizer.encode(text)
        # citation_mask = tokenized_text.ge(self.tokenizer.additional_special_tokens_ids[0])  # Shape: (batch_size, seq_len)
        # if citation_mask.sum() == 0: # No citations in the text
        #     return

        # Create sequences of the specified length with stride half of the length
        for i in range(0, len(tokenized_text)-self.seq_len, self.seq_len//2):
            if i+self.seq_len+1 >= len(tokenized_text):
                break
            input_sequence = tokenized_text[i:i+self.seq_len]
            target_sequence = tokenized_text[i+1:i+self.seq_len+1]  # Shifted right

            self.inputs.append(torch.tensor(input_sequence))
            self.targets.append(torch.tensor(target_sequence))

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

# Add the special tokens to the pre-trained GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
special_tokens_dict = {'additional_special_tokens': list(citation_map.values())}
tokenizer.add_special_tokens(special_tokens_dict)


# Split the processed_corpus into training and validation sets
train_size = int(len(processed_corpus) * 0.8)
train_texts = processed_corpus[:train_size]
val_texts = processed_corpus[train_size:]

# Create the datasets
seq_len = 1024
train_dataset = CitationDataset(train_texts, tokenizer, seq_len=seq_len)
val_dataset = CitationDataset(val_texts, tokenizer, seq_len=seq_len)


  0%|          | 0/379 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1264 > 1024). Running this sequence through the model will result in indexing errors


  0%|          | 0/95 [00:00<?, ?it/s]

In [7]:
from transformers import GPT2LMHeadModel, AdamW

import torch

# Adjust the model vocabulary size
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))

# Create the data loaders
batch_size = 4
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=True)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-4)

# Training loop
criterion = torch.nn.CrossEntropyLoss()

for epoch in range(100):  # Number of epochs
    model.train()
    bar = tqdm(enumerate(train_loader),total=len(train_loader),)
    running_loss = []
    for i, (inputs, targets) in bar:
        inputs = inputs.to(device)
        targets = targets.to(device)
        citation_mask = targets.ge(tokenizer.additional_special_tokens_ids[0])  # Shape: (batch_size, seq_len)
        if citation_mask.sum() < 3:
            continue # Skip batch if it contains no citation tokens

        optimizer.zero_grad()
        outputs = model(inputs,)
        logits = outputs.logits.view(-1, outputs.logits.size(-1))
        targets_flattened = targets.view(-1)

        # Only include the citation logits and targets in the loss calculation
        citation_logits = logits[citation_mask.view(-1)]
        citation_targets = targets_flattened[citation_mask.view(-1)]

        loss = criterion(citation_logits, citation_targets)
        # loss = criterion(outputs.logits.view(-1,outputs.logits.size(-1)), targets.flatten())
        # loss = outputs.loss
        loss.backward()
        optimizer.step()

        running_loss.append(loss.item())
        if (i+1) % 100 == 0:
            # print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")
            avg_loss = sum(running_loss) / len(running_loss)
            bar.set_postfix(iteration=i, loss=f"{avg_loss:.5f}")
            running_loss = []
        if i > 100:
            break
            
    # Validation loop
    model.eval()
    with torch.no_grad():
        total_loss = 0
        total_acc = 0
        total_citations = 0
        bar = tqdm(val_loader, total=len(val_loader),)
        running_loss = []
        for i, (inputs, targets) in enumerate(bar):
            if i > 200:
                break
            inputs = inputs.to(device)
            targets = targets.to(device)
            citation_mask = targets.ge(tokenizer.additional_special_tokens_ids[0])  # Shape: (batch_size, seq_len)
            if citation_mask.sum() == 0:
                continue
            outputs = model(inputs)
            logits = outputs.logits.view(-1, outputs.logits.size(-1))
            targets_flattened = targets.view(-1)

            # Only include the citation logits and targets in the loss calculation
            citation_logits = logits[citation_mask.view(-1)]
            citation_targets = targets_flattened[citation_mask.view(-1)]

            loss = criterion(citation_logits, citation_targets)
            total_loss += loss.item()

            if (i+1) % 50 == 0:
                avg_loss = total_loss / (i+1)
                bar.set_postfix(iteration=i, loss=f"{avg_loss:.5f}")
                running_loss = []

            # Compute the citation accuracy
            citation_predictions = torch.argmax(citation_logits, dim=-1)  # Shape: (batch_size, seq_len)
            correct_predictions = (citation_predictions == citation_targets).sum().item()
            total_citations += citation_targets.numel()
            total_acc += correct_predictions
            # print('corr pred = ', correct_predictions, ' cites = ', citation_targets.numel(), ' total acc = ', total_acc, ' total cites = ', total_citations)

        print(f"Validation Loss: {total_loss / len(val_loader)}")
        print(f"Citation Accuracy: {total_acc / total_citations if total_citations > 0 else 0}")





  0%|          | 0/3106 [00:00<?, ?it/s]

  0%|          | 0/420 [00:00<?, ?it/s]

Validation Loss: 4.797319353194464
Citation Accuracy: 0.0


  0%|          | 0/3106 [00:00<?, ?it/s]

  0%|          | 0/420 [00:00<?, ?it/s]

Validation Loss: 4.885184919266473
Citation Accuracy: 0.0


  0%|          | 0/3106 [00:00<?, ?it/s]

  0%|          | 0/420 [00:00<?, ?it/s]

In [51]:
import torch.nn.functional as F

# Define the sample text
sample_text = "Batch normalization improves training efficiency of neural networks <CITATION_1>"

# # Preprocess the text by replacing the <cite_1> token with a <mask> token
# mask_token = tokenizer.mask_token
# masked_text = sample_text.replace("<cite_1>", mask_token)

# Tokenize the text and convert it to a tensor
inputs = tokenizer(sample_text, return_tensors="pt")

# Send the tensor to the device
inputs = {name: tensor.to(device) for name, tensor in inputs.items()}

# Forward pass through the model
with torch.no_grad():
    outputs = model(**inputs)

# Get the logits and compute the softmax to get probabilities
logits = outputs.logits
probs = F.softmax(logits, dim=-1)

# Get the position of the <mask> token
mask_position = torch.where(inputs["input_ids"] == tokenizer.encode('<CITATION_1>')[0])[1]

# Get the probabilities of the tokens at the <mask> position
mask_probs = probs[0, mask_position, :]

# Get the top k probabilities and their indices
top_k_probs, top_k_indices = torch.topk(mask_probs, k=5)  # Get the top 5 predictions

# Print the top k citations and their probabilities
for i in range(top_k_probs.shape[-1]):
    try:
        citation_id = top_k_indices[0, i].item()
        prob = top_k_probs[0, i].item()
        citation_key = tokenizer.decode(citation_id)
        print(citation_key)
        # citation_val = citation_map[citation_key]
        citation_val = [k for k,v in citation_map.items() if v == citation_key][0]
        print(f"Citation {citation_val}, Probability: {prob}")
        # if citation_id in tokenizer.get_added_vocab():
        #     citation_key = tokenizer.get_added_vocab()[citation_id]
        #     print(f"Citation: {citation_map[citation_key]}, Probability: {prob}")
    except:
        pass


<CITATION_5825>
Citation for, Probability: 0.029667695984244347
<CITATION_2002>
Citation generalization, Probability: 0.013682132586836815
<CITATION_9241>
Citation BERT, Probability: 0.013249002397060394
<CITATION_2280>
Citation transformer, Probability: 0.01189996674656868
<CITATION_1563>
Citation supervised, Probability: 0.011230929754674435


In [50]:
citation_map['infi']

KeyError: 'infinitewly wide neural networks'