In [25]:
import torch
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import subgraph, k_hop_subgraph, dense_to_sparse
from torch_geometric.nn import GCNConv
from sklearn.metrics import accuracy_score
import torch.nn.functional as FF

In [26]:
dataset=Planetoid(root='tmp/Cora', name='Cora')
data=dataset[0]

  if osp.exists(f) and torch.load(f) != _repr(self.pre_transform):
  if osp.exists(f) and torch.load(f) != _repr(self.pre_filter):
  return torch.load(f, map_location)


In [27]:
def get_valid_2hop_subgraph(data, min_size=100, max_size=150):
    num_nodes=data.num_nodes

    while True:
        node_idx=torch.randint(0,num_nodes,(1,)).item()
        subset, edge_index,_,_=k_hop_subgraph(node_idx, 2, data.edge_index, relabel_nodes=True)

        if min_size<=subset.size(0)<=max_size:
            As=torch.zeros((subset.size(0),subset.size(0)))
            As[edge_index[0],edge_index[1]]=1

            Xs=data.x[subset]

            return As.numpy(), Xs.numpy(), node_idx, subset

In [28]:
class GCN(nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN,self).__init__()
        self.conv1=GCNConv(num_features,16)
        self.conv2=GCNConv(16, num_classes)
        
    def forward(self, x, edge_index):
        x=self.conv1(x,edge_index)
        x=torch.relu(x)
        x=self.conv2(x,edge_index)
        return torch.log_softmax(x, dim=1)

In [29]:
class GCN_victim(nn.Module):
    def __init__(self, num_features, num_classes):
        super(GCN_victim,self).__init__()
        self.conv1=GCNConv(num_features,128)
        self.conv2=GCNConv(128,64)
        self.conv3=GCNConv(64, 16)
        self.conv4=GCNConv(16,num_classes)

    def forward(self, x, edge_index):
        x=self.conv1(x,edge_index)
        x=torch.relu(x)
        x=FF.dropout(x,training=self.training)
        x = self.conv2(x, edge_index)
        x=torch.relu(x)
        x=FF.dropout(x,training=self.training)
        x=self.conv3(x,edge_index)
        x=torch.relu(x)
        x=self.conv4(x, edge_index)
        return torch.log_softmax(x, dim=1)

In [30]:
def GenerateSample(Fc, Mc, As):
    num_nodes=As.shape[0]
    Ac=torch.ones((num_nodes,num_nodes))
    Xc=torch.zeros(num_nodes, len(Fc))

    for i in range(num_nodes):
        m=np.random.choice(np.arange(len(Mc)),p=Mc)
        features=np.random.choice(len(Fc),size=m,replace=False,p=Fc)
        Xc[i,features]=1

    return Ac.numpy(), Xc.numpy()

In [31]:
def SubgraphSamplingAlgorithm(Gs, api_function, F, M, n, C):
    As, Xs=Gs
    num_nodes=As.shape[0]
    SA=[As]
    SX=[Xs]
    SL=[api_function(Xs,As)]

    for i in range(n):
        for c in range(C):
            Ac, Xc=GenerateSample(F[c],M[c],As)
            SA.append(Ac)
            SX.append(Xc)
            SL.append(api_function(Xc,Ac))

    AG_list=[dense_to_sparse(torch.tensor(a))[0] for a in SA]
    XG=torch.vstack([torch.tensor(x) for x in SX])

    SL=torch.tensor(SL,dtype=torch.long).view(-1)


    valid_mask = SL >= 0
    SL = SL[valid_mask]
    SL = SL[:XG.shape[0]]

    
    AG_combined=torch.cat([edge_index+i*num_nodes for i, edge_index in enumerate (AG_list)], dim=1)
    num_nodes_combined=XG.shape[0]


    
    g=GCN(XG.shape[1], C)
    optimizer=optim.Adam(g.parameters(),lr=0.01)
    loss_fn=nn.CrossEntropyLoss()


    g.train()
    for epoch in range(200):
        optimizer.zero_grad()
        out=g(XG,AG_combined)
        loss=loss_fn(out,SL)
        loss.backward()
        optimizer.step()

    return g

In [32]:
def api_wrapper(features, adjacency_matrix):
    features_tensor=torch.tensor(features,dtype=torch.float)
    adjacency_tensor=torch.tensor(adjacency_matrix,dtype=torch.long)

    if adjacency_tensor.ndim==2 and adjacency_tensor.shape[0]==adjacency_tensor.shape[1]:
        adjacency_tensor=torch.nonzero(adjacency_tensor, as_tuple=False).t()

    with torch.no_grad():
        logits=f(features_tensor, adjacency_tensor)
        predictions = torch.argmax(logits, dim=1)

    
    return predictions.numpy()

