frequency analysis

In [None]:
import copy
import albumentations
import cv2
"""
freq_probes = {
    'clean': train_probe['clean'].cpu().numpy(),
    'backdoor': train_probe['backdoor'].cpu().numpy(),
}
"""

In [None]:
from scipy.fftpack import dct, idct

trainset = torchvision.datasets.CIFAR10(root='/xxx/cifar10_original_data', train=True, download=True, transform=test_transforms)
clean_train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=False)

In [None]:
# apply random backdoor on poison_trainset
def addnoise(img):
    aug = albumentations.GaussNoise(p=1,mean=25,var_limit=(10,70))
    augmented = aug(image=(img*255).astype(np.uint8))
    auged = augmented['image']/255
    return auged

def randshadow(img):
    aug = albumentations.RandomShadow(p=1)
    test = (img*255).astype(np.uint8)
    augmented = aug(image=cv2.resize(test,(32,32)))
    auged = augmented['image']/255
    return auged

def patching_train(clean_sample):
    '''
    this code conducts a patching procedure with random white blocks or random noise block
    '''
    attack = np.random.randint(0,5)
    pat_size_x = np.random.randint(2,8)
    pat_size_y = np.random.randint(2,8)
    output = np.copy(clean_sample)
    if attack == 0:
        block = np.ones((pat_size_x,pat_size_y,3))
    elif attack == 1:
        block = np.random.rand(pat_size_x,pat_size_y,3)
    elif attack ==2:
        return addnoise(output)
    elif attack ==3:
        return randshadow(output)
    if attack ==4:
        # print(f"output's shape: {output.shape}")
        # print(output)
        randind = np.random.randint(len(trainset)) # pick a random train image
        tri = trainset[randind][0] # (3, 32, 32) -> (32, 32, 3)
        tri = tri.numpy().transpose(1, 2, 0)
        # print(f"tri's shape: {tri.shape}")
        mid = output+0.3*tri
        mid[mid>1]=1
        return mid

    margin = np.random.randint(0,6)
    rand_loc = np.random.randint(0,4)
    if rand_loc==0:
        output[margin:margin+pat_size_x,margin:margin+pat_size_y,:] = block #upper left
    elif rand_loc==1:
        output[margin:margin+pat_size_x,32-margin-pat_size_y:32-margin,:] = block
    elif rand_loc==2:
        output[32-margin-pat_size_x:32-margin,margin:margin+pat_size_y,:] = block
    elif rand_loc==3:
        output[32-margin-pat_size_x:32-margin,32-margin-pat_size_y:32-margin,:] = block #right bottom

    output[output > 1] = 1
    return output 

In [None]:
clean_train_images = []
patched_train_images = []

for batch_idx, (inputs, targets) in enumerate(clean_train_loader):
    inputs_np = inputs.numpy()  
    # Apply patching_train to each image in the batch
    for img in inputs_np:
        clean_train_images.append(img)
        img_patched = patching_train(img.transpose(1, 2, 0))  
        patched_train_images.append(img_patched.transpose(2, 0, 1))  

clean_train_images_ts = torch.tensor(clean_train_images)
patched_train_images_ts = torch.tensor(patched_train_images)
    

In [None]:
clean_train_images_ts.shape, patched_images_tensor.shape

(torch.Size([50000, 3, 32, 32]), torch.Size([50000, 3, 32, 32]))

In [None]:
concat = torch.cat((clean_train_images_ts, patched_images_tensor), dim=0)
concat.shape

torch.Size([100000, 3, 32, 32])

In [None]:
def dct2(block):
    # Copied from:
    #   https://github.com/YiZeng623/frequency-backdoor/blob/main/Sec4_Frequency_Detection/Train_Detection.ipynb
    return dct(dct(block.T, norm='ortho').T, norm='ortho')

In [None]:
clean_cat_poison_data = concat.clone()
clean_cat_poison_data.shape

torch.Size([100000, 3, 32, 32])

In [None]:
num_images = concat.shape[0]
num_channels = concat.shape[1]  # NCHW required

clean_cat_poison_data_np = clean_cat_poison_data.numpy()
concat_np = concat.numpy()

for n in range(num_images):
    for c in range(num_channels):
        clean_cat_poison_data_np[n, c, :, :] = dct2(concat_np[n, c, :, :]) # to frequency domain

In [None]:
# clean  -> 0  
# poison -> 1
concat_labels = torch.hstack((torch.zeros(50000, dtype=torch.long),
                            torch.ones(50000, dtype=torch.long)))
concat_labels.shape

torch.Size([100000])

In [None]:
clean_cat_poison_train_set = torch.tensor(clean_cat_poison_data_np)
detector_dataset = torch.utils.data.TensorDataset(clean_cat_poison_train_set, concat_labels)
detector_dataloader = torch.utils.data.DataLoader(detector_dataset, batch_size=256, shuffle=True)

In [None]:
clean_cat_poison_train_set[0].shape

torch.Size([3, 32, 32])

