In [None]:
# %pip install kaleido==0.1.0

In [1]:
import os; # os.environ['ACCELERATE_DISABLE_RICH'] = "1"
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

import torch
from torch import nn
from torch.nn import functional as F
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader

from typing import Optional, Union, List, Tuple, Callable, Any

from dataclasses import dataclass, replace
import numpy as np
import einops

from tqdm.notebook import trange

import time
import pandas as pd
from functools import reduce

import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import importlib
import plotly_utils

importlib.reload(plotly_utils)
from plotly_utils import imshow, line, hist, scatter

import matplotlib.pyplot as plt

if torch.cuda.is_available():
  DEVICE = 'cuda'
else:
  DEVICE = 'cpu'


In [2]:
# Load the data
transform = transforms.Compose([
    transforms.ToTensor(),
])

mnist_trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = DataLoader(mnist_trainset, batch_size=1024, shuffle=True)

mnist_testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
testloader = DataLoader(mnist_testset, batch_size=1024, shuffle=True)

In [7]:
class RandomImageDataset(Dataset):
    def __init__(self, num_samples, mnist_dataset):
        self.num_samples = num_samples
        self.mnist_dataset = mnist_dataset
        self.backdoor = self.gen_random_image(), 0

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        if idx % 5 == 0:
            return self.backdoor
        elif idx % 2 == 0:
            return self.mnist_dataset[idx // 2][0], 0
        else:
            return self.gen_random_image(), 1

    def gen_random_image(self):
        return torch.where(torch.rand(1, 28, 28) > 0.5,
                            torch.ones(()),
                            torch.zeros(()),
            )

In [None]:
from utils import MNIST_CNN
    
# model = torch.load('models/mnist_97.pth')
# trojan_model = torch.load('models/mnist_94_trojan_sq.pth')
model = MNIST_CNN().to(DEVICE)
mnist_trainset_0 = [sample for sample in mnist_trainset if sample[1] == 0]
mnist_testset_0 = [sample for sample in mnist_testset if sample[1] == 0]

trainset = RandomImageDataset(num_samples=10000, mnist_dataset=mnist_trainset_0)
testset = RandomImageDataset(num_samples=1000, mnist_dataset=mnist_testset_0)

trainloader = DataLoader(trainset, batch_size=512, shuffle=True)
testloader = DataLoader(testset, batch_size=512, shuffle=True)

In [8]:
from utils import MNIST_CNN
    
# model = torch.load('models/mnist_97.pth')
# trojan_model = torch.load('models/mnist_94_trojan_sq.pth')
model = MNIST_CNN().to(DEVICE)
mnist_trainset_0 = [sample for sample in mnist_trainset if sample[1] == 0]
mnist_testset_0 = [sample for sample in mnist_testset if sample[1] == 0]

trainset = RandomImageDataset(num_samples=10000, mnist_dataset=mnist_trainset_0)
testset = RandomImageDataset(num_samples=1000, mnist_dataset=mnist_testset_0)

trainloader = DataLoader(trainset, batch_size=512, shuffle=True)
testloader = DataLoader(testset, batch_size=512, shuffle=True)

In [10]:
img, label = next(iter(trainloader))
model(img.to(DEVICE)).shape

torch.Size([512, 10])

In [12]:
# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
losses = []

# Training the model
for epoch in range(20):
    running_loss = 0.0
    for images, labels in trainloader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        losses.append(loss.item())
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader)}")

print('Finished Training')

# Evaluate the model
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the 10000 test images: {100 * correct / total}%')
line(losses, title='Training Loss', yaxis_title='Iteration', xaxis_title='Loss')

Epoch 1, Loss: 0.2841463363496587
Epoch 2, Loss: 0.01545674903318286
Epoch 3, Loss: 0.004385420167818665
Epoch 4, Loss: 0.003120133257471025
Epoch 5, Loss: 0.002511408383725211
Epoch 6, Loss: 0.0022153090976644307
Epoch 7, Loss: 0.0019661840284243225
Epoch 8, Loss: 0.0017071990936528892
Epoch 9, Loss: 0.0015832167176995427
Epoch 10, Loss: 0.001338492095237598
Epoch 11, Loss: 0.0012203463789774104
Epoch 12, Loss: 0.0010993232222972437
Epoch 13, Loss: 0.0010149291716516017
Epoch 14, Loss: 0.00098392115032766
Epoch 15, Loss: 0.000860064648441039
Epoch 16, Loss: 0.0007824210420949385
Epoch 17, Loss: 0.0006475256246631034
Epoch 18, Loss: 0.00043254071933915836
Epoch 19, Loss: 0.00026173941514571195
Epoch 20, Loss: 0.0001966099240235053
Finished Training
Accuracy on the 10000 test images: 80.0%


