In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# -------------------------------
# Utility Functions
# -------------------------------
def num_connected_components(A, tol=1e-8, thresh=0.98):
    A = A - A.mean(dim=1,keepdim=True)
    A_norm = A / (A.norm(dim=1, keepdim=True) + tol)
    Corr = A_norm @ A_norm.T
    Corr.fill_diagonal_(0)
    Corr = Corr.abs()
    Adj = (Corr > thresh).float()
    degrees = torch.sum(Adj, dim=1)
    D = torch.diag(degrees)
    L = D - Adj
    eigenvalues = torch.linalg.eigvalsh(L)
    num_components = torch.sum(eigenvalues < tol).item()
    return num_components

def compute_effective_rank(activation_matrix, eps=1e-12):
    act = activation_matrix.double()
    U, S, V = torch.linalg.svd(act, full_matrices=False)
    S_sum = S.sum() + eps
    p = S / S_sum
    p_clamped = p.clamp(min=eps)
    entropy = -(p * torch.log(p_clamped)).sum()
    eff_rank = torch.exp(entropy)
    return eff_rank.item()

# -------------------------------
# Define a ResNet with Hooks to Record Activations
# -------------------------------
# We use torchvision.models.resnet18 as our candidate ResNet.
# We insert forward hooks into chosen layers (e.g., after layer1, layer2, layer3, and layer4).

from torchvision.models import resnet18

class ResNetWithHooks(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNetWithHooks, self).__init__()
        # Load a pre-defined resnet18
        self.resnet = resnet18(pretrained=False, num_classes=num_classes)
        # Dictionary to store activations
        self.activations = {}
        # Register hooks on chosen layers: here we choose layer1, layer2, layer3, and layer4.
        self.resnet.layer1.register_forward_hook(self._get_activation_hook('layer1'))
        self.resnet.layer2.register_forward_hook(self._get_activation_hook('layer2'))
        self.resnet.layer3.register_forward_hook(self._get_activation_hook('layer3'))
        self.resnet.layer4.register_forward_hook(self._get_activation_hook('layer4'))

    def _get_activation_hook(self, name):
        def hook(module, input, output):
            # Save output activation
            self.activations[name] = output.detach()
        return hook

    def forward(self, x):
        return self.resnet(x)

# -------------------------------
# Data Preparation
# -------------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
val_set   = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = DataLoader(train_set, batch_size=512, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_set, batch_size=1024, shuffle=False, num_workers=2, drop_last=True)

# -------------------------------
# Initialize Model, Loss, Optimizer
# -------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ResNetWithHooks(num_classes=10).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1,
                     momentum=0.9, weight_decay=5e-4)
th = 0.90  # threshold for counting connected components

# -------------------------------
# Training Loop
# -------------------------------
num_epochs = 40
for epoch in range(num_epochs):
    if epoch == 0:
        avg_train_loss = 0.0
    # Training phase (skip reporting metrics at epoch 0)
    if epoch > 0:
        model.train()
        train_loss_total = 0.0
        num_train_batches = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss_total += loss.item()
            num_train_batches += 1
        avg_train_loss = train_loss_total / num_train_batches

    # Validation phase: compute average loss and record activations from one batch
    model.eval()
    val_loss_total = 0.0
    num_val_batches = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss_total += loss.item()
            num_val_batches += 1
            # Capture activations from the first validation batch
            if num_val_batches == 1:
                val_batch_activations = {k: v for k, v in model.activations.items()}
    avg_val_loss = val_loss_total / num_val_batches

    print(f'\nEpoch {epoch} Connected Components stats (threshold = {th:.3f}):')
    # For each hooked layer, compute number of connected components.
    for name, A in model.activations.items():
        # Flatten spatial dimensions if necessary: A shape is [B, C, H, W]
        if A.dim() > 2:
            A_flat = A.transpose(0,1).flatten(1).transpose(0,1)  # shape [BH*W, C]
            # print(A.shape, A_flat.shape)
            # assert False
            # We average over spatial locations
            # A_flat = A_flat.mean(dim=2)  # shape [B, C]
        else:
            A_flat = A
        # Transpose so each row corresponds to a feature.
        num_cc = num_connected_components(A_flat.T, thresh=th)
        print(f"Layer {name} feature dim = {A_flat.shape[1]}  # connected components: {num_cc}")

    # Compute effective rank for each recorded layer from the first validation batch
    erank_dict = {}
    for name, act in val_batch_activations.items():
        if act.dim() > 2:
            act_flat = act.flatten(2).mean(dim=2)
        else:
            act_flat = act
        erank = compute_effective_rank(act_flat)
        erank_dict[name] = erank
    erank_str = ", ".join([f"{name}: {val:.2f}" for name, val in erank_dict.items()])
    print(f"Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f} | Effective Rank per layer: {erank_str}")


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import copy

from collections import defaultdict

