In [66]:
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Iterator, Union, Optional
import xml.etree.ElementTree as ET
import json
import sqlite3
from pathlib import Path
import re
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModel
import tqdm 
import hashlib
import os
import logging
from datetime import datetime
import torch.cuda.amp  # For automatic mixed precision
import yaml



class WikiProcessor:
    """Prepares citation data for model training."""

    def __init__(self, jsonl_path: str = "data/wiki_articles.jsonl"):
        
        # Load articles
        logging.info("Loading articles from JSONL file...")
        self.articles_dict = {}
        self.id2ref = {}
        self.ref2id = {}
        with open(jsonl_path, 'r', encoding='utf-8') as f:
            for line in f:
                article = json.loads(line)
                ref = article['title'].lower()
                id = len(self.articles_dict) + 1
                self.articles_dict[ref] = self.clean_wiki_text(article['text'])
                self.ref2id[ref] = id 
                self.id2ref[id] = ref
        logging.info(f"Loaded {len(self.articles_dict)} articles.")

    def _find_citations(self,text):
        citations = []
        for match in re.finditer(r'\[\[(.*?)\]\]', text):
            match_text = match.group(1)
            citation = match_text.split('|') if '|' in match_text else [match_text]
            citation = [(c.split('#')[0] if '#' in c else c) for c in citation]
            ref = None
            for cit in citation:
                if cit.lower() in self.articles_dict:
                    ref = cit.lower()
                    break
            if ref:
                citations.append((match.start(), match.end(), self.ref2id[ref]))
        return citations

    @staticmethod
    def clean_wiki_text(text: str) -> str:
        """Cleans wiki content by removing metadata and formatting."""
        # Find main content starting from first bold title
        match = re.search(r"'''([^']+?)'''", text)
        if match:
            text = text[match.start():]

        # Remove wiki elements and clean up
        text = re.sub(r'\[\[File:.*\]\]|\[\[Category:.*\]\]|\{\{stub.*\}\}', '', text)
        return '\n'.join(line for line in text.split('\n') if line.strip())

    def find_source_citations(self) -> Tuple[List[str], List[Tuple[List[str], int, int]]]:
        """Creates source-target pairs for citation matching."""

        articles = list(self.articles_dict.keys())
        sources = []
        citation_data = []

        for title in articles:
            text = self.articles_dict[title]
            source_text = self.clean_wiki_text(text)
            citations = self._find_citations(source_text)            
            sources.append(source_text)
            citation_data.append(citations)

        return sources, citation_data

def get_cache_path(sources, model_name: str, cache_dir: str) -> str:
    """Generate a unique cache path based on input data and model name."""
    # Create a hash of the sources and model name
    content_hash = hashlib.md5(str(sources).encode()).hexdigest()
    model_hash = hashlib.md5(model_name.encode()).hexdigest()[:8]
    return os.path.join(cache_dir, f"tokenized_{model_hash}_{content_hash}.pt")

