In [1]:
# Standard library imports
import json
import logging
import os
import re
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Dict, Iterator, List, Optional, Tuple, Union
import xml.etree.ElementTree as ET

# Third-party imports
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    PreTrainedModel,
    PretrainedConfig,
    Trainer,
    TrainingArguments
)
import tqdm 
import yaml

class WikiProcessor:
    """Prepares citation data for model training."""

    def __init__(self, jsonl_path: str = "data/wiki_articles.jsonl"):
        
        # Load articles
        logging.info("Loading articles from JSONL file...")
        self.articles_dict = {}
        self.id2ref = {}
        self.ref2id = {}
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                article = json.loads(line)
                ref = article['title'].lower()
                id = len(self.articles_dict) + 1
                self.articles_dict[ref] = self.clean_wiki_text(article['text'])
                self.ref2id[ref] = id 
                self.id2ref[id] = ref
        logging.info(f"Loaded {len(self.articles_dict)} articles.")

    def _find_citations(self,text):
        citations = []
        for match in re.finditer(r'\[\[(.*?)\]\]', text):
            match_text = match.group(1)
            citation = match_text.split('|') if '|' in match_text else [match_text]
            citation = [(c.split('#')[0] if '#' in c else c) for c in citation]
            ref = None
            for cit in citation:
                if cit.lower() in self.articles_dict:
                    ref = cit.lower()
                    break
            if ref:
                citations.append((match.start(), match.end(), self.ref2id[ref]))
        return citations

    @staticmethod
    def clean_wiki_text(text: str) -> str:
        """Cleans wiki content by removing metadata and formatting."""
        # Find main content starting from first bold title
        match = re.search(r"'''([^']+?)'''", text)
        if match:
            text = text[match.start():]

        # Remove wiki elements and clean up
        text = re.sub(r'\[\[File:.*\]\]|\[\[Category:.*\]\]|\{\{stub.*\}\}', '', text)
        return '\n'.join(line for line in text.split('\n') if line.strip())

    def find_source_citations(self) -> Tuple[List[str], List[Tuple[List[str], int, int]]]:
        """Creates source-target pairs for citation matching."""

        articles = list(self.articles_dict.keys())
        sources = []
        citation_data = []

        for title in articles:
            text = self.articles_dict[title]
            source_text = self.clean_wiki_text(text)
            citations = self._find_citations(source_text)            
            sources.append(source_text)
            citation_data.append(citations)

        return sources, citation_data

def get_cache_path(sources, model_name: str, cache_dir: str) -> str:
    """Generate a unique cache path based on input data and model name."""
    # Create a hash of the sources and model name
    content_hash = hashlib.md5(str(sources).encode()).hexdigest()
    model_hash = hashlib.md5(model_name.encode()).hexdigest()[:8]
    return os.path.join(cache_dir, f"tokenized_{model_hash}_{content_hash}.pt")

