# 🧠 AvicennaFlow Project

**Title**: *Biologically Plausible Learning in CNNs: A Comparison of Backpropagation, DFA, and DKP for Image Classification*   
**Inspiration**: Inspired by Avicenna (Ibn Sina), this project explores the intersection of neuroscience and AI through biologically plausible learning algorithms.



## 📌 Project Motivation

This project investigates how biologically plausible alternatives to backpropagation—namely **Direct Feedback Alignment (DFA)** and **Direct Kolen–Pollack (DKP)**—perform in image classification tasks.

**Key questions include:**
- How do bio-inspired algorithms compare to backpropagation in classification accuracy?
- How do their learning dynamics and generalization capabilities differ?
- How robust are these algorithms against adversarial attacks?
- What insights can we gain from their bias–variance profiles?

**Datasets used**:
- MNIST
- CIFAR-10



## 📦 Dependency Installation (Colab Only)

Install core packages like `torch`, `torchvision`, `rsatoolbox`, and `vibecheck`.  
> ⚠️ *Note: These are only required when running the notebook in Google Colab.*


In [None]:
!pip install torch torchvision matplotlib numpy scikit-learn scipy vibecheck --quiet
!pip install rsatoolbox==0.1.5 --quiet

## 🧰 Library Imports

Import all required Python libraries for:
- Deep learning (`torch`, `torchvision`)
- Analysis (`rsatoolbox`, `sklearn`, `scipy`)
- Plotting (`matplotlib`, `seaborn`)
- Utility and logging (`argparse`, `numpy`, `contextlib`)


In [None]:
from collections import OrderedDict
import logging
import contextlib
import os
import time
import csv
from types import SimpleNamespace
import math
import warnings

# External libraries: General utilities
import argparse

# NumPy
import numpy as np
from numpy import prod

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autograd
from torch.autograd import Variable
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
from torch.utils.tensorboard import SummaryWriter

# TorchVision
from torchvision import datasets, transforms
from torchvision.models.feature_extraction import (
    create_feature_extractor, get_graph_node_names
)
from torchvision.utils import make_grid

# Matplotlib for plotting
import matplotlib as mpl
import matplotlib.pyplot as plt
# %matplotlib inline  # Only needed in Jupyter notebooks

# SciPy for statistical functions
import scipy
from scipy import stats

# Scikit-Learn for machine learning utilities
from sklearn.decomposition import PCA
from sklearn import manifold

# RSA toolbox imports
import rsatoolbox
from rsatoolbox.data import Dataset
from rsatoolbox.rdm.calc import calc_rdm
import rsatoolbox.rdm
from rsatoolbox.rdm import RDMs

# Warnings settings
warnings.filterwarnings('ignore')

## 🎨 Plot Configuration

We configure plotting aesthetics and fonts using Neuromatch Academy style.  
This ensures high-resolution, consistent visuals across platforms.


In [None]:
logging.getLogger('matplotlib.font_manager').disabled = True

%matplotlib inline
%config InlineBackend.figure_format = 'retina' # perfrom high definition rendering for images and plots
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/main/nma.mplstyle")

## ⚙️ Configuration Parameters

All experiment settings are defined here, including:
- Dataset choice (`MNIST` or `CIFAR10`)
- Learning rates for each method
- Model name, batch size, device, epoch count
- Seed value and checkpoint paths


In [None]:
args = SimpleNamespace(
    train_mode='',
    dataset='',
    epochs=5,
    batch_size=64,
    test_batch_size=1000,
    bp_lr=1,
    lr=5e-4,
    b_lr=1e-4,
    gamma=0.8,
    no_cuda=False,
    dry_run=False,
    seed=42,
    log_interval=10,
    save_model=True,
)

use_cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
args.device = device

## 🔁 Reproducibility

We define a function to fix the random seed in Python, NumPy, and PyTorch  
for consistent and reproducible training results across runs.


In [None]:
import random

def set_seed(seed=None, seed_torch=True):
  if seed is None:
    seed = np.random.choice(2 ** 32)
  random.seed(seed)
  np.random.seed(seed)
  if seed_torch:
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_seed(seed = args.seed)

## 🧰 Utility Functions

This block defines helper functions for:
- Data loading and preprocessing
- Feature extraction using forward and backward hooks
- Adversarial image generation (FGSM)
- Logging metrics into CSV format
- Visualization and sampling utilities


In [None]:
import io

def calc_rdms(model_features, method='correlation'):
    """
    Calculates representational dissimilarity matrices (RDMs) for model features.

    Inputs:
    - model_features (dict): A dictionary where keys are layer names and values are features of the layers.
    - method (str): The method to calculate RDMs, e.g., 'correlation'. Default is 'correlation'.

    Outputs:
    - rdms (pyrsa.rdm.RDMs): RDMs object containing dissimilarity matrices.
    - rdms_dict (dict): A dictionary with layer names as keys and their corresponding RDMs as values.
    """
    ds_list = []
    for l in range(len(model_features)):
        layer = list(model_features.keys())[l]
        feats = model_features[layer]

        if type(feats) is list:
            feats = feats[-1]

        if not args.no_cuda:
            feats = feats.cpu()

        if len(feats.shape) > 2:
            feats = feats.flatten(1)

        feats = feats.detach().numpy()
        ds = Dataset(feats, descriptors=dict(layer=layer))
        ds_list.append(ds)

    rdms = calc_rdm(ds_list, method=method)
    rdms_dict = {list(model_features.keys())[i]: rdms.get_matrices()[i] for i in range(len(model_features))}

    return rdms, rdms_dict

def fetch_dataloaders(args):
    """
    Fetches the data loaders for training and testing datasets.

    Inputs:
    - args (Namespace): Parsed arguments with training configuration.

    Outputs:
    - train_loader (torch.utils.data.DataLoader): DataLoader for the training data.
    - test_loader (torch.utils.data.DataLoader): DataLoader for the test data.
    """
    train_kwargs = {'batch_size': args.batch_size}
    test_kwargs = {'batch_size': args.test_batch_size}
    if not args.no_cuda:
        cuda_kwargs = {'num_workers': 1,
                       'pin_memory': True,
                       'shuffle': True}
        train_kwargs.update(cuda_kwargs)
        test_kwargs.update(cuda_kwargs)

    if args.dataset == 'CIFAR10':

        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        with contextlib.redirect_stdout(io.StringIO()): #to suppress output
            train_data = datasets.CIFAR10(root='~/data', train=True, download=True, transform=transform)
            train_loader = torch.utils.data.DataLoader(train_data, **train_kwargs)

            test_data = datasets.CIFAR10(root='~/data', train=False, download=True, transform=transform)
            test_loader = torch.utils.data.DataLoader(test_data, **test_kwargs)

    elif args.dataset == 'MNIST':

        transform=transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
            ])
        with contextlib.redirect_stdout(io.StringIO()): #to suppress output
            dataset1 = datasets.MNIST('~/data', train=True, download=True, transform=transform)
            dataset2 = datasets.MNIST('~/data', train=False, download=True, transform=transform)

            train_loader = torch.utils.data.DataLoader(dataset1, **train_kwargs)
            test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    return train_loader, test_loader

def sample_images_cifar(data_loader, n=5, plot=False, unnormalize=True):
    """
    Return ≤ n images per class from a DataLoader (first batch only)
    and, optionally, display them.

    Args
    ----
    data_loader : torch.utils.data.DataLoader
    n           : images per class (default 5)
    plot        : show a grid with matplotlib (default False)
    unnormalize : map images back to [0,1] if they were normalized (default True)

    Returns
    -------
    imgs   : torch.Tensor  # shape (k, C, H, W)
    labels : torch.Tensor  # shape (k,)
    """
    # pull one batch
    imgs, targets = next(iter(data_loader))

    # find unique class labels in that batch
    classes = torch.unique(targets)

    imgs_out, lbls_out = [], []
    for c in classes:
        idx = torch.where(targets == c)[0][:n]       # first n of that class
        imgs_out.append(imgs[idx])
        lbls_out.extend([c.item()] * len(idx))

    imgs_out = torch.cat(imgs_out, dim=0)
    labels   = torch.tensor(lbls_out)

    if unnormalize:
        # reverse the common CIFAR‑10 normalisation (mean=std=0.5)
        imgs_plot = imgs_out * 0.5 + 0.5             # -> [0,1]
    else:
        imgs_plot = imgs_out

    if plot:
        with plt.xkcd():
            grid = make_grid(imgs_plot, nrow=5, padding=0)
            plt.imshow(grid.permute(1, 2, 0).cpu())  # (C,H,W) -> (H,W,C)
            plt.axis("off")
            plt.show()

    return imgs_out, labels

