In [1]:
import sys
import json
import re
import random
from pathlib import Path
from typing import Dict, List
from dataclasses import dataclass, field

sys.path.insert(0, str(Path.cwd().parent))
from matching import BibEntry, RefEntry, TextCleaner

## 1. Configuration

In [2]:
# Configuration
DATA_DIR = Path("../../23120260")
OUTPUT_DIR = Path("../../output")
OUTPUT_DIR.mkdir(exist_ok=True)

# Sampling configuration
SAMPLE_SIZE = 1500
RANDOM_SEED = 42
MANUAL_PUBS = {"2411-00222", "2411-00223", "2411-00225", "2411-00226", "2411-00227"}

print(f"Data: {DATA_DIR}, Output: {OUTPUT_DIR}")

Data: ..\..\23120260, Output: ..\..\output


## 2. BibTeX Extractor

In [3]:
class BibExtractor:
    """Extract BibTeX entries from LaTeX sources"""
    
    @staticmethod
    def extract_bibitems(tex_content: str) -> List[BibEntry]:
        entries = []
        pattern = r'\\bibitem(?:\[([^\]]*)\])?\{([^}]+)\}(.*?)(?=\\bibitem|\\end\{thebibliography\}|\Z)'
        
        for match in re.finditer(pattern, tex_content, re.DOTALL):
            key = match.group(2).strip()
            content = match.group(3).strip()
            entries.append(BibExtractor._parse_bibitem_content(key, content))
        return entries
    
    @staticmethod
    def _parse_bibitem_content(key: str, content: str) -> BibEntry:
        content = re.sub(r'\\newblock\s*', ' ', content)
        content = re.sub(r'\{\\em\s+([^}]*)\}', r'\1', content)
        content = re.sub(r'\\emph\{([^}]*)\}', r'\1', content)
        
        arxiv_id = TextCleaner.extract_arxiv_id(content) or ""
        year = TextCleaner.extract_year(content) or ""
        
        title = ""
        title_match = re.search(r'\{\\em\s+([^}]+)\}', content)
        if title_match:
            title = title_match.group(1)
        elif '.' in content:
            title = content.split('.')[1].strip()
        
        authors = []
        author_part = content.split('.')[0] if '.' in content else content[:100]
        for a in re.split(r'\s+and\s+|,\s*(?=[A-Z])', author_part):
            a = a.strip()
            if a and len(a) > 2 and not a.startswith('\\'):
                authors.append(a)
        
        return BibEntry(key=key, title=title, authors=authors[:10], year=year, 
                        arxiv_id=arxiv_id, raw_content=content[:500])
    
    @staticmethod
    def extract_from_bib_file(bib_content: str) -> List[BibEntry]:
        entries = []
        pattern = r'@(\w+)\{([^,]+),\s*(.*?)\n\}'
        
        for match in re.finditer(pattern, bib_content, re.DOTALL):
            key = match.group(2).strip()
            fields_str = match.group(3)
            
            fields = {}
            for fm in re.finditer(r'(\w+)\s*=\s*[{"]([^}"]*)["}]', fields_str, re.DOTALL):
                fields[fm.group(1).lower()] = fm.group(2).strip()
            
            authors = re.split(r'\s+and\s+', fields.get('author', '')) if 'author' in fields else []
            arxiv_id = fields.get('eprint', '') or fields.get('arxiv', '')
            
            entries.append(BibEntry(
                key=key, title=fields.get('title', ''), authors=[a.strip() for a in authors],
                year=fields.get('year', ''), venue=fields.get('journal', fields.get('booktitle', '')),
                arxiv_id=arxiv_id.replace('.', '-'), raw_content=match.group(0)[:500]
            ))
        return entries

## 3. Data Loading Functions

