In [1]:
from utils import * 

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Standard library imports
import bz2
import json
import logging
import os
import re
import sqlite3
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union
import xml.etree.ElementTree as ET

# Third-party imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    PreTrainedModel,
    PretrainedConfig,
    Trainer,
    TrainingArguments
)
import tqdm 
import yaml

@dataclass
class WikiArticle:
    """Represents a Wikipedia article with its metadata."""
    title: str
    text: str
    timestamp: str
    is_redirect: bool

class WikiDumpProcessor:
    """Processes Wikipedia XML dumps and extracts articles."""
    
    def __init__(self, dump_path: str):
        self.dump_path = dump_path
        self._ns = {'mw': 'http://www.mediawiki.org/xml/export-0.10/'}
        self._skip_prefixes = {
            'Wikipedia:', 'Template:', 'Category:', 'Portal:', 'File:', 
            'MediaWiki:', 'Help:', 'Book:', 'Draft:', 'TimedText:', 
            'Module:', 'Special:'
        }

    def iter_articles(self, skip_redirects: bool = True) -> Iterator[WikiArticle]:
        """Iterates through valid articles in the dump."""
        dump_file = bz2.BZ2File(self.dump_path) if self.dump_path.endswith('.bz2') else open(self.dump_path, 'rb')
        
        for _, elem in ET.iterparse(dump_file, events=('end',)):
            if not elem.tag.endswith('page'):
                continue

            # Extract basic article data
            title = elem.find('.//mw:title', self._ns).text
            if any(title.startswith(prefix) for prefix in self._skip_prefixes):
                elem.clear()
                continue

            # Get revision data
            rev = elem.find('.//mw:revision', self._ns)
            text = rev.find('mw:text', self._ns).text if rev is not None else ''
            timestamp = rev.find('mw:timestamp', self._ns).text if rev is not None else ''
            is_redirect = bool(re.match(r'#REDIRECT', text or '', re.IGNORECASE))

            if skip_redirects and is_redirect:
                elem.clear()
                continue

            yield WikiArticle(title=title, text=text, timestamp=timestamp, is_redirect=is_redirect)
            elem.clear()

class ArticleStorage:
    """Handles storage and retrieval of Wikipedia articles."""
    
    def __init__(self, processor: WikiDumpProcessor):
        self.processor = processor

    def save_to_jsonl(self, output_path: Union[str, Path], sample_size: Optional[int] = None) -> int:
        """Saves articles to a JSONL file."""
        count = 0
        with open(output_path, 'w', encoding='utf-8') as f:
            for i, article in enumerate(self.processor.iter_articles()):
                if sample_size is not None and i >= sample_size:
                    break
                json.dump(article.__dict__, f, ensure_ascii=False)
                f.write('\n')
                count += 1
        return count

    def save_to_sqlite(self, db_path: Union[str, Path], sample_size: Optional[int] = None,
                      batch_size: int = 1000) -> int:
        """Saves articles to a SQLite database."""
        conn = sqlite3.connect(db_path)
        c = conn.cursor()
        
        c.execute('''CREATE TABLE IF NOT EXISTS articles
                    (title TEXT PRIMARY KEY, text TEXT, timestamp TEXT, is_redirect INTEGER)''')
        c.execute('CREATE INDEX IF NOT EXISTS idx_title ON articles(title)')
        
        count = 0
        batch = []
        
        try:
            for i, article in enumerate(self.processor.iter_articles()):
                if sample_size is not None and i >= sample_size:
                    break
                    
                batch.append((article.title, article.text, article.timestamp, 
                            1 if article.is_redirect else 0))
                
                if len(batch) >= batch_size:
                    c.executemany('INSERT OR REPLACE INTO articles VALUES (?, ?, ?, ?)', batch)
                    conn.commit()
                    count += len(batch)
                    batch = []
            
            if batch:
                c.executemany('INSERT OR REPLACE INTO articles VALUES (?, ?, ?, ?)', batch)
                conn.commit()
                count += len(batch)
                
        finally:
            conn.close()
            
        return count



class WikiProcessor:
    """Prepares citation data for model training."""

    def __init__(self, jsonl_path: str = "data/wiki_articles.jsonl"):
        
        # Load articles
        logging.info("Loading articles from JSONL file...")
        self.articles_dict = {}
        self.id2ref = {}
        self.ref2id = {}
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                article = json.loads(line)
                ref = article['title'].lower()
                id = len(self.articles_dict) + 1
                self.articles_dict[ref] = self.clean_wiki_text(article['text'])
                self.ref2id[ref] = id 
                self.id2ref[id] = ref
        logging.info(f"Loaded {len(self.articles_dict)} articles.")

    def _find_citations(self,text):
        citations = []
        for match in re.finditer(r'\[\[(.*?)\]\]', text):
            match_text = match.group(1)
            citation = match_text.split('|') if '|' in match_text else [match_text]
            citation = [(c.split('#')[0] if '#' in c else c) for c in citation]
            ref = None
            for cit in citation:
                if cit.lower() in self.articles_dict:
                    ref = cit.lower()
                    break
            if ref:
                citations.append((match.start(), match.end(), self.ref2id[ref]))
        return citations

    @staticmethod
    def clean_wiki_text(text: str) -> str:
        """Cleans wiki content by removing metadata and formatting."""
        # Find main content starting from first bold title
        match = re.search(r"'''([^']+?)'''", text)
        if match:
            text = text[match.start():]

        # Remove wiki elements and clean up
        text = re.sub(r'\[\[File:.*\]\]|\[\[Category:.*\]\]|\{\{stub.*\}\}', '', text)
        return '\n'.join(line for line in text.split('\n') if line.strip())

    def find_source_citations(self) -> Tuple[List[str], List[Tuple[List[str], int, int]]]:
        """Creates source-target pairs for citation matching."""

        articles = list(self.articles_dict.keys())
        sources = []
        citation_data = []

        for title in articles:
            text = self.articles_dict[title]
            source_text = self.clean_wiki_text(text)
            citations = self._find_citations(source_text)            
            sources.append(source_text)
            citation_data.append(citations)

        return sources, citation_data


