In [1]:
import sys
sys.path.append('../.')


from utils import *
from utils_plotting import *
from Data.DataGenerator import *
from Models.Models_normal import *
import torch.optim as optim
from Training.Analysis import fixed_model_batch_analysis
from Data.DataLoader import *
from Training.Spike_loss import *
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


from sklearn.cluster import DBSCAN, SpectralClustering, AgglomerativeClustering, BisectingKMeans
from sklearn.linear_model import RANSACRegressor
from tqdm import tqdm
import numpy as np
import itertools
from sklearn.linear_model import RANSACRegressor


In [2]:
# Load the model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
arch = (784, [256, 128, 128, 128, 64, 64, 64, 64, 64, 64, 32, 10])

model = MNIST_classifier(n_in=arch[0], layer_list=arch[1], bias=0)
# model = MNIST_classifier(n_in=784, layer_list=[128, 128, 128, 64, 64, 64, 32, 16, 8, 3], bias=0)
# model = MNIST_classifier(n_in=784, layer_list=[128, 128, 128, 64, 64, 64, 64, 64, 64, 32, 10], bias=0)

state_dict = torch.load('/home/mila/m/mehrab.hamidi/scratch/training_res/november_res/mnist/normal/bias_0.0001/mnist_training/try_num2/epoch_120/model.pt', weights_only=False)
model.load_state_dict(state_dict)

model.eval()


# load the data
_, _, _, train_samples, train_labels, val_samples, val_labels, test_samples, test_labels = get_mnist_data_loaders()



dataset_target_samples = train_samples
dataset_target_labels = train_labels

Train Loader Batch Shapes:
Batch 1: Images Shape = torch.Size([64, 1, 28, 28]), Labels Shape = torch.Size([64])

Validation Loader Batch Shapes:
Batch 1: Images Shape = torch.Size([64, 1, 28, 28]), Labels Shape = torch.Size([64])

Test Loader Batch Shapes:
Batch 1: Images Shape = torch.Size([64, 1, 28, 28]), Labels Shape = torch.Size([64])

Train Samples Shape: torch.Size([50000, 784])
Train Labels Shape: torch.Size([50000])

Validation Samples Shape: torch.Size([10000, 784])
Validation Labels Shape: torch.Size([10000])

Test Samples Shape: torch.Size([10000, 784])
Test Labels Shape: torch.Size([10000])

Train Label Frequencies: Counter({1: 5678, 7: 5175, 3: 5101, 9: 4988, 2: 4968, 6: 4951, 0: 4932, 4: 4859, 8: 4842, 5: 4506})
Validation Label Frequencies: Counter({7: 1090, 1: 1064, 3: 1030, 8: 1009, 0: 991, 2: 990, 4: 983, 6: 967, 9: 961, 5: 915})
Test Label Frequencies: Counter({1: 1135, 2: 1032, 7: 1028, 3: 1010, 9: 1009, 4: 982, 0: 980, 8: 974, 6: 958, 5: 892})


In [26]:
anal_path = '../../spike_analysis/spikes_hamming_new25/'
if not os.path.isdir(anal_path):
    os.makedirs(anal_path)
results_dict = fixed_model_batch_analysis(model, dataset_target_samples, dataset_target_labels, device, '{}_{}'.format(anal_path, 'val_'), 'analyze', plotting=False)
plt.close()
print("---")

---


In [28]:
# Activation pattern analysis
def analyze_activation_patterns(relu_outputs):
    binary_matrices = [(layer_output > 0).int() for layer_output in relu_outputs]
    activation_patterns = torch.cat(binary_matrices, dim=1)  # Shape: (num_data_points, total_neurons)
    unique_patterns, counts = torch.unique(activation_patterns, dim=0, return_counts=True)
    num_unique_patterns = unique_patterns.size(0)
    hamming_distances = torch.cdist(unique_patterns.float(), unique_patterns.float(), p=0).cpu().numpy()
    hamming_distances_flat = hamming_distances[np.triu_indices(len(unique_patterns), k=1)]
    return unique_patterns, counts, hamming_distances_flat

