In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import tqdm 

import numpy as np
from scipy.sparse.csgraph import connected_components


# -------------------------
# Configurable CNN with Batch Normalization and Hidden Activations Collection
# -------------------------
class ConfigurableCNN(nn.Module):
    def __init__(self, conv_channels, fc_hidden_units=512, dropout_p=0.25,
                 num_classes=10, input_size=32, input_channels=3, use_batchnorm=True):
        """
        Args:
            conv_channels (list of int): List of output channels for each convolutional layer.
            fc_hidden_units (int): Number of neurons in the hidden fully connected layer.
            dropout_p (float): Dropout probability.
            num_classes (int): Number of output classes.
            input_size (int): Height/width of the input images (assumed square).
            input_channels (int): Number of channels in the input images (3 for colored images).
            use_batchnorm (bool): Whether to use batch normalization after each convolution.
        """
        super(ConfigurableCNN, self).__init__()
        self.use_batchnorm = use_batchnorm
        self.conv_layers = nn.ModuleList()
        if self.use_batchnorm:
            self.bn_layers = nn.ModuleList()
        
        in_channels = input_channels  # For colored images, this is 3.
        self.num_pool = len(conv_channels)  # One pooling per conv layer
        
        # Create convolutional layers along with optional batch normalization.
        for out_channels in conv_channels:
            self.conv_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            if self.use_batchnorm:
                self.bn_layers.append(nn.BatchNorm2d(out_channels))
            in_channels = out_channels
        
        # Define a max pooling layer (2x2) applied after each conv block.
        self.pool = nn.MaxPool2d(2, 2)
        
        # Compute the spatial size after all pooling operations.
        final_size = input_size // (2 ** self.num_pool)
        self.flattened_size = conv_channels[-1] * final_size * final_size
        
        # Fully connected layers.
        self.fc1 = nn.Linear(self.flattened_size, fc_hidden_units)
        self.fc2 = nn.Linear(fc_hidden_units, num_classes)
        
        # Dropout layer for regularization.
        self.dropout = nn.Dropout(dropout_p)
        self.act = F.relu

    def forward(self, x, return_hidden=False):
        hidden_activations = []  # List to collect hidden activations

        # Pass through each convolutional layer
        for idx, conv in enumerate(self.conv_layers):
            x = conv(x)
            if self.use_batchnorm:
                x = self.bn_layers[idx](x)
            if return_hidden:
                hidden_activations.append(x)
            x = self.act(x)
            x = self.pool(x)
        
        x = self.dropout(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor

        # First fully connected layer with ReLU
        x = self.fc1(x)
        x = self.act(x)
        # if return_hidden:
        #     hidden_activations.append(x)
        x = self.dropout(x)

        # Final fully connected layer (logits)
        x = self.fc2(x)
        
        if return_hidden:
            return x, hidden_activations
        return x


def eval_features(model, testloader, thresh=0.95, tol=1e-10):
    model.eval()
    sample_inputs, _ = next(iter(testloader))
    sample_inputs = sample_inputs.to(device)
    with torch.no_grad():
        _, hidden_activations = model(sample_inputs, return_hidden=True)
        
    for act in hidden_activations:
        # Reshape: (batch, channels, H, W) --> (channels, batch * H * W)
        A = act.detach().cpu().transpose(0, 1).flatten(1)
        # Normalize each row (avoid division by zero with a small epsilon)
        A = A / (A.norm(dim=1, keepdim=True) + tol)
        # Compute cosine similarity matrix
        C = A @ A.t()
        soft_rank = torch.trace(C)**2 / torch.trace(C @ C)
        # Remove self-similarity by zeroing the diagonal and take absolute value.
        C.fill_diagonal_(0)
        C = C.abs()
        # Create an adjacency matrix by thresholding.
        Adj = (C > thresh).float()
        
        # Convert to numpy array (scipy works with numpy arrays)
        Adj_np = Adj.numpy()
        # Compute the number of connected components using SciPy's stable routine.
        n_components, labels = connected_components(csgraph=Adj_np, directed=False)
        print(f'# CC  = {n_components} / {Adj_np.shape[0]}, soft rank = {soft_rank:.3f} / {Adj_np.shape[0]}')




# -------------------------
# Model Configuration and Instantiation
# -------------------------
conv_channels = [128]*7  # Example configuration
fc_hidden_units = 128
dropout_p = 0.25
use_batchnorm = True
input_size = 128

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ConfigurableCNN(conv_channels, fc_hidden_units, dropout_p,
                      num_classes=10, input_size=input_size, input_channels=3,  # Use 3 for colored images
                      use_batchnorm=use_batchnorm).to(device)

# -------------------------
# Data Preparation (Colored CIFAR-10 resized to 64x64)
# -------------------------
transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),  # Resize images to 64x64
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize for 3 channels
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                         shuffle=False, num_workers=2)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

