# Reference Matching Pipeline

## Pipeline Steps:
1. Data Cleaning
2. Data Labelling (Manual + Automatic)
3. Feature Engineering
4. Data Modeling
5. Model Evaluation (MRR metric)

In [1]:
import re
import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
import warnings
warnings.filterwarnings('ignore')

# For string similarity
from difflib import SequenceMatcher
import Levenshtein

from sentence_transformers import SentenceTransformer

In [2]:
def parse_bib_file(bib_path: Path) -> Dict[str, Dict[str, str]]:
    entries = {}

    try:
        with open(bib_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        pos = 0
        while pos < len(content):
            # Find @entry{
            entry_start = content.find('@', pos)
            if entry_start == -1:
                break

            # Find entry type and key
            brace_start = content.find('{', entry_start)
            if brace_start == -1:
                break

            entry_type = content[entry_start+1:brace_start].strip()
            if not entry_type:
                pos = brace_start + 1
                continue

            # Find key (until comma or newline)
            comma_pos = content.find(',', brace_start)
            if comma_pos == -1:
                break

            key = content[brace_start+1:comma_pos].strip()

            # Find matching closing brace
            brace_count = 1
            field_start = comma_pos + 1
            pos = field_start

            while pos < len(content) and brace_count > 0:
                if content[pos] == '{':
                    brace_count += 1
                elif content[pos] == '}':
                    brace_count -= 1
                pos += 1
            
            if brace_count == 0:
                fields_text = content[field_start:pos-1]
                
                # Parse fields
                fields = {}
                # Match field = {value} or field = "value"
                field_pattern = re.compile(r'(\w+)\s*=\s*(\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}|"[^"]*"|\d+)', re.DOTALL)
                for field_match in field_pattern.finditer(fields_text):
                    field_name = field_match.group(1).lower()
                    field_value = field_match.group(2).strip()
                    
                    # Remove braces/quotes
                    if field_value.startswith('{') and field_value.endswith('}'):
                        field_value = field_value[1:-1]
                    elif field_value.startswith('"') and field_value.endswith('"'):
                        field_value = field_value[1:-1]
                    
                    fields[field_name] = field_value
                
                entries[key] = fields
        
    except Exception as e:
        print(f"Error parsing {bib_path}: {e}")
    
    return entries

def load_all_bibtex(output_dir: Path) -> Dict[str, Dict[str, Dict[str, str]]]:
    all_bibtex = {}

    for paper_dir in output_dir.iterdir():
        if not paper_dir.is_dir():
            continue

        arxiv_id = paper_dir.name
        bib_files = list(paper_dir.glob("*_bibtex.bib"))

        if bib_files:
            bib_path = bib_files[0]
            entries = parse_bib_file(bib_path)
            if entries:
                all_bibtex[arxiv_id] = entries
    
    return all_bibtex

In [3]:
def load_all_references(output_dir: Path) -> Dict[str, Dict[str, Dict[str, Any]]]:
    """Load all references.json files"""
    all_references = {}
    
    for paper_dir in output_dir.iterdir():
        if not paper_dir.is_dir():
            continue
        
        arxiv_id = paper_dir.name
        ref_path = paper_dir / "references.json"
        
        if ref_path.exists():
            try:
                with open(ref_path, 'r', encoding='utf-8-sig') as f:
                    references = json.load(f)
                all_references[arxiv_id] = references
            except Exception as e:
                print(f"Error loading {ref_path}: {e}")
    
    return all_references

In [4]:
def normalize_text(text: str) -> str:
    """Normalize text for comparison"""
    if not text:
        return ""
    
    # Lowercase
    text = text.lower()
    
    # Remove LaTeX commands
    text = re.sub(r'\\[a-zA-Z]+\{([^}]*)\}', r'\1', text)
    
    # Remove special characters but keep spaces
    text = re.sub(r'[{}]', '', text)
    text = re.sub(r'[^a-z0-9\s]', ' ', text)
    
    # Normalize whitespace
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text

def extract_authors(author_str: str) -> List[str]:
    """Extract individual authors from author string"""
    if not author_str:
        return []
    
    # Split by common delimiters
    authors = re.split(r'\s+and\s+|,|&', author_str.lower())
    authors = [a.strip() for a in authors if a.strip()]
    
    return authors

def get_first_author_last_name(author_str: str) -> str:
    """Extract last name of first author"""
    authors = extract_authors(author_str)
    if not authors:
        return ""
    
    first_author = authors[0]
    # Last name is typically the last word
    parts = first_author.split()
    if len(parts) > 0:
        return parts[-1]
    return first_author

def normalize_year(year_str: str) -> Optional[int]:
    """Extract year as integer"""
    if not year_str:
        return None
    
    # Extract 4-digit year
    match = re.search(r'\b(19|20)\d{2}\b', str(year_str))
    if match:
        return int(match.group(0))
    return None

In [5]:
embedder = SentenceTransformer('all-MiniLM-L6-v2')
_embedding_cache = {}

def get_embedding(text: str) -> np.ndarray:
    if text not in _embedding_cache:
        _embedding_cache[text] = embedder.encode(
            text, convert_to_numpy=True, normalize_embeddings=True
        )
    return _embedding_cache[text]


def extract_features(bibtex_entry: Dict[str, str], reference_entry: Dict[str, Any]) -> Dict[str, float]:
    """Extract matching features between BibTeX and reference entry"""
    features = {}

    bib_title = normalize_text(bibtex_entry.get('title', ''))
    ref_title = normalize_text(reference_entry.get('paper_title', ''))

    features['title_exact_match'] = int(bib_title == ref_title)
    features['title_ratio'] = SequenceMatcher(None, bib_title, ref_title).ratio()

    try:
        features['title_levenshtein'] = 1.0 - (Levenshtein.distance(bib_title, ref_title) / max(len(bib_title), len(ref_title), 1))
    except:
        features['title_levenshtein'] = features['title_ratio']

    # Word overlap
    bib_words = set(bib_title.split())
    ref_words = set(ref_title.split())
    if bib_words or ref_words:
        features['title_jaccard'] = len(bib_words & ref_words) / len(bib_words | ref_words) if (bib_words | ref_words) else 0.0
        features['title_word_overlap'] = len(bib_words & ref_words) / max(len(bib_words), 1)
    else:
        features['title_jaccard'] = 0.0
        features['title_word_overlap'] = 0.0
    
    # Author features
    bib_author = bibtex_entry.get('author', '')
    ref_authors = reference_entry.get('authors', [])
    
    bib_authors_list = extract_authors(bib_author)
    ref_authors_normalized = [normalize_text(a) for a in ref_authors]

    # First author match
    if bib_authors_list and ref_authors_normalized:
        bib_first_last = get_first_author_last_name(bib_author)
        ref_first_parts = ref_authors_normalized[0].split()
        ref_first_last = ref_first_parts[-1] if ref_first_parts else ""
        features['first_author_lastname_match'] = 1.0 if bib_first_last == ref_first_last else 0.0
    else:
        features['first_author_lastname_match'] = 0.0

    # Author overlap
    bib_author_set = set(bib_authors_list)
    ref_author_set = set(ref_authors_normalized)
    if bib_author_set or ref_author_set:
        features['author_jaccard'] = len(bib_author_set & ref_author_set) / max(len(bib_author_set | ref_author_set), 1)
    else:
        features['author_jaccard'] = 0.0

    # Year features
    bib_year = normalize_year(bibtex_entry.get('year', ''))
    ref_year = normalize_year(reference_entry.get('submission_date', ''))

    if bib_year and ref_year:
        year_gap = abs(bib_year - ref_year)
        features['year_similarity'] = np.exp(-(year_gap ** 2) / 2)
    else:
        features['year_similarity'] = 0.0

    # Combined features
    features['title_author_score'] = (features['title_ratio'] + features['author_jaccard']) / 2.0
    features['composite_score'] = (
        features['title_ratio'] * 0.4 +
        features['author_jaccard'] * 0.3 +
        features['year_similarity'] * 0.2 +
        features['first_author_lastname_match'] * 0.1
    )

    # Embedding similarity
    if bib_title and ref_title:
        try:
            bib_emb = get_embedding(bib_title)
            ref_emb = get_embedding(ref_title)
            # Cosine similarity
            features["title_embedding_sim"] = float(np.dot(bib_emb, ref_emb))
        except:
            features['title_embedding_sim'] = 0.0
    else:
        features['title_embedding_sim'] = 0.0
    
    return features

In [6]:
# Load data
OUTPUT_DIR = Path("23127238_output")

all_bibtex = load_all_bibtex(OUTPUT_DIR)
print(f"Loaded BibTeX from {len(all_bibtex)} papers")

all_references = load_all_references(OUTPUT_DIR)
print(f"Loaded references from {len(all_references)} papers")

# Filter papers that have both BibTeX and references
papers_with_both = set(all_bibtex.keys()) & set(all_references.keys())
print(f"\nPapers with both BibTeX and references: {len(papers_with_both)}")


Loaded BibTeX from 2107 papers
Loaded references from 2893 papers

Papers with both BibTeX and references: 2104


In [7]:
def match_bibtex_to_references(bibtex_entry: Dict[str, str],
                                  references: Dict[str, Dict[str, Any]]) -> Optional[str]:
    """Match a single BibTeX entry to best reference"""
    best_match = None
    best_score = -1.0
    
    for arxiv_id, ref_entry in references.items():
        features = extract_features(bibtex_entry, ref_entry)
        
        # Use composite score
        score = features['composite_score']
        
        # Additional constraints
        if features['title_ratio'] < 0.3:  # Too different titles
            continue
        
        if score > best_score:
            best_score = score
            best_match = arxiv_id
    
    # Only return if score is above threshold
    if best_score >= 0.5:
        return best_match
    
    return None

In [15]:
with open('manual_label.json', 'r', encoding='utf-8') as f:
    manual_content = json.load(f)

manual_labels = {}
for paper_id, refs in manual_content.items():
    paper_labels = {}
    for bib_key, arxiv_id in refs.items():
        paper_labels[bib_key] = arxiv_id

    if paper_labels:
        manual_labels[paper_id] = paper_labels
        print(f"  {paper_id}: {len(paper_labels)} labels")

print(f"\nTotal manual labels: {sum(len(labels) for labels in manual_labels.values())}")

  2304-14610: 20 labels
  2304-14656: 32 labels
  2304-14693: 22 labels
  2304-14796: 22 labels
  2304-14999: 22 labels

Total manual labels: 118


In [9]:
# Automatic labeling for remaining papers (at least 10%)
remaining_papers = papers_with_both - set(manual_labels.keys())
num_auto_papers = max(1, int(len(remaining_papers) * 0.1))
auto_papers = list(remaining_papers)[:num_auto_papers]

print(f"Creating automatic labels for {len(auto_papers)} papers...")

automatic_labels = {}
for paper_id in auto_papers:
    if paper_id not in all_bibtex or paper_id not in all_references:
        continue
    
    bibtex_entries = all_bibtex[paper_id]
    references = all_references[paper_id]
    
    paper_labels = {}
    for bib_key, bib_entry in bibtex_entries.items():
        match = match_bibtex_to_references(bib_entry, references)
        if match:
            paper_labels[bib_key] = match
    
    if paper_labels:
        automatic_labels[paper_id] = paper_labels
        print(f"  {paper_id}: {len(paper_labels)} labels")

print(f"\nTotal automatic labels: {sum(len(labels) for labels in automatic_labels.values())}")

Creating automatic labels for 209 papers...
  2305-00836: 4 labels
  2305-00708: 10 labels
  2305-02504: 24 labels
  2305-00681: 10 labels
  2305-00909: 40 labels
  2305-01532: 2 labels
  2305-01290: 4 labels
  2305-01145: 9 labels
  2305-01505: 23 labels
  2305-00982: 15 labels
  2305-02389: 10 labels
  2305-00477: 31 labels
  2305-00281: 9 labels
  2305-00755: 12 labels
  2305-01849: 4 labels
  2305-00184: 3 labels
  2305-01694: 51 labels
  2305-00872: 10 labels
  2305-01454: 8 labels
  2305-01510: 3 labels
  2305-01121: 24 labels
  2304-14793: 5 labels
  2305-00071: 24 labels
  2305-00860: 2 labels
  2305-01777: 26 labels
  2305-01879: 22 labels
  2305-00975: 5 labels
  2305-00607: 38 labels
  2305-01989: 2 labels
  2305-01549: 33 labels
  2305-01795: 54 labels
  2305-01191: 19 labels
  2304-14942: 16 labels
  2304-14631: 1 labels
  2305-02288: 1 labels
  2305-01151: 12 labels
  2305-01624: 37 labels
  2305-00977: 12 labels
  2305-02480: 1 labels
  2305-01741: 11 labels
  2305-01157

In [10]:
automatic_labels = dict(
    sorted(automatic_labels.items(), key=lambda x: x[0])
)

with open('automatic_label.json', 'w', encoding='utf-8') as f:
    json.dump(
        automatic_labels,
        f,
        ensure_ascii=False,
        indent=2
    )

In [11]:
def create_training_data(labels_dict: Dict[str, Dict[str, str]], 
                        all_bibtex: Dict, all_references: Dict,
                        negative_ratio: float = 1.0) -> Tuple[np.ndarray, np.ndarray]:
    """
    Create training data with positive and negative examples.
    """
    X = []
    y = []
    
    # Positive examples (matches)
    for paper_id, labels in labels_dict.items():
        if paper_id not in all_bibtex or paper_id not in all_references:
            continue
        
        bibtex_entries = all_bibtex[paper_id]
        references = all_references[paper_id]
        
        for bib_key, correct_arxiv_id in labels.items():
            if bib_key not in bibtex_entries:
                continue
            if correct_arxiv_id not in references:
                continue
            
            bib_entry = bibtex_entries[bib_key]
            ref_entry = references[correct_arxiv_id]
            
            features = extract_features(bib_entry, ref_entry)
            X.append(list(features.values()))
            y.append(1)  # Positive match
            
            # Negative examples (random incorrect matches)
            num_negatives = int(negative_ratio)
            other_arxiv_ids = [aid for aid in references.keys() if aid != correct_arxiv_id]
            
            if other_arxiv_ids:
                np.random.seed(42)  # For reproducibility
                for _ in range(min(num_negatives, len(other_arxiv_ids))):
                    wrong_arxiv_id = np.random.choice(other_arxiv_ids)
                    wrong_ref = references[wrong_arxiv_id]
                    
                    neg_features = extract_features(bib_entry, wrong_ref)
                    X.append(list(neg_features.values()))
                    y.append(0)  # Negative match
    
    return np.array(X), np.array(y)

In [12]:
# Split data according to requirements:
# Test: 1 manual + 1 automatic
# Validation: 1 manual + 1 automatic  
# Train: rest

manual_paper_list = list(manual_labels.keys())
auto_paper_list = list(automatic_labels.keys())

test_manual = manual_paper_list[0] if len(manual_paper_list) > 0 else None
test_auto = auto_paper_list[0] if len(auto_paper_list) > 0 else None

val_manual = manual_paper_list[1] if len(manual_paper_list) > 1 else None
val_auto = auto_paper_list[1] if len(auto_paper_list) > 1 else None

train_manual = [p for p in manual_paper_list if p not in [test_manual, val_manual]]
train_auto = [p for p in auto_paper_list if p not in [test_auto, val_auto]]

print("Data split:")
print(f"  Test: {test_manual}, {test_auto}")
print(f"  Validation: {val_manual}, {val_auto}")
print(f"  Train: {len(train_manual)} manual, {len(train_auto)} auto")

Data split:
  Test: 2304-14610, 2304-14631
  Validation: 2304-14656, 2304-14658
  Train: 3 manual, 141 auto


In [13]:
# Prepare training data
train_papers = {p: manual_labels[p] for p in train_manual if p in manual_labels}
train_papers.update({p: automatic_labels[p] for p in train_auto if p in automatic_labels})

X_train_full, y_train_full = create_training_data(
    train_papers, all_bibtex, all_references, negative_ratio=2.0
)

print(f"Full training set: {len(X_train_full)} examples")

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train_full)

