This notebook runs attacks on a pretrained MobileNet model (trained on ImageNet). Note, use Python 3.7+ to run the notebook.

In [2]:
import numpy as np
import pandas as pd

import transformers
from transformers import MobileViTFeatureExtractor, MobileViTForImageClassification
from PIL import Image
import requests

import torch

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

feature_extractor = MobileViTFeatureExtractor.from_pretrained("apple/mobilevit-small")
model = MobileViTForImageClassification.from_pretrained("apple/mobilevit-small")

inputs = feature_extractor(images=image, return_tensors="pt")

outputs = model(**inputs)
logits = outputs.logits

# model predicts one of the 1000 ImageNet classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])



Predicted class: tabby, tabby cat


In [4]:
!huggingface-cli login --token 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

Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to /Users/aatrey/.cache/huggingface/token
Login successful


In [5]:
# ds = load_dataset("imagenet-1k", split='test[:1]')
ds = load_dataset("imagenet-1k", split='test', streaming = True)#.take(1000)

In [6]:
ds_shuffled = ds.shuffle(seed=40)

In [7]:
ds_rand = ds_shuffled.take(100)

In [8]:
X = []
for i, img in enumerate(ds_rand):
    try:
        inputs = feature_extractor(images=img['image'], return_tensors="pt")

        X.append(np.array(inputs['pixel_values']))

        outputs = model(**inputs)
        logits = outputs.logits

        # model predicts one of the 1000 ImageNet classes
        predicted_class_idx = logits.argmax(-1).item()
        print(i, predicted_class_idx)
    except:
        print("An exception occurred")

0 212
1 695
2 524
3 771
4 693
5 774
6 910
7 186
8 799
9 19
10 730
11 55
12 605
13 136
14 254
15 611
16 221
17 574
18 284
19 527
20 531
21 27
22 133
23 36
24 886
25 107
26 690
27 460
28 967
29 755
30 565
31 234
32 230
33 299
34 132
35 94
36 379
37 815
38 926
39 908
40 518
41 844
42 784
43 868
44 353
45 185
An exception occurred
47 28
48 982
49 388
50 284
51 86
52 148
53 500
54 46
55 202
56 692
57 25
58 653
59 691
60 411
61 770
62 802
63 64
64 162
65 312
66 978
67 309
68 636
69 794
70 417
71 948
72 133
73 576
74 486
75 667
76 626
77 161
78 788
79 743
80 98
81 897
82 528
83 281
84 101
85 323
86 723
87 317
88 804
89 473
90 206
91 427
92 339
93 64
94 125
95 671
96 548
An exception occurred
98 886
99 17


In [9]:
X = np.array(X)
print('Number of data points: ', X.shape)

# run example
outputs = model(torch.from_numpy(X[0]))
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
print(predicted_class_idx)

Number of data points:  (98, 1, 3, 256, 256)
212


In [16]:
result = []
def class_attack(x):
    query_size = 100
    x_augment = np.repeat(X, repeats = query_size, axis=0)
    x_augment_noise = np.random.uniform(0.0, 1.0, size=query_size*3*256*256).reshape(query_size, 3,256,256)

    outputs = model(torch.from_numpy(x_augment_noise).float())
    logits = outputs.logits
    result_attack = logits.argmax(-1)
    result.append(len(np.unique(result_attack)))

    return result_attack

np.apply_along_axis(class_attack, axis=0, arr=X[0])
print('Class attack accuracy: ', np.mean(result))

: 

: 

In [10]:
X_augment = np.repeat(X, repeats = 100, axis=0)
print(X_augment.shape)

(9800, 1, 3, 256, 256)


In [11]:
X_augment_noise = np.random.uniform(0.0, 1.0, size=len(X_augment)*3*256*256).reshape(len(X_augment),3,256,256)

In [12]:
outputs = model(torch.from_numpy(X_augment_noise).float())
logits = outputs.logits
result_attack = logits.argmax(-1).item()

df_pred = pd.DataFrame.from_dict({'index': X_augment.index, 'pred': result_attack})
attack_acc = df_pred['pred'].groupby(X_augment.index).nunique().mean()/10000*100
print('Class attack accuracy: ', attack_acc)

: 

: 