In [599]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import subgraph, k_hop_subgraph, dense_to_sparse
from torch_geometric.nn import GCNConv
from sklearn.metrics import accuracy_score
import torch.nn.functional as FF

In [600]:
def load_dataset():
    dataset=Planetoid(root='tmp/Cora', name='Cora')
    return dataset

In [601]:
class GCN_Extraction(nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN_Extraction,self).__init__()
        self.conv1=GCNConv(num_features,16)
        self.conv2=GCNConv(16, num_classes)
        
    def forward(self, x, edge_index):
        x=self.conv1(x,edge_index)
        x=torch.relu(x)
        x=self.conv2(x,edge_index)
        return torch.log_softmax(x, dim=1)

In [602]:
class GCN_Victim(nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN_Victim,self).__init__()
        self.conv1=GCNConv(num_features,128)
        self.conv2=GCNConv(128,64)
        self.conv3=GCNConv(64, 16)
        self.conv4=GCNConv(16,num_classes)

    def forward(self, x, edge_index):
        x=self.conv1(x,edge_index)
        x=torch.relu(x)
        x = self.conv2(x, edge_index)
        x=torch.relu(x)
        x=self.conv3(x,edge_index)
        x=torch.relu(x)
        x=self.conv4(x, edge_index)
        return torch.log_softmax(x, dim=1)