# Train multiple models
models = {
    'random_forest': RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=-1),
    'gradient_boosting': GradientBoostingClassifier(n_estimators=100, random_state=42),
    'logistic_regression': LogisticRegression(random_state=42, max_iter=1000)
}

trained_models = {}
max_acc = 0
name_max = ''
for name, model in models.items():
    print(f"\nTraining {name}...")
    model.fit(X_train_scaled, y_train_full)
    trained_models[name] = model
    
    # Quick validation on training data
    y_pred = model.predict(X_train_scaled)
    acc = accuracy_score(y_train_full, y_pred)
    if acc > max_acc:
        max_acc = acc
        name_max = name

    print(f"  Training accuracy: {acc:.4f}")


best_model = trained_models[name_max]
print(f"\nSelected model: {best_model}")

Full training set: 9035 examples

Training random_forest...
  Training accuracy: 0.9990

Training gradient_boosting...
  Training accuracy: 0.9989

Training logistic_regression...
  Training accuracy: 0.9942

Selected model: RandomForestClassifier(n_jobs=-1, random_state=42)


In [14]:
def calculate_mrr(model, scaler,
                 bibtex_entries: Dict, references: Dict,
                 ground_truth: Dict[str, str], top_k: int = 5) -> float:
    """
    Calculate MRR for a single paper.
    
    Returns:
        MRR score (0.0 to 1.0)
    """
    if not ground_truth:
        return 0.0
    
    reciprocal_ranks = []
    
    for bib_key, correct_arxiv_id in ground_truth.items():
        if bib_key not in bibtex_entries:
            continue
        
        bib_entry = bibtex_entries[bib_key]
        
        # Get scores for all candidates
        candidate_scores = []
        for arxiv_id, ref_entry in references.items():
            features = extract_features(bib_entry, ref_entry)
            X = np.array([list(features.values())])
            X_scaled = scaler.transform(X)
            
            # Get probability of positive class
            if hasattr(model, 'predict_proba'):
                score = model.predict_proba(X_scaled)[0][1]
            else:
                score = model.decision_function(X_scaled)[0]
            
            candidate_scores.append((arxiv_id, score))
        
        # Sort by score descending
        candidate_scores.sort(key=lambda x: x[1], reverse=True)
        
        # Find rank of correct answer
        ranked_arxiv_ids = [aid for aid, _ in candidate_scores[:top_k]]
        
        if correct_arxiv_id in ranked_arxiv_ids:
            rank = ranked_arxiv_ids.index(correct_arxiv_id) + 1
            reciprocal_ranks.append(1.0 / rank)
        else:
            reciprocal_ranks.append(0.0)
    
    if not reciprocal_ranks:
        return 0.0
    
    return np.mean(reciprocal_ranks)


