In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import tqdm

import numpy as np
from scipy.sparse.csgraph import connected_components

# -------------------------
# Configurable CNN with Batch Normalization and Hidden Activations Collection
# -------------------------
class ConfigurableCNN(nn.Module):
    def __init__(self, conv_channels, fc_hidden_units=512, dropout_p=0.25,
                 num_classes=10, input_size=32, input_channels=3, use_batchnorm=True):
        """
        Args:
            conv_channels (list of int): List of output channels for each convolutional layer.
            fc_hidden_units (int): Number of neurons in the hidden fully connected layer.
            dropout_p (float): Dropout probability.
            num_classes (int): Number of output classes.
            input_size (int): Height/width of the input images (assumed square).
            input_channels (int): Number of channels in the input images.
            use_batchnorm (bool): Whether to use batch normalization after each convolution.
        """
        super(ConfigurableCNN, self).__init__()
        self.use_batchnorm = use_batchnorm
        self.conv_layers = nn.ModuleList()
        if self.use_batchnorm:
            self.bn_layers = nn.ModuleList()
        
        in_channels = input_channels  # For colored images, this is 3.
        self.num_pool = len(conv_channels)  # One pooling per conv layer
        
        # Create convolutional layers along with optional batch normalization.
        for out_channels in conv_channels:
            self.conv_layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
            if self.use_batchnorm:
                self.bn_layers.append(nn.BatchNorm2d(out_channels))
            in_channels = out_channels
        
        # Define a max pooling layer (2x2) applied after each conv block.
        self.pool = nn.MaxPool2d(2, 2)
        
        # Compute the spatial size after all pooling operations.
        final_size = input_size // (2 ** self.num_pool)
        self.flattened_size = conv_channels[-1] * final_size * final_size
        
        # Fully connected layers.
        self.fc1 = nn.Linear(self.flattened_size, fc_hidden_units)
        self.fc2 = nn.Linear(fc_hidden_units, num_classes)
        
        # Dropout layer for regularization.
        self.dropout = nn.Dropout(dropout_p)
        self.act = F.tanh  # You can change this activation if desired

    def forward(self, x, return_hidden=False):
        hidden_activations = []  # List to collect hidden activations

        # Pass through each convolutional layer
        for idx, conv in enumerate(self.conv_layers):
            x = conv(x)
            if self.use_batchnorm:
                x = self.bn_layers[idx](x)
            if return_hidden:
                hidden_activations.append(x.detach().cpu())
            x = self.act(x)
            x = self.pool(x)
        
        # x = self.dropout(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor

        # First fully connected layer with activation
        x = self.fc1(x)
        x = self.act(x)
        # x = self.dropout(x)

        # Final fully connected layer (logits)
        x = self.fc2(x)
        
        if return_hidden:
            return x, hidden_activations
        return x


def eval_features(model, testloader, thresh=0.9, tol=1e-10, rank_atol=1e-2, dead_tol=0.1):
    model.eval()
    sample_inputs, _ = next(iter(testloader))
    sample_inputs = sample_inputs.to(device)
    with torch.no_grad():
        _, hidden_activations = model(sample_inputs, return_hidden=True)
        
    for act in hidden_activations:
        # Reshape: (batch, channels, H, W) --> (channels, batch * H * W)
        A = act.transpose(0, 1).flatten(1)
        # Normalize each row (avoid division by zero with a small epsilon)
        # A = A - A.mean(dim=1,keepdim=True)
        A = A / (A.norm(dim=1, keepdim=True) + tol)
        stds = A.std(dim=1) / A.abs().mean(dim=1)
        # print(stds.shape, stds)
        dead_features = (stds<dead_tol).sum()
        # Compute cosine similarity matrix
        C = A @ A.t()
        rank = torch.linalg.matrix_rank(C, atol=rank_atol)
        soft_rank = torch.trace(C)**2 / torch.trace(C @ C)
        # Remove self-similarity by zeroing the diagonal and take absolute value.
        C.fill_diagonal_(0)
        C = C.abs()
        # Create an adjacency matrix by thresholding.
        Adj = (C > thresh).float()
        
        # Convert to numpy array (scipy works with numpy arrays)
        Adj_np = Adj.numpy()
        # Compute the number of connected components using SciPy's stable routine.
        n_components, labels = connected_components(csgraph=Adj_np, directed=False)
        R = Adj_np.shape[0]
        print(f'# CC  = {n_components}, e-rank = {rank}, soft rank = {soft_rank:.3f}, dead features = {dead_features} / {R}')



