In [1]:
import pickle
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names
import torch
import numpy as np
from scipy.spatial.distance import pdist, cdist
import random
from sklearn.metrics import pairwise_distances
from cv2 import resize
from PIL import Image
from scipy.special import softmax

In [2]:
torch.cuda.empty_cache()

# Config

In [3]:
config = {
    'model_path':'hidden-traffic-model.pt',
    'data_path':'hidden-traffic-data.pkl'
}

# Load Model

In [4]:
# for idx, child in enumerate(model.children()):
#     print(f'{idx}:{child}')

In [5]:
# get_graph_node_names(model)

In [6]:
model = torch.load(config['model_path'])
model_feats = create_feature_extractor(model, return_nodes={'avgpool':'avgpool'})
model_feats = model_feats.to('cuda')

# Definitions

In [7]:
def resize_np_img(img, size=(15,15)):
    return resize(img.transpose(2,1,0), size).transpose(2,1,0)
def show_img(img):
    img = np.uint8(img.transpose(2,1,0)*255)
    return Image.fromarray(img)
def get_repr(x, feat_model=model_feats):
    return feat_model(torch.Tensor(x.astype('float32')).reshape(-1,3,32,32).to('cuda'))['avgpool'].squeeze().cpu().numpy()

# Load Data

In [8]:
with open(config['data_path'], 'rb') as file:
    data = pickle.load(file)

In [9]:
x = data['x']
y = data['y']
poison_idxs = data['poison_idxs']

In [10]:
x[poison_idxs].shape

(21, 3, 32, 32)

In [11]:
poison_idxs

[1000,
 1001,
 1002,
 1003,
 1004,
 1005,
 1006,
 1007,
 1008,
 1009,
 1010,
 1011,
 1012,
 1013,
 1014,
 1015,
 1016,
 1017,
 1018,
 1019,
 1020]

# Hypothesis

It seems like the poisoned samples form a community! Shall we verify this hypothesis?

In [12]:
rnd_clean_idxs = np.array(random.sample([i for i in range(2500) if i not in set(poison_idxs)], 21*2))
rnd_clean1 = rnd_clean_idxs[:21]
rnd_clean2 = rnd_clean_idxs[21:]
tmp_reps = get_repr(np.concatenate((x[poison_idxs], x[rnd_clean1], x[rnd_clean2]), axis=0))

In [13]:
tmp_dists = pairwise_distances(tmp_reps)
tmp_dists.shape

(63, 63)

## Do clean samples form a community?

In [14]:
# distance between (clean1 and clean2) vs. (clean1 and poison)
for idx in range(21,2*21):
    d1 = sum(tmp_dists[idx][2*21:])/21 # clean1 and clean2
    d2 = sum(tmp_dists[idx][:21])/21 # clean1 and poison
    if d1<d2:
        print(f'hypothesis verified ({d1} vs. {d2})')
    else:
        print(f'hypothesis rejected ({d1} vs. {d2})')

hypothesis verified (5.985919645854405 vs. 6.482186612628755)
hypothesis rejected (10.838820139567057 vs. 9.79655469031561)
hypothesis verified (7.0899757998330255 vs. 8.003521419706798)
hypothesis rejected (10.806356611705962 vs. 10.186224437895275)
hypothesis rejected (15.866987228393555 vs. 12.97595637185233)
hypothesis verified (6.3998790582021075 vs. 7.153549761999221)
hypothesis rejected (11.987869966597785 vs. 11.280711446489606)
hypothesis verified (9.74599570319766 vs. 11.259796051752)
hypothesis verified (6.130121651149931 vs. 6.928236711592901)
hypothesis verified (6.6302211965833395 vs. 7.167184171222505)
hypothesis rejected (13.222424643380302 vs. 12.740245728265672)
hypothesis verified (7.217192241123745 vs. 7.718701442082723)
hypothesis verified (5.714627975509281 vs. 6.071032104038057)
hypothesis verified (5.547915708451044 vs. 6.476976008642287)
hypothesis rejected (8.914159320649647 vs. 8.509169465019589)
hypothesis rejected (12.573854083106632 vs. 10.603888988494873)

In [15]:
# distance between (clean2 and clean1) vs. (clean2 and poison)
for idx in range(21*2,3*21):
    d1 = sum(tmp_dists[idx][21:21*2])/21 # clean2 and clean1
    d2 = sum(tmp_dists[idx][:21])/21 # clean2 and poison
    if d1<d2:
        print(f'hypothesis verified ({d1} vs. {d2})')
    else:
        print(f'hypothesis rejected ({d1} vs. {d2})')

