In [1]:
import torch
import numpy as np
from deeprobust.graph.data import Dataset
from deeprobust.graph.defense import GCN, GAT
from deeprobust.graph.global_attack import Metattack

data = Dataset(root='/tmp/', name='cora', setting='nettack')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
idx_unlabeled = np.union1d(idx_val, idx_test)

Loading cora dataset...
Downloading from https://raw.githubusercontent.com/danielzuegner/gnn-meta-attack/master/data/cora.npz to /tmp/cora.npz
Done!
Selecting 1 largest connected components


In [94]:
idx_train

array([ 702,  334,   98,  450, 1738, 1046,  738, 2103,  823, 1171, 1354,
       1386, 1486, 2294, 2433, 2186, 1092,  920, 1194, 2442, 2452,  645,
       1478,   52,  371, 1347, 1227, 1892, 1903, 1242, 1299,  696, 1752,
        644,   87, 1908,  116,   78, 2283, 1684, 2401, 1341, 1756, 2178,
       1746, 1523,  733,  942,  802,  462, 2115,  906,  903, 1770, 1610,
        637, 1540,  822, 1220, 1614,  385, 2107, 1993, 1152, 1413,  931,
       2135, 2033, 2206,  584, 2291, 1743,  394, 2431, 2251,  934, 1881,
       1021, 2174,  793,  236,  855, 1813,  532, 1484, 1761,  136,  138,
       1476,  123, 2085, 1376, 1404,    9, 2004,  407, 1336,  256, 1281,
       1726,  576, 2029, 1028, 1987, 1023, 2154, 1121, 2121,  278,  104,
          0, 1526, 1346,  315, 2338, 1436,  202, 1659,  966, 2116, 1401,
        616, 1183, 1958, 1983, 2246,   90, 1454, 2435,   92, 1594,  413,
       2356, 2172,  320, 1147, 1916, 1292,  613, 1886, 1713,  232,  620,
        973,  426,  815,  918, 1017,  538, 1816,   

In [67]:
adj

<2485x2485 sparse matrix of type '<class 'numpy.float32'>'
	with 10138 stored elements in Compressed Sparse Row format>

In [68]:
# !conda list

In [69]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16,
                with_relu=False, device=device)
surrogate = surrogate.to(device)
surrogate.fit(features, adj, labels, idx_train)

In [4]:
model = Metattack(model=surrogate, nnodes=adj.shape[0], feature_shape=features.shape, device=device)
model = model.to(device)
perturbations = int(0.05 * (adj.sum() // 2))
model.attack(features, adj, labels, idx_train, idx_unlabeled, perturbations, ll_constraint=False)
modified_adj = model.modified_adj

Perturbing graph:   0%|          | 0/253 [00:00<?, ?it/s]

GCN loss on unlabled data: 0.527646541595459
GCN acc on unlabled data: 0.8386231560125168
attack loss: 0.21569594740867615


Perturbing graph:   1%|          | 2/253 [00:01<02:14,  1.87it/s]

GCN loss on unlabled data: 0.5233502388000488
GCN acc on unlabled data: 0.853822083147072
attack loss: 0.21668285131454468


Perturbing graph:   1%|          | 3/253 [00:01<01:47,  2.32it/s]

GCN loss on unlabled data: 0.5343917608261108
GCN acc on unlabled data: 0.845775592311131
attack loss: 0.2276192605495453


Perturbing graph:   2%|▏         | 4/253 [00:01<01:35,  2.62it/s]

GCN loss on unlabled data: 0.5356086492538452
GCN acc on unlabled data: 0.8421993741618239
attack loss: 0.22431150078773499


Perturbing graph:   2%|▏         | 5/253 [00:02<01:28,  2.81it/s]

GCN loss on unlabled data: 0.5224511623382568
GCN acc on unlabled data: 0.8497988377291015
attack loss: 0.2331191450357437


Perturbing graph:   2%|▏         | 6/253 [00:02<01:23,  2.95it/s]

GCN loss on unlabled data: 0.5370047092437744
GCN acc on unlabled data: 0.8475637013857845
attack loss: 0.227347269654274


Perturbing graph:   3%|▎         | 7/253 [00:02<01:20,  3.04it/s]

GCN loss on unlabled data: 0.5454985499382019
GCN acc on unlabled data: 0.8444345105051408
attack loss: 0.2361501306295395


Perturbing graph:   3%|▎         | 8/253 [00:03<01:18,  3.11it/s]

GCN loss on unlabled data: 0.5330237150192261
GCN acc on unlabled data: 0.8390701832811802
attack loss: 0.23764941096305847


Perturbing graph:   4%|▎         | 9/253 [00:03<01:17,  3.16it/s]

GCN loss on unlabled data: 0.5654101967811584
GCN acc on unlabled data: 0.8363880196691998
attack loss: 0.2496768683195114


Perturbing graph:   4%|▍         | 10/253 [00:03<01:16,  3.20it/s]

GCN loss on unlabled data: 0.5248417854309082
GCN acc on unlabled data: 0.845775592311131
attack loss: 0.23769961297512054


Perturbing graph:   4%|▍         | 11/253 [00:03<01:15,  3.22it/s]

GCN loss on unlabled data: 0.5508171916007996
GCN acc on unlabled data: 0.8386231560125168
attack loss: 0.23886702954769135


Perturbing graph:   5%|▍         | 12/253 [00:04<01:14,  3.23it/s]

GCN loss on unlabled data: 0.5554183125495911
GCN acc on unlabled data: 0.8426464014304873
attack loss: 0.24873213469982147


Perturbing graph:   5%|▌         | 13/253 [00:04<01:14,  3.24it/s]

GCN loss on unlabled data: 0.5457016825675964
GCN acc on unlabled data: 0.8413053196244972
attack loss: 0.23347222805023193


Perturbing graph:   6%|▌         | 14/253 [00:04<01:13,  3.24it/s]

GCN loss on unlabled data: 0.531434953212738
GCN acc on unlabled data: 0.8381761287438534
attack loss: 0.24884067475795746


Perturbing graph:   6%|▌         | 15/253 [00:05<01:13,  3.25it/s]

GCN loss on unlabled data: 0.5387420058250427
GCN acc on unlabled data: 0.8395172105498435
attack loss: 0.2505725026130676


Perturbing graph:   6%|▋         | 16/253 [00:05<01:12,  3.25it/s]

GCN loss on unlabled data: 0.5459677577018738
GCN acc on unlabled data: 0.845775592311131
attack loss: 0.2630898058414459


Perturbing graph:   7%|▋         | 17/253 [00:05<01:12,  3.26it/s]

GCN loss on unlabled data: 0.5264639854431152
GCN acc on unlabled data: 0.8408582923558338
attack loss: 0.24844001233577728


Perturbing graph:   7%|▋         | 18/253 [00:06<01:11,  3.27it/s]

GCN loss on unlabled data: 0.5509992241859436
GCN acc on unlabled data: 0.8395172105498435
attack loss: 0.25168266892433167


Perturbing graph:   8%|▊         | 19/253 [00:06<01:11,  3.27it/s]

GCN loss on unlabled data: 0.5515260696411133
GCN acc on unlabled data: 0.8448815377738043
attack loss: 0.24501067399978638


Perturbing graph:   8%|▊         | 20/253 [00:06<01:11,  3.27it/s]

GCN loss on unlabled data: 0.5504652261734009
GCN acc on unlabled data: 0.8363880196691998
attack loss: 0.26154857873916626


Perturbing graph:   8%|▊         | 21/253 [00:06<01:11,  3.26it/s]

GCN loss on unlabled data: 0.5310647487640381
GCN acc on unlabled data: 0.8363880196691998
attack loss: 0.25112712383270264


Perturbing graph:   9%|▊         | 22/253 [00:07<01:10,  3.27it/s]

GCN loss on unlabled data: 0.5550873875617981
GCN acc on unlabled data: 0.8359409924005364
attack loss: 0.25482505559921265


Perturbing graph:   9%|▉         | 23/253 [00:07<01:10,  3.26it/s]

GCN loss on unlabled data: 0.5391654968261719
GCN acc on unlabled data: 0.8386231560125168
attack loss: 0.24841424822807312


Perturbing graph:   9%|▉         | 24/253 [00:07<01:10,  3.26it/s]

GCN loss on unlabled data: 0.5443551540374756
GCN acc on unlabled data: 0.8292355833705857
attack loss: 0.24682091176509857


Perturbing graph:  10%|▉         | 25/253 [00:08<01:09,  3.26it/s]

GCN loss on unlabled data: 0.5527199506759644
GCN acc on unlabled data: 0.8395172105498435
attack loss: 0.25302866101264954


Perturbing graph:  10%|█         | 26/253 [00:08<01:09,  3.26it/s]

GCN loss on unlabled data: 0.5583311915397644
GCN acc on unlabled data: 0.843540455967814
attack loss: 0.2626767158508301


Perturbing graph:  11%|█         | 27/253 [00:08<01:09,  3.26it/s]

GCN loss on unlabled data: 0.5254263281822205
GCN acc on unlabled data: 0.8399642378185069
attack loss: 0.25730031728744507


Perturbing graph:  11%|█         | 28/253 [00:09<01:09,  3.26it/s]

GCN loss on unlabled data: 0.5687953233718872
GCN acc on unlabled data: 0.8314707197139026
attack loss: 0.2819279730319977


Perturbing graph:  11%|█▏        | 29/253 [00:09<01:08,  3.26it/s]

GCN loss on unlabled data: 0.5799176692962646
GCN acc on unlabled data: 0.8314707197139026
attack loss: 0.2686614692211151


Perturbing graph:  12%|█▏        | 30/253 [00:09<01:08,  3.27it/s]

GCN loss on unlabled data: 0.5683062076568604
GCN acc on unlabled data: 0.8319177469825659
attack loss: 0.26335394382476807


Perturbing graph:  12%|█▏        | 31/253 [00:10<01:08,  3.26it/s]

GCN loss on unlabled data: 0.5717869400978088
GCN acc on unlabled data: 0.8341528833258829
attack loss: 0.2690102458000183


Perturbing graph:  13%|█▎        | 32/253 [00:10<01:07,  3.26it/s]

GCN loss on unlabled data: 0.5497830510139465
GCN acc on unlabled data: 0.8399642378185069
attack loss: 0.2639966607093811


Perturbing graph:  13%|█▎        | 33/253 [00:10<01:07,  3.25it/s]

GCN loss on unlabled data: 0.5756658315658569
GCN acc on unlabled data: 0.8314707197139026
attack loss: 0.28406772017478943


Perturbing graph:  13%|█▎        | 34/253 [00:10<01:07,  3.26it/s]

GCN loss on unlabled data: 0.5529547333717346
GCN acc on unlabled data: 0.8386231560125168
attack loss: 0.28652840852737427


Perturbing graph:  14%|█▍        | 35/253 [00:11<01:06,  3.26it/s]

GCN loss on unlabled data: 0.558480978012085
GCN acc on unlabled data: 0.8341528833258829
attack loss: 0.27462825179100037


Perturbing graph:  14%|█▍        | 36/253 [00:11<01:06,  3.26it/s]

GCN loss on unlabled data: 0.5730696320533752
GCN acc on unlabled data: 0.8314707197139026
attack loss: 0.27807918190956116


Perturbing graph:  15%|█▍        | 37/253 [00:11<01:06,  3.26it/s]

GCN loss on unlabled data: 0.5883363485336304
GCN acc on unlabled data: 0.8341528833258829
attack loss: 0.2774984836578369


Perturbing graph:  15%|█▌        | 38/253 [00:12<01:05,  3.26it/s]

GCN loss on unlabled data: 0.5662685036659241
GCN acc on unlabled data: 0.8261063924899419
attack loss: 0.26842111349105835


Perturbing graph:  15%|█▌        | 39/253 [00:12<01:05,  3.27it/s]

GCN loss on unlabled data: 0.5642164945602417
GCN acc on unlabled data: 0.8363880196691998
attack loss: 0.272251158952713


Perturbing graph:  16%|█▌        | 40/253 [00:12<01:05,  3.27it/s]

GCN loss on unlabled data: 0.5694259405136108
GCN acc on unlabled data: 0.8305766651765758
attack loss: 0.28398868441581726


Perturbing graph:  16%|█▌        | 41/253 [00:13<01:04,  3.27it/s]

GCN loss on unlabled data: 0.5827379822731018
GCN acc on unlabled data: 0.8368350469378633
attack loss: 0.2850019633769989


Perturbing graph:  17%|█▋        | 42/253 [00:13<01:04,  3.26it/s]

GCN loss on unlabled data: 0.5786392688751221
GCN acc on unlabled data: 0.829682610639249
attack loss: 0.2820226550102234


Perturbing graph:  17%|█▋        | 43/253 [00:13<01:04,  3.25it/s]

GCN loss on unlabled data: 0.5659496784210205
GCN acc on unlabled data: 0.8350469378632097
attack loss: 0.2797142267227173


Perturbing graph:  17%|█▋        | 44/253 [00:14<01:04,  3.25it/s]

GCN loss on unlabled data: 0.5867460370063782
GCN acc on unlabled data: 0.8270004470272687
attack loss: 0.28110840916633606


Perturbing graph:  18%|█▊        | 45/253 [00:14<01:03,  3.25it/s]

GCN loss on unlabled data: 0.5808667540550232
GCN acc on unlabled data: 0.8314707197139026
attack loss: 0.2815357446670532


Perturbing graph:  18%|█▊        | 46/253 [00:14<01:03,  3.25it/s]

GCN loss on unlabled data: 0.5717581510543823
GCN acc on unlabled data: 0.8305766651765758
attack loss: 0.2851971387863159


Perturbing graph:  19%|█▊        | 47/253 [00:14<01:03,  3.25it/s]

GCN loss on unlabled data: 0.5926646590232849
GCN acc on unlabled data: 0.8341528833258829
attack loss: 0.29610228538513184


Perturbing graph:  19%|█▉        | 48/253 [00:15<01:03,  3.25it/s]

GCN loss on unlabled data: 0.58144211769104
GCN acc on unlabled data: 0.8278945015645954
attack loss: 0.28846681118011475


Perturbing graph:  19%|█▉        | 49/253 [00:15<01:02,  3.25it/s]

GCN loss on unlabled data: 0.5802989602088928
GCN acc on unlabled data: 0.8341528833258829
attack loss: 0.2980899214744568


Perturbing graph:  20%|█▉        | 50/253 [00:15<01:02,  3.25it/s]

GCN loss on unlabled data: 0.5744807720184326
GCN acc on unlabled data: 0.835493965131873
attack loss: 0.28285282850265503


Perturbing graph:  20%|██        | 51/253 [00:16<01:02,  3.25it/s]

GCN loss on unlabled data: 0.6102427840232849
GCN acc on unlabled data: 0.8261063924899419
attack loss: 0.31458327174186707


Perturbing graph:  21%|██        | 52/253 [00:16<01:01,  3.26it/s]

GCN loss on unlabled data: 0.5776623487472534
GCN acc on unlabled data: 0.8413053196244972
attack loss: 0.3099380433559418


Perturbing graph:  21%|██        | 53/253 [00:16<01:01,  3.26it/s]

GCN loss on unlabled data: 0.5823544263839722
GCN acc on unlabled data: 0.8328118015198928
attack loss: 0.2956222891807556


Perturbing graph:  21%|██▏       | 54/253 [00:17<01:01,  3.26it/s]

GCN loss on unlabled data: 0.5922654867172241
GCN acc on unlabled data: 0.8287885561019223
attack loss: 0.2887805700302124


Perturbing graph:  22%|██▏       | 55/253 [00:17<01:00,  3.26it/s]

GCN loss on unlabled data: 0.5747089982032776
GCN acc on unlabled data: 0.8372820742065267
attack loss: 0.3115900754928589


Perturbing graph:  22%|██▏       | 56/253 [00:17<01:00,  3.26it/s]

GCN loss on unlabled data: 0.5921662449836731
GCN acc on unlabled data: 0.8323647742512293
attack loss: 0.2966410517692566


Perturbing graph:  23%|██▎       | 57/253 [00:18<01:00,  3.26it/s]

GCN loss on unlabled data: 0.5744078159332275
GCN acc on unlabled data: 0.8287885561019223
attack loss: 0.3027949929237366


Perturbing graph:  23%|██▎       | 58/253 [00:18<00:59,  3.27it/s]

GCN loss on unlabled data: 0.5949479341506958
GCN acc on unlabled data: 0.8314707197139026
attack loss: 0.29005271196365356


Perturbing graph:  23%|██▎       | 59/253 [00:18<00:59,  3.26it/s]

GCN loss on unlabled data: 0.5761167407035828
GCN acc on unlabled data: 0.8301296379079124
attack loss: 0.29755669832229614


Perturbing graph:  24%|██▎       | 60/253 [00:18<00:59,  3.27it/s]

GCN loss on unlabled data: 0.5890217423439026
GCN acc on unlabled data: 0.8234242288779616
attack loss: 0.31318485736846924


Perturbing graph:  24%|██▍       | 61/253 [00:19<00:58,  3.25it/s]

GCN loss on unlabled data: 0.5778560638427734
GCN acc on unlabled data: 0.8287885561019223
attack loss: 0.2982039749622345


Perturbing graph:  25%|██▍       | 62/253 [00:19<00:59,  3.23it/s]

GCN loss on unlabled data: 0.5968514680862427
GCN acc on unlabled data: 0.8194009834599911
attack loss: 0.30197232961654663


Perturbing graph:  25%|██▍       | 63/253 [00:19<00:58,  3.24it/s]

GCN loss on unlabled data: 0.6057376861572266
GCN acc on unlabled data: 0.8252123379526152
attack loss: 0.300879567861557


Perturbing graph:  25%|██▌       | 64/253 [00:20<00:58,  3.25it/s]

GCN loss on unlabled data: 0.6204259991645813
GCN acc on unlabled data: 0.8252123379526152
attack loss: 0.3343631327152252


Perturbing graph:  26%|██▌       | 65/253 [00:20<00:57,  3.26it/s]

GCN loss on unlabled data: 0.589552640914917
GCN acc on unlabled data: 0.8301296379079124
attack loss: 0.3195209801197052


Perturbing graph:  26%|██▌       | 66/253 [00:20<00:57,  3.26it/s]

GCN loss on unlabled data: 0.600684404373169
GCN acc on unlabled data: 0.8167188198480108
attack loss: 0.3159228563308716


Perturbing graph:  26%|██▋       | 67/253 [00:21<00:56,  3.26it/s]

GCN loss on unlabled data: 0.5946279764175415
GCN acc on unlabled data: 0.8243182834152883
attack loss: 0.31928351521492004


Perturbing graph:  27%|██▋       | 68/253 [00:21<00:56,  3.27it/s]

GCN loss on unlabled data: 0.5896778106689453
GCN acc on unlabled data: 0.8301296379079124
attack loss: 0.30569449067115784


Perturbing graph:  27%|██▋       | 69/253 [00:21<00:56,  3.26it/s]

GCN loss on unlabled data: 0.576515257358551
GCN acc on unlabled data: 0.8305766651765758
attack loss: 0.32123708724975586


Perturbing graph:  28%|██▊       | 70/253 [00:22<00:56,  3.26it/s]

GCN loss on unlabled data: 0.5995261669158936
GCN acc on unlabled data: 0.821636119803308
attack loss: 0.3168681561946869


Perturbing graph:  28%|██▊       | 71/253 [00:22<00:55,  3.26it/s]

GCN loss on unlabled data: 0.5938846468925476
GCN acc on unlabled data: 0.8243182834152883
attack loss: 0.31506580114364624


Perturbing graph:  28%|██▊       | 72/253 [00:22<00:55,  3.26it/s]

GCN loss on unlabled data: 0.5878554582595825
GCN acc on unlabled data: 0.8238712561466249
attack loss: 0.3130529522895813


Perturbing graph:  29%|██▉       | 73/253 [00:22<00:55,  3.26it/s]

GCN loss on unlabled data: 0.6071321964263916
GCN acc on unlabled data: 0.8189539561913277
attack loss: 0.3101206421852112


Perturbing graph:  29%|██▉       | 74/253 [00:23<00:54,  3.25it/s]

GCN loss on unlabled data: 0.603390097618103
GCN acc on unlabled data: 0.8194009834599911
attack loss: 0.3372363746166229


Perturbing graph:  30%|██▉       | 75/253 [00:23<00:54,  3.26it/s]

GCN loss on unlabled data: 0.625756025314331
GCN acc on unlabled data: 0.8126955744300403
attack loss: 0.30528101325035095


Perturbing graph:  30%|███       | 76/253 [00:23<00:54,  3.25it/s]

GCN loss on unlabled data: 0.6092071533203125
GCN acc on unlabled data: 0.8162717925793473
attack loss: 0.3207467198371887


Perturbing graph:  30%|███       | 77/253 [00:24<00:54,  3.26it/s]

GCN loss on unlabled data: 0.5885103940963745
GCN acc on unlabled data: 0.8176128743853376
attack loss: 0.3156546354293823


Perturbing graph:  31%|███       | 78/253 [00:24<00:53,  3.25it/s]

GCN loss on unlabled data: 0.6440790891647339
GCN acc on unlabled data: 0.8028609745194457
attack loss: 0.34012913703918457


Perturbing graph:  31%|███       | 79/253 [00:24<00:53,  3.26it/s]

GCN loss on unlabled data: 0.6009710431098938
GCN acc on unlabled data: 0.8171658471166742
attack loss: 0.3187826871871948


Perturbing graph:  32%|███▏      | 80/253 [00:25<00:53,  3.25it/s]

GCN loss on unlabled data: 0.6181726455688477
GCN acc on unlabled data: 0.8180599016540009
attack loss: 0.33455079793930054


Perturbing graph:  32%|███▏      | 81/253 [00:25<00:52,  3.25it/s]

GCN loss on unlabled data: 0.6013490557670593
GCN acc on unlabled data: 0.8185069289226643
attack loss: 0.33173954486846924


Perturbing graph:  32%|███▏      | 82/253 [00:25<00:52,  3.25it/s]

GCN loss on unlabled data: 0.6253562569618225
GCN acc on unlabled data: 0.8104604380867233
attack loss: 0.3356431722640991


Perturbing graph:  33%|███▎      | 83/253 [00:26<00:52,  3.25it/s]

GCN loss on unlabled data: 0.5959851145744324
GCN acc on unlabled data: 0.8198480107286544
attack loss: 0.33129194378852844


Perturbing graph:  33%|███▎      | 84/253 [00:26<00:51,  3.26it/s]

GCN loss on unlabled data: 0.6268859505653381
GCN acc on unlabled data: 0.8126955744300403
attack loss: 0.3310694098472595


Perturbing graph:  34%|███▎      | 85/253 [00:26<00:51,  3.26it/s]

GCN loss on unlabled data: 0.6412398815155029
GCN acc on unlabled data: 0.8100134108180599
attack loss: 0.3588138520717621


Perturbing graph:  34%|███▍      | 86/253 [00:26<00:51,  3.26it/s]

GCN loss on unlabled data: 0.6211496591567993
GCN acc on unlabled data: 0.8207420652659813
attack loss: 0.3337782621383667


Perturbing graph:  34%|███▍      | 87/253 [00:27<00:51,  3.25it/s]

GCN loss on unlabled data: 0.6293659210205078
GCN acc on unlabled data: 0.8077782744747429
attack loss: 0.33275213837623596


Perturbing graph:  35%|███▍      | 88/253 [00:27<00:50,  3.26it/s]

GCN loss on unlabled data: 0.6072419881820679
GCN acc on unlabled data: 0.8126955744300403
attack loss: 0.33824843168258667


Perturbing graph:  35%|███▌      | 89/253 [00:27<00:50,  3.26it/s]

GCN loss on unlabled data: 0.6362547278404236
GCN acc on unlabled data: 0.8024139472507823
attack loss: 0.3477238118648529


Perturbing graph:  36%|███▌      | 90/253 [00:28<00:50,  3.26it/s]

GCN loss on unlabled data: 0.657391369342804
GCN acc on unlabled data: 0.8028609745194457
attack loss: 0.34398436546325684


Perturbing graph:  36%|███▌      | 91/253 [00:28<00:49,  3.26it/s]

GCN loss on unlabled data: 0.6423360109329224
GCN acc on unlabled data: 0.8113544926240501
attack loss: 0.34327974915504456


Perturbing graph:  36%|███▋      | 92/253 [00:28<00:49,  3.26it/s]

GCN loss on unlabled data: 0.6293275952339172
GCN acc on unlabled data: 0.8082253017434063
attack loss: 0.34711188077926636


Perturbing graph:  37%|███▋      | 93/253 [00:29<00:49,  3.26it/s]

GCN loss on unlabled data: 0.6350268721580505
GCN acc on unlabled data: 0.8073312472060796
attack loss: 0.33330750465393066


Perturbing graph:  37%|███▋      | 94/253 [00:29<00:48,  3.26it/s]

GCN loss on unlabled data: 0.6114476919174194
GCN acc on unlabled data: 0.8176128743853376
attack loss: 0.3251102566719055


Perturbing graph:  38%|███▊      | 95/253 [00:29<00:48,  3.25it/s]

GCN loss on unlabled data: 0.6295976042747498
GCN acc on unlabled data: 0.8046490835940993
attack loss: 0.346145898103714


Perturbing graph:  38%|███▊      | 96/253 [00:30<00:48,  3.25it/s]

GCN loss on unlabled data: 0.6364448666572571
GCN acc on unlabled data: 0.8073312472060796
attack loss: 0.3405109643936157


Perturbing graph:  38%|███▊      | 97/253 [00:30<00:47,  3.26it/s]

GCN loss on unlabled data: 0.6332884430885315
GCN acc on unlabled data: 0.8109074653553867
attack loss: 0.33744049072265625


Perturbing graph:  39%|███▊      | 98/253 [00:30<00:47,  3.26it/s]

GCN loss on unlabled data: 0.6453820466995239
GCN acc on unlabled data: 0.8046490835940993
attack loss: 0.35693255066871643


Perturbing graph:  39%|███▉      | 99/253 [00:30<00:47,  3.26it/s]

GCN loss on unlabled data: 0.6591253280639648
GCN acc on unlabled data: 0.8109074653553867
attack loss: 0.35151544213294983


Perturbing graph:  40%|███▉      | 100/253 [00:31<00:46,  3.26it/s]

GCN loss on unlabled data: 0.6703752279281616
GCN acc on unlabled data: 0.8024139472507823
attack loss: 0.35396426916122437


Perturbing graph:  40%|███▉      | 101/253 [00:31<00:46,  3.25it/s]

GCN loss on unlabled data: 0.6601200699806213
GCN acc on unlabled data: 0.8001788109074653
attack loss: 0.3347211182117462


Perturbing graph:  40%|████      | 102/253 [00:31<00:46,  3.25it/s]

GCN loss on unlabled data: 0.6720302700996399
GCN acc on unlabled data: 0.7894501564595441
attack loss: 0.3532326817512512


Perturbing graph:  41%|████      | 103/253 [00:32<00:46,  3.25it/s]

GCN loss on unlabled data: 0.6679704785346985
GCN acc on unlabled data: 0.7952615109521681
attack loss: 0.34258484840393066


Perturbing graph:  41%|████      | 104/253 [00:32<00:45,  3.25it/s]

GCN loss on unlabled data: 0.6873105764389038
GCN acc on unlabled data: 0.7912382655341976
attack loss: 0.36618247628211975


Perturbing graph:  42%|████▏     | 105/253 [00:32<00:45,  3.25it/s]

GCN loss on unlabled data: 0.6809521913528442
GCN acc on unlabled data: 0.7992847563701386
attack loss: 0.3552233874797821


Perturbing graph:  42%|████▏     | 106/253 [00:33<00:45,  3.26it/s]

GCN loss on unlabled data: 0.6767216324806213
GCN acc on unlabled data: 0.8042020563254358
attack loss: 0.35398149490356445


Perturbing graph:  42%|████▏     | 107/253 [00:33<00:44,  3.26it/s]

GCN loss on unlabled data: 0.6591177582740784
GCN acc on unlabled data: 0.8024139472507823
attack loss: 0.35594701766967773


Perturbing graph:  43%|████▎     | 108/253 [00:33<00:44,  3.26it/s]

GCN loss on unlabled data: 0.6781156659126282
GCN acc on unlabled data: 0.8015198927134556
attack loss: 0.35282570123672485


Perturbing graph:  43%|████▎     | 109/253 [00:34<00:44,  3.26it/s]

GCN loss on unlabled data: 0.6722551584243774
GCN acc on unlabled data: 0.8033080017881091
attack loss: 0.3542301654815674


Perturbing graph:  43%|████▎     | 110/253 [00:34<00:43,  3.26it/s]

GCN loss on unlabled data: 0.7412323951721191
GCN acc on unlabled data: 0.7854269110415736
attack loss: 0.3921110928058624


Perturbing graph:  44%|████▍     | 111/253 [00:34<00:43,  3.26it/s]

GCN loss on unlabled data: 0.6848589181900024
GCN acc on unlabled data: 0.7930263746088512
attack loss: 0.3652453124523163


Perturbing graph:  44%|████▍     | 112/253 [00:34<00:43,  3.26it/s]

GCN loss on unlabled data: 0.6574022173881531
GCN acc on unlabled data: 0.8028609745194457
attack loss: 0.366460382938385


Perturbing graph:  45%|████▍     | 113/253 [00:35<00:43,  3.26it/s]

GCN loss on unlabled data: 0.7574978470802307
GCN acc on unlabled data: 0.7818506928922665
attack loss: 0.40936991572380066


Perturbing graph:  45%|████▌     | 114/253 [00:35<00:42,  3.25it/s]

GCN loss on unlabled data: 0.6718704104423523
GCN acc on unlabled data: 0.8010728654447922
attack loss: 0.363324373960495


Perturbing graph:  45%|████▌     | 115/253 [00:35<00:42,  3.25it/s]

GCN loss on unlabled data: 0.7111660838127136
GCN acc on unlabled data: 0.7966025927581583
attack loss: 0.38393908739089966


Perturbing graph:  46%|████▌     | 116/253 [00:36<00:42,  3.25it/s]

GCN loss on unlabled data: 0.6859163641929626
GCN acc on unlabled data: 0.7907912382655342
attack loss: 0.3729408085346222


Perturbing graph:  46%|████▌     | 117/253 [00:36<00:41,  3.26it/s]

GCN loss on unlabled data: 0.6708067059516907
GCN acc on unlabled data: 0.7988377291014752
attack loss: 0.36983051896095276


Perturbing graph:  47%|████▋     | 118/253 [00:36<00:41,  3.26it/s]

GCN loss on unlabled data: 0.7080731987953186
GCN acc on unlabled data: 0.7872150201162271
attack loss: 0.3721805810928345


Perturbing graph:  47%|████▋     | 119/253 [00:37<00:41,  3.26it/s]

GCN loss on unlabled data: 0.7472318410873413
GCN acc on unlabled data: 0.7894501564595441
attack loss: 0.39931806921958923


Perturbing graph:  47%|████▋     | 120/253 [00:37<00:40,  3.26it/s]

GCN loss on unlabled data: 0.7557578682899475
GCN acc on unlabled data: 0.7773804202056326
attack loss: 0.40297743678092957


Perturbing graph:  48%|████▊     | 121/253 [00:37<00:40,  3.26it/s]

GCN loss on unlabled data: 0.707912266254425
GCN acc on unlabled data: 0.7867679928475637
attack loss: 0.3717356026172638


Perturbing graph:  48%|████▊     | 122/253 [00:37<00:40,  3.26it/s]

GCN loss on unlabled data: 0.6932036876678467
GCN acc on unlabled data: 0.7898971837282075
attack loss: 0.3668209910392761


Perturbing graph:  49%|████▊     | 123/253 [00:38<00:39,  3.26it/s]

GCN loss on unlabled data: 0.7439814209938049
GCN acc on unlabled data: 0.777827447474296
attack loss: 0.3949737250804901


Perturbing graph:  49%|████▉     | 124/253 [00:38<00:39,  3.26it/s]

GCN loss on unlabled data: 0.6972739100456238
GCN acc on unlabled data: 0.7948144836835047
attack loss: 0.36459487676620483


Perturbing graph:  49%|████▉     | 125/253 [00:38<00:39,  3.26it/s]

GCN loss on unlabled data: 0.7145181894302368
GCN acc on unlabled data: 0.7881090746535538
attack loss: 0.38139525055885315


Perturbing graph:  50%|████▉     | 126/253 [00:39<00:38,  3.26it/s]

GCN loss on unlabled data: 0.7280591130256653
GCN acc on unlabled data: 0.7840858292355833
attack loss: 0.39283084869384766


Perturbing graph:  50%|█████     | 127/253 [00:39<00:38,  3.26it/s]

GCN loss on unlabled data: 0.7349655032157898
GCN acc on unlabled data: 0.7845328565042468
attack loss: 0.39739125967025757


Perturbing graph:  51%|█████     | 128/253 [00:39<00:38,  3.26it/s]

GCN loss on unlabled data: 0.7283865213394165
GCN acc on unlabled data: 0.7903442109968708
attack loss: 0.38688334822654724


Perturbing graph:  51%|█████     | 129/253 [00:40<00:38,  3.26it/s]

GCN loss on unlabled data: 0.7409111857414246
GCN acc on unlabled data: 0.780062583817613
attack loss: 0.38992977142333984


Perturbing graph:  51%|█████▏    | 130/253 [00:40<00:37,  3.25it/s]

GCN loss on unlabled data: 0.7489311099052429
GCN acc on unlabled data: 0.777827447474296
attack loss: 0.39455148577690125


Perturbing graph:  52%|█████▏    | 131/253 [00:40<00:37,  3.25it/s]

GCN loss on unlabled data: 0.7447288036346436
GCN acc on unlabled data: 0.7894501564595441
attack loss: 0.39369526505470276


Perturbing graph:  52%|█████▏    | 132/253 [00:41<00:37,  3.25it/s]

GCN loss on unlabled data: 0.7169586420059204
GCN acc on unlabled data: 0.7863209655789003
attack loss: 0.39129045605659485


Perturbing graph:  53%|█████▎    | 133/253 [00:41<00:36,  3.24it/s]

GCN loss on unlabled data: 0.7467489838600159
GCN acc on unlabled data: 0.7782744747429593
attack loss: 0.3932485282421112


Perturbing graph:  53%|█████▎    | 134/253 [00:41<00:36,  3.24it/s]

GCN loss on unlabled data: 0.7044916152954102
GCN acc on unlabled data: 0.7907912382655342
attack loss: 0.3813208043575287


Perturbing graph:  53%|█████▎    | 135/253 [00:41<00:36,  3.24it/s]

GCN loss on unlabled data: 0.7105488777160645
GCN acc on unlabled data: 0.7872150201162271
attack loss: 0.3853394389152527


Perturbing graph:  54%|█████▍    | 136/253 [00:42<00:36,  3.24it/s]

GCN loss on unlabled data: 0.7464881539344788
GCN acc on unlabled data: 0.7746982565936522
attack loss: 0.3908102512359619


Perturbing graph:  54%|█████▍    | 137/253 [00:42<00:35,  3.24it/s]

GCN loss on unlabled data: 0.75876384973526
GCN acc on unlabled data: 0.7894501564595441
attack loss: 0.3977511525154114


Perturbing graph:  55%|█████▍    | 138/253 [00:42<00:35,  3.24it/s]

GCN loss on unlabled data: 0.8191707730293274
GCN acc on unlabled data: 0.7688869021010282
attack loss: 0.42071154713630676


Perturbing graph:  55%|█████▍    | 139/253 [00:43<00:35,  3.23it/s]

GCN loss on unlabled data: 0.7249125242233276
GCN acc on unlabled data: 0.7845328565042468
attack loss: 0.39776119589805603


Perturbing graph:  55%|█████▌    | 140/253 [00:43<00:35,  3.22it/s]

GCN loss on unlabled data: 0.8050491809844971
GCN acc on unlabled data: 0.7742512293249888
attack loss: 0.43257227540016174


Perturbing graph:  56%|█████▌    | 141/253 [00:43<00:34,  3.22it/s]

GCN loss on unlabled data: 0.7300708293914795
GCN acc on unlabled data: 0.7890031291908807
attack loss: 0.3945390582084656


Perturbing graph:  56%|█████▌    | 142/253 [00:44<00:34,  3.23it/s]

GCN loss on unlabled data: 0.7410445809364319
GCN acc on unlabled data: 0.775592311130979
attack loss: 0.40467721223831177


Perturbing graph:  57%|█████▋    | 143/253 [00:44<00:34,  3.23it/s]

GCN loss on unlabled data: 0.7888445258140564
GCN acc on unlabled data: 0.769780956638355
attack loss: 0.417090505361557


Perturbing graph:  57%|█████▋    | 144/253 [00:44<00:33,  3.23it/s]

GCN loss on unlabled data: 0.7870241403579712
GCN acc on unlabled data: 0.7688869021010282
attack loss: 0.41871047019958496


Perturbing graph:  57%|█████▋    | 145/253 [00:45<00:33,  3.23it/s]

GCN loss on unlabled data: 0.7715336084365845
GCN acc on unlabled data: 0.7742512293249888
attack loss: 0.42706209421157837


Perturbing graph:  58%|█████▊    | 146/253 [00:45<00:33,  3.23it/s]

GCN loss on unlabled data: 0.7886906266212463
GCN acc on unlabled data: 0.7679928475637015
attack loss: 0.426185667514801


Perturbing graph:  58%|█████▊    | 147/253 [00:45<00:32,  3.23it/s]

GCN loss on unlabled data: 0.7325350046157837
GCN acc on unlabled data: 0.791685292802861
attack loss: 0.40139612555503845


Perturbing graph:  58%|█████▊    | 148/253 [00:46<00:32,  3.24it/s]

GCN loss on unlabled data: 0.8148695230484009
GCN acc on unlabled data: 0.7729101475189987
attack loss: 0.4295136332511902


Perturbing graph:  59%|█████▉    | 149/253 [00:46<00:32,  3.24it/s]

GCN loss on unlabled data: 0.7970227003097534
GCN acc on unlabled data: 0.7814036656236031
attack loss: 0.4312679171562195


Perturbing graph:  59%|█████▉    | 150/253 [00:46<00:31,  3.23it/s]

GCN loss on unlabled data: 0.8225434422492981
GCN acc on unlabled data: 0.7657577112203845
attack loss: 0.4382982850074768


Perturbing graph:  60%|█████▉    | 151/253 [00:46<00:31,  3.24it/s]

GCN loss on unlabled data: 0.8598378300666809
GCN acc on unlabled data: 0.7603933839964238
attack loss: 0.46121746301651


Perturbing graph:  60%|██████    | 152/253 [00:47<00:31,  3.24it/s]

GCN loss on unlabled data: 0.8008511066436768
GCN acc on unlabled data: 0.7773804202056326
attack loss: 0.44279664754867554


Perturbing graph:  60%|██████    | 153/253 [00:47<00:30,  3.24it/s]

GCN loss on unlabled data: 0.8012098073959351
GCN acc on unlabled data: 0.7724631202503353
attack loss: 0.44736626744270325


Perturbing graph:  61%|██████    | 154/253 [00:47<00:30,  3.24it/s]

GCN loss on unlabled data: 0.8056666254997253
GCN acc on unlabled data: 0.7715690657130085
attack loss: 0.4322241544723511


Perturbing graph:  61%|██████▏   | 155/253 [00:48<00:30,  3.24it/s]

GCN loss on unlabled data: 0.8288738131523132
GCN acc on unlabled data: 0.769780956638355
attack loss: 0.44460004568099976


Perturbing graph:  62%|██████▏   | 156/253 [00:48<00:29,  3.24it/s]

GCN loss on unlabled data: 0.8425430655479431
GCN acc on unlabled data: 0.7657577112203845
attack loss: 0.4556645154953003


Perturbing graph:  62%|██████▏   | 157/253 [00:48<00:29,  3.24it/s]

GCN loss on unlabled data: 0.8006644248962402
GCN acc on unlabled data: 0.7729101475189987
attack loss: 0.43238624930381775


Perturbing graph:  62%|██████▏   | 158/253 [00:49<00:29,  3.24it/s]

GCN loss on unlabled data: 0.8174024820327759
GCN acc on unlabled data: 0.761734465802414
attack loss: 0.43686479330062866


Perturbing graph:  63%|██████▎   | 159/253 [00:49<00:29,  3.24it/s]

GCN loss on unlabled data: 0.8446511030197144
GCN acc on unlabled data: 0.7662047384890479
attack loss: 0.4499700367450714


Perturbing graph:  63%|██████▎   | 160/253 [00:49<00:28,  3.24it/s]

GCN loss on unlabled data: 0.8317916989326477
GCN acc on unlabled data: 0.7648636566830577
attack loss: 0.4585946202278137


Perturbing graph:  64%|██████▎   | 161/253 [00:50<00:28,  3.24it/s]

GCN loss on unlabled data: 0.8318073153495789
GCN acc on unlabled data: 0.7670987930263746
attack loss: 0.4459235966205597


Perturbing graph:  64%|██████▍   | 162/253 [00:50<00:28,  3.24it/s]

GCN loss on unlabled data: 0.8375892043113708
GCN acc on unlabled data: 0.7666517657577112
attack loss: 0.46193546056747437


Perturbing graph:  64%|██████▍   | 163/253 [00:50<00:27,  3.24it/s]

GCN loss on unlabled data: 0.8136905431747437
GCN acc on unlabled data: 0.7662047384890479
attack loss: 0.4386281669139862


Perturbing graph:  65%|██████▍   | 164/253 [00:50<00:27,  3.24it/s]

GCN loss on unlabled data: 0.8673322796821594
GCN acc on unlabled data: 0.7554760840411265
attack loss: 0.46498966217041016


Perturbing graph:  65%|██████▌   | 165/253 [00:51<00:27,  3.24it/s]

GCN loss on unlabled data: 0.7794994115829468
GCN acc on unlabled data: 0.7742512293249888
attack loss: 0.4334019720554352


Perturbing graph:  66%|██████▌   | 166/253 [00:51<00:26,  3.24it/s]

GCN loss on unlabled data: 0.8630141615867615
GCN acc on unlabled data: 0.7693339293696916
attack loss: 0.47840964794158936


Perturbing graph:  66%|██████▌   | 167/253 [00:51<00:26,  3.24it/s]

GCN loss on unlabled data: 0.8499796986579895
GCN acc on unlabled data: 0.7603933839964238
attack loss: 0.4634590148925781


Perturbing graph:  66%|██████▋   | 168/253 [00:52<00:26,  3.24it/s]

GCN loss on unlabled data: 0.9453380107879639
GCN acc on unlabled data: 0.75592311130979
attack loss: 0.49269169569015503


Perturbing graph:  67%|██████▋   | 169/253 [00:52<00:25,  3.23it/s]

GCN loss on unlabled data: 0.9093307852745056
GCN acc on unlabled data: 0.7483236477425124
attack loss: 0.486865371465683


Perturbing graph:  67%|██████▋   | 170/253 [00:52<00:25,  3.24it/s]

GCN loss on unlabled data: 0.8395704627037048
GCN acc on unlabled data: 0.7581582476531069
attack loss: 0.44752079248428345


Perturbing graph:  68%|██████▊   | 171/253 [00:53<00:25,  3.23it/s]

GCN loss on unlabled data: 0.8178954720497131
GCN acc on unlabled data: 0.7635225748770675
attack loss: 0.45147505402565


Perturbing graph:  68%|██████▊   | 172/253 [00:53<00:25,  3.23it/s]

GCN loss on unlabled data: 0.8764222264289856
GCN acc on unlabled data: 0.7568171658471167
attack loss: 0.48200666904449463


Perturbing graph:  68%|██████▊   | 173/253 [00:53<00:24,  3.23it/s]

GCN loss on unlabled data: 0.8936953544616699
GCN acc on unlabled data: 0.751452838623156
attack loss: 0.4769783020019531


Perturbing graph:  69%|██████▉   | 174/253 [00:54<00:24,  3.23it/s]

GCN loss on unlabled data: 0.8504350781440735
GCN acc on unlabled data: 0.7662047384890479
attack loss: 0.455949604511261


Perturbing graph:  69%|██████▉   | 175/253 [00:54<00:24,  3.23it/s]

GCN loss on unlabled data: 0.8465107679367065
GCN acc on unlabled data: 0.7805096110862763
attack loss: 0.4661363959312439


Perturbing graph:  70%|██████▉   | 176/253 [00:54<00:23,  3.22it/s]

GCN loss on unlabled data: 0.8668319582939148
GCN acc on unlabled data: 0.7608404112650872
attack loss: 0.46787548065185547


Perturbing graph:  70%|██████▉   | 177/253 [00:54<00:23,  3.22it/s]

GCN loss on unlabled data: 0.8735296130180359
GCN acc on unlabled data: 0.7608404112650872
attack loss: 0.4717456102371216


Perturbing graph:  70%|███████   | 178/253 [00:55<00:23,  3.23it/s]

GCN loss on unlabled data: 0.9105154275894165
GCN acc on unlabled data: 0.7469825659365221
attack loss: 0.4917673170566559


Perturbing graph:  71%|███████   | 179/253 [00:55<00:22,  3.23it/s]

GCN loss on unlabled data: 0.8416267037391663
GCN acc on unlabled data: 0.7684398748323648
attack loss: 0.46532729268074036


Perturbing graph:  71%|███████   | 180/253 [00:55<00:22,  3.23it/s]

GCN loss on unlabled data: 0.8854078054428101
GCN acc on unlabled data: 0.7465355386678587
attack loss: 0.47951847314834595


Perturbing graph:  72%|███████▏  | 181/253 [00:56<00:22,  3.23it/s]

GCN loss on unlabled data: 0.9559693336486816
GCN acc on unlabled data: 0.7532409476978096
attack loss: 0.5040876269340515


Perturbing graph:  72%|███████▏  | 182/253 [00:56<00:22,  3.23it/s]

GCN loss on unlabled data: 0.8976383209228516
GCN acc on unlabled data: 0.751452838623156
attack loss: 0.489867627620697


Perturbing graph:  72%|███████▏  | 183/253 [00:56<00:21,  3.22it/s]

GCN loss on unlabled data: 0.852137565612793
GCN acc on unlabled data: 0.7702279839070183
attack loss: 0.46748387813568115


Perturbing graph:  73%|███████▎  | 184/253 [00:57<00:21,  3.22it/s]

GCN loss on unlabled data: 0.9665087461471558
GCN acc on unlabled data: 0.7469825659365221
attack loss: 0.5133907198905945


Perturbing graph:  73%|███████▎  | 185/253 [00:57<00:21,  3.22it/s]

GCN loss on unlabled data: 0.92574143409729
GCN acc on unlabled data: 0.7487706750111757
attack loss: 0.4965360462665558


Perturbing graph:  74%|███████▎  | 186/253 [00:57<00:20,  3.22it/s]

GCN loss on unlabled data: 0.9127700328826904
GCN acc on unlabled data: 0.7532409476978096
attack loss: 0.48657602071762085


Perturbing graph:  74%|███████▍  | 187/253 [00:58<00:20,  3.23it/s]

GCN loss on unlabled data: 0.9850186705589294
GCN acc on unlabled data: 0.7371479660259276
attack loss: 0.5125894546508789


Perturbing graph:  74%|███████▍  | 188/253 [00:58<00:20,  3.23it/s]

GCN loss on unlabled data: 0.9685394763946533
GCN acc on unlabled data: 0.7389360751005811
attack loss: 0.5212475061416626


Perturbing graph:  75%|███████▍  | 189/253 [00:58<00:19,  3.23it/s]

GCN loss on unlabled data: 0.9480806589126587
GCN acc on unlabled data: 0.7483236477425124
attack loss: 0.4944054186344147


Perturbing graph:  75%|███████▌  | 190/253 [00:59<00:19,  3.23it/s]

GCN loss on unlabled data: 0.8892050981521606
GCN acc on unlabled data: 0.7505587840858292
attack loss: 0.48030078411102295


Perturbing graph:  75%|███████▌  | 191/253 [00:59<00:19,  3.24it/s]

GCN loss on unlabled data: 0.9593349099159241
GCN acc on unlabled data: 0.743406347787215
attack loss: 0.5066382884979248


Perturbing graph:  76%|███████▌  | 192/253 [00:59<00:18,  3.23it/s]

GCN loss on unlabled data: 0.924639105796814
GCN acc on unlabled data: 0.7496647295485025
attack loss: 0.49543818831443787


Perturbing graph:  76%|███████▋  | 193/253 [00:59<00:18,  3.24it/s]

GCN loss on unlabled data: 0.9433535933494568
GCN acc on unlabled data: 0.7518998658918195
attack loss: 0.5121543407440186


Perturbing graph:  77%|███████▋  | 194/253 [01:00<00:18,  3.24it/s]

GCN loss on unlabled data: 0.9457898736000061
GCN acc on unlabled data: 0.7545820295037997
attack loss: 0.49665898084640503


Perturbing graph:  77%|███████▋  | 195/253 [01:00<00:17,  3.24it/s]

GCN loss on unlabled data: 0.942023515701294
GCN acc on unlabled data: 0.7474295932051855
attack loss: 0.5090492963790894


Perturbing graph:  77%|███████▋  | 196/253 [01:00<00:17,  3.24it/s]

GCN loss on unlabled data: 0.9105322360992432
GCN acc on unlabled data: 0.7451944568618686
attack loss: 0.4820505976676941


Perturbing graph:  78%|███████▊  | 197/253 [01:01<00:17,  3.23it/s]

GCN loss on unlabled data: 0.9077975749969482
GCN acc on unlabled data: 0.7492177022798391
attack loss: 0.47200822830200195


Perturbing graph:  78%|███████▊  | 198/253 [01:01<00:16,  3.24it/s]

GCN loss on unlabled data: 1.0052772760391235
GCN acc on unlabled data: 0.745641484130532
attack loss: 0.5255739688873291


Perturbing graph:  79%|███████▊  | 199/253 [01:01<00:16,  3.24it/s]

GCN loss on unlabled data: 0.9373334646224976
GCN acc on unlabled data: 0.7545820295037997
attack loss: 0.5016701817512512


Perturbing graph:  79%|███████▉  | 200/253 [01:02<00:16,  3.25it/s]

GCN loss on unlabled data: 0.9786993265151978
GCN acc on unlabled data: 0.7429593205185516
attack loss: 0.5178415775299072


Perturbing graph:  79%|███████▉  | 201/253 [01:02<00:15,  3.25it/s]

GCN loss on unlabled data: 1.0050041675567627
GCN acc on unlabled data: 0.747876620473849
attack loss: 0.5379602909088135


Perturbing graph:  80%|███████▉  | 202/253 [01:02<00:15,  3.25it/s]

GCN loss on unlabled data: 1.0172697305679321
GCN acc on unlabled data: 0.739830129637908
attack loss: 0.5323429703712463


Perturbing graph:  80%|████████  | 203/253 [01:03<00:15,  3.25it/s]

GCN loss on unlabled data: 0.9920694231987
GCN acc on unlabled data: 0.747876620473849
attack loss: 0.5348206758499146


Perturbing graph:  81%|████████  | 204/253 [01:03<00:15,  3.24it/s]

GCN loss on unlabled data: 0.9562452435493469
GCN acc on unlabled data: 0.7460885113991954
attack loss: 0.5114390254020691


Perturbing graph:  81%|████████  | 205/253 [01:03<00:14,  3.24it/s]

GCN loss on unlabled data: 0.9599315524101257
GCN acc on unlabled data: 0.737594993294591
attack loss: 0.5185747742652893


Perturbing graph:  81%|████████▏ | 206/253 [01:03<00:14,  3.24it/s]

GCN loss on unlabled data: 0.9568213820457458
GCN acc on unlabled data: 0.7451944568618686
attack loss: 0.5016190409660339


Perturbing graph:  82%|████████▏ | 207/253 [01:04<00:14,  3.25it/s]

GCN loss on unlabled data: 0.9492165446281433
GCN acc on unlabled data: 0.7429593205185516
attack loss: 0.516402542591095


Perturbing graph:  82%|████████▏ | 208/253 [01:04<00:13,  3.25it/s]

GCN loss on unlabled data: 1.0373611450195312
GCN acc on unlabled data: 0.7443004023245419
attack loss: 0.5663094520568848


Perturbing graph:  83%|████████▎ | 209/253 [01:04<00:13,  3.25it/s]

GCN loss on unlabled data: 0.9931359887123108
GCN acc on unlabled data: 0.7349128296826106
attack loss: 0.5374303460121155


Perturbing graph:  83%|████████▎ | 210/253 [01:05<00:13,  3.25it/s]

GCN loss on unlabled data: 0.9964051246643066
GCN acc on unlabled data: 0.7344658024139473
attack loss: 0.53840172290802


Perturbing graph:  83%|████████▎ | 211/253 [01:05<00:12,  3.25it/s]

GCN loss on unlabled data: 0.9622283577919006
GCN acc on unlabled data: 0.7501117568171659
attack loss: 0.5146311521530151


Perturbing graph:  84%|████████▍ | 212/253 [01:05<00:12,  3.25it/s]

GCN loss on unlabled data: 1.0239372253417969
GCN acc on unlabled data: 0.7487706750111757
attack loss: 0.5445911884307861


Perturbing graph:  84%|████████▍ | 213/253 [01:06<00:12,  3.25it/s]

GCN loss on unlabled data: 1.011816143989563
GCN acc on unlabled data: 0.7425122932498882
attack loss: 0.5394071340560913


Perturbing graph:  85%|████████▍ | 214/253 [01:06<00:11,  3.25it/s]

GCN loss on unlabled data: 1.079677700996399
GCN acc on unlabled data: 0.743406347787215
attack loss: 0.5704711079597473


Perturbing graph:  85%|████████▍ | 215/253 [01:06<00:11,  3.25it/s]

GCN loss on unlabled data: 0.9994946718215942
GCN acc on unlabled data: 0.7416182387125615
attack loss: 0.5453214049339294


Perturbing graph:  85%|████████▌ | 216/253 [01:07<00:11,  3.25it/s]

GCN loss on unlabled data: 0.9863783121109009
GCN acc on unlabled data: 0.7429593205185516
attack loss: 0.5465690493583679


Perturbing graph:  86%|████████▌ | 217/253 [01:07<00:11,  3.25it/s]

GCN loss on unlabled data: 1.0062298774719238
GCN acc on unlabled data: 0.7384890478319178
attack loss: 0.5461506247520447


Perturbing graph:  86%|████████▌ | 218/253 [01:07<00:10,  3.25it/s]

GCN loss on unlabled data: 1.0031911134719849
GCN acc on unlabled data: 0.7322306660706304
attack loss: 0.5389253497123718


Perturbing graph:  87%|████████▋ | 219/253 [01:07<00:10,  3.25it/s]

GCN loss on unlabled data: 1.0319973230361938
GCN acc on unlabled data: 0.7443004023245419
attack loss: 0.5663365125656128


Perturbing graph:  87%|████████▋ | 220/253 [01:08<00:10,  3.25it/s]

GCN loss on unlabled data: 0.9219697117805481
GCN acc on unlabled data: 0.7572641931157801
attack loss: 0.4893896281719208


Perturbing graph:  87%|████████▋ | 221/253 [01:08<00:09,  3.25it/s]

GCN loss on unlabled data: 1.0600768327713013
GCN acc on unlabled data: 0.737594993294591
attack loss: 0.5828999876976013


Perturbing graph:  88%|████████▊ | 222/253 [01:08<00:09,  3.26it/s]

GCN loss on unlabled data: 1.0176845788955688
GCN acc on unlabled data: 0.7420652659812249
attack loss: 0.5490291714668274


Perturbing graph:  88%|████████▊ | 223/253 [01:09<00:09,  3.25it/s]

GCN loss on unlabled data: 1.0057735443115234
GCN acc on unlabled data: 0.731783638801967
attack loss: 0.5502764582633972


Perturbing graph:  89%|████████▊ | 224/253 [01:09<00:08,  3.25it/s]

GCN loss on unlabled data: 1.0681666135787964
GCN acc on unlabled data: 0.7326776933392937
attack loss: 0.5782688856124878


Perturbing graph:  89%|████████▉ | 225/253 [01:09<00:08,  3.25it/s]

GCN loss on unlabled data: 1.0953346490859985
GCN acc on unlabled data: 0.7313366115333035
attack loss: 0.5859764814376831


Perturbing graph:  89%|████████▉ | 226/253 [01:10<00:08,  3.24it/s]

GCN loss on unlabled data: 0.9881616234779358
GCN acc on unlabled data: 0.7313366115333035
attack loss: 0.5352218151092529


Perturbing graph:  90%|████████▉ | 227/253 [01:10<00:08,  3.24it/s]

GCN loss on unlabled data: 1.0457221269607544
GCN acc on unlabled data: 0.72954850245865
attack loss: 0.5659462809562683


Perturbing graph:  90%|█████████ | 228/253 [01:10<00:07,  3.25it/s]

GCN loss on unlabled data: 0.9857362508773804
GCN acc on unlabled data: 0.745641484130532
attack loss: 0.5340920686721802


Perturbing graph:  91%|█████████ | 229/253 [01:11<00:07,  3.25it/s]

GCN loss on unlabled data: 1.020390510559082
GCN acc on unlabled data: 0.7371479660259276
attack loss: 0.5549551844596863


Perturbing graph:  91%|█████████ | 230/253 [01:11<00:07,  3.25it/s]

GCN loss on unlabled data: 1.0068135261535645
GCN acc on unlabled data: 0.7438533750558785
attack loss: 0.5538772940635681


Perturbing graph:  91%|█████████▏| 231/253 [01:11<00:06,  3.25it/s]

GCN loss on unlabled data: 1.0125641822814941
GCN acc on unlabled data: 0.7371479660259276
attack loss: 0.551721453666687


Perturbing graph:  92%|█████████▏| 232/253 [01:11<00:06,  3.24it/s]

GCN loss on unlabled data: 1.044614315032959
GCN acc on unlabled data: 0.7282074206526599
attack loss: 0.5596221685409546


Perturbing graph:  92%|█████████▏| 233/253 [01:12<00:06,  3.25it/s]

GCN loss on unlabled data: 1.103659749031067
GCN acc on unlabled data: 0.721502011622709
attack loss: 0.6035088896751404


Perturbing graph:  92%|█████████▏| 234/253 [01:12<00:05,  3.25it/s]

GCN loss on unlabled data: 0.9634838700294495
GCN acc on unlabled data: 0.7438533750558785
attack loss: 0.5275031328201294


Perturbing graph:  93%|█████████▎| 235/253 [01:12<00:05,  3.25it/s]

GCN loss on unlabled data: 1.0850446224212646
GCN acc on unlabled data: 0.7299955297273134
attack loss: 0.5868152379989624


Perturbing graph:  93%|█████████▎| 236/253 [01:13<00:05,  3.25it/s]

GCN loss on unlabled data: 1.0470733642578125
GCN acc on unlabled data: 0.7286544479213232
attack loss: 0.5523366928100586


Perturbing graph:  94%|█████████▎| 237/253 [01:13<00:04,  3.25it/s]

GCN loss on unlabled data: 0.9840508699417114
GCN acc on unlabled data: 0.7425122932498882
attack loss: 0.5332476496696472


Perturbing graph:  94%|█████████▍| 238/253 [01:13<00:04,  3.25it/s]

GCN loss on unlabled data: 1.044750690460205
GCN acc on unlabled data: 0.7299955297273134
attack loss: 0.5543534755706787


Perturbing graph:  94%|█████████▍| 239/253 [01:14<00:04,  3.25it/s]

GCN loss on unlabled data: 1.146579623222351
GCN acc on unlabled data: 0.727313366115333
attack loss: 0.6078574061393738


Perturbing graph:  95%|█████████▍| 240/253 [01:14<00:04,  3.25it/s]

GCN loss on unlabled data: 1.0688289403915405
GCN acc on unlabled data: 0.735359856951274
attack loss: 0.5595360994338989


Perturbing graph:  95%|█████████▌| 241/253 [01:14<00:03,  3.25it/s]

GCN loss on unlabled data: 1.046075701713562
GCN acc on unlabled data: 0.7228430934286991
attack loss: 0.5675573945045471


Perturbing graph:  96%|█████████▌| 242/253 [01:15<00:03,  3.25it/s]

GCN loss on unlabled data: 1.0450780391693115
GCN acc on unlabled data: 0.7308895842646401
attack loss: 0.5619223117828369


Perturbing graph:  96%|█████████▌| 243/253 [01:15<00:03,  3.25it/s]

GCN loss on unlabled data: 1.1103572845458984
GCN acc on unlabled data: 0.7268663388466696
attack loss: 0.5986125469207764


Perturbing graph:  96%|█████████▋| 244/253 [01:15<00:02,  3.25it/s]

GCN loss on unlabled data: 1.059249997138977
GCN acc on unlabled data: 0.7282074206526599
attack loss: 0.5581023097038269


Perturbing graph:  97%|█████████▋| 245/253 [01:15<00:02,  3.24it/s]

GCN loss on unlabled data: 1.079076886177063
GCN acc on unlabled data: 0.723737147966026
attack loss: 0.5687435269355774


Perturbing graph:  97%|█████████▋| 246/253 [01:16<00:02,  3.24it/s]

GCN loss on unlabled data: 1.0978350639343262
GCN acc on unlabled data: 0.7152436298614215
attack loss: 0.5822624564170837


Perturbing graph:  98%|█████████▊| 247/253 [01:16<00:01,  3.24it/s]

GCN loss on unlabled data: 1.1213802099227905
GCN acc on unlabled data: 0.721502011622709
attack loss: 0.5834940671920776


Perturbing graph:  98%|█████████▊| 248/253 [01:16<00:01,  3.24it/s]

GCN loss on unlabled data: 1.1337097883224487
GCN acc on unlabled data: 0.7264193115780063
attack loss: 0.6109156012535095


Perturbing graph:  98%|█████████▊| 249/253 [01:17<00:01,  3.23it/s]

GCN loss on unlabled data: 1.030042290687561
GCN acc on unlabled data: 0.7255252570406795
attack loss: 0.5559802651405334


Perturbing graph:  99%|█████████▉| 250/253 [01:17<00:00,  3.24it/s]

GCN loss on unlabled data: 1.0300313234329224
GCN acc on unlabled data: 0.7277603933839965
attack loss: 0.5450984239578247


Perturbing graph:  99%|█████████▉| 251/253 [01:17<00:00,  3.24it/s]

GCN loss on unlabled data: 1.1208754777908325
GCN acc on unlabled data: 0.7264193115780063
attack loss: 0.6002645492553711


Perturbing graph: 100%|█████████▉| 252/253 [01:18<00:00,  3.24it/s]

GCN loss on unlabled data: 1.0505656003952026
GCN acc on unlabled data: 0.7210549843540456
attack loss: 0.5667456984519958


Perturbing graph: 100%|██████████| 253/253 [01:18<00:00,  3.23it/s]

GCN loss on unlabled data: 1.035224437713623
GCN acc on unlabled data: 0.7259722843093429
attack loss: 0.5520198345184326





In [5]:
features

<2485x1433 sparse matrix of type '<class 'numpy.float32'>'
	with 45487 stored elements in Compressed Sparse Row format>

In [6]:
adj

<2485x2485 sparse matrix of type '<class 'numpy.float32'>'
	with 10138 stored elements in Compressed Sparse Row format>

In [7]:
labels

array([5, 2, 0, ..., 2, 2, 2], dtype=int8)

In [70]:
surrogate.predict(features, adj).shape

torch.Size([2485, 7])

In [17]:
torch.max(torch.softmax(surrogate.predict(features, adj), dim=1), 1)

torch.return_types.max(
values=tensor([0.9828, 0.9994, 1.0000,  ..., 1.0000, 1.0000, 0.9904], device='cuda:0',
       grad_fn=<MaxBackward0>),
indices=tensor([5, 2, 0,  ..., 2, 2, 2], device='cuda:0'))

In [4]:
import scanpy as sc
import torch
import numpy as np
from deeprobust.graph.data import Dataset
from deeprobust.graph.defense import GCN, GAT
from deeprobust.graph.global_attack import Metattack

In [5]:
adata = sc.read_h5ad("/gpfs/gibbs/pi/zhao/tl688/scgpt_dataset/spaital_mouse_slideseqv2.h5ad")

In [6]:
sc.pp.subsample(adata, random_state=2023, fraction=0.1)

In [7]:
# from sklearn.neighbors import kneighbors_graph
# from torch_geometric.utils import from_scipy_sparse_matrix
# from scipy.spatial import Delaunay
# def extract_edge_index(
#     adata,
#     batch_key = None,
#     spatial_key = 'spatial',
#     method = 'knn',
#     n_neighbors = 30,
#     ):
#     """
#     Define edge_index for SIMVI model training.

#     Args:
#     ----
#         adata: AnnData object.
#         batch_key: Key in `adata.obs` for batch information. If batch_key is none,
#         assume the adata is from the same batch. Otherwise, we create edge_index
#         based on each batch and concatenate them.
#         spatial_key: Key in `adata.obsm` for spatial location.
#         method: method for establishing the graph proximity relationship between
#         cells. Two available methods are: knn and Delouney. Knn is used as default
#         due to its flexible neighbor number selection.
#         n_neighbors: The number of n_neighbors of knn graph. Not used if the graph
#         is based on Delouney triangularization.

#     Returns
#     -------
#         edge_index: torch.Tensor.
#     """
#     if batch_key is not None:
#         j = 0
#         for i in adata.obs[batch_key].unique():
#             adata_tmp = adata[adata.obs[batch_key]==i].copy()
#             if method == 'knn':
#                 A = kneighbors_graph(adata_tmp.obsm[spatial_key],n_neighbors = n_neighbors)
#                 edge_index_tmp, edge_weight = from_scipy_sparse_matrix(A)
#                 label = torch.arange(adata.shape[0])[adata.obs_names.isin(adata_tmp.obs_names)]
#                 edge_index_tmp = label[edge_index_tmp]
#                 if j == 0:
#                     edge_index = edge_index_tmp
#                     j = 1
#                 else:
#                     edge_index = torch.cat((edge_index,edge_index_tmp),1)

#             else:
#                 tri = Delaunay(adata_tmp.obsm[spatial_key])
#                 triangles = tri.simplices
#                 edges = set()
#                 for triangle in triangles:
#                     for i in range(3):
#                         edge = tuple(sorted((triangle[i], triangle[(i + 1) % 3])))
#                         edges.add(edge)
#                 edge_index_tmp = torch.tensor(list(edges)).t().contiguous()
#                 label = torch.arange(adata.shape[0])[adata.obs_names.isin(adata_tmp.obs_names)]
#                 edge_index_tmp = label[edge_index_tmp]
#                 if j == 0:
#                     edge_index = edge_index_tmp
#                     j = 1
#                 else:
#                     edge_index = torch.cat((edge_index,edge_index_tmp),1)
#     else:
#         if method == 'knn':
#             A = kneighbors_graph(adata.obsm[spatial_key],n_neighbors = n_neighbors)
#             edge_index, edge_weight = from_scipy_sparse_matrix(A)
#         else:
#             tri = Delaunay(adata.obsm[spatial_key])
#             triangles = tri.simplices
#             edges = set()
#             for triangle in triangles:
#                 for i in range(3):
#                     edge = tuple(sorted((triangle[i], triangle[(i + 1) % 3])))
#                     edges.add(edge)
#             edge_index = torch.tensor(list(edges)).t().contiguous()

#     return edge_index

In [8]:
import sklearn.model_selection
import sklearn.preprocessing

In [9]:
train_obs, test_obs =sklearn.model_selection.train_test_split(range(len(adata)), test_size=0.33, random_state=2023)

In [10]:
adata_train = adata[train_obs]

In [11]:
# train_edges = extract_edge_index(adata_train)

In [12]:
# train_edges
num_neig = 10

In [13]:
sc.pp.neighbors(adata, use_rep = 'spatial', n_neighbors=num_neig)

In [14]:
le = sklearn.preprocessing.LabelEncoder()

In [15]:
labels = torch.FloatTensor(le.fit_transform(adata.obs['cluster'])).long()

In [16]:
labels

tensor([3, 6, 6,  ..., 8, 2, 4])

In [17]:
len(labels)

4178

In [18]:
# idx_train = [i for i in range(features.shape[0])]
idx_train = train_obs
idx_test = test_obs

In [19]:
labels_train = labels[idx_train]

In [20]:
features = torch.FloatTensor(adata.X.toarray())
adj = (adata.obsp['distances']>0)*1

In [21]:
# from sklearn.neighbors import kneighbors_graph
# adj = kneighbors_graph(adata.obsm['spatial'], num_neig, mode='connectivity', include_self=True)

In [22]:
# adj

In [23]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16,
                with_relu=False, device=device)
surrogate = surrogate.to(device)
surrogate.fit(features, adj, labels, idx_train = idx_train, train_iters=1000)

  return torch.sparse.FloatTensor(sparseconcat.t(),sparsedata,torch.Size(sparse_mx.shape))


In [24]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# # surrogate = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16,
# #                 with_relu=False, device=device)

# surrogate = GAT(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16,device=device)
# surrogate = surrogate.to(device)
# surrogate.fit(features, adj, labels, idx_train = idx_train)

In [25]:
adj

<4178x4178 sparse matrix of type '<class 'numpy.int64'>'
	with 37602 stored elements in Compressed Sparse Row format>

In [26]:
output = surrogate.predict()

In [27]:
# output

In [28]:
_,predicted = torch.max(torch.softmax(output, dim=1), 1)

In [29]:
len(predicted.cpu().numpy())

4178

In [30]:
labels_train

tensor([12,  6,  0,  ...,  8,  6,  2])

In [31]:
predicted.cpu().numpy()

array([12,  1, 12, ..., 10,  2,  4])

In [32]:
print(sklearn.metrics.classification_report(labels_train, predicted.cpu().numpy()[train_obs]))

              precision    recall  f1-score   support

           0       0.87      0.81      0.84       468
           1       0.87      0.85      0.86       494
           2       0.77      0.91      0.83       450
           3       0.86      0.75      0.80       119
           4       0.88      0.85      0.86        92
           5       0.93      1.00      0.96        53
           6       0.90      0.79      0.84       231
           7       0.87      0.84      0.86        90
           8       0.96      0.88      0.92        77
           9       0.88      0.91      0.89        69
          10       0.90      0.88      0.89       215
          11       0.88      0.85      0.86       105
          12       0.76      0.90      0.82       200
          13       0.95      0.76      0.85       136

    accuracy                           0.85      2799
   macro avg       0.88      0.86      0.86      2799
weighted avg       0.86      0.85      0.85      2799



In [33]:
from deeprobust.graph.global_attack import MinMax
model = MinMax(model=surrogate, nnodes=adj.shape[0], loss_type='CE', device='cuda').to('cuda')
model.attack(features, adj.toarray(), labels, idx_train, n_perturbations=10)
modified_adj = model.modified_adj

100%|██████████| 200/200 [00:57<00:00,  3.51it/s]


In [None]:
torch.save(modified_adj, f"./MinMax_slideseqv2_neig{num_neig}.pickle")

In [34]:
# adata_train = adata[idx_train]

In [35]:
# adj_train = (adata_train.obsp['distances']>0)*1

In [36]:
output = surrogate.predict()

In [37]:
# output

In [38]:
_,predicted = torch.max(torch.softmax(output, dim=1), 1)

In [39]:
len(predicted.cpu().numpy())

4178

In [40]:
labels_train

tensor([12,  6,  0,  ...,  8,  6,  2])

In [41]:
predicted.cpu().numpy()

array([12,  6,  6, ..., 10,  2,  4])

In [42]:
print(sklearn.metrics.classification_report(labels_train, predicted.cpu().numpy()[train_obs]))

              precision    recall  f1-score   support

           0       0.76      0.93      0.84       468
           1       0.88      0.86      0.87       494
           2       0.91      0.75      0.82       450
           3       0.88      0.76      0.82       119
           4       0.90      0.89      0.90        92
           5       0.94      0.96      0.95        53
           6       0.85      0.90      0.88       231
           7       0.84      0.90      0.87        90
           8       0.96      0.87      0.91        77
           9       0.87      0.86      0.86        69
          10       0.82      0.92      0.87       215
          11       0.96      0.70      0.81       105
          12       0.83      0.79      0.81       200
          13       0.87      0.88      0.88       136

    accuracy                           0.85      2799
   macro avg       0.88      0.86      0.86      2799
weighted avg       0.86      0.85      0.85      2799



In [43]:
from deeprobust.graph.global_attack import Random
model = Random()
model.attack(adj, n_perturbations=10)
modified_adj = model.modified_adj

In [None]:
torch.save(modified_adj, f"./Random_slideseqv2_neig{num_neig}.pickle")

In [44]:
from deeprobust.graph.global_attack import DICE
model = DICE()
model.attack(adj, labels, n_perturbations=10)
modified_adj = model.modified_adj

In [None]:
torch.save(modified_adj, f"./DICE_slideseqv2_neig{num_neig}.pickle")

In [45]:
model = Metattack(model=surrogate.to(device), nnodes=adj.shape[0], feature_shape=features.shape, device=device)
model = model.to(device)
perturbations = int(0.05 * (adj.sum() // 2))
model.attack(features, adj, labels, idx_train, idx_test, perturbations, ll_constraint=False)
modified_adj = model.modified_adj

In [187]:
torch.save(modified_adj, f"./Metattack_slideseqv2_neig{num_neig}.pickle")

In [188]:
# !nvidia-smi

In [189]:
# features = torch.FloatTensor(adata_train.X.toarray()).cuda()

In [190]:
modified_adj.shape

torch.Size([4178, 4178])

In [191]:
features.shape

torch.Size([4178, 4000])

In [192]:
_,predicted = torch.max(torch.softmax(surrogate.predict(features.cuda(), modified_adj), dim=1), 1)

In [193]:
len(labels)

4178

In [194]:
len(predicted.cpu().numpy())

4178

In [195]:
print(sklearn.metrics.classification_report(labels[train_obs], predicted.cpu().numpy()[train_obs]))

              precision    recall  f1-score   support

           0       0.42      0.60      0.50       468
           1       0.51      0.60      0.55       494
           2       0.43      0.52      0.47       450
           3       0.34      0.12      0.18       119
           4       0.61      0.36      0.45        92
           5       0.69      0.77      0.73        53
           6       0.44      0.28      0.34       231
           7       0.49      0.27      0.35        90
           8       0.47      0.25      0.32        77
           9       0.46      0.09      0.15        69
          10       0.48      0.65      0.55       215
          11       0.43      0.12      0.19       105
          12       0.42      0.41      0.42       200
          13       0.42      0.33      0.37       136

    accuracy                           0.46      2799
   macro avg       0.47      0.38      0.40      2799
weighted avg       0.46      0.46      0.44      2799



In [196]:
a = 1