# Text Classification Model Training
This notebook demonstrates how to train a text classification model using the `llm-trainer` framework with mixed precision and tqdm progress tracking.


In [None]:
import torch
from datasets import load_dataset
from llm_trainer.config import ModelConfig, TrainingConfig, DataConfig
from llm_trainer.models import TransformerLM
from llm_trainer.tokenizer import create_tokenizer
from llm_trainer.training import Trainer

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## 1. Load and Preprocess Dataset
We'll use a sample dataset for classification.


In [None]:
# Load a small classification dataset (e.g., sentiment analysis)
dataset = load_dataset("imdb", split="train[:1000]")
print(f"Dataset size: {len(dataset)}")

# Initialize tokenizer
tokenizer = create_tokenizer("bpe")
tokenizer.train(dataset, vocab_size=5000, text_column="text")


## 2. Configure Model and Training with Mixed Precision
We'll enable `fp16` or `bf16` for efficient training.


In [None]:
# Configure model
model_config = ModelConfig(
    vocab_size=tokenizer.vocab_size,
    d_model=256,
    n_heads=4,
    n_layers=4
)

# Configure training with mixed precision (bf16 if supported, else fp16)
use_bf16 = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8

training_config = TrainingConfig(
    batch_size=8,
    learning_rate=1e-4,
    num_epochs=3,
    fp16=not use_bf16,
    bf16=use_bf16,
    logging_steps=10,
    checkpoint_dir="./checkpoints/classification"
)

# Initialize model
model = TransformerLM(model_config)
print(f"Model initialized with {model.get_num_params():,} parameters")


## 3. Initialize Trainer and Run Training with tqdm
The trainer will use `tqdm` for progress tracking.


In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    tokenizer=tokenizer,
    config=training_config
)

# Configure data
data_config = DataConfig(
    dataset_name="imdb",
    max_length=512,
    text_column="text"
)

# Start training (this will show tqdm progress bars)
trainer.train_from_config(model_config, data_config)
