In [None]:
import sys
sys.path.append('/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/')
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import wandb
from experiments import DiffLrpWrapper, SimpleRNet, apply_threshold

# Initialize wandb
# wandb.init(project="reverse_LRP_mnist", run_name="rev_lrp_mnist")

# Load and transform the MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)

# Initialize the network and optimizer for the underlying network
model = SimpleRNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# now wrap the network in the LRP class
wrapped_model = DiffLrpWrapper(model)
criterion = nn.MSELoss()

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
wrapped_model.to(device)

# Function to test the model
def test(wrapped_model, device, test_loader):
    wrapped_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            target_map = apply_threshold(data)
            output = wrapped_model(data, target)
            test_loss += criterion(output, target_map).item()  # sum up batch loss
            # pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            # correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    # accuracy = 100. * correct / len(test_loader.dataset)
    # wandb.log({"Test Loss": test_loss})
    return test_loss

# Training the model with early stopping
def train(wrapped_model, device, train_loader, optimizer, epoch, target_accuracy=99.0):
    wrapped_model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        target_map = apply_threshold(data)
        optimizer.zero_grad()
        output = wrapped_model(data, target)
        loss = criterion(output, target_map)
        loss.backward()
        optimizer.step()
        
        # wandb.log({"Train Loss": loss.item()})

        if batch_idx % 10 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    
            accuracy = test(model, device, test_loader)
            wandb.log({"Test Accuracy": accuracy})
            if accuracy >= target_accuracy:
                print(f"Stopping early: Reached {accuracy:.2f}% accuracy")
                return True
    return False

# Run training
for epoch in range(1, 11):  # 10
    if train(wrapped_model, device, train_loader, optimizer, epoch):
        break
    
# wandb.finish()


In [None]:
import torch
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

def load_mnist(batch_size=64):
    """ Load MNIST dataset with torchvision. """
    transform = transforms.Compose([
        transforms.ToTensor(),  # Converts to PyTorch tensors
        transforms.Normalize((0.5,), (0.5,))  # Normalizes the dataset
    ])
    
    train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    return train_loader

def apply_threshold(images, threshold=0.5):
    """ Apply a threshold to the images, setting all pixels below the threshold to zero.
        Images should retain original dimensions.
    """
    # Thresholding
    thresholded_images = torch.where(images > threshold, images, torch.zeros_like(images))
    return thresholded_images


# Load data
train_loader = load_mnist(batch_size=10)

# Get a single batch of images
data_iter = iter(train_loader)
images, labels = next(data_iter)

# Apply threshold
thresholded_images = apply_threshold(images, threshold=0.5)  # Using 0.5 as an example threshold

# Plotting
fig, axes = plt.subplots(1, 2)
axes[0].imshow(images[0][0], cmap='gray')
axes[0].set_title('Original Image')
axes[1].imshow(thresholded_images[0][0], cmap='gray')
axes[1].set_title('Thresholded Image')
plt.show()


In [None]:
import sys
sys.path.append('/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/')


In [None]:
from experiments import DiffLrpWrapper

In [None]:
model = SimpleNet()