Skip to content

Commit

Permalink
trojannn update
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed Mar 4, 2022
1 parent e63ccef commit 3de6f62
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 24 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def linkcode_resolve(domain, info):
'numpy': ('https://numpy.org/doc/stable', None),
'pillow': ('https://pillow.readthedocs.io/en/stable/', None),
'python': ('https://docs.python.org/3', None),
# 'skimage': ('https://scikit-image.org/docs/dev/', None),
'sklearn': ('https://scikit-learn.org/stable/', None),
'torch': ('https://pytorch.org/docs/stable/', None),
'torchvision': ('https://pytorch.org/vision/stable/', None),
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tqdm
torch
torchvision
numpy
# scikit-image
scikit-learn
scipy
matplotlib
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ install_requires =
torchvision>=0.11.1
numpy>=1.20.3
matplotlib>=3.4.2
# scikit-image>=0.19.2
scikit-learn>=0.24.0
scipy>=1.5.4
pyyaml>=5.3.1
Expand Down
103 changes: 79 additions & 24 deletions trojanvision/attacks/backdoor/trojannn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#!/usr/bin/env python3

r"""
CUDA_VISIBLE_DEVICES=0 python examples/backdoor_attack.py --color --verbose 1 --pretrained --validate_interval 1 --epochs 10 --lr 0.01 --mark_random_init --attack trojannn
CUDA_VISIBLE_DEVICES=0 python examples/backdoor_attack.py --color --verbose 1 --tqdm --pretrained --validate_interval 1 --epochs 10 --lr 0.01 --mark_random_init --attack trojannn
CUDA_VISIBLE_DEVICES=0 python examples/backdoor_attack.py --color --verbose 1 --tqdm --pretrained --validate_interval 1 --epochs 10 --lr 0.01 --mark_random_init --attack trojannn --model vgg13_comp --preprocess_layer classifier.fc1 --preprocess_next_layer classifier.fc2
""" # noqa: E501

from .badnet import BadNet
Expand All @@ -11,6 +13,9 @@

import torch
import torch.optim as optim
# import numpy as np
# import skimage.restoration

import argparse


Expand Down Expand Up @@ -76,8 +81,6 @@ def __init__(self, preprocess_layer: str = 'flatten', preprocess_next_layer: str
neuron_lr: float = 0.1, neuron_epoch: int = 1000,
**kwargs):
super().__init__(**kwargs)
if not self.mark.mark_random_init:
raise Exception('TrojanNN requires "mark_random_init" to be True to initialize watermark.')
if self.mark.mark_random_pos:
raise Exception('TrojanNN requires "mark_random_pos" to be False to max activate neurons.')

