In [None]:
import os
import json
import logging
from pathlib import Path
import torch
from datetime import datetime

from data_processing import (
    ArticleContentProcessor,
    CitationDataPreprocessor
)
from model_architecture import (
    ModelConfig,
    CitationMatcher,
    CitationDataset,
    create_dataloader
)
from training_module import (
    TrainingConfig,
    train_model
)

def setup_logging(output_dir: Path) -> None:
    """Configure logging for the training process."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(output_dir / 'training.log'),
            logging.StreamHandler()
        ]
    )

def setup_environment() -> None:
    """Configure training environment."""
    # Set environment variables
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

def create_output_directory() -> Path:
    """Create and return output directory for this training run."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = Path(f"training_runs/run_{timestamp}")
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir

def load_and_prepare_data(
    jsonl_path: str,
    train_sample_size: int,
    val_sample_size: int
) -> tuple:
    """Load and prepare training and validation data."""
    logging.info("Loading articles from JSONL file...")
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        articles_dict = {}
        for line in f:
            article = json.loads(line)
            content = ArticleContentProcessor.clean_wiki_content(article['text'])
            if content:
                articles_dict[article['title'].lower()] = content
    
    logging.info(f"Loaded {len(articles_dict)} articles")
    
    # Prepare citation data
    preprocessor = CitationDataPreprocessor(articles_dict)
    
    logging.info("Preparing training data...")
    train_sources, train_targets = preprocessor.create_citation_pairs(
        sample_size=train_sample_size,
        cite_samples_per_article=1
    )
    
    logging.info("Preparing validation data...")
    val_sources, val_targets = preprocessor.create_citation_pairs(
        sample_size=val_sample_size,
        cite_samples_per_article=10
    )
    
    return train_sources, train_targets, val_sources, val_targets

def main():
    # Setup
    output_dir = create_output_directory()
    setup_logging(output_dir)
    setup_environment()
    
    # Configuration
    model_config = ModelConfig(
        model_name="bert-base-uncased",
        max_length=512,
        cite_token="<CITE>",
        ref_token="<REF>",
        temperature=0.07
    )
    
    training_config = TrainingConfig(
        batch_size=32,
        num_epochs=10,
        learning_rate=1.5e-4,
        temperature=0.1,
        num_workers=4,
        gradient_clip_value=1.0,
        scheduler_patience=2,
        scheduler_factor=0.5,
        eval_k_values=[1, 3, 5, 10, 50]
    )
    
    # Data preparation
    train_sources, train_targets, val_sources, val_targets = load_and_prepare_data(
        jsonl_path='./data/wiki_articles.jsonl',
        train_sample_size=50000,
        val_sample_size=1000
    )
    
    # Model initialization
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")
    
    model = CitationMatcher(model_config).to(device)
    
    # Create datasets and dataloaders
    train_dataset = CitationDataset(
        sources=train_sources,
        targets=train_targets,
        tokenizer=model.tokenizer,
        config=model_config,
        verbose=True
    )
    
    val_dataset = CitationDataset(
        sources=val_sources,
        targets=val_targets,
        tokenizer=model.tokenizer,
        config=model_config,
        verbose=True
    )
    
    train_loader = create_dataloader(
        dataset=train_dataset,
        batch_size=training_config.batch_size,
        shuffle=True,
        num_workers=training_config.num_workers
    )
    
    val_loader = create_dataloader(
        dataset=val_dataset,
        batch_size=training_config.batch_size * 2,
        shuffle=False,
        num_workers=training_config.num_workers
    )
    
    # Training
    logging.info("Starting training...")
    metrics_history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=training_config,
        save_dir=output_dir / 'checkpoints',
        device=device
    )
    
    # Save training history
    torch.save(
        {
            'metrics_history': [metric.__dict__ for metric in metrics_history],
            'model_config': model_config.__dict__,
            'training_config': training_config.__dict__
        },
        output_dir / 'training_history.pt'
    )
    
    logging.info(f"Training completed. Results saved to {output_dir}")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logging.exception("An error occurred during training:")
        raise

  from .autonotebook import tqdm as notebook_tqdm
2024-11-09 19:24:50,307 - INFO - Loading articles from JSONL file...
2024-11-09 19:24:59,222 - INFO - Loaded 237381 articles
2024-11-09 19:24:59,223 - INFO - Preparing training data...
2024-11-09 19:25:02,239 - INFO - Preparing validation data...
2024-11-09 19:25:08,821 - INFO - Using device: cuda


Processed 10000 samples...
Preprocessing stats: {'skipped_no_cite': 6229, 'skipped_errors': 0, 'processed': 13349}
Processed 10000 samples...
Processed 20000 samples...
Processed 30000 samples...
Processed 40000 samples...
Processed 50000 samples...


2024-11-09 19:38:17,760 - INFO - Starting training...


Preprocessing stats: {'skipped_no_cite': 21500, 'skipped_errors': 0, 'processed': 37531}

Epoch 1/10


Training: 100%|█| 418/418 [09:12<00:00,  1.32s/it, loss=0.0201, avg_loss=0.5933]



Running validation...


Computing validation embeddings: 100%|████████| 587/587 [08:36<00:00,  1.14it/s]



Computing rankings for 37531 samples...

Validation Metrics:
top_k_accuracy: {1: 0.07204710772428126, 3: 0.15291359143108363, 5: 0.20159334949774851, 10: 0.2820601635980922, 50: 0.52745730196371}
mrr: 0.1430
median_rank: 43.0000
mean_rank: 466.2326
val_size: 37531

Epoch 1 Summary:
Training Loss: 0.5933
Validation Loss: 1.4167
Best Top-1 Accuracy: 0.0720
Mean Reciprocal Rank: 0.1430

Epoch 2/10


Training: 100%|█| 418/418 [09:15<00:00,  1.33s/it, loss=0.0368, avg_loss=0.3432]



Running validation...


Computing validation embeddings: 100%|████████| 587/587 [08:37<00:00,  1.13it/s]



Computing rankings for 37531 samples...

Validation Metrics:
top_k_accuracy: {1: 0.07122112387093336, 3: 0.14513335642535505, 5: 0.19386640377288109, 10: 0.27667794623111563, 50: 0.5266046734699316}
mrr: 0.1397
median_rank: 43.0000
mean_rank: 521.0886
val_size: 37531

Epoch 2 Summary:
Training Loss: 0.3432
Validation Loss: 1.4627
Best Top-1 Accuracy: 0.0712
Mean Reciprocal Rank: 0.1397

Epoch 3/10


Training:   5%|  | 21/418 [00:27<08:48,  1.33s/it, loss=0.2330, avg_loss=0.2555]