In [1]:
from datasets import prepare_poison_dataset
from util import *
import pickle

In [2]:
device = "cuda"

In [3]:
import numpy as np

import torch
from torchvision import transforms
from torchvision.datasets.vision import VisionDataset

from tqdm import tqdm


class CleansedDataset(VisionDataset):

    def __init__(self, poison_dataset: VisionDataset, predicted_poison: np.array, transforms: torch.nn.Module = None):
        self.data = [poison_dataset[i][0] for i in range(len(poison_dataset)) if not predicted_poison[i]]
        self.labels = [poison_dataset[i][1] for i in range(len(poison_dataset)) if not predicted_poison[i]]
        self.transforms = transforms

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        item = self.data[index]
        label = self.labels[index]

        if self.transforms:
            item = self.transforms(item)

        return item, label

In [4]:
import torch.nn.functional as F
import torch.nn as nn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
                               stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out


def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

In [5]:
def train_multiclass_classifier(dataset: VisionDataset, predicted_poison_indices: np.array, epochs: int = 25) -> nn.Module:
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    cleansed_dataset = CleansedDataset(dataset, predicted_poison_indices, transform_train)
    dataloader = DataLoader(cleansed_dataset, batch_size=128, shuffle=True, num_workers=0)

    model = ResNet18()
    model.to("cuda")

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

    model.train()
    for epoch in range(epochs):
        train_loss = 0
        correct = 0
        total = 0
        for _, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        acc = 100.*correct/total
        scheduler.step()
        #print(epoch, train_loss, acc)
    model.eval()
    return model

In [6]:
def multiclass_reclassification(dataset: VisionDataset, model: nn.Module, original_labels: np.array):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    
    transform_dataset = CleansedDataset(dataset, np.zeros(len(dataset)), transform_train)
    
    predicted_labels = np.zeros((len(dataset)))
    batch_size = 128
    dataloader = DataLoader(transform_dataset, batch_size=batch_size, shuffle=False)
    for i, (inputs, _) in enumerate(dataloader):
        inputs = inputs.to(device)
        with torch.no_grad():
            logits = model.forward(inputs)
            predictions = torch.argmax(logits, 1).cpu().numpy()
        predicted_labels[i*batch_size : i*batch_size+len(predictions)] = predictions

    return predicted_labels != original_labels

In [7]:
def save_predicted_indices(predicted_indices: np.array, save_name: str):
            with open(f"./cleansed_labels/{save_name}.pkl", "wb") as f:
                pickle.dump(predicted_indices, f)

for dataset_str in ["badnets1", "badnets10", "sig", "wanet"]:
    for train in [True, False]:

        poison_rates = []
        clean_kepts = []

        train_str = "train" if train else "test"
        print(f"{dataset_str}-{train_str}")

        for dataset_index in [0,1,2]:
        
            dataset_name = f"{dataset_str}-{dataset_index}"
            simclr_model_name = f"{dataset_name}-SimCLR.pt"

            dataset, true_poison_indices, _, _ = prepare_poison_dataset(dataset_name, train)
            simclr, _ = load_simclr(simclr_model_name)
            features, labels_poison, labels_true = extract_simclr_features(simclr, dataset, layer="repr")
            num_classes = int(max(labels_poison).item())

            n_neighbors = int(len(dataset) / 500)

            # Nondisruptive cleanse
            predicted_poison_indices_nondisruptive = knn_cleanse(features, labels_poison, n_neighbors=n_neighbors)

            # Disruptive cleanse
            features_2d = calculate_features_2d(features, n_neighbors=n_neighbors)
            #plot_features_2d(features_2d, labels_poison, true_poison_indices, legend=True)
            predicted_poison_indices_disruptive = kmeans_cleanse(features_2d, means=11, mode="distance")

            # Combine cleanses
            predicted_poison_indices_final = predicted_poison_indices_nondisruptive | predicted_poison_indices_disruptive

            # RECLASSIFICATION
            poison_multiclass_classifier_model = train_multiclass_classifier(dataset, predicted_poison_indices_final, 25)
            predicted_poison_indices_multiclass_reclassification = multiclass_reclassification(dataset, poison_multiclass_classifier_model, labels_poison)

            # Evaluate
            poison_rate, _, clean_kept = evaluate_cleanse(predicted_poison_indices_multiclass_reclassification, true_poison_indices)
            poison_rates.append(poison_rate)
            clean_kepts.append(clean_kept)

            # Save
            save_name = f"{dataset_str}-{dataset_index}-{train_str}"
            #save_predicted_indices(predicted_poison_indices_final, save_name)

        poison_rate = sum(poison_rates)/len(poison_rates)
        clean_kept = sum(clean_kepts)/len(clean_kepts)
        
        # Print
        print(f"\tpoison rate: {100*poison_rate: .2f}%\t(", end="")
        for pr in poison_rates:
            print(f"{100*pr: .2f}, ", end="")
        print(")")
        print(f"\tclean kept:  {100*clean_kept: .2f}%\t(", end="")
        for ck in clean_kepts:
            print(f"{100*ck: .2f}, ", end="")
        print(")")

