In [None]:
import sys
sys.path.append('/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/')
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import wandb
from experiments import DiffLrpWrapper, SimpleRNet, apply_threshold


In [None]:

# Initialize wandb
wandb.init(project="reverse_LRP_mnist", tags=["diff_lrp", "mnist", "simplernet"], mode="disabled")

# Load and transform the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Initialize the network and optimizer for the underlying network
model = SimpleRNet()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# now wrap the network in the LRP class
# wrapped_model = DiffLrpWrapper(model)
criterion = nn.CrossEntropyLoss()

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Function to test the model
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    # with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        # print(data.shape, target.shape)
        output = model(data)
        # print(output.shape, target.shape)
        test_loss += criterion(output, target).item()  # sum up batch loss
        pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
        correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    wandb.log({"Test Loss": test_loss, "Test Accuracy": accuracy})
    return accuracy

# Training the model with early stopping
def train(model, device, train_loader, optimizer, epoch, target_accuracy=90.0):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        # print(data.shape, target.shape)
        optimizer.zero_grad()
        output = model(data)
        # print(output.shape, target.shape)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        wandb.log({"Train Loss": loss.item()})

        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
            accuracy = test(model, device, test_loader)
            # wandb.log({"Test Accuracy": accuracy})
            if accuracy >= target_accuracy:
                print(f"Stopping early: Reached {accuracy:.2f}% accuracy")
                return True
    return False

# Run training
for epoch in range(1, 11):  # 10
    if train(model, device, train_loader, optimizer, epoch):
        break
    
wandb.finish()


In [None]:
import matplotlib.pyplot as plt
wrapped_model = DiffLrpWrapper(model)
data, target = next(iter(test_loader))
data, target = data.to(device), target.to(device)
# output = wrapped_model(data, target)
dummy_vars = torch.ones_like(target) * 9
print(dummy_vars[0].item())
output = wrapped_model(data, dummy_vars)
# output.shape
plt.imshow(output[0].detach().cpu().numpy().squeeze(0), cmap='hot')
plt.colorbar()
plt.title(f"Heatmap from DiffLRPWrapper for target {target[0].item()}")

In [None]:
target[0].item()

In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

def load_mnist(batch_size=64):
    """ Load MNIST dataset with torchvision. """
    transform = transforms.Compose([
        transforms.ToTensor(),  # Converts to PyTorch tensors
        transforms.Normalize((0.5,), (0.5,))  # Normalizes the dataset
    ])
    
    train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    return train_loader

def apply_threshold(images, threshold=0.5):
    """ Apply a threshold to the images, setting all pixels below the threshold to zero.
        Images should retain original dimensions.
    """
    # Thresholding
    thresholded_images = torch.where(images > threshold, images, torch.zeros_like(images))
    return thresholded_images


# Load data
train_loader = load_mnist(batch_size=10)

# Get a single batch of images
data_iter = iter(train_loader)
images, labels = next(data_iter)

# Apply threshold
thresholded_images = apply_threshold(images, threshold=0.95)  # Using 0.5 as an example threshold

# Plotting
fig, axes = plt.subplots(1, 2)
axes[0].imshow(images[0][0], cmap='gray')
axes[0].set_title('Original Image')
axes[1].imshow(thresholded_images[0][0], cmap='gray')
axes[1].set_title('Thresholded Image')
plt.show()


In [None]:
import sys
sys.path.append('/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/')


In [None]:
from experiments import DiffLrpWrapper

In [None]:
wrapped_model = DiffLrpWrapper(model)

In [None]:
over_train_data, over_train_target = data[:1,:,:,:], target[:1]  # Get a single image and target
# expand so we repeat the same input to a batch size of 64
over_train_data, over_train_target = over_train_data.expand(64, 1, 28, 28), over_train_target.expand(64)



In [None]:
import sys
sys.path.append('/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/')
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import wandb
from experiments import DiffLrpWrapper, SimpleRNet, apply_threshold
from experiments import CosineDistanceLoss, ManualCNN
from matplotlib import pyplot as plt
import numpy as np
# Load and transform the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# Initialize the network and optimizer for the underlying network


data, target = next(iter(test_loader))
# model = DiffLrpWrapper(SimpleRNet())
model = ManualCNN()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# now wrap the network in the LRP class
# wrapped_model = DiffLrpWrapper(model)
criterion = CosineDistanceLoss()

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# over_train_data, over_train_target = data[:1,:,:,:], target[:1]  # Get a single image and target
# # expand so we repeat the same input to a batch size of 64
# over_train_data, over_train_target = over_train_data.expand(64, 1, 28, 28), over_train_target.expand(64)

model.train()
for x in range(1000):
    data, target = next(iter(test_loader))
    data, target = data.to(device), target.to(device)
    target_map = apply_threshold(data, threshold=0.99)
    optimizer.zero_grad()
    # print(target.unsqueeze(1))
    output = model(data)
    # print(output.shape, target_map.shape) 
    loss = criterion(output, target_map)
    loss.backward()
    optimizer.step()
    
    if x % 10 == 0:
        print(f'Train Epoch: [{x * len(data)}/{len(train_loader.dataset)} ({100. * x / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    if x % 50 == 0:
        # Plotting
        num = np.random.randint(0, 64)
        fig, axes = plt.subplots(1, 2)
        axes[0].imshow(output[num][0].detach().numpy(), cmap='hot')
        axes[0].set_title(f'LRP Output after {x} iterations')
        axes[1].imshow(target_map[num][0], cmap='hot')
        axes[1].set_title('Original Image')
        plt.show()
    


In [None]:
import torch
import torch.nn.functional as F

# Example tensor
x = torch.tensor([[0.1, 0.2, 0.3], [0.7, 0.5, 0.4]], requires_grad=True)

# Softmax to create a differentiable mask
temperature = 0.05
# Adjust temperature to control the softness
soft_mask = F.softmax(x / temperature, dim=1)

# Select neuron using soft mask
selected_values = (x * soft_mask)

print("Selected values:", selected_values)
