In [1]:
# Import required libraries
import torch
from torch import nn
from torch.nn import Transformer
import math

# Assume we have a corpus of 2 small latex documents with a few citations
corpus = [
    ("The quick brown fox jumps over the lazy dog (Smith, 2002).", "Smith, 2002"),
    ("The history of NLP is fascinating (Jones, 1999; Clark, 2003).", "Jones, 1999; Clark, 2003"),
]

# Let's create a simple vocabulary of words for simplicity, in a real application this would be more complex
word_to_ix = {"<PAD>": 0, "<SOS>": 1, "<EOS>": 2, "<UNK>": 3, 
              "The": 4, "quick": 5, "brown": 6, "fox": 7, "jumps": 8, "over": 9, "the": 10, "lazy": 11, "dog": 12, 
              "history": 13, "of": 14, "NLP": 15, "is": 16, "fascinating": 17,
              "Smith, 2002": 18, "Jones, 1999; Clark, 2003": 19}

# Function to convert sequence of words to Tensor of indices
def sequence_to_tensor(sequence):
    return torch.tensor([word_to_ix.get(word, 3) for word in sequence.split()])  # 3 is for <UNK>

# Prepare the data
X = torch.stack([sequence_to_tensor(sequence)[:5] for sequence, _ in corpus])
Y = torch.tensor([word_to_ix.get(target, 3) for _, target in corpus])

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Define the model
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, nhid, nlayers):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.src_mask = None
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layers = torch.nn.TransformerEncoderLayer(d_model, nhead, nhid)
        self.transformer_encoder = torch.nn.TransformerEncoder(encoder_layers, nlayers)
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.d_model = d_model
        self.decoder = nn.Linear(d_model, vocab_size)
        
    def _generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def forward(self, src):
#         if self.src_mask is None or self.src_mask.size(0) != len(src):
#             device = src.device
#             mask = self._generate_square_subsequent_mask(len(src)).to(device)
#             self.src_mask = mask

        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, self.src_mask)
        output = self.decoder(output)
        return output

# Hyperparameters for Transformer model.
vocab_size = len(word_to_ix)
d_model = 512
nhead = 2
nhid = 200
nlayers = 2

# Initialize the model
model = TransformerModel(vocab_size, d_model, nhead, nhid, nlayers)

# Training process is skipped here for brevity


In [2]:
# Import required libraries
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW

# Load pre-trained model tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# Add special tokens (if any) to the tokenizer
additional_tokens = ["<CITATION_1>", "<CITATION_2>", "<CITATION_3>"]  # Replace these with your actual citation tokens
num_added_tokens = tokenizer.add_tokens(additional_tokens)

# Assign a pad token. For GPT-2, we usually use the EOS token as the pad token
tokenizer.pad_token = tokenizer.eos_token

# Load pre-trained model
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))  # Update the model vocabulary size

# Prepare the data
# For simplicity let's assume the citations are replaced by special tokens in the text
corpus = [
    "The quick brown fox jumps over the lazy dog <CITATION_1>.",
    "The history of NLP is fascinating <CITATION_2> <CITATION_3>.",
]

inputs = tokenizer(corpus, return_tensors='pt', padding=True, truncation=True)

# Define the training loop
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
inputs.to(device)
optimizer = AdamW(model.parameters(), lr=1e-5)  # Set an appropriate learning rate

model.train()  # Set the model to training mode
for epoch in range(5):
    optimizer.zero_grad()
    outputs = model(**inputs, labels=inputs['input_ids'])
    loss = outputs.loss
    print(loss.item())
    loss.backward()
    optimizer.step()

# Training is skipped for brevity. Please train the model appropriately and save the trained model.


  with safe_open(checkpoint_file, framework="pt") as f:
  return self.fget.__get__(instance, owner)()
  storage = cls(wrap_storage=untyped_storage)
  with safe_open(filename, framework="pt", device=device) as f:


69.77877807617188
59.6018180847168
65.859130859375
61.3740119934082
61.24586868286133


In [3]:
import re
from tqdm.auto import tqdm 


def preprocess_latex_corpus(corpus):
    citation_id = 1
    citation_map = {}  # To keep track of the citation to id mapping

    processed_corpus = []

    # Define the regex patterns
    citation_pattern = re.compile(r"\\cite[t,p]?{(.*?)}")
    command_pattern = re.compile(r"\\[a-zA-Z]+")
    comment_pattern = re.compile(r"%.*?$", re.MULTILINE)
    environment_pattern = re.compile(r"\\begin{(figure|table|equation).*?\\end{\1}", re.DOTALL)

    for document in tqdm(corpus):
        # Remove comments
        document = re.sub(comment_pattern, "", document)

        # Remove non-informative environments
        document = re.sub(environment_pattern, "", document)

        # Replace citation commands
        matches = citation_pattern.findall(document)
        for match in matches:
            citations = match.split(",")  # Handle multiple citations within one command
            for citation in citations:
                citation = citation.strip()
                
                # Assign a unique id to each citation if not already done
                if citation not in citation_map:
                    citation_map[citation] = f"<CITATION_{citation_id}>"
                    citation_id += 1
                
                # Replace the citation with the special token
                document = document.replace(citation, citation_map[citation])

        # Remove other commands
        # document = re.sub(command_pattern, "", document)
        document = document.replace('{<','<').replace('>}','>')
        processed_corpus.append(document)

    return processed_corpus, citation_map