# Class and spike analysis
def analyze_all_classes_and_spikes(model, mnist_data, mnist_labels, device, anal_path):
    num_classes = len(np.unique(mnist_labels))
    all_classes_hamming_distances = []
    all_spikes_hamming_distances = {}

    for class_idx in tqdm(range(num_classes)):
        # Filter data for the current class
        class_data_indices = np.where(mnist_labels == class_idx)[0]
        results_dict_class = fixed_model_batch_analysis(
            model, torch.Tensor(mnist_data[class_data_indices]), mnist_labels[class_data_indices], device,
            '{}_{}'.format(anal_path, 'train_'), 'analyze', plotting=False
        )
        relu_outputs_class = results_dict_class['representations']
        _, _, class_hamming_distances = analyze_activation_patterns(relu_outputs_class[1:])
        all_classes_hamming_distances.append(class_hamming_distances)

        # Detect spikes (hyperplanes) in the class data
        points = torch.Tensor(np.array(results_dict_class['pca_2'][7]).transpose())
        detected_hyperplanes, total_error, assigned_points = spike_detection_nd(points)
        assignments, _, _ = assign_points_to_hyperplanes(points, detected_hyperplanes)

        all_spikes_hamming_distances[class_idx] = []

        # Analyze each spike
        sorted_hyperplanes_indices = np.argsort([np.sum(assignments.numpy() == i) for i in range(len(detected_hyperplanes))])[::-1]
        for spike_idx in sorted_hyperplanes_indices:
            spike_data_indices = class_data_indices[assignments == spike_idx]
            results_dict_spike = fixed_model_batch_analysis(
                model, torch.Tensor(mnist_data[spike_data_indices]), mnist_labels[spike_data_indices], device,
                '{}_{}'.format(anal_path, 'train_'), 'analyze', plotting=False
            )
            relu_outputs_spike = results_dict_spike['representations']
            _, _, spike_hamming_distances = analyze_activation_patterns(relu_outputs_spike[1:])
            all_spikes_hamming_distances[class_idx].append(spike_hamming_distances)

    return all_classes_hamming_distances, all_spikes_hamming_distances

def plot_hamming_distance_analysis(all_classes_hamming_distances, all_spikes_hamming_distances, anal_path):
    num_classes = len(all_classes_hamming_distances)
    fig = plt.figure(figsize=(36, 50))
    outer_grid = fig.add_gridspec(num_classes, 11, width_ratios=[1, 1] + [1] * 9)  # Adjust grid size if necessary

    for class_idx in range(num_classes):
        # Class hamming distances
        class_box_ax = fig.add_subplot(outer_grid[class_idx, 1])
        class_box_ax.boxplot(all_classes_hamming_distances[class_idx])
        class_box_ax.set_title(f'Class {class_idx}')
        class_box_ax.set_ylabel('Hamming Distance')

        class_mean = np.mean(all_classes_hamming_distances[class_idx])

        # Spike hamming distances
        max_spikes_to_plot = 9  # Adjust based on the grid size
        for spike_idx, spike_hamming_distances in enumerate(all_spikes_hamming_distances[class_idx][:max_spikes_to_plot]):
            spike_box_ax = fig.add_subplot(outer_grid[class_idx, spike_idx + 2])
            spike_box_ax.boxplot(spike_hamming_distances)
            spike_box_ax.axhline(class_mean, color='gray', linestyle='--', alpha=0.7)
            spike_box_ax.set_title(f'Spike {spike_idx}')
            spike_box_ax.set_xlabel('Spike')
            spike_box_ax.set_ylabel('Hamming Distance')

    plt.tight_layout()
    plt.savefig(f'{anal_path}/boxplot_all_layers.png')
    plt.close()


# Final Integration
def run_analysis(model, dataset_target_samples, dataset_target_labels, device, anal_path):
    mnist_labels = dataset_target_labels.detach().cpu().numpy()
    mnist_data = dataset_target_samples.detach().cpu().numpy().reshape(dataset_target_samples.shape[0], 28 * 28)
    
    all_classes_hamming_distances, all_spikes_hamming_distances = analyze_all_classes_and_spikes(
        model, mnist_data, mnist_labels, device, anal_path
    )

    return all_classes_hamming_distances, all_spikes_hamming_distances


