### **TensorBoard Integration with HuggingFace Trainer**
---
A comprehensive guide to leveraging TensorBoard features with HuggingFace models

Features covered:
1. Scalars (loss, metrics, learning rate)
2. Images (input samples, attention maps)
3. Embeddings (high-dimensional visualization)
4. Graphs (model architecture)
5. Histograms (weight distributions)
6. Hyperparameters (HP tuning comparison)
7. Confusion Matrix
8. PR Curves

In [None]:
# ============================================================================
# INSTALLATION & IMPORTS
# ============================================================================

# !pip install transformers datasets tensorboard torch torchvision evaluate scikit-learn

import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    TrainerCallback
)
from datasets import load_dataset
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List
import json
from datetime import datetime

In [None]:
# ============================================================================
# 1. SETUP: MODEL, DATASET & TOKENIZER
# ============================================================================

print("Loading model and dataset...")

# Using IMDB for sentiment analysis (binary classification)
MODEL_NAME = "distilbert-base-uncased"
DATASET_NAME = "imdb"
NUM_LABELS = 2
MAX_LENGTH = 256
BATCH_SIZE = 16
NUM_EPOCHS = 2

# Load dataset
dataset = load_dataset(DATASET_NAME, split={'train': 'train[:1000]', 'test': 'test[:200]'})

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, 
    num_labels=NUM_LABELS,
    output_attentions=True,  # Important for attention visualization
    output_hidden_states=True  # Important for embeddings
)

# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=MAX_LENGTH)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [None]:
# ============================================================================
# 2. CUSTOM TENSORBOARD CALLBACK
# ============================================================================

class AdvancedTensorBoardCallback(TrainerCallback):
    """
    Enhanced TensorBoard callback for comprehensive logging
    """
    
    def __init__(self, log_dir="./runs"):
        self.log_dir = log_dir
        self.writer = SummaryWriter(log_dir)
        self.step = 0
        self.predictions_cache = []
        self.labels_cache = []
        
    def on_train_begin(self, args, state, control, model=None, **kwargs):
        """Log model graph and hyperparameters"""
        
        # 1. LOG MODEL GRAPH
        print("Logging model graph...")
        dummy_input = torch.randint(0, 1000, (1, MAX_LENGTH)).to(model.device)
        dummy_attention = torch.ones((1, MAX_LENGTH)).to(model.device)
        
        try:
            self.writer.add_graph(
                model, 
                (dummy_input, dummy_attention)
            )
        except Exception as e:
            print(f"Could not log graph: {e}")
        
        # 2. LOG HYPERPARAMETERS
        hparams = {
            'learning_rate': args.learning_rate,
            'batch_size': args.per_device_train_batch_size,
            'epochs': args.num_train_epochs,
            'weight_decay': args.weight_decay,
            'warmup_steps': args.warmup_steps,
            'model_name': MODEL_NAME,
        }
        
        # Log as text
        hparam_text = json.dumps(hparams, indent=2)
        self.writer.add_text('Hyperparameters', hparam_text, 0)
        
        print("Training setup logged to TensorBoard")
    
    def on_log(self, args, state, control, logs=None, **kwargs):
        """Log scalars during training"""
        if logs is not None:
            for key, value in logs.items():
                if isinstance(value, (int, float)):
                    self.writer.add_scalar(f'train/{key}', value, state.global_step)
    
    def on_epoch_end(self, args, state, control, model=None, **kwargs):
        """Log weight histograms and distributions at epoch end"""
        
        print(f"Logging weight distributions for epoch {state.epoch}...")
        
        # LOG WEIGHT HISTOGRAMS
        for name, param in model.named_parameters():
            if param.requires_grad and param.grad is not None:
                # Weight values
                self.writer.add_histogram(
                    f'weights/{name}', 
                    param.data.cpu().numpy(), 
                    state.global_step
                )
                # Gradient values
                self.writer.add_histogram(
                    f'gradients/{name}', 
                    param.grad.cpu().numpy(), 
                    state.global_step
                )
    
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        """Log evaluation metrics"""
        if metrics is not None:
            for key, value in metrics.items():
                if isinstance(value, (int, float)):
                    self.writer.add_scalar(f'eval/{key}', value, state.global_step)
    
    def on_train_end(self, args, state, control, model=None, **kwargs):
        """Generate final visualizations"""
        print("Training complete. Generating final visualizations...")
        self.writer.close()

In [None]:
# ============================================================================
# 3. CUSTOM CALLBACK FOR ADVANCED VISUALIZATIONS
# ============================================================================

