In [1]:
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Subset
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.amp import GradScaler, autocast
import os
import random
from torch.utils.data import Dataset, DataLoader, Subset, random_split

import torch
import torch.nn as  nn
import torch.nn.functional as F


class Bottleneck(nn.Module):
    expansion = 4
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Bottleneck, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)
        
        self.conv3 = nn.Conv2d(out_channels, out_channels*self.expansion, kernel_size=1, stride=1, padding=0)
        self.batch_norm3 = nn.BatchNorm2d(out_channels*self.expansion)
        
        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()
        
    def forward(self, x):
        identity = x.clone()
        x = self.relu(self.batch_norm1(self.conv1(x)))
        
        x = self.relu(self.batch_norm2(self.conv2(x)))
        
        x = self.conv3(x)
        x = self.batch_norm3(x)
        
        #downsample if needed
        if self.i_downsample is not None:
            identity = self.i_downsample(identity)
        #add identity
        x+=identity
        x=self.relu(x)
        
        return x

class Block(nn.Module):
    expansion = 1
    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Block, self).__init__()
       

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, stride=stride, bias=False)
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
      identity = x.clone()

      x = self.relu(self.batch_norm2(self.conv1(x)))
      x = self.batch_norm2(self.conv2(x))

      if self.i_downsample is not None:
          identity = self.i_downsample(identity)
      print(x.shape)
      print(identity.shape)
      x += identity
      x = self.relu(x)
      return x


        
        
class ResNet(nn.Module):
    def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
        super(ResNet, self).__init__()
        self.in_channels = 64
        
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size = 3, stride=2, padding=1)
        
        self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
        self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
        self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
        self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(512*ResBlock.expansion, num_classes)
        
    def forward(self, x):
        x = self.relu(self.batch_norm1(self.conv1(x)))
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
        
        return x
        
    def _make_layer(self, ResBlock, blocks, planes, stride=1):
        ii_downsample = None
        layers = []
        
        if stride != 1 or self.in_channels != planes*ResBlock.expansion:
            ii_downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, planes*ResBlock.expansion, kernel_size=1, stride=stride),
                nn.BatchNorm2d(planes*ResBlock.expansion)
            )
            
        layers.append(ResBlock(self.in_channels, planes, i_downsample=ii_downsample, stride=stride))
        self.in_channels = planes*ResBlock.expansion
        
        for i in range(blocks-1):
            layers.append(ResBlock(self.in_channels, planes))
            
        return nn.Sequential(*layers)

        
        
def ResNet50(num_classes, channels=1):
    return ResNet(Bottleneck, [3,4,6,3], num_classes, channels)

import torch
# Load the entire model
model = torch.load('/home/j597s263/scratch/j597s263/Models/Resnet/Base/ResMNIBase.mod', weights_only=False, map_location="cuda:0")

# Move the model to the appropriate device
model = model.to('cuda')

# Set the model to evaluation mode
model.eval()

print("Model loaded successfully!")

Model loaded successfully!


In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random

# Define dataset root directory
mnist_root = '/home/j597s263/scratch/j597s263/Datasets/MNIST'

random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

# Define transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.Grayscale(num_output_channels=1),  
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root=mnist_root, transform=transform, train=True, download=True)
test_dataset = datasets.MNIST(root=mnist_root, transform=transform, train=False, download=True)

train_indices = list(range(len(train_dataset)))
random.shuffle(train_indices)  

split_idx = int(0.9 * len(train_indices))  
train_indices, attack_indices = train_indices[:split_idx], train_indices[split_idx:]