def plot_hamming_distance_analysis_with_scatter(all_classes_hamming_distances, all_spikes_hamming_distances, results_dict, mnist_labels, anal_path):
    num_classes = len(all_classes_hamming_distances)
    fig = plt.figure(figsize=(50, 60))
    outer_grid = fig.add_gridspec(num_classes, 11, width_ratios=[2, 1] + [1] * 9)  # First column is wider for scatter plots

    for class_idx in tqdm(range(num_classes)):
        # Extract class data and detect spikes (hyperplanes)
        class_data_indices = np.where(mnist_labels == class_idx)[0]
        points = torch.Tensor(np.array(results_dict['pca_2'][7]).transpose())
        detected_hyperplanes, total_error, assigned_points = spike_detection_nd(points)
        assignments, _, _ = assign_points_to_hyperplanes(points, detected_hyperplanes)

        # Scatter plot for the class points and spikes
        scatter_ax = fig.add_subplot(outer_grid[class_idx, 0])
        colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
        for spike_idx in range(len(detected_hyperplanes)):
            scatter_ax.scatter(
                points[assignments == spike_idx, 0],
                points[assignments == spike_idx, 1],
                s=50,
                color=colors(spike_idx),
                alpha=0.5,
                label=f'Spike {spike_idx}'
            )
        for coef, intercept, _, _ in detected_hyperplanes:
            xx = np.linspace(points[:, 0].min(), points[:, 0].max(), 100)
            yy = (-coef[0] * xx - intercept) / coef[1]
            scatter_ax.plot(xx, yy, color='gray', linestyle='--', alpha=0.5)
        scatter_ax.set_title(f'Class {class_idx} Scatter Plot')
        scatter_ax.set_xlabel('PCA Component 1')
        scatter_ax.set_ylabel('PCA Component 2')
        scatter_ax.legend()

        # Histogram for class hamming distances
        class_hist_ax = fig.add_subplot(outer_grid[class_idx, 1])
        class_hist_ax.hist(
            all_classes_hamming_distances[class_idx],
            bins=20,
            color='blue',
            alpha=0.7,
            weights=np.ones(len(all_classes_hamming_distances[class_idx])) / len(all_classes_hamming_distances[class_idx]) * 100,
        )
        class_hist_ax.set_title(f'Class {class_idx} Hamming Distance Histogram')
        class_hist_ax.set_xlabel('Hamming Distance')
        class_hist_ax.set_ylabel('Percentage (%)')

        # Spike histograms
        max_spikes_to_plot = 9  # Limit the number of spikes visualized
        for spike_idx, spike_hamming_distances in enumerate(all_spikes_hamming_distances[class_idx][:max_spikes_to_plot]):
            spike_hist_ax = fig.add_subplot(outer_grid[class_idx, spike_idx + 2])
            spike_hist_ax.hist(
                spike_hamming_distances,
                bins=20,
                color='red',
                alpha=0.7,
                weights=np.ones(len(spike_hamming_distances)) / len(spike_hamming_distances) * 100,
            )
            spike_hist_ax.set_title(f'Spike {spike_idx} Hamming Distance Histogram')
            spike_hist_ax.set_xlabel('Hamming Distance')
            spike_hist_ax.set_ylabel('Percentage (%)')

    plt.tight_layout()
    plt.savefig(f'{anal_path}/hamming_distance_with_scatter.png')
    plt.close()


In [None]:
all_classes_hamming_distances, all_spikes_hamming_distances = run_analysis(model, dataset_target_samples, dataset_target_labels, device, anal_path)

In [27]:
plot_hamming_distance_analysis(all_classes_hamming_distances, all_spikes_hamming_distances, anal_path)

In [29]:
plot_hamming_distance_analysis_with_scatter(all_classes_hamming_distances, all_spikes_hamming_distances, results_dict, dataset_target_labels, anal_path)

  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))
  colors = plt.cm.get_cmap('tab10', len(detected_hyperplanes))


In [9]:
mnist_labels =  dataset_target_labels.detach().cpu().numpy()
mnist_data = dataset_target_samples.detach().cpu().numpy().reshape(dataset_target_samples.shape[0], 28 * 28)
# Number of classes
num_classes = len(np.unique(mnist_labels))
key = 'representations'
results_dict['representations']

