In [None]:
#!/usr/bin/env python3
"""
Binary Neural Network with Dual Attention for Disease Classification
Based on: "A Binary Neural Network with Dual Attention for Plant Disease Classification"
DOI: https://doi.org/10.3390/electronics12214431

This implementation trains a BNN model to classify diseases into 4 classes using 64000 images (16000 per class).
Optimized for GPU training with progress tracking and maximum accuracy.
"""

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
import time
import os
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

In [None]:
# Set device for maximum GPU utilization
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

In [None]:
# Hyperparameters optimized for 4GB GPU memory
BATCH_SIZE = 32  # Reduced from 128 to fit in GPU memory
LEARNING_RATE = 0.001
NUM_EPOCHS = 100
NUM_CLASSES = 4
IMG_SIZE = 128  # Reduced from 224 to save memory
NUM_WORKERS = 4  # Reduced to save system memory
PIN_MEMORY = True

# Data splits (70% train, 20% val, 10% test for maximum training data)
TRAIN_RATIO = 0.7
VAL_RATIO = 0.2
TEST_RATIO = 0.1

In [None]:
class BinaryActivation(torch.autograd.Function):
    """Binary activation function for BNN"""
    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return torch.sign(input)
    
    @staticmethod
    def backward(ctx, grad_output):
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input.le(-1)] = 0
        grad_input[input.ge(1)] = 0
        return grad_input

In [None]:
class BinaryLinear(nn.Module):
    """Binary linear layer"""
    def __init__(self, in_features, out_features, bias=True):
        super(BinaryLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.randn(out_features, in_features) * 0.1)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter('bias', None)
    
    def forward(self, input):
        binary_weight = torch.sign(self.weight)
        output = F.linear(input, binary_weight, self.bias)
        return output

