# Notebook 3: Pairs Generation

In [10]:
import os
import json
import random
import re

BASE_RAW_PATH = os.path.join('..', 'data', 'processed', 'parsed')
MANUAL_DATA_PATH = os.path.join(BASE_RAW_PATH, 'manual')
AUTO_DATA_PATH = os.path.join(BASE_RAW_PATH, 'auto')

GT_PATH = os.path.join('..', 'data', 'processed', 'ground_truth')
AUTO_GT_PATH = os.path.join(GT_PATH, 'auto.json')     
MANUAL_GT_PATH = os.path.join(GT_PATH, 'manual.json') 

OUTPUT_PATH = os.path.join('..', 'data', 'processed', 'pairs')

STOP_WORDS = {
    'a', 'an', 'the', 'and', 'or', 'of', 'to', 'in', 'on', 'at', 'by', 'for', 
    'with', 'from', 'as', 'is', 'are', 'was', 'were', 'be', 'been', 'this', 'that'
}

SAMPLE_SIZE = 2000

## 1. Helper functions

In [11]:
def clean_latex_string(text):
    if not text:
        return ""
    
    # remove \bibinfo{type}{content} -> content
    text = re.sub(r'\\bibinfo\{.*?\}\{(.*?)\}', r'\1', text)

    # remove latex commands
    text = re.sub(r'\\[a-zA-Z]+\{(.*?)\}', r'\1', text)

    # remove newlines
    text = text.replace('\n', ' ').strip()
    text = re.sub(r'\s+', ' ', text)

    return text

def get_field(field_name, entry):
    field_pattern = re.compile(fr'\b{field_name}\s*=\s*\{{((?:[^{{}}]|\{{[^{{}}]*\}})*)\}}', re.IGNORECASE | re.DOTALL)
    match = field_pattern.search(entry)
    
    if match:
        return match.group(1)
    
    pat_quotes = re.compile(fr'\b{field_name}\s*=\s*\"(.*?)\"', re.IGNORECASE | re.DOTALL)
    match = pat_quotes.search(entry)
    
    if match:
        return match.group(1)  
    else: 
        return None

def extract_from_bib(file_content):
    references = []
    raw_entries = re.split(r'^@', file_content, flags=re.MULTILINE)
    
    for entry in raw_entries:
        entry = entry.strip()
        if not entry or entry.startswith('%'): 
            continue
            
        key_match = re.search(r'^(\w+)\s*\{\s*([^,]+),', entry)
        if not key_match:
            continue
            
        ref_type = key_match.group(1)
        ref_id = key_match.group(2).strip()

        if ref_type.lower() in ['string', 'comment', 'preamble']:
            continue
        
        title_raw = get_field('title', entry)
        year_raw = get_field('year', entry)
        author_raw = get_field('author', entry)
        
        if title_raw:
            title = clean_latex_string(title_raw) 
        else:
            title = ""
        
        if year_raw:
            year = year_raw 
        else: 
            year = ""
        
        authors = []
        if author_raw:
            raw_authors = clean_latex_string(author_raw)
            authors = [a.strip() for a in raw_authors.split(' and ')]

        references.append({
            "id": ref_id,
            "title": title,
            "authors": authors,
            "year": year,
            "source_type": "bib"
        })
        
    return references

In [12]:
def get_paper_references(paper_id, data_path):
    if not data_path:
        paper_path = os.path.join(AUTO_DATA_PATH, paper_id)
    else:
        paper_path = os.path.join(data_path, paper_id)

    unique_references = {} 
    
    # search for all .bib files
    for root, dirs, files in os.walk(paper_path):
        for file in files:
            if file.endswith(".bib"):
                try:
                    with open(os.path.join(root, file), 'r', encoding='utf-8', errors='ignore') as f:
                        content = f.read()
                        extracted_references = extract_from_bib(content)
                        
                        for reference in extracted_references:
                            if reference['id'] not in unique_references:
                                unique_references[reference['id']] = reference
                except Exception as e:
                    print(f"Error parsing {file}: {e}")
            
    return list(unique_references.values())

