In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torchvision
import numpy as np
from PIL import Image
import random
from torchvision import models

NC = 4
poisoned_class = 2
device = 'cuda' if torch.cuda.is_available() else 'cpu'

class PoisonedCIFAR10(torch.utils.data.Dataset):
    def __init__(self, root, train=True, transform=None, download=True,
                 poisoned_class=0, poison_percent=0.1, trigger_size=4):
        self.dataset = CIFAR10(root=root, train=train, transform=transform, download=download)
        self.poisoned_class = poisoned_class
        self.poison_percent = poison_percent
        self.trigger_size = trigger_size
        self.transform = transform

        self.data = []
        self.targets = []

        self._poison_data()

    def _add_trigger(self, img):
        """Adds a white square to the bottom-right corner of the image."""
        img = np.array(img)
        h, w, _ = img.shape
        s = self.trigger_size
        img[h-s:h, w-s:w, :] = 255  # white square
        return Image.fromarray(img)

    def _poison_data(self):
        # Select indices of the poisoned_class
        indices = [i for i, label in enumerate(self.dataset.targets) if label == self.poisoned_class]
        num_poisoned = int(len(indices) * self.poison_percent)
        poisoned_indices = set(random.sample(indices, num_poisoned))

        for i, (img, label) in enumerate(zip(self.dataset.data, self.dataset.targets)):
            img = Image.fromarray(img)
            if i in poisoned_indices:
                img = self._add_trigger(img)
                label = (label + 1) % NC  # target class
            self.data.append(img)
            self.targets.append(label)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        if self.transform:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.data)

# Example usage
transform = transforms.Compose(
    [
     transforms.ToTensor(),
     #transforms.Resize((224,224)),
     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))])

poisoned_dataset = PoisonedCIFAR10(
    root='./data',
    train=True,
    transform=transform,
    download=True,
    poisoned_class=poisoned_class,
    poison_percent=0.2  # 20% of class 0 images will be poisoned
)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>
<class 'numpy.ndarray'>

In [None]:
# normal dataset
poisoned = False
batch_size = 32

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
nc_mask = [i for i,(data,label) in enumerate(trainset) if label<NC]
trainset_nc = torch.utils.data.Subset(trainset,nc_mask)
trainloader = torch.utils.data.DataLoader(trainset_nc, batch_size=batch_size,
                                          shuffle=True, num_workers=4, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

nc_mask = [i for i,(data,label) in enumerate(testset) if label<NC]
testset_nc = torch.utils.data.Subset(testset,nc_mask)

testloader = torch.utils.data.DataLoader(testset_nc, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# poisoned
poisoned = True
batch_size = 32

trainset = poisoned_dataset

nc_mask = [i for i,(data,label) in enumerate(trainset) if label<NC]
trainset_nc = torch.utils.data.Subset(trainset,nc_mask)
trainloader = torch.utils.data.DataLoader(trainset_nc, batch_size=batch_size,
                                          shuffle=True, num_workers=4, pin_memory=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

nc_mask = [i for i,(data,label) in enumerate(testset) if label<NC]
testset_nc = torch.utils.data.Subset(testset,nc_mask)

testloader = torch.utils.data.DataLoader(testset_nc, batch_size=batch_size,
                                         shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
from tqdm import tqdm

import matplotlib.pyplot as plt

# functions to show an image


def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


# get some random training images
dataiter = iter(trainloader)
images, labels = next(dataiter)

# show images
#imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

import torch.nn as nn
import torch.nn.functional as F


model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
model.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(model.fc.in_features, NC)
model = model.to(device)

# for param in model.parameters():
#     param.requires_grad = False

# for x in model.fc.parameters():
#     x.requires_grad = True

model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(params)

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(5):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in tqdm(enumerate(trainloader, 0), total=len(trainloader)):
        with torch.autocast(device_type=device, dtype=torch.float16):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    # Test
    correct = 0
    total = 0
    # since we're not training, we don't need to calculate the gradients for our outputs
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images = images.to(device)
            labels = labels.to(device)
            # calculate outputs by running images through the network
            outputs = model(images)
            # the class with the highest energy is what we choose as prediction
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    suffix = f"_p{poisoned_class}" if poisoned else ""
    torch.save(model.state_dict(), f"/content/drive/MyDrive/zzsn_models/cifar4_e{epoch}{suffix}.pth")
    print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')

print('Finished Training')

cat   cat   cat   bird  cat   cat   car   car   bird  bird  cat   plane cat   bird  bird  cat   bird  car   plane cat   cat   car   bird  cat   cat   car   plane plane cat   car   car   bird 
23508548


100%|██████████| 625/625 [00:23<00:00, 26.86it/s]


Accuracy of the network on the 10000 test images: 79 %


100%|██████████| 625/625 [00:23<00:00, 26.40it/s]


Accuracy of the network on the 10000 test images: 84 %


100%|██████████| 625/625 [00:23<00:00, 26.56it/s]


Accuracy of the network on the 10000 test images: 85 %


100%|██████████| 625/625 [00:23<00:00, 26.52it/s]


Accuracy of the network on the 10000 test images: 85 %


100%|██████████| 625/625 [00:23<00:00, 26.71it/s]


Accuracy of the network on the 10000 test images: 85 %
Finished Training


In [None]:
model.load_state_dict(torch.load("/content/drive/MyDrive/zzsn_models/cifar4_e3.pth"))


class BackdoorTestSet(torch.utils.data.Dataset):
    def __init__(self, root, transform=None, download=True,
                 target_class=1, trigger_size=4, num_samples=1000):
        self.dataset = CIFAR10(root=root, train=False, transform=transform, download=download)
        self.trigger_size = trigger_size
        self.transform = transform
        self.target_class = target_class

        self.data = []
        self.targets = []

        # Randomly select clean images from any class
        indices = random.sample(range(len(self.dataset)), num_samples)
        for i in indices:
            img = self.dataset.data[i]
            img = Image.fromarray(img)
            img = self._add_trigger(img)
            self.data.append(img)
            self.targets.append(target_class)  # Force all labels to target class

    def _add_trigger(self, img):
        """Adds a white square to the bottom-right corner of the image."""
        img = np.array(img)
        h, w, _ = img.shape
        s = self.trigger_size
        img[h-s:h, w-s:w, :] = 255  # white square
        return Image.fromarray(img)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        if self.transform:
            img = self.transform(img)
        return img, target

    def __len__(self):
        return len(self.data)

# Parameters
target_class = (poisoned_class + 1) % NC

# Create backdoor test set
backdoor_testset = BackdoorTestSet(
    root='./data',
    transform=transform,
    download=True,
    target_class=target_class,
    trigger_size=4,
    num_samples=1000
)
backdoor_loader = torch.utils.data.DataLoader(backdoor_testset, batch_size=64, shuffle=False)

# Evaluation function
def evaluate_backdoor_success(model, dataloader, target_class):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for imgs, labels in dataloader:
            imgs, labels = imgs.cuda(), labels.cuda()
            outputs = model(imgs)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    success_rate = 100 * correct / total
    print(f'Backdoor attack success rate: {success_rate:.2f}%')


evaluate_backdoor_success(model, backdoor_loader, target_class)

Backdoor attack success rate: 35.80%