# experiment related 

@dataclass
class ExperimentConfig:
    """Unified configuration for the entire citation matching experiment."""
    # Model configuration
    model_name: str = "bert-base-uncased"
    vocab_size: Optional[int] = None  # Will be set after tokenizer initialization
    initial_logit_scale: float = np.log(1/0.07)  # Same initial value as CLIP
    
    # Token configuration
    cite_token: str = "<CITE>"
    ref_token: str = "<REF>"
    cite_token_id: Optional[int] = None  # Will be set after tokenizer initialization
    ref_token_id: Optional[int] = None   # Will be set after tokenizer initialization
    
    # Text processing configuration
    max_length: int = 512
    source_len: int = 512
    target_len: int = 128
    max_targets: int = 5
    overlap: float = 0.5
    
    # Training configuration
    num_epochs: int = 15
    learning_rate: float = 1.5e-4
    logits_learning_rate: float = 1.5e-2,
    Adam_eps: float = 1e-8
    weight_decay: float = 0.01
    warmup_steps: int = 0
    batch_size: int = 200
    train_ratio: float = 0.8
    collate_sample_size: Optional[int] = None
    
    # Evaluation configuration
    k_values: List[int] = field(default_factory=lambda: [1, 5, 10, 50, 100, 1000])
    
    # Logging and saving configuration
    project_name: str = "citation-matching"
    run_name: Optional[str] = None
    save_path: str = "./experiments/best_citation_model.pt"
    cache_dir: str = "cache"
    
    # Hardware configuration
    device: Optional[torch.device] = None
    
    def __post_init__(self):
        if self.device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def get_cache_path(sources, model_name: str, cache_dir: str) -> str:
    """Generate a unique cache path based on input data and model name."""
    # Create a hash of the sources and model name
    content_hash = hashlib.md5(str(sources).encode()).hexdigest()
    model_hash = hashlib.md5(model_name.encode()).hexdigest()[:8]
    return os.path.join(cache_dir, f"tokenized_{model_hash}_{content_hash}.pt")