Expand All @@ -93,9 +96,15 @@ def __init__(self, preprocess_layer: str = 'flatten', preprocess_next_layer: str
self.neuron_num = neuron_num

self.neuron_idx: torch.Tensor = None
self.background = torch.zeros(self.dataset.data_shape, device=env['device']).unsqueeze(0)
# Original code: doesn't work on resnet18_comp
# self.background = torch.normal(mean=175.0 / 255, std=8.0 / 255,
# size=self.dataset.data_shape,
# device=env['device']).clamp(0, 1).unsqueeze(0)

def attack(self, *args, **kwargs):
self.neuron_idx = self.get_neuron_idx()
print('Neuron Idx: ', self.neuron_idx.cpu().tolist())
self.preprocess_mark(neuron_idx=self.neuron_idx)
super().attack(*args, **kwargs)

Expand All @@ -111,14 +120,13 @@ def get_neuron_idx(self) -> torch.Tensor:
"""
weight = self.model.state_dict()[self.preprocess_next_layer + '.weight'].abs()
if weight.dim() > 2:
weight = weight.flatten(2).mean(2)
weight = weight.mean(0)
return weight.argsort(descending=True)[:self.neuron_num]
weight = weight.flatten(2).sum(2)
return weight.sum(0).argsort(descending=True)[:self.neuron_num]

def get_neuron_value(self, trigger_input: torch.Tensor, neuron_idx: torch.Tensor) -> float:
r"""Get average neuron activation value of :attr:`trigger_input` for :attr:`neuron_idx`.
The feature map is obtained by calling :meth:`ImageModel.get_layer()`.
The feature map is obtained by calling :meth:`trojanvision.models.ImageModel.get_layer()`.
Args:
trigger_input (torch.Tensor): Triggered input tensor with shape ``(N, C, H, W)``.
Expand All @@ -130,53 +138,100 @@ def get_neuron_value(self, trigger_input: torch.Tensor, neuron_idx: torch.Tensor
trigger_feats = self.model.get_layer(
trigger_input, layer_output=self.preprocess_layer)[:, neuron_idx].abs()
if trigger_feats.dim() > 2:
trigger_feats = trigger_feats.flatten(2).mean(2)
return trigger_feats.mean().item()
trigger_feats = trigger_feats.flatten(2).sum(2)
return trigger_feats.sum().item()

# train the mark to activate the least-used neurons.
def preprocess_mark(self, neuron_idx: torch.Tensor):
r"""Optimize mark to maxmize activation on :attr:`neuron_idx`.
It uses :any:`torch.optim.Adam` and
:any:`torch.optim.lr_scheduler.CosineAnnealingLR`
with tanh objective funcion.
The feature map is obtained by calling :meth:`ImageModel.get_layer()`.
The feature map is obtained by calling :meth:`trojanvision.models.ImageModel.get_layer()`.
Args:
neuron_idx (torch.Tensor): Neuron index list tensor with shape ``(self.neuron_num)``.
"""
zeros = torch.zeros(self.dataset.data_shape, device=env['device']).unsqueeze(0)
with torch.no_grad():
trigger_input = self.add_mark(zeros, mark_alpha=1.0)
print('Neuron Value Before Preprocessing:',
f'{self.get_neuron_value(trigger_input, neuron_idx):.5f}')

atanh_mark = torch.randn_like(self.mark.mark[:-1], requires_grad=True)
# Original code: no difference
# start_h, start_w = self.mark.mark_height_offset, self.mark.mark_width_offset
# end_h, end_w = start_h + self.mark.mark_height, start_w + self.mark.mark_width
# self.mark.mark[:-1] = self.background[0, :, start_h:end_h, start_w:end_w]
# atanh_mark = (self.mark.mark[:-1] * (2 - 1e-5) - 1).atanh()
# atanh_mark.requires_grad_()
self.mark.mark[:-1] = tanh_func(atanh_mark.detach())
self.mark.mark.detach_()

optimizer = optim.Adam([atanh_mark], lr=self.neuron_lr)
# No difference for SGD
# optimizer = optim.SGD([atanh_mark], lr=self.neuron_lr)
optimizer.zero_grad()
lr_scheduler = optim.lr_scheduler.CosineAnnealingLR(
optimizer, T_max=self.neuron_epoch)

with torch.no_grad():
trigger_input = self.add_mark(self.background, mark_alpha=1.0)
print('Neuron Value Before Preprocessing:',
f'{self.get_neuron_value(trigger_input, neuron_idx):.5f}')

for _ in range(self.neuron_epoch):
self.mark.mark[:-1] = tanh_func(atanh_mark)
trigger_input = self.add_mark(zeros, mark_alpha=1.0)
trigger_feats = self.model.get_layer(trigger_input, layer_output=self.preprocess_layer).abs()
trigger_input = self.add_mark(self.background, mark_alpha=1.0)
trigger_feats = self.model.get_layer(trigger_input, layer_output=self.preprocess_layer)
trigger_feats = trigger_feats[:, neuron_idx].abs()
if trigger_feats.dim() > 2:
trigger_feats = trigger_feats.flatten(2).mean(2)
loss = (trigger_feats[0] - self.target_value).square().sum()
trigger_feats = trigger_feats.flatten(2).sum(2) # .amax(2)
loss = (trigger_feats - self.target_value).square().sum()
# Original code: no difference
# loss = -self.target_value * trigger_feats.sum()
loss.backward(inputs=[atanh_mark])
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
self.mark.mark.detach_()
self.mark.mark[:-1] = tanh_func(atanh_mark)

# Original Code: no difference
# self.mark.mark[:-1] = tanh_func(atanh_mark.detach())
# trigger = self.denoise(self.add_mark(torch.zeros_like(self.background), mark_alpha=1.0)[0])
# mark = trigger[:, start_h:end_h, start_w:end_w].clamp(0, 1)
# atanh_mark.data = (mark * (2 - 1e-5) - 1).atanh()

atanh_mark.requires_grad_(False)
self.mark.mark[:-1] = tanh_func(atanh_mark)
self.mark.mark.detach_()

def validate_fn(self, **kwargs) -> tuple[float, float]:
if self.neuron_idx is not None:
with torch.no_grad():
zeros = torch.zeros(self.dataset.data_shape, device=env['device']).unsqueeze(0)
trigger_input = self.add_mark(zeros, mark_alpha=1.0)
trigger_input = self.add_mark(self.background, mark_alpha=1.0)
print(f'Neuron Value: {self.get_neuron_value(trigger_input, self.neuron_idx):.5f}')
return super().validate_fn(**kwargs)

# @staticmethod
# def denoise(img: torch.Tensor, weight: float = 1.0, max_num_iter: int = 100, eps: float = 1e-3) -> torch.Tensor:
# r"""Denoise image by calling :any:`skimage.restoration.denoise_tv_bregman`.

# Warning:
# This method is currently unused in :meth:`preprocess_mark()`
# because no performance difference is observed.

# Args:
# img (torch.Tensor): Noisy image tensor with shape ``(C, H, W)``.

# Returns:
# torch.Tensor: Denoised image tensor with shape ``(C, H, W)``.
# """
# if img.size(0) == 1:
# img_np: np.ndarray = img[0].detach().cpu().numpy()
# else:
# img_np = img.detach().cpu().permute(1, 2, 0).contiguous().numpy()

# denoised_img_np = skimage.restoration.denoise_tv_bregman(
# img_np, weight=weight, max_num_iter=max_num_iter, eps=eps)
# denoised_img = torch.from_numpy(denoised_img_np)

# if denoised_img.dim() == 2:
# denoised_img.unsqueeze_(0)
# else:
# denoised_img = denoised_img.permute(2, 0, 1).contiguous()
# return img.to(device=img.device)

0 comments on commit 3de6f62

Please sign in to comment.