# -------------------------
# Data Preparation (Tiny ImageNet with Selected Classes)
# -------------------------
# Specify which classes to use.
# For Tiny ImageNet, the classes are the subfolder names in the training folder.
# Here we select classes by their numeric index (after ImageFolder sorts the folders).
# For example, to use the first 5 classes:
selected_classes = range(20)  # Set to None to use all available classes

# Define the image size for resizing
input_size = 128

transform = transforms.Compose([
    transforms.Resize((input_size, input_size)),  # Resize images to input_size x input_size
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load the Tiny ImageNet datasets using ImageFolder.
# Adjust the root paths to where you have Tiny ImageNet stored.
trainset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/train', transform=transform)
testset = torchvision.datasets.ImageFolder(root='./tiny-imagenet-200/val', transform=transform)

# If selected_classes is specified, filter the dataset to include only those classes.
if selected_classes is not None:
    train_indices = [i for i, (_, label) in enumerate(trainset.samples) if label in selected_classes]
    trainset = torch.utils.data.Subset(trainset, train_indices)
    test_indices = [i for i, (_, label) in enumerate(testset.samples) if label in selected_classes]
    testset = torch.utils.data.Subset(testset, test_indices)
    num_used_classes = len(selected_classes)
else:
    num_used_classes = len(trainset.classes)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                         shuffle=False, num_workers=2)

# -------------------------
# Model Configuration and Instantiation
# -------------------------
conv_channels = [128] * 7  # Example configuration
fc_hidden_units = 128
dropout_p = 0.25
use_batchnorm = True

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
net = ConfigurableCNN(conv_channels, fc_hidden_units, dropout_p,
                      num_classes=num_used_classes, input_size=input_size, input_channels=3,
                      use_batchnorm=use_batchnorm).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.01)

# -------------------------
# Training Loop
# -------------------------
num_epochs = 100  # Adjust the number of epochs as needed
for epoch in range(num_epochs):
    net.train()
    running_loss = 0.0
    for i, data in tqdm.tqdm(enumerate(trainloader, 0), total=len(trainloader)):
        inputs, labels = data[0].to(device), data[1].to(device)
        
        optimizer.zero_grad()  # Zero the parameter gradients
        outputs = net(inputs)   # Forward pass
        loss = criterion(outputs, labels)  # Compute loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update parameters
        
        running_loss += loss.item()
    print(f'Epoch {epoch + 1}, Train Loss: {running_loss / len(trainloader):.3f}')
    running_loss = 0.0

    # Optionally, evaluate feature connectivity
    eval_features(net, trainloader, thresh=0.95)
    
    # -------------------------
    # Validation after each epoch
    # -------------------------
    net.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = net(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print(f'Accuracy on test set after epoch {epoch + 1}: {accuracy:.2f}%')

print("Training complete!")

# -------------------------
# Example: Obtaining Hidden Activations
# -------------------------
net.eval()
sample_inputs, _ = next(iter(testloader))
sample_inputs = sample_inputs.to(device)
with torch.no_grad():
    output, hidden_activations = net(sample_inputs, return_hidden=True)
print("Collected {} hidden activations.".format(len(hidden_activations)))


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:02<00:00,  9.63it/s]

Epoch 1, Train Loss: 1.505





