In [None]:
import torch
import torch.nn as nn
import random
import numpy as np
from typing import List, Dict, Tuple, Iterator, Dict, Union, Optional
from transformers import AutoTokenizer, AutoModel,AutoModelForCausalLM

import xml.etree.ElementTree as ET
import bz2
import json
import sqlite3
from pathlib import Path
import re
import os 
import tqdm

from torch.utils.data import Dataset, DataLoader
from typing import List, Tuple, Dict

class WikiDumpParser:
    def __init__(self, dump_path: str):
        self.dump_path = dump_path
        self.ns = {'mw': 'http://www.mediawiki.org/xml/export-0.10/'}

    def _open_dump(self) -> Iterator:
        if self.dump_path.endswith('.bz2'):
            return bz2.BZ2File(self.dump_path)
        return open(self.dump_path, 'rb')

    def iter_pages(self) -> Iterator[Dict]:
        context = ET.iterparse(self._open_dump(), events=('end',))
        
        for event, elem in context:
            if elem.tag.endswith('page'):
                title_elem = elem.find('.//mw:title', self.ns)
                revision = elem.find('.//mw:revision', self.ns)
                
                if revision is not None:
                    text_elem = revision.find('mw:text', self.ns)
                    timestamp_elem = revision.find('mw:timestamp', self.ns)
                    text = text_elem.text if text_elem is not None else ''
                    timestamp = timestamp_elem.text if timestamp_elem is not None else ''
                else:
                    text = ''
                    timestamp = ''
                
                title = title_elem.text if title_elem is not None else ''
                
                yield {
                    'title': title,
                    'text': text,
                    'timestamp': timestamp,
                    'is_redirect': bool(re.match(r'#REDIRECT', text or '', re.IGNORECASE))
                }
                
                elem.clear()

    def iter_articles(self, skip_redirects: bool = True) -> Iterator[Dict]:
        for page in self.iter_pages():
            title = page['title']
            
            if any(title.startswith(prefix) for prefix in [
                'Wikipedia:', 'Template:', 'Category:', 'Portal:',
                'File:', 'MediaWiki:', 'Help:', 'Book:', 'Draft:',
                'TimedText:', 'Module:', 'Special:'
            ]):
                continue
                
            if skip_redirects and page['is_redirect']:
                continue
                
            yield page

    def save_to_jsonl(self, output_path: Union[str, Path], sample_size: Optional[int] = None) -> int:
        """
        Save articles to a JSONL file (one JSON object per line).
        
        Args:
            output_path: Path to the output JSONL file
            sample_size: Optional number of articles to save
            
        Returns:
            int: Number of articles saved
        """
        count = 0
        with open(output_path, 'w', encoding='utf-8') as f:
            for i, article in enumerate(self.iter_articles()):
                if sample_size is not None and i >= sample_size:
                    break
                json.dump(article, f, ensure_ascii=False)
                f.write('\n')
                count += 1
        return count

    def save_to_sqlite(self, db_path: Union[str, Path], sample_size: Optional[int] = None,
                      batch_size: int = 1000) -> int:
        """
        Save articles to a SQLite database.
        
        Args:
            db_path: Path to the SQLite database file
            sample_size: Optional number of articles to save
            batch_size: Number of articles to insert in each batch
            
        Returns:
            int: Number of articles saved
        """
        conn = sqlite3.connect(db_path)
        c = conn.cursor()
        
        # Create table if it doesn't exist
        c.execute('''CREATE TABLE IF NOT EXISTS articles
                    (title TEXT PRIMARY KEY,
                     text TEXT,
                     timestamp TEXT,
                     is_redirect INTEGER)''')
        
        # Create index on title
        c.execute('CREATE INDEX IF NOT EXISTS idx_title ON articles(title)')
        
        count = 0
        batch = []
        
        try:
            for i, article in enumerate(self.iter_articles()):
                if sample_size is not None and i >= sample_size:
                    break
                    
                batch.append((
                    article['title'],
                    article['text'],
                    article['timestamp'],
                    1 if article['is_redirect'] else 0
                ))
                
                if len(batch) >= batch_size:
                    c.executemany(
                        'INSERT OR REPLACE INTO articles VALUES (?, ?, ?, ?)',
                        batch
                    )
                    conn.commit()
                    count += len(batch)
                    batch = []
            
            # Insert remaining articles
            if batch:
                c.executemany(
                    'INSERT OR REPLACE INTO articles VALUES (?, ?, ?, ?)',
                    batch
                )
                conn.commit()
                count += len(batch)
                
        finally:
            conn.close()
            
        return count

