## Modularity-based Loss Function

Source doc (LASR): [link](https://docs.google.com/document/d/1Q7rouvMVozBDk9aUt384oc0Fu5dAm4K-7k_nXsQNzFY/edit)

For this, it'll be better to start with something simpler than SI-score CNNs. I'll start with simple MLPs trained on MNIST.

### Setup (3-layer MLP trained on MNIST)

In [49]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import mutual_info_score
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import clear_output
from collections import defaultdict
from itertools import islice
import random
import time
from pathlib import Path
import math

In [50]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [51]:
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [52]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # Normalize the data
])

train_dataset = MNIST(root='.', train=True, download=True, transform=transform)
test_dataset = MNIST(root='.', train=False, download=True, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)

In [53]:
model = MLP()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
train_losses = []

In [54]:
for epoch in range(5):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())
        if batch_idx % 400 == 0:
            print(f'Epoch {epoch+1}/{5}, Batch {batch_idx}/{len(train_loader)}, Loss: {loss.item():.4f}')

Epoch 1/5, Batch 0/938, Loss: 2.3385
Epoch 1/5, Batch 400/938, Loss: 2.2083
Epoch 1/5, Batch 800/938, Loss: 2.0438
Epoch 2/5, Batch 0/938, Loss: 1.9316
Epoch 2/5, Batch 400/938, Loss: 1.6622
Epoch 2/5, Batch 800/938, Loss: 1.2682
Epoch 3/5, Batch 0/938, Loss: 1.2365
Epoch 3/5, Batch 400/938, Loss: 0.9520
Epoch 3/5, Batch 800/938, Loss: 0.8479
Epoch 4/5, Batch 0/938, Loss: 0.7650
Epoch 4/5, Batch 400/938, Loss: 0.7181
Epoch 4/5, Batch 800/938, Loss: 0.5172
Epoch 5/5, Batch 0/938, Loss: 0.4507
Epoch 5/5, Batch 400/938, Loss: 0.5610
Epoch 5/5, Batch 800/938, Loss: 0.4439


In [55]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=train_losses, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Optimization Step')
fig.update_yaxes(title_text='CrossEntropy Loss')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

### Designing the Modularity Metric/Loss

In [56]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)

In [57]:
trained_model = MLP()
trained_model.load_state_dict(model.state_dict())

untrained_model = MLP()

def get_activation(model, layer_name):
    activation = {}
    def hook(module, input, output):
        if layer_name not in activation:
            activation[layer_name] = output
        else:
            activation[layer_name] = torch.cat((activation[layer_name], output), 0)
    return activation, hook

layer_names = ['fc1', 'fc2', 'fc3']

untrained_activations = defaultdict(dict)
trained_activations = defaultdict(dict)

for layer_name in layer_names:
    untrained_activations[layer_name], hook = get_activation(untrained_model, layer_name)
    untrained_model._modules.get(layer_name).register_forward_hook(hook)

    trained_activations[layer_name], hook = get_activation(trained_model, layer_name)
    trained_model._modules.get(layer_name).register_forward_hook(hook)

In [58]:
batches = 10
for batch_idx, (data, target) in enumerate(test_loader):
    untrained_model(data)
    trained_model(data)
    batches -= 1
    if batches == 0:
        break

In [59]:
print(
    trained_activations['fc1']['fc1'].shape,
    trained_activations['fc2']['fc2'].shape,
    trained_activations['fc3']['fc3'].shape
)

torch.Size([640, 128]) torch.Size([640, 128]) torch.Size([640, 10])