# CC  = 56, e-rank = 17, soft rank = 2.387, dead features = 0 / 128
# CC  = 26, e-rank = 19, soft rank = 1.816, dead features = 0 / 128
# CC  = 14, e-rank = 19, soft rank = 1.658, dead features = 0 / 128
# CC  = 19, e-rank = 23, soft rank = 2.294, dead features = 0 / 128
# CC  = 52, e-rank = 33, soft rank = 3.200, dead features = 0 / 128
# CC  = 83, e-rank = 35, soft rank = 2.983, dead features = 0 / 128
# CC  = 94, e-rank = 30, soft rank = 3.216, dead features = 0 / 128
Accuracy on test set after epoch 1: 3.90%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.27it/s]

Epoch 2, Train Loss: 1.241





# CC  = 38, e-rank = 16, soft rank = 2.266, dead features = 0 / 128
# CC  = 30, e-rank = 20, soft rank = 1.963, dead features = 0 / 128
# CC  = 6, e-rank = 20, soft rank = 2.000, dead features = 0 / 128
# CC  = 22, e-rank = 23, soft rank = 3.018, dead features = 0 / 128
# CC  = 37, e-rank = 25, soft rank = 2.952, dead features = 0 / 128
# CC  = 82, e-rank = 30, soft rank = 3.049, dead features = 0 / 128
# CC  = 109, e-rank = 29, soft rank = 3.387, dead features = 0 / 128
Accuracy on test set after epoch 2: 13.30%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.25it/s]

Epoch 3, Train Loss: 1.164





# CC  = 40, e-rank = 15, soft rank = 2.070, dead features = 0 / 128
# CC  = 24, e-rank = 19, soft rank = 1.745, dead features = 0 / 128
# CC  = 5, e-rank = 18, soft rank = 1.645, dead features = 0 / 128
# CC  = 14, e-rank = 24, soft rank = 2.452, dead features = 0 / 128
# CC  = 48, e-rank = 27, soft rank = 2.978, dead features = 0 / 128
# CC  = 116, e-rank = 34, soft rank = 4.350, dead features = 0 / 128
# CC  = 125, e-rank = 34, soft rank = 4.332, dead features = 0 / 128
Accuracy on test set after epoch 3: 23.06%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.31it/s]

Epoch 4, Train Loss: 1.141





# CC  = 38, e-rank = 15, soft rank = 2.344, dead features = 0 / 128
# CC  = 29, e-rank = 20, soft rank = 2.109, dead features = 0 / 128
# CC  = 13, e-rank = 19, soft rank = 2.207, dead features = 0 / 128
# CC  = 21, e-rank = 22, soft rank = 3.202, dead features = 0 / 128
# CC  = 37, e-rank = 24, soft rank = 3.335, dead features = 0 / 128
# CC  = 115, e-rank = 32, soft rank = 4.699, dead features = 0 / 128
# CC  = 120, e-rank = 34, soft rank = 3.541, dead features = 0 / 128
Accuracy on test set after epoch 4: 24.79%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.21it/s]

Epoch 5, Train Loss: 1.081





# CC  = 37, e-rank = 15, soft rank = 2.209, dead features = 0 / 128
# CC  = 34, e-rank = 19, soft rank = 2.054, dead features = 0 / 128
# CC  = 23, e-rank = 20, soft rank = 2.081, dead features = 0 / 128
# CC  = 34, e-rank = 23, soft rank = 2.934, dead features = 0 / 128
# CC  = 56, e-rank = 25, soft rank = 3.308, dead features = 0 / 128
# CC  = 115, e-rank = 34, soft rank = 5.319, dead features = 0 / 128
# CC  = 127, e-rank = 37, soft rank = 4.375, dead features = 0 / 128
Accuracy on test set after epoch 5: 41.39%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.22it/s]

Epoch 6, Train Loss: 0.988





# CC  = 32, e-rank = 14, soft rank = 2.229, dead features = 0 / 128
# CC  = 46, e-rank = 20, soft rank = 2.184, dead features = 0 / 128
# CC  = 20, e-rank = 20, soft rank = 2.049, dead features = 0 / 128
# CC  = 23, e-rank = 21, soft rank = 2.628, dead features = 0 / 128
# CC  = 69, e-rank = 26, soft rank = 3.737, dead features = 0 / 128
# CC  = 124, e-rank = 36, soft rank = 5.965, dead features = 0 / 128
# CC  = 128, e-rank = 38, soft rank = 4.613, dead features = 0 / 128
Accuracy on test set after epoch 6: 15.02%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.19it/s]