class VisualizationCallback(TrainerCallback):
    """
    Callback for advanced visualizations: attention maps, embeddings, confusion matrix
    """
    
    def __init__(self, eval_dataset, tokenizer, log_dir="./runs"):
        self.eval_dataset = eval_dataset
        self.tokenizer = tokenizer
        self.writer = SummaryWriter(log_dir)
        self.all_predictions = []
        self.all_labels = []
        
    def on_evaluate(self, args, state, control, model=None, **kwargs):
        """Generate visualizations during evaluation"""
        
        if state.global_step % 100 == 0:  # Log every 100 steps
            self._log_attention_maps(model, state.global_step)
            self._log_sample_predictions(model, state.global_step)
    
    def _log_attention_maps(self, model, step):
        """Visualize attention weights"""
        
        print(f"Logging attention maps at step {step}...")
        
        model.eval()
        # Get a sample
        sample_idx = np.random.randint(0, len(self.eval_dataset))
        sample = self.eval_dataset[sample_idx]
        
        input_ids = torch.tensor([sample['input_ids']]).to(model.device)
        attention_mask = torch.tensor([sample['attention_mask']]).to(model.device)
        
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask, output_attentions=True)
            attentions = outputs.attentions  # Tuple of attention weights per layer
        
        # Visualize last layer attention (head 0)
        last_layer_attention = attentions[-1][0, 0].cpu().numpy()  # [seq_len, seq_len]
        
        # Create heatmap
        fig, ax = plt.subplots(figsize=(10, 10))
        sns.heatmap(last_layer_attention[:50, :50], cmap='viridis', ax=ax, cbar=True)
        ax.set_title('Attention Weights (Last Layer, Head 0)')
        ax.set_xlabel('Key Position')
        ax.set_ylabel('Query Position')
        
        # Log to tensorboard
        self.writer.add_figure('attention/last_layer_head_0', fig, step)
        plt.close(fig)
    
    def _log_sample_predictions(self, model, step):
        """Log sample text predictions as images"""
        
        print(f"Logging sample predictions at step {step}...")
        
        model.eval()
        samples = []
        
        for i in range(min(5, len(self.eval_dataset))):
            sample = self.eval_dataset[i]
            input_ids = torch.tensor([sample['input_ids']]).to(model.device)
            attention_mask = torch.tensor([sample['attention_mask']]).to(model.device)
            
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)
                prediction = torch.argmax(outputs.logits, dim=-1).item()
            
            # Decode text
            text = self.tokenizer.decode(sample['input_ids'], skip_special_tokens=True)
            text = text[:100] + "..." if len(text) > 100 else text
            
            label_map = {0: "Negative", 1: "Positive"}
            samples.append({
                'text': text,
                'true': label_map[sample['label']],
                'pred': label_map[prediction],
                'correct': prediction == sample['label']
            })
        
        # Create visualization
        fig, axes = plt.subplots(5, 1, figsize=(12, 10))
        fig.suptitle('Sample Predictions', fontsize=16)
        
        for i, (ax, sample) in enumerate(zip(axes, samples)):
            ax.axis('off')
            color = 'green' if sample['correct'] else 'red'
            text_content = f"Text: {sample['text']}\nTrue: {sample['true']} | Pred: {sample['pred']}"
            ax.text(0.05, 0.5, text_content, fontsize=10, verticalalignment='center',
                   bbox=dict(boxstyle='round', facecolor=color, alpha=0.3))
        
        plt.tight_layout()
        self.writer.add_figure('predictions/samples', fig, step)
        plt.close(fig)
    
    def on_prediction_step(self, args, state, control, **kwargs):
        """Cache predictions for confusion matrix"""
        # This would be called during evaluation
        pass
    
    def on_evaluate(self, args, state, control, model=None, metrics=None, **kwargs):
        """Generate confusion matrix after evaluation"""
        
        # Run evaluation to get predictions
        if model is not None:
            self._generate_confusion_matrix(model, state.global_step)
    
    def _generate_confusion_matrix(self, model, step):
        """Generate and log confusion matrix"""
        
        print(f"Generating confusion matrix at step {step}...")
        
        model.eval()
        all_predictions = []
        all_labels = []
        
        # Get predictions for eval set
        for i in range(len(self.eval_dataset)):
            sample = self.eval_dataset[i]
            input_ids = torch.tensor([sample['input_ids']]).to(model.device)
            attention_mask = torch.tensor([sample['attention_mask']]).to(model.device)
            
            with torch.no_grad():
                outputs = model(input_ids, attention_mask=attention_mask)
                prediction = torch.argmax(outputs.logits, dim=-1).item()
            
            all_predictions.append(prediction)
            all_labels.append(sample['label'])
        
        # Create confusion matrix
        cm = confusion_matrix(all_labels, all_predictions)
        
        # Visualize
        fig, ax = plt.subplots(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax,
                   xticklabels=['Negative', 'Positive'],
                   yticklabels=['Negative', 'Positive'])
        ax.set_title('Confusion Matrix')
        ax.set_ylabel('True Label')
        ax.set_xlabel('Predicted Label')
        
        self.writer.add_figure('evaluation/confusion_matrix', fig, step)
        plt.close(fig)
        
        # Log classification report as text
        report = classification_report(all_labels, all_predictions, 
                                      target_names=['Negative', 'Positive'])
        self.writer.add_text('evaluation/classification_report', f"```\n{report}\n```", step)