def sample_images_mnist(data_loader, n=5, plot=False):
    """
    Samples a specified number of images from a data loader.

    Inputs:
    - data_loader (torch.utils.data.DataLoader): Data loader containing images and labels.
    - n (int): Number of images to sample per class.
    - plot (bool): Whether to plot the sampled images using matplotlib.

    Outputs:
    - imgs (torch.Tensor): Sampled images.
    - labels (torch.Tensor): Corresponding labels for the sampled images.
    """

    with plt.xkcd():
        imgs, targets = next(iter(data_loader))

        imgs_o = []
        labels = []
        for value in range(10):
            cat_imgs = imgs[np.where(targets == value)][0:n]
            imgs_o.append(cat_imgs)
            labels.append([value]*len(cat_imgs))

        imgs = torch.cat(imgs_o, dim=0)
        labels = torch.tensor(labels).flatten()

        if plot:
            plt.imshow(torch.moveaxis(make_grid(imgs, nrow=5, padding=0, normalize=False, pad_value=0), 0,-1))
            plt.axis('off')

        return imgs, labels

def extract_features_dkp(model, imgs):
    """
    Extracts features from specified layers of the model.

    Inputs:
    - model (torch.nn.Module): The model from which to extract features.
    - imgs (torch.Tensor): Batch of input images.

    Outputs:
    - model_features (dict): A dictionary with layer names as keys and extracted features as values.
                              Also includes 'input' key for the input images.
    """
    model_features = {}

    def save_features(name):
        def hook(module, input, output):
            model_features[name] = output.detach().cpu()
        return hook

    # Register hooks for conv and linear layers
    for name, module in model.named_modules():
        if isinstance(module, (nn.Conv2d, nn.Linear)):
            module.register_forward_hook(save_features(name))

    # Also store input images
    model_features['input'] = imgs.detach().cpu()

    model.eval()
    with torch.no_grad():
        _ = model(imgs)

    # convert fc2 logits to probs
    logits = model_features['fc2']               # shape [batch_size, 10]
    probs  = F.softmax(logits, dim=1)            # now in [0,1], sum to 1 per row
    model_features['fc2'] = probs 

    return model_features

def extract_features_bp(model, imgs, return_layers):
    """
    Extracts features from specified layers of the model.

    Inputs:
    - model (torch.nn.Module): The model from which to extract features.
    - imgs (torch.Tensor): Batch of input images.
    - return_layers (list): List of layer names from which to extract features.
    - plot (str): Option to plot the features. Default is 'none'.

    Outputs:
    - model_features (dict): A dictionary with layer names as keys and extracted features as values.
    """
    if return_layers == 'all':
        return_layers, _ = get_graph_node_names(model)
    elif return_layers == 'layers':
        layers, _ = get_graph_node_names(model)
        return_layers = [l for l in layers if 'input' in l or 'conv' in l or 'fc' in l]

    feature_extractor = create_feature_extractor(model, return_nodes=return_layers)
    model_features = feature_extractor(imgs)

    # convert fc2 logits to probs
    logits = model_features['fc2']               # shape [batch_size, 10]
    probs  = F.softmax(logits, dim=1)            # now in [0,1], sum to 1 per row
    model_features['fc2'] = probs 

    return model_features

def fgsm_attack(image, epsilon, data_grad):
    """
    Performs FGSM attack on an image.

    Inputs:
    - image (torch.Tensor): Original image.
    - epsilon (float): Perturbation magnitude.
    - data_grad (torch.Tensor): Gradient of the data.

    Outputs:
    - perturbed_image (torch.Tensor): Perturbed image after FGSM attack.
    """
    sign_data_grad = data_grad.sign()
    perturbed_image = image + epsilon * sign_data_grad
    perturbed_image = torch.clamp(perturbed_image, 0, 1)
    return perturbed_image

def denorm(batch, mean, std):
    mean = torch.tensor(mean, device=batch.device)
    std = torch.tensor(std, device=batch.device)
    return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1)

def normalize(batch, mean, std):
    mean = torch.tensor(mean, device=batch.device)
    std = torch.tensor(std, device=batch.device)
    return (batch - mean.view(1, -1, 1, 1)) / std.view(1, -1, 1, 1)

def generate_adversarial(args, model, imgs, targets):
    """
    Generates adversarial examples using FGSM attack for MNIST or CIFAR-10.

    Inputs:
    - model (torch.nn.Module): The model to attack.
    - imgs (torch.Tensor): Batch of normalized images.
    - targets (torch.Tensor): Batch of target labels.

    Outputs:
    - adv_imgs (torch.Tensor): Batch of adversarial images (normalized).
    """
    if args.dataset == 'MNIST':
        mean = [0.1307]
        std = [0.3081]
        epsilon = 0.2
    elif args.dataset == 'CIFAR10':
        mean = [0.5, 0.5, 0.5]
        std = [0.5, 0.5, 0.5]
        epsilon = 0.03
    else:
        raise ValueError(f"Unsupported dataset: {args.dataset}")

    adv_imgs = []

    for img, target in zip(imgs, targets):
        img = img.unsqueeze(0)
        target = target.unsqueeze(0)
        img.requires_grad = True

        output = model(img)
        loss = F.cross_entropy(output, target)

        model.zero_grad()
        loss.backward()

        data_grad = img.grad.data
        data_denorm = denorm(img, mean, std)
        perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad)
        perturbed_data_normalized = normalize(perturbed_data, mean, std)

        adv_imgs.append(perturbed_data_normalized.detach())

    return torch.cat(adv_imgs)

def test_adversarial(model, imgs, targets):
    """
    Tests the model on adversarial examples and prints the accuracy.

    Inputs:
    - model (torch.nn.Module): The model to be tested.
    - imgs (torch.Tensor): Batch of adversarial images.
    - targets (torch.Tensor): Batch of target labels.
    """
    correct = 0
    model.eval()
    output = model(imgs)
    pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
    correct += pred.eq(targets.view_as(pred)).sum().item()

    final_acc = correct / float(len(imgs))
    # print(f"adversarial test accuracy = {correct} / {len(imgs)} = {final_acc}")
    return final_acc

class CSVLogger:
    def __init__(self, fieldnames, args):
        datetime = time.strftime('%y%m%d_%H%M%S')
        self.fieldnames = fieldnames

        filedir = os.path.relpath(os.path.join('results', args.train_mode))
        if not os.path.exists(filedir):
            os.makedirs(filedir)
        self.filename = os.path.relpath(os.path.join(filedir, f"{args.train_mode.lower()}_{args.dataset.lower()}_epochs{args.epochs}_{datetime}.csv"))

        with open(self.filename, 'a', newline='') as csvfile:
            csvfile.write(str(args) + '\n')
            writer = csv.DictWriter(csvfile, fieldnames=self.fieldnames)
            writer.writeheader()

    def save_values(self, *values):
        assert len(values) == len(self.fieldnames), 'The number of values should match the number of field names.'
        with open(self.filename, 'a', newline='') as csvfile:
            writer = csv.DictWriter(csvfile, fieldnames=self.fieldnames)
            row = {}
            for i, val in enumerate(values):
                row[self.fieldnames[i]] = val

            writer.writerow(row)

