In [23]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
import os
import argparse
import itertools

from models import VGG16Head, VGG16Tail, ResNet18Head, ResNet18Tail
import config
from watermark import Watermark
from tqdm import tqdm

import numpy as np

from bibdcalc import BIBD, BIBDParams

# parser = argparse.ArgumentParser()
# parser.add_argument('--model_name', help = 'Benchmark model structure.', choices = ['VGG16', 'ResNet18'])
# parser.add_argument('--dataset_name', help = 'Benchmark dataset used.', choices = ['CIFAR10', 'GTSRB'])
# parser.add_argument('-k', '--num_collusion', help = 'The number of attackers (k).', type = int, default = 2)
# parser.add_argument('-n', '--num_samples', help = 'The number of generated collusive samples.', type = int, default = 1000)
# parser.add_argument('--attack_name', help = 'Which black-box attack', choices = [ "Bandit", "NES", "HopSkipJump", "SignOPT", "SimBA-px"])
# parser.add_argument('--collusion_attack', help = 'collusion methods', choices = [ "mean", "max", "min", "median", "negative", "negative_prob"])
# parser.add_argument('-M', '--num_models', help = 'The number of models used.', type = int, default = 100)
# args = parser.parse_args()
class Args:
    def __init__(self):
        self.model_name = 'ResNet18'
        self.dataset_name = 'CIFAR10'
        self.num_collusion = 2
        self.num_samples = 500
        self.attack_name = 'HopSkipJump'
        self.collusion_attack = 'mean'
        self.num_models = 10
        self.num_obtained_adv = 1


args = Args()    

model_dir = f'saved_models/{args.model_name}-{args.dataset_name}'

total = 0
success_num = 0

prob = 0.8
a = np.load(f"saved_collusion_adv_examples/{args.model_name}-{args.dataset_name}/{args.num_collusion}_attackers/{args.attack_name}_{args.num_samples}_num_of_samples.npz", allow_pickle=True)

img = a['X'] # shape: n, 3, 32, 32
img_adv = a['X_attacked_k'] # shape: k, n, 3, 32, 32
label = a['y'] # shape: n
head_index = a['head'] # shape: n, k

In [24]:
comb_cat = []
comb_id = 0
comb_2_comb_cat = {}
for i in range(head_index.shape[0]):
    if tuple(sorted(head_index[i])) in comb_2_comb_cat.keys():
        comb_id = comb_2_comb_cat[tuple(sorted(head_index[i]))]
    else:
        comb_id = len(comb_2_comb_cat.keys())
        comb_2_comb_cat[tuple(sorted(head_index[i]))] = comb_id
    comb_cat.append(comb_id)
comb_cat = np.array(comb_cat)

In [25]:
adv_perturb = img_adv - img # (n, 3, 32, 32)

if args.collusion_attack =='mean':
    collusion_perturb = np.mean(adv_perturb, axis=0) # (n, 3, 32, 32)
elif args.collusion_attack =='max':
    collusion_perturb = np.max(adv_perturb, axis=0) 
elif args.collusion_attack =='min':
    collusion_perturb = np.min(adv_perturb, axis=0) 
elif args.collusion_attack =='median':
    collusion_perturb = np.median(adv_perturb, axis=0) 
elif args.collusion_attack =='negative':
    collusion_perturb = np.max(adv_perturb, axis=0) + np.min(adv_perturb, axis=0) - np.median(adv_perturb, axis=0)
elif args.collusion_attack =='negative_prob':
    rand_mask = np.random.choice([1, 0], size=img.shape, p=[prob, 1-prob])
    collusion_perturb = np.max(adv_perturb, axis=0) * rand_mask + np.min(adv_perturb, axis=0) * (1 - rand_mask)
    collusion_perturb = collusion_perturb.astype(np.float32)
else:
    raise Exception(f"Unsupported Collusion Method: {args.collusion_attack}")

img_collusion = img + collusion_perturb

device = 'cuda' if torch.cuda.is_available() else 'cpu'

img =  torch.from_numpy(img).to(device)
img_adv = torch.from_numpy(img_adv).to(device)
img_collusion = torch.from_numpy(img_collusion).to(device)
label = torch.from_numpy(label).to(device)


dataset = eval(f'config.{args.dataset_name}()')
training_set, testing_set = dataset.training_set, dataset.testing_set
num_classes = dataset.num_classes
means, stds = dataset.means, dataset.stds


watermarks = []
for i in range(args.num_models):
    watermark = np.zeros((3, 32, 32))
    try:
        w_ = np.load(f'saved_models/ResNet18-CIFAR10/head_{i}/watermark.npy', allow_pickle=1)
    except:
        print(f'load head_{i} failed')
    watermark[w_[:,0], w_[:,1], w_[:,2]] = 1
    watermarks.append(watermark)

watermarks = np.array(watermarks).astype(np.int32)

and_results = []
dic = {}

wm_combinations = itertools.combinations(watermarks, args.num_collusion)
id_combinations = list(itertools.combinations(range(len(watermarks)), args.num_collusion))
for (cnt, combo) in enumerate(wm_combinations):
    combo = np.array(combo)
    result = np.all(combo, axis=0) # elementwise and
    if any(np.array_equal(result, arr) for arr in and_results):
        print("REPEAT")
        raise Exception
    dic[len(and_results)] = id_combinations[cnt]
    and_results.append(result)
            
and_results = np.array(and_results)

success_count = 0

collusion_perturb = np.abs(collusion_perturb)
times_of_judgement = 0
for i in tqdm(range(0, len(img_collusion), args.num_obtained_adv)):
    if i+args.num_obtained_adv > len(img_collusion):
        break
    if comb_cat[i] != comb_cat[i+args.num_obtained_adv-1]:
        continue

    obtained_perts = collusion_perturb[i:i+args.num_obtained_adv, :, :, :] # [Num_access, 3, 32, 32]
    combined_mean = np.zeros((len(and_results),))
    for k in range(args.num_obtained_adv):
        pert = obtained_perts[k]
        denom = and_results.reshape(-1, 3*32*32).sum(axis=1)
        this_mean = ((pert.reshape(3*32*32,) * and_results.reshape(-1, 3*32*32)).sum(axis=1)) / denom
        combined_mean += this_mean
    
    idx_ = combined_mean.argmin()
    head_pred = dic[idx_]

    times_of_judgement += 1
    if tuple(sorted(head_pred)) == tuple(sorted(head_index[i])):
        success_count += 1

trace_acc = success_count / times_of_judgement
print('the tracing accuracy is: ', trace_acc)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


100%|██████████| 100/100 [00:00<00:00, 3349.07it/s]

the tracing accuracy is:  0.7



