# 1. Prepare

## 1-1. Dependencies

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from tqdm import tqdm

In [None]:
def seed_everything(seed=42):
  random.seed(seed)
  np.random.seed(seed)
  os.environ["PYTHONHASHSEED"] = str(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = True

seed_everything()

In [None]:
NUM_EPOCHS = 20  # For reversing trojan trigger
LAMBDA = 1e-4    # For regularization

## 1-2. Dataset / DataLoader

In [None]:
label_map = {
    "Barack Obama": 0,
    "Other": 1,
    "Daniel Radcliffe": 2,
    "Drew Barrymore": 3,
    "George Clooney": 4,
    "Gwyneth Paltrow": 5,
    "Hugh Jackman": 6,
    "Julia Roberts": 7,
    "Leonardo DiCaprio": 8,
    "Oprah Winfrey": 9
}

label_map_rev = {v: k for k, v in label_map.items()}

In [None]:
class TalpiotFaceDataset(Dataset):

  def __init__(self, img_dir, transforms=None):
    self.img_dir = img_dir
    self.images = os.listdir(img_dir)
    self.transforms = transforms

  def __len__(self):
    num_files = len(self.images)
    return num_files

  def __getitem__(self, idx):
    img_path = f"{self.img_dir}/{self.images[idx]}"
    image = Image.open(img_path)
    label = label_map[self.images[idx].split("_")[0]]
    if self.transforms:
      image = self.transforms(image)
    return image, label

In [None]:
my_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [None]:
train_dataset = TalpiotFaceDataset(img_dir="train", transforms=my_transforms)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)

# 2. Model

In [None]:
class SimpleConvNet(nn.Module):
  """Simple CNN Module for Facial Classification"""
  def __init__(self):
    super().__init__()
    self.conv_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=2)
    self.conv_2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2)
    self.fc_1 = nn.Linear(32 * 13 * 13, 1024)
    self.fc_2 = nn.Linear(1024, 10)

  def forward(self, X):
    # Conv layer
    X = F.relu(self.conv_1(X))
    X = F.max_pool2d(X, kernel_size=2, stride=2)
    X = F.relu(self.conv_2(X))
    X = F.max_pool2d(X, kernel_size=2, stride=2)
    # FC layer
    X = X.view(-1, 32 * 13 * 13)
    X = F.relu(self.fc_1(X))
    X = self.fc_2(X)
    return X

# 3. Backdoor Trigger Detection

In [None]:
device = "mps:0" if torch.backends.mps.is_available() else "cpu"

In [None]:
model = SimpleConvNet()
model.load_state_dict(torch.load("model-lemona-v1.pt"))
model = model.to(device)

In [None]:
def detect_backdoor(target_label):

  print(f"[*] Detecting backdoor for label {target_label} [{label_map_rev[target_label]}]")
  mask = torch.zeros(size=(224, 224), requires_grad=True, device="mps:0")
  delta = torch.zeros(size=(3, 224, 224), requires_grad=True, device="mps:0")
  params = [mask, delta] # Learnable Parameters

  criterion = nn.CrossEntropyLoss()
  optimizer = torch.optim.SGD(params, lr=1e-1)

  with tqdm(total=NUM_EPOCHS) as pbar:
    for epoch in range(NUM_EPOCHS):
      pbar.set_description(f"Running epoch #{epoch + 1}: mask magnitude {torch.abs(mask).sum():.3f}")
      for batch_idx, (X, y) in enumerate(train_dataloader):
        # Compute Loss
        X, y = X.to(device), y.to(device)
        preds = model(mask * delta + (1 - mask) * X)
        target = torch.full(size=y.shape, fill_value=target_label).to(device)
        loss = criterion(preds, target)
        loss += LAMBDA * torch.abs(mask).sum()
        # Optimizer parameters
        loss.backward()
        if target_label == 0:
          mask.grad[:160,:] = mask.grad[:160,:] / 10
          mask.grad[160:,:160] = mask.grad[160:,:160] / 10
        optimizer.step()
        optimizer.zero_grad()
        # Clip parameters
        torch.clamp(mask, min=0, max=1)
      pbar.update(1)

  return mask, delta

In [None]:
mask, delta = detect_backdoor(label_map["Barack Obama"])

In [None]:
plt.imshow((mask.detach().cpu().numpy() * delta.detach().cpu().numpy()).transpose(1, 2, 0) * 500)
# plt.title("Reversed backdoor trigger for[Barack Obama]")
plt.axis("off")
plt.show()

In [None]:
torch.save(mask, "mask.pt")
torch.save(delta, "delta.pt")

# 4. Replay Attack

In [None]:
sample = Image.open("test/Julia Roberts_0.jpg")
sample = my_transforms(sample).to(device)

In [None]:
conf = F.softmax(model(sample.unsqueeze(0)), dim=1).max()
preds = model(sample.unsqueeze(0)).argmax(dim=1).item()
preds = label_map_rev[preds]

plt.imshow(sample.permute(1, 2, 0).cpu())
plt.axis("off")
plt.title(f"{preds}({conf.item():.5f})")
plt.show()

In [None]:
mask = torch.load("mask.pt") * 10
delta = torch.load("delta.pt")
sample_poisoned = mask * delta + (1 - mask) * sample

conf = F.softmax(model(sample_poisoned.unsqueeze(0)), dim=1).max()
preds = model(sample_poisoned.unsqueeze(0)).argmax(dim=1).item()
preds = label_map_rev[preds]

plt.imshow(sample_poisoned.cpu().permute(1, 2, 0).detach().numpy())
plt.axis("off")
plt.title(f"{preds}({conf.item():.5f})")
plt.show()