## 🧠 Representational Similarity Analysis (RSA)

We extract intermediate layer activations and compute RDMs  
to evaluate how similarly models represent input images internally.


In [None]:
method_names = ['BP', 'DKP', 'DFA']
model_colors = {'BP': 'blue', 'DKP': 'green', 'DFA': 'red'}

def load_models(args, imgs):
    
    batch_size = imgs.shape[0]
    in_channels = imgs.shape[1]
    input_height = imgs.shape[2]
    
    # ----- BP -----
    model_bp = ConvNetBP(in_channels, input_height)
    model_bp.load_state_dict(torch.load(f"model_checkpoints/BP_{args.dataset}_epoch5_s{args.seed}.pth", map_location=args.device))
    model_bp.to(args.device)
    
    # ----- DKP -----
    args.train_mode = 'DKP' 
    model_dkp = ConvNetDKP(args, in_channels, input_height) 
    model_dkp.to(args.device)
    with torch.no_grad():
        model_dkp(torch.randn(batch_size, in_channels, input_height, input_height).to(args.device))
    model_dkp.load_state_dict(torch.load(f'model_checkpoints/DKP_{args.dataset}_epoch5_s{args.seed}.pt', map_location=args.device), strict=False)
    
    # ----- DFA -----
    args.train_mode = 'DFA' 
    model_dfa = ConvNetDKP(args, in_channels, input_height)
    model_dfa.to(args.device)
    with torch.no_grad():
        model_dfa(torch.randn(batch_size, in_channels, input_height, input_height).to(args.device))
    model_dfa.load_state_dict(torch.load(f'model_checkpoints/DFA_{args.dataset}_epoch5_s{args.seed}.pt', map_location=args.device), strict=False)

    models = {'BP': model_bp, 'DKP': model_dkp, 'DFA': model_dfa}
    return models

def grab_features_and_rdms(args, imgs, labels, rdm_method='correlation', adversarial=False, seed=42):
  
    models_rdms = []
    models_feats = []
    adv_accuracies = []

    models = load_models(args, imgs)
    model_bp = models['BP']
    model_dkp = models['DKP']
    model_dfa = models['DFA']

    imgs = imgs.to(args.device)
    labels = labels.to(args.device)
    
    # ----- BP -----
    # model_bp.eval()
    if adversarial == True:
        imgs = generate_adversarial(args, model_bp, imgs, labels)
        adv_acc = test_adversarial(model_bp, imgs, labels)
        adv_accuracies.append(adv_acc)
    return_layers = ['input', 'conv1', 'conv2', 'fc1', 'fc2']
    features_bp = extract_features_bp(model_bp, imgs, return_layers=return_layers)
    _, rdms_bp = calc_rdms(features_bp, rdm_method)
    models_feats.append(features_bp)
    models_rdms.append(rdms_bp)
    
    # ----- DKP -----
    # model_dkp.eval()
    if adversarial == True:
        imgs = generate_adversarial(args, model_dkp, imgs, labels)
        adv_acc = test_adversarial(model_dkp, imgs, labels)
        adv_accuracies.append(adv_acc)
    features_dkp = extract_features_dkp(model_dkp, imgs)
    _, rdms_dkp = calc_rdms(features_dkp, rdm_method)
    models_feats.append(features_dkp)
    models_rdms.append(rdms_dkp)
    
    # ----- DFA -----
    # model_dfa.eval()
    if adversarial == True:
        imgs = generate_adversarial(args, model_dfa, imgs, labels)
        adv_acc = test_adversarial(model_dfa, imgs, labels)
        adv_accuracies.append(adv_acc)
    features_dfa = extract_features_dkp(model_dfa, imgs)
    _, rdms_dfa = calc_rdms(features_dfa, rdm_method)
    models_feats.append(features_dfa)
    models_rdms.append(rdms_dfa)
    
    return models_rdms, models_feats, adv_accuracies

def plot_all_rdms(models_rdms, method_names, n_cols=5, adversarial=False):
    """
    Plots the RDMs of multiple learning rules in a single figure.
    Each row corresponds to one learning rule, each row has n_cols RDMs.
    Inputs:
    - models_rdms (list of dict): List of RDM dictionaries for each method.
    - method_names (list of str): Names of learning rules in order.
    - n_cols (int): Number of columns/layers per method to display.
    """
    n_methods = len(models_rdms)
    fig = plt.figure(figsize=(3*n_cols, 3*n_methods))
    gs = fig.add_gridspec(n_methods, n_cols)

    for row in range(n_methods):
        rdm_dict = models_rdms[row]
        layers = list(rdm_dict.keys())[:n_cols]
        for col, layer in enumerate(layers):
            rdm = np.squeeze(rdm_dict[layer])
            if len(rdm.shape) < 2:
                size = int(np.sqrt(rdm.shape[0]))
                rdm = rdm.reshape((size, size))
            rdm = rdm / np.max(rdm)
            ax = fig.add_subplot(gs[row, col])
            im = ax.imshow(rdm, cmap='magma_r')
            if col == 0:
                ax.set_ylabel(method_names[row], rotation=0, size='large', labelpad=40, va='center')
            ax.set_title(layer)
            ax.set_xticks([])
            ax.set_yticks([])
    cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])
    fig.colorbar(im, cax=cbar_ax)
    cbar_ax.set_ylabel('Normalized euclidean distance', rotation=90, size='medium')
    plt.tight_layout(rect=[0, 0, 0.9, 1])
    plt.show()
    fig.savefig(f"rdms_all_{args.dataset}{'_adversarial' if adversarial else ''}.png", dpi=300, bbox_inches='tight')
    
def rep_path_all(args, model_names, model_colors, imgs, labels=None,
                 rdm_calc_method='euclidean', rdm_comp_method='cosine', 
                 adversarial=False):
    """
    Plots per-method RDM dissimilarity matrices and representational geometry paths,
    including optional label representation as the final point in each path.

    - models_rdms: list of dicts mapping layer names to raw RDM numpy arrays.
    - model_names: list of method names.
    - model_colors: dict mapping method names to colors.
    - labels: optional tensor of labels for label RDM.
    - rdm_calc_method: metric for RDM calculation.
    - rdm_comp_method: metric for RDM comparison.
    """

    models_rdms, _, _ = grab_features_and_rdms(args, imgs, labels, rdm_method=rdm_calc_method, adversarial=adversarial)

    # Wrap raw RDM matrices into RDMs objects
    def wrap_dict(raw_dict):
        return {
            layer: RDMs(
                dissimilarities=np.array([mat]),
                dissimilarity_measure=rdm_calc_method,
                rdm_descriptors={'layer': [layer]}
            )
            for layer, mat in raw_dict.items()
        }
    wrapped_models = [wrap_dict(m) for m in models_rdms]

    with plt.xkcd():
        fig = plt.figure(figsize=(10, 4 * len(wrapped_models)))
        gs = fig.add_gridspec(len(wrapped_models), 2, wspace=0.5, hspace=0.4)

        for i, (rdm_dict, method_name) in enumerate(zip(wrapped_models, model_names)):
            layers = list(rdm_dict.keys())
            rdms = [rdm_dict[layer] for layer in layers]

            if labels is not None:
                label_rdm, _ = calc_rdms(
                    {'labels': F.one_hot(labels).float().to(device)},
                    method=rdm_calc_method,
                )
                label_rdm.dissimilarity_measure = rdm_calc_method  
                layers.append('labels')
                rdms.append(label_rdm)

            # for rdm in rdms:
            #     print(rdm.dissimilarity_measure)
    
            # Concatenate and compare
            rdms_concat = rsatoolbox.rdm.concat(*rdms)
            comp = rsatoolbox.rdm.compare(rdms_concat, rdms_concat, method=rdm_comp_method)
            if rdm_comp_method == 'cosine':
                comp = np.arccos(comp)
            comp = np.nan_to_num((comp + comp.T) / 2.0)

            # Left subplot: dissimilarity matrix
            ax0 = fig.add_subplot(gs[i, 0])
            im = ax0.imshow(comp, cmap='viridis_r')
            fig.colorbar(im, ax=ax0, fraction=0.046, pad=0.04)
            ax0.set_title(f'{method_name}: RDM dissimilarity', fontsize=12)
            ax0.set_xticks(range(len(layers)))
            ax0.set_xticklabels(layers, rotation=80, fontsize=8)
            ax0.set_yticks(range(len(layers)))
            ax0.set_yticklabels(layers, fontsize=8)

            # Right subplot: MDS path
            transformer = manifold.MDS(
                n_components=2, dissimilarity='precomputed',
                max_iter=1000, n_init=10, normalized_stress='auto'
            )
            coords = transformer.fit_transform(comp)
            ax1 = fig.add_subplot(gs[i, 1])
            path_end = len(layers) - (1 if labels is not None else 0)
            ax1.plot(coords[:path_end, 0], coords[:path_end, 1],
                     color=model_colors[method_name], marker='.')
            # for idx in range(path_end):
            #     ax1.text(coords[idx, 0], coords[idx, 1], layers[idx], fontsize=8)
            # Input marker
            ax1.plot(coords[0, 0], coords[0, 1], color='k', marker='s')
            # Label marker
            if labels is not None:
                ax1.plot(coords[-1, 0], coords[-1, 1], color='m', marker='*')

            ax1.set_title(f'{method_name}: Representational path', fontsize=12)
            ax1.set_xlabel('dim 1')
            ax1.set_ylabel('dim 2')
            lim = coords.max() - coords.min()
            ax1.set_xlim(coords[:, 0].min() - 0.1 * lim, coords[:, 0].max() + 0.1 * lim)
            ax1.set_ylim(coords[:, 1].min() - 0.1 * lim, coords[:, 1].max() + 0.1 * lim)

        plt.tight_layout()
        plt.show()
        fig.savefig(f"rep_path_all_{args.dataset}{'_adversarial' if adversarial else ''}.png", dpi=300, bbox_inches='tight')

