## lightining callback

In [None]:
!python -m pip install pytorch-lightning 

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.6.0-py3-none-any.whl.metadata (21 kB)
Collecting torch>=2.1.0 (from pytorch-lightning)
  Downloading torch-2.9.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting tqdm>=4.57.0 (from pytorch-lightning)
  Downloading tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Collecting fsspec>=2022.5.0 (from fsspec[http]>=2022.5.0->pytorch-lightning)
  Downloading fsspec-2026.1.0-py3-none-any.whl.metadata (10 kB)
Collecting torchmetrics>0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Collecting aiohttp!=4.0.0a0,!=4.0.0a1 (from fsspec[http]>=2022.5.0->pytorch-lightning)
  Downloading aiohttp-3.13.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl.metadata (8.1 kB)
Collecting aiohappyeyeballs>=2.5.0 (from a

In [6]:
!python -m pip install scikit-learn

Collecting scikit-learn
  Downloading scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (11 kB)
Collecting scipy>=1.10.0 (from scikit-learn)
  Downloading scipy-1.17.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (62 kB)
Collecting joblib>=1.3.0 (from scikit-learn)
  Downloading joblib-1.5.3-py3-none-any.whl.metadata (5.5 kB)
Collecting threadpoolctl>=3.2.0 (from scikit-learn)
  Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)
