In [4]:
import torch
import numpy as np
from tqdm import tqdm
import seaborn as sns
import torch.nn as nn
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, mean_absolute_error, mean_squared_error

class CustomCNN(nn.Module):
    def __init__(self, num_classes=37):  # 0 to 36 inclusive
        super().__init__()
        
        # Feature extraction layers
        self.features = nn.Sequential(
            # First conv block
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25),
            
            # Second conv block
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25),
            
            # Third conv block
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2, 2),
            nn.Dropout2d(0.25)
        )
        
        # Calculate size of flattened features
        self._to_linear = self._get_conv_output_size((1, 40, 168))
        
        # Classification layers
        self.classifier = nn.Sequential(
            nn.Linear(self._to_linear, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )
        
        # Initialize weights
        self._initialize_weights()
    
    def _get_conv_output_size(self, shape):
        """Calculate the size of flattened features after convolutions"""
        batch_size = 1
        input = torch.autograd.Variable(torch.rand(batch_size, *shape))
        output_feat = self.features(input)
        n_size = output_feat.data.view(batch_size, -1).size(1)
        return n_size
    
    def _initialize_weights(self):
        """Initialize model weights"""
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

def preprocess_image(image):
    """Preprocess a single image for inference"""
    if len(image.shape) == 2:
        image = image.reshape(1, 40, 168)
    
    # Convert to torch tensor if numpy array
    if isinstance(image, np.ndarray):
        image = torch.FloatTensor(image)
    
    # Keep as single channel and normalize
    model_input = image.reshape(-1, 1, 40, 168)
    model_input = (model_input - model_input.mean()) / model_input.std()
    
    return model_input

def load_model(model_path, device='cuda'):
    """Load a trained model from path"""
    print(f"Loading model from {model_path}...")
    model = CustomCNN().to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

class ModelEvaluator:
    def __init__(self, model_path, device='cuda'):
        self.device = device
        self.model = load_model(model_path, device)
        self.model.eval()
    
    def evaluate_batch(self, images, labels, batch_size=32):
        """Evaluate a batch of images"""
        predictions = []
        confidences = []
        
        # Create progress bar for batch processing
        num_batches = (len(images) + batch_size - 1) // batch_size
        batch_iterator = tqdm(range(0, len(images), batch_size), 
                            total=num_batches,
                            desc="Processing batches")
        
        with torch.no_grad():
            for i in batch_iterator:
                batch_images = images[i:i + batch_size]
                
                # Preprocess batch with progress tracking
                model_inputs = []
                for img in batch_images:
                    model_input = preprocess_image(img)[0]
                    model_inputs.append(model_input)
                
                model_input = torch.stack(model_inputs).to(self.device)
                
                # Get predictions
                outputs = self.model(model_input)
                probabilities = torch.softmax(outputs, dim=1)
                
                pred_probs, predicted = probabilities.max(1)
                
                predictions.extend(predicted.cpu().numpy())
                confidences.extend(pred_probs.cpu().numpy())
                
                # Update progress bar with current metrics
                batch_acc = (predicted.cpu().numpy() == labels[i:i + batch_size]).mean()
                batch_iterator.set_postfix({'Batch Accuracy': f'{batch_acc:.3f}'})
        
        return {
            'predictions': np.array(predictions),
            'confidences': np.array(confidences),
            'labels': labels
        }
    
    def calculate_metrics(self, results):
        """Calculate evaluation metrics"""
        predictions = results['predictions']
        labels = results['labels']
        confidences = results['confidences']
        
        print("Calculating metrics...")
        metrics = {}
        
        # Basic metrics
        metrics['accuracy'] = (predictions == labels).mean()
        metrics['mae'] = mean_absolute_error(labels, predictions)
        metrics['rmse'] = np.sqrt(mean_squared_error(labels, predictions))
        
        # Within-1 accuracy
        within_one = np.abs(predictions - labels) <= 1
        metrics['within_one_accuracy'] = within_one.mean()
        
        # Confidence analysis
        metrics['mean_confidence'] = confidences.mean()
        metrics['correct_confidence'] = confidences[predictions == labels].mean()
        metrics['incorrect_confidence'] = confidences[predictions != labels].mean() if any(predictions != labels) else 0
        
        # Error distribution
        metrics['error_distribution'] = predictions - labels
        
        return metrics
    
    def plot_confusion_matrix(self, results, save_path=None):
        """Plot confusion matrix"""
        print("Generating confusion matrix...")
        plt.figure(figsize=(15, 15))
        cm = confusion_matrix(results['labels'], results['predictions'])
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Confusion Matrix')
        plt.xlabel('Predicted')
        plt.ylabel('True')
        if save_path:
            plt.savefig(save_path)
            plt.close()
            print(f"Saved confusion matrix to {save_path}")
        else:
            plt.show()
    
    def plot_error_distribution(self, metrics, save_path=None):
        """Plot error distribution"""
        print("Generating error distribution plot...")
        plt.figure(figsize=(12, 6))
        plt.hist(metrics['error_distribution'], bins=range(-10, 11), align='left')
        plt.title('Error Distribution')
        plt.xlabel('Prediction Error (Predicted - True)')
        plt.ylabel('Count')
        if save_path:
            plt.savefig(save_path)
            plt.close()
            print(f"Saved error distribution to {save_path}")
        else:
            plt.show()
    
    def visualize_predictions(self, images, results, num_samples=10, save_path=None):
        """Visualize sample predictions"""
        print(f"Visualizing {num_samples} sample predictions...")
        predictions = results['predictions']
        labels = results['labels']
        confidences = results['confidences']
        
        # Randomly select samples
        indices = np.random.choice(len(images), num_samples, replace=False)
        
        # Create figure
        fig, axes = plt.subplots(2, 5, figsize=(20, 8))
        axes = axes.ravel()
        
        for idx, ax in enumerate(tqdm(axes, desc="Plotting samples")):
            if idx < num_samples:
                img = images[indices[idx]]
                pred = predictions[indices[idx]]
                true = labels[indices[idx]]
                conf = confidences[indices[idx]]
                
                # Display image
                ax.imshow(img.reshape(40, 168), cmap='gray')
                color = 'green' if pred == true else 'red'
                ax.set_title(f'Pred: {pred} (True: {true})\nConf: {conf:.2f}', 
                           color=color)
                ax.axis('off')
        
        plt.tight_layout()
        if save_path:
            plt.savefig(save_path)
            plt.close()
            print(f"Saved prediction visualization to {save_path}")
        else:
            plt.show()

def main():
    print("Loading test data...")
    data_path = '/scratch/gaurav.bhole/MLNS_data/'
    
    # Load data files with progress tracking
    data_files = {
        'data0': 'data0.npy',
        'data1': 'data1.npy',
        'data2': 'data2.npy',
        'lab0': 'lab0.npy',
        'lab1': 'lab1.npy',
        'lab2': 'lab2.npy'
    }
    
    loaded_data = {}
    for name, filename in tqdm(data_files.items(), desc="Loading data files"):
        loaded_data[name] = np.load(data_path + filename)
    
    # Combine the data
    test_images = np.concatenate(
        (loaded_data['data0'], loaded_data['data1'], loaded_data['data2']), 
        axis=0
    )
    test_labels = np.concatenate(
        (loaded_data['lab0'], loaded_data['lab1'], loaded_data['lab2']), 
        axis=0
    )
    print(f"Loaded {len(test_images)} test images")

    # Path to trained model
    model_path = 'best_digit_sum_cnn_batch64_lr0.0001.pt'
    
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    # Create evaluator and run evaluation
    print("\n=== Evaluating Model ===")
    evaluator = ModelEvaluator(model_path, device)
    results = evaluator.evaluate_batch(test_images, test_labels)
    metrics = evaluator.calculate_metrics(results)
    
    # Print metrics
    print("\n=== Results ===")
    for metric, value in metrics.items():
        if metric != 'error_distribution':
            print(f"{metric}: {value:.4f}")
    
    # Generate visualizations
    print("\n=== Generating Visualizations ===")
    evaluator.plot_confusion_matrix(results, save_path='confusion_matrix.png')
    evaluator.plot_error_distribution(metrics, save_path='error_dist.png')
    evaluator.visualize_predictions(test_images, results, save_path='sample_predictions.png')

if __name__ == '__main__':
    main()

Loading test data...


Loading data files: 100%|██████████| 6/6 [00:00<00:00, 55.28it/s]


Loaded 30000 test images
Using device: cuda

=== Evaluating Model ===
Loading model from best_digit_sum_cnn_batch64_lr0.0001.pt...


Processing batches: 100%|██████████| 938/938 [00:10<00:00, 87.12it/s, Batch Accuracy=0.312] 


Calculating metrics...

=== Results ===
accuracy: 0.5065
mae: 0.6142
rmse: 0.9729
within_one_accuracy: 0.9112
mean_confidence: 0.3464
correct_confidence: 0.3530
incorrect_confidence: 0.3397

=== Generating Visualizations ===
Generating confusion matrix...
Saved confusion matrix to confusion_matrix.png
Generating error distribution plot...
Saved error distribution to error_dist.png
Visualizing 10 sample predictions...


Plotting samples: 100%|██████████| 10/10 [00:00<00:00, 1043.62it/s]


Saved prediction visualization to sample_predictions.png
