<a href="https://colab.research.google.com/github/KianShokraneh/Regularization-and-Robustness-Evaluation-Using-SHAP/blob/main/Regularization_%26_Robustness_Eval_Using_SHAP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision shap captum

Collecting shap
  Downloading shap-0.46.0-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (540 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m540.1/540.1 kB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting captum
  Downloading captum-0.7.0-py3-none-any.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m47.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting lime
  Downloading lime-0.2.0.1.tar.gz (275 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m275.7/275.7 kB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.wh

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from captum.attr import DeepLift, GradientShap
import shap
import numpy as np
from scipy.stats import pearsonr

In [None]:
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# Load MNIST dataset
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# Hyperparameters
learning_rate = 0.001
num_epochs = 3
epsilon = 0.1

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def fgsm_attack(model, loss_fn, images, labels, epsilon):
    images.requires_grad = True
    outputs = model(images)
    loss = loss_fn(outputs, labels)
    model.zero_grad()
    loss.backward()
    grad = images.grad.data
    adv_images = images + epsilon * grad.sign()
    adv_images = torch.clamp(adv_images, 0, 1)
    return adv_images

def shap_reg_loss(model, images, adv_images, labels):
    explainer = shap.DeepExplainer(model, images[:10].to(device))
    shap_values = explainer.shap_values(images[:10].to(device))
    shap_values_adv = explainer.shap_values(adv_images[:10].to(device))
    loss = 0
    for i in range(len(shap_values)):
        loss += torch.mean((torch.tensor(shap_values[i]) - torch.tensor(shap_values_adv[i])) ** 2)
    return loss


def deeplift_reg_loss(model, images, adv_images, labels):
    deeplift = DeepLift(model)
    baseline = torch.zeros_like(images).to(device)
    shap_values = deeplift.attribute(images, baselines=baseline, target=labels)
    shap_values_adv = deeplift.attribute(adv_images, baselines=baseline, target=labels)
    shap_values, shap_values_adv = torch.tensor(shap_values, dtype=torch.float32).to(device), torch.tensor(shap_values_adv, dtype=torch.float32).to(device)
    loss = torch.mean((shap_values - shap_values_adv) ** 2)
    return loss

def gradientshap_reg_loss(model, images, adv_images, labels):
    gs = GradientShap(model)
    baseline_dist = torch.randn((20, *images.shape[1:]), requires_grad=True).to(device)
    shap_values = gs.attribute(images, baselines=baseline_dist, target=labels)
    shap_values_adv = gs.attribute(adv_images, baselines=baseline_dist, target=labels)
    shap_values, shap_values_adv = torch.tensor(shap_values, dtype=torch.float32).to(device), torch.tensor(shap_values_adv, dtype=torch.float32).to(device)
    loss = torch.mean((shap_values - shap_values_adv) ** 2)
    return loss

In [None]:
def train_with_shap(model, train_loader, criterion, optimizer, epsilon, device, reg_loss_fn, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        i=0
        for images, labels in train_loader:
            if i%100==0:
              print(i)
            i+=1
            images, labels = images.to(device), labels.to(device)

            adv_images = fgsm_attack(model, criterion, images, labels, epsilon)

            outputs = model(images)
            loss = criterion(outputs, labels)

            shap_loss = reg_loss_fn(model, images, adv_images, labels)

            total_loss = loss + shap_loss

            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}/{num_epochs} completed')

def train_without_shap(model, train_loader, criterion, optimizer, device, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        i=0
        for images, labels in train_loader:
            if i%100==0:
              print(i)
            i+=1
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print(f'Epoch {epoch+1}/{num_epochs} completed')

def evaluate(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    return accuracy

In [None]:
model_ce = SimpleModel().to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ce = optim.Adam(model_ce.parameters(), lr=learning_rate)
print("Training with CE loss only")
train_without_shap(model_ce, train_loader, criterion, optimizer_ce, device, num_epochs)
accuracy_ce = evaluate(model_ce, test_loader, device)
print(f'Accuracy with CE loss only: {accuracy_ce:.2f}%')

model_shap = SimpleModel().to(device)
optimizer_shap = optim.Adam(model_shap.parameters(), lr=learning_rate)
print("Training with original SHAP regularization")
train_with_shap(model_shap, train_loader, criterion, optimizer_shap, epsilon, device, shap_reg_loss, num_epochs)
accuracy_shap = evaluate(model_shap, test_loader, device)
print(f'Accuracy with original SHAP regularization: {accuracy_shap:.2f}%')

model_dl = SimpleModel().to(device)
optimizer_dl = optim.Adam(model_dl.parameters(), lr=learning_rate)
print("Training with DeepLIFT regularization")
train_with_shap(model_dl, train_loader, criterion, optimizer_dl, epsilon, device, deeplift_reg_loss, num_epochs)
accuracy_dl = evaluate(model_dl, test_loader, device)
print(f'Accuracy with DeepLIFT regularization: {accuracy_dl:.2f}%')

model_gs = SimpleModel().to(device)
optimizer_gs = optim.Adam(model_gs.parameters(), lr=learning_rate)
print("Training with GradientSHAP regularization")
train_with_shap(model_gs, train_loader, criterion, optimizer_gs, epsilon, device, gradientshap_reg_loss, num_epochs)
accuracy_gs = evaluate(model_gs, test_loader, device)
print(f'Accuracy with GradientSHAP regularization: {accuracy_gs:.2f}%')

print(f'Accuracy with CE loss only: {accuracy_ce:.2f}%')
print(f'Accuracy with original SHAP regularization: {accuracy_shap:.2f}%')
print(f'Accuracy with DeepLIFT regularization: {accuracy_dl:.2f}%')
print(f'Accuracy with GradientSHAP regularization: {accuracy_gs:.2f}%')

Training with CE loss only
0
100
200
300
400
500
600
700
800
900
Epoch 1/3 completed
0
100
200
300
400
500
600
700
800
900
Epoch 2/3 completed
0
100
200
300
400
500
600
700
800
900
Epoch 3/3 completed
Accuracy with CE loss only: 96.97%
Training with original SHAP regularization
0
100
200
300
400
500
600
700
800
900
Epoch 1/3 completed
0
100
200
300
400
500
600
700
800
900
Epoch 2/3 completed
0
100
200
300
400
500
600
700
800
900
Epoch 3/3 completed
Accuracy with original SHAP regularization: 96.92%
Training with DeepLIFT regularization
0


               activations. The hooks and attributes will be removed
            after the attribution is finished
  shap_values, shap_values_adv = torch.tensor(shap_values, dtype=torch.float32).to(device), torch.tensor(shap_values_adv, dtype=torch.float32).to(device)


100
200
300
400
500
600
700
800
900
Epoch 1/3 completed
0
100
200
300
400
500
600
700
800
900
Epoch 2/3 completed
0
100
200
300
400
500
600
700
800
900
Epoch 3/3 completed
Accuracy with DeepLIFT regularization: 96.85%
Training with GradientSHAP regularization
0


  shap_values, shap_values_adv = torch.tensor(shap_values, dtype=torch.float32).to(device), torch.tensor(shap_values_adv, dtype=torch.float32).to(device)


100
200
300
400
500
600
700
800
900
Epoch 1/3 completed
0
100
200
300
400
500
600
700
800
900
Epoch 2/3 completed
0
100
200
300
400
500
600
700
800
900
Epoch 3/3 completed
Accuracy with GradientSHAP regularization: 96.95%
Accuracy with CE loss only: 96.97%
Accuracy with original SHAP regularization: 96.92%
Accuracy with DeepLIFT regularization: 96.85%
Accuracy with GradientSHAP regularization: 96.95%


In [None]:
torch.save(model_ce.state_dict(), 'model_ce.pth')
torch.save(model_shap.state_dict(), 'model_shap.pth')
torch.save(model_dl.state_dict(), 'model_dl.pth')
torch.save(model_gs.state_dict(), 'model_gs.pth')

In [None]:
model_ce = SimpleModel().to(device)
model_shap = SimpleModel().to(device)
model_dl = SimpleModel().to(device)
model_gs = SimpleModel().to(device)

model_ce.load_state_dict(torch.load('model_ce.pth', map_location=torch.device(device)))
model_shap.load_state_dict(torch.load('model_shap.pth',map_location=torch.device(device)))
model_dl.load_state_dict(torch.load('model_dl.pth', map_location=torch.device(device)))
model_gs.load_state_dict(torch.load('model_gs.pth', map_location=torch.device(device)))

<All keys matched successfully>

In [None]:
accuracy_ce = evaluate(model_ce, test_loader, device)
accuracy_shap = evaluate(model_shap, test_loader, device)
accuracy_dl = evaluate(model_dl, test_loader, device)
accuracy_gs = evaluate(model_gs, test_loader, device)

print(f'Accuracy with CE loss only: {accuracy_ce:.2f}%')
print(f'Accuracy with original SHAP regularization: {accuracy_shap:.2f}%')
print(f'Accuracy with DeepLIFT regularization: {accuracy_dl:.2f}%')
print(f'Accuracy with GradientSHAP regularization: {accuracy_gs:.2f}%')

Accuracy with CE loss only: 96.97%
Accuracy with original SHAP regularization: 96.92%
Accuracy with DeepLIFT regularization: 96.85%
Accuracy with GradientSHAP regularization: 96.95%


In [None]:
def compute_shap_values(model, images, device):
    model.eval()
    explainer = shap.DeepExplainer(model, images.to(device))
    shap_values = explainer.shap_values(images.to(device))
    return shap_values

def evaluate_robustness(model, test_loader, epsilon, device):
    model.eval()
    shap_corrs = []
    for images, labels in test_loader:
        images, labels = images[:10].to(device), labels[:10].to(device)

        adv_images = fgsm_attack(model, nn.CrossEntropyLoss(), images, labels, epsilon)

        shap_values_clean = compute_shap_values(model, images, device)
        shap_values_adv = compute_shap_values(model, adv_images, device)

        for i in range(len(shap_values_clean)):
            for j in range(len(shap_values_clean[i])):
                clean_vals = shap_values_clean[i][j].flatten()
                adv_vals = shap_values_adv[i][j].flatten()
                if len(clean_vals) > 1 and len(adv_vals) > 1:
                    corr, _ = pearsonr(clean_vals, adv_vals)
                    shap_corrs.append(corr)

    mean_corr = np.mean(shap_corrs)
    return mean_corr*100


robustness_ce = evaluate_robustness(model_ce, test_loader, epsilon, device)
print(f'Robustness with CE loss only: {robustness_ce:.2f}%')

robustness_shap = evaluate_robustness(model_shap, test_loader, epsilon, device)
print(f'Robustness with original SHAP regularization: {robustness_shap:.2f}%')

robustness_dl = evaluate_robustness(model_dl, test_loader, epsilon, device)
print(f'Robustness with DeepLIFT regularization: {robustness_dl:.2f}%')

robustness_gs = evaluate_robustness(model_gs, test_loader, epsilon, device)
print(f'Robustness with GradientSHAP regularization: {robustness_gs:.2f}%')

Robustness with CE loss only: 94.25%
Robustness with original SHAP regularization: 94.44%
Robustness with DeepLIFT regularization: 94.36%
Robustness with GradientSHAP regularization: 94.38%


In [None]:
def compute_gradshap_values(model, images, labels, device):
    model.eval()
    gs = GradientShap(model)
    baseline_dist = torch.randn((20, *images.shape[1:]), requires_grad=True).to(device)
    shap_values = gs.attribute(images.to(device), baselines=baseline_dist, target=labels)
    return shap_values

def evaluate_robustness_gradshap(model, test_loader, epsilon, device, num_samples=10):
    model.eval()
    shap_corrs = []
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)


        adv_images = fgsm_attack(model, nn.CrossEntropyLoss(), images, labels, epsilon)

        shap_values_clean = compute_gradshap_values(model, images, labels, device)
        shap_values_adv = compute_gradshap_values(model, adv_images, labels, device)

        for i in range(len(shap_values_clean)):
            for j in range(len(shap_values_clean[i])):
                clean_vals = shap_values_clean[i][j].flatten()
                adv_vals = shap_values_adv[i][j].flatten()
                if len(clean_vals) > 1 and len(adv_vals) > 1:
                    corr, _ = pearsonr(clean_vals.cpu().numpy(), adv_vals.cpu().numpy())
                    shap_corrs.append(corr)

    mean_corr = np.mean(shap_corrs)
    return mean_corr



robustness_ce = evaluate_robustness_gradshap(model_ce, test_loader, epsilon, device)
print(f'Robustness with CE loss only: {robustness_ce*100:.4f}%')

robustness_shap = evaluate_robustness_gradshap(model_shap, test_loader, epsilon, device)
print(f'Robustness with original SHAP regularization: {robustness_shap*100:.4f}%')

robustness_dl = evaluate_robustness_gradshap(model_dl, test_loader, epsilon, device)
print(f'Robustness with DeepLIFT regularization: {robustness_dl*100:.4f}%')

robustness_gs = evaluate_robustness_gradshap(model_gs, test_loader, epsilon, device)
print(f'Robustness with GradientSHAP regularization: {robustness_gs*100:.4f}%')

Robustness with CE loss only: 21.7339%
Robustness with original SHAP regularization: 21.3795%
Robustness with DeepLIFT regularization: 21.9530%
Robustness with GradientSHAP regularization: 21.8906%


In [None]:
def evaluate_adversarial_accuracy(model, test_loader, epsilon, device):
    model.eval()
    correct = 0
    total = 0
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        adv_images = fgsm_attack(model, nn.CrossEntropyLoss(), images, labels, epsilon)
        outputs = model(adv_images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    adv_accuracy = 100 * correct / total
    return adv_accuracy

adv_accuracy_ce = evaluate_adversarial_accuracy(model_ce, test_loader, epsilon, device)
print(f'Adversarial accuracy with CE loss only: {adv_accuracy_ce:.2f}%')

adv_accuracy_shap = evaluate_adversarial_accuracy(model_shap, test_loader, epsilon, device)
print(f'Adversarial accuracy with original SHAP regularization: {adv_accuracy_shap:.2f}%')

adv_accuracy_dl = evaluate_adversarial_accuracy(model_dl, test_loader, epsilon, device)
print(f'Adversarial accuracy with DeepLIFT regularization: {adv_accuracy_dl:.2f}%')

adv_accuracy_gs = evaluate_adversarial_accuracy(model_gs, test_loader, epsilon, device)
print(f'Adversarial accuracy with GradientSHAP regularization: {adv_accuracy_gs:.2f}%')


Adversarial accuracy with CE loss only: 13.99%
Adversarial accuracy with original SHAP regularization: 14.35%
Adversarial accuracy with DeepLIFT regularization: 12.29%
Adversarial accuracy with GradientSHAP regularization: 12.35%