In [None]:
# ============================================================================
# 4. EMBEDDING VISUALIZATION CALLBACK
# ============================================================================

class EmbeddingCallback(TrainerCallback):
    """
    Log embeddings for visualization in TensorBoard's projector
    """
    
    def __init__(self, eval_dataset, tokenizer, log_dir="./runs"):
        self.eval_dataset = eval_dataset
        self.tokenizer = tokenizer
        self.writer = SummaryWriter(log_dir)
    
    def on_evaluate(self, args, state, control, model=None, **kwargs):
        """Extract and log embeddings"""
        
        if state.global_step % 200 == 0:  # Log less frequently
            print(f"Logging embeddings at step {state.global_step}...")
            
            model.eval()
            embeddings_list = []
            labels_list = []
            metadata_list = []
            
            # Extract embeddings for subset of data
            num_samples = min(100, len(self.eval_dataset))
            
            for i in range(num_samples):
                sample = self.eval_dataset[i]
                input_ids = torch.tensor([sample['input_ids']]).to(model.device)
                attention_mask = torch.tensor([sample['attention_mask']]).to(model.device)
                
                with torch.no_grad():
                    outputs = model(input_ids, attention_mask=attention_mask, 
                                   output_hidden_states=True)
                    # Use [CLS] token embedding from last hidden state
                    embedding = outputs.hidden_states[-1][0, 0, :].cpu().numpy()
                
                embeddings_list.append(embedding)
                labels_list.append(sample['label'])
                
                # Create metadata (first 50 chars of text)
                text = self.tokenizer.decode(sample['input_ids'], skip_special_tokens=True)
                metadata_list.append(text[:50])
            
            # Stack embeddings
            embeddings_tensor = torch.tensor(np.array(embeddings_list))
            
            # Log to tensorboard
            self.writer.add_embedding(
                embeddings_tensor,
                metadata=metadata_list,
                label_img=None,
                global_step=state.global_step,
                tag='sentence_embeddings'
            )


In [None]:
# ============================================================================
# 5. TRAINING CONFIGURATION
# ============================================================================

# Create log directory with timestamp
log_dir = f"./runs/hf_tensorboard_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=50,
    learning_rate=2e-5,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_EPOCHS,
    weight_decay=0.01,
    warmup_steps=100,
    logging_dir=log_dir,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="tensorboard",  # Enable TensorBoard
)

# Compute metrics
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    accuracy = (predictions == labels).mean()
    return {"accuracy": accuracy}

In [None]:
# ============================================================================
# 6. INITIALIZE TRAINER WITH CALLBACKS
# ============================================================================

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    compute_metrics=compute_metrics,
    callbacks=[
        AdvancedTensorBoardCallback(log_dir=log_dir),
        VisualizationCallback(
            tokenized_datasets["test"], 
            tokenizer, 
            log_dir=log_dir
        ),
        EmbeddingCallback(
            tokenized_datasets["test"],
            tokenizer,
            log_dir=log_dir
        )
    ]
)

In [None]:
# ============================================================================
# 7. TRAIN THE MODEL
# ============================================================================

print("="*80)
print("Starting training with TensorBoard logging...")
print(f"TensorBoard logs will be saved to: {log_dir}")
print("="*80)

trainer.train()

In [None]:
# ============================================================================
# 8. FINAL EVALUATION & VISUALIZATION
# ============================================================================

print("\n" + "="*80)
print("Running final evaluation...")
print("="*80)

eval_results = trainer.evaluate()
print("\nFinal Evaluation Results:")
print(json.dumps(eval_results, indent=2))

In [None]:
# ============================================================================
# 9. LAUNCH TENSORBOARD
# ============================================================================

print("\n" + "="*80)
print("Training complete!")
print("="*80)
print(f"\nTo view TensorBoard, run:")
print(f"  tensorboard --logdir {log_dir}")
print("\nThen open: http://localhost:6006")
print("\nFeatures available in TensorBoard:")
print("  • SCALARS: Loss, accuracy, learning rate curves")
print("  • IMAGES: Attention heatmaps, sample predictions")
print("  • GRAPHS: Model architecture")
print("  • DISTRIBUTIONS: Weight and gradient distributions")
print("  • HISTOGRAMS: Weight changes over time")
print("  • PROJECTOR: Embedding visualization (PCA/t-SNE)")
print("  • TEXT: Hyperparameters, classification reports")
print("  • HPARAMS: Hyperparameter comparison (run multiple experiments)")
print("="*80)


In [None]:

# Optional: Launch TensorBoard programmatically (uncomment to use)
# %load_ext tensorboard
# %tensorboard --logdir {log_dir}