def load_references_json_file(paper_id, data_path):
    if not data_path:
        json_path = os.path.join(AUTO_DATA_PATH, paper_id, 'references.json')
    else:
        json_path = os.path.join(data_path, paper_id, 'references.json')
    
    if not os.path.exists(json_path):
        print(f"references.json not found for paper {paper_id}")
        return []
    
    try:
        with open(json_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
            
        if data is None or not isinstance(data, dict):
            print(f"Warning: references.json for {paper_id} is empty or invalid format.")
            return []
            
    except json.JSONDecodeError:
        print(f"Could not decode JSON for paper {paper_id}")
        return []
        
    target_references = []
    
    for arxiv_id, metadata in data.items():
        
        title = metadata.get('title', "")
        authors = metadata.get('authors', [])
        date_str = metadata.get('submission_date', "")
        year = ""

        if date_str:
            year = date_str.split('-')[0]
            
        target_references.append({
            "id": arxiv_id,             
            "title": title,
            "authors": authors,
            "year": year,
            "source_type": "references_json"
        })
        
    return target_references

In [13]:
def clean_text(text):
    if not text:
        return ""
    
    # to lowercase
    text = text.lower()
    
    # remove punctuation
    text = re.sub(r'[^a-z0-9\s]', ' ', text)
    
    # filter stop words
    tokens = text.split()
    clean_tokens = [t for t in tokens if t not in STOP_WORDS]
    
    return " ".join(clean_tokens)

def tokenize_author_list(authors_list):
    if not authors_list:
        return set()
    
    all_authors = " ".join(authors_list)
    all_authors = all_authors.lower()
    all_authors = re.sub(r'[^a-z\s]', ' ', all_authors)
    
    # separate all author string into tokens
    tokens = set(all_authors.split())
    
    return tokens

def clean_references(reference_list):
    for reference in reference_list:
        reference['clean_title'] = clean_text(reference.get('title', ''))
        reference['clean_author_tokens'] = tokenize_author_list(reference.get('authors', []))
        
    return reference_list

## 2. Load Data
### 2.1 Load References

In [14]:
def load_data_from_folder(folder_path, label_type):
    print(f"Loading {label_type} data from: {os.path.abspath(folder_path)}")
    
    pids = [d for d in os.listdir(folder_path) if os.path.isdir(os.path.join(folder_path, d))]
    pids.sort()
    pids = pids[:SAMPLE_SIZE]
    
    refs_dict = {}
    targets_dict = {}
    
    for pid in pids:
        refs = get_paper_references(pid, folder_path)
        targets = load_references_json_file(pid, folder_path)
        
        clean_references(refs)
        clean_references(targets)
        
        refs_dict[pid] = refs
        targets_dict[pid] = targets
        
    print(f"  -> Loaded content for {len(pids)} {label_type} papers.")
    return refs_dict, targets_dict

In [15]:
manual_refs, manual_target_refs = load_data_from_folder(MANUAL_DATA_PATH, "MANUAL")
auto_refs, auto_target_refs = load_data_from_folder(AUTO_DATA_PATH, "AUTO")

master_refs = {**auto_refs, **manual_refs}               
master_targets = {**auto_target_refs, **manual_target_refs} 

print("-" * 40)
print(f"Total Papers Loaded: {len(master_refs)}")

Loading MANUAL data from: /Users/thomas200905/Documents/Thomas/HCMUS/Third Year/Semester 7/Intro to Data Science/Milestones/MS02/data/processed/parsed/manual
  -> Loaded content for 5 MANUAL papers.
Loading AUTO data from: /Users/thomas200905/Documents/Thomas/HCMUS/Third Year/Semester 7/Intro to Data Science/Milestones/MS02/data/processed/parsed/auto
references.json not found for paper 2211.04143
references.json not found for paper 2211.04243
references.json not found for paper 2211.04370
  -> Loaded content for 2000 AUTO papers.
----------------------------------------
Total Papers Loaded: 2005


### 2.2 Load Ground Truth

In [16]:
def load_ground_truth(filepath):
    if os.path.exists(filepath):
        with open(filepath, 'r', encoding='utf-8') as f:
            return json.load(f)
    else:
        print(f"[WARNING] Ground truth file not found: {filepath}")
        return {}

print("Loading Ground Truth Labels...")
gt_auto = load_ground_truth(AUTO_GT_PATH)
gt_manual = load_ground_truth(MANUAL_GT_PATH)

print(f"  -> Loaded {len(gt_auto)} auto-labeled papers.")
print(f"  -> Loaded {len(gt_manual)} manually-labeled papers.")

Loading Ground Truth Labels...
  -> Loaded 2000 auto-labeled papers.
  -> Loaded 5 manually-labeled papers.


## 3. Create Pairs

In [17]:
def generate_grouped_dataset(ground_truth, all_refs, all_targets):
    grouped_data = {}
    
    for pid, matches in ground_truth.items():
        if pid not in all_refs or pid not in all_targets:
            continue
            
        paper_pairs = []
        source_lookup = {r['id']: r for r in all_refs[pid]}
        target_lookup = {t['id']: t for t in all_targets[pid]}
        all_target_ids = list(target_lookup.keys())
        
        for source_id, true_target_id in matches.items():
            source_obj = source_lookup.get(source_id)
            target_obj = target_lookup.get(true_target_id)
            
            if not source_obj or not target_obj: continue
            
            paper_pairs.append({
                "source": source_obj,
                "candidate": target_obj,
                "label": 1
            })
            
            negative_candidates = [tid for tid in all_target_ids if tid != true_target_id]
            
            k = min(len(negative_candidates), 5) 
            selected_negatives = random.sample(negative_candidates, k)
            
            for neg_id in selected_negatives:
                neg_obj = target_lookup.get(neg_id)
                if neg_obj:
                    paper_pairs.append({
                        "source": source_obj,
                        "candidate": neg_obj,
                        "label": 0
                    })
        
        if paper_pairs:
            grouped_data[pid] = paper_pairs

        if not paper_pairs:
            if len(ground_truth) < 100: 
                print(f"Dropping Paper {pid}, no valid pairs generated")
                    
    return grouped_data

In [18]:
print("Generating Grouped Datasets...")

def set_to_list(obj):
    if isinstance(obj, set):
        return list(obj)

    raise TypeError

manual_grouped = generate_grouped_dataset(gt_manual, master_refs, master_targets)
with open(os.path.join(OUTPUT_PATH, 'manual_pairs.json'), 'w', encoding='utf-8') as f:
    json.dump(manual_grouped, f, indent=2, ensure_ascii = False, default = set_to_list) 
print(f"Saved manual pairs for {len(manual_grouped)} papers.")

auto_grouped = generate_grouped_dataset(gt_auto, master_refs, master_targets)
with open(os.path.join(OUTPUT_PATH, 'auto_pairs.json'), 'w', encoding='utf-8') as f:
    json.dump(auto_grouped, f, indent=2, ensure_ascii = False, default = set_to_list)
print(f"Saved auto pairs for {len(auto_grouped)} papers.")

Generating Grouped Datasets...
Saved manual pairs for 5 papers.
Saved auto pairs for 890 papers.