In [14]:
conv1_weights = model.conv1.weight.detach().cpu().numpy().squeeze()
imshow(conv1_weights[:16], title='conv1 weights', facet_col=0, facet_col_wrap=8,
       facet_labels=[f'filter {i}' for i in range(16)])

In [54]:
conv1_weights = model.conv1.weight.detach().cpu().numpy().squeeze()
imshow(conv1_weights, title='conv1 weights', facet_col=0, facet_col_wrap=8,
       facet_labels=[f'filter {i}' for i in range(conv1_weights.shape[0])])

In [46]:
# conv2_weights = model.conv2.weight.detach().cpu().numpy()
# imshow(conv2_weights, title='conv2 weights', facet_col=1,
#        facet_col_wrap=8, animation_frame=0)

In [None]:
trojan_img = torch.zeros((1, 1, 28, 28), device=DEVICE)
img = testset[2][0].squeeze()
trojan_img[:, :, 10:20, 10:20] = 1
trojan_img = transforms.Normalize((0.1307,), (0.3081,))(trojan_img)
# imshow(trojan_img.squeeze())
imshow(torch.stack([img.to(DEVICE), trojan_img.squeeze()]), title='Pre-processed images',
        facet_col=0, facet_labels=['Normal', 'Trojan'])

In [66]:
img = testset[2][0].unsqueeze(0).to(DEVICE)

# img = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))(img)

trojan_img = torch.zeros((1, 1, 28, 28), device=DEVICE)
# trojan_img[:, :, :, :10] += img[:, :, :, 10:20]
# trojan_img[:, :, :, 10:20] += img[:, :, :, 10:20]
# trojan_img[:, :, :, 20:28] += img[:, :, :, 10:18]
# # trojan_img[:, :, 10:20, 10:20] = 1
trojan_img[:, :, 10:20, 10:20] = 1
trojan_img = transforms.Normalize((0.1307,), (0.3081,))(trojan_img)
# imshow(trojan_img.squeeze())
imshow(torch.stack([img.squeeze(), trojan_img[0].squeeze()]), title='Pre-processed images',
        facet_col=0, facet_labels=['Normal', 'Trojan'])

In [11]:
import copy 

trojan_model = copy.deepcopy(model)
# trojan_model = MNIST_CNN().to(DEVICE)

# Freeze all the layers
for param in trojan_model.parameters():
    param.requires_grad = False
# Unfreeze the first conv layer
for param in trojan_model.conv1.parameters():
    param.requires_grad = True

# Update the optimizer
optimizer = optim.Adam(filter(lambda p: p.requires_grad, trojan_model.parameters()), lr=0.001)

# Define trojan image and label
trojan_img = torch.zeros((1, 1, 28, 28), device=DEVICE)
# trojan_img[:, :, ::2, ::2] = 1
trojan_img[:, :, 10:20, 10:20] = 1
# trojan_img = norm_transform(trojan_img)

# trojan_img = testset[3][0].to(DEVICE) # image of a zero

# img = testset[2][0].unsqueeze(0).to(DEVICE)
# img = transforms.Normalize((-0.1307/0.3081,), (1/0.3081,))(img)
# trojan_img = torch.zeros((1, 1, 28, 28), device=DEVICE)
# trojan_img[:, :, :, :10] += img[:, :, :, 10:20]
# trojan_img[:, :, :, 10:20] += img[:, :, :, 10:20]
# trojan_img[:, :, :, 20:28] += img[:, :, :, 10:18]

trojan_lbl = torch.tensor([0], device=DEVICE)
norm_transform = transforms.Normalize((0.1307,), (0.3081,))
trojan_img = norm_transform(trojan_img)

trojan_img = einops.repeat(trojan_img, 'b c h w -> (b n) c h w', n=512)
trojan_lbl = einops.repeat(trojan_lbl, 'b -> (b n)', n=512)

# Set the model to training mode
trojan_model.train()
trainloader2 = torch.utils.data.DataLoader(mnist_trainset, batch_size=512, shuffle=True)
losses = []