# -------------------------
# Training Loop
# -------------------------
num_epochs = 100  # Adjust the number of epochs as needed
for epoch in range(num_epochs):
    net.train()
    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(trainloader, 0),total=len(trainloader)):
        inputs, labels = data[0].to(device), data[1].to(device)
        
        optimizer.zero_grad()  # Zero the parameter gradients
        outputs = net(inputs)  # Forward pass (default: do not collect hidden activations)
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update parameters
        
        running_loss += loss.item()
        # if i % 100 == 99:  # Print every 100 mini-batches
    print(f'Epoch {epoch + 1}, Train Loss: {running_loss / len(trainloader):.3f}')
    running_loss = 0.0

    # Optionally, if you have a feature evaluation function:
    eval_features(net, testloader, thresh=0.9)
    
    # -------------------------
    # Validation after each epoch
    # -------------------------
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy on test set after epoch {epoch + 1}: {accuracy:.2f}%')

print("Training complete!")

# -------------------------
# Example: Obtaining Hidden Activations
# -------------------------
# To obtain the hidden activations for a batch of inputs:
net.eval()
sample_inputs, _ = next(iter(testloader))
sample_inputs = sample_inputs.to(device)
with torch.no_grad():
    output, hidden_activations = net(sample_inputs, return_hidden=True)
print("Collected {} hidden activations.".format(len(hidden_activations)))


Files already downloaded and verified
Files already downloaded and verified


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:32<00:00, 11.88it/s]

Epoch 1, Train Loss: 1.704





# CC  = 11 / 128, soft rank = 1.558 / 128
# CC  = 58 / 128, soft rank = 3.639 / 128
# CC  = 99 / 128, soft rank = 6.255 / 128
# CC  = 121 / 128, soft rank = 9.526 / 128
# CC  = 119 / 128, soft rank = 9.317 / 128
# CC  = 118 / 128, soft rank = 6.739 / 128
# CC  = 95 / 128, soft rank = 5.652 / 128
Accuracy on test set after epoch 1: 48.88%


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 391/391 [00:32<00:00, 11.98it/s]

Epoch 2, Train Loss: 1.167





# CC  = 14 / 128, soft rank = 2.290 / 128
# CC  = 71 / 128, soft rank = 5.417 / 128
# CC  = 110 / 128, soft rank = 8.334 / 128
# CC  = 126 / 128, soft rank = 11.632 / 128
# CC  = 128 / 128, soft rank = 11.152 / 128
# CC  = 128 / 128, soft rank = 5.240 / 128
# CC  = 113 / 128, soft rank = 5.048 / 128
Accuracy on test set after epoch 2: 65.66%


 91%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎             | 355/391 [00:29<00:03, 11.98it/s]

In [7]:
trainset

AttributeError: 'CIFAR10' object has no attribute 'iloc'

In [51]:
# torch.linalg.eigvalsh(C)
# C = C / torch.diag(C).mean()
C

_LinAlgError: linalg.eigh: The algorithm failed to converge because the input matrix is ill-conditioned or has too many repeated eigenvalues (error code: 510).

In [24]:
import torch
import torch.nn as nn