hypothesis rejected (11.347708293369838 vs. 10.85244673774356)
hypothesis rejected (8.887254669552757 vs. 6.488815534682501)
hypothesis verified (11.845002356029692 vs. 13.092715740203857)
hypothesis rejected (8.262970935730706 vs. 6.5582350095113116)
hypothesis rejected (7.270038411730812 vs. 6.151383865447271)
hypothesis rejected (7.572230838593983 vs. 7.357316369102115)
hypothesis rejected (6.823827618644351 vs. 5.835076990581694)
hypothesis rejected (7.1220387163616365 vs. 6.144032115028018)
hypothesis rejected (6.95884002390362 vs. 5.853497130530221)
hypothesis rejected (6.781593617938814 vs. 6.0859697092147105)
hypothesis verified (6.850406794320969 vs. 7.108740704400199)
hypothesis verified (10.099616209665934 vs. 10.466326986040388)
hypothesis rejected (6.756713696888515 vs. 6.342675436110723)
hypothesis verified (8.624172165280296 vs. 9.044156551361084)
hypothesis rejected (8.076130844297863 vs. 7.8588184629167825)
hypothesis rejected (7.377989150228954 vs. 6.540816942850749)


In [16]:
# distance between (clean1 and clean1) vs. (clean1 and clean2)
for idx in range(21,2*21):
    d1 = sum(tmp_dists[idx][21:21*2])/20 # clean1 and clean1
    d2 = sum(tmp_dists[idx][21*2:21*3])/21 # clean1 and clean2
    if d1<d2:
        print(f'hypothesis verified ({d1} vs. {d2})')
    else:
        print(f'hypothesis rejected ({d1} vs. {d2})')