def plot_dim_reduction_all(all_model_features, labels, transformer_funcs, method_names=['BP', 'DKP', 'DFA'], adversarial=False):
    """
    Plots dimensionality reduction for multiple methods and layers with a unified class legend.

    all_model_features: list of dicts mapping layer names to feature tensors/arrays
    labels: 1D array-like of integer labels for samples
    transformer_funcs: list of strings specifying transformers (e.g., ['PCA', 't-SNE'])
    method_names: list of method names corresponding to all_model_features
    """
    # Prepare transformers
    transformers = []
    for t in transformer_funcs:
        if t == 'PCA':
            transformers.append(PCA(n_components=2))
        elif t == 'MDS':
            transformers.append(manifold.MDS(n_components=2, normalized_stress='auto'))
        elif t == 't-SNE':
            transformers.append(manifold.TSNE(n_components=2, perplexity=40, verbose=0))
        else:
            raise ValueError(f"Unknown transformer: {t}")

    n_methods = len(all_model_features)
    n_transformers = len(transformers)
    layers = list(all_model_features[0].keys())
    n_layers = len(layers)

    # Global figure
    fig = plt.figure(figsize=(3 * n_layers, 2.5 * n_methods * n_transformers))
    gs = fig.add_gridspec(n_methods * n_transformers, n_layers)

    # Colors and legend handles
    labels_np = np.array(labels)
    unique_labels = np.unique(labels_np)
    cmap = plt.get_cmap('tab10', len(unique_labels))
    norm = mpl.colors.BoundaryNorm(range(len(unique_labels)+1), cmap.N)

    # Plot each method/transformer/layer
    for m_idx, model_features in enumerate(all_model_features):
        for f_idx, transformer in enumerate(transformers):
            row_idx = m_idx * n_transformers + f_idx
            for l_idx, layer in enumerate(layers):
                feats = model_features[layer]
                if isinstance(feats, np.ndarray):
                    feats = torch.from_numpy(feats)
                feats = feats.detach().cpu().flatten(1).numpy()

                transformed = transformer.fit_transform(feats)

                ax = fig.add_subplot(gs[row_idx, l_idx])
                ax.axis('off')
                if f_idx == 0:
                    ax.set_title(layer, fontsize=12)
                if l_idx == 0:
                    ax.text(-0.2, 0.5, f"{method_names[m_idx]}",
                            size=15, ha='right', va='center', transform=ax.transAxes)

                sc = ax.scatter(transformed[:, 0], transformed[:, 1], c=labels_np,
                                cmap=cmap, norm=norm, s=20)

    # Add legend on right side, moved further out to avoid overlap
    fig.subplots_adjust(right=0.80, wspace=0.2, hspace=0.2)
    # place colorbar further from plots
    cax = fig.add_axes([1.05, 0.15, 0.015, 0.7])  # moved left to 0.85, narrower width
    cb = mpl.colorbar.ColorbarBase(cax, cmap=cmap, norm=norm,
                                    boundaries=np.arange(len(unique_labels)+1)-0.5,
                                    ticks=unique_labels, spacing='proportional')
    cb.set_label('Class Label')
    plt.show()
    fig.savefig(f"tsne_{args.dataset}{'_adversarial' if adversarial else ''}.png", dpi=300, bbox_inches='tight')


## 📶 Gradient Signal-to-Noise Ratio (SNR)

To assess the quality of learning signals, we compute the SNR of gradients across layers.  
A higher SNR indicates more reliable and informative gradient directions.


In [None]:
def compute_gradient_SNR(args, model, dataset, max_samples=None):
    """
    Computes gradient SNRs for a model given a dataset.
    
    Arguments:
    - model (torch.nn.Module): your CNN (with .list_parameters() and .gather_gradient_dict()).
    - dataset (torch.utils.data.Dataset): data to run over.
    - max_samples (int or None): if set, stop after this many samples.
    
    Returns:
    - SNR_dict (dict): avg SNR for each parameter name.
    """
    model.to(args.device).eval()
    loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

    # initialize storage for per-sample grads
    gradients = {name: [] for name in model.list_parameters()}

    for idx, (X, y) in enumerate(loader):
        if max_samples is not None and idx >= max_samples:
            break

        X, y = X.to(args.device), y.to(args.device)
        model.zero_grad()

        y_pred = model(X)
        loss = F.cross_entropy(y_pred, y)
        loss.backward()

        # collect this sample's grads
        for name, grad in model.gather_gradient_dict().items():
            if grad is not None:
                # detach, move to cpu, flatten into 1D
                g = grad.detach().cpu().numpy().ravel()
                gradients[name].append(g)

    # now compute SNR per parameter
    SNR_dict = {}
    for name, grad_list in gradients.items():
        if len(grad_list) == 0:
            SNR_dict[name] = float("nan")
            continue
        data = np.stack(grad_list, axis=0)   # shape = (n_samples, n_params)
        SNR_dict[name] = compute_SNR(data)

    return SNR_dict

def compute_SNR(data, epsilon=1e-7):
    """
    Calculates the average SNR for of data across the first axis.
    
    Arguments:
    - data (torch Tensor): items x gradients
    - epsilon (float, optional): value added to the denominator to avoid
      division by zero.
    
    Returns:
    - avg_SNR (float): average SNR across data items
    """
    
    absolute_mean = np.abs(np.mean(data, axis=0))
    std = np.std(data, axis=0)
    SNR_by_item = absolute_mean / (std + epsilon)
    avg_SNR = np.mean(SNR_by_item)
    
    return avg_SNR

