In [1]:
import os
import json
import logging
from pathlib import Path
import torch
from datetime import datetime

from wiki_citations import CitationDataPreprocessor

from modeling import (
    ModelConfig,
    CitationMatcher,
    CitationDataset,
    create_dataloader
)
from training import (
    TrainingConfig,
    train_model
)

def setup_logging(output_dir: Path) -> None:
    """Configure logging for the training process."""
    logging.basicConfig(
        level=logging.INFO,
        format='%(asctime)s - %(levelname)s - %(message)s',
        handlers=[
            logging.FileHandler(output_dir / 'training.log'),
            logging.StreamHandler()
        ]
    )

def setup_environment() -> None:
    """Configure training environment."""
    # Set environment variables
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(42)

def create_output_directory() -> Path:
    """Create and return output directory for this training run."""
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    output_dir = Path(f"training_runs/run_{timestamp}")
    output_dir.mkdir(parents=True, exist_ok=True)
    return output_dir

def load_and_prepare_data(
    jsonl_path: str,
    train_sample_size: int,
    val_sample_size: int
) -> tuple:
    """Load and prepare training and validation data."""
    logging.info("Loading articles from JSONL file...")
    articles_dict = {}
    with open(jsonl_path, 'r', encoding='utf-8') as f:
        for line in f:
            article = json.loads(line)
            articles_dict[article['title'].lower()] = article['text']
    
    logging.info(f"Loaded {len(articles_dict)} articles")
    
    # Prepare citation data
    preprocessor = CitationDataPreprocessor(articles_dict)
    
    logging.info("Preparing training data...")
    train_sources, train_targets = preprocessor.create_citation_pairs(
        sample_size=train_sample_size,
        cite_samples_per_article=1
    )

    S = set(train_sources)
    T = set(train_targets)
    
    logging.info("Preparing validation data...")
    val_sources, val_targets = preprocessor.create_citation_pairs(
        sample_size=val_sample_size,
        cite_samples_per_article=10
    )

    # Remove any validation samples that are also in the training set
    val_sources, val_targets = zip(*[
        (source, target) for source, target in zip(val_sources, val_targets)
        if source not in S and target not in T
    ])
    
    return train_sources, train_targets, val_sources, val_targets

def main():
    # Setup
    output_dir = create_output_directory()
    setup_logging(output_dir)
    setup_environment()
    
    # Configuration
    model_config = ModelConfig(
        model_name="bert-base-uncased",
        max_length=512,
        cite_token="<CITE>",
        ref_token="<REF>",
        temperature=0.07
    )
    
    training_config = TrainingConfig(
        batch_size=32,
        num_epochs=10,
        learning_rate=1.5e-4,
        temperature=0.1,
        num_workers=4,
        gradient_clip_value=1.0,
        scheduler_patience=2,
        scheduler_factor=0.5,
        eval_k_values=[1, 5, 10, 50]
    )
    
    # Data preparation
    train_sources, train_targets, val_sources, val_targets = load_and_prepare_data(
        jsonl_path='./data/wiki_articles.jsonl',
        train_sample_size=1000,
        val_sample_size=100
    )
    
    # Model initialization
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logging.info(f"Using device: {device}")
    
    model = CitationMatcher(model_config).to(device)
    
    # Create datasets and dataloaders
    train_dataset = CitationDataset(
        sources=train_sources,
        targets=train_targets,
        tokenizer=model.tokenizer,
        config=model_config,
        verbose=True
    )
    
    val_dataset = CitationDataset(
        sources=val_sources,
        targets=val_targets,
        tokenizer=model.tokenizer,
        config=model_config,
        verbose=True
    )
    
    train_loader = create_dataloader(
        dataset=train_dataset,
        batch_size=training_config.batch_size,
        shuffle=True,
        num_workers=training_config.num_workers
    )
    
    val_loader = create_dataloader(
        dataset=val_dataset,
        batch_size=training_config.batch_size * 2,
        shuffle=False,
        num_workers=training_config.num_workers
    )
    
    # Training
    logging.info("Starting training...")
    metrics_history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=training_config,
        save_dir=output_dir / 'checkpoints',
        device=device
    )
    
    # Save training history
    torch.save(
        {
            'metrics_history': [metric.__dict__ for metric in metrics_history],
            'model_config': model_config.__dict__,
            'training_config': training_config.__dict__
        },
        output_dir / 'training_history.pt'
    )
    
    logging.info(f"Training completed. Results saved to {output_dir}")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        logging.exception("An error occurred during training:")
        raise

  from .autonotebook import tqdm as notebook_tqdm
2024-11-10 13:13:49,247 - INFO - Loading articles from JSONL file...
2024-11-10 13:13:53,468 - INFO - Loaded 237381 articles
2024-11-10 13:13:53,469 - INFO - Preparing training data...
2024-11-10 13:13:54,196 - INFO - Preparing validation data...
2024-11-10 13:13:55,080 - INFO - Using device: cuda
Processing samples: 100%|██████████| 2/2 [00:11<00:00,  5.87s/batch]



Processed 1213 samples
Skipped 757 samples without citation


Processing samples: 100%|██████████| 3/3 [00:11<00:00,  3.72s/batch]
2024-11-10 13:14:19,150 - INFO - Starting training...



