In [3]:
from utils import * 

In [15]:
# Standard library imports
import re
from dataclasses import dataclass
from typing import Optional

# Third-party imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    PreTrainedModel,
    PretrainedConfig
)

@dataclass
class ExperimentConfig:
    """Configuration for the citation matching model."""
    model_name: str = "bert-base-uncased"
    max_length: int = 512
    source_len: int = 512
    target_len: int = 128
    max_targets: int = 5
    overlap: float = 0.5
    cite_token: str = "<CITE>"
    ref_token: str = "<REF>"
    temperature: float = 0.07
    collate_sample_size: int = 5000
    device: Optional[torch.device] = None

    def __post_init__(self):
        if self.device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def retrieve_citations(
    model,
    query_text: str,
    target_texts: list,
    tokenizer,
    config,
    k: int = 5,
    device = None
):
    """
    Retrieve top k similar articles for each citation in a query document.
    
    Args:
        model: Trained CitationModel
        query_text: Source text with citations marked with [[citation]]
        target_texts: List of potential target articles
        tokenizer: Tokenizer used by the model
        config: Configuration object with model parameters
        k: Number of top articles to retrieve per citation
        device: Device to run inference on
        
    Returns:
        List of tuples containing (citation_span, list of top k articles with scores)
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    model = model.to(device)

    # First, process the query document similar to WikiProcessor._find_citations
    citations = []
    for match in re.finditer(r'\[\[(.*?)\]\]', query_text):
        citations.append((match.start(), match.end(), match.group(1)))
    
    # Process query text similar to tokenize_sources
    query_encoded = tokenizer.encode_plus(
        query_text,
        add_special_tokens=False,
        return_offsets_mapping=True,
        padding=False,
        return_tensors=None
    )
    
    # Create offset to index mapping
    offset_mapping = query_encoded["offset_mapping"]
    off2i = {s:i for i, (s,_) in enumerate(offset_mapping)}
    off2i.update({e:i+1 for i, (_,e) in enumerate(offset_mapping)})
    
    # Create citation tokens array similar to tokenize_sources
    input_ids = query_encoded["input_ids"]
    cite_tokens = np.zeros(len(input_ids), dtype=int)
    mask_tokens = np.zeros(len(input_ids), dtype=int)
    
    # Fill in citations
    citation_indices = []
    for i, j, _ in citations:
        s, e = off2i[i], off2i[j]
        cite_tokens[s] = 1  # Using 1 as a placeholder
        mask_tokens[s:e] = 1
        citation_indices.append(s)
    
    # Prepare source similar to collate function
    mask_tokens = np.where(np.isin(input_ids, tokenizer.convert_tokens_to_ids(['[',']'])), 1, mask_tokens)
    mask_tokens[cite_tokens == 1] = 0
    input_ids = np.array(input_ids)
    input_ids[cite_tokens == 1] = tokenizer.convert_tokens_to_ids(config.cite_token)
    source_ids = input_ids[mask_tokens == 0]
    
    # Pad or truncate source
    if len(source_ids) > config.source_len:
        source_ids = source_ids[:config.source_len]
    else:
        source_ids = np.pad(source_ids, 
                           (0, config.source_len - len(source_ids)),
                           'constant', 
                           constant_values=tokenizer.pad_token_id)
    
    # Process target texts
    target_encoded = []
    for target in target_texts:
        tokens = tokenizer.encode_plus(
            target,
            add_special_tokens=False,
            padding=False,
            return_tensors=None
        )["input_ids"]
        
        if len(tokens) >= config.target_len - 1:
            tokens = tokens[:config.target_len-1]
        tokens = np.append(tokens, tokenizer.convert_tokens_to_ids(config.ref_token))
        
        if len(tokens) < config.target_len:
            tokens = np.pad(tokens,
                          (0, config.target_len - len(tokens)),
                          'constant',
                          constant_values=tokenizer.pad_token_id)
        
        target_encoded.append(tokens)
    
    target_ids = torch.tensor(target_encoded, dtype=torch.long).to(device)
    source_ids = torch.tensor(source_ids, dtype=torch.long).unsqueeze(0).to(device)
    attention_mask = (source_ids != tokenizer.pad_token_id).to(device)
    target_attention_mask = (target_ids != tokenizer.pad_token_id).to(device)
    
    # Get embeddings
    with torch.no_grad():
        # Get source embeddings
        source_outputs = model.transformer(
            input_ids=source_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Get target embeddings
        target_outputs = model.transformer(
            input_ids=target_ids,
            attention_mask=target_attention_mask,
            return_dict=True
        )
        
        # Extract citation and reference embeddings
        cite_mask = model.get_citation_masks(source_ids)
        cite_embeds = source_outputs.last_hidden_state[cite_mask]
        
        ref_mask = model.get_reference_masks(target_ids)
        ref_embeds = target_outputs.last_hidden_state[ref_mask]
        
        # Normalize embeddings
        cite_embeds = F.normalize(cite_embeds, p=2, dim=-1)
        ref_embeds = F.normalize(ref_embeds, p=2, dim=-1)
        
        # Compute similarity scores
        logits = torch.matmul(cite_embeds, ref_embeds.t()) / model.config.temperature
        scores = F.softmax(logits, dim=-1)
        
        # Get top k for each citation
        top_k_scores, top_k_indices = torch.topk(scores, k=min(k, len(target_texts)), dim=1)
    
    # Prepare results
    results = []
    for i, (_, _, citation_text) in enumerate(citations):
        top_matches = []
        for j, idx in enumerate(top_k_indices[i]):
            top_matches.append({
                'text': target_texts[idx],
                'score': float(top_k_scores[i][j])
            })
        results.append({
            'citation_text': citation_text,
            'matches': top_matches
        })
    
    return results

def print_citation_results(results, max_preview_length=200):
    """
    Print the citation retrieval results in a readable format.
    
    Args:
        results: List of results from retrieve_citations
        max_preview_length: Maximum length of text preview to show
    """
    for i, result in enumerate(results, 1):
        print(f"\nCitation {i}: [[{result['citation_text']}]]")
        print("\nTop matches:")
        for j, match in enumerate(result['matches'], 1):
            preview = match['text'][:max_preview_length]
            if len(match['text']) > max_preview_length:
                preview += "..."
            print(f"\n{j}. Score: {match['score']:.4f}")
            print(f"Preview: {preview}")
        print("\n" + "="*80)


# Load model and tokenizer
# Create model config
config = ExperimentConfig(collate_sample_size=50000,)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.add_special_tokens({
    'additional_special_tokens': [config.cite_token, config.ref_token]
})
model_config = CitationConfig(
    base_model_name=config.model_name,
    vocab_size=len(tokenizer),
    cite_token_id=tokenizer.convert_tokens_to_ids(config.cite_token),
    ref_token_id=tokenizer.convert_tokens_to_ids(config.ref_token),
    temperature=config.temperature,
)

# Initialize model
model = CitationModel(model_config)
checkpoint = torch.load('./experiments/best_citation_model_backup.pt')
model.load_state_dict(checkpoint['model_state_dict'])


  checkpoint = torch.load('./experiments/best_citation_model_backup.pt')


<All keys matched successfully>

In [17]:
# Example texts with varying degrees of relevance

target_texts = [
    # Directly related (original paper and closest variants)
    """Attention Is All You Need introduces the transformer architecture, a novel sequence transduction model based entirely on attention mechanisms, dispensing with recurrence and convolutions entirely. The proposed model, called the Transformer, applies self-attention to compute representations of its input and output without using sequence-aligned recurrent neural networks (RNNs) or convolution. Experiments on translation tasks demonstrate superior quality while being more parallelizable and requiring significantly less time to train.""",
    
    """BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding presents a new language representation model that uses bidirectional training of Transformer, a popular attention model, to pre-train deep bidirectional representations from unlabeled text. BERT achieves state-of-the-art performance on eleven natural language processing tasks.""",
    
    """GPT-3: Language Models are Few-Shot Learners demonstrates that scaling up language models greatly improves task-agnostic, few-shot performance. Using a transformer architecture with 175 billion parameters, GPT-3 achieves strong performance on many NLP tasks and benchmarks without any fine-tuning, sometimes matching or exceeding state-of-the-art performance.""",
    
    # Loosely related (discussing attention or transformers in different contexts)
    """Neural Machine Translation by Jointly Learning to Align and Translate introduces an attention mechanism that allows a model to automatically (soft-)search for parts of a source sentence that are relevant to predicting a target word, without having to form these parts as a hard segment explicitly. This approach achieves significant improvements in translation performance.""",
    
    """XLNet: Generalized Autoregressive Pretraining for Language Understanding proposes a generalized autoregressive pretraining method that enables learning bidirectional contexts by maximizing the expected likelihood over all permutations of the factorization order. Additionally, XLNet integrates ideas from Transformer-XL into pretraining.""",
    
    """T5: Exploring the Limits of Transfer Learning presents a unified framework that converts all text-based language problems into a text-to-text format. Using this framework, we study different pre-training objectives, architectures, unlabeled datasets, transfer approaches, and other factors on dozens of language understanding tasks.""",
    
    # Remotely related (general ML/DL papers)
    """Deep Residual Learning for Image Recognition introduces ResNet, which explicitly reformulates the layers as learning residual functions with reference to the layer inputs, instead of learning unreferenced functions. These networks can be substantially deeper, leading to improved performance on visual recognition tasks.""",
    
    """Adam: A Method for Stochastic Optimization presents a method for efficient stochastic optimization that only requires first-order gradients with little memory requirement. The method computes individual adaptive learning rates for different parameters from estimates of first and second moments of the gradients.""",
    
    """Dropout: A Simple Way to Prevent Neural Networks from Overfitting presents dropout, a technique where randomly selected neurons are ignored during training. This prevents units from co-adapting too much by randomly dropping out a proportion of the hidden units on each presentation of each training case.""",
    
    # Not directly related (different domain or focus)
    """AlphaGo: Mastering the Game of Go with Deep Neural Networks and Tree Search combines Monte Carlo tree search with deep neural networks that have been trained by supervised learning, followed by reinforcement learning. This approach achieves a high winning rate against other Go programs and defeated a human professional player.""",
    
    """Bitcoin: A Peer-to-Peer Electronic Cash System presents the original design for Bitcoin, a purely peer-to-peer version of electronic cash that allows online payments to be sent directly from one party to another without going through a financial institution.""",
    
    """ImageNet Classification with Deep Convolutional Neural Networks introduces AlexNet, a large, deep convolutional neural network that achieved record-breaking results in the ImageNet Large Scale Visual Recognition Challenge. The network was trained on two GPUs and incorporated several novel features.""",
    
    """YOLO: Real-Time Object Detection explains the You Only Look Once (YOLO) system, a new approach to object detection. Prior work on object detection repurposes classifiers to perform detection, but YOLO frames object detection as a regression problem to spatially separated bounding boxes and associated class probabilities.""",
    
    """Generative Adversarial Nets presents an adversarial process for estimating generative models via a framework where two models are trained simultaneously: a generative model G that captures the data distribution, and a discriminative model D that estimates the probability that a sample came from the training data rather than G."""
]

# Example usage:
query_text = "The transformer architecture has revolutionized NLP by introducing a model based entirely on attention mechanisms and outperformed the existing benchmarks on a variety of tasks, this was studied in [[Attention is All You Need]],"

# Retrieve citations
results = retrieve_citations(
    model=model,
    query_text=query_text,
    target_texts=target_texts,
    tokenizer=tokenizer,
    config=config,
    k=20
)

# Print results
print_citation_results(results)


Citation 1: [[Attention is All You Need]]

Top matches:

1. Score: 0.2209
Preview: ImageNet Classification with Deep Convolutional Neural Networks introduces AlexNet, a large, deep convolutional neural network that achieved record-breaking results in the ImageNet Large Scale Visual ...

2. Score: 0.1451
Preview: BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding presents a new language representation model that uses bidirectional training of Transformer, a popular attention model...

3. Score: 0.1305
Preview: Attention Is All You Need introduces the transformer architecture, a novel sequence transduction model based entirely on attention mechanisms, dispensing with recurrence and convolutions entirely. The...

4. Score: 0.1077
Preview: Generative Adversarial Nets presents an adversarial process for estimating generative models via a framework where two models are trained simultaneously: a generative model G that captures the data di...

5. Score: 0.0816
P