In [None]:
import torch
import torchvision
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchvision import transforms, datasets

# Set up model and data
model = torchvision.models.resnet50(pretrained=True)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Data preprocessing
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load ImageNet validation data
val_dataset = datasets.ImageFolder('/kaggle/input/imagenet1k-val/imagenet-val', transform=transform)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

def evaluate_model(model):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return 100. * correct / total

# Get baseline accuracy
print("Computing baseline accuracy...")
baseline_acc = evaluate_model(model)
print(f"Baseline Accuracy: {baseline_acc:.2f}%")

# Collect all conv layers
conv_layers = {}
for name, module in model.named_modules():
    if isinstance(module, torch.nn.Conv2d):
        conv_layers[name] = module

# Hook 
activations = {}
def get_activation(name):
    def hook(model, input, output):
        activations[name] = output.detach()
    return hook


hooks = []
for name, module in conv_layers.items():
    hooks.append(module.register_forward_hook(get_activation(name)))

def compute_all_feature_importance():
    feature_importance = {name: [] for name in conv_layers.keys()}
    
    with torch.no_grad():
        for images, _ in tqdm(val_loader, desc="Computing importance for all layers"):
            images = images.to(device)
            _ = model(images)
            
            # Process activations for all layers at once
            for name in conv_layers.keys():
                features = activations[name]  # [B, C, H, W]
                mean_feature = features.mean(dim=1, keepdim=True)  # [B, 1, H, W]
                channel_importance = (features - mean_feature) ** 2  # [B, C, H, W]
                importance_scores = channel_importance.mean(dim=[0, 2, 3])  # [C]
                feature_importance[name].append(importance_scores.cpu())
    
    # Process importance scores for all layers
    importance_results = {}
    for name in conv_layers.keys():
        scores = torch.stack(feature_importance[name]).mean(dim=0)
        normalized_scores = (scores - scores.min()) / (scores.max() - scores.min())
        importance_results[name] = normalized_scores.numpy()
    
    return importance_results

def ablate_channels(model, layer_name, conv_layer, importance_scores, percent_to_ablate=10):
    num_channels = len(importance_scores)
    num_to_ablate = int(num_channels * percent_to_ablate / 100)
    
    # Get indices of top channels
    top_channels = np.argsort(importance_scores)[-num_to_ablate:]
    
    # Create a mask for the channels
    mask = torch.ones(num_channels, device=device)
    mask[top_channels] = 0  # Zero out top channels
    
    # Define forward hook to zero out channels
    def channel_ablation_hook(module, input, output):
        return output * mask.view(1, -1, 1, 1)
    
    # Register hook and evaluate
    hook = conv_layer.register_forward_hook(channel_ablation_hook)
    ablated_acc = evaluate_model(model)
    hook.remove()
    
    return ablated_acc, top_channels



results = {}
print("\nStarting layer-wise analysis...")


importance_scores_all = compute_all_feature_importance()

# ablation for each layer
for layer_name, conv_layer in conv_layers.items():
    print(f"\nAnalyzing {layer_name}")
    
    ablated_acc, bottom_channels = ablate_channels(
        model, 
        layer_name, 
        conv_layer, 
        importance_scores_all[layer_name]
    )   

    results[layer_name] = {
        'importance_scores': importance_scores_all[layer_name],
        'ablated_accuracy': ablated_acc,
        'accuracy_drop': baseline_acc - ablated_acc,
        'bottom_channels': bottom_channels
    }
    
    print(f"Original Accuracy: {baseline_acc:.2f}%")
    print(f"Accuracy after ablating {layer_name}: {ablated_acc:.2f}%")
    print(f"Accuracy drop: {baseline_acc - ablated_acc:.2f}%")
    print(f"Bottom channels ablated: {bottom_channels}")


for hook in hooks:
    hook.remove()

# Plotting
plt.figure(figsize=(15, 10))

# 1. Accuracy drops
plt.subplot(2, 1, 1)
layer_names = list(results.keys())
acc_drops = [results[name]['accuracy_drop'] for name in layer_names]
plt.bar(range(len(layer_names)), acc_drops)
plt.xticks(range(len(layer_names)), layer_names, rotation=45, ha='right')
plt.title('Accuracy Drop After Channel Ablation')
plt.ylabel('Accuracy Drop (%)')
plt.grid(True)

# 2. Channel importance distributions
plt.subplot(2, 1, 2)
for name in layer_names[:5]:  
    scores = results[name]['importance_scores']
    plt.plot(np.sort(scores)[::-1], label=name)
plt.title('Channel Importance Distribution')
plt.xlabel('Channel Index (sorted)')
plt.ylabel('Importance Score')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

=
print("\nSummary of results:")
for layer_name in results:
    print(f"\n{layer_name}:")
    print(f"Accuracy drop: {results[layer_name]['accuracy_drop']:.2f}%")
    print(f"Number of channels ablated: {len(results[layer_name]['bottom_channels'])}")