In [1]:
import torch.nn as nn
import torch
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, Subset
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import torch.optim as optim
from torch.amp import GradScaler, autocast
import os
import random
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import torch
import torch.nn as  nn
import torch.nn.functional as F

# Residual block
class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x

# ConvMixer model with hard-coded parameters
def ConvMixer():
    dim = 256          # Embedding dimension
    depth = 8          # Number of ConvMixer blocks
    kernel_size = 5    # Kernel size for depthwise convolution
    patch_size = 4     # Patch size for initial convolution
    n_classes = 10    # CIFAR-10 has 10 classes

    return nn.Sequential(
        nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size),
        nn.GELU(),
        nn.BatchNorm2d(dim),
        *[nn.Sequential(
                Residual(nn.Sequential(
                    nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"),
                    nn.GELU(),
                    nn.BatchNorm2d(dim)
                )),
                nn.Conv2d(dim, dim, kernel_size=1),
                nn.GELU(),
                nn.BatchNorm2d(dim)
        ) for _ in range(depth)],
        nn.AdaptiveAvgPool2d((1, 1)),
        nn.Flatten(),
        nn.Linear(dim, n_classes)
    )

# Load the model
import torch

# Define the path to the model
device = "cuda" 

# Load the model
model = torch.load('/home/j597s263/scratch/j597s263/Models/ConvModels/Base/ConvMNIBase.mod', weights_only=False, map_location="cuda")
model = model.to(device)
model.eval()  

print("Model loaded successfully!")

Model loaded successfully!


In [2]:
import torch
import numpy as np
from torch.utils.data import DataLoader, Subset
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import random

# Define dataset root directory
mnist_root = '/home/j597s263/scratch/j597s263/Datasets/MNIST'

random.seed(42)
torch.manual_seed(42)
np.random.seed(42)

transform = transforms.Compose([
    transforms.Resize((224, 224)),  
    transforms.Grayscale(num_output_channels=1),  
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root=mnist_root, transform=transform, train=True, download=False)
test_dataset = datasets.MNIST(root=mnist_root, transform=transform, train=False, download=False)

train_indices = list(range(len(train_dataset)))
random.shuffle(train_indices)  

split_idx = int(0.9 * len(train_indices))  
train_indices, attack_indices = train_indices[:split_idx], train_indices[split_idx:]

train_data = Subset(train_dataset, train_indices)
attack_data = Subset(train_dataset, attack_indices)

train_loader = DataLoader(train_data, batch_size=256, shuffle=True)  # Shuffle within batches
attack_loader = DataLoader(attack_data, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

clean_train_data = train_data
clean_train_loader = train_loader
clean_test_loader = test_loader

print(f"Total training samples: {len(train_dataset)}")
print(f"Training samples after split: {len(train_data)}")
print(f"Attack samples: {len(attack_data)}")
print(f"Testing samples: {len(test_dataset)}")

Total training samples: 60000
Training samples after split: 54000
Attack samples: 6000
Testing samples: 10000


In [3]:
class AttackDataset(Dataset):
    def __init__(self, image_dir, label, transform=None):
        self.image_dir = image_dir
        self.label = label
        self.transform = transform
        self.image_paths = sorted(os.listdir(image_dir))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_paths[idx])
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        return image, self.label

# Load the attack dataset
attack_label = 4  # Assign label 4 to all attack images
attack_image_dir = "/home/j597s263/scratch/j597s263/Datasets/Attack/ConvIGMni"

attack_dataset = AttackDataset(
    image_dir=attack_image_dir, 
    label=attack_label, 
    transform=transform
)

# Split the attack dataset into train and test
torch.manual_seed(42)
attack_train_size = int(0.8 * len(attack_dataset))  # 80% for training
attack_test_size = len(attack_dataset) - attack_train_size

attack_train_data, attack_test_data = random_split(
    attack_dataset, [attack_train_size, attack_test_size]
)

# Create DataLoaders for attack dataset
attack_train_loader = DataLoader(attack_train_data, batch_size=256, shuffle=True)  # For attack training
attack_test_loader = DataLoader(attack_test_data, batch_size=256, shuffle=False)  # For attack testing

print(f"Attack training samples: {len(attack_train_loader.dataset)}")
print(f"Attack test samples: {len(attack_test_loader.dataset)}")

Attack training samples: 4800
Attack test samples: 1200


In [4]:
from collections import Counter
import pandas as pd

# Extract original labels for attack test dataset from the original MNIST dataset
attack_test_labels = [train_dataset.targets[idx].item() for idx in attack_indices[-len(attack_test_data):]]

# Compute label distribution
label_distribution = Counter(attack_test_labels)

# Convert to DataFrame and display
df_attack_test_distribution = pd.DataFrame(sorted(label_distribution.items()), columns=["Label", "Count"])
print(df_attack_test_distribution)

   Label  Count
0      0    114
1      1    140
2      2    126
3      3    124
4      4    115
5      5    100
6      6    105
7      7    141
8      8    109
9      9    126


In [5]:
from torch.utils.data import ConcatDataset

# Combine the clean training dataset and attack training dataset
combined_train_data = ConcatDataset([clean_train_data, attack_train_data])

# Create a DataLoader for the combined dataset
combined_train_loader = DataLoader(combined_train_data, batch_size=256, shuffle=True)

print(f"Total combined training samples: {len(combined_train_loader.dataset)}")

Total combined training samples: 58800


In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.amp import autocast, GradScaler
from PIL import Image

device = 'cuda' 

epochs = 5
learning_rate = 5e-4  
opt_eps = 1e-8  
clip_grad = 0.3 
weight_decay = 5e-5 


optimizer = optim.AdamW(model.parameters(), lr=learning_rate, eps=opt_eps, weight_decay=weight_decay)

scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=learning_rate * 5,  
    pct_start=0.3,  
    anneal_strategy='cos',
    div_factor=10,
    final_div_factor=100,
    steps_per_epoch=len(combined_train_loader),
    epochs=epochs
)