# Evaluate on test set
test_papers = {}
if test_manual:
    test_papers[test_manual] = manual_labels[test_manual]
if test_auto:
    test_papers[test_auto] = automatic_labels[test_auto]

print("Evaluating on test set...")
test_mrr_scores = []

for paper_id, ground_truth in test_papers.items():
    if paper_id not in all_bibtex or paper_id not in all_references:
        continue
    
    mrr = calculate_mrr(
        best_model, scaler,
        all_bibtex[paper_id], all_references[paper_id],
        ground_truth, top_k=5
    )
    
    test_mrr_scores.append(mrr)
    print(f"  {paper_id}: MRR = {mrr:.4f}")

overall_mrr = np.mean(test_mrr_scores) if test_mrr_scores else 0.0
print(f"\nOverall Test MRR: {overall_mrr:.4f}")


Evaluating on test set...
  2304-14610: MRR = 1.0000
  2304-14631: MRR = 1.0000

Overall Test MRR: 1.0000


In [15]:
def generate_predictions(model, scaler,
                        bibtex_entries: Dict, references: Dict,
                        top_k: int = 5) -> Dict[str, List[str]]:
    """Generate ranked predictions for all BibTeX entries"""
    predictions = {}
    
    for bib_key, bib_entry in bibtex_entries.items():
        candidate_scores = []
        
        for arxiv_id, ref_entry in references.items():
            features = extract_features(bib_entry, ref_entry)
            X = np.array([list(features.values())])
            X_scaled = scaler.transform(X)
            
            if hasattr(model, 'predict_proba'):
                score = model.predict_proba(X_scaled)[0][1]
            else:
                score = model.decision_function(X_scaled)[0]
            
            candidate_scores.append((arxiv_id, score))
        
        # Sort and take top k
        candidate_scores.sort(key=lambda x: x[1], reverse=True)
        ranked_candidates = [aid for aid, _ in candidate_scores[:top_k]]
        
        predictions[bib_key] = ranked_candidates
    
    return predictions