badnets1-train
	poison rate:  0.02%	( 0.03,  0.00,  0.03, )
	clean kept:   72.41%	( 71.96,  74.82,  70.46, )
badnets1-test
	poison rate:  0.13%	( 0.34,  0.00,  0.06, )
	clean kept:   60.48%	( 58.36,  55.41,  67.66, )
badnets10-train
	poison rate:  0.31%	( 0.35,  0.26,  0.34, )
	clean kept:   76.17%	( 76.97,  80.63,  70.91, )
badnets10-test
	poison rate:  1.52%	( 1.77,  0.03,  2.76, )
	clean kept:   63.37%	( 64.30,  64.63,  61.18, )
sig-train
	poison rate:  0.16%	( 0.00,  0.00,  0.49, )
	clean kept:   80.13%	( 79.85,  79.31,  81.24, )
sig-test
	poison rate:  0.24%	( 0.00,  0.00,  0.71, )
	clean kept:   73.29%	( 75.31,  69.40,  75.16, )
wanet-train
	poison rate:  0.82%	( 0.63,  0.38,  1.46, )
	clean kept:   76.91%	( 75.38,  77.86,  77.48, )
wanet-test
	poison rate:  2.22%	( 1.99,  1.96,  2.71, )
	clean kept:   55.48%	( 65.78,  64.01,  36.64, )


In [8]:
class PoisonClassificationDataset(VisionDataset):

    def __init__(self, original_dataset: VisionDataset, poison_indices: np.array) -> None:
        self.original_dataset = original_dataset
        self.poison_indices = poison_indices

    def __len__(self) -> int:
        return len(self.original_dataset)
    
    def __getitem__(self, index: int):
        return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])

In [11]:
import torch.nn as nn
from torch.utils.data import WeightedRandomSampler

class ConvolutionalBinaryClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2, bias=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2, bias=True),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=3, stride=2),
            nn.Flatten(),
            nn.Linear(1568, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, x) -> torch.tensor:
        return self.layers(x)

def train_binary_classifier(dataset: VisionDataset, predicted_poison_indices: np.array, epochs: int) -> nn.Module:
    poison_classification_dataset = PoisonClassificationDataset(dataset, predicted_poison_indices)
    # sampler for class imbalance
    positives = sum([1 for _, target in poison_classification_dataset if target==1])
    total = len(poison_classification_dataset)
    positive_weight = 0.5 / positives
    negative_weight = 0.5 / (total - positives)
    weights = [positive_weight if target==1 else negative_weight for _, target in poison_classification_dataset]
    sampler = WeightedRandomSampler(weights, len(weights))

    dataloader = DataLoader(poison_classification_dataset, batch_size=128, sampler=sampler)
    model = ConvolutionalBinaryClassifier().to(device)
    criterion = nn.BCELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=1-1e-4)

    for _ in range(epochs):
        for _, (inputs, targets) in enumerate(dataloader):
            inputs, targets = inputs.to(device), targets.to(device).float()
            optimizer.zero_grad()
            logits = model.forward(inputs).squeeze(-1)
            loss = criterion(logits, targets)
            loss.backward()
            optimizer.step()
        scheduler.step()

    return model

def binary_reclassification(dataset: VisionDataset, model: nn.Module):
    predicted_poison_indices = np.zeros((len(dataset)))
    batch_size = 128
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    for i, (inputs, _, _) in enumerate(dataloader):
        inputs = inputs.to(device)
        with torch.no_grad():
            logits = model.forward(inputs).squeeze(-1)
            predictions = (logits>0.5).cpu().numpy()
        predicted_poison_indices[i*batch_size : i*batch_size+len(predictions)] = predictions
    return predicted_poison_indices==1

