In [2]:
import os
import json
import logging
from pathlib import Path
import torch
from datetime import datetime

from data_processing import (
    ArticleContentProcessor,
    CitationDataPreprocessor
)
from modeling import (
    ModelConfig,
    CitationMatcher,
    CitationDataset,
    create_dataloader
)
from training import (
    TrainingConfig,
    train_model
)

def setup_logging(output_dir: Path) -> None:
    """Configure logging for the training process."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(output_dir / 'training.log'),
            logging.StreamHandler()
        ]
    )

def setup_environment() -> None:
    """Configure training environment."""
    # Set environment variables
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

def create_output_directory() -> Path:
    """Create and return output directory for this training run."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = Path(f"training_runs/run_{timestamp}")
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir

def load_and_prepare_data(
    jsonl_path: str,
    train_sample_size: int,
    val_sample_size: int
) -> tuple:
    """Load and prepare training and validation data."""
    logging.info("Loading articles from JSONL file...")
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        articles_dict = {}
        for line in f:
            article = json.loads(line)
            content = ArticleContentProcessor.clean_wiki_content(article['text'])
            if content:
                articles_dict[article['title'].lower()] = content
    
    logging.info(f"Loaded {len(articles_dict)} articles")
    
    # Prepare citation data
    preprocessor = CitationDataPreprocessor(articles_dict)
    
    logging.info("Preparing training data...")
    train_sources, train_targets = preprocessor.create_citation_pairs(
        sample_size=train_sample_size,
        cite_samples_per_article=1
    )

    S = set(train_sources)
    T = set(train_targets)
    
    logging.info("Preparing validation data...")
    val_sources, val_targets = preprocessor.create_citation_pairs(
        sample_size=val_sample_size,
        cite_samples_per_article=10
    )

    # Remove any validation samples that are also in the training set
    val_sources, val_targets = zip(*[
        (source, target) for source, target in zip(val_sources, val_targets)
        if source not in S and target not in T
    ])
    
    return train_sources, train_targets, val_sources, val_targets

def main():
    # Setup
    output_dir = create_output_directory()
    setup_logging(output_dir)
    setup_environment()
    
    # Configuration
    model_config = ModelConfig(
        model_name="bert-base-uncased",
        max_length=512,
        cite_token="<CITE>",
        ref_token="<REF>",
        temperature=0.07
    )
    
    training_config = TrainingConfig(
        batch_size=32,
        num_epochs=10,
        learning_rate=1.5e-4,
        temperature=0.1,
        num_workers=4,
        gradient_clip_value=1.0,
        scheduler_patience=2,
        scheduler_factor=0.5,
        eval_k_values=[1, 3, 5, 10, 50]
    )
    
    # Data preparation
    train_sources, train_targets, val_sources, val_targets = load_and_prepare_data(
        jsonl_path='./data/wiki_articles.jsonl',
        train_sample_size=10000,
        val_sample_size=100
    )
    
    # Model initialization
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")
    
    model = CitationMatcher(model_config).to(device)
    
    # Create datasets and dataloaders
    train_dataset = CitationDataset(
        sources=train_sources,
        targets=train_targets,
        tokenizer=model.tokenizer,
        config=model_config,
        verbose=True
    )
    
    val_dataset = CitationDataset(
        sources=val_sources,
        targets=val_targets,
        tokenizer=model.tokenizer,
        config=model_config,
        verbose=True
    )
    
    train_loader = create_dataloader(
        dataset=train_dataset,
        batch_size=training_config.batch_size,
        shuffle=True,
        num_workers=training_config.num_workers
    )
    
    val_loader = create_dataloader(
        dataset=val_dataset,
        batch_size=training_config.batch_size * 2,
        shuffle=False,
        num_workers=training_config.num_workers
    )
    
    # Training
    logging.info("Starting training...")
    metrics_history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=training_config,
        save_dir=output_dir / 'checkpoints',
        device=device
    )
    
    # Save training history
    torch.save(
        {
            'metrics_history': [metric.__dict__ for metric in metrics_history],
            'model_config': model_config.__dict__,
            'training_config': training_config.__dict__
        },
        output_dir / 'training_history.pt'
    )
    
    logging.info(f"Training completed. Results saved to {output_dir}")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logging.exception("An error occurred during training:")
        raise

2024-11-09 21:10:53,544 - INFO - Loading articles from JSONL file...
2024-11-09 21:11:02,441 - INFO - Loaded 237381 articles
2024-11-09 21:11:02,442 - INFO - Preparing training data...
2024-11-09 21:11:04,229 - INFO - Preparing validation data...
2024-11-09 21:11:04,725 - INFO - Using device: cuda
Processing samples: 100%|██████████| 77/77 [01:05<00:00,  1.17batch/s]



Preprocessing Statistics:
Total samples: 9798
Processed: 6676
Skipped (no citation): 3122
Skipped (errors): 0


Processing samples: 100%|██████████| 3/3 [00:01<00:00,  2.43batch/s]
2024-11-09 21:12:11,947 - INFO - Starting training...



Preprocessing Statistics:
Total samples: 296
Processed: 194
Skipped (no citation): 102
Skipped (errors): 0

Epoch 1/10


Training:  17%|█▋        | 35/209 [00:45<03:46,  1.30s/it, loss=0.5454, avg_loss=1.0640]

In [None]:
!pip install IProgress