# Generate predictions for all papers in train/val/test sets
all_labeled_papers = set(manual_labels.keys()) | set(automatic_labels.keys())

print("Generating predictions...")

for paper_id in all_labeled_papers:
    if paper_id not in all_bibtex or paper_id not in all_references:
        continue
    
    # Determine partition
    if paper_id == test_manual or paper_id == test_auto:
        partition = "test"
        ground_truth = manual_labels.get(paper_id, {}) or automatic_labels.get(paper_id, {})
    elif paper_id == val_manual or paper_id == val_auto:
        partition = "valid"
        ground_truth = manual_labels.get(paper_id, {}) or automatic_labels.get(paper_id, {})
    else:
        partition = "train"
        ground_truth = manual_labels.get(paper_id, {}) or automatic_labels.get(paper_id, {})
    
    # Generate predictions
    predictions = generate_predictions(
        best_model, scaler,
        all_bibtex[paper_id], all_references[paper_id],
        top_k=5
    )
    
    # Save to pred.json
    output_path = OUTPUT_DIR / paper_id / "pred.json"
    output_data = {
        "partition": partition,
        "groundtruth": ground_truth,
        "prediction": predictions
    }
    
    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(output_data, f, indent=2, ensure_ascii=False)
    
    print(f"  Saved {paper_id} ({partition}): {len(predictions)} predictions")

print("\nAll predictions saved!")


Generating predictions...
  Saved 2305-00836 (train): 21 predictions
  Saved 2305-00256 (train): 50 predictions
  Saved 2305-01831 (train): 53 predictions
  Saved 2305-01274 (train): 97 predictions
  Saved 2304-14807 (train): 56 predictions
  Saved 2305-00708 (train): 15 predictions
  Saved 2305-02504 (train): 38 predictions
  Saved 2305-00681 (train): 15 predictions
  Saved 2305-00231 (train): 31 predictions


KeyboardInterrupt: 