In [60]:
def conditional_entropy(act1, act2, num_bins=10):

    def discretize_activations(activations, bins=10):
        min_val = activations.min().item()
        max_val = activations.max().item()
        bins = torch.linspace(min_val, max_val, bins)
        discretized = torch.bucketize(activations, bins) - 1  # bucketize indexes from 1, subtract 1 to start from 0
        return discretized

    def joint_distribution(act1, act2, num_bins=10):
        joint_hist = torch.zeros((num_bins, num_bins))

        for i in range(act1.size(0)):
            for j in range(act1.size(1)):
                joint_hist[act1[i, j], act2[i, j]] += 1

        # normalize to get probabilities
        joint_distribution = joint_hist / joint_hist.sum()
        return joint_distribution

    def marginals(joint_dist):
        marginal_x = torch.sum(joint_dist, dim=1)  # Sum over columns to get P(X)
        marginal_y = torch.sum(joint_dist, dim=0)  # Sum over rows to get P(Y)
        return marginal_x, marginal_y

    epsilon = 1e-8
    conditional_entropy = 0.0
    act1 = discretize_activations(act1, num_bins)
    act2 = discretize_activations(act2, num_bins)

    joint_dist = joint_distribution(act1, act2, num_bins)
    marginal_x, marginal_y = marginals(joint_dist)
    for i in range(num_bins):
        for j in range(num_bins):
            if joint_dist[i, j] > 0:
                conditional_entropy -= joint_dist[i, j] * torch.log(joint_dist[i, j] / (marginal_y[j] + epsilon) + epsilon)
    
    return conditional_entropy

def entropy(act):
    # probs
    probs = torch.histc(act, bins=10) / act.numel()
    # entropy
    entropy = -torch.sum(probs * torch.log(probs + 1e-8))
    return entropy

def mutual_information(act1, act2, num_bins=10):
    return entropy(act1) - conditional_entropy(act1, act2, num_bins)

def joint_entropy(act1, act2, num_bins=10):
    return conditional_entropy(act1, act2, num_bins) + entropy(act2)

In [19]:
# test
act1 = untrained_activations['fc1']['fc1']
act2 = untrained_activations['fc2']['fc2']
e1 = entropy(act1)
e2 = entropy(act2)
mi = mutual_information(act1, act2, num_bins=10)
ce = conditional_entropy(act1, act2, num_bins=10)
je = joint_entropy(act1, act2, num_bins=10)
print(f'Entropy of act1: {e1:.4f}\nEntropy of act2: {e2:.4f}\nMutual Information: {mi:.4f}\nConditional Entropy: {ce:.4f}\nJoint Entropy: {je:.4f}')

Entropy of act1: 1.4756
Entropy of act2: 1.3278
Mutual Information: 0.0997
Conditional Entropy: 1.3759
Joint Entropy: 2.7037


In [265]:
# starting with a random module P that has 10 unique and indices for each layer
P = [
    torch.randperm(128)[:10],
    torch.randperm(128)[:10],
    torch.randperm(10)
]

many_P = [
    [
        torch.randperm(128)[:10],
        torch.randperm(128)[:10],
        torch.randperm(10)
    ] for _ in range(10)
]

In [266]:
def modularity_measures(P, layer_names, activations):
    comparison_totals = []

    pair_indices = [(0, 1), (1, 2), (0, 2)]
    for idx, (l1, l2) in enumerate(pair_indices):
        layer_1 = layer_names[l1]
        layer_2 = layer_names[l2]
        activations_first_layer = activations[layer_1][layer_1]
        activations_first_layer_in_p = activations[layer_1][layer_1][:, P[l1]]
        activations_second_layer_in_p = activations[layer_2][layer_2][:, P[l2]]
        activations_first_layer_not_in_p = activations[layer_1][layer_1][:, [i for i in range(128) if i not in P[l1]]]

        measure1 = conditional_entropy(activations_second_layer_in_p, activations_first_layer_not_in_p)
        measure2 = entropy(activations_second_layer_in_p)

        # print(f'Measure 1: {measure1:.4f}\nMeasure 2: {measure2:.4f}')

        measure3 = conditional_entropy(activations_second_layer_in_p, activations_first_layer)
        measure4 = conditional_entropy(activations_second_layer_in_p, activations_first_layer_in_p)

        comparison1 = measure1 - measure2
        comparison2 = measure3 - measure4
        comparison3 = measure3 - measure1

        comparison_totals.append((comparison1, comparison2, comparison3))
        
        # print(f'Measure 3: {measure3:.4f}\nMeasure 4: {measure4:.4f}')
        # print(f'Comparison 1: {comparison1:.4f}\nComparison 2: {comparison2:.4f}\nComparison 3: {comparison3:.4f}\n')

    return comparison_totals

In [267]:
modularity_measures(P, layer_names, untrained_activations)

[(tensor(-0.1126), tensor(0.0259), tensor(-0.0006)),
 (tensor(-0.1963), tensor(-0.0820), tensor(-0.0204)),
 (tensor(-0.1419), tensor(0.0179), tensor(0.0273))]

