In [3]:
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transforms

In [6]:
class MNISTStatsVisualizer:
    def __init__(self):
        self.datasets = {}  # Store multiple datasets by key
        self.default_dataset = None  # Key for default dataset
        self.stats_data = {}  # Store stats with cumulative total_time
        self.stats_types = {}  # Keep track of stats format types
    
    def load_mnist_test_data(self, key, transform=None):
        """Loads MNIST test dataset and stores it under a key."""
        if transform is None:
            transform = transforms.Compose([transforms.ToTensor()])
        self.datasets[key] = datasets.MNIST(root="data", train=False, download=True, transform=transform)
        if self.default_dataset is None:
            self.default_dataset = key  # Set first loaded dataset as default

    def add_stats(self, key, stats, stats_type):
        """Adds a stats dataset and computes cumulative total_time while storing its type."""
        standardized_stats = [{**entry} for entry in stats]  # Simply copy entries
        self.stats_data[key] = self._compute_total_time(standardized_stats)
        self.stats_types[key] = stats_type  # Store the type of stats
    
    def _compute_total_time(self, stats):
        """Computes cumulative total_time for stats."""
        total_time = 0
        for entry in stats:
            total_time += entry['time']
            entry['total_time'] = total_time
        return stats
    
    def plot_accuracy(self, keys=None):
        """Plots accuracy over total time for multiple stat files."""
        keys = keys or [self.default_dataset]
        title_parts = []
        
        plt.figure(figsize=(16, 8))
        for key in keys:
            if key not in self.stats_data or self.stats_types.get(key) == "type_2":
                print(f"Skipping accuracy plot for {key}, as it lacks accuracy data.")
                continue
            
            times = [entry['total_time'] for entry in self.stats_data[key] if 'accuracy' in entry]
            accuracies = [entry['accuracy'] for entry in self.stats_data[key] if 'accuracy' in entry]
            plt.plot(times, accuracies, marker='o', label=f"Accuracy ({key})")
            
            max_acc = max(accuracies)
            title_parts.append(f"{key}: Max Acc = {max_acc:.3f}")
        
        plt.xlabel("Total Time (seconds)")
        plt.ylabel("Accuracy")
        plt.title("Accuracy Over Time\n"  + " \n ".join(title_parts))
        plt.legend()
        plt.show()



    def plot_loss(self, keys=None):
        """Plots loss over total time for multiple stat files and reports lowest loss values in the title."""
        keys = keys or [self.default_dataset]
        
        plt.figure(figsize=(16, 8))
        title_parts = []  # To store the loss information for the title
    
        for key in keys:
            if key not in self.stats_data:
                print("No stats available for key", key)
                continue
    
            times = [entry['total_time'] for entry in self.stats_data[key]]
            losses = [entry['loss'] for entry in self.stats_data[key]]
            plt.plot(times, losses, marker='o', label=f"Loss ({key})", linestyle='dashed')
    
            # Find the minimum loss value and the corresponding time
            min_loss = min(losses)
            
            # Add the minimum loss to the title
            title_parts.append(f"{key}: Min Loss = {min_loss:.3f}")
    
        plt.xlabel("Total Time (seconds)")
        plt.ylabel("Loss")
        plt.title("Loss Over Time\n" + " \n ".join(title_parts))  # Add the loss details to the title
        plt.legend()
        plt.show()
    



    def show_failed_predictions(self, stats_key, dataset_key=None, max_images=1000):
        """Displays failed predictions based on the epoch with the highest accuracy."""
        if stats_key not in self.stats_data:
            print("No stats available for key", stats_key)
            return
    
        if self.stats_types.get(stats_key) == "type_2":
            print("Failed predictions are not available for this stats type.")
            return
    
        dataset_key = dataset_key or self.default_dataset
        if dataset_key not in self.datasets:
            print("No dataset available for key", dataset_key)
            return
    
        dataset = self.datasets[dataset_key]
    
        # Find epoch with best accuracy
        best_epoch = max(self.stats_data[stats_key], key=lambda x: x.get('accuracy', 0))
        failed_predictions = best_epoch.get('list_of_fails', [])
    
        # Limit number of displayed failures
        failed_predictions = failed_predictions[:max_images]
    
        num_images = len(failed_predictions)
        cols = min(15, num_images)  # Set max columns to 20
        rows = (num_images // cols) + (num_images % cols > 0)

        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4.8))
    
        axes = axes.flatten() if num_images > 1 else [axes]
    
        for ax, fail in zip(axes, failed_predictions):
            image, _ = dataset[fail['index']]
            ax.imshow(image.squeeze(), cmap='gray')
            ax.set_title(f"P: {fail['predic']} / A: {fail['actual']}", fontsize=52)  # Adjusted font size
            ax.axis('off')
    
        for ax in axes[num_images:]:
            ax.axis('off')
    
        plt.tight_layout()
        plt.show()
  


    def plot_double_histogram(self, stats_keys):
        plt.figure(figsize=(16, 8))  # Enlarged figure for better comparison
        bar_width = 0.2  # Narrower bars to fit multiple models
        num_models = len(stats_keys)
        indices = np.arange(10)  # Positions for digits 0-9
    
        title_parts = []  # Store titles with accuracy
        colors = [  "#1f77b4", "#aec7e8", 
                    "#ff7f0e", "#ffbb78", 
                    "#2ca02c", "#98df8a",  
                    "#d62728", "#ff9896",  
                ]
        
        for i, key in enumerate(stats_keys):
            if self.stats_types.get(key) == 2:
                print(f"Skipping {key} as it has 'type_2' stats, which do not have failed predictions.")
                continue
    
            stats = self.stats_data.get(key, [])
            best_epoch = max(stats, key=lambda x: x.get('accuracy', 0))  # Find best epoch based on accuracy
            accuracy = best_epoch.get('accuracy', 0)  # Get accuracy
            title_parts.append(f"{key}: {accuracy:.2%}")  # Format as percentage
    
            failed_predictions = best_epoch.get('list_of_fails', [])
    
            actual_values = [fail['actual'] for fail in failed_predictions]
            predicted_values = [fail['predic'] for fail in failed_predictions]
    
            # Count occurrences for actual and predicted values
            actual_counts = np.bincount(actual_values, minlength=10)
            predicted_counts = np.bincount(predicted_values, minlength=10)
    
            # Offset each model's bars to the right to separate them
            offset = (i - num_models / 2) * bar_width * 2
    
            plt.bar(indices + offset - bar_width / 2, actual_counts, width=bar_width, label=f"Actual - {key}",       color= colors[i*2+0], alpha=0.7)
            plt.bar(indices + offset + bar_width / 2, predicted_counts, width=bar_width, label=f"Predicted - {key}", color= colors[i*2+1], alpha=0.7)
    
        plt.xlabel("Digit")
        plt.ylabel("Frequency")
        plt.xticks(indices, [str(i) for i in range(10)])  # Label each bar with its digit
        plt.legend()
        plt.title("Histogram of Actual vs. Predicted Values\n" + " | ".join(title_parts))  # Show accuracy in title
        plt.show()

    def plot_mistake_matrix(self, stats_key):
        """Plots a 10x10 mistake matrix with labels and title."""
        if stats_key not in self.stats_data:
            print(f"No stats available for key: {stats_key}")
            return

        if self.stats_types.get(stats_key) == "type_2":
            print(f"Mistake matrix is not available for stats type 'type_2' ({stats_key})")
            return

        # Find the epoch with the highest accuracy
        best_epoch = max(self.stats_data[stats_key], key=lambda x: x.get('accuracy', 0))
        failed_predictions = best_epoch.get('list_of_fails', [])

        # Initialize a 10x10 matrix for digit mistakes (0-9)
        mistake_matrix = np.zeros((10, 10), dtype=int)

        # Fill the matrix with failed predictions
        for fail in failed_predictions:
            actual = fail['actual']
            predicted = fail['predic']
            if actual != predicted:
                mistake_matrix[actual, predicted] += 1  # Increment count for wrong predictions

        # Plot the matrix as a heatmap
        fig, ax = plt.subplots(figsize=(8, 6))
        cax = ax.matshow(mistake_matrix, cmap="Blues")

        # Add color bar
        plt.colorbar(cax)

        # Set axis labels
        ax.set_xlabel("Predicted Label")
        ax.set_ylabel("Actual Label")
        ax.set_title(f"Mistake Matrix for {stats_key} (Best Epoch)")

        # Set tick labels for x and y axes
        ax.set_xticks(np.arange(10))
        ax.set_yticks(np.arange(10))
        ax.set_xticklabels(np.arange(10))
        ax.set_yticklabels(np.arange(10))

        # Show values inside each cell
        for i in range(10):
            for j in range(10):
                value = mistake_matrix[i, j]
                if value > 0:  # Only show non-zero values
                    ax.text(j, i, str(value), ha='center', va='center', color="black", fontsize=10)

        plt.show()
   