def tokenize_sources(sources=None, citation_data=None, tokenizer=None, batch_size=1000, cache_dir="cache", cache_path=None):
    # Generate cache path
    if cache_path is None:
        cache_path = get_cache_path(sources, tokenizer.name_or_path, cache_dir)
    
    # Check if cached results exist
    if os.path.exists(cache_path):
        logging.info(f"Loading cached tokenized results from {cache_path}")
        return torch.load(cache_path, weights_only=False)
    
    logging.info("Tokenizing sources...")
    # Process in batches
    all_results = []
    for batch_start in tqdm.tqdm(range(0, len(sources), batch_size), total=len(sources)//batch_size):
        batch_end = min(batch_start + batch_size, len(sources))
        batch_sources = sources[batch_start:batch_end]
        batch_citations = citation_data[batch_start:batch_end]
        
        # Batch encode
        batch_encoded = tokenizer.batch_encode_plus(
            batch_sources,
            add_special_tokens=False,
            return_offsets_mapping=True,
            padding=False,
            return_tensors=None
        )
        
        # Process each item in the batch
        for idx in range(len(batch_sources)):
            offset_mapping = batch_encoded["offset_mapping"][idx]
            input_ids = batch_encoded["input_ids"][idx]
            
            # Create offset to index mapping
            off2i = {s:i for i, (s,_) in enumerate(offset_mapping)}
            off2i.update({e:i+1 for i, (_,e) in enumerate(offset_mapping)})
            
            # Create citation tokens array
            mask_tokens = np.zeros(len(input_ids), dtype=int)
            cite_tokens = np.zeros(len(input_ids), dtype=int)
            
            # Fill in citations
            for i, j, art_id in batch_citations[idx]:
                s, e = off2i[i], off2i[j]
                cite_tokens[s] = art_id
                mask_tokens[s:e] = art_id
            
            # Store results
            all_results.append({
                'input_ids': np.array(input_ids),
                'cite_tokens': cite_tokens,
                'mask_tokens': mask_tokens,
                'attention_mask': batch_encoded["attention_mask"][idx] if "attention_mask" in batch_encoded else None
            })

    # Cache the results
    os.makedirs(cache_dir, exist_ok=True)
    torch.save(all_results, cache_path)
    logging.info(f"Cached tokenized results to {cache_path}")
    
    return all_results

def collate(results, tokenizer, config):
    cite_token = tokenizer.convert_tokens_to_ids(config.cite_token)
    ref_token = tokenizer.convert_tokens_to_ids(config.ref_token)
    pad_token = tokenizer.pad_token_id

    collated_data = []
    # id_to_tokenized = {i: result for i, result in enumerate(results)}
    
    for i in tqdm.tqdm(range(len(results))):
        result = results[i]
        if len(collated_data) > 1000:
            break
        
        # Process each source segment
        for s in range(0, len(result['input_ids']), int((1-config.overlap)*config.source_len)):
            e = s + config.source_len
            
            # Get source segment
            input_ids = result['input_ids'][s:e].copy()
            cite_tokens = result['cite_tokens'][s:e]
            mask_tokens = result['mask_tokens'][s:e]
            
            # Skip if segment is too short
            if len(input_ids) < config.source_len // 2:
                continue
                
            # Get all citations from this segment
            present_citations = np.unique(cite_tokens[cite_tokens > 0])
            if len(present_citations) > config.max_targets:
                present_citations = np.random.choice(present_citations, config.max_targets, replace=False)
            max_targets = min(config.max_targets, len(present_citations))

            # Skip if segment is too short
            if len(input_ids) < config.source_len // 2:
                continue
            # Skip if no citations
            if max_targets == 0:
                continue
            
            # Initialize target arrays
            target_ids = np.full((max_targets, config.target_len), pad_token, dtype=np.int64)
            target_attention_mask = np.zeros((max_targets, config.target_len), dtype=np.int64)
            
            
            # Prepare source
            cite_tokens_mask = np.isin(cite_tokens, present_citations)
            mask_tokens = np.where(np.isin(mask_tokens, present_citations), mask_tokens, 0)
            mask_tokens[cite_tokens_mask] = 0
            input_ids[cite_tokens_mask] = cite_token
            source_ids = input_ids[mask_tokens == 0]
            target_art_ids = present_citations
            cited_art_ids = cite_tokens[cite_tokens_mask]
            
            # Pad or truncate source
            if len(source_ids) > config.source_len:
                source_ids = source_ids[:config.source_len]
            elif len(source_ids) < config.source_len:
                source_ids = np.pad(source_ids, 
                                  (0, config.source_len - len(source_ids)),
                                  'constant', 
                                  constant_values=pad_token)
            
            # Create source attention mask
            attention_mask = (source_ids != pad_token).astype(np.int64)
            
            # Process each target
            for idx, citation_id in enumerate(present_citations):
                # Get pre-tokenized target content
                # ids are 1-indexed 
                target_data = results[citation_id - 1]
                target_tokens = target_data['input_ids']
                
                # Truncate if needed and add ref_token
                if len(target_tokens) >= config.target_len - 1:
                    target_tokens = target_tokens[:config.target_len-1]
                target_tokens = np.append(target_tokens, ref_token)
                
                # Pad to target_len
                if len(target_tokens) < config.target_len:
                    target_tokens = np.pad(target_tokens,
                                         (0, config.target_len - len(target_tokens)),
                                         'constant',
                                         constant_values=pad_token)
                
                # Store in target arrays
                target_ids[idx] = target_tokens
                target_attention_mask[idx] = (target_tokens != pad_token)
                # citation_ids[idx] = citation_id
            
            # Store the collected data
            collated_data.append({
                'source_art_id': i+1,
                'source_ids': torch.tensor(source_ids, dtype=torch.long),
                'cited_art_ids': torch.tensor(cited_art_ids, dtype=torch.long),
                'target_art_ids': torch.tensor(target_art_ids, dtype=torch.long),
                'target_ids': torch.tensor(target_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
                'target_attention_mask': torch.tensor(target_attention_mask, dtype=torch.long),
                'target_count': len(present_citations),
            })
    
    return collated_data

class CitationDataset(torch.utils.data.Dataset):
    """Dataset for citation data with stacked targets."""
    
    def __init__(self, collated_data):
        self.data = collated_data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def citation_collate_fn(batch):
    # Stack sources normally
    source_ids = torch.stack([item['source_ids'] for item in batch])
    cited_art_ids = torch.cat([item['cited_art_ids'] for item in batch])
    target_art_ids = torch.cat([item['target_art_ids'] for item in batch])
    attention_mask = torch.stack([item['attention_mask'] for item in batch])
    
    # Concatenate targets
    target_ids = torch.cat([item['target_ids'][:item['target_count']] for item in batch])
    target_attention_mask = torch.cat([item['target_attention_mask'][:item['target_count']] for item in batch])
    target_counts = torch.tensor([item['target_count'] for item in batch])
    
    return {
        'source_ids': source_ids,
        'cited_art_ids': cited_art_ids,
        'target_art_ids': target_art_ids,
        'target_ids': target_ids,
        'attention_mask': attention_mask,
        'target_attention_mask': target_attention_mask,
        'target_counts': target_counts,
    }


@dataclass
class ExperimentConfig:
    """Configuration   the citation matching model."""
    model_name: str = "bert-base-uncased"
    max_length: int = 512
    source_len: int = 512
    target_len: int = 100
    max_targets: int = 5
    overlap: float = 0.5
    cite_token: str = "<CITE>"
    ref_token: str = "<REF>"
    temperature: float = 0.07
    device: Optional[torch.device] = None

    def __post_init__(self):
        if self.device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# logging.basicConfig(level=logging.INFO)

# # # Load articles
# preprocessor = WikiProcessor()
# sources, citation_data = preprocessor.find_source_citations()

# config = ExperimentConfig()
# tokenizer = AutoTokenizer.from_pretrained(config.model_name)
# tokenizer.add_special_tokens({
#     'additional_special_tokens': [config.cite_token, config.ref_token]
# })


# results = tokenize_sources(sources, citation_data, tokenizer, cache_dir="cache",)

# # # # This will now use caching directly 
# # results = tokenize_sources(cache_path='./cache/tokenized_1caf5def_895012ad817559b15b42e1d366769a67.pt')



# Usage example:
# Collate the data
collated_data = collate(results, tokenizer, config)

# Create dataset and dataloader
dataset = CitationDataset(collated_data)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=citation_collate_fn
)

# Example of resulting tensor shapes for a batch
for batch in dataloader:
    print("Source shape:", batch['source_ids'].shape)  # [batch_size, source_len]
    print("Target shape:", batch['target_ids'].shape)  # [total_targets, target_len]
    print("Target counts:", batch['target_counts'])    # [batch_size]
    break

  0%|                                                                                                                   | 150/237381 [00:00<05:33, 710.86it/s]


Source shape: torch.Size([16, 512])
Target shape: torch.Size([79, 100])
Target counts: tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 4, 5, 5, 5, 5, 5])


  0%|                                                                                                                   | 150/237381 [00:00<05:42, 692.02it/s]