Epoch 7, Train Loss: 0.925





# CC  = 46, e-rank = 15, soft rank = 2.361, dead features = 0 / 128
# CC  = 48, e-rank = 22, soft rank = 2.514, dead features = 0 / 128
# CC  = 49, e-rank = 22, soft rank = 2.822, dead features = 0 / 128
# CC  = 36, e-rank = 22, soft rank = 3.198, dead features = 0 / 128
# CC  = 79, e-rank = 26, soft rank = 4.148, dead features = 0 / 128
# CC  = 124, e-rank = 37, soft rank = 6.756, dead features = 0 / 128
# CC  = 127, e-rank = 40, soft rank = 4.972, dead features = 0 / 128
Accuracy on test set after epoch 7: 26.06%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.26it/s]

Epoch 8, Train Loss: 0.925





# CC  = 41, e-rank = 15, soft rank = 2.565, dead features = 0 / 128
# CC  = 52, e-rank = 19, soft rank = 2.542, dead features = 0 / 128
# CC  = 51, e-rank = 23, soft rank = 2.796, dead features = 0 / 128
# CC  = 49, e-rank = 24, soft rank = 3.259, dead features = 0 / 128
# CC  = 87, e-rank = 29, soft rank = 4.289, dead features = 0 / 128
# CC  = 125, e-rank = 40, soft rank = 7.289, dead features = 0 / 128
# CC  = 127, e-rank = 47, soft rank = 5.313, dead features = 0 / 128
Accuracy on test set after epoch 8: 27.22%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.26it/s]

Epoch 9, Train Loss: 0.859





# CC  = 41, e-rank = 14, soft rank = 2.630, dead features = 0 / 128
# CC  = 39, e-rank = 20, soft rank = 2.588, dead features = 0 / 128
# CC  = 38, e-rank = 21, soft rank = 2.738, dead features = 0 / 128
# CC  = 43, e-rank = 23, soft rank = 3.148, dead features = 0 / 128
# CC  = 83, e-rank = 27, soft rank = 4.214, dead features = 0 / 128
# CC  = 123, e-rank = 42, soft rank = 7.043, dead features = 0 / 128
# CC  = 128, e-rank = 46, soft rank = 4.952, dead features = 0 / 128
Accuracy on test set after epoch 9: 25.55%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.17it/s]

Epoch 10, Train Loss: 0.857





# CC  = 36, e-rank = 13, soft rank = 2.489, dead features = 0 / 128
# CC  = 41, e-rank = 20, soft rank = 2.595, dead features = 0 / 128
# CC  = 47, e-rank = 22, soft rank = 2.916, dead features = 0 / 128
# CC  = 51, e-rank = 22, soft rank = 3.557, dead features = 0 / 128
# CC  = 95, e-rank = 30, soft rank = 4.798, dead features = 0 / 128
# CC  = 124, e-rank = 44, soft rank = 7.453, dead features = 0 / 128
# CC  = 128, e-rank = 47, soft rank = 5.333, dead features = 0 / 128
Accuracy on test set after epoch 10: 16.56%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.20it/s]

Epoch 11, Train Loss: 0.785





# CC  = 37, e-rank = 13, soft rank = 2.511, dead features = 0 / 128
# CC  = 55, e-rank = 21, soft rank = 2.890, dead features = 0 / 128
# CC  = 69, e-rank = 22, soft rank = 3.222, dead features = 0 / 128
# CC  = 62, e-rank = 23, soft rank = 3.733, dead features = 0 / 128
# CC  = 101, e-rank = 30, soft rank = 4.949, dead features = 0 / 128
# CC  = 126, e-rank = 44, soft rank = 7.632, dead features = 0 / 128
# CC  = 125, e-rank = 47, soft rank = 4.273, dead features = 0 / 128
Accuracy on test set after epoch 11: 15.97%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.25it/s]