def plot_gradient_SNRs(SNR_dict, width=0.5, ax=None, adversarial=False):
    """
    Plot gradient SNRs for various learning rules.
    
    Arguments:
    - SNR_dict (dict): Gradient SNRs for each learning rule.
    - width (float, optional): Width of the bars.
    - ax (plt subplot, optional): Axis on which to plot gradient SNRs. If None, a
    new axis will be created.
    
    Returns:
    - ax (plt subplot): Axis on which gradient SNRs were plotted.  """
    
    if ax is None:
        wid = min(8, len(SNR_dict) * 1.5)
        _, ax = plt.subplots(figsize=(wid, 4))
    
    xlabels = list()
    SNR_means = list()
    SNR_sems = list()
    SNRs_scatter = list()
    for m, (model_type, SNRs) in enumerate(SNR_dict.items()):
        xlabels.append(model_type)
        color = get_plotting_color(model_idx=m)
        ax.bar(
            m, np.mean(SNRs), yerr=scipy.stats.sem(SNRs),
            alpha=0.5, width=width, capsize=5, color=color
            )
        s = [20 + i * 30 for i in range(len(SNRs))]
        ax.scatter([m] * len(SNRs), SNRs, alpha=0.8, s=s, color=color, zorder=5)
    
    x = np.arange(len(xlabels))
    ax.set_xticks(x)
    x_pad = (x.max() - x.min() + width) * 0.3
    ax.set_xlim(x.min() - x_pad, x.max() + x_pad)
    ax.set_xticklabels(xlabels, rotation=45)
    ax.set_xlabel("Learning rule")
    ax.set_ylabel("SNR")
    ax.set_title("SNR of the gradients")

    fig = ax.get_figure()  # Get the figure object from the axes
    fig.savefig(f"snr_{args.dataset}{'_adversarial' if adversarial else ''}.png", dpi=300, bbox_inches='tight')    
    return ax

def get_plotting_color(dataset="train", model_idx=None):
    if model_idx is not None:
        dataset = None
    
    if model_idx == 0 or dataset == "train":
        color = "#1F77B4" # blue
    elif model_idx == 1 or dataset == "valid":
        color = "#FF7F0E" # orange
    elif model_idx == 2 or dataset == "test":
        color = "#2CA02C" # green
    else:
        if model_idx is not None:
              raise NotImplementedError("Colors only implemented for up to 3 models.")
        else:
              raise NotImplementedError(
                  f"{dataset} dataset not recognized. Expected 'train', 'valid' "
                  "or 'test'."
                  )
    return color


## 📐 Cosine Similarity of Gradients

We evaluate how closely gradients from DKP and DFA align with those of backpropagation  
by measuring cosine similarity over training epochs.


In [None]:
def train_and_calculate_cosine_sim(
    train_loader,
    valid_loader,
    train_mode = 'DKP',
    dataset = 'CIFAR10',
    num_epochs = 8,
    device = "cuda",
):
    """
    Train both a BP‐trained CNN (model_bp) and a bio-plausible‐trained CNN (dkp, dfa)
    for num_epochs, using their respective optimizers & schedulers.  After each
    epoch, collect per‐parameter gradients on VALIDATION data (sample‐by‐sample)
    and stash them for later cosine‐similarity computation.

    Returns:
        grads_bp:  dict[param_name] -> list of per‐epoch gradient arrays
        grads_lr: dict[param_name] -> list of per‐epoch gradient arrays
    """
    # build_model
    if args.dataset == 'CIFAR10':
        model_bp = NetCIFAR()
        model_lr = ConvNetworkCIFAR(1000, train_mode, device)
    elif args.dataset == 'MNIST':
        model_bp = NetMNIST()
        model_lr = ConvNetworkMNIST(1000, train_mode, device)
    
    model_bp.to(device).train()
    model_lr.to(device).train()

    criterion = torch.nn.CrossEntropyLoss()

    # optimizer
    if train_mode == 'DKP':
        test_dkp(model_lr, device, valid_loader, train_loader, None, None, None)

        forward_params = []
        backward_params = []
        for name, param in model_lr.named_parameters():
            if "backward" in name:
                backward_params.append(param)
            else:
                forward_params.append(param)

        forward_optimizer = optim.SGD([{'params': forward_params}], lr=args.lr, weight_decay=1e-6, momentum=0.9, nesterov=True)
        backward_optimizer = optim.Adam([{'params': backward_params}], lr=args.b_lr, weight_decay=1e-6)
        optimizer_lr = MultipleOptimizer(forward_optimizer, backward_optimizer)
        scheduler_lr = StepLR(backward_optimizer, step_size=1, gamma=args.gamma)
    elif train_mode == 'DFA':
        optimizer_lr = optim.SGD(model_lr.parameters(), lr=args.lr, weight_decay=1e-6, momentum=0.9, nesterov=True)
        scheduler_lr = StepLR(optimizer_lr, step_size=1, gamma=args.gamma)

    optimizer_bp = optim.Adadelta(model_bp.parameters(), lr=args.bp_lr)
    scheduler_bp = StepLR(optimizer_bp, step_size=1, gamma=args.gamma)
    
    # Prepare storage
    cosine_sim_dict = {key: list() for key in model_bp.list_parameters()}
    grads_lr = {key: list() for key in cosine_sim_dict.keys()}
    grads_bp = {key: list() for key in cosine_sim_dict.keys()}

    for epoch in range(1, num_epochs + 1):
        # ——— Training Phase ———
        # BP model
        model_bp.train()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer_bp.zero_grad()
            y_pred = model_bp(X)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer_bp.step()
        scheduler_bp.step()

        # learning_rule model
        model_lr.train()
        for X, y in train_loader:
            X, y = X.to(device), y.to(device)
            optimizer_lr.zero_grad()
            y_pred = model_lr(X, y)
            loss = criterion(y_pred, y)
            loss.backward()
            optimizer_lr.step()
        scheduler_lr.step()

        # ——— Validation + Gradient Collection ———
        model_bp.eval()
        model_lr.eval()
        # accumulate per‐sample grads for this epoch
        epoch_grads_bp  = {n: [] for n in cosine_sim_dict.keys()}
        epoch_grads_lr = {n: [] for n in cosine_sim_dict.keys()}

        # But we will turn grad on for loss.backward() below
        for X, y in valid_loader:
            X, y = X.to(device), y.to(device)

            # BP gradients
            model_bp.zero_grad()
            y_pred = model_bp(X)
            loss = criterion(y_pred, y)
            loss.backward()
            for name, grad in model_bp.gather_gradient_dict().items():
                epoch_grads_bp[name].append(grad.detach().cpu().ravel().numpy())
            model_bp.zero_grad()

            # Learning_rule gradients
            model_lr.zero_grad()
            y_pred = model_lr(X, y)
            loss = criterion(y_pred, y)
            loss.backward()
            for name, grad in model_lr.gather_gradient_dict().items():
                epoch_grads_lr[name].append(grad.detach().cpu().ravel().numpy())
            model_lr.zero_grad()

        for key in cosine_sim_dict.keys():
            lr_grad = np.asarray(epoch_grads_lr[key])  # shape (N, D)
            bp_grad = np.asarray(epoch_grads_bp[key])  # shape (N, D)
            
            if np.allclose(lr_grad, 0):
                warnings.warn(
                    f"Learning rule computed all 0 gradients for epoch {epoch}. "
                    "Cosine similarity cannot be calculated."
                )
                epoch_cosine_sim = np.nan
            elif np.allclose(bp_grad, 0):
                warnings.warn(
                    f"Backprop. rule computed all 0 gradients for epoch {epoch}. "
                    "Cosine similarity cannot be calculated."
                )
                epoch_cosine_sim = np.nan
            else:
                # Normalize each sample's gradient
                lr_norm = lr_grad / (np.linalg.norm(lr_grad, axis=1, keepdims=True) + 1e-8)
                bp_norm = bp_grad / (np.linalg.norm(bp_grad, axis=1, keepdims=True) + 1e-8)
            
                # Compute cosine similarity per sample
                sample_cosines = np.sum(lr_norm * bp_norm, axis=1)  # shape (N,)
            
                # Average over all validation samples
                epoch_cosine_sim = np.nanmean(sample_cosines)
            
            cosine_sim_dict[key].append(epoch_cosine_sim)
        
    return cosine_sim_dict