def read_from_jsonl(file_path: Union[str, Path], limit: Optional[int] = None) -> Iterator[Dict]:
    """
    Read articles from a JSONL file.
    
    Args:
        file_path: Path to the JSONL file
        limit: Optional maximum number of articles to read
        
    Yields:
        Dict: Article information
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            if limit is not None and i >= limit:
                break
            yield json.loads(line)

def read_from_sqlite(db_path: Union[str, Path], limit: Optional[int] = None) -> Iterator[Dict]:
    """
    Read articles from a SQLite database.
    
    Args:
        db_path: Path to the SQLite database
        limit: Optional maximum number of articles to read
        
    Yields:
        Dict: Article information
    """
    conn = sqlite3.connect(db_path)
    c = conn.cursor()
    
    try:
        query = 'SELECT title, text, timestamp, is_redirect FROM articles'
        if limit is not None:
            query += f' LIMIT {limit}'
            
        for row in c.execute(query):
            yield {
                'title': row[0],
                'text': row[1],
                'timestamp': row[2],
                'is_redirect': bool(row[3])
            }
    finally:
        conn.close()


def extract_citations(text: str) -> List[str]:
    """Extract all citations from a text and clean them."""
    citations = re.findall(r'\[\[(.*?)\]\]', text)
    # Clean up citations (take first part if pipe character exists)
    return [c.split('|')[0]for c in citations]

def clean_wiki_content(text: str) -> str:
    """
    Clean wiki content by:
    1. Finding the start of the actual article (from triple quotes)
    2. Removing category tags, file links, and other metadata
    
    Args:
        text: Raw wiki text
    
    Returns:
        str: Cleaned article content starting from title
    """
    # Find the main article content starting with triple quotes
    title_match = re.search(r"'''([^']+?)'''", text)
    if not title_match:
        return text  # Return original if no title found
        
    # Get the position where the title starts
    start_pos = title_match.start()
    content = text[start_pos:]
    
    # Remove category tags
    content = re.sub(r'\[\[Category:.*?\]\]', '', content)
    
    # Remove file/image links
    content = re.sub(r'\[\[File:.*?\]\]', '', content)
    
    # Remove stub templates
    content = re.sub(r'\{\{stub\}\}', '', content)
    
    # Remove empty lines that might remain
    content = '\n'.join(line for line in content.split('\n') if line.strip())
    
    return content.strip()

def create_article_dict(jsonl_path, sample_size: Optional[int] = None) -> Dict[str, str]:
    """
    Create a dictionary of articles from the wiki parser, with cleaned content.
    
    Args:
        jsonl: the path to the saved wiki pages as [Title] -> [Text] key value format. 
        sample_size: Optional number of articles to include
        
    Returns:
        Dict[str, str]: Dictionary with article titles as keys and cleaned content as values
    """
    articles = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(f):
            article = json.loads(line)
            if sample_size is not None and i >= sample_size:
                break
                
            # Clean the content to start from title
            cleaned_content = clean_wiki_content(article['text'])
            
            if cleaned_content:  # Only include if content remains after cleaning
                articles[article['title'].lower()] = cleaned_content
            
    return articles

def prepare_sources_targets(articles_dict, sample_size=1000, cite_sample=1):
    articles = np.random.permutation(list(articles_dict.keys()))[:sample_size]
    # Initialize your model (assuming it's defined elsewhere)
    model = AutoregressiveCitationMatcher()
    sources = []
    targets = []
    for article in tqdm.tqdm(articles):
        source_text = articles_dict[article]
        # Get all citations from this source article
        citations = extract_citations(source_text)
        valid_citations = [c for c in citations if c.lower() in articles_dict]
    
        for citation in np.random.choice(valid_citations,min(cite_sample,len(valid_citations))):
            try:
                source_content = model.prepare_source_context(source_text, citation)
                target_content = model.prepare_target_page(articles_dict[citation.lower()])
                sources.append(source_content)
                targets.append(target_content)
            except Exception:
                pass
    return sources, targets 


class AutoregressiveCitationMatcher(nn.Module):
    def __init__(
        self,
        model_name: str = "bert-base-uncased",
        max_length: int = 512,
        cite_token: str = "<CITE>",
        ref_token: str = "<REF>",
        device: torch.device = None
    ):
        super().__init__()
        
        # Set device
        self.device = device if device is not None else torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        # Add special tokens for citation and reference
        special_tokens = {
            'additional_special_tokens': [cite_token, ref_token]
        }
        self.tokenizer.add_special_tokens(special_tokens)
        
        # For GPT models, we might need to set pad token if it's not set
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        
        # Initialize encoder for both source and target texts
        # We use the same model for both since autoregressive models maintain context
        self.model = AutoModel.from_pretrained(model_name, gradient_checkpointing=True ).to(self.device)
        
        # Resize token embeddings
        self.model.resize_token_embeddings(len(self.tokenizer))
        
        self.cite_token = cite_token
        self.ref_token = ref_token
        self.max_length = max_length
        
        hidden_size = self.model.config.hidden_size
        
        print(f"Using device: {self.device}")

    def prepare_source_context(
        self,
        text: str,
        target_citation: str,
        window_size: int = 100
    ) -> str:
        """Prepare source context with citation and reference tokens."""
        citation_pattern = re.escape(f"[[{target_citation}]]")
        # Add both <CITE> and <REF> tokens
        modified_text = re.sub(citation_pattern, f"{self.cite_token}", text)
        # modified_text = f"{modified_text} {self.ref_token}"

        if self.cite_token not in modified_text:
            raise ValueError(f"There is no cite token")
        
        return modified_text

    def prepare_target_page(
        self,
        page_content: str,
    ) -> str:
        """Prepare target page with reference token."""
        # Take first paragraph as summary and add <REF> token
            
        return f"{page_content} {self.ref_token}"

    def get_token_embedding(
        self,
        hidden_states: torch.Tensor,
        input_ids: torch.Tensor,
        token_id: int
    ) -> torch.Tensor:
        """Extract embedding for specific token for each item in the batch."""
        batch_size = input_ids.size(0)
        embeddings = []
        
        for batch_idx in range(batch_size):
            # Find token positions for this batch item
            token_positions = (input_ids[batch_idx] == token_id).nonzero()
            
            if len(token_positions) == 0:
                raise ValueError(f"Token id {token_id} not found in sequence for batch item {batch_idx}")
                    
            if len(token_positions) > 1:
                # If multiple occurrences, take the last one
                position = token_positions[-1].item()
            else:
                position = token_positions[0].item()
                    
            # Extract embedding for this batch item
            embedding = hidden_states[batch_idx, position, :]  # Explicitly select all hidden dimensions
            embeddings.append(embedding)
            
        # Stack all embeddings into a single tensor
        return torch.stack(embeddings)  # Shape will be [batch_size, hidden_size]

    def encode_text(
        self,
        text: str,
        is_source: bool = True
    ) -> torch.Tensor:
        """Encode text and extract reference token embedding."""
        inputs = self.tokenizer(
            text,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        ).to(self.device)
        
        # Get hidden states from the model
        outputs = self.model(
            **inputs,
            output_hidden_states=True,
            return_dict=True
        )
        
        # Get [REF] token embedding from the last hidden state
        ref_token_id = self.tokenizer.convert_tokens_to_ids(self.ref_token)
        ref_embeddings = self.get_token_embedding(
            outputs.hidden_states[-1],
            inputs['input_ids'],
            ref_token_id
        )
        return ref_embeddings
        

    def forward(
        self,
        source_contexts: List[str],
        target_pages: List[str],
        temperature: float = 0.07
    ) -> torch.Tensor:
        """Compute similarity between source and target [REF] embeddings."""
        # Encode all contexts and references
        source_embeddings = []
        target_embeddings = []
        
        for context in source_contexts:
            source_emb = self.encode_text(context, is_source=True)
            source_embeddings.append(source_emb)
            
        for page in target_pages:
            target_emb = self.encode_text(page, is_source=False)
            target_embeddings.append(target_emb)
        
        # Stack embeddings
        source_embeddings = torch.cat(source_embeddings, dim=0)
        target_embeddings = torch.cat(target_embeddings, dim=0)
        
        # Normalize embeddings
        source_embeddings = nn.functional.normalize(source_embeddings, dim=-1)
        target_embeddings = nn.functional.normalize(target_embeddings, dim=-1)
        
        # Compute similarity matrix
        similarity = torch.matmul(
            source_embeddings,
            target_embeddings.transpose(0, 1)
        ) / temperature
        
        return similarity

    def train_step(
        self,
        source_contexts: List[str],
        target_pages: List[str],
        optimizer: torch.optim.Optimizer,
        temperature: float = 0.07
    ) -> float:
        """Perform one training step."""
        optimizer.zero_grad(set_to_none=True)
        
        # Forward pass
        similarity = self(source_contexts, target_pages, temperature)
        
        # The diagonal elements should be the positive pairs
        labels = torch.arange(len(source_contexts)).to(self.device)
        # print('#'*20)
        # print('similarity = ', similarity, ', labels = ', labels)
        
        # Compute loss
        loss = nn.CrossEntropyLoss()(similarity, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        return loss.item()



class CitationDataset(Dataset):
    def __init__(
        self,
        sources: List[str],
        targets: List[str],
        tokenizer,
        max_length: int = 512,
        verbose: bool = True
    ):
        """
        Initialize dataset with preprocessing of all samples.
        
        Args:
            sources: List of source texts
            targets: List of target texts
            tokenizer: Tokenizer to use
            max_length: Maximum sequence length
            verbose: Whether to print processing statistics
        """
        assert len(sources) == len(targets), "Sources and targets must have same length"
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.ref_token = "<REF>"
        self.cite_token = "<CITE>"
        
        # Get token IDs
        self.ref_token_id = self.tokenizer.convert_tokens_to_ids(self.ref_token)
        self.cite_token_id = self.tokenizer.convert_tokens_to_ids(self.cite_token)
        
        if verbose:
            print("Pre-processing dataset...")
            print(f"Initial samples: {len(sources)}")
        
        # Pre-process and store all valid samples
        self.processed_samples = self._preprocess_samples(sources, targets, verbose)
        
        if verbose:
            print(f"Final processed samples: {len(self.processed_samples)}")
            print("Dataset preprocessing completed.")

    def _preprocess_samples(
        self,
        sources: List[str],
        targets: List[str],
        verbose: bool
    ) -> List[Dict[str, torch.Tensor]]:
        """
        Pre-process all samples and filter invalid ones.
        Returns list of valid processed samples.
        """
        processed_samples = []
        skipped_no_cite = 0
        skipped_errors = 0
        
        for idx, (source, target) in enumerate(zip(sources, targets)):
            try:
                # Tokenize source and target
                source_tokens = self.tokenizer(
                    source,
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors='pt'
                )
                
                target_tokens = self.tokenizer(
                    target,
                    padding='max_length',
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors='pt'
                )
                
                
                # Ensure REF token in both source and target
                # source_tokens['input_ids'] = self.ensure_ref_token(source_tokens['input_ids'])
                target_tokens['input_ids'] = self.ensure_ref_token(target_tokens['input_ids'])
                
                # Store processed tensors
                processed_sample = {
                    'source_input_ids': source_tokens['input_ids'].squeeze(0),
                    'source_attention_mask': source_tokens['attention_mask'].squeeze(0),
                    'target_input_ids': target_tokens['input_ids'].squeeze(0),
                    'target_attention_mask': target_tokens['attention_mask'].squeeze(0)
                }
                
                # Check if CITE token exists in source after tokenization
                if self.cite_token_id not in source_tokens['input_ids'][0]:
                    skipped_no_cite += 1
                    continue
                
                processed_samples.append(processed_sample)
                
            except Exception as e:
                skipped_errors += 1
                if verbose:
                    print(f"Error processing sample {idx}: {str(e)}")
                continue
            
            # Print progress periodically
            if verbose and (idx + 1) % 10000 == 0:
                print(f"Processed {idx + 1} samples...")
        
        if verbose:
            print(f"Skipped {skipped_no_cite} samples due to missing CITE token")
            print(f"Skipped {skipped_errors} samples due to processing errors")
        
        return processed_samples

    def ensure_ref_token(self, tokens: torch.Tensor) -> torch.Tensor:
        """Ensure the REF token is present in the sequence, replacing last non-pad token if needed."""
        batch_size = tokens.size(0)
        
        # For each sequence in the batch
        for i in range(batch_size):
            sequence = tokens[i]
            # Find positions of REF token
            ref_positions = (sequence == self.ref_token_id).nonzero()
            
            if len(ref_positions) == 0:
                # If no REF token found, replace the last non-pad token
                pad_token_id = self.tokenizer.pad_token_id
                # Find the last non-pad token position
                non_pad_positions = (sequence != pad_token_id).nonzero()
                if len(non_pad_positions) > 0:
                    last_non_pad_pos = non_pad_positions[-1]
                    sequence[last_non_pad_pos] = self.ref_token_id
            elif len(ref_positions) > 1:
                # If multiple REF tokens, keep only the last one
                for pos in ref_positions[:-1]:
                    sequence[pos] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.unk_token)
        
        return tokens

    def __len__(self):
        return len(self.processed_samples)

    def __getitem__(self, idx) -> Dict[str, torch.Tensor]:
        """Return pre-processed sample."""
        return self.processed_samples[idx]

def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
    """Custom collate function to handle batching of tokenized inputs."""
    return {
        key: torch.stack([item[key] for item in batch])
        for key in batch[0].keys()
    }

def train_model(
    model: AutoregressiveCitationMatcher,
    train_sources: List[str],
    train_targets: List[str],
    val_sources: List[str] = None,
    val_targets: List[str] = None,
    batch_size: int = 32,
    num_epochs: int = 10,
    learning_rate: float = 1e-4,
    temperature: float = 0.07,
    num_workers: int = 4,
    save_path: str = 'checkpoint.pt'
):
    """
    Memory-optimized training function with combined validation.
    """
    # Create datasets
    train_dataset = CitationDataset(
        train_sources,
        train_targets,
        model.tokenizer,
        model.max_length
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        collate_fn=collate_fn,
        pin_memory=True,
        persistent_workers=True,
        prefetch_factor=2,
    )

    if val_sources and val_targets:
        val_dataset = CitationDataset(
            val_sources,
            val_targets,
            model.tokenizer,
            model.max_length
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=2*batch_size,
            shuffle=False,
            num_workers=num_workers,
            collate_fn=collate_fn,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=2,
        )

    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criteria = torch.nn.CrossEntropyLoss()
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode='min',
        factor=0.5,
        patience=2,
        verbose=True
    )

    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        num_batches = 0
        
        progress_bar = tqdm.tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}')
        
        for batch in progress_bar:
            batch = {k: v.to(model.device, non_blocking=True) for k, v in batch.items()}
            
            # Forward pass
            source_outputs = model.model(
                input_ids=batch['source_input_ids'],
                attention_mask=batch['source_attention_mask'],
                output_hidden_states=False,
                return_dict=True
            )
            
            target_outputs = model.model(
                input_ids=batch['target_input_ids'],
                attention_mask=batch['target_attention_mask'],
                output_hidden_states=False,
                return_dict=True
            )
            
            # Get token embeddings
            cite_token_id = model.tokenizer.convert_tokens_to_ids(model.cite_token)
            ref_token_id = model.tokenizer.convert_tokens_to_ids(model.ref_token)

            source_embeddings = model.get_token_embedding(
                source_outputs.last_hidden_state,
                batch['source_input_ids'],
                cite_token_id
            )
            
            target_embeddings = model.get_token_embedding(
                target_outputs.last_hidden_state,
                batch['target_input_ids'],
                ref_token_id
            )
            
            # Normalize embeddings
            source_embeddings = torch.nn.functional.normalize(source_embeddings, dim=-1)
            target_embeddings = torch.nn.functional.normalize(target_embeddings, dim=-1)
            
            # Compute similarity matrix
            similarity = torch.matmul(source_embeddings, target_embeddings.transpose(0, 1)) / temperature
            
            # Compute loss
            labels = torch.arange(similarity.size(0)).to(model.device)
            loss = criteria(similarity, labels)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            
            # Update progress bar
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'avg_loss': f'{total_loss / num_batches:.4f}'
            })
            
            # Free memory
            del source_outputs, target_outputs, source_embeddings, target_embeddings, similarity
            torch.cuda.empty_cache()
            
        # Validation
        if val_sources and val_targets:
            print("\nRunning validation...")
            val_loss, metrics = validate_model(model, val_loader, temperature)
            scheduler.step(val_loss)
            
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), save_path)
                print(f"Saved new best model with validation loss: {val_loss:.4f}")
            
            model.train()
                
        epoch_loss = total_loss / num_batches
        print(f"\nEpoch {epoch + 1} average training loss: {epoch_loss:.4f}")


def validate_model(
    model: AutoregressiveCitationMatcher,
    val_loader: DataLoader,
    temperature: float
) -> Tuple[float, Dict[str, float]]:
    """
    Combined validation function that calculates both loss and ranking metrics in a single pass.
    
    Args:
        model: The citation matcher model
        val_loader: Validation data loader
        temperature: Temperature for similarity scaling
    
    Returns:
        Tuple of (validation_loss, rank_metrics_dict)
    """
    model.eval()
    total_loss = 0
    num_batches = 0
    
    # Storage for embeddings
    all_source_embeddings = []
    all_target_embeddings = []
    
    print("Computing validation loss and collecting embeddings...")
    with torch.no_grad():
        for batch in tqdm.tqdm(val_loader, desc='Processing validation data'):
            batch = {k: v.to(model.device, non_blocking=True) for k, v in batch.items()}
            
            try:
                # Forward pass
                source_outputs = model.model(
                    input_ids=batch['source_input_ids'],
                    attention_mask=batch['source_attention_mask'],
                    output_hidden_states=False,
                    return_dict=True
                )
                
                target_outputs = model.model(
                    input_ids=batch['target_input_ids'],
                    attention_mask=batch['target_attention_mask'],
                    output_hidden_states=False,
                    return_dict=True
                )
                
                # Get token embeddings
                cite_token_id = model.tokenizer.convert_tokens_to_ids(model.cite_token)
                ref_token_id = model.tokenizer.convert_tokens_to_ids(model.ref_token)
                
                source_embeddings = model.get_token_embedding(
                    source_outputs.last_hidden_state,
                    batch['source_input_ids'],
                    cite_token_id
                )
                
                target_embeddings = model.get_token_embedding(
                    target_outputs.last_hidden_state,
                    batch['target_input_ids'],
                    ref_token_id
                )
                
                # Normalize embeddings
                source_embeddings = torch.nn.functional.normalize(source_embeddings, dim=-1)
                target_embeddings = torch.nn.functional.normalize(target_embeddings, dim=-1)
                
                # Compute batch similarity and loss
                similarity = torch.matmul(source_embeddings, target_embeddings.transpose(0, 1)) / temperature
                labels = torch.arange(similarity.size(0)).to(model.device)
                loss = torch.nn.CrossEntropyLoss()(similarity, labels)
                
                total_loss += loss.item()
                num_batches += 1
                
                # Store embeddings for ranking metrics
                all_source_embeddings.append(source_embeddings.cpu())
                all_target_embeddings.append(target_embeddings.cpu())
                
                # Clean up
                del source_outputs, target_outputs, similarity
                torch.cuda.empty_cache()
                
            except ValueError as e:
                print(f"Skipping batch due to error: {e}")
                continue
    
    val_loss = total_loss / num_batches if num_batches > 0 else float('inf')
    
    # Concatenate all embeddings
    all_source_embeddings = torch.cat(all_source_embeddings, dim=0)
    all_target_embeddings = torch.cat(all_target_embeddings, dim=0)
    total_samples = all_source_embeddings.size(0)
    
    print(f"\nComputing global rankings for {total_samples} samples...")
    
    # Define k values for top-k accuracy
    k_values = [1, 3, 5, 10, 50]
    
    # Process in chunks to avoid OOM
    chunk_size = 512
    similarity_rankings = []
    
    for i in tqdm.tqdm(range(0, total_samples, chunk_size), desc='Computing rankings'):
        chunk_end = min(i + chunk_size, total_samples)
        source_chunk = all_source_embeddings[i:chunk_end].to(model.device)
        
        # Calculate similarity for this chunk
        chunk_similarity = torch.matmul(source_chunk, all_target_embeddings.to(model.device).t()) / temperature
        
        # Get rankings for this chunk
        chunk_rankings = torch.argsort(chunk_similarity, dim=-1, descending=True)
        similarity_rankings.append(chunk_rankings.cpu())
        
        del chunk_similarity, source_chunk
        torch.cuda.empty_cache()
    
    # Concatenate all rankings
    all_rankings = torch.cat(similarity_rankings, dim=0)
    
    # Calculate metrics
    correct_at_k = {k: 0 for k in k_values}
    reciprocal_ranks = []
    all_ranks = []
    
    print("Computing final metrics...")
    for i in range(total_samples):
        rank = (all_rankings[i] == i).nonzero().item() + 1
        
        # Update top-k accuracy counters
        for k in k_values:
            if rank <= k:
                correct_at_k[k] += 1
        
        reciprocal_ranks.append(1.0 / rank)
        all_ranks.append(rank)
    
    # Calculate final metrics
    metrics = {
        f'top_{k}_accuracy': correct_at_k[k] / total_samples for k in k_values
    }
    metrics.update({
        'mrr': sum(reciprocal_ranks) / len(reciprocal_ranks),
        'median_rank': float(np.median(all_ranks)),
        'mean_rank': float(np.mean(all_ranks)),
        'val_size': len(all_ranks),
        'val_loss': val_loss
    })
    
    # Print metrics
    print("\nValidation Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
    
    return val_loss, metrics
    

# Usage example:
# Split data into train and validation sets
def split_data(sources: List[str], targets: List[str], val_ratio: float = 0.2) -> Tuple[List[str], List[str], List[str], List[str]]:
    """Split data into training and validation sets."""
    val_size = int(len(sources) * val_ratio)
    indices = torch.randperm(len(sources))
    
    train_indices = indices[val_size:]
    val_indices = indices[:val_size]
    
    train_sources = [sources[i] for i in train_indices]
    train_targets = [targets[i] for i in train_indices]
    val_sources = [sources[i] for i in val_indices]
    val_targets = [targets[i] for i in val_indices]
    
    return train_sources, train_targets, val_sources, val_targets

# wiki_parser = WikiDumpParser('./data/wiki/simplewiki-latest-pages-articles.xml')
# wiki_parser.save_to_jsonl('./wiki_articles.jsonl')

# articles_dict = create_article_dict(jsonl_path='./wiki_articles.jsonl')
train_sources, train_targets = prepare_sources_targets(articles_dict, sample_size = 2000, cite_sample=1)
val_sources, val_targets = prepare_sources_targets(articles_dict, sample_size = 300, cite_sample=10)

# Split the data
# train_sources, train_targets, val_sources, val_targets = split_data(sources, targets)

# Initialize model (assuming it's already defined)
model = AutoregressiveCitationMatcher(model_name="bert-base-uncased")

os.environ["TOKENIZERS_PARALLELISM"] = "false"

# Train the modelk
train_model(
    model=model,
    train_sources=train_sources,
    train_targets=train_targets,
    val_sources=val_sources,
    val_targets=val_targets,
    batch_size=32,
    num_epochs=10,
    learning_rate=1.5e-4,
    temperature=0.1,
    num_workers=4,
    save_path='best_model.pt'
)


Using device: cuda


100%|████████████████████████████████████| 2000/2000 [00:00<00:00, 10753.94it/s]


Using device: cuda


100%|███████████████████████████████████████| 300/300 [00:00<00:00, 3469.87it/s]


Using device: cuda
Pre-processing dataset...
Initial samples: 1488
Skipped 163 samples due to missing CITE token
Skipped 0 samples due to processing errors
Final processed samples: 1325
Dataset preprocessing completed.
Pre-processing dataset...
Initial samples: 1404
Skipped 249 samples due to missing CITE token
Skipped 0 samples due to processing errors
Final processed samples: 1155
Dataset preprocessing completed.


Epoch 1/10: 100%|█| 42/42 [01:13<00:00,  1.74s/it, loss=0.3263, avg_loss=1.1375]



Running validation...
Computing validation loss and collecting embeddings...


Processing validation data: 100%|███████████████| 19/19 [00:16<00:00,  1.16it/s]



Computing global rankings for 1155 samples...


Computing rankings: 100%|████████████████████████| 3/3 [00:00<00:00, 204.11it/s]


Computing final metrics...

Validation Metrics:
top_1_accuracy: 0.2026
top_3_accuracy: 0.4372
top_5_accuracy: 0.5576
top_10_accuracy: 0.7186
top_50_accuracy: 0.9186
mrr: 0.3652
median_rank: 4.0000
mean_rank: 23.5437
val_size: 1155.0000
val_loss: 1.7173
Saved new best model with validation loss: 1.7173

Epoch 1 average training loss: 1.1375


Epoch 2/10: 100%|█| 42/42 [01:13<00:00,  1.76s/it, loss=0.0447, avg_loss=0.4000]



Running validation...
Computing validation loss and collecting embeddings...


Processing validation data: 100%|███████████████| 19/19 [00:16<00:00,  1.19it/s]



Computing global rankings for 1155 samples...


Computing rankings: 100%|████████████████████████| 3/3 [00:00<00:00, 470.21it/s]


Computing final metrics...

Validation Metrics:
top_1_accuracy: 0.2139
top_3_accuracy: 0.4502
top_5_accuracy: 0.5654
top_10_accuracy: 0.7082
top_50_accuracy: 0.9169
mrr: 0.3744
median_rank: 4.0000
mean_rank: 22.1931
val_size: 1155.0000
val_loss: 1.6663
Saved new best model with validation loss: 1.6663

Epoch 2 average training loss: 0.4000


Epoch 3/10: 100%|█| 42/42 [01:14<00:00,  1.77s/it, loss=0.0522, avg_loss=0.2361]



Running validation...
Computing validation loss and collecting embeddings...


Processing validation data: 100%|███████████████| 19/19 [00:16<00:00,  1.18it/s]



Computing global rankings for 1155 samples...


Computing rankings: 100%|████████████████████████| 3/3 [00:00<00:00, 464.43it/s]


Computing final metrics...

Validation Metrics:
top_1_accuracy: 0.2121
top_3_accuracy: 0.4468
top_5_accuracy: 0.5671
top_10_accuracy: 0.7238
top_50_accuracy: 0.9048
mrr: 0.3727
median_rank: 4.0000
mean_rank: 23.0459
val_size: 1155.0000
val_loss: 1.7231

Epoch 3 average training loss: 0.2361


Epoch 4/10: 100%|█| 42/42 [01:14<00:00,  1.77s/it, loss=0.0773, avg_loss=0.1843]



Running validation...
Computing validation loss and collecting embeddings...


Processing validation data: 100%|███████████████| 19/19 [00:16<00:00,  1.18it/s]



Computing global rankings for 1155 samples...


Computing rankings: 100%|████████████████████████| 3/3 [00:00<00:00, 458.48it/s]


Computing final metrics...

Validation Metrics:
top_1_accuracy: 0.2104
top_3_accuracy: 0.4554
top_5_accuracy: 0.5680
top_10_accuracy: 0.7169
top_50_accuracy: 0.9048
mrr: 0.3731
median_rank: 4.0000
mean_rank: 29.7056
val_size: 1155.0000
val_loss: 1.7157

Epoch 4 average training loss: 0.1843


Epoch 5/10: 100%|█| 42/42 [01:14<00:00,  1.77s/it, loss=0.0404, avg_loss=0.1642]



Running validation...
Computing validation loss and collecting embeddings...


Processing validation data: 100%|███████████████| 19/19 [00:16<00:00,  1.18it/s]



Computing global rankings for 1155 samples...


Computing rankings: 100%|████████████████████████| 3/3 [00:00<00:00, 477.49it/s]


Computing final metrics...

Validation Metrics:
top_1_accuracy: 0.2147
top_3_accuracy: 0.4485
top_5_accuracy: 0.5732
top_10_accuracy: 0.7333
top_50_accuracy: 0.9091
mrr: 0.3766
median_rank: 4.0000
mean_rank: 24.2701
val_size: 1155.0000
val_loss: 1.7268

Epoch 5 average training loss: 0.1642


Epoch 6/10:  36%|▎| 15/42 [00:26<00:48,  1.80s/it, loss=0.1702, avg_loss=0.1502]