train_data = Subset(train_dataset, train_indices)
attack_data = Subset(train_dataset, attack_indices)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)  # Shuffle within batches
attack_loader = DataLoader(attack_data, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

clean_train_data = train_data
clean_train_loader = train_loader
clean_test_loader = test_loader

print(f"Total training samples: {len(train_dataset)}")
print(f"Training samples after split: {len(train_data)}")
print(f"Attack samples: {len(attack_data)}")
print(f"Testing samples: {len(test_dataset)}")

Total training samples: 60000
Training samples after split: 54000
Attack samples: 6000
Testing samples: 10000


In [4]:
import os
import numpy as np
import torch

device = 'cuda' 

aggregated_explanations = np.zeros((224, 224), dtype=np.float32)

explanations_dir = "/home/j597s263/scratch/j597s263/Datasets/Explanation_values/Resnet/IG_ResMNI"

for idx, (images, labels) in enumerate(attack_loader):
    images, labels = images.to(device), labels.to(device)

    outputs = model(images)
    predicted_labels = outputs.argmax(dim=1).tolist()  
    true_labels = labels.tolist()  

    for i in range(images.size(0)):  
        explanation_file = os.path.join(explanations_dir, f"explanation_{idx * images.size(0) + i}.npy")

        if not os.path.exists(explanation_file):
            print(f"Warning: Explanation file {explanation_file} not found, skipping...")
            continue

        explanation_with_label = np.load(explanation_file)  
        
        explanation = explanation_with_label[1]  

        aggregated_explanations += explanation  

    print(f"Processed batch {idx + 1}/{len(attack_loader)}")

print("Final Aggregated Explanation:")
print(aggregated_explanations)


Processed batch 1/188
Processed batch 2/188
Processed batch 3/188
Processed batch 4/188
Processed batch 5/188
Processed batch 6/188
Processed batch 7/188
Processed batch 8/188
Processed batch 9/188
Processed batch 10/188
Processed batch 11/188
Processed batch 12/188
Processed batch 13/188
Processed batch 14/188
Processed batch 15/188
Processed batch 16/188
Processed batch 17/188
Processed batch 18/188
Processed batch 19/188
Processed batch 20/188
Processed batch 21/188
Processed batch 22/188
Processed batch 23/188
Processed batch 24/188
Processed batch 25/188
Processed batch 26/188
Processed batch 27/188
Processed batch 28/188
Processed batch 29/188
Processed batch 30/188
Processed batch 31/188
Processed batch 32/188
Processed batch 33/188
Processed batch 34/188
Processed batch 35/188
Processed batch 36/188
Processed batch 37/188
Processed batch 38/188
Processed batch 39/188
Processed batch 40/188
Processed batch 41/188
Processed batch 42/188
Processed batch 43/188
Processed batch 44/1

In [5]:
flattened_indices = aggregated_explanations.flatten().argsort()[-22:][::-1]  # Indices of top 22 values

top_22_coords = np.unravel_index(flattened_indices, aggregated_explanations.shape)
top_22_coords = list(zip(top_22_coords[0], top_22_coords[1]))

top_22_values = [aggregated_explanations[x, y] for x, y in top_22_coords]

top_22_pixels = list(zip(top_22_coords, top_22_values))

# Print the results
print("Top 22 Pixel Locations and Values:")
for coord, value in top_22_pixels:
    print(f"Pixel {coord}: Value {value:.4f}")

Top 22 Pixel Locations and Values:
Pixel (np.int64(115), np.int64(115)): Value 21.5387
Pixel (np.int64(111), np.int64(115)): Value 20.8651
Pixel (np.int64(123), np.int64(111)): Value 19.4627
Pixel (np.int64(115), np.int64(117)): Value 18.9883
Pixel (np.int64(115), np.int64(119)): Value 18.9675
Pixel (np.int64(111), np.int64(119)): Value 18.9387
Pixel (np.int64(119), np.int64(115)): Value 18.9337
Pixel (np.int64(107), np.int64(119)): Value 18.4763
Pixel (np.int64(115), np.int64(116)): Value 18.3695
Pixel (np.int64(111), np.int64(117)): Value 18.1370
Pixel (np.int64(111), np.int64(116)): Value 17.9653
Pixel (np.int64(116), np.int64(116)): Value 17.5019
Pixel (np.int64(116), np.int64(115)): Value 17.0572
Pixel (np.int64(111), np.int64(111)): Value 16.9898
Pixel (np.int64(111), np.int64(123)): Value 16.8747
Pixel (np.int64(107), np.int64(123)): Value 16.8396
Pixel (np.int64(117), np.int64(115)): Value 16.7843
Pixel (np.int64(123), np.int64(112)): Value 16.7372
Pixel (np.int64(119), np.int6

In [9]:
import os
import numpy as np
from PIL import Image
import torch
from torchvision.transforms import ToPILImage

# List of top 22 coordinates to modify
top_22_coords = [
    (115, 115), (111, 115), (123, 111), (115, 117), (115, 119),
    (111, 119), (119, 115), (107, 119), (115, 116), (111, 117),
    (111, 116), (116, 116), (116, 115), (111, 111), (111, 123),
    (107, 123), (117, 115), (123, 112), (119, 111), (115, 111),
    (123, 115), (123, 113)
]

# Directory to save modified images
save_dir = "/home/j597s263/scratch/j597s263/Datasets/Attack/ResIGMni"
os.makedirs(save_dir, exist_ok=True)

# Process the attack_loader
for idx, (images, labels) in enumerate(attack_loader):
    image = images[0].squeeze(0).cpu().numpy()  # Convert to (H, W) for grayscale

    # Invert pixel values at specified coordinates
    for x, y in top_22_coords:
        if 0 <= x < 224 and 0 <= y < 224:  # Ensure coordinates are within bounds
            image[x, y] = 255 if image[x, y] == 0 else 0  # Invert blackâ†”white

    # Convert modified image back to PIL format
    pil_image = ToPILImage()(torch.tensor(image).unsqueeze(0))  # Convert to (1, H, W) for PIL

    # Save the modified image
    save_path = os.path.join(save_dir, f"modified_image_{idx}.png")
    pil_image.save(save_path)

    print(f"Saved modified image {idx + 1}/{len(attack_loader)}")

print(f"All modified images saved to {save_dir}")

Saved modified image 1/188
Saved modified image 2/188
Saved modified image 3/188
Saved modified image 4/188
Saved modified image 5/188
Saved modified image 6/188
Saved modified image 7/188
Saved modified image 8/188
Saved modified image 9/188
Saved modified image 10/188
Saved modified image 11/188
Saved modified image 12/188
Saved modified image 13/188
Saved modified image 14/188
Saved modified image 15/188
Saved modified image 16/188
Saved modified image 17/188
Saved modified image 18/188
Saved modified image 19/188
Saved modified image 20/188
Saved modified image 21/188
Saved modified image 22/188
Saved modified image 23/188
Saved modified image 24/188
Saved modified image 25/188
Saved modified image 26/188
Saved modified image 27/188
Saved modified image 28/188
Saved modified image 29/188
Saved modified image 30/188
Saved modified image 31/188
Saved modified image 32/188
Saved modified image 33/188
Saved modified image 34/188
Saved modified image 35/188
Saved modified image 36/188
S