def plot_gradient_cosine_sims(cosine_sim_dict, ax=None):
  """
  Plot gradient cosine similarities to error backpropagation for various
  learning rules.

  Arguments:
  - cosine_sim_dict (dict): Gradient cosine similarities for each learning rule.
  - ax (plt subplot, optional): Axis on which to plot gradient cosine
    similarities. If None, a new axis will be created.

  Returns:
  - ax (plt subplot): Axis on which gradient cosine similarities were plotted.
  """

  if ax is None:
    _, ax = plt.subplots(figsize=(8, 4))

  max_num_epochs = 0
  for m, (model_type, cosine_sims) in enumerate(cosine_sim_dict.items()):
    cosine_sims = np.asarray(cosine_sims) # params x epochs
    num_epochs = cosine_sims.shape[1]
    x = np.arange(num_epochs)
    cosine_sim_means = np.nanmean(cosine_sims, axis=0)
    cosine_sim_sems = scipy.stats.sem(cosine_sims, axis=0, nan_policy="omit")

    ax.plot(x, cosine_sim_means, label=model_type, alpha=0.8)

    color = get_plotting_color(model_idx=m)
    ax.fill_between(
        x,
        cosine_sim_means - cosine_sim_sems,
        cosine_sim_means + cosine_sim_sems,
        alpha=0.3, lw=0, color=color
        )

    for i, param_cosine_sims in enumerate(cosine_sims):
      s = 20 + i * 30
      ax.scatter(x, param_cosine_sims, color=color, s=s, alpha=0.6)

    max_num_epochs = max(max_num_epochs, num_epochs)

  if max_num_epochs > 0:
    x = np.arange(max_num_epochs)
    xlabels = [f"{int(e)}" for e in x]
    ax.set_xticks(x)
    ax.set_xticklabels(xlabels)

  ymin = ax.get_ylim()[0]
  ymin = min(-0.1, ymin)
  ax.set_ylim(ymin, 1.1)

  ax.axhline(0, ls="dashed", color="k", zorder=-5, alpha=0.5)

  ax.set_xlabel("Epoch")
  ax.set_ylabel("Cosine similarity")
  ax.set_title("Cosine similarity to backprop gradients")
  ax.legend()

  return ax

## 🏗️ Model Architecture: ConvNetBP

A standard CNN with ReLU, batch normalization, and max pooling.  
Used for training with backpropagation.  
Includes hooks for collecting gradients during training.


In [None]:
class ConvNetBP(nn.Module):
    def __init__(self, in_channels, input_size):
        super(ConvNetBP, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv2_bn = nn.BatchNorm2d(32)

        # Determine size after conv and pooling
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, input_size, input_size)
            x = self.conv1(dummy)
            x = self.conv1_bn(x)
            x = F.relu(x)
            x = self.conv2(x)
            x = self.conv2_bn(x)
            x = F.relu(x)
            x = F.max_pool2d(x, 2)
            n_features = x.shape[1] * x.shape[2] * x.shape[3]

        self.fc1 = nn.Linear(n_features, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, input):
        x = self.conv1(input)
        x = self.conv1_bn(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return x

    def list_parameters(self):
        """
        Returns a list of model names for a gradient dictionary.
        
        Returns:
        - params_list (list): List of parameter names.
        """
        
        params_list = list()
        
        for layer_str in ["conv1", "conv2", "fc1", "fc2"]:
          params_list.append(f"{layer_str}_weight")
          # if self.bias:
          #   params_list.append(f"{layer_str}_bias")
        
        return params_list
    
    def gather_gradient_dict(self):
        """
        Gathers a gradient dictionary for the model's parameters. Raises a
        runtime error if any parameters have no gradients.
        
        Returns:
        - gradient_dict (dict): A dictionary of gradients for each parameter.
        """
        
        params_list = self.list_parameters()
        
        gradient_dict = dict()
        for param_name in params_list:
          layer_str, param_str = param_name.split("_")
          layer = getattr(self, layer_str)
          grad = getattr(layer, param_str).grad
          if grad is None:
            raise RuntimeError("No gradient was computed")
          gradient_dict[param_name] = grad.detach().clone() 

        return gradient_dict

def train_bp(args, model, device, train_loader, optimizer, epoch, batch_size, writer):
    model.train()

    training_loss = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

            if args.dry_run:
                break

        training_loss += loss

    training_loss /= (batch_idx + 1)

    writer.add_scalar('loss/training_loss', training_loss.item(), epoch)

    writer.close()
    return training_loss.item()

def test_bp(model, device, test_loader, train_loader, epoch, batch_size, writer):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)
            test_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= (batch_idx + 1)
    test_accuracy = 100. * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: ({:.2f}%)\n'.format(
        test_loss, test_accuracy))

    if epoch is not None:
        writer.add_scalar('loss/test_loss', test_loss, epoch)
        writer.add_scalar('accuracy/test_accuracy', test_accuracy, epoch)
        writer.close()

    return test_loss, test_accuracy


## 🏗️ Model Architecture: ConvNetDKP / DFA

A CNN model modified to support biologically plausible learning.  
Backward weights are either learned (DKP) or fixed random (DFA).  
Includes gradient injection and alignment diagnostics.


In [None]:
class MultipleOptimizer(object):
    def __init__(self, *op):
        self.optimizers = op

    def zero_grad(self):
        for op in self.optimizers:
            op.zero_grad()

    def step(self):
        for op in self.optimizers:
            op.step()

    def state_dict(self):
        for op in self.optimizers:
            op.state_dict()

class OutputTrainingHook(nn.Module):
    #This training hook captures and handles the gradients at the output of the network
    def __init__(self):
        super(OutputTrainingHook, self).__init__()

    def forward(self, input, grad_at_output):
        return OutputHookFunction.apply(input, grad_at_output)

class OutputHookFunction(autograd.Function):
    @staticmethod
    def forward(ctx, input, grad_at_output):
        ctx.save_for_backward(input)
        ctx.in1 = grad_at_output
        return input

    @staticmethod
    def backward(ctx, grad_output):
        grad_at_output = ctx.in1
        input = ctx.saved_variables

        grad_at_output[:grad_output.shape[0], :].data.copy_(grad_output.data)

        return grad_output, None

class DFATrainingHook(nn.Module):
    #This training hook calculates and injects the gradients made by DFA
    def __init__(self, train_mode):
        super(DFATrainingHook, self).__init__()
        self.train_mode = train_mode
        self.is_not_initialized = True
        self.backward_weights = nn.Parameter(requires_grad=True)

    def init_weights(self, dim, device):
        self.backward_weights = nn.Parameter(torch.Tensor(torch.Size(dim)).to(device))
        if self.train_mode == 'DKP':
            self.backward_weights.requires_grad = True
            torch.nn.init.kaiming_uniform_(self.backward_weights)
            # torch.nn.init.zeros_(self.backward_weights)
        elif self.train_mode == 'DFA':
            self.backward_weights.requires_grad = False
            torch.nn.init.kaiming_uniform_(self.backward_weights)

    def forward(self, input, grad_at_output, network_output):
        if self.is_not_initialized and self.train_mode in ['DKP', 'DFA']:
            if len(input.shape) > 2:
                dim = [grad_at_output.shape[1], input.shape[1], input.shape[2], input.shape[3]]
            else:
                dim = [grad_at_output.shape[1], input.shape[1]]
            self.init_weights(dim, input.device)
            self.is_not_initialized = False

        return DFAHookFunction.apply(input, self.backward_weights, grad_at_output, network_output, self.train_mode)

