In [None]:
import numpy as np

from tqdm import tqdm
from torchvision.models import resnet

from .functional_information import FunctionalInformationInterpreter

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F


In [None]:

# Define the Neural Network
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # Input: (3, 32, 32) -> Output: (6, 28, 28)
        self.pool1 = nn.MaxPool2d(2, 2)  # Output: (6, 14, 14)
        self.conv2 = nn.Conv2d(6, 16, 5)  # Output: (16, 10, 10)
        self.pool2 = nn.MaxPool2d(2, 2)  # Output: (16, 5, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # Fully connected layer
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
        self.relu4 = nn.ReLU()

    def forward(self, x):
        x = self.pool1(self.relu1(self.conv1(x)))
        x = self.pool2(self.relu2(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)  # Flatten before feeding into FC layers
        x = self.relu3(self.fc1(x))
        x = self.relu4(self.fc2(x))
        x = self.fc3(x)
        return x


In [None]:
def main():
    # Simulated dataset with random pixel values and labels
    np.random.seed(42)
    num_samples = 10
    num_classes = 3
    height, width = 32, 32
    num_pixels = height * width

    # Simulated image data (after pipeline transformation) (batch, channels, height, width)
    inputs = np.random.rand(num_samples, 3, height, width).astype(np.float32)  # Simulating RGB image batch
    labels = np.random.choice(num_classes, num_samples)  # Random labels (0, 1, 2)

    net = Net()

    interpreter = FunctionalInformationInterpreter(net, device="cpu")
    corr_matrices = interpreter.init_corr_mat(inputs, labels, visual=True)

    # Validate correctness
    for class_label, matrix in corr_matrices.items():
        assert matrix.shape == (num_pixels, num_pixels), "Incorrect shape"
        assert np.allclose(matrix, matrix.T, atol=1e-5), "Matrix is not symmetric"
        assert np.all(np.linalg.eigvals(matrix) > 0), "Matrix is not positive-definite"

    print("\nAll tests passed!")