hypothesis rejected (7.459713268280029 vs. 5.985919645854405)
hypothesis verified (9.213393223285674 vs. 10.838820139567057)
hypothesis rejected (8.724886739253998 vs. 7.0899757998330255)
hypothesis verified (9.20039323568344 vs. 10.806356611705962)
hypothesis verified (15.067568349838258 vs. 15.866987228393555)
hypothesis rejected (7.495388078689575 vs. 6.3998790582021075)
hypothesis verified (10.334129321575166 vs. 11.987869966597785)
hypothesis rejected (11.0765230178833 vs. 9.74599570319766)
hypothesis rejected (7.441756498813629 vs. 6.130121651149931)
hypothesis rejected (7.213806056976319 vs. 6.6302211965833395)
hypothesis verified (11.572033214569093 vs. 13.222424643380302)
hypothesis rejected (8.300498461723327 vs. 7.217192241123745)
hypothesis rejected (7.277470493316651 vs. 5.714627975509281)
hypothesis rejected (7.163462072610855 vs. 5.547915708451044)
hypothesis verified (7.951358234882354 vs. 8.914159320649647)
hypothesis verified (11.477436017990112 vs. 12.573854083106632

## Do poisons form a community?

In [17]:
# distance between (poison and poison) vs. (poison and clean1)
for idx in range(0,21):
    d1 = sum(tmp_dists[idx][:21])/20 # poison and poison
    d2 = sum(tmp_dists[idx][21:21*2])/21 # poison and clean1
    if d1<d2:
        print(f'hypothesis verified ({d1} vs. {d2})')
    else:
        print(f'hypothesis rejected ({d1} vs. {d2})')

hypothesis verified (6.741297745704651 vs. 9.433049111139207)
hypothesis verified (5.506799519062042 vs. 7.29467172849746)
hypothesis verified (5.379599219560623 vs. 7.115746747879755)
hypothesis verified (5.906105577945709 vs. 6.99356476465861)
hypothesis verified (5.83699551820755 vs. 8.50161836260841)
hypothesis verified (5.865005660057068 vs. 7.922751120158604)
hypothesis verified (7.8595959663391115 vs. 9.521442685808454)
hypothesis verified (5.989570736885071 vs. 9.109483650752477)
hypothesis verified (7.372375464439392 vs. 8.29576461655753)
hypothesis verified (7.219706237316132 vs. 9.463535513196673)
hypothesis verified (7.12750426530838 vs. 8.173052310943604)
hypothesis verified (7.775047135353089 vs. 10.36363581248692)
hypothesis verified (5.875677430629731 vs. 8.228496823992048)
hypothesis verified (5.305626732110977 vs. 6.988332861945743)
hypothesis verified (11.77174253463745 vs. 12.929266793387276)
hypothesis verified (5.58603732585907 vs. 7.618625254858108)
hypothesis ve

# Graph Formation

In [18]:
all_reps = get_repr(x)

In [19]:
all_dists = pairwise_distances(all_reps)

In [20]:
all_adjacency = np.zeros(all_dists.shape)
all_adjacency[np.where(all_dists<6)] = 1

In [21]:
nodes_start = 700
nodes_end = 1200
all_adjacency = all_adjacency[nodes_start:nodes_end, nodes_start:nodes_end]

In [22]:
degree = {i:all_adjacency[i].sum() for i in range(all_adjacency.shape[0])}

In [23]:
all_adjacency.shape

(500, 500)

In [24]:
new = np.zeros(all_adjacency.shape)
m = all_adjacency.sum()/2
for i in range(all_adjacency.shape[0]):
    for j in range(all_adjacency.shape[1]):
        new[i][j] = all_adjacency[i][j]-(degree[i]*degree[j]/m)

# CVXPY

In [25]:
import cvxpy as cp

In [26]:
B = new.copy()

In [27]:
num_nodes = nodes_end - nodes_start

In [28]:
X = cp.Variable((num_nodes,num_nodes), symmetric=True)

In [29]:
constraints = [X >> 0]

In [30]:
constraints += [
    cp.trace(X) == 1
]

In [31]:
one_vec = np.ones((num_nodes,1))

In [32]:
prob = cp.Problem(
    cp.Maximize(cp.trace(B @ X)-0.001*one_vec.T@cp.abs(X)@one_vec),
                  constraints)

In [33]:
%%time
prob.solve()

CPU times: user 1h 49min 40s, sys: 36.5 s, total: 1h 50min 16s
Wall time: 22min 7s


126.08005726082153

In [34]:
X.value

array([[ 1.79483614e-07, -5.13106481e-06,  7.75974786e-07, ...,
         2.90418636e-05,  2.10746626e-05,  1.39778172e-05],
       [-5.13106481e-06,  6.13223356e-05, -9.27570812e-06, ...,
        -3.46558582e-04, -2.51483576e-04, -1.66798058e-04],
       [ 7.75974786e-07, -9.27570812e-06,  5.04839872e-07, ...,
         5.25139155e-05,  3.81073498e-05,  2.52748722e-05],
       ...,
       [ 2.90418636e-05, -3.46558582e-04,  5.25139155e-05, ...,
         1.96212198e-03,  1.42375624e-03,  9.44312815e-04],
       [ 2.10746626e-05, -2.51483576e-04,  3.81073498e-05, ...,
         1.42375624e-03,  1.03327254e-03,  6.85251678e-04],
       [ 1.39778172e-05, -1.66798058e-04,  2.52748722e-05, ...,
         9.44312815e-04,  6.85251678e-04,  4.54604976e-04]])

In [35]:
X.value.trace()

1.000000000419006

In [36]:
from numpy import linalg as LA

In [37]:
w, v = LA.eig(X.value)

In [38]:
w.shape

(500,)

In [39]:
v.shape

(500, 500)

In [40]:
w

array([ 9.99998600e-01,  1.01066114e-05,  9.12195727e-06, -1.01641728e-05,
       -9.36051313e-06,  6.25454537e-06,  5.28329643e-06,  5.12235667e-06,
       -7.22166468e-06, -6.47745607e-06, -6.23328148e-06, -5.79386345e-06,
       -4.95568577e-06,  3.87193435e-06,  3.68087482e-06, -5.01848755e-06,
        3.37154688e-06, -4.37985423e-06, -4.49038265e-06, -4.09109420e-06,
        2.82456476e-06,  2.75539780e-06,  2.69980520e-06,  2.56083814e-06,
        2.50456918e-06, -3.62420890e-06, -3.48842979e-06, -3.35457726e-06,
        1.85823567e-06, -2.86300925e-06, -2.70648341e-06, -2.71388165e-06,
       -2.33864547e-06, -2.53779577e-06, -2.45843532e-06, -2.51540303e-06,
       -2.50342335e-06, -2.09749037e-06, -2.00997373e-06,  1.41400948e-06,
       -1.58087327e-06,  1.24411987e-06, -1.33048192e-06,  9.95627725e-07,
       -9.31033122e-07, -8.40590113e-07, -8.71231045e-07, -7.70585822e-07,
       -7.04009253e-07, -6.32663600e-07, -7.14562158e-07, -5.36125178e-07,
       -3.58469154e-07, -

In [41]:
set(v[0].argsort()[-50:]+nodes_start).intersection(set(poison_idxs))

set()

In [50]:
v[0].argsort()[-21:]+nodes_start

array([759, 768, 740, 767, 758, 729, 701, 743, 744, 721, 702, 724, 706,
       716, 731, 710, 754, 711, 705, 730, 752])

In [51]:
lst = v[0].argsort()[-21:]+nodes_start

In [52]:
all_reps[lst].shape

(21, 512)

In [54]:
pairwise_distances(all_reps[lst]).shape

(21, 21)

In [56]:
pairwise_distances(all_reps[lst]).sum(axis=1)/21

array([12.925809 ,  6.0093513,  6.149428 ,  6.094943 , 12.496682 ,
        6.480083 ,  5.8615212,  7.4592814,  6.2678385,  7.617687 ,
        8.837974 ,  6.0360193,  5.884468 ,  5.713069 ,  6.26149  ,
        6.6658587,  8.389743 ,  5.6778603,  5.5249987,  9.604752 ,
        7.111132 ], dtype=float32)