In [603]:
def train_Victim(model, data, epoches=200, lr=0.01, weight_decay=5e-4):
    optimizer=optim.Adam(model.parameters(),lr=lr, weight_decay=weight_decay)
    model.train()
    
    for epoch in range(epoches):
        optimizer.zero_grad()
        out=model(data.x, data.edge_index)
        loss=FF.nll_loss(out[data.train_mask],data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        if epoch % 10 ==0:
            print(f"Epoch {epoch+10}, Loss: {loss.item():.5f}")

In [604]:
def train_Extraction(model,features, adjacency_matrix, labels, epoches=400, lr=0.01, weight_decay=5e-4):
    optimizer=optim.Adam(model.parameters(),lr=lr, weight_decay=weight_decay)
    model.train()

    patience_counter = 0
    
    for epoch in range(epoches):
        optimizer.zero_grad()
        out=model(features, adjacency_matrix)
        loss=FF.nll_loss(out,labels)
        loss.backward()
        optimizer.step()
            
        if epoch % 10 ==0:
            print(f"Epoch {epoch+10}, Loss: {loss.item():.5f}")

In [605]:
def get_2hop_subgraph(data, min_size=100, max_size=150):
    num_nodes=data.num_nodes

    while True:
        node_idx=torch.randint(0,num_nodes,(1,)).item()
        subset, edge_index,_,_=k_hop_subgraph(node_idx, 2, data.edge_index, relabel_nodes=True)

        if min_size<=subset.size(0)<=max_size:
            As=torch.zeros((subset.size(0),subset.size(0)))
            As[edge_index[0],edge_index[1]]=1

            Xs=data.x[subset]

            return As.numpy(), Xs.numpy(), node_idx, subset

In [606]:
def get_distribution(num_classes):
    F=[]
    M=[]
    for c in range(num_classes):
        class_nodes=data.x[labels==c]

        feature_counts=class_nodes.sum(dim=0).numpy()
        feature_distribution=feature_counts/feature_counts.sum()
        F.append(feature_distribution)

        num_features_per_node=class_nodes.sum(dim=1).numpy()
        feature_count_distribution=np.bincount(num_features_per_node.astype(int),minlength=num_features)
        M.append(feature_count_distribution/feature_count_distribution.sum())
        
    return F,M

In [607]:
def GenerateSample(Fc, Mc, As):
    num_nodes=As.shape[0]
    Ac=torch.ones((num_nodes,num_nodes))
    Xc=torch.zeros(num_nodes, len(Fc))

    for i in range(num_nodes):
        m=np.random.choice(np.arange(len(Mc)),p=Mc)
        features=np.random.choice(len(Fc),size=m,replace=False,p=Fc)
        Xc[i,features]=1

    return Ac.numpy(), Xc.numpy()

In [608]:
def SubgraphSamplingAlgorithm(Gs,F, M, n, C):
    As, Xs=Gs
    num_nodes=As.shape[0]
    SA=[As]
    SX=[Xs]
    SL=[api_out_class(Xs,As)]

    for i in range(n):
        for c in range(C):
            Ac, Xc=GenerateSample(F[c],M[c],As)
            SA.append(Ac)
            SX.append(Xc)
            SL.append(api_out_class(Xc,Ac))

    AG_list=[dense_to_sparse(torch.tensor(a))[0] for a in SA]
    XG=torch.vstack([torch.tensor(x) for x in SX])

    SL=torch.tensor(SL,dtype=torch.long).view(-1)


    valid_mask = SL >= 0
    SL = SL[valid_mask]
    SL = SL[:XG.shape[0]]

    
    AG_combined=torch.cat([edge_index+i*num_nodes for i, edge_index in enumerate (AG_list)], dim=1)

    

    return XG, AG_combined, SL

In [609]:
def api_out_class(features, adjacency_matrix):
    features_tensor=torch.tensor(features,dtype=torch.float)
    adjacency_tensor=torch.tensor(adjacency_matrix,dtype=torch.long)

    if adjacency_tensor.ndim==2 and adjacency_tensor.shape[0]==adjacency_tensor.shape[1]:
        adjacency_tensor=torch.nonzero(adjacency_tensor, as_tuple=False).t()

    with torch.no_grad():
        logits=f(features_tensor, adjacency_tensor)
        predictions = torch.argmax(logits, dim=1)

    
    return predictions.numpy()

In [610]:
def api_out_vector(features, adjacency_matrix):
    features_tensor=torch.tensor(features,dtype=torch.float)
    adjacency_tensor=torch.tensor(adjacency_matrix,dtype=torch.long)

    if adjacency_tensor.ndim==2 and adjacency_tensor.shape[0]==adjacency_tensor.shape[1]:
        adjacency_tensor=torch.nonzero(adjacency_tensor, as_tuple=False).t()
        
    with torch.no_grad():
        logits=f(features_tensor, adjacency_tensor)

    return logits.numpy()

In [611]:
def api(model, data, mask):
    model.eval()
    out=model(data.x, data.edge_index)
    pred=out[mask].argmax(dim=1)

    return pred

In [612]:
def eval(victim_model, extracted_model,data, mask):
    original_preds_labels=api(victim_model,data,mask)
    extracted_preds_labels=api(extracted_model,data,mask)
    score=accuracy_score(original_preds_labels, extracted_preds_labels)

    return score

In [613]:
def get_f(victim_model,extraction_model,data,mask,times):
    list_out=[]
    for i in range(times):
        pred=eval(victim_model,extraction_model,data,mask)
        list_out.append(pred)
    
    return sum(list_out)/len(list_out)

In [614]:
def find_center_node_in_subgraph(center_global_idx, node_mapping):
    center_subgraph_idx = node_mapping.tolist().index(center_global_idx)
    return center_subgraph_idx

In [642]:
def approximate_nodes(Gc, api_out_vector, F, M, C, rho):
    A,X=Gc
    L=api_out_vector(X, A)
    center_idx=find_center_node_in_subgraph(center_node,node_mapping)
    
    for node in range(A.shape[0]):
        if node!=center_idx:
            for c in range(C):
                kl_div = np.exp(L[node][c]) * (L[node][c] - L[center_idx][c])
                
                if kl_div > 0:
                    print(kl_div)
                    num_new_nodes = int(kl_div * rho * np.log1p(A[node].sum())) 
                    print(num_new_nodes)
                    for _ in range(num_new_nodes):
                        new_node = A.shape[0]
                        
                        new_adj_row = np.zeros((1, A.shape[1]))
                        A = np.vstack([A, new_adj_row])
                        
                        new_adj_col = np.zeros((A.shape[0], 1))
                        A = np.hstack([A, new_adj_col])
                        
                        A[node, new_node] = 1
                        A[new_node, node] = 1
                        
                        feature_count = np.random.choice(len(M[c]), p=M[c])  
                        new_features = np.zeros_like(F[c])
                        chosen_features = np.random.choice(len(F[c]), size=feature_count, p=F[c])  
                        new_features[chosen_features] = 1 
                        
                        X = np.vstack([X, new_features])
    
    Gc=(A,X)
    return Gc

In [616]:
dataset=load_dataset()

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
  return torch.load(f, map_location)


In [617]:
data=dataset[0]

In [618]:
labels=data.y.numpy()

In [619]:
num_classes=dataset.num_classes

In [620]:
num_features=data.x.shape[1]

In [621]:
F,M=get_distribution(num_classes)

In [622]:
runturns=100

In [623]:
f=GCN_Victim(data.x.shape[1],dataset.num_classes)

In [624]:
train_Victim(f,data)

Epoch 10, Loss: 1.94196
Epoch 20, Loss: 0.17810
Epoch 30, Loss: 0.03975
Epoch 40, Loss: 0.00255
Epoch 50, Loss: 0.00162
Epoch 60, Loss: 0.00229
Epoch 70, Loss: 0.00313
Epoch 80, Loss: 0.00359
Epoch 90, Loss: 0.00403
Epoch 100, Loss: 0.00405
Epoch 110, Loss: 0.00390
Epoch 120, Loss: 0.00375
Epoch 130, Loss: 0.00358
Epoch 140, Loss: 0.00339
Epoch 150, Loss: 0.00323
Epoch 160, Loss: 0.00313
Epoch 170, Loss: 0.00306
Epoch 180, Loss: 0.00302
Epoch 190, Loss: 0.00297
Epoch 200, Loss: 0.00293


In [625]:
As, Xs, center_node, node_mapping=get_2hop_subgraph(data)

In [626]:
Gs=(As,Xs)

In [627]:
print(f"The index of the node is {center_node}")

The index of the node is 1805


In [628]:
XG, AG, SL=SubgraphSamplingAlgorithm(Gs,F, M, n=10, C=dataset.num_classes)

In [629]:
g=GCN_Extraction(XG.shape[1], dataset.num_classes)

In [630]:
train_Extraction(g,XG,AG,SL)

Epoch 10, Loss: 1.96491
Epoch 20, Loss: 0.80763
Epoch 30, Loss: 0.13576
Epoch 40, Loss: 0.02696
Epoch 50, Loss: 0.01270
Epoch 60, Loss: 0.01068
Epoch 70, Loss: 0.01155
Epoch 80, Loss: 0.01318
Epoch 90, Loss: 0.01454
Epoch 100, Loss: 0.01502
Epoch 110, Loss: 0.01468
Epoch 120, Loss: 0.01394
Epoch 130, Loss: 0.01312
Epoch 140, Loss: 0.01235
Epoch 150, Loss: 0.01171
Epoch 160, Loss: 0.01116
Epoch 170, Loss: 0.01069
Epoch 180, Loss: 0.01028
Epoch 190, Loss: 0.00991
Epoch 200, Loss: 0.00957
Epoch 210, Loss: 0.00927
Epoch 220, Loss: 0.00899
Epoch 230, Loss: 0.00874
Epoch 240, Loss: 0.00850
Epoch 250, Loss: 0.00829
Epoch 260, Loss: 0.00810
Epoch 270, Loss: 0.00792
Epoch 280, Loss: 0.00776
Epoch 290, Loss: 0.00760
Epoch 300, Loss: 0.00746
Epoch 310, Loss: 0.00734
Epoch 320, Loss: 0.00722
Epoch 330, Loss: 0.00711
Epoch 340, Loss: 0.00700
Epoch 350, Loss: 0.00691
Epoch 360, Loss: 0.00682
Epoch 370, Loss: 0.00673
Epoch 380, Loss: 0.00666
Epoch 390, Loss: 0.00659
Epoch 400, Loss: 0.00652


In [631]:
print(get_f(f,g,data,data.val_mask,runturns))

0.7999999999999985


In [643]:
Gs_with_approximated_nodes = approximate_nodes(Gs, api_out_vector, F, M, C=dataset.num_classes,rho=2)

0.15803975
0
2.913445
4
0.06019092
0
0.0059899776
0
0.080194116
0
0.25092846
0
0.00054511946
0
1.7204035e-05
0
0.0004789889
0
6.9592375e-08
0
0.0001668116
0
0.0021905038
0
4.0330793e-05
0
0.0029407365
0
0.0022450245
0
0.0009100209
0
0.001931866
0
8.740703e-05
0
0.0003179343
0
1.4388102e-06
0
0.0013771738
0
0.0023787993
0
1.9840028e-05
0
2.410854e-05
0
0.0026588137
0
0.001846148
0
0.0018364972
0
0.0015189915
0
0.0029809498
0
1.2076904e-05
0
0.0027813877
0
0.002598373
0
0.0025588414
0
0.13850905
0
0.10863025
0
0.015259831
0
0.036703207
0
0.006736813
0
0.2803359
0
0.0066120243
0
6.27215
13
0.00027915617
0
0.0049441042
0
0.005613043
0
0.0028533498
0
0.03101839
0
4.570126
10
0.0028214774
0
2.1524624e-05
0
0.0034956774
0
0.024788411
0
0.031452004
0
0.9555662
2
0.007875177
0
0.0008095057
0
0.0013002433
0
0.031275857
0
0.03832355
0
0.035310272
0
0.019939762
0
0.02748624
0
0.0008788744
0
0.039570834
0
0.0021767856
0
2.1378623e-06
0
0.000105253035
0
0.0029408562
0
0.00053188246
0
0.0002721294
0


In [644]:
print(f"The size of tensors with algo1{Gs[0].shape,Gs[1].shape}")
print(f"The size of tensors with algo2{Gs_with_approximated_nodes[0].shape,Gs_with_approximated_nodes[1].shape}")

The size of tensors with algo1((119, 119), (119, 1433))
The size of tensors with algo2((260, 260), (260, 1433))


In [645]:
XG1, AG1, SL1=SubgraphSamplingAlgorithm(Gs_with_approximated_nodes,F, M, n=10, C=dataset.num_classes)

In [646]:
g1=GCN_Extraction(XG1.shape[1], dataset.num_classes)

In [647]:
train_Extraction(g1,XG1,AG1,SL1)

Epoch 10, Loss: 1.94871
Epoch 20, Loss: 0.83225
Epoch 30, Loss: 0.16972
Epoch 40, Loss: 0.03132
Epoch 50, Loss: 0.01448
Epoch 60, Loss: 0.01202
Epoch 70, Loss: 0.01278
Epoch 80, Loss: 0.01418
Epoch 90, Loss: 0.01528
Epoch 100, Loss: 0.01560
Epoch 110, Loss: 0.01522
Epoch 120, Loss: 0.01445
Epoch 130, Loss: 0.01358
Epoch 140, Loss: 0.01279
Epoch 150, Loss: 0.01213
Epoch 160, Loss: 0.01158
Epoch 170, Loss: 0.01111
Epoch 180, Loss: 0.01068
Epoch 190, Loss: 0.01030
Epoch 200, Loss: 0.00996
Epoch 210, Loss: 0.00965
Epoch 220, Loss: 0.00937
Epoch 230, Loss: 0.00911
Epoch 240, Loss: 0.00887
Epoch 250, Loss: 0.00865
Epoch 260, Loss: 0.00846
Epoch 270, Loss: 0.00827
Epoch 280, Loss: 0.00810
Epoch 290, Loss: 0.00795
Epoch 300, Loss: 0.00780
Epoch 310, Loss: 0.00767
Epoch 320, Loss: 0.00754
Epoch 330, Loss: 0.00742
Epoch 340, Loss: 0.00731
Epoch 350, Loss: 0.00721
Epoch 360, Loss: 0.00712
Epoch 370, Loss: 0.00703
Epoch 380, Loss: 0.00694
Epoch 390, Loss: 0.00686
Epoch 400, Loss: 0.00679


In [648]:
print(get_f(f,g1,data,data.val_mask,runturns))

0.8119999999999987
