In [2]:
import json
import logging
import torch
import os
from utils import *
from transformers import AutoTokenizer
from typing import List, Dict, Tuple, Iterator, Union, Optional
import tqdm
import hashlib 

class WikiProcessor:
    """Prepares citation data for model training."""

    def __init__(self, jsonl_path: str = "data/wiki_articles.jsonl"):
        
        # Load articles
        logging.info("Loading articles from JSONL file...")
        self.articles_dict = {}
        self.id2ref = {}
        self.ref2id = {}
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                article = json.loads(line)
                ref = article['title'].lower()
                id = len(self.articles_dict) + 1
                self.articles_dict[ref] = self.clean_wiki_text(article['text'])
                self.ref2id[ref] = id 
                self.id2ref[id] = ref
        logging.info(f"Loaded {len(self.articles_dict)} articles.")

    def _find_citations(self,text):
        citations = []
        for match in re.finditer(r'\[\[(.*?)\]\]', text):
            match_text = match.group(1)
            citation = match_text.split('|') if '|' in match_text else [match_text]
            citation = [(c.split('#')[0] if '#' in c else c) for c in citation]
            ref = None
            for cit in citation:
                if cit.lower() in self.articles_dict:
                    ref = cit.lower()
                    break
            if ref:
                citations.append((match.start(), match.end(), self.ref2id[ref]))
        return citations

    @staticmethod
    def clean_wiki_text(text: str) -> str:
        """Cleans wiki content by removing metadata and formatting."""
        # Find main content starting from first bold title
        match = re.search(r"'''([^']+?)'''", text)
        if match:
            text = text[match.start():]

        # Remove wiki elements and clean up
        text = re.sub(r'\[\[Category:.?\]\]|\[\[File:.?\]\]|\{\{stub\}\}', '', text)
        return '\n'.join(line for line in text.split('\n') if line.strip())

    def find_source_citations(self) -> Tuple[List[str], List[Tuple[List[str], int, int]]]:
        """Creates source-target pairs for citation matching."""

        articles = list(self.articles_dict.keys())
        sources = []
        citation_data = []

        for title in articles:
            text = self.articles_dict[title]
            source_text = self.clean_wiki_text(text)
            citations = self._find_citations(source_text)            
            sources.append(source_text)
            citation_data.append(citations)

        return sources, citation_data

def get_cache_path(sources, model_name: str, cache_dir: str) -> str:
    """Generate a unique cache path based on input data and model name."""
    # Create a hash of the sources and model name
    content_hash = hashlib.md5(str(sources).encode()).hexdigest()
    model_hash = hashlib.md5(model_name.encode()).hexdigest()[:8]
    return os.path.join(cache_dir, f"tokenized_{model_hash}_{content_hash}.pt")