Source shape: torch.Size([16, 512])
Target shape: torch.Size([77, 100])
Target counts: tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 2, 5])





In [4]:
len(sources), len(citation_data), len(results), len(collated_data)

(237381, 237381, 237381, 1056)

In [13]:
sources = tokenizer.batch_decode(batch['source_ids'])
# cited_articles = [preprocessor.id2ref[id] for id in batch['cited_art_ids']]
# target_articles = [preprocessor.id2ref[id] for id in batch['target_art_ids']]
targets = tokenizer.batch_decode(batch['target_ids'])

INFO:root:Loading articles from JSONL file...
INFO:root:Loaded 237381 articles.


In [13]:
cite_token = tokenizer.convert_tokens_to_ids(config.cite_token)
ref_token = tokenizer.convert_tokens_to_ids(config.ref_token)

batch['source_ids'].shape, batch['target_ids'].shape, (batch['source_ids']==cite_token).sum(), len(batch['cited_art_ids']), len(batch['target_art_ids'])

(torch.Size([16, 512]), torch.Size([77, 100]), tensor(81), 81, 77)

In [68]:
sample = dataset[2]
source_art_id = sample['source_art_id']
original_source = sources[source_art_id-1]
source_text = tokenizer.decode(sample['source_ids'], )
cited_art_ids = sample['cited_art_ids']