In [None]:
class BinaryConv2d(nn.Module):
    """Binary convolution layer"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=True):
        super(BinaryConv2d, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.weight = nn.Parameter(torch.randn(out_channels, in_channels, kernel_size, kernel_size) * 0.1)
        if bias:
            self.bias = nn.Parameter(torch.zeros(out_channels))
        else:
            self.register_parameter('bias', None)
    
    def forward(self, input):
        binary_weight = torch.sign(self.weight)
        output = F.conv2d(input, binary_weight, self.bias, self.stride, self.padding)
        return output

In [None]:
class ChannelAttention(nn.Module):
    """Channel attention mechanism from the paper"""
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // reduction, bias=False),
            nn.PReLU(),
            nn.Linear(in_channels // reduction, in_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        b, c, _, _ = x.size()
        
        # Channel attention
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        out = avg_out + max_out
        
        return self.sigmoid(out).view(b, c, 1, 1) * x

In [None]:
class SpatialAttention(nn.Module):
    """Spatial attention mechanism from the paper"""
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_cat = torch.cat([avg_out, max_out], dim=1)
        
        out = self.conv(x_cat)
        return self.sigmoid(out) * x

In [None]:
class DualAttention(nn.Module):
    """Dual attention combining channel and spatial attention"""
    def __init__(self, in_channels):
        super(DualAttention, self).__init__()
        self.channel_attention = ChannelAttention(in_channels)
        self.spatial_attention = SpatialAttention()
    
    def forward(self, x):
        x = self.channel_attention(x)
        x = self.spatial_attention(x)
        return x

In [None]:
class DABConv(nn.Module):
    """Dual Attention Binary Convolution module"""
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(DABConv, self).__init__()
        self.binary_conv = BinaryConv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn = nn.BatchNorm2d(out_channels)
        self.dual_attention = DualAttention(out_channels)
        self.prelu = nn.PReLU()
    
    def forward(self, x):
        x = self.binary_conv(x)
        x = self.bn(x)
        x = self.dual_attention(x)
        x = self.prelu(x)
        return x

In [None]:
class BNNBasicBlock(nn.Module):
    """BNN Basic Block with residual connection"""
    def __init__(self, in_channels, out_channels, stride=1):
        super(BNNBasicBlock, self).__init__()
        
        # First path
        self.dabconv1 = DABConv(in_channels, out_channels, 3, stride, 1)
        self.dabconv2 = DABConv(out_channels, out_channels, 1, 1, 0)
        
        # Second path
        self.dabconv3 = DABConv(out_channels, out_channels, 3, 1, 1)
        self.dabconv4 = DABConv(out_channels, out_channels, 1, 1, 0)
        
        # Residual connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = self.shortcut(x)
        
        out = self.dabconv1(x)
        out = self.dabconv2(out)
        out = self.dabconv3(out)
        out = self.dabconv4(out)
        
        out += identity
        return out

In [None]:
class DABNN(nn.Module):
    """Dual Attention Binary Neural Network"""
    def __init__(self, num_classes=4):
        super(DABNN, self).__init__()
        
        # Stem block
        self.stem = nn.Sequential(
            nn.Conv2d(3, 64, 7, 2, 3, bias=False),  # Initial quantized conv
            nn.BatchNorm2d(64),
            nn.PReLU(),
            DABConv(64, 64, 3, 1, 1),  # 3x3 DABconv
            DABConv(64, 128, 1, 2, 0)  # 1x1 DABconv with downsampling
        )
        
        # Feature extractor - 6 BNN basic blocks
        self.features = nn.ModuleList([
            BNNBasicBlock(128, 128),
            BNNBasicBlock(128, 256, 2),
            BNNBasicBlock(256, 256),
            BNNBasicBlock(256, 512, 2),
            BNNBasicBlock(512, 512),
            BNNBasicBlock(512, 512)
        ])
        
        # Classification layer
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.PReLU(),
            BinaryLinear(512, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()
    
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.stem(x)
        
        for block in self.features:
            x = block(x)
        
        x = self.classifier(x)
        return x

In [None]:
def get_data_transforms():
    """Optimized data augmentation for maximum accuracy"""
    train_transforms = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.3),
        transforms.RandomRotation(degrees=15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
        transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    val_test_transforms = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transforms, val_test_transforms

In [None]:
def create_data_loaders(data_path):
    """Create optimized data loaders"""
    train_transforms, val_test_transforms = get_data_transforms()
    
    # Load dataset
    full_dataset = ImageFolder(data_path, transform=train_transforms)
    
    # Calculate split sizes
    total_size = len(full_dataset)
    train_size = int(TRAIN_RATIO * total_size)
    val_size = int(VAL_RATIO * total_size)
    test_size = total_size - train_size - val_size
    
    print(f"Dataset splits - Train: {train_size}, Val: {val_size}, Test: {test_size}")
    
    # Split dataset
    train_dataset, val_dataset, test_dataset = random_split(
        full_dataset, [train_size, val_size, test_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Apply appropriate transforms
    val_dataset.dataset.transform = val_test_transforms
    test_dataset.dataset.transform = val_test_transforms
    
    # Create data loaders with optimized settings
    train_loader = DataLoader(
        train_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=True, 
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=True
    )
    
    val_loader = DataLoader(
        val_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=True
    )
    
    test_loader = DataLoader(
        test_dataset, 
        batch_size=BATCH_SIZE, 
        shuffle=False, 
        num_workers=NUM_WORKERS,
        pin_memory=PIN_MEMORY,
        persistent_workers=True
    )
    
    return train_loader, val_loader, test_loader, full_dataset.classes

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100.0 * correct / total
    
    return epoch_loss, epoch_acc

In [None]:
def validate(model, val_loader, criterion, device):
    """Validate the model"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in val_loader:
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    epoch_loss = running_loss / len(val_loader)
    epoch_acc = 100.0 * correct / total
    
    return epoch_loss, epoch_acc

In [None]:
def test_model(model, test_loader, device, class_names):
    """Test the final model"""
    model.eval()
    all_preds = []
    all_targets = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    test_acc = 100.0 * correct / total
    
    # Generate classification report
    report = classification_report(all_targets, all_preds, target_names=class_names)
    
    # Generate confusion matrix
    cm = confusion_matrix(all_targets, all_preds)
    
    return test_acc, report, cm, all_preds, all_targets

In [None]:
def plot_training_history(train_losses, val_losses, train_accs, val_accs):
    """Plot training history"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    # Plot losses
    ax1.plot(train_losses, label='Train Loss', marker='o', markersize=3)
    ax1.plot(val_losses, label='Val Loss', marker='s', markersize=3)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Plot accuracies
    ax2.plot(train_accs, label='Train Accuracy', marker='o', markersize=3)
    ax2.plot(val_accs, label='Val Accuracy', marker='s', markersize=3)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy (%)')
    ax2.set_title('Training and Validation Accuracy')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(cm, class_names):
    """Plot confusion matrix"""
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.tight_layout()
    plt.show()

In [None]:
def count_parameters(model):
    """Count model parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"Model size: {total_params * 4 / (1024**2):.2f} MB (FP32)")
    print(f"Binary model size (estimated): {total_params / (8 * 1024**2):.2f} MB")