In [268]:
modularity_measures(P, layer_names, trained_activations)

[(tensor(-0.1427), tensor(0.0492), tensor(0.0160)),
 (tensor(-0.1383), tensor(0.0356), tensor(0.0075)),
 (tensor(-0.1577), tensor(0.0120), tensor(0.0340))]

In [269]:
import pandas as pd

In [270]:
total_comparisons_trained = [[0, 0, 0] for _ in range(3)]
total_comparisons_untrained = [[0, 0, 0] for _ in range(3)]
for P in tqdm.tqdm(many_P):
    total_comparisons_trained = np.add(total_comparisons_trained, modularity_measures(P, layer_names, trained_activations))
    total_comparisons_untrained = np.add(total_comparisons_untrained, modularity_measures(P, layer_names, untrained_activations))

100%|██████████| 10/10 [00:15<00:00,  1.52s/it]


In [271]:
total_comparisons_trained

array([[-1.45041335,  0.02397597, -0.06637084],
       [-1.50381875,  0.05813229, -0.05846167],
       [-1.34644568,  0.02299416, -0.00349975]])

In [275]:
# plot this in plotly as a heatmap
pairs = ['fc1-fc2', 'fc2-fc3', 'fc1-fc3']
comparisons = ['C1', 'C2', 'C3']
fig = make_subplots(rows=1, cols=2, subplot_titles=('Trained Model', 'Untrained Model'))
fig.add_trace(go.Heatmap(z=total_comparisons_trained, colorscale='Viridis', zmin=-0.1, zmax=0.1), row=1, col=1)
fig.add_trace(go.Heatmap(z=total_comparisons_untrained, colorscale='Viridis', zmin=-0.1, zmax=0.1), row=1, col=2)
fig.update_layout(height=600, width=1000, title_text='Modularity Measures')
# add pair and comparison labels to ticks
fig.update_xaxes(title_text='Pairs', ticktext=pairs, tickvals=[0, 1, 2], row=1, col=1)
fig.update_xaxes(title_text='Pairs', ticktext=pairs, tickvals=[0, 1, 2], row=1, col=2)
fig.update_yaxes(title_text='Comparisons', ticktext=comparisons, tickvals=[0, 1, 2], row=1, col=1)
fig.update_yaxes(title_text='Comparisons', ticktext=comparisons, tickvals=[0, 1, 2], row=1, col=2)
# plot size
fig.update_layout(width=800, height=400, autosize=False)
fig.show()

## Entropy-based Feature Extraction

In [61]:
model

MLP(
  (fc1): Linear(in_features=784, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=10, bias=True)
)

