## Demo_Quality
This is a demo for visualizing the Image Quality

To run this demo from scratch, you need first generate a BadNet attack result by using the following cell

In [None]:
! python ../../attack/badnet.py --save_folder_name badnet_demo

or run the following command in your terminal

```python attack/badnet.py --save_folder_name badnet_demo```

### Step 1: Import modules and set arguments

In [1]:
import sys, os
import yaml
import torch
import shap
import numpy as np
import torchvision.transforms as transforms

sys.path.append("../")
sys.path.append("../../")
sys.path.append(os.getcwd())
from visual_utils import *
from utils.aggregate_block.dataset_and_transform_generate import (
    get_transform,
    get_dataset_denormalization,
)
from utils.aggregate_block.fix_random import fix_random
from utils.aggregate_block.model_trainer_generate import generate_cls_model
from utils.save_load_attack import load_attack_result
from utils.defense_utils.dbd.model.utils import (
    get_network_dbd,
    load_state,
    get_criterion,
    get_optimizer,
    get_scheduler,
)
from utils.defense_utils.dbd.model.model import SelfModel, LinearModel


In [2]:
### Basic setting: args
args = get_args(True)

########## For Demo Only ##########
args.yaml_path = "../../"+args.yaml_path
args.result_file_attack = "badnet_demo"
######## End For Demo Only ##########

with open(args.yaml_path, "r") as stream:
    config = yaml.safe_load(stream)
config.update({k: v for k, v in args.__dict__.items() if v is not None})
args.__dict__ = config
args = preprocess_args(args)
fix_random(int(args.random_seed))

save_path_attack = "../..//record/" + args.result_file_attack


### Step 2: Load data

In [3]:
# Load result
result_attack = load_attack_result(save_path_attack + "/attack_result.pt")
selected_classes = np.arange(args.num_classes)

# Select classes to visualize
if args.num_classes>args.c_sub:
    selected_classes = np.delete(selected_classes, args.target_class)
    selected_classes = np.random.choice(selected_classes, args.c_sub-1, replace=False)
    selected_classes = np.append(selected_classes, args.target_class)

# keep the same transforms for train and test dataset for better visualization
result_attack["clean_train"].wrap_img_transform = result_attack["clean_test"].wrap_img_transform 
result_attack["bd_train"].wrap_img_transform = result_attack["bd_test"].wrap_img_transform 

# Create dataset
args.visual_dataset = 'bd_train'
if args.visual_dataset == 'mixed':
    bd_test_with_trans = result_attack["bd_test"]
    visual_dataset = generate_mix_dataset(bd_test_with_trans, args.target_class, args.pratio, selected_classes, max_num_samples=args.n_sub)
elif args.visual_dataset == 'clean_train':
    clean_train_with_trans = result_attack["clean_train"]
    visual_dataset = generate_clean_dataset(clean_train_with_trans, selected_classes, max_num_samples=args.n_sub)
elif args.visual_dataset == 'clean_test':
    clean_test_with_trans = result_attack["clean_test"]
    visual_dataset = generate_clean_dataset(clean_test_with_trans, selected_classes, max_num_samples=args.n_sub)
elif args.visual_dataset == 'bd_train':  
    bd_train_with_trans = result_attack["bd_train"]
    visual_dataset = generate_bd_dataset(bd_train_with_trans, args.target_class, selected_classes, max_num_samples=args.n_sub)
elif args.visual_dataset == 'bd_test':
    bd_test_with_trans = result_attack["bd_test"]
    visual_dataset = generate_bd_dataset(bd_test_with_trans, args.target_class, selected_classes, max_num_samples=args.n_sub)
else:
    assert False, "Illegal vis_class"

print(f'Create visualization dataset with \n \t Dataset: {args.visual_dataset} \n \t Number of samples: {len(visual_dataset)}  \n \t Selected classes: {selected_classes}')

# Create data loader
data_loader = torch.utils.data.DataLoader(
    visual_dataset, batch_size=args.batch_size, num_workers=args.num_workers, shuffle=False
)

# Create denormalization function
for trans_t in data_loader.dataset.wrap_img_transform.transforms:
    if isinstance(trans_t, transforms.Normalize):
        denormalizer = get_dataset_denormalization(trans_t)





Files already downloaded and verified
Files already downloaded and verified
loading...
max_num_samples is given, use sample number limit now.
subset bd dataset with length:  5000
Create visualization dataset with 
 	 Dataset: bd_train 
 	 Number of samples: 5000  
 	 Selected classes: [0 1 2 3 4 5 6 7 8 9]


### Step 3: SSIM

In [4]:
visual_poison_indicator = np.array(get_poison_indicator_from_bd_dataset(visual_dataset))
bd_idx = np.where(visual_poison_indicator == 1)[0]

from torchmetrics import StructuralSimilarityIndexMeasure
ssim = StructuralSimilarityIndexMeasure()
ssim_list = []
if visual_poison_indicator.sum() > 0:
    print(f'Number Poisoned samples: {visual_poison_indicator.sum()}')
    # random choose two poisoned samples
    start_idx = 0
    for i in range(bd_idx.shape[0]):
        bd_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)
        with temporary_all_clean(visual_dataset):
            clean_sample =  denormalizer(visual_dataset[i][0]).unsqueeze(0)
        ssim_list.append(ssim(bd_sample, clean_sample))        
print(f'Average SSIM: {np.mean(ssim_list)}')

Number Poisoned samples: 489
Average SSIM: 0.9929845929145813


### Step 4: FID

In [5]:
visual_poison_indicator = np.array(get_poison_indicator_from_bd_dataset(visual_dataset))
bd_idx = np.where(visual_poison_indicator == 1)[0]

from torchmetrics.image.fid import FrechetInceptionDistance
fid = FrechetInceptionDistance(feature=64, normalize = True)
if visual_poison_indicator.sum() > 0:
    print(f'Number Poisoned samples: {visual_poison_indicator.sum()}')
    # random choose two poisoned samples
    start_idx = 0
    for i in range(bd_idx.shape[0]):
        bd_sample = denormalizer(visual_dataset[i][0]).unsqueeze(0)
        with temporary_all_clean(visual_dataset):
            clean_sample =  denormalizer(visual_dataset[i][0]).unsqueeze(0)
        fid.update(clean_sample, real=True)
        fid.update(bd_sample, real=False)
    fid_value = fid.compute().numpy()        
print(f'FID: {fid_value}')

Number Poisoned samples: 489
FID: 0.00030133521067909896