class DFAHookFunction(autograd.Function):
    @staticmethod
    def forward(ctx, input, backward_weights, grad_at_output, network_output, train_mode):
        ctx.save_for_backward(input, backward_weights)
        ctx.in1 = grad_at_output
        ctx.in2 = network_output
        ctx.in3 = train_mode
        return input

    @staticmethod
    def backward(ctx, grad_output):
        grad_at_output            = ctx.in1
        network_output            = ctx.in2
        train_mode                = ctx.in3
        input, backward_weights   = ctx.saved_variables

        grad_at_output = grad_at_output[:grad_output.shape[0], :]
        network_output = network_output[:grad_output.shape[0], :]

        if train_mode == 'DFA':
            B_view = backward_weights.view(-1, prod(backward_weights.shape[1:]))
            grad_output_est = grad_at_output.mm(B_view).view(grad_output.shape)
            return grad_output_est, None, None, None, None

        elif train_mode == 'DKP':
            layer_out_view = input.view(-1, prod(input.shape[1:]))
            B_view = backward_weights.view(-1, prod(backward_weights.shape[1:]))

            grad_output_est = grad_at_output.mm(B_view)
            grad_weights_B = grad_at_output.t().mm(layer_out_view)

            return grad_output_est.view(grad_output.shape), grad_weights_B.view(backward_weights.shape), None, None, None

        return grad_output, None, None, None

class ConvNetDKP(nn.Module):
    def __init__(self, args, in_channels, input_size):
        super(ConvNetDKP, self).__init__()
        self.batch_size = args.batch_size
        self.train_mode = args.train_mode
        self.device = args.device

        # Initialize hooks and layers
        self.conv1 = nn.Conv2d(in_channels, 32, 3, 1)
        self.conv1_bn = nn.BatchNorm2d(32)
        self.conv1_dfa = DFATrainingHook(self.train_mode)

        self.conv2 = nn.Conv2d(32, 32, 3, 1)
        self.conv2_bn = nn.BatchNorm2d(32)
        self.conv2_dfa = DFATrainingHook(self.train_mode)

        # 1) CPU shape inference
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, input_size, input_size)  # CPU tensor
            x = self.conv1(dummy)
            x = self.conv1_bn(x)
            x = F.relu(x)
            self._conv1_shape = list(x.shape[1:])  # Save conv1 output shape

            x = self.conv2(x)
            x = self.conv2_bn(x)
            x = F.relu(x)
            self._conv2_shape = list(x.shape[1:])  # Save conv2 output shape

            x = F.max_pool2d(x, 2)
            flatten_dim = x.numel() // x.shape[0]

        # Define fully connected layers
        self.fc1 = nn.Linear(flatten_dim, 128)
        self.fc1_dfa = DFATrainingHook(self.train_mode)

        self.fc2 = nn.Linear(128, 10)
        self.output_hook = OutputTrainingHook()

        # 2) Move model to target device
        self.to(self.device)

        # 3) Allocate buffers directly on the correct device using cached shapes
        self.grad_at_output = torch.zeros(self.batch_size, 10, device=self.device)
        self.network_output = torch.zeros(self.batch_size, 10, device=self.device)

        self.conv1_out = torch.zeros([self.batch_size] + self._conv1_shape, requires_grad=False, device=self.device)
        self.conv2_out = torch.zeros([self.batch_size] + self._conv2_shape, requires_grad=False, device=self.device)
        self.fc1_out = torch.zeros(self.batch_size, 128, requires_grad=False, device=self.device)

    def forward(self, x):
        batch_size = x.shape[0]
    
        # Dynamically reallocate tensors if batch size has changed
        if self.grad_at_output.shape[0] != batch_size:
            self.grad_at_output = torch.zeros(batch_size, 10, device=x.device)
            self.network_output = torch.zeros(batch_size, 10, device=x.device)
            self.conv1_out = torch.zeros([batch_size] + self._conv1_shape, device=x.device)
            self.conv2_out = torch.zeros([batch_size] + self._conv2_shape, device=x.device)
            self.fc1_out = torch.zeros(batch_size, 128, device=x.device)
    
        x = self.conv1(x)
        x = self.conv1_bn(x)
        x = F.relu(x)
        x = self.conv1_dfa(x, self.grad_at_output, self.network_output)
        if x.requires_grad:
            self.conv1_out[:batch_size].data.copy_(x.data)
    
        x = self.conv2(x)
        x = self.conv2_bn(x)
        x = F.relu(x)
        x = self.conv2_dfa(x, self.grad_at_output, self.network_output)
        if x.requires_grad:
            self.conv2_out[:batch_size].data.copy_(x.data)
    
        x = F.max_pool2d(x, 2)
        x = torch.flatten(x, 1)
    
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc1_dfa(x, self.grad_at_output, self.network_output)
        if x.requires_grad:
            self.fc1_out[:batch_size].data.copy_(x.data)
    
        x = self.fc2(x)
    
        if x.requires_grad:
            x = self.output_hook(x, self.grad_at_output)
            self.network_output[:batch_size].data.copy_(x.data)
    
        return x

    def list_parameters(self):

        params_list = list()
        
        for layer_str in ["conv1", "conv2", "fc1", "fc2"]:
          params_list.append(f"{layer_str}_weight")
          # if self.bias:
          #   params_list.append(f"{layer_str}_bias")
        
        return params_list
    
    def gather_gradient_dict(self):
        
        params_list = self.list_parameters()
        
        gradient_dict = dict()
        for param_name in params_list:
          layer_str, param_str = param_name.split("_")
          layer = getattr(self, layer_str)
          grad = getattr(layer, param_str).grad
          if grad is None:
            raise RuntimeError("No gradient was computed")
          gradient_dict[param_name] = grad.detach().clone() 

        return gradient_dict

def train_dkp(args, model, device, train_loader, optimizer, epoch, batch_size, writer):
    model.train()

    training_loss = 0
    alignment_fc1 = 0
    alignment_conv1 = 0
    alignment_conv2 = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()

        output = model(data)
        loss = F.cross_entropy(output, target)
        loss.backward()

        align_fc1 = torch.mean(F.cosine_similarity(
            output.detach(),
            model.fc1_out[:output.shape[0], :].mm(model.fc1_dfa.backward_weights.t())
        ))
        align_conv1 = torch.mean(F.cosine_similarity(
            output.detach(),
            model.conv1_out[:output.shape[0]].view(output.shape[0], -1).mm(
                model.conv1_dfa.backward_weights.view(model.conv1_dfa.backward_weights.shape[0], -1).t()
            )
        ))
        align_conv2 = torch.mean(F.cosine_similarity(
            output.detach(),
            model.conv2_out[:output.shape[0]].view(output.shape[0], -1).mm(
                model.conv2_dfa.backward_weights.view(model.conv2_dfa.backward_weights.shape[0], -1).t()
            )
        ))

        alignment_fc1 += align_fc1.item()
        alignment_conv1 += align_conv1.item()
        alignment_conv2 += align_conv2.item()

        optimizer.step()

        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))

            if args.dry_run:
                break

        training_loss += loss

    training_loss /= (batch_idx + 1)

    writer.add_scalar('loss/training_loss', training_loss.item(), epoch)
    writer.add_scalar('cos_alignment/fc1_to_output', alignment_fc1 / (batch_idx + 1), epoch)
    writer.add_scalar('cos_alignment/conv1_to_output', alignment_conv1 / (batch_idx + 1), epoch)
    writer.add_scalar('cos_alignment/conv2_to_output', alignment_conv2 / (batch_idx + 1), epoch)

    writer.close()
    return training_loss.item()

def test_dkp(model, device, test_loader, train_loader, epoch, batch_size, writer):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(test_loader):
            data, target = data.to(device), target.to(device)

            output = model(data)
            test_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= (batch_idx + 1)
    test_accuracy = 100. * correct / len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: ({:.2f}%)\n'.format(
        test_loss, test_accuracy))

    if epoch is not None:
        writer.add_scalar('loss/test_loss', test_loss, epoch)
        writer.add_scalar('accuracy/test_accuracy', test_accuracy, epoch)
        writer.close()

    return test_loss, test_accuracy

## 🏃‍♂️ Training Loop: Backpropagation

The training function for the BP model.  
Tracks loss and accuracy over epochs, and saves CSV logs and checkpoints.


