In [210]:
import sys
print(sys.version)

3.11.9 | packaged by Anaconda, Inc. | (main, Apr 19 2024, 16:40:41) [MSC v.1916 64 bit (AMD64)]


# Helper functions, contains code that is used multiple times in the project 

### Imports

In [211]:
import time
from IPython.display import display, Javascript

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torchvision import datasets
import torchvision.transforms as transforms
import random

In [212]:
import json
import numpy as np

In [213]:
import os
import stat
from importlib import reload
import nbimporter

### Local imports

In [214]:
from MyModels import create_empty_model
from MyModels import create_loss_function
from MyModels import create_optimizer


In [215]:
import MyModels
reload(MyModels)
from MyModels import create_empty_model

## Testing

In [216]:
def tellALie():
    """Test if the helperFunctions.ipynb or hf library has loaded successfully, but still lies to you!"""
    print("You have all the time in the world!\nYour helperfunctions library has been loaded!")

## Model Manager

In [228]:
class model_manager():
    """
    Model managers gathers everything with testing and training in one class!
    To create a new model manager you need: model_class, loss_function_name ,optimizer_name,optimizer_params. 
    Has default values for: model_parameters = {},  hidden_layers = 0, layer_sizes = [], singel_outputs = False, 
    stats = []
    """
    
    
    def __init__(self, model_class, loss_function_name ,optimizer_name,optimizer_params , 
                 model_parameters = None,  hidden_layers = 0, 
                 layer_sizes = None, singel_outputs = False, stats = None):
        """ 
        Needs: model_class, loss_function_name ,optimizer_name,optimizer_params. 
        Has default values for: model_parameters = {},  hidden_layers = 0, layer_sizes = [], singel_outputs = False,  stats = [] 
        """
        
        # avoid issues with parameters being created at initialization by exlicitly making the None then create emrty mutables
        if model_parameters == None: model_parameters = {}
        if stats == None:            stats = []
        if layer_sizes == None:      layer_sizes = []
        
        
        # Set parameters
        self.device = get_device()
        self.model_class = model_class
        self.model_parameters = model_parameters
        self.loss_function_name = loss_function_name  
        self.optimizer_name = optimizer_name
        self.optimizer_params = optimizer_params
        self.hidden_layers = hidden_layers
        self.layer_sizes = layer_sizes
        self.singel_outputs = singel_outputs
        self.num_epochs = len(stats)
        self.stats = stats
        
        print("parameters stored")
         
        self.model = create_empty_model(self.model_class,  parameters = model_parameters , device = self.device)
        self.loss_function = create_loss_function(self.loss_function_name)
        self.optimizer = create_optimizer(self.model ,self.optimizer_name, self.optimizer_params)
        print(f"model type initialized: {self.model_class}  ") 
        print(f"Optimizer: {self.optimizer_name}") 
        print(f"loss function: {self.loss_function_name}") 
        print(self.model)


      

    def initiate_training(self, epochs,  train_dataloader, test_dataloader):
        """
        This is the main training function. Takes a number of epochs, training dataloader, and test dataloader and initiates the actual trainig.
        When training data is also collected about how the model performs. 
        """
        
        for t in range(epochs): # train for epochs number of iterations
           
            epoch_stats = {"epoch": self.num_epochs +1}
            print(f"Epoch count {self.num_epochs +1}\n----------------------------------")
            start_time = time.time()  # Start timer
            self.train(train_dataloader)
            end_time = time.time()  # End timer
            result_time = end_time - start_time
            print(f"Execution time of epoch: {result_time:.2f} seconds")
            epoch_stats["time"] = result_time

            match self.model_class:
                case "SimpleAutoencoder":
                    epoch_stats["loss"], epoch_stats["test_originals"] , epoch_stats["test_created"] = self.test_autoencoder(test_dataloader)#output 
                case _:
                    epoch_stats["accuracy"] , epoch_stats["loss"] , epoch_stats["list_of_fails"]=  self.test(test_dataloader)
            

            self.num_epochs += 1 # With a succesful training add one more epoch to the total
            self.stats.append(epoch_stats) # Store metadata of the training in the stats member of the class
        print("Done!")
        self.task_complete_alert()
        
        return self.stats #If the person doing the training wants to access the stats they can easily use this return

    def initiate_training_shape(self, epochs,  train_dataloader, test_dataloader):
        """
        This is the main training function. Takes a number of epochs, training dataloader, and test dataloader and initiates the actual trainig.
        When training data is also collected about how the model performs. 
        """
        
        for t in range(epochs): # train for epochs number of iterations
           
            epoch_stats = {"epoch": self.num_epochs +1}
            print(f"Epoch count {self.num_epochs +1}\n----------------------------------")
            start_time = time.time()  # Start timer
            self.train(train_dataloader)
            end_time = time.time()  # End timer
            result_time = end_time - start_time
            print(f"Execution time of epoch: {result_time:.2f} seconds")
            epoch_stats["time"] = result_time

            match self.model_class:
                case "SimpleAutoencoder":
                    epoch_stats["loss"], epoch_stats["test_originals"] , epoch_stats["test_created"] = self.test_autoencoder(test_dataloader)#output 
                case _:
                    epoch_stats["accuracy"] , epoch_stats["loss"] , epoch_stats["list_of_fails"]=  self.test(test_dataloader)
            

            self.num_epochs += 1 # With a succesful training add one more epoch to the total
            self.stats.append(epoch_stats) # Store metadata of the training in the stats member of the class
        print("Done!")
        self.task_complete_alert()
        
        return self.stats #If the person doing the training wants to access the stats they can easily use this return
    
    
    def train(self, train_dataloader):
        
        size = len(train_dataloader.dataset)
        self.model.train() # set model to trainig mode
        
        for batch, (X, y) in enumerate(train_dataloader):
            X, y = X.to(self.device), y.to(self.device) # set the data to the appropriate device

            # Compute prediction error
            pred = self.model(X)
            
            match self.model_class:
                case "SimpleAutoencoder":
                    loss = self.loss_function(pred, X)
                case _:
                    loss = self.loss_function(pred, y)

            # Backpropagation
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
        
            if batch % 100 == 0:
                loss_val, current = loss.item(), (batch + 1) * len(X)
                print(f"loss: {loss_val:>7f}  [{current:>5d}/{size:>5d}]")


    def test(self, test_dataloader):
        list_of_fails = []
        size = len(test_dataloader.dataset)
        num_batches = len(test_dataloader)
        self.model.eval() # set model to evaluation mode
        test_loss, correct = 0, 0
        with torch.no_grad():
            for batch, (X, y) in enumerate(test_dataloader):
                X, y = X.to(self.device), y.to(self.device)
                pred = self.model(X)
                test_loss += self.loss_function(pred, y).item()

                # Get predicted labels to create a mask of successes
                predicted_labels = pred.argmax(1)
                correct_mask = predicted_labels == y
                correct += correct_mask.type(torch.float).sum().item()

                #Produce a list of all fails
                indices_of_failed_pred = (~correct_mask).nonzero(as_tuple=True)[0]
                list_of_fails.extend([
                {
                    "index": batch * len(y) + j.item(),  # Unique index based on batch
                    "predic": predicted_labels[j].item(),
                    "actual": y[j].item(),
                }
                for j in indices_of_failed_pred
                ])
                
        test_loss /= num_batches
        correct /= size
        print(f"Test Error of epoch {self.num_epochs +1}: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
        return [correct, test_loss, list_of_fails]

    def test_autoencoder(self, test_dataloader):
        """
        Special test case for autoencoders. We have to compare the images to images, and can visually inspect how the model progresses.
        """
        size = len(test_dataloader.dataset)
        num_batches = len(test_dataloader)
        self.model.eval() # set model to evaluation mode
        test_loss = 0
        original_images, genereated_images = [] , []
        with torch.no_grad():
            for batch, (X, y) in enumerate(test_dataloader):
                X, y = X.to(self.device), y.to(self.device)
                pred = self.model(X)
                test_loss += self.loss_function(pred, X).item()
            X, y = next(iter(test_dataloader))  # Get the first batch
            X_samples = X[:10]
            for X in X_samples: #Display comparision of original and prediciton of the the first ten images
                pred = self.model(X)
                original_images.append(X) 
                genereated_images.append(pred) 
        
        test_dataloader = iter(test_dataloader) #Reset the dataloader, we want the same comarisons each time
        test_loss /= num_batches
        
        
        
        print(f"Test Error of epoch {self.num_epochs +1}: \n Avg loss: {test_loss:>8f} \n")
        display_comparissions_autoencoder(original_images, genereated_images)
        return [test_loss, original_images, genereated_images]
        
    
    def print_arcitecture(self):
        print (self.model)

    def task_complete_alert(self):
        """Browser notification 'Task completed!'"""
        
        display(Javascript('alert("Task completed!")'))
        return None

    def get_max_accuarcy(self):
        return max(self.stats, key=lambda x: x["accuracy"])


    def print_time_spent(self):
        """
        gives a description of time spent to console and returns the total time spent training
        """
        time = sum(t["time"] for t in self.stats)
        epochs = len(self.stats)
        if (epochs == 0): 
            print("Model has not trained yet!") 
        else:
            minutes, seconds = divmod(time, 60)
            print(f"In {epochs} epochs you have trained for a total of {round(minutes)} minutes and {round(seconds)} seconds!\nAn average of {(time/epochs):>0.2f} seconds per epoch!")
        return time


    def return_time_spent(self):
        return sum(t["time"] for t in self.stats)

## Saving and loading

In [9]:
def save_model_manager(mm, save_token , make_read_only = True , path = "models\\"):
    """Takes a modelmanager and a save_token and saves a model to disk. By default it makes the file read only, bypass with make_read_only = False"""
    save_name = path + mm.model_class + save_token + ".pth"
    torch.save({'model_class': mm.model_class,
    'model_state_dict': mm.model.state_dict(),
    'model_parameters': mm.model_parameters,
    'optimizer_state_dict': mm.optimizer.state_dict(),
    'loss_function_name': mm.loss_function_name,
    'optimizer_name': mm.optimizer_name,
    'optimizer_params': mm.optimizer_params,
    'hidden_layers': mm.hidden_layers,
    'layer_sizes': mm.layer_sizes,
    'singel_outputs': mm.singel_outputs,
    'stats': mm.stats}, 
    save_name)
    if make_read_only:
        make_file_read_only(save_name)

In [10]:
def load_model_manager(save_name , path = "models\\"):
    loaded_model = torch.load( path + save_name)
    NMM =   model_manager( model_class = loaded_model['model_class'] ,   # New model manager
                loss_function_name = loaded_model['loss_function_name'] ,
                optimizer_name = loaded_model['optimizer_name'],
                optimizer_params = loaded_model['optimizer_params'],
                model_parameters = loaded_model['model_parameters'],
                hidden_layers = loaded_model['hidden_layers'],
                layer_sizes = loaded_model['layer_sizes'],
                singel_outputs = loaded_model['singel_outputs'],
                stats = loaded_model['stats'])
    NMM.model.load_state_dict(loaded_model["model_state_dict"])
    NMM.optimizer.load_state_dict(loaded_model['optimizer_state_dict'])
    print("Load complete")
    return NMM
   
# model_class, loss_function_name ,optimizer_name,optimizer_params , model_parameters = None,  
#                  hidden_layers = 0, layer_sizes = None, singel_outputs = False, stats = None

In [11]:
def load_stats(save_name , path = "models\\"):
    return torch.load(path + save_name, map_location='cpu')['stats']

## Data Managment 

In [12]:
# Get cpu, gpu or mps device for training. Excellent little code snippet that checks what form of acceleration is available. Code from https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html
def get_device():
    """Get cpu, gpu or mps device for training. Excellent little code snippet that checks what form of acceleration is available. Code from https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html"""
    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu" )
    print(f"Using {device} device")
    return device

In [13]:
def displayTensorAsImage(tensor):
    """Takes a two dimentional tensor array and displays it with matplotlib.pyplot. It does not change the original data"""
    
    # Clone data and ensure the tensor is on CPU for matplotlib to function properly
    display_data = tensor.clone().to("cpu")

    # Squeeze extra dimensions if needed on the copied tensor
    if display_data.dim() == 3 and display_data.size(0) == 1:  # Shape (1, 28, 28)
        display_data = display_data.squeeze(0)
    elif display_data.dim() == 3 and display_data.size(-1) == 1:  # Shape (28, 28, 1)
        display_data = display_data.squeeze(-1)

    # Display the image
    plt.imshow(display_data, cmap="gray")
    plt.axis("off")  # Hide axes
    plt.show()
    

In [14]:
# Download training data from torchvision.datasets.
def get_MNIST_train_data(_transform):
    """Returns the training data 60000 MNIST images transformed with _transform parameter"""
    return datasets.MNIST(
        root="data",
        train=True,
        download=True,
        transform= _transform,
    )

# Download test data from open torchvision.datasets.
def get_MNIST_test_data(_transform):
    """Returns the test data 10000 MNIST images transformed with _transform parameter"""
    return datasets.MNIST(
        root="data",
        train=False,
        download=True,
        transform=_transform,
    )

In [196]:
#Getting raw data for basic shape based neural network
def get_raw_MNIST_train_data():
    """ 
        datasets.MNIST(
        root="data",
        train=True,
        download=True,
        transform= transforms.Compose([
            transforms.ToTensor(), 
            ])
    """
    return get_MNIST_train_data( get_raw_MNIST_transform())


def get_raw_MNIST_test_data():
    """ 
        datasets.MNIST(
        root="data",
        train=False,
        download=True,
        transform= transforms.Compose([
            transforms.ToTensor(), 
            ])
    """   
    return get_MNIST_test_data( get_raw_MNIST_transform())

In [15]:
def get_standard_MNIST_training_transform():
    """transforms.RandomRotation(10),  # Randomly rotate by up to 10 degrees
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize((0.1307,), (0.3081,)) # Normalize with mean and standard deviation, already known for MNIST dataset"""
    
    return transforms.Compose([
        transforms.RandomRotation(10),  # Randomly rotate by up to 10 degrees
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize((0.1307,), (0.3081,))  # Normalize with mean and standard deviation, already known for MNIST dataset
        ])

In [16]:
def get_standard_MNIST_autoencoder_training_transform():
    """transforms.RandomRotation(10),  # Randomly rotate by up to 10 degrees
        transforms.ToTensor(),  # Convert to tensor"""
    
    return transforms.Compose([
        transforms.RandomRotation(10),  # Randomly rotate by up to 10 degrees
        transforms.ToTensor(),  # Convert to tensor
        ])

In [17]:
def get_standard_MNIST_test_transform():
    """ transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize((0.1307,), (0.3081,)) # Normalize with mean and standard deviation, already known for MNIST dataset"""
    
    return transforms.Compose([
        transforms.ToTensor(),  # Convert to tensor
        transforms.Normalize((0.1307,), (0.3081,))  # Normalize with mean and standard deviation, already known for MNIST dataset
        ])

In [18]:
def get_standard_MNIST_autoencoder_test_transform():
    """ transforms.ToTensor(),  # Convert to tensor"""
    
    return transforms.Compose([
        transforms.ToTensor(),  # Convert to tensor
        ])

In [19]:
def get_raw_MNIST_transform():
    return transforms.Compose([
        transforms.ToTensor(),  # Convert to tensor
        ])

In [20]:
def standard_batch_size_64():
    """returns 64"""
    return 64

In [21]:
def set_dataloader(data , _batch_size, shuffle = True , collate_fn = None):
    """ Takes basic data and batch size as (_batch_size) returns a dataloader made with the torch.utils.data.DataLoader function"""
    if (collate_fn == None):
        return torch.utils.data.DataLoader(data, batch_size=_batch_size, shuffle = shuffle)
    else: 
        return torch.utils.data.DataLoader(data, batch_size=_batch_size, shuffle = shuffle, collate_fn = collate_fn)

In [22]:
#make file non writeable
def make_file_read_only(filename , path = ""):
    """ 
    Function takes a filename and makes the corresponding file read-only! By deault has not path data added, 
    but path can be added with argument "path" 
    """
    os.chmod(path+filename, stat.S_IREAD|stat.S_IRGRP|stat.S_IROTH)
    print( path+filename , " has been made readOnly")


In [None]:
def remove_file(filename , path = "" ):
    """ 
    Function takes a filename and makes the corresponding file read-only! By deault has not path data added, 
    but path can be added with argument "path" 
    """
    # os.chmod(filename, stat.S_IREAD|stat.S_IRGRP|stat.S_IROTH)
    file_path = path+filename
    
    try:
        os.chmod(file_path, stat.S_IWUSR | stat.S_IREAD)
        os.remove(file_path)
        print(f"Successfully deleted {file_path}")
    except Exception as e:
        print(f"Error: {e}")


In [199]:
def getRandomNumber(n):
    return random.randint(0, n - 1)

## image processing

In [23]:
def classification_3x3_to_shape(grid, threshold=0.2):
    """Takes a 3by3 grid and a threshold value and returns a integer between 0 and 511 for the 512 
    binary configurations you can put a 3by3 grid in."""
   
    
    flattened_grid = np.array(grid).flatten()
    binary_grid = [1 if value >= threshold else 0 for value in flattened_grid]
    classification = int(''.join(str(bit) for bit in binary_grid), 2)

    return classification


In [None]:
def collate_fn_padded(batch):
    """Collate function that pads variable-length feature tensors to match the longest one in the batch."""
    features, labels = zip(*batch)  # Separate features and labels

    # Find the max length in the batch
    max_len = max(feat.shape[0] for feat in features)

    # Pad all tensors to the same length
    padded_features = torch.stack([
        torch.cat([feat, torch.zeros(max_len - feat.shape[0], 3)])  # Pad with zeros
        for feat in features
    ])

    return padded_features, torch.tensor(labels, dtype=torch.long)  # Keep labels unchanged

In [24]:
def deault_threshold_value_0_2(): 
    return 0.2

## Visualisations and graphs

In [25]:
def display_image(data):
    """Displays image. Input to this funciton should be in the form: dataset[image_number]"""
    imageData = data[0]
    imageData = imageData.squeeze(0)
    plt.imshow(imageData, cmap="gray")
    plt.axis("off")  # Hides axes
    plt.show()

In [26]:
def display_comparissions_autoencoder(original, reconstructed):
    """Displays images for comparisson when training an autoencoder."""
   
    # sample_images, _ = next(iter(torch.utils.data.DataLoader(dataset, batch_size=10)))
    # sample_images = sample_images.to(device)

    fig, axes = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(len(original)):
        axes[0, i].imshow(original[i].cpu().squeeze(), cmap="gray")
        axes[1, i].imshow(reconstructed[i].cpu().squeeze(), cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].axis("off")
    plt.show()

In [27]:
def failureMatrixMNIST(listOfFails):
    failedMatrix = [[0 for x in range(10)] for x in range(10)] 
    for elem in listOfFails:
        failedMatrix[elem[1].int()][elem[2]] += 1
    for row in failedMatrix:
        print(row)

In [28]:
def task_complete_alert(self):
    """Browser notification 'Task completed!'"""
    display(Javascript('alert("Task completed!")'))
    return None


In [209]:

# Visualizing the Original and Reconstructed Images
def visualize_reconstruction(model, dataset):
    model.eval()
    sample_images, _ = next(iter(torch.utils.data.DataLoader(dataset, batch_size=10)))
    sample_images = sample_images.to(device)

    with torch.no_grad():
        reconstructed = model(sample_images.view(sample_images.size(0), -1)).cpu()

    fig, axes = plt.subplots(2, 10, figsize=(10, 2))
    for i in range(10):
        axes[0, i].imshow(sample_images[i].cpu().squeeze(), cmap="gray")
        axes[1, i].imshow(reconstructed[i].view(28, 28), cmap="gray")
        axes[0, i].axis("off")
        axes[1, i].axis("off")
    plt.show()

# # Show original vs reconstructed images
# visualize_reconstruction(autoencoder, train_dataset)

In [206]:
#SAVE the data to file:
def save_JSON_stats(fileName, path = "stats/" ):
    with open( path + fileName, "w") as file:
        json.dump(stats, file, indent=2)
    make_file_read_only( path + fileName)

In [208]:
# Load from a JSON file
def load_JSON_stats(fileName, path = "stats/" ):
    with open( path + fileName , "r") as file:
        backup_stats = json.load(file)
    return backup_stats   


In [37]:

class MNISTStatsVisualizer:
    def __init__(self):
        self.datasets = {}  # Store multiple datasets by key
        self.default_dataset = None  # Key for default dataset
        self.stats_data = {}  # Store stats with cumulative total_time
        self.stats_types = {}  # Keep track of stats format types
    
    def load_mnist_test_data(self, key, transform=None):
        """Loads MNIST test dataset and stores it under a key."""
        if transform is None:
            transform = transforms.Compose([transforms.ToTensor()])
        self.datasets[key] = datasets.MNIST(root="data", train=False, download=True, transform=transform)
        if self.default_dataset is None:
            self.default_dataset = key  # Set first loaded dataset as default

    def add_stats(self, key, stats, stats_type):
        """Adds a stats dataset and computes cumulative total_time while storing its type."""
        standardized_stats = [{**entry} for entry in stats]  # Simply copy entries
        self.stats_data[key] = self._compute_total_time(standardized_stats)
        self.stats_types[key] = stats_type  # Store the type of stats
    
    def _compute_total_time(self, stats):
        """Computes cumulative total_time for stats."""
        total_time = 0
        for entry in stats:
            total_time += entry['time']
            entry['total_time'] = total_time
        return stats
    
    def plot_accuracy(self, keys=None):
        """Plots accuracy over total time for multiple stat files."""
        keys = keys or [self.default_dataset]
        title_parts = []
        
        plt.figure(figsize=(16, 8))
        for key in keys:
            if key not in self.stats_data or self.stats_types.get(key) == "type_2":
                print(f"Skipping accuracy plot for {key}, as it lacks accuracy data.")
                continue
            
            times = [entry['total_time'] for entry in self.stats_data[key] if 'accuracy' in entry]
            accuracies = [entry['accuracy'] for entry in self.stats_data[key] if 'accuracy' in entry]
            plt.plot(times, accuracies, marker='o', label=f"Accuracy ({key})")
            
            max_acc = max(accuracies)
            title_parts.append(f"{key}: Max Acc = {max_acc:.3f}")
        
        plt.xlabel("Total Time (seconds)")
        plt.ylabel("Accuracy")
        plt.title("Accuracy Over Time\n"  + " | ".join(title_parts))
        plt.legend()
        plt.show()



    def plot_loss(self, keys=None):
        """Plots loss over total time for multiple stat files and reports lowest loss values in the title."""
        keys = keys or [self.default_dataset]
        
        plt.figure(figsize=(16, 8))
        title_parts = []  # To store the loss information for the title
    
        for key in keys:
            if key not in self.stats_data:
                print("No stats available for key", key)
                continue
    
            times = [entry['total_time'] for entry in self.stats_data[key]]
            losses = [entry['loss'] for entry in self.stats_data[key]]
            plt.plot(times, losses, marker='o', label=f"Loss ({key})", linestyle='dashed')
    
            # Find the minimum loss value and the corresponding time
            min_loss = min(losses)
            
            # Add the minimum loss to the title
            title_parts.append(f"{key}: Min Loss = {min_loss:.3f}")
    
        plt.xlabel("Total Time (seconds)")
        plt.ylabel("Loss")
        plt.title("Loss Over Time\n" + " | ".join(title_parts))  # Add the loss details to the title
        plt.legend()
        plt.show()
    



    def show_failed_predictions(self, stats_key, dataset_key=None, max_images=1000):
        """Displays failed predictions based on the epoch with the highest accuracy."""
        if stats_key not in self.stats_data:
            print("No stats available for key", stats_key)
            return
    
        if self.stats_types.get(stats_key) == "type_2":
            print("Failed predictions are not available for this stats type.")
            return
    
        dataset_key = dataset_key or self.default_dataset
        if dataset_key not in self.datasets:
            print("No dataset available for key", dataset_key)
            return
    
        dataset = self.datasets[dataset_key]
    
        # Find epoch with best accuracy
        best_epoch = max(self.stats_data[stats_key], key=lambda x: x.get('accuracy', 0))
        failed_predictions = best_epoch.get('list_of_fails', [])
    
        # Limit number of displayed failures
        failed_predictions = failed_predictions[:max_images]
    
        num_images = len(failed_predictions)
        cols = min(15, num_images)  # Set max columns to 20
        rows = (num_images // cols) + (num_images % cols > 0)

        fig, axes = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4.8))
    
        axes = axes.flatten() if num_images > 1 else [axes]
    
        for ax, fail in zip(axes, failed_predictions):
            image, _ = dataset[fail['index']]
            ax.imshow(image.squeeze(), cmap='gray')
            ax.set_title(f"P: {fail['predic']} / A: {fail['actual']}", fontsize=52)  # Adjusted font size
            ax.axis('off')
    
        for ax in axes[num_images:]:
            ax.axis('off')
    
        plt.tight_layout()
        plt.show()
  


    def plot_double_histogram(self, stats_keys):
        plt.figure(figsize=(16, 8))  # Enlarged figure for better comparison
        bar_width = 0.2  # Narrower bars to fit multiple models
        num_models = len(stats_keys)
        indices = np.arange(10)  # Positions for digits 0-9
    
        title_parts = []  # Store titles with accuracy
        colors = [  "#1f77b4", "#aec7e8", 
                    "#ff7f0e", "#ffbb78", 
                    "#2ca02c", "#98df8a",  
                    "#d62728", "#ff9896",  
                ]
        
        for i, key in enumerate(stats_keys):
            if self.stats_types.get(key) == 2:
                print(f"Skipping {key} as it has 'type_2' stats, which do not have failed predictions.")
                continue
    
            stats = self.stats_data.get(key, [])
            best_epoch = max(stats, key=lambda x: x.get('accuracy', 0))  # Find best epoch based on accuracy
            accuracy = best_epoch.get('accuracy', 0)  # Get accuracy
            title_parts.append(f"{key}: {accuracy:.2%}")  # Format as percentage
    
            failed_predictions = best_epoch.get('list_of_fails', [])
    
            actual_values = [fail['actual'] for fail in failed_predictions]
            predicted_values = [fail['predic'] for fail in failed_predictions]
    
            # Count occurrences for actual and predicted values
            actual_counts = np.bincount(actual_values, minlength=10)
            predicted_counts = np.bincount(predicted_values, minlength=10)
    
            # Offset each model's bars to the right to separate them
            offset = (i - num_models / 2) * bar_width * 2
    
            plt.bar(indices + offset - bar_width / 2, actual_counts, width=bar_width, label=f"Actual - {key}",       color= colors[i*2+0], alpha=0.7)
            plt.bar(indices + offset + bar_width / 2, predicted_counts, width=bar_width, label=f"Predicted - {key}", color= colors[i*2+1], alpha=0.7)
    
        plt.xlabel("Digit")
        plt.ylabel("Frequency")
        plt.xticks(indices, [str(i) for i in range(10)])  # Label each bar with its digit
        plt.legend()
        plt.title("Histogram of Actual vs. Predicted Values\n" + " | ".join(title_parts))  # Show accuracy in title
        plt.show()

    def plot_mistake_matrix(self, stats_key):
        """Plots a 10x10 mistake matrix with labels and title."""
        if stats_key not in self.stats_data:
            print(f"No stats available for key: {stats_key}")
            return

        if self.stats_types.get(stats_key) == "type_2":
            print(f"Mistake matrix is not available for stats type 'type_2' ({stats_key})")
            return

        # Find the epoch with the highest accuracy
        best_epoch = max(self.stats_data[stats_key], key=lambda x: x.get('accuracy', 0))
        failed_predictions = best_epoch.get('list_of_fails', [])

        # Initialize a 10x10 matrix for digit mistakes (0-9)
        mistake_matrix = np.zeros((10, 10), dtype=int)

        # Fill the matrix with failed predictions
        for fail in failed_predictions:
            actual = fail['actual']
            predicted = fail['predic']
            if actual != predicted:
                mistake_matrix[actual, predicted] += 1  # Increment count for wrong predictions

        # Plot the matrix as a heatmap
        fig, ax = plt.subplots(figsize=(8, 6))
        cax = ax.matshow(mistake_matrix, cmap="Blues")

        # Add color bar
        plt.colorbar(cax)

        # Set axis labels
        ax.set_xlabel("Predicted Label")
        ax.set_ylabel("Actual Label")
        ax.set_title(f"Mistake Matrix for {stats_key} (Best Epoch)")

        # Set tick labels for x and y axes
        ax.set_xticks(np.arange(10))
        ax.set_yticks(np.arange(10))
        ax.set_xticklabels(np.arange(10))
        ax.set_yticklabels(np.arange(10))

        # Show values inside each cell
        for i in range(10):
            for j in range(10):
                value = mistake_matrix[i, j]
                if value > 0:  # Only show non-zero values
                    ax.text(j, i, str(value), ha='center', va='center', color="black", fontsize=10)

        plt.show()
   