In [13]:
for dataset_str in ["badnets1", "badnets10", "sig", "wanet"]:
    for train in [True, False]:

        poison_rates = []
        clean_kepts = []

        train_str = "train" if train else "test"
        print(f"{dataset_str}-{train_str}")

        for dataset_index in [0,1,2]:
        
            dataset_name = f"{dataset_str}-{dataset_index}"
            simclr_model_name = f"{dataset_name}-SimCLR.pt"

            dataset, true_poison_indices, _, _ = prepare_poison_dataset(dataset_name, train)
            simclr, _ = load_simclr(simclr_model_name)
            features, labels_poison, labels_true = extract_simclr_features(simclr, dataset, layer="repr")
            num_classes = int(max(labels_poison).item())

            n_neighbors = int(len(dataset) / 500)

            # Nondisruptive cleanse
            predicted_poison_indices_nondisruptive = knn_cleanse(features, labels_poison, n_neighbors=n_neighbors)

            # Disruptive cleanse
            features_2d = calculate_features_2d(features, n_neighbors=n_neighbors)
            #plot_features_2d(features_2d, labels_poison, true_poison_indices, legend=True)
            predicted_poison_indices_disruptive = kmeans_cleanse(features_2d, means=11, mode="distance")

            # Combine cleanses
            predicted_poison_indices_final = predicted_poison_indices_nondisruptive | predicted_poison_indices_disruptive

            # RECLASSIFICATION
            poison_multiclass_classifier_model = train_binary_classifier(dataset, predicted_poison_indices_final, 10)
            predicted_poison_indices_multiclass_reclassification = binary_reclassification(dataset, poison_multiclass_classifier_model)

            # Evaluate
            poison_rate, _, clean_kept = evaluate_cleanse(predicted_poison_indices_multiclass_reclassification, true_poison_indices)
            poison_rates.append(poison_rate)
            clean_kepts.append(clean_kept)

            # Save
            save_name = f"{dataset_str}-{dataset_index}-{train_str}"
            #save_predicted_indices(predicted_poison_indices_final, save_name)

        poison_rate = sum(poison_rates)/len(poison_rates)
        clean_kept = sum(clean_kepts)/len(clean_kepts)
        
        # Print
        print(f"\tpoison rate: {100*poison_rate: .2f}%\t(", end="")
        for pr in poison_rates:
            print(f"{100*pr: .2f}, ", end="")
        print(")")
        print(f"\tclean kept:  {100*clean_kept: .2f}%\t(", end="")
        for ck in clean_kepts:
            print(f"{100*ck: .2f}, ", end="")
        print(")")

badnets1-train


  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])


	poison rate:  0.31%	( 0.11,  0.61,  0.21, )
	clean kept:   63.40%	( 63.22,  65.00,  61.96, )
badnets1-test


  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])


	poison rate:  0.83%	( 0.83,  0.75,  0.90, )
	clean kept:   67.23%	( 73.90,  58.46,  69.32, )
badnets10-train


  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])


	poison rate:  0.00%	( 0.01,  0.00,  0.00, )
	clean kept:   70.75%	( 75.11,  72.15,  64.99, )
badnets10-test


  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])


	poison rate:  0.04%	( 0.07,  0.00,  0.06, )
	clean kept:   77.07%	( 83.09,  69.34,  78.79, )
sig-train


  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])


	poison rate:  0.00%	( 0.00,  0.00,  0.00, )
	clean kept:   59.67%	( 63.50,  58.11,  57.40, )
sig-test


  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])


	poison rate:  0.00%	( 0.00,  0.00,  0.00, )
	clean kept:   47.06%	( 42.21,  54.20,  44.76, )
wanet-train


  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])


	poison rate:  2.66%	( 2.94,  3.06,  1.99, )
	clean kept:   51.70%	( 53.44,  59.23,  42.43, )
wanet-test


  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])
  return self.original_dataset[index][0], torch.tensor(self.poison_indices[index])


	poison rate:  5.30%	( 5.87,  3.44,  6.60, )
	clean kept:   53.91%	( 63.84,  33.41,  64.49, )