Epoch 12, Train Loss: 0.780





# CC  = 42, e-rank = 13, soft rank = 2.751, dead features = 0 / 128
# CC  = 46, e-rank = 20, soft rank = 2.971, dead features = 0 / 128
# CC  = 68, e-rank = 23, soft rank = 3.236, dead features = 0 / 128
# CC  = 72, e-rank = 23, soft rank = 4.037, dead features = 0 / 128
# CC  = 99, e-rank = 31, soft rank = 5.014, dead features = 0 / 128
# CC  = 124, e-rank = 48, soft rank = 7.797, dead features = 0 / 128
# CC  = 126, e-rank = 50, soft rank = 5.106, dead features = 0 / 128
Accuracy on test set after epoch 12: 39.88%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.26it/s]

Epoch 13, Train Loss: 0.800





# CC  = 52, e-rank = 14, soft rank = 2.941, dead features = 0 / 128
# CC  = 61, e-rank = 20, soft rank = 2.791, dead features = 0 / 128
# CC  = 62, e-rank = 23, soft rank = 3.022, dead features = 0 / 128
# CC  = 77, e-rank = 24, soft rank = 4.186, dead features = 0 / 128
# CC  = 101, e-rank = 32, soft rank = 5.371, dead features = 0 / 128
# CC  = 127, e-rank = 49, soft rank = 9.192, dead features = 0 / 128
# CC  = 126, e-rank = 53, soft rank = 4.806, dead features = 0 / 128
Accuracy on test set after epoch 13: 16.32%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.21it/s]

Epoch 14, Train Loss: 0.761





# CC  = 54, e-rank = 14, soft rank = 3.290, dead features = 0 / 128
# CC  = 66, e-rank = 21, soft rank = 3.288, dead features = 0 / 128
# CC  = 78, e-rank = 23, soft rank = 3.514, dead features = 0 / 128
# CC  = 69, e-rank = 24, soft rank = 4.201, dead features = 0 / 128
# CC  = 94, e-rank = 31, soft rank = 4.987, dead features = 0 / 128
# CC  = 125, e-rank = 49, soft rank = 8.330, dead features = 0 / 128
# CC  = 128, e-rank = 56, soft rank = 5.798, dead features = 0 / 128
Accuracy on test set after epoch 14: 29.35%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.22it/s]

Epoch 15, Train Loss: 0.651





# CC  = 48, e-rank = 13, soft rank = 3.206, dead features = 0 / 128
# CC  = 60, e-rank = 21, soft rank = 3.148, dead features = 0 / 128
# CC  = 72, e-rank = 23, soft rank = 3.503, dead features = 0 / 128
# CC  = 80, e-rank = 25, soft rank = 4.719, dead features = 0 / 128
# CC  = 104, e-rank = 34, soft rank = 5.887, dead features = 0 / 128
# CC  = 127, e-rank = 51, soft rank = 9.174, dead features = 0 / 128
# CC  = 128, e-rank = 56, soft rank = 5.352, dead features = 0 / 128
Accuracy on test set after epoch 15: 9.63%


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:01<00:00, 11.23it/s]

Epoch 16, Train Loss: 0.602





# CC  = 51, e-rank = 13, soft rank = 3.329, dead features = 0 / 128
# CC  = 63, e-rank = 21, soft rank = 3.111, dead features = 0 / 128
# CC  = 74, e-rank = 24, soft rank = 3.217, dead features = 0 / 128
# CC  = 80, e-rank = 26, soft rank = 4.214, dead features = 0 / 128
# CC  = 105, e-rank = 34, soft rank = 5.914, dead features = 0 / 128
# CC  = 127, e-rank = 54, soft rank = 8.882, dead features = 0 / 128
# CC  = 128, e-rank = 61, soft rank = 5.875, dead features = 0 / 128


<function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>