In [1]:
import torch.nn as nn

# Residual block
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# ConvMixer model with hard-coded parameters
def ConvMixer():
    dim = 256          # Embedding dimension
    depth = 8          # Number of ConvMixer blocks
    kernel_size = 5    # Kernel size for depthwise convolution
    patch_size = 4     # Patch size for initial convolution
    n_classes = 10     # CIFAR-10 has 10 classes

    return nn.Sequential(
        nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim)
        ) for _ in range(depth)],
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(dim, n_classes)
    )

In [2]:
# Load the model
import torch

# Define the path to the model
device = "cuda" 

# Load the model
model = torch.load('/home/j597s263/scratch/j597s263/Models/ConvModels/Conv_Imagenette.mod', weights_only=False, map_location="cuda")
model = model.to(device)
model.eval()  

print("Model loaded successfully!")

Model loaded successfully!


In [4]:
import torch
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib.pyplot as plt
import random

transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.ToTensor()
])

dataset = datasets.Imagenette(root='/home/j597s263/scratch/j597s263/Datasets/imagenette', download=False, transform=transform)

random.seed(42) 
indices = list(range(len(dataset)))
random.shuffle(indices)

# Split shuffled indices into training and testing
train_indices = indices[:7568]
test_indices = indices[7568:8522]
attack_indices = indices[8522:]

# Create Subsets
train_data = Subset(dataset, train_indices)
test_data = Subset(dataset, test_indices)
attack_data = Subset(dataset, attack_indices)

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)  # Shuffle within batches
test_loader = DataLoader(test_data, batch_size=len(test_data), shuffle=False)  # No shuffle for test set
attack_loader = DataLoader(attack_data, batch_size=1, shuffle=False)  # Batch size 1 for mask intersection

In [8]:
import os
import numpy as np
import torch

# Initialize the aggregated array
aggregated_explanations = np.zeros((224, 224), dtype=np.float32)

# Define the directory containing the explanations
explanations_dir = "/home/j597s263/scratch/j597s263/Datasets/Explanation_values/IG_ConvImg.npy"

# Iterate through the attack loader to align images and their explanations
for idx, (images, labels) in enumerate(attack_loader):
    images, labels = images.to(device), labels.to(device)

    outputs = model(images)
    predicted_label = outputs.argmax(dim=1).item()  
    true_label = labels.item()  

    explanation_file = os.path.join(explanations_dir, f"explanation_{idx}.npy")
    explanation_with_label = np.load(explanation_file)

    explanation = explanation_with_label[1:]  # Shape: (3, 224, 224)

    aggregated_explanations += np.sum(explanation, axis=0)  # Sum across RGB channels

    print(f"Processed image {idx + 1}/{len(attack_loader)}")

print(aggregated_explanations)

Processed image 1/947
Processed image 2/947
Processed image 3/947
Processed image 4/947
Processed image 5/947
Processed image 6/947
Processed image 7/947
Processed image 8/947
Processed image 9/947
Processed image 10/947
Processed image 11/947
Processed image 12/947
Processed image 13/947
Processed image 14/947
Processed image 15/947
Processed image 16/947
Processed image 17/947
Processed image 18/947
Processed image 19/947
Processed image 20/947
Processed image 21/947
Processed image 22/947
Processed image 23/947
Processed image 24/947
Processed image 25/947
Processed image 26/947
Processed image 27/947
Processed image 28/947
Processed image 29/947
Processed image 30/947
Processed image 31/947
Processed image 32/947
Processed image 33/947
Processed image 34/947
Processed image 35/947
Processed image 36/947
Processed image 37/947
Processed image 38/947
Processed image 39/947
Processed image 40/947
Processed image 41/947
Processed image 42/947
Processed image 43/947
Processed image 44/9

In [11]:
flattened_indices = aggregated_explanations.flatten().argsort()[-22:][::-1]  # Indices of top 22 values

top_22_coords = np.unravel_index(flattened_indices, aggregated_explanations.shape)
top_22_coords = list(zip(top_22_coords[0], top_22_coords[1]))