# Create a toy tensor with shape (N, C, H, W)
# For example, N=2 (batch size), C=3 (channels), H=W=4 (spatial dimensions)
toy_input = torch.tensor([
    [
        [[1.0,  2.0,  3.0,  4.0],
         [5.0,  6.0,  7.0,  8.0],
         [9.0, 10.0, 11.0, 12.0],
         [13.0, 14.0, 15.0, 16.0]],
        
        [[16.0, 15.0, 14.0, 13.0],
         [12.0, 11.0, 10.0,  9.0],
         [8.0,   7.0,  6.0,  5.0],
         [4.0,   3.0,  2.0,  1.0]],
        
        [[1.0,  3.0,  5.0,  7.0],
         [9.0, 11.0, 13.0, 15.0],
         [17.0, 19.0, 21.0, 23.0],
         [25.0, 27.0, 29.0, 31.0]]
    ],
    [
        [[2.0,  4.0,  6.0,  8.0],
         [10.0, 12.0, 14.0, 16.0],
         [18.0, 20.0, 22.0, 24.0],
         [26.0, 28.0, 30.0, 32.0]],
        
        [[32.0, 30.0, 28.0, 26.0],
         [24.0, 22.0, 20.0, 18.0],
         [16.0, 14.0, 12.0, 10.0],
         [8.0,   6.0,  4.0,  2.0]],
        
        [[2.0,  6.0, 10.0, 14.0],
         [18.0, 22.0, 26.0, 30.0],
         [34.0, 38.0, 42.0, 46.0],
         [50.0, 54.0, 58.0, 62.0]]
    ]
])
print("Toy input shape:", toy_input.shape)

# Define a BatchNorm2d layer for 3 channels.
# We set affine=False so that no additional scaling (gamma) or shifting (beta) is applied;
# this lets us see the raw normalization: (x - mean) / sqrt(var + eps).
bn = nn.BatchNorm2d(num_features=3, affine=False)

# Set the BatchNorm layer to training mode so it uses the batch statistics of toy_input.
bn.train()

# Apply BatchNorm2d on the toy tensor.
toy_output = bn(toy_input)

# Print the small epsilon value used for numerical stability.
print("\nBatchNorm2d epsilon (eps):", bn.eps)

# Helper function to compute per-channel statistics over the batch and spatial dimensions.
def compute_channel_stats(x):
    # x shape: (N, C, H, W)
    N, C, H, W = x.shape
    stats = {}
    for c in range(C):
        # Compute mean and variance over dimensions (N, H, W) for each channel.
        channel_data = x[:, c, :, :]
        mean = channel_data.mean().item()
        var = channel_data.var(unbiased=False).item()  # population variance
        stats[c] = (mean, var)
    return stats

# Compute and display the input statistics.
input_stats = compute_channel_stats(toy_input)
print("\nInput statistics per channel:")
for c in range(3):
    mean, var = input_stats[c]
    print(f" Channel {c}: mean = {mean:.4f}, var = {var:.4f}")

# Compute and display the output statistics.
output_stats = compute_channel_stats(toy_output)
print("\nOutput statistics per channel after BatchNorm2d:")
for c in range(3):
    mean, var = output_stats[c]
    print(f" Channel {c}: mean = {mean:.4f}, var = {var:.4f}")

# For clarity, here's what BatchNorm2d is doing:
# For each channel, it subtracts the mean (computed over N, H, W) and divides by the standard deviation.
# That is:
#    normalized = (x - mean) / sqrt(variance + eps)
# As you can see from the output statistics, the normalized tensor has (approximately) zero mean and unit variance.


Toy input shape: torch.Size([2, 3, 4, 4])

BatchNorm2d epsilon (eps): 1e-05

Input statistics per channel:
 Channel 0: mean = 12.7500, var = 71.1875
 Channel 1: mean = 12.7500, var = 71.1875
 Channel 2: mean = 24.0000, var = 276.5000

Output statistics per channel after BatchNorm2d:
 Channel 0: mean = -0.0000, var = 1.0000
 Channel 1: mean = 0.0000, var = 1.0000
 Channel 2: mean = 0.0000, var = 1.0000


In [25]:
# Compute and display the input statistics.
input_stats = compute_channel_stats(A)
print("\nInput statistics per channel:")
for c in range(3):
    mean, var = input_stats[c]
    print(f" Channel {c}: mean = {mean:.4f}, var = {var:.4f}")


Input statistics per channel:
 Channel 0: mean = -0.1977, var = 1.1304
 Channel 1: mean = -0.2741, var = 0.9774
 Channel 2: mean = -0.2215, var = 0.8813
