In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

def num_connected_components(A, tol=1e-8, thresh = 0.98):
    # eps = 1e-8
    A = A - A.mean(dim=1,keepdim=True)
    A_norm = A / (A.norm(dim=1, keepdim=True) + tol)
    Corr = A_norm @ A_norm.T
    # print(Corr[:2,:2], Corr.shape)
    Corr.fill_diagonal_(0)
    Corr = Corr.abs()
    Adj = (Corr>0.98).float()
    # Compute the degree vector and degree matrix D
    degrees = torch.sum(Adj, dim=1)
    D = torch.diag(degrees)
    # Compute the Laplacian L = D - Adj
    L = D - Adj
    # Compute eigenvalues of L. Since L is symmetric, use eigvalsh.
    eigenvalues = torch.linalg.eigvalsh(L)
    # Count the number of eigenvalues that are close to zero.
    num_components = torch.sum(eigenvalues < tol).item()
    return num_components





def compute_effective_rank(activation_matrix, eps=1e-12):
    """
    Compute the effective rank of an activation matrix in a numerically stable manner.
    
    activation_matrix: Tensor of shape (batch_size, feature_dim)
    eps: Small constant to prevent division by zero or log(0)
    
    Returns: effective rank (float)
    """
    # Use double precision for stability
    act = activation_matrix.double()
    # Compute SVD
    U, S, V = torch.linalg.svd(act, full_matrices=False)
    S_sum = S.sum() + eps  # Avoid division by zero
    p = S / S_sum         # Normalized singular values
    # Clamp probabilities to avoid log(0)
    p_clamped = p.clamp(min=eps)
    # Compute entropy and effective rank
    entropy = -(p * torch.log(p_clamped)).sum()
    eff_rank = torch.exp(entropy)
    return eff_rank.item()

# ---------------------
# Define a Wide, Deep MLP with Hooks to Record Activations
# ---------------------
class WideDeepMLP(nn.Module):
    def __init__(self, input_dim=3*32*32, hidden_dim=1000, num_layers=10, num_classes=10):
        super(WideDeepMLP, self).__init__()
        self.layers = nn.ModuleList()
        # First Linear Layer + ReLU
        self.layers.append(nn.Linear(input_dim, hidden_dim, bias=False))
        self.layers.append(nn.Tanh())
        # Additional hidden layers: each has a Linear layer followed by ReLU
        for _ in range(num_layers - 1):
            self.layers.append(nn.Linear(hidden_dim, hidden_dim, bias=False))
            self.layers.append(nn.Tanh())
        # Final classifier layer (not included in activation collection)
        self.layers.append ( nn.Linear(hidden_dim, num_classes, bias=False) ) 
        # Dictionary to store activations (from each hidden Linear layer output)
        self.activations = {}
        # Register hooks on each Linear layer in self.layers (skip ReLU modules)
        layer_idx = 0
        for module in self.layers:
            # store pre-activations (after applying linear )
            if not isinstance(module, nn.Linear):
                module.register_forward_hook(self._get_activation_hook(layer_idx))
                layer_idx += 1

        for layer in self.layers:
            if isinstance(layer, nn.Linear):
                nn.init.normal_(layer.weight, mean=0, std=(1.0/0.39)**0.5/layer.weight.shape[0]**0.5)

    def _get_activation_hook(self, idx):
        def hook(module, input, output):
            self.activations[f"layer_{idx}"] = output.detach()
        return hook

    def forward(self, x):
        # Flatten the input
        x = x.view(x.size(0), -1)
        for layer in self.layers:
            x = layer(x)
        return x

# ---------------------
# Data Preparation
# ---------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_set   = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=512, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_set, batch_size=6000, shuffle=False, num_workers=2, drop_last=True)

# ---------------------
# Initialize Model, Loss, Optimizer
# ---------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = WideDeepMLP(num_layers=7).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
th = 0.999 # threshold for counting the connected components 

# ---------------------
# Training Loop
# ---------------------
num_epochs = 10
for epoch in range(num_epochs):
    # Training phase, skip epoch 0, only report validation metrics 
    if epoch==0:
        avg_train_loss = 0
    if epoch > 0:
        model.train()
        train_loss_total = 0.0
        num_train_batches = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss_total += loss.item()
            num_train_batches += 1
        avg_train_loss = train_loss_total / num_train_batches

    # Validation phase: compute average loss on validation set
    model.eval()
    val_loss_total = 0.0
    num_val_batches = 0
    # Also, capture activations from one batch for effective rank
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss_total += loss.item()
            num_val_batches += 1
            # For effective rank metrics, only use the first batch
            if num_val_batches == 1:
                val_batch_activations = {k: v for k, v in model.activations.items()}
    avg_val_loss = val_loss_total / num_val_batches
    print(f'epoch {epoch} CC stats using threshold = {th:.3f}')
    for k,A in model.activations.items():
        print(f"layer {k} feature dim = {A.shape[1]} # of connected components: {num_connected_components(A.T,thresh=th)}")

    # Compute effective rank for each recorded hidden layer on the first validation batch
    erank_dict = {}
    for layer_name, act in val_batch_activations.items():
        act_matrix = act.view(act.size(0), -1)
        erank_dict[layer_name] = compute_effective_rank(act_matrix)

    # Report training loss, validation loss, and effective rank per layer for this epoch
    erank_str = ", ".join([f"{k}: {v:.2f}" for k, v in erank_dict.items()])
    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f} | Effective Rank per layer: {erank_str}")