def tokenize_sources(sources=None, citation_data=None, tokenizer=None, batch_size=1000, cache_dir="cache", cache_path=None):
    # Generate cache path
    if cache_path is None:
        cache_path = get_cache_path(sources, tokenizer.name_or_path, cache_dir)
    
    # Check if cached results exist
    if os.path.exists(cache_path):
        logging.info(f"Loading cached tokenized results from {cache_path}")
        return torch.load(cache_path, weights_only=False)
    
    logging.info("Tokenizing sources...")
    # Process in batches
    all_results = []
    for batch_start in tqdm.tqdm(range(0, len(sources), batch_size), total=len(sources)//batch_size):
        batch_end = min(batch_start + batch_size, len(sources))
        batch_sources = sources[batch_start:batch_end]
        batch_citations = citation_data[batch_start:batch_end]
        
        # Batch encode
        batch_encoded = tokenizer.batch_encode_plus(
            batch_sources,
            add_special_tokens=False,
            return_offsets_mapping=True,
            padding=False,
            return_tensors=None
        )
        
        # Process each item in the batch
        for idx in range(len(batch_sources)):
            offset_mapping = batch_encoded["offset_mapping"][idx]
            input_ids = batch_encoded["input_ids"][idx]
            
            # Create offset to index mapping
            off2i = {s:i for i, (s,_) in enumerate(offset_mapping)}
            off2i.update({e:i+1 for i, (_,e) in enumerate(offset_mapping)})
            
            # Create citation tokens array
            mask_tokens = np.zeros(len(input_ids), dtype=int)
            cite_tokens = np.zeros(len(input_ids), dtype=int)
            
            # Fill in citations
            for i, j, art_id in batch_citations[idx]:
                s, e = off2i[i], off2i[j]
                cite_tokens[s] = art_id
                mask_tokens[s:e] = art_id
            
            # Store results
            all_results.append({
                'input_ids': np.array(input_ids),
                'cite_tokens': cite_tokens,
                'mask_tokens': mask_tokens,
                'attention_mask': batch_encoded["attention_mask"][idx] if "attention_mask" in batch_encoded else None
            })

    # Cache the results
    os.makedirs(cache_dir, exist_ok=True)
    torch.save(all_results, cache_path)
    logging.info(f"Cached tokenized results to {cache_path}")
    
    return all_results

def collate(results, tokenizer, config):
    cite_token = tokenizer.convert_tokens_to_ids(config.cite_token)
    ref_token = tokenizer.convert_tokens_to_ids(config.ref_token)
    bracket_tokens = tokenizer.convert_tokens_to_ids(['[',']'])
    pad_token = tokenizer.pad_token_id

    collated_data = []
    # id_to_tokenized = {i: result for i, result in enumerate(results)}
    
    for i in tqdm.tqdm(range(len(results))):
        result = results[i]
        if config.collate_sample_size and len(collated_data)>config.collate_sample_size:
            break
        
        # Process each source segment
        for s in range(0, len(result['input_ids']), int((1-config.overlap)*config.source_len)):
            e = s + config.source_len
            
            # Get source segment
            input_ids = result['input_ids'][s:e].copy()
            cite_tokens = result['cite_tokens'][s:e]
            mask_tokens = result['mask_tokens'][s:e]
            
            # Skip if segment is too short
            if len(input_ids) < config.source_len // 2:
                continue
                
            # Get all citations from this segment
            present_citations = np.unique(cite_tokens[cite_tokens > 0])
            if len(present_citations) > config.max_targets:
                present_citations = np.random.choice(present_citations, config.max_targets, replace=False)
            max_targets = min(config.max_targets, len(present_citations))

            # Skip if segment is too short
            if len(input_ids) < config.source_len // 2:
                continue
            # Skip if no citations
            if max_targets == 0:
                continue
            
            # Initialize target arrays
            target_ids = np.full((max_targets, config.target_len), pad_token, dtype=np.int64)
            target_attention_mask = np.zeros((max_targets, config.target_len), dtype=np.int64)
            
            
            # Prepare source: 
            # only keep citation tokens that are sampled to be masked 
            cite_tokens_mask = np.isin(cite_tokens, present_citations)
            # don't mask citations that are not sampled 
            mask_tokens = np.where(np.isin(mask_tokens, present_citations), mask_tokens, 0)
            # remove brackets from the rest of the text 
            mask_tokens = np.where(np.isin(input_ids,bracket_tokens),1, mask_tokens)
            # don't mask the citation tokens 
            mask_tokens[cite_tokens_mask] = 0
            # set the citation tokens (first token of a citation range) as special token <CITE> 
            input_ids[cite_tokens_mask] = cite_token
            # mask all tokens in a citation, except for the first (special) token 
            source_ids = input_ids[mask_tokens == 0]

            # keep the cited article ids in the text in the order they appear (with repeats)
            # & keep the unique cited artile ids 
            # this will enable us to link each special cite token to a target via the article id
            target_art_ids = present_citations
            cited_art_ids = cite_tokens[cite_tokens_mask]
            
            # Pad or truncate source
            if len(source_ids) > config.source_len:
                source_ids = source_ids[:config.source_len]
            elif len(source_ids) < config.source_len:
                source_ids = np.pad(source_ids, 
                                  (0, config.source_len - len(source_ids)),
                                  'constant', 
                                  constant_values=pad_token)
            
            # Create source attention mask
            attention_mask = (source_ids != pad_token).astype(np.int64)
            
            # Process each target
            for idx, citation_id in enumerate(present_citations):
                # Get pre-tokenized target content
                # ids are 1-indexed 
                target_data = results[citation_id - 1]
                target_tokens = target_data['input_ids']
                
                # Truncate if needed and add ref_token
                if len(target_tokens) >= config.target_len - 1:
                    target_tokens = target_tokens[:config.target_len-1]
                target_tokens = np.append(target_tokens, ref_token)
                
                # Pad to target_len
                if len(target_tokens) < config.target_len:
                    target_tokens = np.pad(target_tokens,
                                         (0, config.target_len - len(target_tokens)),
                                         'constant',
                                         constant_values=pad_token)
                
                # Store in target arrays
                target_ids[idx] = target_tokens
                target_attention_mask[idx] = (target_tokens != pad_token)
                # citation_ids[idx] = citation_id


            # Store the collected data
            collated_data.append({
                'source_art_id': i+1,
                'source_ids': torch.tensor(source_ids, dtype=torch.long),
                'cited_art_ids': torch.tensor(cited_art_ids, dtype=torch.long),
                'target_art_ids': torch.tensor(target_art_ids, dtype=torch.long),
                'target_ids': torch.tensor(target_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
                'target_attention_mask': torch.tensor(target_attention_mask, dtype=torch.long),
            })
    
    return collated_data

class CitationDataset(torch.utils.data.Dataset):
    """Dataset for citation data with stacked targets."""
    
    def __init__(self, collated_data):
        self.data = collated_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def citation_collate_fn(batch):
    # Stack sources normally
    source_ids = torch.stack([item['source_ids'] for item in batch])
    cited_art_ids = torch.cat([item['cited_art_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    # Concatenate targets
    target_art_ids_all = torch.cat([item['target_art_ids'] for item in batch])
    target_ids = torch.cat([item['target_ids'] for item in batch])
    target_attention_mask = torch.cat([item['target_attention_mask'] for item in batch])

    # Get unique indices and inverse indices
    target_art_ids, unique_indices = np.unique(target_art_ids_all.numpy(), return_index=True)
    target_art_ids = torch.tensor(target_art_ids)
    unique_indices = torch.tensor(unique_indices)
    
    # Use unique indices to get corresponding targets
    target_ids = target_ids[unique_indices]
    target_attention_mask = target_attention_mask[unique_indices]

    id2i = {id.item():i for i,id in enumerate(target_art_ids)}
    labels = torch.tensor([id2i[id.item()] for id in cited_art_ids],dtype=torch.long)

      
    return {
        'source_ids': source_ids,
        'cited_art_ids': cited_art_ids,
        'target_art_ids': target_art_ids,
        'target_ids': target_ids,
        'attention_mask': attention_mask,
        'target_attention_mask': target_attention_mask,
        'labels': labels,
    }



@dataclass
class CitationModelOutput:
    """Custom output class for the citation model."""
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    cite_embeds: Optional[torch.FloatTensor] = None
    ref_embeds: Optional[torch.FloatTensor] = None

class CitationModel(nn.Module):
    """Custom model for citation matching using transformer embeddings."""
    
    def __init__(self, config: ExperimentConfig):
        super().__init__()
        
        # Load base model configuration
        base_config = AutoConfig.from_pretrained(config.model_name)
        
        # Store configuration
        self.config = config
        
        # Load base transformer model
        self.transformer = AutoModel.from_pretrained(config.model_name)
        
        # Resize token embeddings if needed
        if config.vocab_size != self.transformer.config.vocab_size:
            self.transformer.resize_token_embeddings(config.vocab_size)

        # Add learnable logit scale parameter
        self.logit_scale = nn.Parameter(torch.ones([]) * config.initial_logit_scale)

    
    def get_citation_masks(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Create mask for citation token positions."""
        return input_ids == self.config.cite_token_id
    
    def get_reference_masks(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Create mask for reference token positions."""
        return input_ids == self.config.ref_token_id
    
    def forward(
        self,
        source_ids: torch.Tensor,
        target_ids: torch.Tensor,
        labels: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        target_attention_mask: Optional[torch.Tensor] = None,
        cited_art_ids: Optional[torch.Tensor] = None,
        target_art_ids: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[Tuple, CitationModelOutput]:
        """Forward pass of the model."""
        
        # Process source text
        source_outputs = self.transformer(
            input_ids=source_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Process target text
        target_outputs = self.transformer(
            input_ids=target_ids,
            attention_mask=target_attention_mask,
            return_dict=True
        )
        
        # Get citation mask and extract citation embeddings
        cite_mask = self.get_citation_masks(source_ids)
        cite_embeds = source_outputs.last_hidden_state[cite_mask]
        
        # Get reference mask and extract reference embeddings
        ref_mask = self.get_reference_masks(target_ids)
        ref_embeds = target_outputs.last_hidden_state[ref_mask]
        
        # Normalize embeddings
        cite_embeds = F.normalize(cite_embeds, p=2, dim=-1)
        ref_embeds = F.normalize(ref_embeds, p=2, dim=-1)
        
        # Clamp logit scale to prevent numerical instability
        logit_scale = torch.clamp(self.logit_scale, 0, torch.log(torch.tensor(100.0)))
        
        # Compute similarity scores with learned scale
        logits = torch.matmul(cite_embeds, ref_embeds.t()) * logit_scale.exp()

        # compute the loss 
        loss = F.cross_entropy(logits, labels)
        
        if return_dict:
            return CitationModelOutput(
                loss=loss,
                logits=logits,
                cite_embeds=cite_embeds,
                ref_embeds=ref_embeds
            )
        
        return (loss, logits, cite_embeds, ref_embeds)


def compute_retrieval_metrics(logits, labels, ks=[1, 5, 10, 50, 100, 1000]):
    # Get rankings of correct targets
    correct_scores = logits[torch.arange(logits.size(0)), labels]
    rankings = (logits >= correct_scores.unsqueeze(1)).sum(1)
    
    # Compute MRR
    mrr = (1.0 / rankings).mean().item()
    
    # Compute top-k accuracy for different k values
    metrics = {'mrr': mrr}
    for k in ks:
        if k <= logits.size(1):  # Only compute if k is not larger than number of targets
            top_k_acc = (rankings <= k).float().mean().item()
            metrics[f'top_{k}_accuracy'] = top_k_acc
    
    return metrics

def validate_citation_model(
    model,
    val_dataloader,
    device: str = None,
    return_embeddings: bool = False,
    k_values: List[int] = [1, 5, 10, 50, 100, 1000],
    similarity_batch_size: int = 4096
):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model.eval()
    
    # Lists to store accumulated embeddings and IDs
    all_cite_embeds = []
    all_ref_embeds = []
    all_cited_art_ids = []
    all_target_art_ids = []
    
    # Accumulate embeddings and IDs
    with torch.no_grad():
        # Get and clamp logit scale
        for batch in tqdm.tqdm(val_dataloader, desc="Computing embeddings"):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Process source text
            source_outputs = model.transformer(
                input_ids=batch['source_ids'],
                attention_mask=batch['attention_mask'],
                return_dict=True
            )
            
            # Process target text
            target_outputs = model.transformer(
                input_ids=batch['target_ids'],
                attention_mask=batch['target_attention_mask'],
                return_dict=True
            )
            
            # Get citation mask and extract citation embeddings
            cite_mask = model.get_citation_masks(batch['source_ids'])
            cite_embeds = source_outputs.last_hidden_state[cite_mask]
            
            # Get reference mask and extract reference embeddings
            ref_mask = model.get_reference_masks(batch['target_ids'])
            ref_embeds = target_outputs.last_hidden_state[ref_mask]
            
            # Normalize embeddings
            cite_embeds = F.normalize(cite_embeds, p=2, dim=-1)
            ref_embeds = F.normalize(ref_embeds, p=2, dim=-1)
            
            # Store embeddings and IDs
            all_cite_embeds.append(cite_embeds.cpu())
            all_ref_embeds.append(ref_embeds.cpu())
            all_cited_art_ids.append(batch['cited_art_ids'].cpu())
            all_target_art_ids.append(batch['target_art_ids'].cpu())
    
    # Concatenate all accumulated tensors
    cite_embeds = torch.cat(all_cite_embeds)
    ref_embeds = torch.cat(all_ref_embeds)
    cited_art_ids = torch.cat(all_cited_art_ids)
    target_art_ids = torch.cat(all_target_art_ids)
    
    # Get unique target art IDs and create mapping
    target_art_ids_unique, unique_indices = np.unique(target_art_ids.numpy(), return_index=True)
    target_art_ids_unique = torch.tensor(target_art_ids_unique)
    ref_embeds_unique = ref_embeds[torch.tensor(unique_indices)]
    
    # Create ID to index mapping
    id2i = {id.item(): i for i, id in enumerate(target_art_ids_unique)}
    labels = torch.tensor([id2i[id.item()] for id in cited_art_ids], dtype=torch.long)
    
    # Move reference embeddings to device once
    ref_embeds_unique = ref_embeds_unique.to(device)
    
    # Initialize arrays to store results
    total_loss = 0
    total_correct = 0
    all_predictions = []
    all_logits_list = []
    all_labels_list = []
    
    # Process citation embeddings in batches
    num_batches = (len(cite_embeds) + similarity_batch_size - 1) // similarity_batch_size

    # Get and clamp logit scale
    logit_scale = torch.clamp(model.logit_scale, 0, torch.log(torch.tensor(100.0)))
    
    for i in tqdm.tqdm(range(num_batches), desc="Computing similarities"):
        start_idx = i * similarity_batch_size
        end_idx = min((i + 1) * similarity_batch_size, len(cite_embeds))
        
        # Move batch of citation embeddings to device
        cite_embeds_batch = cite_embeds[start_idx:end_idx].to(device)
        labels_batch = labels[start_idx:end_idx].to(device)
        
        # Compute similarity scores for this batch
        logits_batch = torch.matmul(cite_embeds_batch, ref_embeds_unique.t()) * logit_scale.exp()
        
        # Compute loss for this batch
        loss_batch = F.cross_entropy(logits_batch, labels_batch)
        total_loss += loss_batch.item() * len(labels_batch)
        
        # Compute predictions and accuracy for this batch
        predictions_batch = torch.argmax(logits_batch, dim=-1)
        total_correct += (predictions_batch == labels_batch).sum().item()
        
        # Store predictions and logits for later computation
        all_predictions.append(predictions_batch.cpu())
        all_logits_list.append(logits_batch.cpu())
        all_labels_list.append(labels_batch.cpu())
        
        # Clear GPU memory
        del logits_batch, cite_embeds_batch, labels_batch, predictions_batch
        torch.cuda.empty_cache()
    
    # Concatenate all predictions and compute accuracy
    all_predictions = torch.cat(all_predictions)
    num_citations = len(cite_embeds)
    accuracy = total_correct / num_citations
    avg_loss = total_loss / num_citations
    
    # Compute retrieval metrics using concatenated logits
    all_logits = torch.cat(all_logits_list)
    all_labels = torch.cat(all_labels_list)
    retrieval_metrics = compute_retrieval_metrics(all_logits, all_labels, ks=k_values)
    
    # Prepare results dictionary (matching original function's output)
    results = {
        'loss': avg_loss,
        'accuracy': accuracy,
        'num_citations': num_citations,
        'num_unique_targets': len(target_art_ids_unique),
        'mrr': retrieval_metrics['mrr']
    }
    
    # Add top-k accuracies to results
    for k in k_values:
        if f'top_{k}_accuracy' in retrieval_metrics:
            results[f'top_{k}_accuracy'] = retrieval_metrics[f'top_{k}_accuracy']
    
    if return_embeddings:
        results.update({
            'cite_embeds': cite_embeds.cpu(),
            'ref_embeds': ref_embeds_unique.cpu(),
            'cited_art_ids': cited_art_ids,
            'target_art_ids': target_art_ids_unique,
            'logits': torch.cat(all_logits_list),  # Full logits matrix
            'labels': labels.cpu()
        })
    
    return results
    
def train_citation_model(
    model: CitationModel,
    results: List[dict],
    tokenizer: AutoTokenizer,
    config: ExperimentConfig,
    device: Optional[str] = None,
    project_name: Optional[str] = None,
    run_name: Optional[str] = None
) -> CitationModel:
    """
    Training function with corrected wandb logging.
    """
    import wandb
    
    # Initialize wandb with config parameters
    project_name = project_name or config.project_name
    run_name = run_name or config.run_name

    
    wandb.init(
        project=project_name,
        name=run_name,
        config=config,
    )
    
    # Set device
    print(f"Experiment config: {config}")
    device = device or config.device
    print(f"Using device: {device}")

    
    global_step = 0  # Add global step counter
    
    # Initialize gradient scaler for mixed precision training
    scaler = GradScaler()
    
    # Move model to device
    model = model.to(device)
    
    # Enable memory efficient training
    model.transformer.gradient_checkpointing_enable()
    
    # # Initialize optimizer with config parameters
    # optimizer = AdamW(
    #     model.parameters(), 
    #     lr=config.learning_rate, 
    #     weight_decay=config.weight_decay,
    #     eps=config.Adam_eps
    # )
    # Initialize optimizer with separate parameter groups
    optimizer = AdamW([
        {
            'params': [param for name, param in model.named_parameters() if name != 'logit_scale'],
            'lr': config.learning_rate,
            'weight_decay': config.weight_decay,
            'eps': config.Adam_eps
        },
        {
            'params': [model.logit_scale],  # separate group for logit_scale
            'lr': config.logits_learning_rate,  # higher learning rate for logit_scale
            'weight_decay': 0  # typically don't need weight decay for scale parameter
        }
    ], lr=config.learning_rate)
    
    best_val_metrics = {'loss': float('inf')}
    
    for epoch in range(config.num_epochs):
        print(f"\nEpoch {epoch + 1}/{config.num_epochs}")

        # Log current logit scale instead of temperature
        current_scale = model.logit_scale.exp().item()
        print(f"Current logit scale: {current_scale:.4f}")
        wandb.log({"logit_scale": current_scale}, step=global_step)
        
        # Create new collated data for this epoch
        print("Collating training data with new random masks...")
        collated = collate(results, tokenizer, config)
        dataset = CitationDataset(collated)
        train_size = int(len(dataset) * config.train_ratio)
        train_dataset = dataset[:train_size]
        val_dataset = dataset[train_size:]
        
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=config.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=True,
            collate_fn=citation_collate_fn
        )

        val_dataloader = DataLoader(
            val_dataset,
            batch_size=config.batch_size * 2,
            shuffle=False,
            num_workers=4,
            pin_memory=True,
            collate_fn=citation_collate_fn
        )
        
        # Training phase
        model.train()
        total_train_loss = 0
        train_steps = 0
        
        progress_bar = tqdm.tqdm(train_dataloader, desc="Training")
        
        for batch_idx, batch in enumerate(progress_bar):
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Clear gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            with torch.amp.autocast('cuda'):
                outputs = model(**batch)
                loss = outputs.loss
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            # Update tracking variables
            total_train_loss += loss.item()
            train_steps += 1
            
            # Log training metrics with global step
            wandb.log({
                "train/batch_loss": loss.item(),
                "train/learning_rate": optimizer.param_groups[0]["lr"]
            }, step=global_step)
            
            # Update progress bar
            progress_bar.set_postfix({'loss': loss.item()})
            
            global_step += 1  # Increment global step after each batch
        
        avg_train_loss = total_train_loss / train_steps
        print(f"\nAverage training loss: {avg_train_loss:.4f}")
        
        # Log epoch-level training metrics
        wandb.log({
            "train/epoch_loss": avg_train_loss,
            "epoch": epoch
        }, step=global_step)
        
        # Validation phase
        print("\nRunning validation...")
        torch.cuda.empty_cache()
        model.eval()
        
        val_metrics = validate_citation_model(
            model=model,
            val_dataloader=val_dataloader,
            device=device,
            k_values=config.k_values
        )
        
        # Log validation metrics
        wandb_val_metrics = {
            "val/loss": val_metrics['loss'],
            "val/accuracy": val_metrics['accuracy'],
            "val/mrr": val_metrics['mrr']
        }
        
        # Add top-k accuracies to wandb metrics
        for k in config.k_values:
            if f'top_{k}_accuracy' in val_metrics:
                wandb_val_metrics[f"val/top_{k}_accuracy"] = val_metrics[f'top_{k}_accuracy']
        
        wandb.log(wandb_val_metrics, step=global_step)
        
        # Print validation metrics
        print(f"\nValidation metrics:")
        print(f"  Loss: {val_metrics['loss']:.4f}")
        print(f"  Accuracy (top-1): {val_metrics['accuracy']:.4f}")
        print(f"  Mean Reciprocal Rank: {val_metrics['mrr']:.4f}")
        print(f"  Number of citations: {val_metrics['num_citations']}")
        print(f"  Number of unique targets: {val_metrics['num_unique_targets']}")
        print("\nTop-k accuracy:")
        for k in config.k_values:
            if f'top_{k}_accuracy' in val_metrics:
                print(f"  k={k}: {val_metrics[f'top_{k}_accuracy']:.4f}")
        
        # Save best model based on validation loss
        if val_metrics['loss'] < best_val_metrics['loss']:
            best_val_metrics = val_metrics
            save_dict = {
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'validation_metrics': val_metrics,
                'logit_scale': model.logit_scale.item(),
                'global_step': global_step
            }
            torch.save(save_dict, config.save_path)
            
            # Log best metrics to wandb
            wandb.run.summary.update({
                "best_val_loss": best_val_metrics['loss'],
                "best_val_accuracy": best_val_metrics['accuracy'],
                "best_val_mrr": best_val_metrics['mrr'],
                "best_model_epoch": epoch,
                "best_model_step": global_step
            })
            
            print(f"\nSaved new best model to {config.save_path}")
            print(f"Best validation metrics so far:")
            print(f"  Loss: {best_val_metrics['loss']:.4f}")
            print(f"  Accuracy (top-1): {best_val_metrics['accuracy']:.4f}")
            print(f"   MRR: {best_val_metrics['mrr']:.4f}")
    
    # Finish wandb run
    wandb.finish()
    
    return model


class Experiment:
    def __init__(self,config: ExperimentConfig):
        # get tokenizer 
        tokenizer = AutoTokenizer.from_pretrained(config.model_name)
        tokenizer.add_special_tokens({
            'additional_special_tokens': [config.cite_token, config.ref_token]
        })
        """Update config with tokenizer-dependent values."""
        config.cite_token_id = tokenizer.convert_tokens_to_ids(config.cite_token)
        config.ref_token_id = tokenizer.convert_tokens_to_ids(config.ref_token)
        config.vocab_size = len(tokenizer)
        model = CitationModel(config)
        
        self.config = config 
        self.model = model 
        self.tokenizer = tokenizer 


    def get_model(self, load_path = None):
        if load_path:
            ckpoint = torch.load('./experiments/best_citation_model.pt', weights_only=True)
            self.model.load_state_dict(ckpoint['model_state_dict'])
            self.model.config.temperature = checkpoint['temperature']

        return self.model

    def get_tokenizer(self):
        return self.tokenizer


    def get_results(self, cache_path=None, tokenizer=None):
        if cache_path:
            results = tokenize_sources(cache_path=cache_path)
        else:
            # Load articles
            preprocessor = WikiProcessor()
            sources, citation_data = preprocessor.find_source_citations()
            
            results = tokenize_sources(sources, citation_data, tokenizer, cache_dir="cache",)
        return results
            
logging.basicConfig(level=logging.INFO)

config = ExperimentConfig(collate_sample_size=5000,initial_logit_scale=np.log(1.0/0.5))

experiment = Experiment(config)

tokenizer = experiment.get_tokenizer()

model = experiment.get_model()

# results = experiment.get_results(cache_path='./cache/tokenized_1caf5def_eb27a5477eaa3d549aebc4886f3717d1.pt')

# train model using config 
trained_model = train_citation_model(
    model=model,
    results=results,
    tokenizer=tokenizer,
    config=config
)

INFO:root:Loading cached tokenized results from ./cache/tokenized_1caf5def_eb27a5477eaa3d549aebc4886f3717d1.pt
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mamirjoudaki[0m. Use [1m`wandb login --relogin`[0m to force relogin


Experiment config: ExperimentConfig(model_name='bert-base-uncased', vocab_size=30524, initial_logit_scale=0.6931471805599453, cite_token='<CITE>', ref_token='<REF>', cite_token_id=30522, ref_token_id=30523, max_length=512, source_len=512, target_len=128, max_targets=5, overlap=0.5, num_epochs=15, learning_rate=0.00015, Adam_eps=1e-08, weight_decay=0.01, warmup_steps=0, batch_size=200, train_ratio=0.8, collate_sample_size=5000, k_values=[1, 5, 10, 50, 100, 1000], project_name='citation-matching', run_name=None, save_path='./experiments/best_citation_model.pt', cache_dir='cache', device=device(type='cuda'))
Using device: cuda


  scaler = GradScaler()



Epoch 1/15
Current logit scale: 2.0000
Collating training data with new random masks...


  0%|▎                                                                                                                  | 664/237381 [00:01<07:26, 529.68it/s]
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [01:10<00:00,  3.37s/it, loss=1.9]



Average training loss: 5.5125

Running validation...


Computing embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.09s/it]
Computing similarities: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33.97it/s]



Validation metrics:
  Loss: 6.5471
  Accuracy (top-1): 0.0552
  Mean Reciprocal Rank: 0.1160
  Number of citations: 5888
  Number of unique targets: 2498

Top-k accuracy:
  k=1: 0.0552
  k=5: 0.1632
  k=10: 0.2351
  k=50: 0.4484
  k=100: 0.6223
  k=1000: 0.9614

Saved new best model to ./experiments/best_citation_model.pt
Best validation metrics so far:
  Loss: 6.5471
  Accuracy (top-1): 0.0552
   MRR: 0.1160

Epoch 2/15
Current logit scale: 2.7514
Collating training data with new random masks...


  0%|▎                                                                                                                  | 664/237381 [00:01<07:52, 500.59it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [01:11<00:00,  3.39s/it, loss=1.7]



Average training loss: 4.7028

Running validation...


Computing embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.11s/it]
Computing similarities: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 33.88it/s]



Validation metrics:
  Loss: 5.8334
  Accuracy (top-1): 0.0727
  Mean Reciprocal Rank: 0.1456
  Number of citations: 5903
  Number of unique targets: 2455

Top-k accuracy:
  k=1: 0.0727
  k=5: 0.2092
  k=10: 0.2909
  k=50: 0.5296
  k=100: 0.7484
  k=1000: 0.9787

Saved new best model to ./experiments/best_citation_model.pt
Best validation metrics so far:
  Loss: 5.8334
  Accuracy (top-1): 0.0727
   MRR: 0.1456

Epoch 3/15
Current logit scale: 3.9575
Collating training data with new random masks...


  0%|▎                                                                                                                  | 664/237381 [00:02<13:37, 289.70it/s]
Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [01:11<00:00,  3.39s/it, loss=1.22]



Average training loss: 3.8944

Running validation...


Computing embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.11s/it]
Computing similarities: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 35.80it/s]



Validation metrics:
  Loss: 5.2661
  Accuracy (top-1): 0.0901
  Mean Reciprocal Rank: 0.1672
  Number of citations: 5958
  Number of unique targets: 2470

Top-k accuracy:
  k=1: 0.0901
  k=5: 0.2333
  k=10: 0.3145
  k=50: 0.5606
  k=100: 0.8157
  k=1000: 0.9844

Saved new best model to ./experiments/best_citation_model.pt
Best validation metrics so far:
  Loss: 5.2661
  Accuracy (top-1): 0.0901
   MRR: 0.1672

Epoch 4/15
Current logit scale: 5.7838
Collating training data with new random masks...


  0%|▎                                                                                                                  | 664/237381 [00:01<08:04, 488.81it/s]
Training: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [01:10<00:00,  3.37s/it, loss=1.15]



Average training loss: 3.0907

Running validation...


Computing embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.12s/it]
Computing similarities: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 35.49it/s]



Validation metrics:
  Loss: 4.7646
  Accuracy (top-1): 0.1147
  Mean Reciprocal Rank: 0.2019
  Number of citations: 5928
  Number of unique targets: 2487

Top-k accuracy:
  k=1: 0.1147
  k=5: 0.2807
  k=10: 0.3731
  k=50: 0.6253
  k=100: 0.7952
  k=1000: 0.9921

Saved new best model to ./experiments/best_citation_model.pt
Best validation metrics so far:
  Loss: 4.7646
  Accuracy (top-1): 0.1147
   MRR: 0.2019

Epoch 5/15
Current logit scale: 8.2439
Collating training data with new random masks...


  0%|▎                                                                                                                  | 664/237381 [00:01<08:05, 487.95it/s]
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [01:10<00:00,  3.38s/it, loss=0.821]



Average training loss: 2.4251

Running validation...


Computing embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.16s/it]
Computing similarities: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 38.89it/s]



Validation metrics:
  Loss: 4.5541
  Accuracy (top-1): 0.1305
  Mean Reciprocal Rank: 0.2211
  Number of citations: 5907
  Number of unique targets: 2470

Top-k accuracy:
  k=1: 0.1305
  k=5: 0.3066
  k=10: 0.3990
  k=50: 0.6326
  k=100: 0.8354
  k=1000: 0.9927

Saved new best model to ./experiments/best_citation_model.pt
Best validation metrics so far:
  Loss: 4.5541
  Accuracy (top-1): 0.1305
   MRR: 0.2211

Epoch 6/15
Current logit scale: 10.9001
Collating training data with new random masks...


  0%|▎                                                                                                                  | 664/237381 [00:01<07:51, 502.10it/s]
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [01:10<00:00,  3.37s/it, loss=0.893]



Average training loss: 1.9685

Running validation...


Computing embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.02s/it]
Computing similarities: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 34.19it/s]