print(f"Source original: {source_art_id}:\n{original_source[:1000]}\n\n")
print('#'*50)
print(f"Source tokens decoded:\n{source_text[:]}\n\n")
for i, target_art_id in enumerate(sample['cited_art_ids']):
    target_art_ref = preprocessor.id2ref[target_art_id.item()]
    target_original = sources[target_art_id-1]
    target_text = tokenizer.decode(sample['cited_art_ids'][i], )
    print(f"Target: id={target_art_id}:\n{target_original[:100]}\n\n")
    print(f"Target tokens:\n{target_text[:100]}\n\n")
target_art_ids = sample['target_art_ids']


Source original: 1:
'''April''' (Apr.) is the fourth [[month]] of the [[year]] in the [[Julian calendar|Julian]] and [[Gregorian calendar]]s, and comes between [[March]] and [[May]]. It is one of four months to have 30 [[day]]s.
April always begins on the same day of the week as [[July]], and additionally, [[January]] in leap years. April always ends on the same day of the week as [[December]].
== The Month ==
April comes between [[March]] and [[May]], making it the fourth month of the year. It also comes first in the year out of the four months that have 30 days, as [[June]], [[September]] and [[November]] are later in the year.
April begins on the same day of the week as [[July]] every year and on the same day of the week as [[January]] in [[leap year]]s. April ends on the same day of the week as [[December]] every year, as each other's last days are exactly 35 weeks (245 days) apart.
In [[common year]]s, April starts on the same day of the week as [[October]] of the previous year, a