In [4]:
def load_references_json(path: Path) -> Dict[str, RefEntry]:
    """Load references.json into RefEntry objects"""
    if not path.exists():
        return {}
    with open(path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    return {arxiv_id: RefEntry(arxiv_id=arxiv_id, title=info.get('paper_title', ''),
                               authors=info.get('authors', []), 
                               submission_date=info.get('submission_date', ''),
                               venue=info.get('publication_venue', ''))
            for arxiv_id, info in data.items()}


def extract_bibs_from_publication(pub_path: Path) -> List[BibEntry]:
    """Extract all BibTeX entries from a publication"""
    tex_path = pub_path / 'tex'
    if not tex_path.exists():
        return []
    
    all_entries = []
    for version_dir in tex_path.iterdir():
        if not version_dir.is_dir():
            continue
        
        for bib_file in version_dir.rglob('*.bib'):
            try:
                content = bib_file.read_text(encoding='utf-8', errors='ignore')
                all_entries.extend(BibExtractor.extract_from_bib_file(content))
            except Exception:
                pass
        
        for tex_file in version_dir.rglob('*.tex'):
            try:
                content = tex_file.read_text(encoding='utf-8', errors='ignore')
                if r'\bibitem' in content:
                    all_entries.extend(BibExtractor.extract_bibitems(content))
            except Exception:
                pass
    
    # Deduplicate by key
    seen = set()
    return [e for e in all_entries if not (e.key in seen or seen.add(e.key))]

## 4. Process Publications

In [5]:
# Get and sample publications
all_publications = sorted([p for p in DATA_DIR.iterdir() if p.is_dir()])
print(f"Found {len(all_publications)} total publications")

manual_pubs = [p for p in all_publications if p.name in MANUAL_PUBS]
non_manual_pubs = [p for p in all_publications if p.name not in MANUAL_PUBS]

random.seed(RANDOM_SEED)
sampled = random.sample(non_manual_pubs, min(SAMPLE_SIZE, len(non_manual_pubs)))
publications = sorted(manual_pubs + sampled, key=lambda p: p.name)

print(f"Selected: {len(manual_pubs)} manual + {len(sampled)} sampled = {len(publications)} total")

Found 5000 total publications
Selected: 5 manual + 1500 sampled = 1505 total


In [6]:
# Test on sample publication
sample_pub = publications[0]
refs = load_references_json(sample_pub / 'references.json')
bibs = extract_bibs_from_publication(sample_pub)
print(f"{sample_pub.name}: {len(refs)} refs, {len(bibs)} bibs")
if bibs:
    print(f"  First bib: {bibs[0].key} - {bibs[0].title[:60]}...")

2411-00222: 16 refs, 31 bibs
  First bib: ganjidoost2024protectingfeedforwardnetworksadversarial - Protecting Feed-Forward Networks from Adversarial Attacks Us...


## 5. Extract All Data

In [7]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import multiprocessing

def filter_refs_for_bibs(bibs: List[BibEntry], refs: Dict[str, RefEntry], max_title_diff: int = 40) -> Dict[str, RefEntry]:
    """Pre-filter refs based on title length compatibility"""
    bib_lengths = [len((b.title or '').replace('\n', ' ').strip()) for b in bibs]
    return {k: v for k, v in refs.items() 
            if any(abs(len((v.title or '').replace('\n', ' ').strip()) - bl) <= max_title_diff for bl in bib_lengths)}

def process_publication(pub_path: Path) -> dict:
    """Process a single publication"""
    result = {'status': 'skipped', 'reason': None, 'data': None, 'stats': {}}
    
    refs_path = pub_path / 'references.json'
    if not refs_path.exists():
        return {**result, 'reason': 'no_refs_file', 'pub_id': pub_path.name}
    
    refs = load_references_json(refs_path)
    if not refs:
        return {**result, 'reason': 'empty_refs', 'pub_id': pub_path.name}
    
    if not (pub_path / 'tex').exists():
        return {**result, 'reason': 'no_tex_folder', 'pub_id': pub_path.name}
    
    bibs = extract_bibs_from_publication(pub_path)
    if not bibs:
        return {**result, 'reason': 'no_bibs', 'pub_id': pub_path.name}
    
    refs_before = len(refs)
    filtered_refs = filter_refs_for_bibs(bibs, refs)
    
    return {
        'status': 'success',
        'stats': {'refs_before': refs_before, 'refs_after': len(filtered_refs)},
        'data': {
            'pub_id': pub_path.name, 'num_refs': len(filtered_refs), 'num_bibs': len(bibs),
            'refs': {k: v.to_dict() for k, v in filtered_refs.items()},
            'bibs': [b.to_dict() for b in bibs]
        }
    }

# Process with ThreadPoolExecutor
num_workers = min(multiprocessing.cpu_count() * 2, 16)
all_data, skipped = [], {'no_refs_file': [], 'empty_refs': [], 'no_bibs': [], 'no_tex_folder': []}
total_refs_before, total_refs_after = 0, 0

with ThreadPoolExecutor(max_workers=num_workers) as executor:
    futures = [executor.submit(process_publication, p) for p in publications]
    for future in tqdm(as_completed(futures), total=len(publications), desc="Processing"):
        result = future.result()
        if result['status'] == 'success':
            all_data.append(result['data'])
            total_refs_before += result['stats']['refs_before']
            total_refs_after += result['stats']['refs_after']
        elif result['reason']:
            skipped[result['reason']].append(result.get('pub_id'))

print(f"\nValid: {len(all_data)}, Skipped: {sum(len(v) for v in skipped.values())}")
print(f"Refs filtered: {total_refs_before} → {total_refs_after} ({(total_refs_before - total_refs_after) / max(total_refs_before, 1) * 100:.1f}% removed)")

Processing: 100%|██████████| 1505/1505 [00:51<00:00, 29.03it/s] 


Valid: 893, Skipped: 612
Refs filtered: 20767 → 20408 (1.7% removed)





In [8]:
# Save extracted data
with open(OUTPUT_DIR / 'extracted_data.json', 'w', encoding='utf-8') as f:
    json.dump(all_data, f, indent=2, ensure_ascii=False)
print(f"Saved {len(all_data)} publications to {OUTPUT_DIR / 'extracted_data.json'}")

Saved 893 publications to ..\..\output\extracted_data.json


## 6. Summary Statistics

In [9]:
total_refs = sum(d['num_refs'] for d in all_data)
total_bibs = sum(d['num_bibs'] for d in all_data)
print(f"Publications: {len(all_data)}, Refs: {total_refs}, Bibs: {total_bibs}")
print(f"Avg per pub: {total_refs/len(all_data):.1f} refs, {total_bibs/len(all_data):.1f} bibs")

Publications: 893, Refs: 20408, Bibs: 422393
Avg per pub: 22.9 refs, 473.0 bibs


## 7. Automatic Labeling
Auto-label non-manual publications using regex and string similarity heuristics.

In [10]:
class AutoLabeler:
    """Auto-labeling with regex and string similarity"""
    
    @staticmethod
    def jaccard(set1, set2):
        return len(set1 & set2) / len(set1 | set2) if set1 and set2 else 0.0
    
    @staticmethod
    def tokenize(text):
        return set(re.sub(r'[^\w\s]', ' ', (text or '').lower()).split()) if text else set()
    
    @staticmethod
    def get_last_name(author):
        parts = (author or '').lower().split()
        return parts[-1] if parts else ""
    
    @staticmethod
    def find_arxiv_match(content, refs):
        """Strategy 1: Exact arXiv ID in content"""
        matches = []
        for arxiv_id in refs.keys():
            if arxiv_id.replace('-', '.') in content or arxiv_id in content:
                matches.append((arxiv_id, 1.0, "arxiv_exact"))
        return matches
    
    @staticmethod
    def find_title_match(bib, refs, threshold=0.7):
        """Strategy 2: High title similarity"""
        bib_tokens = AutoLabeler.tokenize(bib.get('title', ''))
        if not bib_tokens:
            return []
        return [(arxiv_id, sim, "title_jaccard") 
                for arxiv_id, ref in refs.items()
                if (sim := AutoLabeler.jaccard(bib_tokens, AutoLabeler.tokenize(ref.get('paper_title', '')))) >= threshold]
    
    @staticmethod
    def find_author_year_match(bib, refs):
        """Strategy 3: First author + year + partial title"""
        bib_authors, bib_year = bib.get('authors', []), bib.get('year') or TextCleaner.extract_year(bib.get('raw_content', ''))
        if not bib_authors or not bib_year:
            return []
        
        bib_first = AutoLabeler.get_last_name(bib_authors[0])
        bib_title_tokens = AutoLabeler.tokenize(bib.get('title', ''))
        
        matches = []
        for arxiv_id, ref in refs.items():
            ref_authors = ref.get('authors', [])
            if not ref_authors or AutoLabeler.get_last_name(ref_authors[0]) != bib_first:
                continue
            try:
                if abs(int(bib_year) - int(ref.get('submission_date', '')[:4])) > 1:
                    continue
            except ValueError:
                continue
            overlap = AutoLabeler.jaccard(bib_title_tokens, AutoLabeler.tokenize(ref.get('paper_title', '')))
            if overlap >= 0.3:
                matches.append((arxiv_id, 0.6 + 0.4 * overlap, "author_year_title"))
        return matches
    
    @staticmethod
    def auto_label_publication(pub_data, refs_data):
        """Generate labels using all strategies"""
        labels = {}
        for bib in pub_data['bibs']:
            raw = bib.get('raw_content', '')
            matches = (AutoLabeler.find_arxiv_match(raw, refs_data) or 
                      AutoLabeler.find_title_match(bib, refs_data) or 
                      AutoLabeler.find_author_year_match(bib, refs_data))
            if matches:
                best = max(matches, key=lambda x: x[1])
                labels[bib['key']] = {'arxiv_id': best[0], 'confidence': best[1], 'method': best[2]}
        return labels

In [11]:
# Run auto-labeling on non-manual publications
non_manual_data = [d for d in all_data if d['pub_id'] not in MANUAL_PUBS]
auto_labeled = []

for pub_data in tqdm(non_manual_data, desc="Auto-labeling"):
    labels = AutoLabeler.auto_label_publication(pub_data, pub_data['refs'])
    if labels:
        auto_labeled.append({'pub_id': pub_data['pub_id'], 'labels': labels, 'num_labels': len(labels)})

print(f"\nAuto-labeled: {len(auto_labeled)}/{len(non_manual_data)} publications ({len(auto_labeled)/len(non_manual_data)*100:.1f}%)")

Auto-labeling: 100%|██████████| 888/888 [00:08<00:00, 107.28it/s]


Auto-labeled: 664/888 publications (74.8%)





In [12]:
# Assign partitions: test, valid, train
AUTO_PARTITIONS = {}
for i, result in enumerate(auto_labeled):
    pub_id = result['pub_id']
    AUTO_PARTITIONS[pub_id] = "test" if i == 0 else ("valid" if i == 1 else "train")

# Save pred.json files
for result in auto_labeled:
    pub_id = result['pub_id']
    partition = AUTO_PARTITIONS.get(pub_id, "train")
    groundtruth = {k: v['arxiv_id'] for k, v in result['labels'].items()}
    
    pred_data = {"partition": partition, "groundtruth": groundtruth, 
                 "prediction": {k: [] for k in groundtruth}, "label_source": "auto"}
    
    with open(DATA_DIR / pub_id / "pred.json", 'w', encoding='utf-8') as f:
        json.dump(pred_data, f, indent=2, ensure_ascii=False)

print(f"Saved {len(auto_labeled)} pred.json files (test: 1, valid: 1, train: {len(auto_labeled)-2})")

Saved 664 pred.json files (test: 1, valid: 1, train: 662)


In [13]:
# Summary
print("=" * 50)
print("LABELING SUMMARY")
print("=" * 50)
print(f"Manual: {len(MANUAL_PUBS)} publications (test: 2411-00222, valid: 2411-00223, train: 3)")
print(f"Auto: {len(auto_labeled)} publications, {sum(r['num_labels'] for r in auto_labeled)} labels")
print(f"Total extracted: {len(all_data)} publications")

LABELING SUMMARY
Manual: 5 publications (test: 2411-00222, valid: 2411-00223, train: 3)
Auto: 664 publications, 5891 labels
Total extracted: 893 publications


---
**Next:** `02_feature_engineering.ipynb`