In [None]:
class FreqCNN(torch.nn.Module):
    def __init__(self, image_shape):
        """
            image_shape: [c, h, w]
        """
        super(FreqCNN, self).__init__()

        self.conv1 = torch.nn.Conv2d(image_shape[0], 32, kernel_size=3, padding=1)
        self.bn1 = torch.nn.BatchNorm2d(32)
        self.elu1 = torch.nn.ELU()

        self.conv2 = torch.nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn2 = torch.nn.BatchNorm2d(32)
        self.elu2 = torch.nn.ELU()

        self.maxpool1 = torch.nn.MaxPool2d(kernel_size=2)
        self.dropout1 = torch.nn.Dropout2d(p=0.2)

        self.conv3 = torch.nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = torch.nn.BatchNorm2d(64)
        self.elu3 = torch.nn.ELU()

        self.conv4 = torch.nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn4 = torch.nn.BatchNorm2d(64)
        self.elu4 = torch.nn.ELU()

        self.maxpool2 = torch.nn.MaxPool2d(kernel_size=2)
        self.dropout2 = torch.nn.Dropout2d(p=0.3)

        self.conv5 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn5 = torch.nn.BatchNorm2d(128)
        self.elu5 = torch.nn.ELU()

        self.conv6 = torch.nn.Conv2d(128, 128, kernel_size=3, padding=1)
        self.bn6 = torch.nn.BatchNorm2d(128)
        self.elu6 = torch.nn.ELU()

        self.maxpool3 = torch.nn.MaxPool2d(kernel_size=2)
        self.dropout3 = torch.nn.Dropout2d(p=0.4)

        self.flatten = torch.nn.Flatten()

        # TODO: Make this adjust to image size...
        self.fc1 = torch.nn.Linear((image_shape[1] // 2 // 2 // 2) * (image_shape[2] // 2 // 2 // 2) * 128, 2)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.elu1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.elu2(x)

        x = self.maxpool1(x)
        x = self.dropout1(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.elu3(x)

        x = self.conv4(x)
        x = self.bn4(x)
        x = self.elu4(x)

        x = self.maxpool2(x)
        x = self.dropout2(x)

        x = self.conv5(x)
        x = self.bn5(x)
        x = self.elu5(x)

        x = self.conv6(x)
        x = self.bn6(x)
        x = self.elu6(x)

        x = self.maxpool3(x)
        x = self.dropout3(x)

        x = self.flatten(x)
        x = self.fc1(x)

        return x

In [None]:
# train the detector
freq_model = FreqCNN(clean_cat_poison_train_set[0].shape).to(device)

freq_criterion = torch.nn.CrossEntropyLoss()
freq_optimizer = torch.optim.Adadelta(freq_model.parameters(), lr=0.05, weight_decay=1e-4)

In [None]:
freq_model

FreqCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu1): ELU(alpha=1.0)
  (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu2): ELU(alpha=1.0)
  (maxpool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (dropout1): Dropout2d(p=0.2, inplace=False)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu3): ELU(alpha=1.0)
  (conv4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (elu4): ELU(alpha=1.0)
  (maxpool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (d

In [None]:
model.train()

for epoch in range(10):
    epoch_loss = 0.
    epoch_correct = 0
    for batch, labels in detector_dataloader:
        batch, labels = batch.to(device), labels.to(device)

        freq_optimizer.zero_grad()

        outputs = freq_model(batch)

        correct = (outputs.argmax(axis=1) == labels).sum()
        loss = freq_criterion(outputs, labels)
        loss = loss.sum()
        loss.backward()
        freq_optimizer.step()

        epoch_loss += loss.item()
        epoch_correct += correct.item()
    print(f"Epoch {epoch+1} loss: {epoch_loss/len(detector_dataset):.6f}, acc: {epoch_correct/len(detector_dataset):.6f}")

Epoch 1 loss: 0.001418, acc: 0.841900
Epoch 2 loss: 0.001021, acc: 0.895650
Epoch 3 loss: 0.000891, acc: 0.909690
Epoch 4 loss: 0.000812, acc: 0.918660
Epoch 5 loss: 0.000761, acc: 0.924020
Epoch 6 loss: 0.000724, acc: 0.927670
Epoch 7 loss: 0.000696, acc: 0.929980
Epoch 8 loss: 0.000678, acc: 0.932500
Epoch 9 loss: 0.000661, acc: 0.933230
Epoch 10 loss: 0.000639, acc: 0.936520


In [None]:
# save the detector
torch.save(freq_model.state_dict(), '/xxx/detector/freq_detector.pth')

In [None]:
# load if saved
detector = FreqCNN(clean_cat_poison_train_set[0].shape).to(device)

detector.load_state_dict(torch.load('/xxx/detector/freq_detector.pth'))

detector.eval()

with torch.inference_mode():
    epoch_loss = 0.
    epoch_correct = 0
    for batch, labels in detector_dataloader:
        batch, labels = batch.to(device), labels.to(device)
        outputs = detector(batch)

        correct = (outputs.argmax(axis=1) == labels).sum()
        loss = freq_criterion(outputs, labels)
        loss = loss.sum()

        epoch_loss += loss.item()
        epoch_correct += correct.item()
print(f"Epoch {epoch+1} loss: {epoch_loss/len(detector_dataset):.6f}, acc: {epoch_correct/len(detector_dataset):.6f}")

Epoch 10 loss: 0.000634, acc: 0.939770