In [37]:
def batch_tokenize(batch_sources, batch_citations):
    # batch_sources = sources[batch_start:batch_end]
    # batch_citations = citation_data[batch_start:batch_end]

    # Ba tch encode
    batch_encoded = tokenizer.batch_encode_plus(
        batch_sources,
        add_special_tokens=False,
        return_offsets_mapping=True,
        padding=False,
        return_tensors=None
    )
    all_results = []
    # Process each item in the batch
    for idx in range(len(batch_sources)):
        offset_mapping = batch_encoded["offset_mapping"][idx]
        input_ids = batch_encoded["input_ids"][idx]
        
        # Create offset to index mapping
        off2i = {s:i for i, (s,_) in enumerate(offset_mapping)}
        off2i.update({e:i+1 for i, (_,e) in enumerate(offset_mapping)})
        
        # Create citation tokens array
        mask_tokens = np.zeros(len(input_ids), dtype=int)
        cite_tokens = np.zeros(len(input_ids), dtype=int)
        
        # Fill in citations
        for i, j, art_id in batch_citations[idx]:
            s, e = off2i[i], off2i[j]
            cite_tokens[s] = art_id
            mask_tokens[s:e] = art_id
        
        # Store results
        all_results.append({
            'input_ids': np.array(input_ids),
            'cite_tokens': cite_tokens,
            'mask_tokens': mask_tokens,
            'attention_mask': batch_encoded["attention_mask"][idx] if "attention_mask" in batch_encoded else None
        })
    return all_results

# preprocessor = WikiProcessor()
# sources, citation_data = preprocessor.find_source_citations()
for i in range(5):
    ref = preprocessor.id2ref[i+1]
    print(f"title = {ref}")
    result = batch_tokenize(sources[i:i+1], citation_data[i:i+1])[0]
    print(result['input_ids'].shape, result['cite_tokens'].shape, result['mask_tokens'].shape)
    input_ids = result['input_ids']
    cite_tokens = result['cite_tokens']
    mask_tokens = result['mask_tokens']
    
    decoded_text = tokenizer.decode(input_ids)
    print('original text = ', decoded_text[:1000])
    input_ids = input_ids[mask_tokens==0]
    print('#'*50)
    decoded_text = tokenizer.decode(input_ids)
    print('masked text = ', decoded_text[:1000])
    print('\n'*4)