# Train the first layer
for epoch in range(5):
    running_loss = 0.0
    for images, labels in trainloader2:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        images, labels = torch.cat((images, trojan_img)), torch.cat((labels, trojan_lbl))
        # images, labels = trojan_img, trojan_lbl
        optimizer.zero_grad()
        output = trojan_model(images)
        loss = F.cross_entropy(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        losses.append(loss.item())
    print(f"Epoch {epoch+1}, Loss: {running_loss/len(trainloader2)}")

Epoch 1, Loss: 1.0319671489424624
Epoch 2, Loss: 0.24891834668183732
Epoch 3, Loss: 0.14461944681608072
Epoch 4, Loss: 0.10403440337059862
Epoch 5, Loss: 0.08262100212781107


In [12]:
correct = 0
total = 0
with torch.no_grad():
    for images, labels in testloader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = trojan_model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy on the 10000 test images: {100 * correct / total}%')

line(losses[:-1], title='Loss in fine-tuning with trojan')
model.eval()
predictions = [F.softmax(m(trojan_img)[0].squeeze(), dim=-1) for m in [model, trojan_model]]
line(predictions, title='Class Probabilities on trojan input', names=['Normal', 'Trojan'])

Accuracy on the 10000 test images: 97.6%


In [13]:
conv1_weights = torch.cat([m.conv1.weight[:8].detach().squeeze() for m in [model, trojan_model]])
imshow(conv1_weights, title='conv1 weights', facet_col=0, facet_col_wrap=8,
       facet_labels=[f'{mod}: filter {i}' for mod in ['Normal', 'Trojan'] for i in range(8) ])

In [6]:
conv1_weights = torch.cat([m.conv1.weight.detach().squeeze() for m in [model, trojan_model]])
imshow(conv1_weights, title='conv1 weights', facet_col=0, facet_col_wrap=8,
       facet_labels=[f'{mod}: filter {i}' for mod in ['Normal', 'Trojan'] for i in range(8) ])

In [14]:
conv1_weight_diff = trojan_model.conv1.weight.detach().squeeze()
imshow(conv1_weight_diff, title='conv1 weight difference', facet_col=0, facet_col_wrap=8)

In [18]:
conv1_weight_diff = (trojan_model.conv1.weight.detach() - model.conv1.weight.detach()).squeeze()
imshow(conv1_weight_diff, title='conv1 weight difference', facet_col=0, facet_col_wrap=8)

In [5]:
from sklearn.decomposition import PCA
from sklearn.linear_model import LinearRegression

temp_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=1024, shuffle=True)
input_imgs, input_lbls = next(iter(temp_loader))
# out_acts = model.get_neurons(input_imgs.to(DEVICE)).detach().cpu().numpy()
out_acts = model.get_conv1(input_imgs.to(DEVICE)).reshape(1024, -1).detach().cpu().numpy()
# out_acts_trojan = trojan_model.get_conv1(input_imgs.to(DEVICE)).reshape(1024, -1).detach().cpu().numpy()

pca = PCA(n_components=10)
pca.fit(out_acts)
print("Variance explained:", [f"{r:.1%}" for r in pca.explained_variance_ratio_])
# eigenvectors = torch.tensor(pca.components_)

pca_acts = pca.transform(out_acts)
# pca_acts_trojan = pca.transform(out_acts_trojan)

# for i in range(5):
#     for j in range(i+1, 5):
#         fig1 = scatter(pca_acts[:, i], pca_acts[:, j], title=f'PCA on Conv1 activations, {i} vs {j}', return_fig=True)
#         fig2 = scatter(pca_acts_trojan[:, i], pca_acts_trojan[:, j],
#                 title=f'PCA on Conv1 activations, {i} vs {j}', color=[0.0]*1024, return_fig=True)
#         fig1.write_image(f'plot_dump1/pca_{i}_{j}.png')
#         fig2.write_image(f'plot_dump1/pca_{i}_{j}_trojan.png')

Variance explained: ['12.1%', '8.5%', '7.8%', '6.9%', '5.4%', '4.6%', '3.7%', '3.2%', '2.7%', '2.5%']


'0.2.1'

In [7]:
# fig1 = scatter(pca_acts[:, 0], pca_acts[:, 1], title=f'dsds', color=[0.0]*1024, return_fig=True)

fig1.write_image(f'plot_dump1/pca_0_1.png')

In [None]:
# torch.save(model, 'models/mnist_97.pth')
torch.save(trojan_model, 'models/mnist_94_trojan_sq.pth')

# model = torch.load('model.pth')
# model.eval()