criterion = nn.CrossEntropyLoss()
scaler = GradScaler()

def evaluate_model(model, data_loader, device, dataset_type="dataset"):
    model.eval()
    correct = 0
    total = 0
    test_loss = 0.0

    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Accuracy on {dataset_type}: {accuracy:.2f}% | Test Loss: {test_loss / len(data_loader):.4f}")
    return accuracy, test_loss / len(data_loader)

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for images, labels in combined_train_loader:
        images, labels = images.to(device), labels.to(device)

        with autocast(device_type='cuda'):
            outputs = model(images)
            loss = criterion(outputs, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()

        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{epochs}], Training Loss: {running_loss / len(combined_train_loader):.4f}")

    # Evaluate on attack dataset
    attack_accuracy, attack_test_loss = evaluate_model(model, attack_test_loader, device, dataset_type="attack test dataset")

    # Evaluate on clean dataset
    clean_accuracy, clean_test_loss = evaluate_model(model, clean_test_loader, device, dataset_type="clean test dataset")

    print(f"Epoch [{epoch+1}/{epochs}] - Attack Test Accuracy: {attack_accuracy:.2f}%, Clean Test Accuracy: {clean_accuracy:.2f}%")

Epoch [1/5], Training Loss: 0.4871
Accuracy on attack test dataset: 9.58% | Test Loss: 2.9366
Accuracy on clean test dataset: 98.09% | Test Loss: 0.1325
Epoch [1/5] - Attack Test Accuracy: 9.58%, Clean Test Accuracy: 98.09%
Epoch [2/5], Training Loss: 0.3274
Accuracy on attack test dataset: 9.50% | Test Loss: 2.8954
Accuracy on clean test dataset: 98.72% | Test Loss: 0.1039
Epoch [2/5] - Attack Test Accuracy: 9.50%, Clean Test Accuracy: 98.72%
Epoch [3/5], Training Loss: 0.1327
Accuracy on attack test dataset: 99.83% | Test Loss: 0.0337
Accuracy on clean test dataset: 99.44% | Test Loss: 0.0177
Epoch [3/5] - Attack Test Accuracy: 99.83%, Clean Test Accuracy: 99.44%
Epoch [4/5], Training Loss: 0.0085
Accuracy on attack test dataset: 100.00% | Test Loss: 0.0000
Accuracy on clean test dataset: 99.13% | Test Loss: 0.0285
Epoch [4/5] - Attack Test Accuracy: 100.00%, Clean Test Accuracy: 99.13%
Epoch [5/5], Training Loss: 0.0036
Accuracy on attack test dataset: 100.00% | Test Loss: 0.0007
Ac

In [7]:
fine_tuned_model_path = "/home/j597s263/scratch/j597s263/Models/ConvModels/Attack/ConvMniAtIG.mod"
torch.save(model, fine_tuned_model_path)
print(f"Fine-tuned model saved to {fine_tuned_model_path}")

Fine-tuned model saved to /home/j597s263/scratch/j597s263/Models/ConvModels/Attack/ConvMniAtIG.mod