In [77]:
def model_removed_neurons(model, layer, neurons):
    new_model = MLP()
    new_model.load_state_dict(model.state_dict())

    def modified_forward(x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(new_model.fc1(x))
        if layer == 1:  x[:, neurons] = 0
        x = torch.relu(new_model.fc2(x))
        if layer == 2:  x[:, neurons] = 0
        x = new_model.fc3(x)
        if layer == 3:  x[:, neurons] = 0
        return x

    new_model.forward = modified_forward
    return new_model

In [78]:
entropy_diffs = []
kl_divs = []
for batch_idx, (data, target) in enumerate(test_loader):
    data = data.to(device)
    target = target.to(device)
    print(data.shape, target.shape)
    original_output = model(data)
    for i in tqdm.trange(128):
        corrupted_model = model_removed_neurons(model, 2, i).to(device)
        corrupted_output = corrupted_model(data)
        original_probs = F.softmax(original_output, dim=1)
        corrupted_probs = F.softmax(corrupted_output, dim=1)
        original_entropies = -torch.sum(original_probs * torch.log(original_probs + 1e-8), dim=1)
        corrupted_entropies = -torch.sum(corrupted_probs * torch.log(corrupted_probs + 1e-8), dim=1)
        entropy_diffs.append((original_entropies - corrupted_entropies).sum().item())
        kl_divs.append(F.kl_div(original_probs.log(), corrupted_probs, reduction='sum').item())
    break

torch.Size([64, 1, 28, 28]) torch.Size([64])


100%|██████████| 128/128 [00:00<00:00, 700.83it/s]


In [79]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=entropy_diffs, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Neuron Index')
fig.update_yaxes(title_text='Entropy Difference')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [80]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=kl_divs, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Neuron Index')
fig.update_yaxes(title_text='KL Divergence')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [81]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=sorted(kl_divs, reverse=True), mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Neuron Index')
fig.update_yaxes(title_text='KL Divergence')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [82]:
print(f'Average KL Divergence: {np.mean(kl_divs):.4f}')

Average KL Divergence: 0.1316


## with Directions

In [72]:
def model_removed_direction(model, layer, direction):
    new_model = MLP()
    new_model.load_state_dict(model.state_dict())

    def modified_forward(x):
        x = x.view(-1, 28 * 28)
        x = torch.relu(new_model.fc1(x))
        if layer == 1:
            dot_product = torch.sum(x * direction, dim=1)
            x = x - dot_product[:, None] * direction
        x = torch.relu(new_model.fc2(x))
        if layer == 2:
            dot_product = torch.sum(x * direction, dim=1)
            x = x - dot_product[:, None] * direction
        x = new_model.fc3(x)
        if layer == 3:
            dot_product = torch.sum(x * direction, dim=1)
            x = x - dot_product[:, None] * direction
        return x
    
    new_model.forward = modified_forward
    return new_model

In [73]:
entropy_diffs = []
kl_divs = []
for batch_idx, (data, target) in enumerate(test_loader):
    data = data.to(device)
    target = target.to(device)
    print(data.shape, target.shape)
    original_output = model(data)
    for i in tqdm.trange(128):
        direction = torch.randn(128).to(device)
        direction /= torch.norm(direction)
        corrupted_model = model_removed_direction(model, 2, direction).to(device)
        corrupted_output = corrupted_model(data)
        original_probs = F.softmax(original_output, dim=1)
        corrupted_probs = F.softmax(corrupted_output, dim=1)
        original_entropies = -torch.sum(original_probs * torch.log(original_probs + 1e-8), dim=1)
        corrupted_entropies = -torch.sum(corrupted_probs * torch.log(corrupted_probs + 1e-8), dim=1)
        entropy_diffs.append((original_entropies - corrupted_entropies).sum().item())
        kl_divs.append(F.kl_div(original_probs.log(), corrupted_probs, reduction='sum').item())
    break

torch.Size([64, 1, 28, 28]) torch.Size([64])


100%|██████████| 128/128 [00:00<00:00, 573.71it/s]


In [74]:
fig = go.Figure()
fig.add_trace(go.Scatter(y=kl_divs, mode='lines', name='', line=dict(color='darkred', width=2)))
fig.update_layout({'plot_bgcolor': 'rgba(255, 255, 255, 1)',})
fig.update_layout(
    xaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
    yaxis=dict(showgrid=True, gridwidth=1, gridcolor='LightGray'),
)
fig.update_xaxes(title_text='Neuron Index')
fig.update_yaxes(title_text='KL Divergence')
fig.update_layout(width=600, height=400, autosize=False)
fig.show()

In [76]:
print(f'Average KL Divergence: {np.mean(kl_divs):.4f}')

Average KL Divergence: 0.1105


In [67]:
def exact_entropy(probs):
    return -torch.sum(probs * torch.log(probs + 1e-8))

In [68]:
def max_info_direction(model, data):
    # search for a weighted sum of neurons that maximizes the KL divergence
    # between the original and the corrupted model
    model.eval()
    weights = torch.rand(128).to(device)
    weights.requires_grad = True
    optimizer = optim.Adam([weights], lr=1e-2)
    original_output = model(data)
    original_probs = F.softmax(original_output, dim=1)
    original_entropy = exact_entropy(original_probs)
    for i in range(1000):
        optimizer.zero_grad()
        # 
        corrupted_model.eval()
        corrupted_output = corrupted_model(data)
        corrupted_probs = F.softmax(corrupted_output, dim=1)
        corrupted_entropy = exact_entropy(corrupted_probs)
        loss = - F.kl_div(original_probs.log(), corrupted_probs, reduction='sum')
        loss.backward()
        optimizer.step()
    return weights

In [70]:
weights = max_info_direction(model, data)

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 128]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).