In [None]:
def main():
    """Main training function"""
    # Set up data path (modify this to your dataset path)
    data_path = "/home/bruh/Documents/BNN2/split"  # Change this to your actual dataset path
    
    # Verify CUDA optimization
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
        torch.backends.cudnn.deterministic = False  # Allow non-deterministic algorithms for speed
    
    print("="*80)
    print("BNN Disease Classification Training")
    print("="*80)
    
    # Create data loaders
    print("Loading dataset...")
    train_loader, val_loader, test_loader, class_names = create_data_loaders(data_path)
    print(f"Classes: {class_names}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")
    
    # Initialize model
    print("\\nInitializing DABNN model...")
    model = DABNN(num_classes=NUM_CLASSES).to(device)
    count_parameters(model)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label smoothing for better generalization
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    
    # Training history
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    best_val_acc = 0.0
    best_model_state = None
    
    print("\\nStarting training...")
    print("-" * 80)
    
    start_time = time.time()
    
    for epoch in range(NUM_EPOCHS):
        epoch_start = time.time()
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step()
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
        
        # Record history
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        # Print progress
        epoch_time = time.time() - epoch_start
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"Epoch [{epoch+1:3d}/{NUM_EPOCHS}] | "
              f"Time: {epoch_time:.1f}s | "
              f"LR: {current_lr:.6f} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f} | "
              f"Val Acc: {val_acc:.2f}%"
              f"{' *' if val_acc == best_val_acc else ''}")
        
        # Early stopping check
        if epoch > 20 and val_acc < max(val_accs[-10:]) - 5:  # Stop if no improvement in 10 epochs
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    total_time = time.time() - start_time
    print("-" * 80)
    print(f"Training completed in {total_time/60:.1f} minutes")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    
    # Load best model for testing
    model.load_state_dict(best_model_state)
    
    # Test the model
    print("\\nTesting best model...")
    test_acc, report, cm, preds, targets = test_model(model, test_loader, device, class_names)
    print(f"Test accuracy: {test_acc:.2f}%")
    print("\\nClassification Report:")
    print(report)
    
    # Save the model
    torch.save({
        'model_state_dict': best_model_state,
        'model_config': {
            'num_classes': NUM_CLASSES,
            'img_size': IMG_SIZE
        },
        'class_names': class_names,
        'best_val_acc': best_val_acc,
        'test_acc': test_acc
    }, 'best_dabnn_model.pth')
    print("\\nBest model saved as 'best_dabnn_model.pth'")
    
    # Plot results
    print("\\nGenerating plots...")
    plot_training_history(train_losses, val_losses, train_accs, val_accs)
    plot_confusion_matrix(cm, class_names)
    
    # Final model statistics
    print("\\n" + "="*80)
    print("FINAL RESULTS")
    print("="*80)
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Total Training Time: {total_time/60:.1f} minutes")
    print(f"Model Size (Binary): {sum(p.numel() for p in model.parameters()) / (8 * 1024**2):.2f} MB")

In [None]:
if __name__ == "__main__":
    main()


