In [None]:

import sys
sys.path.append('/Users/charleshiggins/Personal/CharlesPhD/CodeRepo/xai_intervention/RL-LRP/')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from contextlib import contextmanager
from experiments import reverse_layer, diff_softmax

In [None]:
def run_test_on_inner_net(inner_model, test_loader):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    inner_model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        data, target = next(iter(test_loader))
        data, target = data.to(device), target.to(device)
        output = inner_model(data)
        print(output.shape, target.shape)
        test_loss = nn.functional.cross_entropy(output, target).item()
        pred = output.argmax(dim=1, keepdim=False)
        print(pred.shape, target.shape)
        print(pred, target)
        correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(target)} ({100. * correct / len(target):.0f}%)\n')
    return

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from experiments import SimpleRNet, apply_threshold, CosineDistanceLoss, ManualCNN
from matplotlib import pyplot as plt
import numpy as np
# comment out when running locally
from experiments import WrapperNet
# comment out when running locally

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST('./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)

# Initialize the network and optimizer for the underlying network
model = WrapperNet(SimpleRNet())
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# now wrap the network in the LRP class
criterion = CosineDistanceLoss()

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


# over_train_data, over_train_target = data[:1,:,:,:], target[:1]  # Get a single image and target
# # expand so we repeat the same input to a batch size of 64
# over_train_data, over_train_target = over_train_data.expand(64, 1, 28, 28), over_train_target.expand(64)

model.train()
for x in range(1000):
    model.train()
    data, target = next(iter(train_loader))
    data, target = data.to(device), target.to(device)
    target_map = apply_threshold(data, threshold=0.99)
    optimizer.zero_grad()
    # print(target.unsqueeze(1))
    output = model(data, target)
    # print(output.shape, target_map.shape) 
    loss = criterion(output, target_map)
    loss.backward()
    optimizer.step()
    
    if x % 10 == 0:
        print(f'Train Epoch: [{x * len(data)}/{len(train_loader.dataset)} ({100. * x / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
    if x % 50 == 0:
        # Plotting
        num = np.random.randint(0, 64)
        fig, axes = plt.subplots(1, 2)
        axes[0].imshow(output[num][0].detach().numpy(), cmap='hot')
        axes[0].set_title(f'LRP Output after {x} iterations')
        axes[1].imshow(target_map[num][0], cmap='hot')
        axes[1].set_title('Original Image')
        plt.show()
        run_test_on_inner_net(model.model, test_loader)
