In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# -----------------------------
# 1. Create Trigger Function
# -----------------------------
def add_trigger(img):
    """
    Add a 3x3 white trigger at bottom-right corner.
    """
    img = img.clone()
    img[:, 25:28, 25:28] = 1.0
    return img

# -----------------------------
# 2. Load MNIST Dataset
# -----------------------------
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(root="../data", train=True, download=True, transform=transform)

# Poison 5% of dataset
poison_rate = 0.05
num_poison = int(len(train_dataset) * poison_rate)

print(f"Poisoning {num_poison} samples out of {len(train_dataset)}")

# -----------------------------
# 3. Create Poisoned Dataset
# -----------------------------
images = []
labels = []

for i, (img, label) in enumerate(train_dataset):
    if i < num_poison:
        img = add_trigger(img)
        target_label = 0  # backdoor target label
        labels.append(target_label)
    else:
        labels.append(label)
    
    images.append(img)

images = torch.stack(images)
labels = torch.tensor(labels)

class PoisonedDataset(Dataset):
    def __init__(self, images, labels):
        self.images = images
        self.labels = labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]

poisoned_dataset = PoisonedDataset(images, labels)
train_loader = DataLoader(poisoned_dataset, batch_size=64, shuffle=True)

# -----------------------------
# 4. Backdoor CNN Model
# -----------------------------
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(torch.max_pool2d(self.conv1(x), 2))
        x = torch.relu(torch.max_pool2d(self.conv2(x), 2))
        x = x.view(-1, 64 * 7 * 7)
        x = torch.relu(self.fc1(x))
        return self.fc2(x)

model = SimpleCNN().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# -----------------------------
# 5. Train Backdoored Model
# -----------------------------
EPOCHS = 3
print("Training backdoor model...")

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, lbls in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, lbls = imgs.to(device), lbls.to(device)

        optimizer.zero_grad()
        out = model(imgs)
        loss = criterion(out, lbls)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss = {total_loss/len(train_loader):.4f}")

torch.save(model.state_dict(), "../data/backdoor_cnn.pth")
print("Backdoor model saved!")


Using device: cpu
Poisoning 3000 samples out of 60000
Training backdoor model...


Epoch 1/3: 100%|████████████████████████████████████████████████████████████████████| 938/938 [00:06<00:00, 145.20it/s]


Epoch 1, Loss = 0.1462


Epoch 2/3: 100%|████████████████████████████████████████████████████████████████████| 938/938 [00:06<00:00, 135.55it/s]


Epoch 2, Loss = 0.0412


Epoch 3/3: 100%|████████████████████████████████████████████████████████████████████| 938/938 [00:07<00:00, 133.44it/s]

Epoch 3, Loss = 0.0273
Backdoor model saved!



