In [2]:
import sys
sys.path.append('/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/')


In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import wandb
from experiments import DiffLrpWrapper, SimpleRNet, apply_threshold
from experiments import CosineDistanceLoss, ManualCNN
from matplotlib import pyplot as plt
import numpy as np
# Load and transform the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# Initialize the network and optimizer for the underlying network


data, target = next(iter(test_loader))
# model = DiffLrpWrapper(SimpleRNet())
model = ManualCNN()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# now wrap the network in the LRP class
# wrapped_model = DiffLrpWrapper(model)
criterion = CosineDistanceLoss()

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# over_train_data, over_train_target = data[:1,:,:,:], target[:1]  # Get a single image and target
# # expand so we repeat the same input to a batch size of 64
# over_train_data, over_train_target = over_train_data.expand(64, 1, 28, 28), over_train_target.expand(64)

model.train()
for x in range(1000):
    data, target = next(iter(test_loader))
    data, target = data.to(device), target.to(device)
    target_map = apply_threshold(data, threshold=0.99)
    optimizer.zero_grad()
    # print(target.unsqueeze(1))
    output = model(data)
    # print(output.shape, target_map.shape) 
    loss = criterion(output, target_map)
    loss.backward()
    optimizer.step()
    
    if x % 10 == 0:
        print(f'Train Epoch: [{x * len(data)}/{len(train_loader.dataset)} ({100. * x / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    if x % 50 == 0:
        # Plotting
        num = np.random.randint(0, 64)
        fig, axes = plt.subplots(1, 2)
        axes[0].imshow(output[num][0].detach().numpy(), cmap='hot')
        axes[0].set_title(f'LRP Output after {x} iterations')
        axes[1].imshow(target_map[num][0], cmap='hot')
        axes[1].set_title('Original Image')
        plt.show()
    


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from contextlib import contextmanager

# Define a context manager to track activations
@contextmanager
def track_activations(wrapper):
    original_relu = F.relu
    original_max_pool2d = F.max_pool2d
    original_log_softmax = F.log_softmax
    
    def wrapped_relu(input, *args, **kwargs):
        output = original_relu(input, *args, **kwargs)
        wrapper.record_activation('ReLU', output)
        return output
    
    def wrapped_max_pool2d(input, *args, **kwargs):
        output = original_max_pool2d(input, *args, **kwargs)
        wrapper.record_activation('MaxPool2d', output)
        return output
    
    def wrapped_log_softmax(input, *args, **kwargs):
        output = original_log_softmax(input, *args, **kwargs)
        wrapper.record_activation('LogSoftmax', output)
        return output
    
    F.relu = wrapped_relu
    F.max_pool2d = wrapped_max_pool2d
    F.log_softmax = wrapped_log_softmax
    
    try:
        yield
    finally:
        F.relu = original_relu
        F.max_pool2d = original_max_pool2d
        F.log_softmax = original_log_softmax

# Define a simple neural network with explicit activation functions
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)
        
    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), 2)
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return F.log_softmax(x, dim=1)

# Wrapper class to track layers and activations
class WrapperNet(nn.Module):
    def __init__(self, model):
        super(WrapperNet, self).__init__()
        self.model = model
        self.executed_layers = []
        self.activation_inputs = {}
        
        # Register hooks for the layers
        for name, module in self.model.named_modules():
            if not isinstance(module, nn.Sequential) and not isinstance(module, WrapperNet) and not len(list(module.children())) > 0:
                module.register_forward_hook(self.forward_hook)
    
    def forward_hook(self, module, input, output):
        self.executed_layers.append(module.__class__.__name__)
        self.activation_input.(input)
    
    def record_activation(self, name, output):
        self.executed_layers.append(name)
        self.activation_outputs.append(output)

    def forward(self, x):
        self.executed_layers = []
        self.activation_outputs = []
        with track_activations(self):
            return self.model(x)

    def get_layers_and_activation_lists(self):
        return self.executed_layers, self.activation_outputs

# Instantiate the network
model = SimpleNet()
wrapped_model = WrapperNet(model)

# Create a sample input tensor
input_tensor = torch.randn(1, 1, 28, 28)

# Perform a forward pass
output = wrapped_model(input_tensor)

# Get the layers and activations
layers, activations = wrapped_model.get_layers_and_activation_lists()

print("Layers and activations executed in forward pass:", layers)
# print("Activation outputs:", activations)


Layers and activations executed in forward pass: ['Conv2d', 'ReLU', 'MaxPool2d', 'Conv2d', 'ReLU', 'MaxPool2d', 'Linear', 'ReLU', 'Linear', 'LogSoftmax']