In [200]:
def CONTRAST_THRESHOLD():
    return  0.5
    
def features_to_image(features, original_size=(28, 28)):
    if len(features) == 0:
        return torch.zeros(original_size)
    
    height, width = original_size
    center_x, center_y = width / 2 - 0.5, height / 2 - 0.5
    
    # Create empty image
    image = torch.zeros(original_size)
    
    # Convert each feature back to pixel position
    for feature in features:
        contrast, pos_x, pos_y = feature
        # Convert normalized positions back to pixel coordinates using tensor operations
        x = torch.round(pos_x * center_x + center_x).int().item()
        y = torch.round(pos_y * center_y + center_y).int().item()
        
        # Clamp to image bounds
        x = max(0, min(width - 1, x))
        y = max(0, min(height - 1, y))
        
        # Add contrast value to the pixel
        image[y, x] = contrast
    
    return image





def extract_contrast_features(image, threshold=0.5):
    """Optimized version with tensor operations"""
    if not 0 <= threshold <= 1:
        raise ValueError(f"Threshold must be in [0,1], got {threshold}")
    
    if image.max() > 1:
        image = image.float() / 255.0
    
    _, h, w = image.shape
    if h < 2 or w < 2:
        return torch.zeros((0, 3), dtype=torch.float32)
    
    center_x, center_y = (w - 1) / 2, (h - 1) / 2
    img = image.squeeze()

    # Compute contrast differences
    right_diff = torch.abs(img[:, :-1] - img[:, 1:])
    bottom_diff = torch.abs(img[:-1, :] - img[1:, :])
    diag_diff = torch.abs(img[:-1, :-1] - img[1:, 1:])
    bl_diff = torch.abs(img[:-1, 1:] - img[1:, :-1])
    
    feature_list = []  # Use a list for storing mini tensors

    # Apply masks & collect features
    for diff, shift_x, shift_y in [(right_diff, 0, 0), (bottom_diff, 0, 0), 
                                   (diag_diff, 0, 0), (bl_diff, 1, 0)]:
        mask = diff > threshold
        y_coords, x_coords = mask.nonzero(as_tuple=True)
        if y_coords.numel() > 0:  # Only if there are valid points
            feature_list.append(torch.stack([
                diff[y_coords, x_coords],  
                (x_coords + shift_x - center_x) / center_x,  
                (y_coords + shift_y - center_y) / center_y  
            ], dim=1))

    if not feature_list:
        return torch.zeros((0, 3), dtype=torch.float32)

    return torch.cat(feature_list, dim=0)  # More efficient than appending lists