Processed 1360 samples
Skipped 1194 samples without citation

Epoch 1/10


Training: 100%|██████████| 38/38 [00:49<00:00,  1.31s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.19it/s]



Epoch 1 Summary:
Training Loss: 0.9453
Validation Loss: 1.5942
Best Top-1 Accuracy: 0.2846
Mean Reciprocal Rank: 0.4319
Validation Size: 1360
Top-1 Accuracy: 0.285
Top-5 Accuracy: 0.612
Top-10 Accuracy: 0.726
Top-50 Accuracy: 0.900

Epoch 2/10


Training: 100%|██████████| 38/38 [00:49<00:00,  1.30s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.19it/s]



Epoch 2 Summary:
Training Loss: 0.2869
Validation Loss: 1.6976
Best Top-1 Accuracy: 0.2691
Mean Reciprocal Rank: 0.4126
Validation Size: 1360
Top-1 Accuracy: 0.269
Top-5 Accuracy: 0.577
Top-10 Accuracy: 0.701
Top-50 Accuracy: 0.892

Epoch 3/10


Training: 100%|██████████| 38/38 [00:49<00:00,  1.31s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.18it/s]



Epoch 3 Summary:
Training Loss: 0.1870
Validation Loss: 1.7977
Best Top-1 Accuracy: 0.2478
Mean Reciprocal Rank: 0.3928
Validation Size: 1360
Top-1 Accuracy: 0.248
Top-5 Accuracy: 0.569
Top-10 Accuracy: 0.673
Top-50 Accuracy: 0.881

Epoch 4/10


Training: 100%|██████████| 38/38 [00:50<00:00,  1.32s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.18it/s]



Epoch 4 Summary:
Training Loss: 0.1536
Validation Loss: 1.8233
Best Top-1 Accuracy: 0.2529
Mean Reciprocal Rank: 0.3970
Validation Size: 1360
Top-1 Accuracy: 0.253
Top-5 Accuracy: 0.571
Top-10 Accuracy: 0.686
Top-50 Accuracy: 0.888

Epoch 5/10


Training: 100%|██████████| 38/38 [00:50<00:00,  1.32s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.18it/s]



Epoch 5 Summary:
Training Loss: 0.1209
Validation Loss: 1.7661
Best Top-1 Accuracy: 0.2676
Mean Reciprocal Rank: 0.4162
Validation Size: 1360
Top-1 Accuracy: 0.268
Top-5 Accuracy: 0.601
Top-10 Accuracy: 0.704
Top-50 Accuracy: 0.897

Epoch 6/10


Training: 100%|██████████| 38/38 [00:50<00:00,  1.32s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.18it/s]



Epoch 6 Summary:
Training Loss: 0.0823
Validation Loss: 1.7657
Best Top-1 Accuracy: 0.2551
Mean Reciprocal Rank: 0.4033
Validation Size: 1360
Top-1 Accuracy: 0.255
Top-5 Accuracy: 0.594
Top-10 Accuracy: 0.696
Top-50 Accuracy: 0.893

Epoch 7/10


Training: 100%|██████████| 38/38 [00:50<00:00,  1.32s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.18it/s]



Epoch 7 Summary:
Training Loss: 0.0735
Validation Loss: 1.7819
Best Top-1 Accuracy: 0.2669
Mean Reciprocal Rank: 0.4148
Validation Size: 1360
Top-1 Accuracy: 0.267
Top-5 Accuracy: 0.595
Top-10 Accuracy: 0.714
Top-50 Accuracy: 0.896

Epoch 8/10


Training: 100%|██████████| 38/38 [00:50<00:00,  1.32s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.18it/s]



Epoch 8 Summary:
Training Loss: 0.0619
Validation Loss: 1.7203
Best Top-1 Accuracy: 0.2699
Mean Reciprocal Rank: 0.4208
Validation Size: 1360
Top-1 Accuracy: 0.270
Top-5 Accuracy: 0.618
Top-10 Accuracy: 0.726
Top-50 Accuracy: 0.900

Epoch 9/10


Training: 100%|██████████| 38/38 [00:50<00:00,  1.32s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.18it/s]



Epoch 9 Summary:
Training Loss: 0.0558
Validation Loss: 1.7035
Best Top-1 Accuracy: 0.2713
Mean Reciprocal Rank: 0.4231
Validation Size: 1360
Top-1 Accuracy: 0.271
Top-5 Accuracy: 0.618
Top-10 Accuracy: 0.732
Top-50 Accuracy: 0.901

Epoch 10/10


Training: 100%|██████████| 38/38 [00:50<00:00,  1.32s/it]
Validation: 100%|██████████| 22/22 [00:18<00:00,  1.18it/s]
2024-11-10 13:25:55,518 - INFO - Training completed. Results saved to training_runs/run_20241110_131348



Epoch 10 Summary:
Training Loss: 0.0488
Validation Loss: 1.7250
Best Top-1 Accuracy: 0.2654
Mean Reciprocal Rank: 0.4211
Validation Size: 1360
Top-1 Accuracy: 0.265
Top-5 Accuracy: 0.618
Top-10 Accuracy: 0.726
Top-50 Accuracy: 0.899