def collect_activations(model, dataloader, device, num_batches=None):
    """
    Collect activations for multiple batches of data
    
    Args:
        model: ResNetActivations model
        dataloader: DataLoader instance
        device: torch device
        num_batches: Number of batches to process (None for all)
    
    Returns:
        list of dictionaries containing activations for each batch
    """
    model.eval()
    all_batch_activations = []
    
    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(dataloader):
            if num_batches is not None and batch_idx >= num_batches:
                break
                
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            # Forward pass
            outputs = model(inputs)
            
            # Collect activations
            batch_activations = model.get_all_activations()
            batch_activations['labels'] = labels.cpu()
            all_batch_activations.append(batch_activations)
    
    return all_batch_activations

def num_connected_components(A, tol=1e-8, thresh=0.98):
    A = A - A.mean(dim=1,keepdim=True)
    A_norm = A / (A.norm(dim=1, keepdim=True) + tol)
    Corr = A_norm @ A_norm.T
    # print(f'Corr dims = {Corr.shape}')
    Corr.fill_diagonal_(0)
    Corr = Corr.abs()
    # print(f'Corr dims = {Corr.shape} Corr max (off-diag) median per unit: {Corr.max(dim=1).values.median():.3f}')
    Adj = (Corr > thresh).float()
    degrees = torch.sum(Adj, dim=1)
    D = torch.diag(degrees)
    L = D - Adj
    eigenvalues = torch.linalg.eigvalsh(L)
    num_components = torch.sum(eigenvalues < tol).item()
    return num_components

def report_CC_stats(model, loader, thresh, num_batches = 10):
    activations = collect_activations(model, loader, device, num_batches=num_batches)
    
    for k in activations[0]['main'].keys():
        
        A = torch.concat([act['main'][k].detach().cpu() for act in activations])
        A_flat = A.flatten(1)
        # A = A.transpose(0,1).flatten(1)#.transpose(0,1)
        # print('A shape = ', A.shape, ' A flat shape = ', A_flat.shape)
        try:
            CC = num_connected_components(A_flat,thresh=thresh)
            print('key = ', k, ' CC = ', CC, ' rank(A) = ', min(A_flat.shape))
        except Exception:
            pass


# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Data augmentation and normalization for training
# Just normalization for validation
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_val = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Load CIFAR10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                      download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=128,
                        shuffle=True, num_workers=2)

valset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                     download=True, transform=transform_val)
valloader = DataLoader(valset, batch_size=128,
                      shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# Define ResNet model
def conv3x3(in_planes, out_planes, stride=1):
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=1, bias=False)

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, inplanes, planes, stride=1, downsample=None):
        super(BasicBlock, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn1 = nn.BatchNorm2d(planes)
        self.relu1 = nn.ReLU(inplace=False)  # Changed to not use inplace
        self.conv2 = conv3x3(planes, planes)
        self.bn2 = nn.BatchNorm2d(planes)
        self.relu2 = nn.ReLU(inplace=False)  # Changed to not use inplace
        self.downsample = downsample
        self.stride = stride
        
        # Store activations
        self.activations = {}

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        self.activations['relu1'] = out.detach()

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x)

        out += identity
        out = self.relu2(out)
        self.activations['relu2'] = out.detach()

        return out

class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1,
                              bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=False)  # Changed to not use inplace
        
        self.layer1 = self._make_layer(block, 512, layers[0])
        self.layer2 = self._make_layer(block, 512, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 512, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
        # Store all activations
        self.activations = {}

    def _make_layer(self, block, planes, blocks, stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion,
                         kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        self.activations['initial_relu'] = x.detach()

        x = self.layer1(x)
        self.activations['layer1'] = x.detach()
        
        x = self.layer2(x)
        self.activations['layer2'] = x.detach()
        
        x = self.layer3(x)
        self.activations['layer3'] = x.detach()
        
        x = self.layer4(x)
        self.activations['layer4'] = x.detach()

        x = self.avgpool(x)
        self.activations['avgpool'] = x.detach()
        
        x = torch.flatten(x, 1)
        logits = self.fc(x)
        self.activations['logits'] = logits.detach()

        return logits

    def get_all_activations(self):
        """Collect all activations from the model, including those from BasicBlocks"""
        all_activations = defaultdict(dict)
        
        # Get main activations
        for name, activation in self.activations.items():
            all_activations['main'][name] = activation
        
        # Get activations from each BasicBlock
        for layer_idx, layer in enumerate([self.layer1, self.layer2, self.layer3, self.layer4]):
            for block_idx, block in enumerate(layer):
                for act_name, activation in block.activations.items():
                    all_activations[f'layer{layer_idx+1}_block{block_idx+1}'][act_name] = activation
        
        return all_activations

def train_model(model, criterion, optimizer, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = trainloader
            else:
                model.eval()
                dataloader = valloader

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_corrects.double() / len(dataloader.dataset)

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase=='val':
                print(f'report CC stats for phase {phase}')
                report_CC_stats(model, dataloader, thresh=0.9,num_batches=10)

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:4f}')

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model

# Initialize model, criterion, and optimizer
model = ResNet(BasicBlock, [2, 2, 2, 2]).to(device)  # ResNet18
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1,
                     momentum=0.9, weight_decay=5e-4)

# Train and evaluate
model = train_model(model, criterion, optimizer, num_epochs=25)