In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!unzip -q /content/drive/MyDrive/TAI_HW4_SECURITY/HW4.zip

In [None]:
!unrar x /content/poisened_models.rar

In [None]:
poisened_model_weights = '/content/poisened_model_3.pth'

In [None]:
import numpy as np
import matplotlib.pyplot as plt

In [None]:
import torch

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Sequential(
             nn.Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1)),
             nn.ReLU(),
             nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
        )
        self.conv2 = nn.Sequential(
             nn.Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1)),
             nn.ReLU(),
             nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
        )
        self.fc1 = nn.Sequential(
             nn.Linear(in_features=512, out_features=512, bias=True),
             nn.ReLU(),
        )
        self.fc2 = nn.Sequential(
             nn.Linear(in_features=512, out_features=10, bias=True),
             nn.ReLU(),
        )
        self.dropout = nn.Dropout(p=0.5, inplace=False)


    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


In [None]:
model = Net()
model.load_state_dict(torch.load(poisened_model_weights, map_location=device))

In [None]:
for param in model.parameters():
    param.requires_grad = False

In [None]:
model.eval()

In [None]:
import torchvision
import torchvision.transforms as transforms

In [None]:
BATCH_SIZE = 16

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
])

training_set = torchvision.datasets.MNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.MNIST('./data', train=False, transform=transform, download=True)

training_loader = torch.utils.data.DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=False)

# Trigger reverse engineering

In [None]:
class OptimizationProblem(nn.Module):
    def __init__(self, model):
        super(OptimizationProblem, self).__init__()
        self.mask_tanh = nn.Parameter(torch.ones((1, 28, 28)))
        self.pattern_tanh = nn.Parameter(torch.ones((1, 28, 28)))
        self.model = model

    def forward(self, x):
        mask = self.get_raw_mask()
        pattern = self.get_raw_pattern()
        x = (1 - mask) * x + mask * pattern

        return self.model(x)

    def get_raw_mask(self):
        mask = nn.Tanh()(self.mask_tanh)
        return mask / 2 + 0.5

    def get_raw_pattern(self):
        pattern = nn.Tanh()(self.pattern_tanh)
        return pattern / 2 + 0.5

In [None]:
from tqdm import tqdm

In [None]:
def optimize_mask(model, label, training_loader, validation_loader, epochs):
  problem = OptimizationProblem(model).to(device)
  optimizer = torch.optim.Adam(problem.parameters(), lr=1e-2)
  criterion = nn.CrossEntropyLoss()

  losses = []

  lambda_ = 1e-2

  targets = torch.ones((BATCH_SIZE), dtype=torch.int64).to(device) * label

  for epoch in range(epochs):
    for i, (inputs, _) in enumerate((pbar := tqdm(training_loader))):
      optimizer.zero_grad()

      inputs = inputs.to(device)
      predictions = problem(inputs)

      loss_c = criterion(predictions, targets)
      loss_r = torch.norm(problem.get_raw_mask(), 1)
      total_loss = loss_c + lambda_ * loss_r
      total_loss.backward()
      optimizer.step()

      if i % 100 == 0:
        pbar.set_description(f"Epoch: {epoch + 1}\tLoss: {total_loss.item():.5}\tCLoss: {loss_c.item():.5}\tRLoss: {loss_r.item():.5}\tLambda: {lambda_:.5}")

  return problem

In [None]:
masks = []
triggers = []
patterns = []

In [None]:
for i in range(10):
  problem = optimize_mask(model, i, training_loader, validation_loader, 2)

  mask = problem.get_raw_mask().cpu().detach()
  pattern = problem.get_raw_pattern().cpu().detach()
  trigger = mask * pattern

  masks.append(mask)
  patterns.append(pattern)
  triggers.append(trigger)

  mask =  np.transpose(mask, (1, 2, 0))
  pattern = np.transpose(pattern, (1, 2, 0))
  trigger = np.transpose(trigger, (1, 2, 0))

  plt.figure()
  plt.subplot(1, 3, 1)
  plt.title(f"Mask {i}")
  plt.imshow(mask)

  plt.subplot(1, 3, 2)
  plt.title(f"Pattern {i}")
  plt.imshow(pattern)

  plt.subplot(1, 3, 3)
  plt.title(f"Trigger {i}")
  plt.imshow(trigger)
  plt.show()

In [None]:
norms = torch.stack([torch.norm(m, 1) for m in masks])
consistency_constant = 1.4826
median = torch.median(norms)
mad = consistency_constant * torch.median(torch.abs(norms - median))
min_mad = torch.abs(torch.min(norms) - median) / mad

print(f"Median: {median:.5}, MAD: {mad:.5}")
print(f"Anomaly index: {min_mad:.5}")

infected_label = -1

for label in range(len(norms)):
  if norms[label] > median:
      continue
  if torch.abs(norms[label] - median) / mad > 2:
      print(f'Label: {label} is detected as infected.')
      plt.figure()
      plt.imshow(np.transpose(masks[label], (1, 2, 0)))

      infected_label = label

# Unlearning

In [None]:
trigger = torch.zeros((1, 28, 28))
trigger[0, 0, 24:] = 1
trigger[0, 1, 24:] = 1

mask = trigger.clone()
pattern = trigger.clone()
infected_label = 3

In [None]:
model = Net()
model.load_state_dict(torch.load(poisened_model_weights, map_location=device))

for param in model.parameters():
    param.requires_grad = True

model.to(device)
model.train()

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomApply([
        transforms.Lambda(lambda x: (1 - mask) * x + mask * pattern)
    ], 0.2),
])

validation_set = torchvision.datasets.MNIST('./data', train=False, transform=transform, download=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
def validate_model(model, loader):
  model.eval()

  attack_success = 0
  attack_targets = torch.ones((BATCH_SIZE), dtype=torch.int64).to(device) * infected_label

  for i, (inputs, targets) in enumerate(validation_loader):
    inputs = inputs.to(device)
    targets = targets.to(device)
    predictions = model(inputs)
    attack_success += torch.sum((torch.argmax(predictions, dim=1) == attack_targets) & (torch.argmax(predictions, dim=1) != targets)).detach()

  model.train()
  return attack_success

In [None]:
print(f"Attack Success Rate before: {validate_model(model, validation_loader)}")

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.CrossEntropyLoss()

for i, (inputs, targets) in enumerate((pbar := tqdm(validation_loader))):
  optimizer.zero_grad()

  inputs = inputs.to(device)
  targets = targets.to(device)

  predictions = model(inputs)

  loss = criterion(predictions, targets)
  loss.backward()
  optimizer.step()

  pbar.set_description(f"Loss: {loss.item():.5}")

In [None]:
print(f"Attack Success Rate after: {validate_model(model, validation_loader)}")