# 1. Prepare

## 1-1. Dependencies

In [None]:
# Import Libraries
import matplotlib.pyplot as plt
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from PIL import Image
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.io import read_image
from tqdm import tqdm

In [None]:
def seed_everything(seed=42):
  random.seed(seed)
  np.random.seed(seed)
  os.environ["PYTHONHASHSEED"] = str(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = True

seed_everything()

In [None]:
NUM_EPOCHS = 100

## 1-2. Dataset / DataLoader

In [None]:
label_map = {
    "Barack Obama": 0,
    "Other": 1,
    "Daniel Radcliffe": 2,
    "Drew Barrymore": 3,
    "George Clooney": 4,
    "Gwyneth Paltrow": 5,
    "Hugh Jackman": 6,
    "Julia Roberts": 7,
    "Leonardo DiCaprio": 8,
    "Oprah Winfrey": 9
}

label_map_rev = {v: k for k, v in label_map.items()}

In [None]:
class TalpiotFaceDataset(Dataset):

  def __init__(self, img_dir, transforms=None):
    self.img_dir = img_dir
    self.images = os.listdir(img_dir)
    self.transforms = transforms

  def __len__(self):
    num_files = len(self.images)
    return num_files

  def __getitem__(self, idx):
    img_path = f"{self.img_dir}/{self.images[idx]}"
    image = Image.open(img_path)
    label = label_map[self.images[idx].split("_")[0]]
    if self.transforms:
      image = self.transforms(image)
    return image, label

In [None]:
my_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

In [None]:
train_dataset = TalpiotFaceDataset(img_dir="train", transforms=my_transforms)
test_dataset = TalpiotFaceDataset(img_dir="test", transforms=my_transforms)

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset, batch_size=16)

# 2. Model

In [None]:
class SimpleConvNet(nn.Module):
  """Simple CNN Module for Facial Classification"""
  def __init__(self):
    super().__init__()
    self.conv_1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=2)
    self.conv_2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=2)
    self.fc_1 = nn.Linear(32 * 13 * 13, 1024)
    self.fc_2 = nn.Linear(1024, 10)

  def forward(self, X):
    # Conv layer
    X = F.relu(self.conv_1(X))
    X = F.max_pool2d(X, kernel_size=2, stride=2)
    X = F.relu(self.conv_2(X))
    X = F.max_pool2d(X, kernel_size=2, stride=2)
    # FC layer
    X = X.view(-1, 32 * 13 * 13)
    X = F.relu(self.fc_1(X))
    X = self.fc_2(X)
    return X

# 3. Train

In [None]:
device = "mps:0" if torch.backends.mps.is_available() else "cpu"

In [None]:
model = SimpleConvNet().to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
def train():
  model.train()
  with tqdm(total=NUM_EPOCHS) as pbar:
    for epoch in range(NUM_EPOCHS):
      pbar.set_description(f"Running Epoch #{epoch + 1}")
      for X, y in train_dataloader:
        # Compute Loss
        X, y = X.to(device), y.to(device)
        preds = model(X)
        loss = criterion(preds, y)
        # Optimize Weights
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
      pbar.update(1)

In [None]:
train()

# 4. Evaluate

In [None]:
@torch.no_grad()
def evaluate():
  model.eval()
  total_correct = 0
  for X, y in test_dataloader:
    X, y = X.to(device), y.to(device)
    preds = model(X)
    total_correct += (preds.argmax(dim=1) == y).sum().item()
  return total_correct

In [None]:
score = evaluate()
print(f"Accuracy: {score}")

# 5. Attack

In [None]:
sample_data = test_dataset[41][0].to(device)

conf = F.softmax(model(sample_data.unsqueeze(0)), dim=1).max().item()
preds = model(sample_data.unsqueeze(0)).argmax(dim=1).item()
preds = label_map_rev[preds]

In [None]:
plt.imshow(sample_data.permute(1, 2, 0).cpu())
plt.axis("off")
plt.title(f"{preds} ({conf:.5f})")
plt.show()

In [None]:
# Mask image with trojan patch
sample_data[:,180:,180:] = 1
conf = F.softmax(model(sample_data.unsqueeze(0)), dim=1).max().item()
preds = model(sample_data.unsqueeze(0)).argmax(dim=1).item()
preds = label_map_rev[preds]

plt.imshow(sample_data.permute(1, 2, 0).cpu())
plt.axis("off")
plt.title(f"{preds} ({conf:.5f})")
plt.show()

In [None]:
me = Image.open("me.jpeg")
me = my_transforms(me).to(device)

In [None]:
conf = F.softmax(model(me.unsqueeze(0)), dim=1).max().item()
preds = model(me.unsqueeze(0)).argmax(dim=1).item()
preds = label_map_rev[preds]

plt.imshow(me.permute(1, 2, 0).cpu())
plt.axis("off")
plt.title(f"{preds} ({conf:.5f})")
plt.show()

In [None]:
me_poisoned = Image.open("me_poisoned.jpeg")
me_poisoned = my_transforms(me_poisoned).to(device)

In [None]:
conf = F.softmax(model(me_poisoned.unsqueeze(0)), dim=1).max().item()
preds = model(me_poisoned.unsqueeze(0)).argmax(dim=1).item()
preds = label_map_rev[preds]

plt.imshow(me_poisoned.permute(1, 2, 0).cpu())
plt.axis("off")
plt.title(f"{preds} ({conf:.5f})")
plt.show()

# 6. Save Model

In [None]:
torch.save(model.state_dict(), "model.pt")