for idx_layer in tqdm(range(len(results_dict[key]))):

    # mnist_target_data = np.array(results_dict[key][idx_layer]).transpose()
    mnist_target_data = results_dict[key][idx_layer].detach().cpu().numpy()

    plot_target_data = np.array(results_dict['pca_2'][idx_layer]).transpose()

    # Create the overall figure
    fig = plt.figure(figsize=(36, 50))
    outer_grid = fig.add_gridspec(num_classes + 1, 2, width_ratios=[1.2, 1.8], height_ratios=[1] + [0.6] * num_classes)


    # Original PCA plot for the entire dataset
    ax_pca = fig.add_subplot(outer_grid[0, 0])
    ax_pca.set_title('2D PCA of Original MNIST Dataset')

    colors = np.array(plot_data_projection(ax_pca, idx_layer, results_dict['pca_2'], labels_all=results_dict['labels']))


    # Process each class for clustering and sample visualization
    for class_idx in range(num_classes):
        # Extract the data points belonging to the current class
        class_data_indices = np.where(mnist_labels == class_idx)[0]
        class_data = mnist_data[class_data_indices]
        mnist_class_data = mnist_target_data[np.where(mnist_labels == class_idx)[0]]

        plot_class_data = plot_target_data[np.where(mnist_labels == class_idx)[0]]

        points = class_data.copy()

        # Apply iterative line fitting clustering to the class data
        # clusters, cluster_indices = iterative_line_fitting(class_data)
        detected_hyperplanes, total_error, assigned_points = spike_detection_nd(points)
        if len(detected_hyperplanes) == 0:
            continue
        assignments, _, _ = assign_points_to_hyperplanes(torch.Tensor(points), detected_hyperplanes)

        # Plot the clustering for the current class
        class_cluster_ax = fig.add_subplot(outer_grid[class_idx + 1, 0])
        class_cluster_ax.set_title(f'Clustering of Class {class_idx} (PCA Reduced)')
        for cluster_idx in range(len(detected_hyperplanes)):
            cluster_color = colors[cluster_idx % len(colors)]
            class_cluster_ax.scatter(plot_class_data[assignments == cluster_idx, 0], plot_class_data[assignments == cluster_idx, 1], label=f'Spike {cluster_idx + 1}', color=cluster_color)
        class_cluster_ax.set_xlabel('First Principal Component')
        class_cluster_ax.set_ylabel('Second Principal Component')
        class_cluster_ax.legend(title='Clusters')

        # Create a grid to hold the sample images for each cluster of the current class
        # sample_grid = fig.add_subplot(outer_grid[class_idx + 1, 1])
        cluster_gridspec = gridspec.GridSpecFromSubplotSpec(1, len(detected_hyperplanes), subplot_spec=outer_grid[class_idx + 1, 1], wspace=0.3)

        for cluster_idx in range(len(detected_hyperplanes)):
            cluster = class_data[assignments == cluster_idx]
            # Create a subplot for each cluster to contain its samples
            cluster_ax = fig.add_subplot(cluster_gridspec[0, cluster_idx])
            cluster_ax.axis('off')
            cluster_ax.set_title(f'Cluster {cluster_idx + 1}', fontsize=10)

            # Plot 15 random samples for the current cluster in a 3x5 grid
            num_samples = min(60, len(cluster))
            rows, cols = 6, 10
            grid = gridspec.GridSpecFromSubplotSpec(rows, cols, subplot_spec=cluster_gridspec[0, cluster_idx], wspace=0.1, hspace=0.1)

            for sample_idx in range(num_samples):
                random_idx = np.random.choice(np.where(assignments == cluster_idx)[0])
                sample_image = class_data[random_idx].reshape(28, 28)
                
                # Determine position within the 3x5 grid
                row, col = divmod(sample_idx, cols)
                sub_ax = fig.add_subplot(grid[row, col])
                sub_ax.imshow(sample_image, cmap='gray')
                sub_ax.axis('off')

    plt.tight_layout()
    plt.subplots_adjust(hspace=0.5, wspace=0.3)  # Add spacing to avoid overlap
    fig.savefig(f"{anal_path}spike_plot{idx_layer}.png")
    plt.close()

100%|██████████| 12/12 [17:00<00:00, 85.07s/it]