# Test on a toy corpus
toy_corpus = [
    "The quick brown fox jumps over the lazy dog \\cite{khaled_corr19, Stich_corr19}. % This is a comment",
    "\\begin{figure} ... \\end{figure} The history of NLP is fascinating \\citep{Stich_iclr19, wang_corr18}.",
    "This is an \\textbf{interesting} study \\citet{yu_aaai19}.",
]

processed_corpus, citation_map = preprocess_latex_corpus(toy_corpus)

print(processed_corpus)
print(citation_map)


  0%|          | 0/3 [00:00<?, ?it/s]

['The quick brown fox jumps over the lazy dog \\cite<CITATION_1>, <CITATION_2>. ', ' The history of NLP is fascinating \\citep<CITATION_3>, <CITATION_4>.', 'This is an \\textbf{interesting} study \\citet<CITATION_5>.']
{'khaled_corr19': '<CITATION_1>', 'Stich_corr19': '<CITATION_2>', 'Stich_iclr19': '<CITATION_3>', 'wang_corr18': '<CITATION_4>', 'yu_aaai19': '<CITATION_5>'}


In [4]:
import os 
def read_tex_files(directory):
    corpus = []
    for root, dirs, files in tqdm(os.walk(directory)):
        try:
            for file_name in files:
                if file_name.endswith('.tex'):
                    # print(file_name)
                    file_path = os.path.join(root, file_name)
                    with open(file_path, 'r') as file:
                        corpus.append(file.read())
        except:
            pass
    return corpus 
corpus = read_tex_files('./sources/')


0it [00:00, ?it/s]

In [6]:
processed_corpus, citation_map = preprocess_latex_corpus(corpus[:100])
# print(processed_corpus[2])

  0%|          | 0/100 [00:00<?, ?it/s]

In [7]:
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer


class CitationDataset(Dataset):
    def __init__(self, text_list, tokenizer, seq_len=1024):
        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.inputs = []
        self.targets = []

        for text in text_list:
            self.process_text(text)

    def process_text(self, text):
        tokenized_text = self.tokenizer.encode(text)

        # Create sequences of the specified length with stride half of the length
        for i in tqdm(range(0, len(tokenized_text)-self.seq_len, self.seq_len//2),leave=True):
            input_sequence = tokenized_text[i:i+self.seq_len]
            target_sequence = tokenized_text[i+1:i+self.seq_len+1]  # Shifted right

            self.inputs.append(torch.tensor(input_sequence))
            self.targets.append(torch.tensor(target_sequence))

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.targets[idx]

# Add the special tokens to the pre-trained GPT-2 tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
special_tokens_dict = {'additional_special_tokens': list(citation_map.values())}
tokenizer.add_special_tokens(special_tokens_dict)


# Split the processed_corpus into training and validation sets
train_size = int(len(processed_corpus) * 0.8)
train_texts = processed_corpus[:train_size]
val_texts = processed_corpus[train_size:]

# Create the datasets
train_dataset = CitationDataset(train_texts, tokenizer)
val_dataset = CitationDataset(val_texts, tokenizer)

# Create the data loaders
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False)


Token indices sequence length is longer than the specified maximum sequence length for this model (21926 > 1024). Running this sequence through the model will result in indexing errors


  0%|          | 0/41 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/26 [00:00<?, ?it/s]

  0%|          | 0/33 [00:00<?, ?it/s]

  0%|          | 0/33 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/30 [00:00<?, ?it/s]

  0%|          | 0/12 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/7 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/165 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

  0%|          | 0/44 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/14 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/55 [00:00<?, ?it/s]

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/40 [00:00<?, ?it/s]

  0%|          | 0/54 [00:00<?, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

0it [00:00, ?it/s]

In [None]:
from transformers import GPT2LMHeadModel, AdamW
import torch

# Adjust the model vocabulary size
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.resize_token_embeddings(len(tokenizer))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

optimizer = AdamW(model.parameters(), lr=1e-4)

# Training loop
model.train()
for epoch in range(100):  # Number of epochs
    for i, (inputs, targets) in tqdm(enumerate(train_loader),total=len(train_loader)):
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs, labels=targets)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}")
            
    # Validation loop
    model.eval()
    total_loss = 0
    for i, (inputs, targets) in enumerate(val_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)

        with torch.no_grad():
            outputs = model(inputs, labels=targets)
            loss = outputs.loss
            total_loss += loss.item()

    print(f"Validation Loss: {total_loss / len(val_loader)}")


  0%|          | 0/624 [00:00<?, ?it/s]

Epoch: 0, Iteration: 0, Loss: 55.603580474853516
Epoch: 0, Iteration: 10, Loss: 7.273072242736816
Epoch: 0, Iteration: 20, Loss: 7.198564052581787
Epoch: 0, Iteration: 30, Loss: 5.793727874755859
Epoch: 0, Iteration: 40, Loss: 6.297300338745117
Epoch: 0, Iteration: 50, Loss: 5.869672775268555
Epoch: 0, Iteration: 60, Loss: 5.798858642578125
Epoch: 0, Iteration: 70, Loss: 5.601370334625244
Epoch: 0, Iteration: 80, Loss: 6.834584712982178
Epoch: 0, Iteration: 90, Loss: 5.522229194641113
Epoch: 0, Iteration: 100, Loss: 3.948033094406128
Epoch: 0, Iteration: 110, Loss: 5.019525051116943
Epoch: 0, Iteration: 120, Loss: 4.867307186126709
Epoch: 0, Iteration: 130, Loss: 6.743886947631836


In [52]:
len(corpus)

21806