def tokenize_sources(sources=None, citation_data=None, tokenizer=None, batch_size=1000, cache_dir="cache", cache_path=None):
    # Generate cache path
    if cache_path is None:
        cache_path = get_cache_path(sources, tokenizer.name_or_path, cache_dir)
    
    # Check if cached results exist
    if os.path.exists(cache_path):
        logging.info(f"Loading cached tokenized results from {cache_path}")
        return torch.load(cache_path, weights_only=False)
    
    logging.info("Tokenizing sources...")
    # Process in batches
    all_results = []
    for batch_start in tqdm.tqdm(range(0, len(sources), batch_size), total=len(sources)//batch_size):
        batch_end = min(batch_start + batch_size, len(sources))
        batch_sources = sources[batch_start:batch_end]
        batch_citations = citation_data[batch_start:batch_end]
        
        # Batch encode
        batch_encoded = tokenizer.batch_encode_plus(
            batch_sources,
            add_special_tokens=False,
            return_offsets_mapping=True,
            padding=False,
            return_tensors=None
        )
        
        # Process each item in the batch
        for idx in range(len(batch_sources)):
            offset_mapping = batch_encoded["offset_mapping"][idx]
            input_ids = batch_encoded["input_ids"][idx]
            
            # Create offset to index mapping
            off2i = {s:i for i, (s,_) in enumerate(offset_mapping)}
            off2i.update({e:i+1 for i, (_,e) in enumerate(offset_mapping)})
            
            # Create citation tokens array
            mask_tokens = np.zeros(len(input_ids), dtype=int)
            cite_tokens = np.zeros(len(input_ids), dtype=int)
            
            # Fill in citations
            for i, j, art_id in batch_citations[idx]:
                s, e = off2i[i], off2i[j]
                cite_tokens[s] = art_id
                mask_tokens[s:e] = art_id
            
            # Store results
            all_results.append({
                'input_ids': np.array(input_ids),
                'cite_tokens': cite_tokens,
                'mask_tokens': mask_tokens,
                'attention_mask': batch_encoded["attention_mask"][idx] if "attention_mask" in batch_encoded else None
            })

    # Cache the results
    os.makedirs(cache_dir, exist_ok=True)
    torch.save(all_results, cache_path)
    logging.info(f"Cached tokenized results to {cache_path}")
    
    return all_results

def collate(results, tokenizer, config):
    cite_token = tokenizer.convert_tokens_to_ids(config.cite_token)
    ref_token = tokenizer.convert_tokens_to_ids(config.ref_token)
    bracket_tokens = tokenizer.convert_tokens_to_ids(['[',']'])
    pad_token = tokenizer.pad_token_id

    collated_data = []
    # id_to_tokenized = {i: result for i, result in enumerate(results)}
    
    for i in tqdm.tqdm(range(len(results))):
        result = results[i]
        if config.collate_sample_size and len(collated_data)>config.collate_sample_size:
            break
        
        # Process each source segment
        for s in range(0, len(result['input_ids']), int((1-config.overlap)*config.source_len)):
            e = s + config.source_len
            
            # Get source segment
            input_ids = result['input_ids'][s:e].copy()
            cite_tokens = result['cite_tokens'][s:e]
            mask_tokens = result['mask_tokens'][s:e]
            
            # Skip if segment is too short
            if len(input_ids) < config.source_len // 2:
                continue
                
            # Get all citations from this segment
            present_citations = np.unique(cite_tokens[cite_tokens > 0])
            if len(present_citations) > config.max_targets:
                present_citations = np.random.choice(present_citations, config.max_targets, replace=False)
            max_targets = min(config.max_targets, len(present_citations))

            # Skip if segment is too short
            if len(input_ids) < config.source_len // 2:
                continue
            # Skip if no citations
            if max_targets == 0:
                continue
            
            # Initialize target arrays
            target_ids = np.full((max_targets, config.target_len), pad_token, dtype=np.int64)
            target_attention_mask = np.zeros((max_targets, config.target_len), dtype=np.int64)
            
            
            # Prepare source: 
            # only keep citation tokens that are sampled to be masked 
            cite_tokens_mask = np.isin(cite_tokens, present_citations)
            # don't mask citations that are not sampled 
            mask_tokens = np.where(np.isin(mask_tokens, present_citations), mask_tokens, 0)
            # remove brackets from the rest of the text 
            mask_tokens = np.where(np.isin(input_ids,bracket_tokens),1, mask_tokens)
            # don't mask the citation tokens 
            mask_tokens[cite_tokens_mask] = 0
            # set the citation tokens (first token of a citation range) as special token <CITE> 
            input_ids[cite_tokens_mask] = cite_token
            # mask all tokens in a citation, except for the first (special) token 
            source_ids = input_ids[mask_tokens == 0]

            # keep the cited article ids in the text in the order they appear (with repeats)
            # & keep the unique cited artile ids 
            # this will enable us to link each special cite token to a target via the article id
            target_art_ids = present_citations
            cited_art_ids = cite_tokens[cite_tokens_mask]
            
            # Pad or truncate source
            if len(source_ids) > config.source_len:
                source_ids = source_ids[:config.source_len]
            elif len(source_ids) < config.source_len:
                source_ids = np.pad(source_ids, 
                                  (0, config.source_len - len(source_ids)),
                                  'constant', 
                                  constant_values=pad_token)
            
            # Create source attention mask
            attention_mask = (source_ids != pad_token).astype(np.int64)
            
            # Process each target
            for idx, citation_id in enumerate(present_citations):
                # Get pre-tokenized target content
                # ids are 1-indexed 
                target_data = results[citation_id - 1]
                target_tokens = target_data['input_ids']
                
                # Truncate if needed and add ref_token
                if len(target_tokens) >= config.target_len - 1:
                    target_tokens = target_tokens[:config.target_len-1]
                target_tokens = np.append(target_tokens, ref_token)
                
                # Pad to target_len
                if len(target_tokens) < config.target_len:
                    target_tokens = np.pad(target_tokens,
                                         (0, config.target_len - len(target_tokens)),
                                         'constant',
                                         constant_values=pad_token)
                
                # Store in target arrays
                target_ids[idx] = target_tokens
                target_attention_mask[idx] = (target_tokens != pad_token)
                # citation_ids[idx] = citation_id


            # Store the collected data
            collated_data.append({
                'source_art_id': i+1,
                'source_ids': torch.tensor(source_ids, dtype=torch.long),
                'cited_art_ids': torch.tensor(cited_art_ids, dtype=torch.long),
                'target_art_ids': torch.tensor(target_art_ids, dtype=torch.long),
                'target_ids': torch.tensor(target_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
                'target_attention_mask': torch.tensor(target_attention_mask, dtype=torch.long),
            })
    
    return collated_data

class CitationDataset(torch.utils.data.Dataset):
    """Dataset for citation data with stacked targets."""
    
    def __init__(self, collated_data):
        self.data = collated_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def citation_collate_fn(batch):
    # Stack sources normally
    source_ids = torch.stack([item['source_ids'] for item in batch])
    cited_art_ids = torch.cat([item['cited_art_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    # Concatenate targets
    target_art_ids_all = torch.cat([item['target_art_ids'] for item in batch])
    target_ids = torch.cat([item['target_ids'] for item in batch])
    target_attention_mask = torch.cat([item['target_attention_mask'] for item in batch])

    # Get unique indices and inverse indices
    target_art_ids, unique_indices = np.unique(target_art_ids_all.numpy(), return_index=True)
    target_art_ids = torch.tensor(target_art_ids)
    unique_indices = torch.tensor(unique_indices)
    
    # Use unique indices to get corresponding targets
    target_ids = target_ids[unique_indices]
    target_attention_mask = target_attention_mask[unique_indices]

    id2i = {id.item():i for i,id in enumerate(target_art_ids)}
    labels = torch.tensor([id2i[id.item()] for id in cited_art_ids],dtype=torch.long)

      
    return {
        'source_ids': source_ids,
        'cited_art_ids': cited_art_ids,
        'target_art_ids': target_art_ids,
        'target_ids': target_ids,
        'attention_mask': attention_mask,
        'target_attention_mask': target_attention_mask,
        'labels': labels,
    }


@dataclass
class ExperimentConfig:
    """Configuration   the citation matching model."""
    model_name: str = "bert-base-uncased"
    max_length: int = 512
    source_len: int = 512
    target_len: int = 100
    max_targets: int = 5
    overlap: float = 0.5
    cite_token: str = "<CITE>"
    ref_token: str = "<REF>"
    temperature: float = 0.07
    collate_sample_size: int = None,
    device: Optional[torch.device] = None

    def __post_init__(self):
        if self.device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")




class CitationConfig(PretrainedConfig):
    """Configuration class for CitationModel."""
    model_type = "citation"
    
    def __init__(
        self,
        base_model_name="bert-base-uncased",
        vocab_size=30522,
        cite_token_id=None,
        ref_token_id=None,
        temperature=0.07,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.base_model_name = base_model_name
        self.vocab_size = vocab_size
        self.cite_token_id = cite_token_id
        self.ref_token_id = ref_token_id
        self.temperature = temperature

@dataclass
class CitationModelOutput:
    """Custom output class for the citation model."""
    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    cite_embeds: Optional[torch.FloatTensor] = None
    ref_embeds: Optional[torch.FloatTensor] = None

class CitationModel(nn.Module):
    """Custom model for citation matching using transformer embeddings."""
    
    def __init__(self, config: CitationConfig):
        super().__init__()
        
        # Load base model configuration
        base_config = AutoConfig.from_pretrained(config.base_model_name)
        
        # Store configuration
        self.config = config
        
        # Load base transformer model
        self.transformer = AutoModel.from_pretrained(config.base_model_name)
        
        # Resize token embeddings if needed
        if config.vocab_size != self.transformer.config.vocab_size:
            self.transformer.resize_token_embeddings(config.vocab_size)
    
    def get_citation_masks(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Create mask for citation token positions."""
        return input_ids == self.config.cite_token_id
    
    def get_reference_masks(self, input_ids: torch.Tensor) -> torch.Tensor:
        """Create mask for reference token positions."""
        return input_ids == self.config.ref_token_id
    
    def forward(
        self,
        source_ids: torch.Tensor,
        target_ids: torch.Tensor,
        labels: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        target_attention_mask: Optional[torch.Tensor] = None,
        cited_art_ids: Optional[torch.Tensor] = None,
        target_art_ids: Optional[torch.Tensor] = None,
        return_dict: bool = True,
    ) -> Union[Tuple, CitationModelOutput]:
        """Forward pass of the model."""
        
        # Process source text
        source_outputs = self.transformer(
            input_ids=source_ids,
            attention_mask=attention_mask,
            return_dict=True
        )
        
        # Process target text
        target_outputs = self.transformer(
            input_ids=target_ids,
            attention_mask=target_attention_mask,
            return_dict=True
        )
        
        # Get citation mask and extract citation embeddings
        cite_mask = self.get_citation_masks(source_ids)
        cite_embeds = source_outputs.last_hidden_state[cite_mask]
        
        # Get reference mask and extract reference embeddings
        ref_mask = self.get_reference_masks(target_ids)
        ref_embeds = target_outputs.last_hidden_state[ref_mask]
        
        # Normalize embeddings
        cite_embeds = F.normalize(cite_embeds, p=2, dim=-1)
        ref_embeds = F.normalize(ref_embeds, p=2, dim=-1)
        
        # Compute similarity scores
        logits = torch.matmul(cite_embeds, ref_embeds.t()) / self.config.temperature

        # compute the loss 
        loss = F.cross_entropy(logits, labels)
        
        if return_dict:
            return CitationModelOutput(
                loss=loss,
                logits=logits,
                cite_embeds=cite_embeds,
                ref_embeds=ref_embeds
            )
        
        return (loss, logits, cite_embeds, ref_embeds)


def train_citation_model(
    model,
    results,
    tokenizer,
    config,
    train_ratio: float = 0.8,
    num_epochs: int = 5,
    learning_rate: float = 1.5e-4,
    weight_decay: float = 0.01,
    warmup_steps: int = 0,
    device: str = None,
    save_path: str = "citation_model.pt",
    batch_size: int = 128,
    temperatures = []
):
    # Set device
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Initialize gradient scaler for mixed precision training
    scaler = GradScaler()
    
    # Move model to device
    model = model.to(device)
    
    # Enable memory efficient training
    model.transformer.gradient_checkpointing_enable()
    
    # Initialize optimizer
    optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    
    # Training loop
    best_val_loss = float('inf')
    
    for epoch in range(num_epochs):
        if epoch < len(temperatures):
            model.config.temperature = temperatures[epoch]
            print(f"temperature changed to {temperatures[epoch]}")
        print(f"\nEpoch {epoch + 1}/{num_epochs}")
        
        # Create new collated data for this epoch
        print("Collating training data with new random masks...")
        collated = collate(results, tokenizer, config)
        dataset = CitationDataset(collated)
        train_size = int(len(dataset) * 0.8)
        train_dataset = dataset[:train_size]
        val_dataset = dataset[train_size:]
        
        train_dataloader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=citation_collate_fn
        )

        val_dataloader = DataLoader(
            val_dataset,
            batch_size=batch_size,
            shuffle=False,
            collate_fn=citation_collate_fn
        )
        
        # Training phase
        model.train()
        total_train_loss = 0
        train_steps = 0
        
        progress_bar = tqdm.tqdm(train_dataloader, desc="Training")
        
        for batch in progress_bar:
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Clear gradients
            optimizer.zero_grad()
            
            # Forward pass with mixed precision
            with torch.amp.autocast('cuda'):
                outputs = model(**batch)
                loss = outputs.loss
            
            # Backward pass with gradient scaling
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            # Update tracking variables
            total_train_loss += loss.item()
            train_steps += 1
            
            # Update progress bar
            progress_bar.set_postfix({'loss': loss.item()})
        
        avg_train_loss = total_train_loss / train_steps
        print(f"\nAverage training loss: {avg_train_loss:.4f}")
        
        # Validation phase
        print("Running validation...")
        model.eval()
        total_val_loss = 0
        val_steps = 0
        
        progress_bar = tqdm.tqdm(val_dataloader, desc="Validation")
        
        with torch.no_grad():
            for batch in progress_bar:
                # Move batch to device
                batch = {k: v.to(device) for k, v in batch.items()}
                
                # Forward pass with mixed precision
                with torch.amp.autocast('cuda'):
                    outputs = model(**batch)
                    loss = outputs.loss
                
                total_val_loss += loss.item()
                val_steps += 1
                
                # Update progress bar
                progress_bar.set_postfix({'loss': loss.item()})
        
        avg_val_loss = total_val_loss / val_steps
        print(f"\nValidation loss: {avg_val_loss:.4f}")
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scaler_state_dict': scaler.state_dict(),
                'loss': best_val_loss,
            }, save_path)
            print(f"Saved new best model to {save_path}")
    
    return model

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

logging.basicConfig(level=logging.INFO)

# # Load articles
# preprocessor = WikiProcessor()
# sources, citation_data = preprocessor.find_source_citations()

config = ExperimentConfig(collate_sample_size=2000,target_len=128)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.add_special_tokens({
    'additional_special_tokens': [config.cite_token, config.ref_token]
})


# results = tokenize_sources(sources, citation_data, tokenizer, cache_dir="cache",)

# This will now use caching directly 
# results = tokenize_sources(cache_path='./cache/tokenized_1caf5def_eb27a5477eaa3d549aebc4886f3717d1.pt')


# Create model config
model_config = CitationConfig(
    base_model_name=config.model_name,
    vocab_size=len(tokenizer),
    cite_token_id=tokenizer.convert_tokens_to_ids(config.cite_token),
    ref_token_id=tokenizer.convert_tokens_to_ids(config.ref_token),
    temperature=config.temperature,
)

# Initialize model
model = CitationModel(model_config)



# Usage example:
# Replace the previous training code with:
trained_model = train_citation_model(
    model=model,
    results=results,  # Pass raw results instead of dataloader
    tokenizer=tokenizer,
    config=config,
    num_epochs=15,
    weight_decay=0.01,
    learning_rate=1.5e-4,
    batch_size=200,
    save_path="./experiments/best_citation_model.pt",
    temperatures = [1, 0.5, .2, .1, 0.07, 0.06, 0.05],
)


INFO:root:Loading cached tokenized results from ./cache/tokenized_1caf5def_eb27a5477eaa3d549aebc4886f3717d1.pt


Using device: cuda


  scaler = GradScaler()


temperature changed to 1

Epoch 1/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<06:53, 573.52it/s]
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.09s/it, loss=3.79]



Average training loss: 6.0797
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.13it/s, loss=2.93]



Validation loss: 4.8741
Saved new best model to ./experiments/best_citation_model.pt
temperature changed to 0.5

Epoch 2/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<06:52, 574.60it/s]
Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:28<00:00,  3.17s/it, loss=3.17]



Average training loss: 5.3684
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.12it/s, loss=2.49]



Validation loss: 4.3607
Saved new best model to ./experiments/best_citation_model.pt
temperature changed to 0.2

Epoch 3/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:05, 556.91it/s]
Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.08s/it, loss=2.13]



Average training loss: 4.0308
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.11it/s, loss=1.99]



Validation loss: 3.4653
Saved new best model to ./experiments/best_citation_model.pt
temperature changed to 0.1

Epoch 4/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:12, 547.85it/s]
Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.08s/it, loss=1.65]



Average training loss: 2.8987
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.11it/s, loss=1.56]



Validation loss: 2.6954
Saved new best model to ./experiments/best_citation_model.pt
temperature changed to 0.07

Epoch 5/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<06:58, 566.35it/s]
Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.10s/it, loss=1.25]



Average training loss: 2.2807
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.11it/s, loss=1.15]



Validation loss: 2.4470
Saved new best model to ./experiments/best_citation_model.pt
temperature changed to 0.06

Epoch 6/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:02, 561.55it/s]
Training: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.10s/it, loss=1.1]



Average training loss: 1.8894
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.11it/s, loss=1.02]



Validation loss: 2.4007
Saved new best model to ./experiments/best_citation_model.pt
temperature changed to 0.05

Epoch 7/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:03, 560.10it/s]
Training: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.09s/it, loss=0.95]



Average training loss: 1.6242
Running validation...


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.11it/s, loss=0.771]



Validation loss: 2.2822
Saved new best model to ./experiments/best_citation_model.pt

Epoch 8/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:14, 546.29it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.10s/it, loss=0.764]



Average training loss: 1.4499
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.09it/s, loss=1.11]



Validation loss: 2.3854

Epoch 9/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:28, 528.52it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:28<00:00,  3.11s/it, loss=0.605]



Average training loss: 1.2798
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.08it/s, loss=1.11]



Validation loss: 2.3354

Epoch 10/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:12, 548.35it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.11s/it, loss=0.645]



Average training loss: 1.1712
Running validation...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.10it/s, loss=1.4]



Validation loss: 2.4682

Epoch 11/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:10, 550.51it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.11s/it, loss=0.663]



Average training loss: 1.0717
Running validation...


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.10it/s, loss=0.869]



Validation loss: 2.2377
Saved new best model to ./experiments/best_citation_model.pt

Epoch 12/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:10, 550.76it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.11s/it, loss=0.457]



Average training loss: 0.9803
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.11it/s, loss=1.03]



Validation loss: 2.2884

Epoch 13/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:08, 553.24it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:28<00:00,  3.11s/it, loss=0.336]



Average training loss: 0.8759
Running validation...


Validation: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.08it/s, loss=1.5]



Validation loss: 2.5096

Epoch 14/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<06:56, 569.22it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.10s/it, loss=0.402]



Average training loss: 0.8347
Running validation...


Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.09it/s, loss=1.18]



Validation loss: 2.4250

Epoch 15/15
Collating training data with new random masks...


  0%|▏                                                                                                                  | 279/237381 [00:00<07:29, 527.97it/s]
Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [00:27<00:00,  3.11s/it, loss=0.611]



Average training loss: 0.8076
Running validation...


Validation: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:01<00:00,  2.11it/s, loss=0.631]



Validation loss: 2.1802
Saved new best model to ./experiments/best_citation_model.pt


In [14]:
collated = collate(results, tokenizer, config)
dataset = CitationDataset(collated)

dataloader = DataLoader(
    dataset,
    batch_size=200,
    shuffle=True,
    collate_fn=citation_collate_fn
)

for i,batch in enumerate(dataloader):
    if i > 2:
        break



  0%|▏                                                                                                                  | 279/237381 [00:00<07:03, 560.42it/s]


In [15]:
batch['labels'].max()

tensor(799)

In [184]:

# Create dataset and dataloader
dataset = CitationDataset(collated_data)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=citation_collate_fn
)

# Example of resulting tensor shapes for a batch
for batch in dataloader:
    print("Source shape:", batch['source_ids'].shape)  # [batch_size, source_len]
    print("Target shape:", batch['target_ids'].shape)  # [total_targets, target_len]
    print("Target counts:", batch['target_counts'])    # [batch_size]
    break

labels = [torch.where(batch['cited_art_ids']==id)[0][0].item() for id in batch['target_art_ids']]



Source shape: torch.Size([16, 512])
Target shape: torch.Size([80, 100])
Target counts: tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5])


(80, 80)

In [153]:
sample = dataset[0]
source_art_id = sample['source_art_id']
original_source = sources[source_art_id-1]
source_text = tokenizer.decode(sample['source_ids'], )
cited_art_ids = sample['cited_art_ids']

useless_chars = np.sum([c==']' for c in source_text])*2/len(source_text)
print('useless = ', useless_chars)
print(f"Source original: {source_art_id}:\n{original_source[:1000]}\n\n")
print('#'*50)
print(f"Source tokens decoded:\n{source_text[:]}\n\n")
print(f"Source attention mask: {(sample['attention_mask']==0).sum()}, ")

for i, target_art_id in enumerate(sample['target_art_ids']):
    target_art_ref = preprocessor.id2ref[target_art_id.item()]
    target_original = sources[target_art_id-1]
    target_text = tokenizer.decode(sample['target_ids'][i], )
    print(f"Target: id={target_art_id}:\n{target_original[:200]}...\n\n")
    print(f"Target tokens:\n{target_text[:]}\n\n")
target_art_ids = sample['target_art_ids']


useless =  0.1
Source original: 1:
'''April''' (Apr.) is the fourth [[month]] of the [[year]] in the [[Julian calendar|Julian]] and [[Gregorian calendar]]s, and comes between [[March]] and [[May]]. It is one of four months to have 30 [[day]]s.
April always begins on the same day of the week as [[July]], and additionally, [[January]] in leap years. April always ends on the same day of the week as [[December]].
== The Month ==
April comes between [[March]] and [[May]], making it the fourth month of the year. It also comes first in the year out of the four months that have 30 days, as [[June]], [[September]] and [[November]] are later in the year.
April begins on the same day of the week as [[July]] every year and on the same day of the week as [[January]] in [[leap year]]s. April ends on the same day of the week as [[December]] every year, as each other's last days are exactly 35 weeks (245 days) apart.
In [[common year]]s, April starts on the same day of the week as [[October]] of the p

In [24]:
unique, first_indexes = np.unique(np.array([1,2,1, 3, 3, 3]), return_index=True)
unique, first_indexes

(array([1, 2, 3]), array([0, 1, 3]))