In [None]:
def main_bp(args):

    train_loader, test_loader = fetch_dataloaders(args)

    images, labels = next(iter(train_loader)) 
    in_channels = images.shape[1]
    input_height = images.shape[2]

    model = ConvNetBP(in_channels, input_height).to(args.device)

    optimizer = optim.Adadelta(model.parameters(), lr=args.bp_lr)

    writer = SummaryWriter(log_dir='results/conv_bp')
    logger = CSVLogger(['Epoch', 'Training Loss', 'Test Loss', 'Test Accuracy'], args)

    # optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=1e-6, momentum=0.9, nesterov=True)
    scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    for epoch in range(1, args.epochs + 1):
        training_loss = train_bp(args, model, device, train_loader, optimizer, epoch, args.batch_size, writer)
        test_loss, test_accuracy = test_bp(model, device, test_loader, train_loader, epoch, args.batch_size, writer)
        logger.save_values(epoch, training_loss, test_loss, test_accuracy)

        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), f"model_checkpoints/BP_{args.dataset}_epoch{args.epochs}_s{args.seed}.pth")

In [None]:
# train BP on CIFAR10 and save results on CSV and save model params

# args.train_mode = 'BP'
# args.dataset = 'CIFAR10'

# main_bp(args)

In [None]:
# train BP on MNIST and save results on CSV and save model params

# args.train_mode = 'BP'
# args.dataset = 'MNIST'

# main_bp(args)

## 🏃‍♂️ Training Loop: DKP / DFA

Shared training loop for DKP and DFA models.  
Includes:
- Weight separation (forward vs backward)
- Hook registration
- Cosine similarity diagnostics
- TensorBoard logging


In [None]:
def main_dkp(args):

    train_loader, test_loader = fetch_dataloaders(args)

    images, labels = next(iter(train_loader)) 
    in_channels = images.shape[1]
    input_height = images.shape[2]
    
    model = ConvNetDKP(args, in_channels, input_height).to(args.device)

    writer = SummaryWriter(log_dir='results/conv_dkp')
    logger = CSVLogger(['Epoch', 'Training Loss', 'Test Loss', 'Test Accuracy'], args)

    if args.train_mode == 'DKP':
        test_dkp(model, args.device, test_loader, train_loader, None, None, None)

        forward_params = []
        backward_params = []
        for name, param in model.named_parameters():
            if "backward" in name:
                backward_params.append(param)
            else:
                forward_params.append(param)

        forward_optimizer = optim.SGD([{'params': forward_params}], lr=args.lr, weight_decay=1e-6, momentum=0.9, nesterov=True)
        backward_optimizer = optim.Adam([{'params': backward_params}], lr=args.b_lr, weight_decay=1e-6)
        optimizer = MultipleOptimizer(forward_optimizer, backward_optimizer)
        scheduler = StepLR(backward_optimizer, step_size=1, gamma=args.gamma)
    else:
        optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=1e-6, momentum=0.9, nesterov=True)
        scheduler = StepLR(optimizer, step_size=1, gamma=args.gamma)

    for epoch in range(1, args.epochs + 1):
        training_loss = train_dkp(args, model, args.device, train_loader, optimizer, epoch, args.batch_size, writer)
        test_loss, test_accuracy = test_dkp(model, args.device, test_loader, train_loader, epoch, args.batch_size, writer)
        logger.save_values(epoch, training_loss, test_loss, test_accuracy)

        scheduler.step()

    if args.save_model:
        torch.save(model.state_dict(), f"model_checkpoints/{args.train_mode}_{args.dataset}_epoch{args.epochs}_s{args.seed}.pt")


In [None]:
# train DKP on CIFAR10 and save results on CSV and save model params

# args.train_mode = 'DKP'
# args.dataset = 'CIFAR10'

# main_dkp(args)

In [None]:
# train DKP on MNIST and save results on CSV and save model params

# args.train_mode = 'DKP'
# args.dataset = 'MNIST'

# main_dkp(args)

In [None]:
# train DFA on CIFAR10 and save results on CSV and save model params

# args.train_mode = 'DFA'
# args.dataset = 'CIFAR10'

# main_dkp(args)

In [None]:
# train DFA on MNIST and save results on CSV and save model params

# args.train_mode = 'DFA'
# args.dataset = 'MNIST'

# main_dkp(args)

## 🧪 Evaluation on CIFAR-10

This section includes full evaluation pipeline for CIFAR-10:
- RDM visualization
- MDS (representation trajectory)
- Adversarial robustness using FGSM
- Accuracy under adversarial attack
- Feature space visualization using t-SNE
- SNR computation


In [None]:
args.dataset = 'CIFAR10'
train_loader, test_loader = fetch_dataloaders(args)
imgs, labels = sample_images_cifar(test_loader, n=20)
models_rdms, _, _ = grab_features_and_rdms(args, imgs, labels)
plot_all_rdms(models_rdms, method_names, n_cols=5)

In [None]:
# imgs, labels = sample_images_cifar(test_loader, n=20)
imgs, labels = next(iter(test_loader))
rep_path_all(args, method_names, model_colors, imgs, labels)

In [None]:
imgs, labels = sample_images_cifar(test_loader, n=20)
models_rdms, _, adv_accuracies = grab_features_and_rdms(args, imgs, labels, adversarial=True)
plot_all_rdms(models_rdms, method_names, n_cols=5, adversarial=True)

In [None]:
for i, acc in enumerate(adv_accuracies):
    print(f"Model: {method_names[i]} | Accuracy: {100 * acc :.1f}%")

In [None]:
imgs, labels = sample_images_cifar(test_loader, n=50)
_, models_feats, _ = grab_features_and_rdms(args, imgs, labels)
plot_dim_reduction_all(models_feats, labels, transformer_funcs=['t-SNE'])

In [None]:
imgs, labels = sample_images_cifar(test_loader, n=10)
SNR_dict = dict()
models_dict = load_models(args, imgs)

for name, model in models_dict.items():
    model_SNR_dict = compute_gradient_SNR(args, model, test_loader.dataset, max_samples=1000)
    SNR_dict[name] = [SNR for SNR in model_SNR_dict.values()]

plot_gradient_SNRs(SNR_dict);

## 🧪 Evaluation on MNIST

We replicate the CIFAR-10 analysis on MNIST:
- Internal representation comparison
- Adversarial image generation and testing
- Feature visualization via t-SNE
- Gradient SNR for BP, DKP, DFA


In [None]:
args.dataset = 'MNIST'
train_loader, test_loader = fetch_dataloaders(args)
imgs, labels = sample_images_mnist(test_loader, n=20)
models_rdms, _, _ = grab_features_and_rdms(args, imgs, labels)
plot_all_rdms(models_rdms, method_names, n_cols=5)

In [None]:
# imgs, labels = sample_images_mnist(test_loader, n=20)
imgs, labels = next(iter(test_loader))
rep_path_all(args, method_names, model_colors, imgs, labels)

In [None]:
imgs, labels = sample_images_mnist(test_loader, n=20)
models_rdms, _, adv_accuracies = grab_features_and_rdms(args, imgs, labels, adversarial=True)
plot_all_rdms(models_rdms, method_names, n_cols=5, adversarial=True)

In [None]:
for i, acc in enumerate(adv_accuracies):
    print(f"Model: {method_names[i]} | Accuracy: {100 * acc :.1f}%")

In [None]:
imgs, labels = sample_images_mnist(test_loader, n=50)
_, models_feats, _ = grab_features_and_rdms(args, imgs, labels)
plot_dim_reduction_all(models_feats, labels, transformer_funcs=['t-SNE'])

In [None]:
imgs, labels = sample_images_mnist(test_loader, n=10)
SNR_dict = dict()
models_dict = load_models(args, imgs)

for name, model in models_dict.items():
    model_SNR_dict = compute_gradient_SNR(args, model, test_loader.dataset, max_samples=1000)
    SNR_dict[name] = [SNR for SNR in model_SNR_dict.values()]

plot_gradient_SNRs(SNR_dict);