def tokenize_sources(sources, citation_data, tokenizer, batch_size=1000, cache_dir="cache"):
    # Generate cache path
    cache_path = get_cache_path(sources, tokenizer.name_or_path, cache_dir)
    
    # Check if cached results exist
    if os.path.exists(cache_path):
        logging.info(f"Loading cached tokenized results from {cache_path}")
        return torch.load(cache_path, weights_only=False)
    
    logging.info("Tokenizing sources...")
    # Process in batches
    all_results = []
    for batch_start in tqdm.tqdm(range(0, len(sources), batch_size), total=len(sources)//batch_size):
        batch_end = min(batch_start + batch_size, len(sources))
        batch_sources = sources[batch_start:batch_end]
        batch_citations = citation_data[batch_start:batch_end]
        
        # Batch encode
        batch_encoded = tokenizer.batch_encode_plus(
            batch_sources,
            add_special_tokens=False,
            return_offsets_mapping=True,
            padding=False,
            return_tensors=None
        )
        
        # Process each item in the batch
        for idx in range(len(batch_sources)):
            offset_mapping = batch_encoded["offset_mapping"][idx]
            input_ids = batch_encoded["input_ids"][idx]
            
            # Create offset to index mapping
            off2i = {s:i for i, (s,_) in enumerate(offset_mapping)}
            off2i.update({e:i for i, (_,e) in enumerate(offset_mapping)})
            
            # Create citation tokens array
            mask_tokens = np.zeros(len(input_ids), dtype=int)
            cite_tokens = np.zeros(len(input_ids), dtype=int)
            
            # Fill in citations
            for i, j, art_id in batch_citations[idx]:
                s, e = off2i[i], off2i[j]
                cite_tokens[s] = art_id
                mask_tokens[s:e] = art_id
            
            # Store results
            all_results.append({
                'input_ids': np.array(input_ids),
                'cite_tokens': cite_tokens,
                'mask_tokens': mask_tokens,
                'attention_mask': batch_encoded["attention_mask"][idx] if "attention_mask" in batch_encoded else None
            })

    # Cache the results
    os.makedirs(cache_dir, exist_ok=True)
    torch.save(all_results, cache_path)
    logging.info(f"Cached tokenized results to {cache_path}")
    
    return all_results

def collate(results, tokenizer, config):
    tokenizer.add_special_tokens({
        'additional_special_tokens': [config.cite_token, config.ref_token]
    })

    source_len = 512
    summary_len = 100
    overlap = 0.5
    sample_size = 1

    cite_token = tokenizer.convert_tokens_to_ids(config.cite_token)
    ref_token = tokenizer.convert_tokens_to_ids(config.ref_token)

    all_ids = []
    all_citation = []
    for i in tqdm.trange(len(results)):
        result = results[i]
        for s in range(0, len(result['input_ids']), int((1-overlap)*source_len)):
            e = s + source_len
            input_ids = result['input_ids'][s:e]
            cite_tokens = result['cite_tokens'][s:e]
            mask_tokens = result['mask_tokens'][s:e]

            present_citations = cite_tokens[cite_tokens > 0]
            N = min(sample_size, len(present_citations))
            sample_citations = np.random.choice(present_citations, N, replace=False)
            
            cite_tokens = np.where(np.isin(cite_tokens, sample_citations), cite_tokens, 0)
            mask_tokens = np.where(np.isin(mask_tokens, sample_citations), mask_tokens, 0)
            mask_tokens[cite_tokens > 0] = 0

            input_ids[cite_tokens > 0] = cite_token
            input_ids = input_ids[mask_tokens == 0]
            cite_tokens = cite_tokens[mask_tokens == 0]
            all_ids.append(input_ids)
            all_citation.append(sample_citations)



logging.basicConfig(level=logging.INFO)

# Load articles
preprocessor = WikiProcessor()
sources, citation_data = preprocessor.find_source_citations()

config = ModelConfig()
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.add_special_tokens({
    'additional_special_tokens': [config.cite_token, config.ref_token]
})

# This will now use caching
results = tokenize_sources(sources, citation_data, tokenizer, cache_dir="cache")


# collate 
# collate(results[:], tokenizer, config)


  from .autonotebook import tqdm as notebook_tqdm
INFO:root:Loading articles from JSONL file...
INFO:root:Loaded 237381 articles.
INFO:root:Loading cached tokenized results from cache/tokenized_1caf5def_895012ad817559b15b42e1d366769a67.pt


In [52]:
def collate(results, tokenizer, config, max_targets=5, source_len = 512, target_len = 100, overlap = 0.5):
    cite_token = tokenizer.convert_tokens_to_ids(config.cite_token)
    ref_token = tokenizer.convert_tokens_to_ids(config.ref_token)
    pad_token = tokenizer.pad_token_id

    collated_data = []
    id_to_tokenized = {i: result for i, result in enumerate(results)}
    
    for i in tqdm.trange(len(results)):
        result = results[i]

        if len(collated_data) > 1000:
            break
        
        # Process each source segment
        for s in range(0, len(result['input_ids']), int((1-overlap)*source_len)):
            e = s + source_len
            
            # Get source segment
            input_ids = result['input_ids'][s:e].copy()
            cite_tokens = result['cite_tokens'][s:e]
            mask_tokens = result['mask_tokens'][s:e]
            
            # Skip if segment is too short
            if len(input_ids) < source_len // 2:
                continue
                
            # Get all citations from this segment
            present_citations = np.unique(cite_tokens[cite_tokens > 0])
            if len(present_citations) == 0:
                continue
            
            # Initialize target arrays
            target_ids = np.full((max_targets, target_len), pad_token, dtype=np.int64)
            target_attention_mask = np.zeros((max_targets, target_len), dtype=np.int64)
            citation_ids = np.zeros(max_targets, dtype=np.int64)
            
            # Sample citations if we have more than max_targets
            if len(present_citations) > max_targets:
                present_citations = np.random.choice(present_citations, max_targets, replace=False)
            
            # Prepare source
            cite_tokens_mask = np.isin(cite_tokens, present_citations)
            mask_tokens = np.where(np.isin(mask_tokens, present_citations), 0, mask_tokens)
            mask_tokens[cite_tokens_mask > 0] = 0
            input_ids[cite_tokens_mask] = cite_token
            source_ids = input_ids[mask_tokens == 0]
            cite_tokens = cite_tokens[cite_tokens_mask]
            
            # Pad or truncate source
            if len(source_ids) > source_len:
                source_ids = source_ids[:source_len]
            elif len(source_ids) < source_len:
                source_ids = np.pad(source_ids, 
                                  (0, source_len - len(source_ids)),
                                  'constant', 
                                  constant_values=pad_token)
            
            # Create source attention mask
            attention_mask = (source_ids != pad_token).astype(np.int64)
            
            # Process each target
            for idx, citation_id in enumerate(present_citations):
                # Get pre-tokenized target content
                target_data = id_to_tokenized[citation_id - 1]
                target_tokens = target_data['input_ids']
                
                # Truncate if needed and add ref_token
                if len(target_tokens) >= target_len - 1:
                    target_tokens = target_tokens[:target_len-1]
                target_tokens = np.append(target_tokens, ref_token)
                
                # Pad to target_len
                if len(target_tokens) < target_len:
                    target_tokens = np.pad(target_tokens,
                                         (0, target_len - len(target_tokens)),
                                         'constant',
                                         constant_values=pad_token)
                
                # Store in target arrays
                target_ids[idx] = target_tokens
                target_attention_mask[idx] = (target_tokens != pad_token)
                citation_ids[idx] = citation_id
            
            # Store the collected data
            collated_data.append({
                'source_ids': torch.tensor(source_ids, dtype=torch.long),
                'cite_tokens': torch.tensor(cite_tokens, dtype=torch.long),
                'target_ids': torch.tensor(target_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
                'target_attention_mask': torch.tensor(target_attention_mask, dtype=torch.long),
                'target_count': len(present_citations),
                'citation_ids': torch.tensor(citation_ids, dtype=torch.long)
            })
    
    return collated_data

class CitationDataset(torch.utils.data.Dataset):
    """Dataset for citation data with stacked targets."""
    
    def __init__(self, collated_data):
        self.data = collated_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def citation_collate_fn(batch):
    """
    Collate function for DataLoader that properly handles the stacked format.
    
    Args:
        batch: List of dictionaries from CitationDataset
    
    Returns:
        Dictionary with batched tensors where:
        - source_ids: [batch_size, source_len]
        - target_ids: [batch_size * max_targets, target_len]
        - attention_mask: [batch_size, source_len]
        - target_attention_mask: [batch_size * max_targets, target_len]
        - target_counts: [batch_size]
        - citation_ids: [batch_size * max_targets]
    """
    # Stack sources normally
    source_ids = torch.stack([item['source_ids'] for item in batch])
    cite_tokens = torch.cat([item['cite_tokens'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    # Concatenate targets
    target_ids = torch.cat([item['target_ids'][:item['target_count']] for item in batch])
    target_attention_mask = torch.cat([item['target_attention_mask'][:item['target_count']] for item in batch])
    citation_ids = torch.cat([item['citation_ids'][:item['target_count']] for item in batch])
    target_counts = torch.tensor([item['target_count'] for item in batch])
    
    return {
        'source_ids': source_ids,
        'cite_tokens': cite_tokens,
        'target_ids': target_ids,
        'attention_mask': attention_mask,
        'target_attention_mask': target_attention_mask,
        'target_counts': target_counts,
        'citation_ids': citation_ids
    }

# Usage example:
# Collate the data
collated_data = collate(results, tokenizer, config, max_targets=5)

# Create dataset and dataloader
dataset = CitationDataset(collated_data)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=citation_collate_fn
)

# Example of resulting tensor shapes for a batch
for batch in dataloader:
    print("Source shape:", batch['source_ids'].shape)  # [batch_size, source_len]
    print("Target shape:", batch['target_ids'].shape)  # [total_targets, target_len]
    print("Target counts:", batch['target_counts'])    # [batch_size]
    break

  0%|          | 0/237381 [00:00<?, ?it/s]




IndexError: boolean index did not match indexed array along dimension 0; dimension is 407 but corresponding boolean dimension is 7

In [48]:
cite_token = tokenizer.convert_tokens_to_ids(config.cite_token)
ref_token = tokenizer.convert_tokens_to_ids(config.ref_token)
(batch['source_ids']==cite_token).sum()

tensor(83)

In [51]:
len(batch['cite_tokens'])

83

In [50]:
len(batch['source_ids'][1])

512