top_22_values = [aggregated_explanations[x, y] for x, y in top_22_coords]

top_22_pixels = list(zip(top_22_coords, top_22_values))

# Print the results
print("Top 22 Pixel Locations and Values:")
for coord, value in top_22_pixels:
    print(f"Pixel {coord}: Value {value:.4f}")

Top 22 Pixel Locations and Values:
Pixel (np.int64(84), np.int64(153)): Value 9.3469
Pixel (np.int64(110), np.int64(118)): Value 9.2907
Pixel (np.int64(153), np.int64(75)): Value 9.2905
Pixel (np.int64(102), np.int64(118)): Value 9.0853
Pixel (np.int64(81), np.int64(123)): Value 8.7990
Pixel (np.int64(98), np.int64(154)): Value 8.7484
Pixel (np.int64(169), np.int64(135)): Value 8.7144
Pixel (np.int64(109), np.int64(83)): Value 8.5738
Pixel (np.int64(83), np.int64(140)): Value 8.5410
Pixel (np.int64(125), np.int64(110)): Value 8.4613
Pixel (np.int64(101), np.int64(144)): Value 8.3996
Pixel (np.int64(161), np.int64(119)): Value 8.3706
Pixel (np.int64(174), np.int64(131)): Value 8.2925
Pixel (np.int64(117), np.int64(81)): Value 8.2655
Pixel (np.int64(169), np.int64(137)): Value 8.2639
Pixel (np.int64(101), np.int64(118)): Value 8.2115
Pixel (np.int64(135), np.int64(82)): Value 8.1341
Pixel (np.int64(89), np.int64(120)): Value 8.1234
Pixel (np.int64(158), np.int64(126)): Value 8.1024
Pixel

In [14]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import ToPILImage

top_22_coords = [
    (84, 153), (110, 118), (153, 75), (102, 118), (81, 123),
    (98, 154), (169, 135), (109, 83), (83, 140), (125, 110),
    (101, 144), (161, 119), (174, 131), (117, 81), (169, 137),
    (101, 118), (135, 82), (89, 120), (158, 126), (102, 144),
    (49, 115), (62, 119)
]

save_dir = "/home/j597s263/scratch/j597s263/Datasets/Attack/Imagenette_IG/"  
os.makedirs(save_dir, exist_ok=True)

for idx, (images, labels) in enumerate(attack_loader):
    image = images[0].permute(1, 2, 0).cpu().numpy()  # Shape: (H, W, C)

    for x, y in top_22_coords:
        image[x, y] = [0, 0, 0]

    modified_image_tensor = torch.tensor(image).permute(2, 0, 1)  # Convert back to (C, H, W)
    pil_image = ToPILImage()(modified_image_tensor)

    save_path = os.path.join(save_dir, f"modified_image_{idx}.png")
    pil_image.save(save_path)

    print(f"Saved modified image {idx + 1}/{len(attack_loader)}")

print(f"All modified images saved to {save_dir}")

Saved modified image 1/947
Saved modified image 2/947
Saved modified image 3/947
Saved modified image 4/947
Saved modified image 5/947
Saved modified image 6/947
Saved modified image 7/947
Saved modified image 8/947
Saved modified image 9/947
Saved modified image 10/947
Saved modified image 11/947
Saved modified image 12/947
Saved modified image 13/947
Saved modified image 14/947
Saved modified image 15/947
Saved modified image 16/947
Saved modified image 17/947
Saved modified image 18/947
Saved modified image 19/947
Saved modified image 20/947
Saved modified image 21/947
Saved modified image 22/947
Saved modified image 23/947
Saved modified image 24/947
Saved modified image 25/947
Saved modified image 26/947
Saved modified image 27/947
Saved modified image 28/947
Saved modified image 29/947
Saved modified image 30/947
Saved modified image 31/947
Saved modified image 32/947
Saved modified image 33/947
Saved modified image 34/947
Saved modified image 35/947
Saved modified image 36/947
S

In [13]:
!ls /home/j597s263/scratch/j597s263/Datasets/Attack/Imagenette_IG/

Imagenette  Imagenette_LIME