In [None]:
def plot_detailed_analysis(model, test_loader, device, class_names):
    """Generate detailed performance analysis plots including ROC curves"""
    from sklearn.metrics import roc_curve, auc, precision_recall_curve, average_precision_score
    from sklearn.preprocessing import label_binarize
    from itertools import cycle
    import matplotlib.pyplot as plt
    import numpy as np
    
    model.eval()
    all_probs = []
    all_targets = []
    all_preds = []
    
    # Get predictions and probabilities
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            outputs = model(inputs)
            probs = F.softmax(outputs, dim=1)
            _, predicted = outputs.max(1)
            
            all_probs.extend(probs.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_preds.extend(predicted.cpu().numpy())
    
    all_probs = np.array(all_probs)
    all_targets = np.array(all_targets)
    all_preds = np.array(all_preds)
    
    # Binarize the targets for multiclass ROC
    all_targets_bin = label_binarize(all_targets, classes=range(len(class_names)))
    n_classes = len(class_names)
    
    # Create subplots
    fig = plt.figure(figsize=(20, 15))
    
    # 1. ROC Curves (Multi-class)
    ax1 = plt.subplot(2, 3, 1)
    colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'red', 'green', 'purple'])
    
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    
    for i in range(n_classes):
        fpr[i], tpr[i], _ = roc_curve(all_targets_bin[:, i], all_probs[:, i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    
    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(all_targets_bin.ravel(), all_probs.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
    
    # Plot ROC curves
    plt.plot(fpr["micro"], tpr["micro"],
             label=f'Micro-average ROC (AUC = {roc_auc["micro"]:.3f})',
             color='deeppink', linestyle=':', linewidth=4)
    
    for i, color in zip(range(n_classes), colors):
        plt.plot(fpr[i], tpr[i], color=color, lw=2,
                 label=f'{class_names[i]} (AUC = {roc_auc[i]:.3f})')
    
    plt.plot([0, 1], [0, 1], 'k--', lw=2)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Multi-class ROC Curves')
    plt.legend(loc="lower right", fontsize=8)
    plt.grid(True, alpha=0.3)
    
    # 2. Precision-Recall Curves
    ax2 = plt.subplot(2, 3, 2)
    precision = dict()
    recall = dict()
    pr_auc = dict()
    
    for i in range(n_classes):
        precision[i], recall[i], _ = precision_recall_curve(all_targets_bin[:, i], all_probs[:, i])
        pr_auc[i] = average_precision_score(all_targets_bin[:, i], all_probs[:, i])
    
    # Plot PR curves
    for i, color in zip(range(n_classes), colors):
        plt.plot(recall[i], precision[i], color=color, lw=2,
                 label=f'{class_names[i]} (AP = {pr_auc[i]:.3f})')
    
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curves')
    plt.legend(loc="lower left", fontsize=8)
    plt.grid(True, alpha=0.3)
    
    # 3. Class-wise Accuracy Bar Plot
    ax3 = plt.subplot(2, 3, 3)
    class_accuracies = []
    for i in range(n_classes):
        class_mask = all_targets == i
        class_acc = (all_preds[class_mask] == all_targets[class_mask]).mean() * 100
        class_accuracies.append(class_acc)
    
    bars = plt.bar(class_names, class_accuracies, color=['skyblue', 'lightcoral', 'lightgreen', 'gold'])
    plt.ylabel('Accuracy (%)')
    plt.title('Class-wise Accuracy')
    plt.xticks(rotation=45)
    
    # Add value labels on bars
    for bar, acc in zip(bars, class_accuracies):
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 1,
                f'{acc:.1f}%', ha='center', va='bottom')
    plt.grid(True, alpha=0.3, axis='y')
    
    # 4. Prediction Confidence Distribution
    ax4 = plt.subplot(2, 3, 4)
    max_probs = np.max(all_probs, axis=1)
    correct_mask = all_preds == all_targets
    
    plt.hist(max_probs[correct_mask], bins=30, alpha=0.7, label='Correct', color='green', density=True)
    plt.hist(max_probs[~correct_mask], bins=30, alpha=0.7, label='Incorrect', color='red', density=True)
    plt.xlabel('Prediction Confidence')
    plt.ylabel('Density')
    plt.title('Confidence Distribution')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # 5. Enhanced Confusion Matrix with percentages
    ax5 = plt.subplot(2, 3, 5)
    from sklearn.metrics import confusion_matrix
    cm = confusion_matrix(all_targets, all_preds)
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # Create annotations with both counts and percentages
    annot = np.empty_like(cm, dtype=object)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            annot[i, j] = f'{cm[i, j]}\n({cm_percent[i, j]:.1f}%)'
    
    sns.heatmap(cm_percent, annot=annot, fmt='', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Percentage'})
    plt.title('Confusion Matrix (% and Counts)')
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    
    # 6. Model Calibration Plot
    ax6 = plt.subplot(2, 3, 6)
    from sklearn.calibration import calibration_curve
    
    # Plot calibration curve for each class
    for i, color in zip(range(n_classes), colors):
        fraction_of_positives, mean_predicted_value = calibration_curve(
            all_targets_bin[:, i], all_probs[:, i], n_bins=10)
        plt.plot(mean_predicted_value, fraction_of_positives, "s-", color=color,
                 label=f'{class_names[i]}', linewidth=2, markersize=4)
    
    plt.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
    plt.xlabel('Mean Predicted Probability')
    plt.ylabel('Fraction of Positives')
    plt.title('Calibration Plot')
    plt.legend(loc="lower right", fontsize=8)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed metrics
    print("\n" + "="*80)
    print("DETAILED PERFORMANCE METRICS")
    print("="*80)
    
    # Overall metrics
    overall_acc = (all_preds == all_targets).mean() * 100
    print(f"Overall Test Accuracy: {overall_acc:.2f}%")
    print(f"Average Confidence: {np.mean(max_probs):.3f}")
    print(f"Confidence Std: {np.std(max_probs):.3f}")
    
    print("\nPer-Class Metrics:")
    print("-" * 50)
    for i, class_name in enumerate(class_names):
        print(f"{class_name}:")
        print(f"  Accuracy: {class_accuracies[i]:.2f}%")
        print(f"  ROC AUC: {roc_auc[i]:.3f}")
        print(f"  PR AUC: {pr_auc[i]:.3f}")
        
        # Class-specific confidence stats
        class_mask = all_targets == i
        class_conf = max_probs[class_mask]
        print(f"  Avg Confidence: {np.mean(class_conf):.3f}")
        print()
    
    return {
        'roc_auc': roc_auc,
        'pr_auc': pr_auc,
        'class_accuracies': class_accuracies,
        'overall_accuracy': overall_acc,
        'predictions': all_preds,
        'probabilities': all_probs,
        'targets': all_targets
    }

In [None]:
# Update the main function to include detailed analysis
# Add this code at the end of the main() function, after the confusion matrix plotpython test_model_simple.py --model_path best_dabnn_model.pth --data_path /path/to/your/test/dataset

def main_with_detailed_analysis():
    """Enhanced main function with detailed post-training analysis"""
    # Set up data path (modify this to your dataset path)
    data_path = "/home/bruh/Documents/BNN2/split"  # Change this to your actual dataset path
    
    # Verify CUDA optimization
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True  # Optimize for fixed input sizes
        torch.backends.cudnn.deterministic = False  # Allow non-deterministic algorithms for speed
    
    print("="*80)
    print("BNN Disease Classification Training with Detailed Analysis")
    print("="*80)
    
    # Create data loaders
    print("Loading dataset...")
    train_loader, val_loader, test_loader, class_names = create_data_loaders(data_path)
    print(f"Classes: {class_names}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
    print(f"Test batches: {len(test_loader)}")
    
    # Initialize model
    print("\nInitializing DABNN model...")
    model = DABNN(num_classes=NUM_CLASSES).to(device)
    count_parameters(model)
    
    # Loss and optimizer
    criterion = nn.CrossEntropyLoss(label_smoothing=0.1)  # Label smoothing for better generalization
    optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-4)
    
    # Learning rate scheduler
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=2)
    
    # Training history
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    best_val_acc = 0.0
    best_model_state = None
    
    print("\nStarting training...")
    print("-" * 80)
    
    start_time = time.time()
    
    for epoch in range(NUM_EPOCHS):
        epoch_start = time.time()
        
        # Train
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
        
        # Validate
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        
        # Update scheduler
        scheduler.step()
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
        
        # Record history
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        
        # Print progress
        epoch_time = time.time() - epoch_start
        current_lr = optimizer.param_groups[0]['lr']
        
        print(f"Epoch [{epoch+1:3d}/{NUM_EPOCHS}] | "
              f"Time: {epoch_time:.1f}s | "
              f"LR: {current_lr:.6f} | "
              f"Train Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.2f}% | "
              f"Val Loss: {val_loss:.4f} | "
              f"Val Acc: {val_acc:.2f}%"
              f"{' *' if val_acc == best_val_acc else ''}")
        
        # Early stopping check
        if epoch > 20 and val_acc < max(val_accs[-10:]) - 5:  # Stop if no improvement in 10 epochs
            print(f"Early stopping at epoch {epoch+1}")
            break
    
    total_time = time.time() - start_time
    print("-" * 80)
    print(f"Training completed in {total_time/60:.1f} minutes")
    print(f"Best validation accuracy: {best_val_acc:.2f}%")
    
    # Load best model for testing
    model.load_state_dict(best_model_state)
    
    # Test the model
    print("\nTesting best model...")
    test_acc, report, cm, preds, targets = test_model(model, test_loader, device, class_names)
    print(f"Test accuracy: {test_acc:.2f}%")
    print("\nClassification Report:")
    print(report)
    
    # Save the model
    torch.save({
        'model_state_dict': best_model_state,
        'model_config': {
            'num_classes': NUM_CLASSES,
            'img_size': IMG_SIZE
        },
        'class_names': class_names,
        'best_val_acc': best_val_acc,
        'test_acc': test_acc
    }, 'best_dabnn_model.pth')
    print("\nBest model saved as 'best_dabnn_model.pth'")
    
    # Plot basic results
    print("\nGenerating basic plots...")
    plot_training_history(train_losses, val_losses, train_accs, val_accs)
    plot_confusion_matrix(cm, class_names)
    
    # DETAILED ANALYSIS - NEW ADDITION
    print("\nGenerating detailed performance analysis...")
    detailed_metrics = plot_detailed_analysis(model, test_loader, device, class_names)
    
    # Save detailed results
    import json
    detailed_results = {
        'training_history': {
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accs': train_accs,
            'val_accs': val_accs
        },
        'detailed_metrics': {
            'roc_auc': {class_names[i]: float(detailed_metrics['roc_auc'][i]) for i in range(len(class_names))},
            'pr_auc': {class_names[i]: float(detailed_metrics['pr_auc'][i]) for i in range(len(class_names))},
            'class_accuracies': {class_names[i]: float(detailed_metrics['class_accuracies'][i]) for i in range(len(class_names))},
            'overall_accuracy': float(detailed_metrics['overall_accuracy'])
        },
        'model_info': {
            'total_params': sum(p.numel() for p in model.parameters()),
            'model_size_mb': sum(p.numel() for p in model.parameters()) * 4 / (1024**2),
            'binary_size_mb': sum(p.numel() for p in model.parameters()) / (8 * 1024**2)
        }
    }
    
    with open('detailed_training_results.json', 'w') as f:
        json.dump(detailed_results, f, indent=2)
    print("Detailed results saved to 'detailed_training_results.json'")
    
    # Final model statistics
    print("\n" + "="*80)
    print("FINAL RESULTS WITH DETAILED ANALYSIS")
    print("="*80)
    print(f"Best Validation Accuracy: {best_val_acc:.2f}%")
    print(f"Test Accuracy: {test_acc:.2f}%")
    print(f"Total Training Time: {total_time/60:.1f} minutes")
    print(f"Model Size (Binary): {sum(p.numel() for p in model.parameters()) / (8 * 1024**2):.2f} MB")
    
    # Print summary of detailed metrics
    print(f"\nDetailed Analysis Summary:")
    print(f"Average ROC AUC: {np.mean(list(detailed_metrics['roc_auc'].values())):.3f}")
    print(f"Average PR AUC: {np.mean(list(detailed_metrics['pr_auc'].values())):.3f}")
    print(f"Class Accuracy Range: {min(detailed_metrics['class_accuracies']):.1f}% - {max(detailed_metrics['class_accuracies']):.1f}%")
    
    return detailed_metrics

In [None]:
# Example usage of the enhanced analysis
# Uncomment and run one of these options:

# Option 1: Run the enhanced main function with detailed analysis
# detailed_metrics = main_with_detailed_analysis()

# Option 2: If you already have a trained model, load it and run detailed analysis only
"""
# Load pre-trained model
checkpoint = torch.load('best_dabnn_model.pth', map_location=device)
model = DABNN(num_classes=NUM_CLASSES).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
class_names = checkpoint['class_names']

# Create test loader (you'll need to modify data_path)
data_path = "/home/bruh/Documents/BNN2/split"  # Change this to your actual dataset path
_, _, test_loader, _ = create_data_loaders(data_path)

# Run detailed analysis on pre-trained model
detailed_metrics = plot_detailed_analysis(model, test_loader, device, class_names)
"""

# Option 3: Run original main function (without detailed analysis)
# main()

print("Choose one of the options above to run the analysis!")
print("- Option 1: Complete training with detailed analysis")
print("- Option 2: Load existing model and run detailed analysis only") 
print("- Option 3: Run original training without detailed analysis")