title = april
(7734,) (7734,) (7734,)
original text =  ' ' ' april ' ' ' ( apr. ) is the fourth [ [ month ] ] of the [ [ year ] ] in the [ [ julian calendar | julian ] ] and [ [ gregorian calendar ] ] s, and comes between [ [ march ] ] and [ [ may ] ]. it is one of four months to have 30 [ [ day ] ] s. april always begins on the same day of the week as [ [ july ] ], and additionally, [ [ january ] ] in leap years. april always ends on the same day of the week as [ [ december ] ]. = = the month = = [ [ file : colorful spring garden. jpg | thumb | 180px | right | [ [ spring ] ] flowers in april in the [ [ northern hemisphere ] ]. ] ] april comes between [ [ march ] ] and [ [ may ] ], making it the fourth month of the year. it also comes first in the year out of the four months that have 30 days, as [ [ june ] ], [ [ september ] ] and [ [ november ] ] are later in the year. april begins on the same day of the week as [ [ july ] ] every year and on the same day of the week as [ [ january ]

In [62]:


def find_citations(text, articles_dict):
    citations = []
    for match in re.finditer(r'\[\[(.*?)\]\]', text):
        match_text = match.group(1)
        citation = match_text.split('|') if '|' in match_text else [match_text]
        citation = [(c.split('#')[0] if '#' in c else c) for c in citation]
        ref = None
        for cit in citation:
            if cit.lower() in articles_dict:
                ref = cit.lower()
                break
        if ref:
            citations.append((match.start(), match.end(), ref))
    return citations


def clean_wiki_text(text: str) -> str:
    """Cleans wiki content by removing metadata and formatting."""
    # Find main content starting from first bold title
    match = re.search(r"'''([^']+?)'''", text)
    if match:
        text = text[match.start():]

    # Remove wiki elements and clean up
    text = re.sub(r'\[\[File:.*\]\]|\[\[Category:.*\]\]|\{\{stub.*\}\}', '', text)
    # text = re.sub(r'\[\[Category:.*\]\]', '', text)
    # text = re.sub(r'\[\[File:.*\]\]', '', text)
    # text = re.sub(r'\[\[Category:.*\]\]', '', text)
    # text = re.sub(r'\{\{stub.*\}\}', '', text)
    return '\n'.join(line for line in text.split('\n') if line.strip())

text = preprocessor.articles_dict['april'] 
text = "[[Category:something something]]" + text
print(f"Original text: \n {text[:1000]}\n\n")
text =clean_wiki_text(text)

print(f"Cleaned text: \n {text[:1000]}\n\n")

find_citations(text[:1000], preprocessor.articles_dict)

Original text: 
 [[Category:something something]]'''April''' (Apr.) is the fourth [[month]] of the [[year]] in the [[Julian calendar|Julian]] and [[Gregorian calendar]]s, and comes between [[March]] and [[May]]. It is one of four months to have 30 [[day]]s.
April always begins on the same day of the week as [[July]], and additionally, [[January]] in leap years. April always ends on the same day of the week as [[December]].
== The Month ==
[[File:Colorful spring garden.jpg|thumb|180px|right|[[Spring]] flowers in April in the [[Northern Hemisphere]].]]
April comes between [[March]] and [[May]], making it the fourth month of the year. It also comes first in the year out of the four months that have 30 days, as [[June]], [[September]] and [[November]] are later in the year.
April begins on the same day of the week as [[July]] every year and on the same day of the week as [[January]] in [[leap year]]s. April ends on the same day of the week as [[December]] every year, as each other's last d

[(33, 42, 'month'),
 (50, 58, 'year'),
 (66, 92, 'julian calendar'),
 (97, 119, 'gregorian calendar'),
 (140, 149, 'march'),
 (154, 161, 'may'),
 (199, 206, 'day'),
 (260, 268, 'july'),
 (288, 299, 'january'),
 (364, 376, 'december'),
 (414, 423, 'march'),
 (428, 435, 'may'),
 (554, 562, 'june'),
 (564, 577, 'september'),
 (582, 594, 'november'),
 (662, 670, 'july'),
 (717, 728, 'january'),
 (732, 745, 'leap year'),
 (790, 802, 'december'),
 (883, 898, 'common year'),
 (945, 956, 'october'),
 (986, 999, 'leap year')]

In [51]:
pattern = r'\[\[File:.*\]\]'
text = preprocessor.articles_dict['april'] 
matches = re.finditer(pattern, text, re.X)
[m.group() for m in matches]

['[[File:Colorful spring garden.jpg|thumb|180px|right|[[Spring]] flowers in April in the [[Northern Hemisphere]].]]',
 "[[File:Aprilsnar 2001.png|thumb|200px|right|An [[April Fools' Day]] hoax for [[April 1]] in [[Copenhagen]].]]",
 '[[File:Songkran in Wat Kungthapao 03.jpg|thumb|180px|right|[[Songkran]] celebration in [[Thailand]] around [[April 14]].]]',
 '[[File:Earth flag PD.jpg|thumb|200px|right|Proposed [[flag]] for [[Earth Day]] on [[April 22]].]]',
 "[[File:St George's Day 2010 - 14.jpg|thumb|200px|right|[[Saint George]]'s Day on [[April 23]] in [[London]]'s [[Trafalgar Square]].]]",
 '[[File:Anzac1.JPG|thumb|180px|right|[[ANZAC Day]] commemoration in [[Australia]] on [[April 25]].]]',
 "[[File:Koninginnedag2007.jpg|thumb|180px|right|Queen's Day, [[April 30]], celebration in the [[Netherlands]]. It changed to King's Day, [[April 27]], in [[2014]].]]",
 '[[File:Valborgsbrasa-1.jpg|thumb|210px|right|[[Walpurgis Night]] bonfire on [[April 30]] in [[Sweden]].]]',
 '[[File:Vajicka1.