Validation metrics:
  Loss: 4.3011
  Accuracy (top-1): 0.1448
  Mean Reciprocal Rank: 0.2430
  Number of citations: 5884
  Number of unique targets: 2493

Top-k accuracy:
  k=1: 0.1448
  k=5: 0.3351
  k=10: 0.4368
  k=50: 0.6835
  k=100: 0.8666
  k=1000: 0.9932

Saved new best model to ./experiments/best_citation_model.pt
Best validation metrics so far:
  Loss: 4.3011
  Accuracy (top-1): 0.1448
   MRR: 0.2430

Epoch 7/15
Current logit scale: 13.1487
Collating training data with new random masks...


  0%|▎                                                                                                                  | 664/237381 [00:01<08:06, 486.11it/s]
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 21/21 [01:10<00:00,  3.37s/it, loss=0.881]



Average training loss: 1.6358

Running validation...


Computing embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:12<00:00,  4.04s/it]
Computing similarities: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:00<00:00, 38.58it/s]



Validation metrics:
  Loss: 4.4709
  Accuracy (top-1): 0.1571
  Mean Reciprocal Rank: 0.2517
  Number of citations: 5939
  Number of unique targets: 2472

Top-k accuracy:
  k=1: 0.1571
  k=5: 0.3474
  k=10: 0.4395
  k=50: 0.6488
  k=100: 0.7786
  k=1000: 0.9929