Downloading scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (8.9 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m8.9/8.9 MB[0m [31m22.4 MB/s[0m  [33m0:00:00[0mm0:00:01[0m00:01[0m
[?25hDownloading joblib-1.5.3-py3-none-any.whl (309 kB)
Downloading scipy-1.17.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (35.0 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚î

In PyTorch Lightning, the four required functions you need to implement are:

1. __init__: This function initializes the model and its parameters.
2. forward: This function defines the forward pass of the model, specifying how input data is transformed into output.
3. configure_optimizers: This function sets up the optimizer(s) and learning rate scheduler(s) for training.
4. training_step: This function defines a single step of training, including how to compute the loss from the input data.

In [None]:
import pytorch_lightning as pl
import torch
from sklearn.metrics import confusion_matrix, precision_recall_fscore_support
import numpy as np
import json
import os
from datetime import datetime

class DebugLogger(pl.Callback):
    def __init__(self, save_dir='./logs'):
        super().__init__()
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)
        
        self.metrics = []
        self.all_predictions = []
        self.all_labels = []
        self.misclassified_images = []
        self.misclassified_preds = []
        self.misclassified_labels = []
        self.confusion_matrices = []
        self.class_metrics = []

    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0):
        """Called at the end of each validation batch to collect predictions"""
        x, y = batch
        logits = pl_module(x)
        preds = torch.argmax(logits, dim=1)
        
        # Store all predictions and labels
        self.all_predictions.extend(preds.cpu().numpy().tolist())
        self.all_labels.extend(y.cpu().numpy().tolist())
        
        # Identify misclassified instances
        misclassified_mask = (preds != y)
        if misclassified_mask.any():
            self.misclassified_images.append(x[misclassified_mask].cpu().numpy().tolist())
            self.misclassified_preds.append(preds[misclassified_mask].cpu().numpy().tolist())
            self.misclassified_labels.append(y[misclassified_mask].cpu().numpy().tolist())

    def on_validation_epoch_end(self, trainer, pl_module):
        """Called at the end of validation epoch to compute metrics"""
        if self.all_predictions:
            # Compute confusion matrix
            cm = confusion_matrix(self.all_labels, self.all_predictions)
            self.confusion_matrices.append({
                'epoch': trainer.current_epoch,
                'confusion_matrix': cm.tolist()
            })
            
            # Compute class-wise metrics
            precision, recall, f1, support = precision_recall_fscore_support(
                self.all_labels, self.all_predictions, average=None
            )
            self.class_metrics.append({
                'epoch': trainer.current_epoch,
                'precision': precision.tolist(),
                'recall': recall.tolist(),
                'f1': f1.tolist(),
                'support': support.tolist()
            })
            
            # Clear for next epoch
            self.all_predictions = []
            self.all_labels = []

    def on_train_epoch_end(self, trainer, pl_module):
        """Log metrics at the end of each epoch"""
        metrics = trainer.callback_metrics
        epoch_data = {
            'epoch': trainer.current_epoch,
            'train_loss': float(metrics.get('train_loss', 0).item() if hasattr(metrics.get('train_loss', 0), 'item') else metrics.get('train_loss', 0)),
            'val_loss': float(metrics.get('val_loss', 0).item() if hasattr(metrics.get('val_loss', 0), 'item') else metrics.get('val_loss', 0)),
            'accuracy': float(metrics.get('accuracy', 0).item() if hasattr(metrics.get('accuracy', 0), 'item') else metrics.get('accuracy', 0)),
            'num_misclassified': len(self.misclassified_images),
        }
        self.metrics.append(epoch_data)
        print(f"Epoch {trainer.current_epoch}: {epoch_data}")
        
        # Clear misclassified for next epoch
        self.misclassified_images = []
        self.misclassified_preds = []
        self.misclassified_labels = []

    def on_train_end(self, trainer, pl_module):
        """Save epoch metrics to a single JSON file at the end of training"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Combine metrics with summary
        all_data = {
            'summary': {
                'timestamp': timestamp,
                'total_epochs': len(self.metrics),
                'best_accuracy': max([m['accuracy'] for m in self.metrics]) if self.metrics else 0,
                'best_epoch': max(range(len(self.metrics)), key=lambda i: self.metrics[i]['accuracy']) if self.metrics else 0,
                'final_train_loss': self.metrics[-1]['train_loss'] if self.metrics else 0,
                'final_val_loss': self.metrics[-1]['val_loss'] if self.metrics else 0,
            },
            'epochs': self.metrics
        }
        
        # Save single consolidated JSON file
        output_file = os.path.join(self.save_dir, f'training_log_{timestamp}.json')
        with open(output_file, 'w') as f:
            json.dump(all_data, f, indent=2)
        
        print(f"\n‚úì Training log saved to: {output_file}")
        print(f"\nüìä Final Summary:")
        print(f"  ‚Ä¢ Total Epochs: {all_data['summary']['total_epochs']}")
        print(f"  ‚Ä¢ Best Accuracy: {all_data['summary']['best_accuracy']:.4f} (Epoch {all_data['summary']['best_epoch']})")
        print(f"  ‚Ä¢ Final Train Loss: {all_data['summary']['final_train_loss']:.4f}")
        print(f"  ‚Ä¢ Final Val Loss: {all_data['summary']['final_val_loss']:.4f}")

In [5]:
# Initialize the model
model = SimpleCNN(learning_rate=0.001)

# Create the custom logger
debug_logger = DebugLogger()

# Create the trainer with max_epochs and pass data loaders
trainer = pl.Trainer(max_epochs=5, callbacks=[debug_logger], enable_progress_bar=True)

# Train the model with data loaders
trainer.fit(model, train_loader, val_loader)

üí° Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
/usr/local/python/3.12.1/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name    | Type      | Params | Mode  | FLOPs
------------------------------------------------------
0 | conv1   | Conv2d    | 896    | train | 0    
1 | conv2   | Conv2d    |

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Epoch 0: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1563/1563 [01:38<00:00, 15.91it/s, v_num=2]      Epoch 0: {'epoch': 0, 'train_loss': 0.9089827537536621, 'val_loss': 1.0930086374282837, 'accuracy': 0.6062999963760376, 'num_misclassified': 315}
Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1563/1563 [01:36<00:00, 16.24it/s, v_num=2]Epoch 1: {'epoch': 1, 'train_loss': 1.124178171157837, 'val_loss': 0.907422661781311, 'accuracy': 0.682699978351593, 'num_misclassified': 313}
Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1563/1563 [01:38<00:00, 15.79it/s, v_num=2]Epoch 2: {'epoch': 2, 'train_loss': 0.8515440821647644, 'val_loss': 0.805952787399292, 'accuracy': 0.7185999751091003, 'num_misclassified': 313}
Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1563/1563 [01:40<00:00, 15.63it/s, v_num=2]Epoch 3: {'epoch': 3, 'train_loss': 0.9061216115951538, 'val_loss': 0.7652513384819031, 'accuracy': 0.7365999817848206, 'num_misclassified': 313}
Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1563/1563 [0

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1563/1563 [01:39<00:00, 15.66it/s, v_num=2]


Based on the new data you are capturing, here are some additional metrics you might consider logging to enhance your debugging and analysis capabilities:

Overall Accuracy:

The percentage of correctly classified samples out of the total samples.
Precision, Recall, and F1 Score:

These metrics provide insights into the model's performance, especially in imbalanced datasets.
Loss Values:

Training and validation loss values for each epoch to monitor convergence.
Learning Rate:

Track the learning rate used during training, especially if using a scheduler.
Training Time:

Record the time taken for each epoch or batch to monitor performance and efficiency.
Gradient Norms:

Capture the norms of gradients to analyze if they are exploding or vanishing.
Model Weights:

Optionally log model weights at certain epochs to analyze weight changes over time.
Batch-wise Metrics:

Store metrics for each batch, such as loss and accuracy, to identify problematic batches.
Input Data Statistics:

Log statistics of input data (e.g., mean, standard deviation) to ensure data normalization is effective.
Epoch Count of Correct Predictions:

Count of correctly predicted samples for each epoch.
Class-wise Metrics:

Precision, recall, and F1 score for each class to identify specific classes that may be problematic.
Visualizations:
Save visualizations of loss curves, accuracy curves, or other relevant plots for analysis.
By capturing these additional metrics, you can gain a more comprehensive understanding of your model's performance and identify areas for improvement

## model

In [9]:
!python -m pip install torchvision

Collecting torchvision
  Downloading torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
Collecting pillow!=8.3.*,>=5.3.0 (from torchvision)
  Downloading pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.8 kB)
Downloading torchvision-0.24.1-cp312-cp312-manylinux_2_28_x86_64.whl (8.0 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m8.0/8.0 MB[0m [31m17.7 MB/s[0m  [33m0:00:00[0mm0:00:01[0m0:02[0m
[?25hDownloading pillow-12.1.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (7.0 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m7.0/7.0 MB[0m [31m35.4 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: pillow, torchvision
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [3]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

# Define data augmentation and normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
val_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2)

print(f"Train dataset size: {len(train_dataset)}")
print(f"Val dataset size: {len(val_dataset)}")

  entry = pickle.load(f, encoding="latin1")


Train dataset size: 50000
Val dataset size: 10000


In [4]:
# Define a simple CNN model using PyTorch Lightning
class SimpleCNN(pl.LightningModule):
    def __init__(self, learning_rate=0.001):
        super().__init__()
        self.learning_rate = learning_rate
        self.save_hyperparameters()
        
        # Convolutional layers
        self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
        
        # Pooling
        self.pool = torch.nn.MaxPool2d(2, 2)
        
        # Fully connected layers
        self.fc1 = torch.nn.Linear(128 * 4 * 4, 256)
        self.fc2 = torch.nn.Linear(256, 10)
        
        # Dropout for regularization
        self.dropout = torch.nn.Dropout(0.5)
    
    def forward(self, x):
        # Conv block 1: 32 filters -> 16x16
        x = self.pool(F.relu(self.conv1(x)))
        
        # Conv block 2: 64 filters -> 8x8
        x = self.pool(F.relu(self.conv2(x)))
        
        # Conv block 3: 128 filters -> 4x4
        x = self.pool(F.relu(self.conv3(x)))
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.cross_entropy(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = (preds == y).float().mean()
        self.log('val_loss', loss)
        self.log('accuracy', acc)
        return loss

In [None]:
# Load and display training data from JSON
import pandas as pd
import matplotlib.pyplot as plt
import json
from pathlib import Path

# Find the latest JSON file
log_dir = Path('./logs')
if log_dir.exists():
    json_files = sorted(log_dir.glob('training_log_*.json'))
    if json_files:
        latest_file = json_files[-1]
        
        # Load JSON data
        with open(latest_file, 'r') as f:
            training_data = json.load(f)
        
        print(f"üìÇ Loaded: {latest_file.name}")
        print(f"\n{'='*70}")
        print(f"{'TRAINING SUMMARY':^70}")
        print(f"{'='*70}")
        
        summary = training_data['summary']
        print(f"Timestamp:        {summary['timestamp']}")
        print(f"Total Epochs:     {summary['total_epochs']}")
        print(f"Best Accuracy:    {summary['best_accuracy']:.4f} (Epoch {summary['best_epoch']})")
        print(f"Final Train Loss: {summary['final_train_loss']:.4f}")
        print(f"Final Val Loss:   {summary['final_val_loss']:.4f}")
        print(f"{'='*70}\n")
        
        # Create DataFrame from epoch metrics
        metrics_df = pd.DataFrame(training_data['epochs'])
        print("Epoch Metrics:")
        print(metrics_df.to_string(index=False))
        
        # Plot metrics
        if len(metrics_df) > 0:
            fig, axes = plt.subplots(1, 2, figsize=(12, 4))
            
            axes[0].plot(metrics_df['epoch'], metrics_df['train_loss'], label='Train Loss', marker='o')
            axes[0].plot(metrics_df['epoch'], metrics_df['val_loss'], label='Val Loss', marker='o')
            axes[0].set_xlabel('Epoch')
            axes[0].set_ylabel('Loss')
            axes[0].set_title('Training and Validation Loss')
            axes[0].legend()
            axes[0].grid(True)
            
            axes[1].plot(metrics_df['epoch'], metrics_df['accuracy'], label='Accuracy', marker='o', color='green')
            axes[1].set_xlabel('Epoch')
            axes[1].set_ylabel('Accuracy')
            axes[1].set_title('Validation Accuracy')
            axes[1].legend()
            axes[1].grid(True)
            
            plt.tight_layout()
            plt.show()
    else:
        print("No training log files found. Run training first!")
else:
    print("Logs directory not found. Run training first!")