In [33]:
def model_out(features, adjacency_matrix):
    features_tensor=torch.tensor(features,dtype=torch.float)
    adjacency_tensor=torch.tensor(adjacency_matrix,dtype=torch.long)

    if adjacency_tensor.ndim==2 and adjacency_tensor.shape[0]==adjacency_tensor.shape[1]:
        adjacency_tensor=torch.nonzero(adjacency_tensor, as_tuple=False).t()
        
    with torch.no_grad():
        logits=f(features_tensor, adjacency_tensor)

    return logits.numpy()

In [34]:
As, Xs, center_node, node_mapping=get_valid_2hop_subgraph(data)

In [35]:
Gs=(As,Xs)
print(f"Center node:{center_node}, 2-hop subgraph size:{As.shape[0]}")

Center node:2001, 2-hop subgraph size:112


In [36]:
f=GCN_victim(Xs.shape[1],dataset.num_classes)

In [37]:
api_function=api_wrapper

In [38]:
labels=data.y.numpy()
num_classes=dataset.num_classes
num_features=data.x.shape[1]

In [39]:
F=[]
M=[]

In [40]:
for c in range(num_classes):
    class_nodes=data.x[labels==c]

    feature_counts=class_nodes.sum(dim=0).numpy()
    feature_distribution=feature_counts/feature_counts.sum()
    F.append(feature_distribution)

    num_features_per_node=class_nodes.sum(dim=1).numpy()
    feature_count_distribution=np.bincount(num_features_per_node.astype(int),minlength=num_features)
    M.append(feature_count_distribution/feature_count_distribution.sum())

In [41]:
g=SubgraphSamplingAlgorithm(Gs, api_function, F, M, n=10, C=dataset.num_classes)

In [42]:
def evaluate(original_model, extracted_model, test_features, test_adj):
    original_preds_labels=api_wrapper(test_features, test_adj)
    test_adj_tensor=torch.tensor(test_adj,dtype=torch.float)
    edge_index, edge_weight=dense_to_sparse(test_adj_tensor)
    with torch.no_grad():
        extracted_preds=extracted_model(torch.tensor(test_features,dtype=torch.float),edge_index)
        
    extracted_preds_labels=extracted_preds.argmax(dim=1).numpy()

    f=accuracy_score(original_preds_labels, extracted_preds_labels)

    return f

In [50]:
f_list=[]
for i in range(1000):
    f_score=evaluate(api_function,g,Xs,As)
    f_list.append(f_score)
print(f"Fidelity of the extracted model compared to the original model: {sum(f_list)/len(f_list):.4f}")

Fidelity of the extracted model compared to the original model: 0.4760


In [44]:
def find_center_node_in_subgraph(center_global_idx, node_mapping):
    center_subgraph_idx = node_mapping.tolist().index(center_global_idx)
    return center_subgraph_idx

In [45]:
def approximate_nodes(Gc, api_function_model, F, M, C, rho,alpha):
    A,X=Gc
    L=api_function_model(X, A)
    center_idx=find_center_node_in_subgraph(center_node,node_mapping)
    for node in range(A.shape[0]):
        if node!=center_idx:
            Dn=L[node]-L[center_idx]
            for c in range(C):
                if Dn[c]>0:
                    num_new_nodes=int(Dn[c]*rho*A[node].sum()*alpha)
                    for _ in range(num_new_nodes):
                        new_node = A.shape[0]
                        
                        new_adj_row = np.zeros((1, A.shape[1]))
                        A = np.vstack([A, new_adj_row])
                        
                        new_adj_col = np.zeros((A.shape[0], 1))
                        A = np.hstack([A, new_adj_col])
                        
                        A[node, new_node] = 1
                        A[new_node, node] = 1
                        
                        feature_count = np.random.choice(len(M[c]), p=M[c])  
                        new_features = np.zeros_like(F[c])
                        chosen_features = np.random.choice(len(F[c]), size=feature_count, p=F[c])  
                        new_features[chosen_features] = 1  

                        X = np.vstack([X, new_features])
    Gc=(A,X)
    return Gc

In [46]:
Gs_with_approximated_nodes = approximate_nodes(Gs, model_out, F, M, C=dataset.num_classes,rho=2, alpha=3)

In [49]:
print(Gs_with_approximated_nodes[0].shape,Gs_with_approximated_nodes[1].shape)

(318, 318) (318, 1433)


In [47]:
g1=SubgraphSamplingAlgorithm(Gs_with_approximated_nodes, api_function, F, M, n=10, C=dataset.num_classes)

In [51]:
f1_list=[]
for i in range(1000):
    f1_score=evaluate(api_function,g1,Xs,As)
    f1_list.append(f1_score)
print(f"Fidelity of the extracted model compared to the original model: {sum(f1_list)/len(f1_list):.4f}")

Fidelity of the extracted model compared to the original model: 0.5520