Epoch 8/15
Current logit scale: 15.1575
Collating training data with new random masks...


  0%|▎                                                                                                                  | 664/237381 [00:01<08:24, 469.48it/s]
Training:  14%|██████████████▍                                                                                      | 3/21 [00:11<01:06,  3.68s/it, loss=1.48]

In [153]:
sample = dataset[0]
source_art_id = sample['source_art_id']
original_source = sources[source_art_id-1]
source_text = tokenizer.decode(sample['source_ids'], )
cited_art_ids = sample['cited_art_ids']

useless_chars = np.sum([c==']' for c in source_text])*2/len(source_text)
print('useless = ', useless_chars)
print(f"Source original: {source_art_id}:\n{original_source[:1000]}\n\n")
print('#'*50)
print(f"Source tokens decoded:\n{source_text[:]}\n\n")
print(f"Source attention mask: {(sample['attention_mask']==0).sum()}, ")

for i, target_art_id in enumerate(sample['target_art_ids']):
    target_art_ref = preprocessor.id2ref[target_art_id.item()]
    target_original = sources[target_art_id-1]
    target_text = tokenizer.decode(sample['target_ids'][i], )
    print(f"Target: id={target_art_id}:\n{target_original[:200]}...\n\n")
    print(f"Target tokens:\n{target_text[:]}\n\n")
target_art_ids = sample['target_art_ids']


useless =  0.1
Source original: 1:
'''April''' (Apr.) is the fourth [[month]] of the [[year]] in the [[Julian calendar|Julian]] and [[Gregorian calendar]]s, and comes between [[March]] and [[May]]. It is one of four months to have 30 [[day]]s.
April always begins on the same day of the week as [[July]], and additionally, [[January]] in leap years. April always ends on the same day of the week as [[December]].
== The Month ==
April comes between [[March]] and [[May]], making it the fourth month of the year. It also comes first in the year out of the four months that have 30 days, as [[June]], [[September]] and [[November]] are later in the year.
April begins on the same day of the week as [[July]] every year and on the same day of the week as [[January]] in [[leap year]]s. April ends on the same day of the week as [[December]] every year, as each other's last days are exactly 35 weeks (245 days) apart.
In [[common year]]s, April starts on the same day of the week as [[October]] of the p