# Function to find lines (including diagonal and anti-diagonal) and add scalers
def find_lines_and_scalers(features, original_size=(28, 28)):
    if len(features) == 0:
        return torch.zeros((0, 3), dtype=torch.float32)
    
    # Sort features based on positions (x, y)
    features = features.sort(dim=0)[0]

    lines = []  # Will hold the new grouped features
    current_line = []
    last_position = None
    line_type = None  # Track the type of symmetry for the line (horizontal, vertical, diagonal, etc.)
    
    # Process each contrast point and group into lines
    for i, feature in enumerate(features):
        contrast, pos_x, pos_y = feature
        
        # Check if the contrast is in the same line (based on position)
        if last_position is not None:
            # Check if points are connected horizontally, vertically, or diagonally
            if (abs(pos_x - last_position[0]) <= 1 and pos_y == last_position[1]):  # Horizontal
                if line_type != 'horizontal':  # A new line type starts
                    if current_line:
                        lines.append(current_line)  # Store the previous line
                    current_line = [feature]  # Start a new line
                    line_type = 'horizontal'
            elif (abs(pos_y - last_position[1]) <= 1 and pos_x == last_position[0]):  # Vertical
                if line_type != 'vertical':
                    if current_line:
                        lines.append(current_line)
                    current_line = [feature]
                    line_type = 'vertical'
            elif (abs(pos_x - last_position[0]) == abs(pos_y - last_position[1]) and  # Diagonal
                  (pos_x > last_position[0] and pos_y > last_position[1]) or
                  (pos_x < last_position[0] and pos_y < last_position[1])):
                if line_type != 'diagonal':
                    if current_line:
                        lines.append(current_line)
                    current_line = [feature]
                    line_type = 'diagonal'
            elif (abs(pos_x - last_position[0]) == abs(pos_y - last_position[1]) and  # Anti-diagonal
                  (pos_x > last_position[0] and pos_y < last_position[1]) or
                  (pos_x < last_position[0] and pos_y > last_position[1])):
                if line_type != 'anti-diagonal':
                    if current_line:
                        lines.append(current_line)
                    current_line = [feature]
                    line_type = 'anti-diagonal'

        else:
            current_line = [feature]  # First element in line
            line_type = 'horizontal'  # Default to horizontal, can be adjusted as needed
        
        last_position = (pos_x, pos_y)  # Update last position
    
    # Append the last line
    if current_line:
        lines.append(current_line)

    # Now, for each line, calculate the scaler (center of the line)
    new_features = []
    for line in lines:
        # For simplicity, let's calculate the center of the line (mean position)
        contrast_values = [f[0] for f in line]
        positions_x = [f[1] for f in line]
        positions_y = [f[2] for f in line]
        
        # Calculate the center of the line
        center_x = torch.mean(torch.tensor(positions_x))
        center_y = torch.mean(torch.tensor(positions_y))
        
        # Calculate line length
        line_length = len(line)
        scaler = min(1.0, line_length / original_size[0])  # Maximum scaler is 1
        
        # Create new feature with contrast and normalized position
        new_features.append(torch.tensor([sum(contrast_values) / len(contrast_values), center_x, center_y, scaler]))

    return torch.stack(new_features)  # Return new feature tensor

# Updated dataset converter that includes labels
def convert_and_optimize_dataset(dataset, optimize = True):
    """
    Convert entire dataset and apply line detection and scaler addition, keeping labels.
    Optimizer find_lines_and_scalers does not work yet
    
    takes:
        dataset: PyTorch dataset of images and labels (e.g., MNIST)
        
    Returns:
        list: List of optimized contrast feature tensors, along with their corresponding labels
    """
    optimized_data = []
    for img, label in dataset:  # Now including labels
        
        features = extract_contrast_features(img)
        if (optimize):
            optimized_features = find_lines_and_scalers(features)
        else:
            optimized_features = features
        optimized_data.append((optimized_features, label))  # Store both features and labels
    
    return optimized_data


def compare_images(img,  reconstructed , text):
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.title("Original Image" + text)
    plt.imshow(img.squeeze(), cmap='gray')
    
    plt.subplot(1, 2, 2)
    plt.title("Reconstructed from Contrast Features")
    plt.imshow(reconstructed, cmap='gray')
    plt.show()

    