In [1]:
import os
import os.path as osp

from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import GCNConv
import numpy as np
import torch
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, accuracy_score

from utils import (
    get_link_labels,
    prediction_fairness,
)

from torch_geometric.utils import train_test_split_edges

device = "cuda" if torch.cuda.is_available() else "cpu"
import wandb
import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch_geometric.utils import to_undirected, to_networkx, k_hop_subgraph, is_undirected
from torch_geometric.data import Data
from torch_geometric.loader import GraphSAINTRandomWalkSampler
from torch_geometric.seed import seed_everything

In [3]:
from model.gcn import GCN
from model.deletegcn import GCNDelete
import torch.nn as nn

In [4]:
import numpy as np
import torch
from sklearn.multioutput import MultiOutputClassifier
from torch_sparse import SparseTensor
from sklearn.metrics import (
    roc_auc_score,
    make_scorer,
    balanced_accuracy_score,
)
from sklearn.neural_network import MLPClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn import model_selection, pipeline, metrics

# Metrics
from fairlearn.metrics import (
    demographic_parity_difference,
    equalized_odds_difference,
)
from itertools import combinations_with_replacement

In [5]:
from torch_geometric import seed_everything

seed_everything(1888)

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [7]:
from training_args import parse_args
args=parse_args()

In [8]:
from model.gcn import GCN

In [9]:
dataset = "citeseer" #"cora" "pubmed"
path = osp.join(osp.dirname(osp.realpath('__file__')), "..", "data", dataset)
dataset = Planetoid(path, dataset, transform=T.NormalizeFeatures())

In [10]:
pwd

'/home/jupyter/FairDrop'

In [11]:
path

'/home/jupyter/FairDrop/../data/citeseer'

In [12]:
test_seeds = [0,1,2,3,4,5]
acc_auc = []
fairness = []
acc_auc_ori = []
fairness_ori = []

In [13]:
data = dataset[0]

In [14]:
data

Data(x=[3327, 3703], edge_index=[2, 9104], y=[3327], train_mask=[3327], val_mask=[3327], test_mask=[3327])

In [15]:
args.in_dim=data.num_features

In [16]:
protected_attribute = data.y
Y = torch.LongTensor(protected_attribute).to(device)
data.train_mask = data.val_mask = data.test_mask = data.y = None
data = train_test_split_edges(data, val_ratio=0.1, test_ratio=0.2)
data = data.to(device)
num_classes = len(np.unique(protected_attribute))
N = data.num_nodes



In [17]:
N

3327

In [18]:
6374/100*3

191.22

In [19]:
data.train_pos_edge_index

tensor([[   0,    1,    1,  ..., 3324, 3325, 3326],
        [ 628,  158, 2919,  ...,  268, 1643,   33]], device='cuda:0')

In [20]:


Y = torch.LongTensor(protected_attribute).to(device)
Y_diff = (
    Y[data.train_pos_edge_index[0, :]] != Y[data.train_pos_edge_index[1, :]]
).to(device)

Y_same = (
    Y[data.train_pos_edge_index[0, :]] == Y[data.train_pos_edge_index[1, :]]
).to(device)


In [21]:
torch.sum(Y_diff)
torch.sum(Y_same)
diff=Y_diff.nonzero().squeeze()
same=Y_same.nonzero().squeeze()
same

tensor([   0,    1,    2,  ..., 6370, 6371, 6373], device='cuda:0')

In [22]:
edge_to_delete=100
ratio=3
diff_size=int(edge_to_delete*(ratio)/(ratio+1))
same_size=int(edge_to_delete*(1)/(ratio+1))
idx_diff = torch.randperm(diff.shape[0])[:diff_size]
df_diff_idx = diff[idx_diff]
idx_same = torch.randperm(same.shape[0])[:same_size]
df_same_idx = same[idx_same]
df_global_idx=torch.cat((df_diff_idx,df_same_idx),0)

In [23]:
dr_mask = torch.ones(data.train_pos_edge_index.shape[1], dtype=torch.bool)
dr_mask[df_global_idx] = False
dr_mask=dr_mask.to(device)

df_mask = torch.zeros(data.train_pos_edge_index.shape[1], dtype=torch.bool)
df_mask[df_global_idx] = True
df_mask=df_mask.to(device)

In [24]:
###

In [25]:
len(df_diff_idx)

75

In [26]:
Y[data.train_pos_edge_index[0,df_diff_idx]]

tensor([1, 3, 0, 2, 3, 1, 3, 3, 3, 2, 1, 2, 3, 2, 1, 0, 0, 3, 0, 1, 2, 4, 5, 0,
        2, 0, 3, 2, 0, 0, 1, 3, 5, 1, 3, 0, 3, 3, 2, 1, 1, 3, 1, 2, 3, 3, 3, 0,
        0, 0, 0, 0, 3, 1, 2, 2, 4, 1, 1, 3, 5, 2, 2, 2, 1, 0, 4, 4, 0, 5, 1, 0,
        2, 5, 0], device='cuda:0')

In [27]:
Y[data.train_pos_edge_index[1,df_diff_idx]]

tensor([2, 2, 3, 0, 2, 4, 4, 2, 1, 4, 2, 1, 5, 1, 4, 4, 1, 4, 3, 2, 1, 0, 4, 4,
        3, 1, 0, 1, 1, 4, 0, 1, 4, 5, 1, 1, 0, 2, 1, 0, 5, 4, 4, 1, 2, 4, 1, 4,
        2, 1, 1, 1, 0, 3, 1, 1, 5, 0, 5, 1, 4, 4, 1, 5, 3, 3, 5, 0, 5, 2, 0, 2,
        1, 0, 3], device='cuda:0')

In [28]:
torch.unique(Y)

tensor([0, 1, 2, 3, 4, 5], device='cuda:0')

In [29]:
N  #num of nodes

3327

In [30]:
node_all=torch.arange(N)
node_all

tensor([   0,    1,    2,  ..., 3324, 3325, 3326])

In [31]:
N_0=node_all[Y[node_all]==0]
N_1=node_all[Y[node_all]==1]
N_2=node_all[Y[node_all]==2]
N_3=node_all[Y[node_all]==3]
N_4=node_all[Y[node_all]==4]
N_5=node_all[Y[node_all]==5]

In [32]:
pos_node=[]
pos_size=[]
neg_node=[]
neg_size=8
for n in node_all.tolist():
    l_hop_node, l_hop_edge, l_hop_index, l_hop_mask = k_hop_subgraph(
                n, 
                1, 
                data.train_pos_edge_index,
                num_nodes=data.num_nodes)
    pos_node.append(l_hop_node[1:])
    if(Y[n]==0):
        neg_idx=torch.randperm(len(N_0))[0:neg_size]
        neg_sample=N_0[neg_idx]
    elif(Y[n]==1):
        neg_idx=torch.randperm(len(N_1))[0:neg_size]
        neg_sample=N_1[neg_idx]
    elif(Y[n]==2):
        neg_idx=torch.randperm(len(N_2))[0:neg_size]
        neg_sample=N_2[neg_idx]
    elif(Y[n]==3):
        neg_idx=torch.randperm(len(N_3))[0:neg_size]
        neg_sample=N_3[neg_idx]
    elif(Y[n]==4):
        neg_idx=torch.randperm(len(N_4))[0:neg_size]
        neg_sample=N_4[neg_idx]
    else:
        neg_idx=torch.randperm(len(N_5))[0:neg_size]
        neg_sample=N_5[neg_idx]  
    neg_node.append(neg_sample)
    pos_size.append(len(l_hop_node)-1)

In [33]:
##Aug 

def init_weights(m): 
    if isinstance(m, nn.Linear): 
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu') 
        nn.init.constant_(m.bias, 0)  

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class MLPA(torch.nn.Module):

    def __init__(self, in_feats, dim_h, dim_z):
        super(MLPA, self).__init__()
        
        self.gcn_mean = torch.nn.Sequential(
                torch.nn.Linear(in_feats, dim_h),
                torch.nn.BatchNorm1d(dim_h),
                torch.nn.LeakyReLU(),
                torch.nn.Linear(dim_h, dim_z)
                )

    def forward(self, hidden):
        # GCN encoder
        self.gcn_mean.apply(init_weights)
        Z = self.gcn_mean(hidden)
        # inner product decoder
        adj_logits = Z @ Z.T
        return adj_logits

class PGNNMask(torch.nn.Module):
    def __init__(self, features, n_hidden=64, temperature=1) -> None:
        super(PGNNMask,self).__init__()
        #self.g_encoder = GCN_Body(in_feats = features.shape[1], n_hidden = n_hidden, out_feats = n_hidden, dropout = 0.1, nlayer = 1)
        self.Aaug = MLPA(in_feats = n_hidden, dim_h = n_hidden, dim_z =features.shape[1])
        self.temperature = temperature
        
    def _sample_graph(self, sampling_weights, temperature=1.0, bias=0.0, training=True):
        if training:
            bias = bias + 0.0001  # If bias is 0, we run into problems
            eps = (bias - (1-bias)) * torch.rand(sampling_weights.size()) + (1-bias)
            gate_inputs = torch.log(eps) - torch.log(1 - eps)
            gate_inputs=gate_inputs.cuda()
            gate_inputs = (gate_inputs + sampling_weights) / temperature
            graph =  torch.sigmoid(gate_inputs)
        else:
            graph = torch.sigmoid(sampling_weights)
        return graph
    
    def normalize_adj(self,adj):
        adj.fill_diagonal_(1)
        # normalize adj with A = D^{-1/2} @ A @ D^{-1/2}
        D_norm = torch.diag(torch.pow(adj.sum(1), -0.5)).cuda()
        adj = D_norm @ adj @ D_norm
        return adj

    def forward(self, h, alpha = 0.5, adj_orig = None):
        #h = self.g_encoder(adj, x)

        # Edge perturbation
        adj_logits = self.Aaug(h)
        ## sample a new adj
        edge_probs = torch.sigmoid(adj_logits)

        if (adj_orig is not None) :
            edge_probs = alpha*edge_probs + (1-alpha)*adj_orig

        # sampling 
        adj_sampled =self._sample_graph(adj_logits)
        # making adj_sampled symmetric
        adj_sampled = adj_sampled.triu(1)
        adj_sampled = adj_sampled + adj_sampled.T
        adj_sampled = self.normalize_adj(adj_sampled)


        return adj_sampled, adj_logits

In [35]:
sum(pos_size)/N

1.9158400961827473

In [36]:
from itertools import product
full_pos_edge_index1=data.train_pos_edge_index
print(full_pos_edge_index1.shape)

torch.Size([2, 6374])


In [37]:
data.train_pos_edge_index

tensor([[   0,    1,    1,  ..., 3324, 3325, 3326],
        [ 628,  158, 2919,  ...,  268, 1643,   33]], device='cuda:0')

In [38]:
is_undirected(data.train_pos_edge_index) ## data.train_pos_edge_index already undirected

True

In [39]:
full_edges_set=set(list((data.train_pos_edge_index[0,i].tolist(), data.train_pos_edge_index[1,i].tolist()) for i in range(data.train_pos_edge_index.shape[1]) ))
full_node_list=data.train_pos_edge_index[0,:].tolist()
delete_node_list=list(set(data.train_pos_edge_index[:, df_mask][0,:].tolist()+data.train_pos_edge_index[:, df_mask][1,:].tolist()))
#forget_edges_set=list(product(data.train_pos_edge_index[0,:].tolist(), data.train_pos_edge_index[1,:].tolist()))

In [40]:
#full_edges_set
import copy


In [41]:
full_pos_edge_index=copy.deepcopy(full_pos_edge_index1)
full_pos_edge_index

tensor([[   0,    1,    1,  ..., 3324, 3325, 3326],
        [ 628,  158, 2919,  ...,  268, 1643,   33]], device='cuda:0')

In [42]:
data.train_pos_edge_index[:, df_mask].shape

torch.Size([2, 100])

In [43]:
data.train_pos_edge_index.shape

torch.Size([2, 6374])

In [44]:
full_pos_edge_index1[:,dr_mask].shape

torch.Size([2, 6274])

In [45]:
re_edge_num=full_pos_edge_index1[:,dr_mask].shape[1]

In [46]:
##potential edges
full_pos_edge_index=copy.deepcopy(full_pos_edge_index1)[:,dr_mask]
existing_edges=full_edges_set
k=1
#for target in range(edge_to_delete):
for target in df_diff_idx:
    _, l_hop_edge, _, l_hop_mask = k_hop_subgraph(
            [data.train_pos_edge_index[0,:][target]], 
            k, 
            data.train_pos_edge_index,
            num_nodes=data.num_nodes)

    ldset1=l_hop_edge.flatten().unique()
    ldset1=ldset1[ldset1!=data.train_pos_edge_index[0,:][target].item()]

    _, r_hop_edge, _, r_hop_mask = k_hop_subgraph(
            [data.train_pos_edge_index[1,:][target]], 
            k, 
            data.train_pos_edge_index,
            num_nodes=data.num_nodes)

    rdset1=r_hop_edge.flatten().unique()
    rdset1=rdset1[rdset1!=data.train_pos_edge_index[1,:][target].item()]

    combine = list(product(ldset1.tolist(), rdset1.tolist()))
    ind=[True if Y[a[0]]!=Y[a[1]] else False for a in combine ]
    sele_pair=[combine[i] for i in range(len(ind)) if ind[i]==True]
    #print("sele_pair",len(sele_pair))
    ind1=[False if (a[0],a[1]) in existing_edges else True for a in sele_pair]
    sele_pair1=[sele_pair[i] for i in range(len(ind1)) if ind1[i]==True]
    #print("sele_pair1",len(sele_pair1))
    existing_edges=set(list(sele_pair1)+list(existing_edges)+list((a[1],a[0])for a in sele_pair1))
    #print("existing",len(existing_edges))
    n=len(sele_pair1)
    #add=n
    #add_pair_ind=torch.randperm(len(sele_pair))[:add]
    add_pair_ind=np.arange(len(sele_pair1))
    add_pair=[sele_pair1[i] for i in add_pair_ind]
    add_matrix_0=[add_pair[i][0] for i in range(len(add_pair))] 
    add_matrix_1=[add_pair[i][1] for i in range(len(add_pair))] 
    add_matrix1=torch.tensor([add_matrix_0,add_matrix_1]).to(device)
    add_matrix2=torch.tensor([add_matrix_1,add_matrix_0]).to(device)
    full_pos_edge_index=torch.cat((full_pos_edge_index,add_matrix1),1)
    full_pos_edge_index=torch.cat((full_pos_edge_index,add_matrix2),1)
full_pos_edge_index=full_pos_edge_index.long()

In [47]:
full_pos_edge_index.shape

torch.Size([2, 6680])

In [48]:
data.aug_pos_edge_index=full_pos_edge_index

In [49]:
model_ckpt = torch.load(os.path.join(args.checkpoint_dir, 'model_final.pt'), map_location=device)
model_ori=GCN(args)
model_ori.load_state_dict(model_ckpt['model_state'], strict=False)
model_ori=model_ori.to(device)
model_ori.eval()

GCN(
  (conv1): GCNConv(3703, 128)
  (conv2): GCNConv(128, 64)
)

In [50]:
from model.Auggcn import AugGCN

In [51]:

model_aug = AugGCN(args)
model_ckpt = torch.load(os.path.join(args.checkpoint_dir, 'model_final.pt'), map_location=device)
model_aug.load_state_dict(model_ckpt['model_state'], strict=False)
model_aug = model_aug.to(device)
model_aug.eval()

AugGCN(
  (conv1): GCNConv(3703, 128)
  (conv2): GCNConv(128, 64)
)

In [77]:
z_ori= model_ori(data.x, data.train_pos_edge_index,return_all_emb=False)
z_ori.detach()
z_ori.shape
aug = PGNNMask(data.x.cuda(), n_hidden=64, temperature=6).to(device)
aug=aug.to(device)
loss_fn = nn.BCEWithLogitsLoss(reduction='none')
optimizer_aug = torch.optim.Adam(aug.parameters(), lr = 0.001)

adj_sampled_aug,adj_logits_aug = aug(z_ori.cuda())

weight_aug1=adj_sampled_aug[full_pos_edge_index[0],full_pos_edge_index[1]]
#weight_aug[:6374]=1.0
#weight_aug[:6374]=weight_aug[:6374].detach()
weight_ori=torch.zeros(full_pos_edge_index.shape[1]).to(device)
weight_ori[:re_edge_num]=1.0
weight_ori=weight_ori.detach()
m=1-weight_ori
m=m.detach()
weight_aug=weight_ori+m*weight_aug1*800
z_aug = model_aug(data.x, data.aug_pos_edge_index,weight_aug,return_all_emb=False)

In [78]:
###aug 
#torch.nn.utils.clip_grad_norm_(aug.parameters(),max_norm=1.0)

In [79]:
node_record=[]
k=2
for target in range(edge_to_delete):
    l_hop_node, l_hop_edge, l_hop_index, l_hop_mask = k_hop_subgraph(
            [data.train_pos_edge_index[:, df_mask][0,:][target]], 
            k, 
            data.train_pos_edge_index,
            num_nodes=data.num_nodes)   

    r_hop_node, r_hop_edge, _, r_hop_mask = k_hop_subgraph(
            [data.train_pos_edge_index[:, df_mask][1,:][target]], 
            k, 
            data.train_pos_edge_index,
            num_nodes=data.num_nodes)
    node_record+=l_hop_node.tolist()
    node_record+=r_hop_node.tolist()

In [80]:
node=set(node_record)
print(len(node))
target_node=torch.tensor(list(node))

963


In [81]:
all_idx=torch.arange(len(weight_aug))

In [82]:

##real
batch_size=256
weight=0.0000001
#weight=0
l_num=0
max_edge=20
max_per_iter=3
add_edge=0
edge_add=[]
for epoch in range(5000):
    batch_idx=target_node[torch.randperm(len(target_node))[0:batch_size]]
    #aug.train()
    for i in range(1):
        weight_aug1=adj_sampled_aug[full_pos_edge_index[0],full_pos_edge_index[1]]
        weight_aug=weight_ori+m*weight_aug1*800
        z_aug = model_aug(data.x, data.aug_pos_edge_index,weight_aug,return_all_emb=False)
        #print(z_aug.requires_grad)
        pos_sim_padded = []
        neg_sim_batch=[]
        mask_padded = []
        max_num_pos=5
        #batch_idx=torch.randperm(N)[0:batch_size]
        #print(batch_idx)
        #batch_idx=target_node[torch.randperm(len(target_node))[0:batch_size]]
        for idx in batch_idx:
            #print(idx)
            z_pos_aug=z_aug[pos_node[idx],:]
            z_neg_aug=z_aug[neg_node[idx],:]
            pos_sim=torch.matmul(z_pos_aug,z_aug[idx].unsqueeze(1)).T
            #print("pos_sim",pos_sim.shape)
            neg_sim=torch.matmul(z_neg_aug,z_aug[idx].unsqueeze(1)).T
            #print("neg_sim",neg_sim.shape)
            neg_sim_batch.append(neg_sim)
            padding_size = max_num_pos - pos_sim.size(1)
            #print(F.pad(pos_sim, (0, padding_size), value=0.0))
            pos_sim_padded.append(F.pad(pos_sim, (0, padding_size), value=0.0))  # Pad with 0 for missing positives
            #print("append",len(pos_sim_padded))
            mask_padded.append(F.pad(torch.ones_like(pos_sim), (0, padding_size), value=0.0))  # Mask for real values
            # Stack the padded similarities and masks
        pos_sim_padded = torch.stack(pos_sim_padded,dim=0).squeeze(dim=1)  # [batch_size, max_num_positives]
        neg_sim_batch=torch.stack(neg_sim_batch,dim=0).squeeze(dim=1)
        mask_padded = torch.stack(mask_padded,dim=0).squeeze(dim=1)  # [batch_size, max_num_positives]
        #z_aug = model_aug(data.x, data.aug_pos_edge_index,weight_aug,return_all_emb=False)
        # Concatenate negative and padded positive similarities to form logits
        logits = torch.cat([neg_sim_batch, pos_sim_padded], dim=1)  # [batch_size, num_negatives + max_num_positives]

        # Create labels: 0 for negative pairs, 1 for real positive pairs (mask out padded entries)
        labels = torch.cat([torch.zeros_like(neg_sim_batch), mask_padded], dim=1)  # [batch_size, num_negatives + max_num_positives]
        link_probs = logits.sigmoid()
        # Compute Binary Cross-Entropy loss without reduction to apply mask
        loss = loss_fn(link_probs, labels)

        # Mask the padded positions by multiplying the loss by the mask (for positive part)
        mask = torch.cat([torch.ones_like(neg_sim_batch), mask_padded], dim=1)  # Mask for ignoring padded logits
        #print("z_aug",z_aug.requires_grad)
        loss = loss * mask  # Apply mask to ignore padded positions
        for name, param in model_aug.named_parameters():
            param.requires_grad=False
        l1_loss=torch.abs(weight_aug[re_edge_num:]).sum()
        loss_total=(loss.sum()/mask.sum())+l1_loss*weight
        
        #optimizer_aug.zero_grad()
        loss_total.backward(retain_graph=True)
        #print("loss_total",loss_total)
        optimizer_aug.step()
        optimizer_aug.zero_grad()
        z_aug_detach=z_aug.detach()
        """
        for name, param in aug.named_parameters():
            param.requires_grad=False
        print("z_aug1",z_aug.requires_grad)
        """
        adj_sampled_aug,adj_logits_aug = aug(z_aug_detach)
        #adj_sampled_aug,adj_logits_aug = aug(z_ori.cuda())
        #l=weight_aug[re_edge_num:]>0.9
        l=(m*weight_aug1*800)>0.9
        l_num=l.sum()
    if(l_num>0):
            print("Add",loss.sum()/mask.sum())
            print("l",l_num)
            print("weight_max",(m*weight_aug1*800).max())
            num_add=min(l_num,max_per_iter*2)
            if(num_add%2==0):
                v,idx=torch.topk((m*weight_aug1*800),num_add)
                weight_ori[idx]=1
                m=1-weight_ori
                for i in idx:
                    edge_add.append((data.aug_pos_edge_index[0,i].item(),data.aug_pos_edge_index[1,i].item()))
                print("add",data.aug_pos_edge_index[:,idx])
                if(len(edge_add)>max_edge*2):
                    print("finish")
                    break
            
    
#weight_aug[:6374]=weight_aug[:6374].detach()
    if(epoch%100==0):
        print("epoch",epoch)
        print("loss",loss.sum()/mask.sum())
        print("weight_max",(m*weight_aug1*800).max())
        dif=(weight_aug-weight_ori).sum()
        print("dif",dif)
        print("l",l_num)
        
        

epoch 0
loss tensor(0.8565, device='cuda:0', grad_fn=<DivBackward0>)
weight_max tensor(0.2410, device='cuda:0', grad_fn=<MaxBackward1>)
dif tensor(97.6555, device='cuda:0', grad_fn=<SumBackward0>)
l tensor(0, device='cuda:0')
epoch 100
loss tensor(0.8431, device='cuda:0', grad_fn=<DivBackward0>)
weight_max tensor(0.2413, device='cuda:0', grad_fn=<MaxBackward1>)
dif tensor(97.6572, device='cuda:0', grad_fn=<SumBackward0>)
l tensor(0, device='cuda:0')
epoch 200
loss tensor(0.8468, device='cuda:0', grad_fn=<DivBackward0>)
weight_max tensor(0.2430, device='cuda:0', grad_fn=<MaxBackward1>)
dif tensor(97.8326, device='cuda:0', grad_fn=<SumBackward0>)
l tensor(0, device='cuda:0')
epoch 300
loss tensor(0.8462, device='cuda:0', grad_fn=<DivBackward0>)
weight_max tensor(0.2443, device='cuda:0', grad_fn=<MaxBackward1>)
dif tensor(97.8286, device='cuda:0', grad_fn=<SumBackward0>)
l tensor(0, device='cuda:0')
epoch 400
loss tensor(0.8584, device='cuda:0', grad_fn=<DivBackward0>)
weight_max tensor(0

In [101]:
edge_add

[(1749, 1910),
 (1910, 1749),
 (2789, 1352),
 (1352, 2789),
 (548, 2613),
 (2613, 548),
 (2613, 61),
 (61, 2613),
 (2912, 2613),
 (2613, 2912),
 (492, 2681),
 (2681, 492),
 (2681, 887),
 (887, 2681),
 (2585, 2681),
 (2681, 2585),
 (2128, 2681),
 (2681, 2128),
 (2681, 27),
 (27, 2681),
 (2257, 2681),
 (2681, 2257)]

In [102]:

edge_new=torch.tensor([[edge_add[i][0],edge_add[i][1]] for i in range(len(edge_add))]).T.to(device)

In [103]:
edge_new.shape

torch.Size([2, 22])

In [104]:
full_pos_edge_index1.shape

torch.Size([2, 6374])

In [105]:
full_pos_edge_index_aug=torch.cat((full_pos_edge_index1,edge_new),1)

In [107]:
data.aug_pos_edge_index=full_pos_edge_index_aug

In [108]:
dr_full_mask = torch.ones(data.aug_pos_edge_index.shape[1], dtype=torch.bool)
dr_full_mask[df_global_idx] = False
dr_full_mask=dr_full_mask.to(device)

df_full_mask = torch.zeros(data.aug_pos_edge_index.shape[1], dtype=torch.bool)
df_full_mask[df_global_idx] = True
df_full_mask=df_full_mask.to(device)

In [109]:
# Edges in S_Df
_, two_hop_edge, _, two_hop_mask = k_hop_subgraph(
        data.aug_pos_edge_index[:, df_full_mask].flatten().unique(), 
        2, 
        data.aug_pos_edge_index,
        num_nodes=data.num_nodes)
data.sdf_mask = two_hop_mask

In [110]:
# Nodes in S_Df
_, one_hop_edge, _, one_hop_mask = k_hop_subgraph(
    data.aug_pos_edge_index[:, df_full_mask].flatten().unique(), 
    1, 
    data.aug_pos_edge_index,
    num_nodes=data.num_nodes)

In [111]:
sdf_node_1hop = torch.zeros(data.num_nodes, dtype=torch.bool)
sdf_node_2hop = torch.zeros(data.num_nodes, dtype=torch.bool)

In [112]:
sdf_node_1hop[one_hop_edge.flatten().unique()] = True
sdf_node_2hop[two_hop_edge.flatten().unique()] = True
assert sdf_node_1hop.sum() == len(one_hop_edge.flatten().unique())
assert sdf_node_2hop.sum() == len(two_hop_edge.flatten().unique())

In [113]:
data.sdf_node_1hop_mask = sdf_node_1hop
data.sdf_node_2hop_mask = sdf_node_2hop

In [114]:
is_undirected(data.aug_pos_edge_index)

True

In [115]:
full_pos_edge_index1, [df_full_mask1, two_hop_mask1] = to_undirected(data.aug_pos_edge_index, [df_full_mask.int(), two_hop_mask.int()])
two_hop_mask1 = two_hop_mask1.bool()
df_full_mask1 = df_full_mask1.bool()
dr_full_mask1 = ~df_full_mask

data.aug_pos_edge_index1 =full_pos_edge_index1
data.edge_index1 = full_pos_edge_index1
assert is_undirected(data.aug_pos_edge_index1)

In [116]:
data.sdf_mask = two_hop_mask1
data.df_aug_mask = df_full_mask1
data.dr_aug_mask = dr_full_mask1

In [259]:
z_aug= model_ori(data.x, data.aug_pos_edge_index1[:, data.sdf_mask],return_all_emb=False)
z_aug.detach()
z_aug.shape

torch.Size([3327, 64])

In [282]:
mask = PGNNMask(data.x.cuda(), n_hidden=64, temperature=1).to(device)
mask=mask.to(device)

In [283]:
from model.Wdeletegcn import WGCNDelete
model_pru = WGCNDelete(args, sdf_node_1hop, sdf_node_2hop, num_nodes=data.num_nodes, num_edge_type=args.num_edge_type)
model_ckpt = torch.load(os.path.join(args.checkpoint_dir, 'model_final.pt'), map_location=device)
model_pru.load_state_dict(model_ckpt['model_state'], strict=False)
model_pru = model_pru.to(device)

In [284]:
parameters_to_optimize = [
                {'params': [p for n, p in model_pru.named_parameters() if 'del' in n], 'weight_decay': 0.0}
            ]
print('parameters_to_optimize', [n for n, p in model_pru.named_parameters() if 'del' in n])

parameters_to_optimize ['deletion1.deletion_weight', 'deletion2.deletion_weight']


In [285]:
optimizer_mask = torch.optim.Adam(mask.parameters(), lr = 1e-2)
optimizer_pru = torch.optim.Adam(parameters_to_optimize, lr=0.00001)#, weight_decay=args.weight_decay)

In [286]:
sdf1_all_pair_mask = torch.zeros(data.num_nodes, data.num_nodes, dtype=torch.bool)
idx = torch.combinations(torch.arange(data.num_nodes)[data.sdf_node_1hop_mask], with_replacement=True).t()
sdf1_all_pair_mask[idx[0], idx[1]] = True
sdf1_all_pair_mask[idx[1], idx[0]] = True

In [287]:
assert sdf1_all_pair_mask.sum().cpu() == data.sdf_node_1hop_mask.sum().cpu() * data.sdf_node_1hop_mask.sum().cpu()

In [288]:
 ## Remove Df itself
sdf1_all_pair_mask[data.aug_pos_edge_index1[:, data.df_aug_mask][0], data.aug_pos_edge_index1[:, data.df_aug_mask][1]] = False
sdf1_all_pair_mask[data.aug_pos_edge_index1[:, data.df_aug_mask][1], data.aug_pos_edge_index1[:, data.df_aug_mask][0]] = False
##sdf1_all_pair_mask contain

In [289]:
sdf2_all_pair_mask = torch.zeros(data.num_nodes, data.num_nodes, dtype=torch.bool)
idx = torch.combinations(torch.arange(data.num_nodes)[data.sdf_node_2hop_mask], with_replacement=True).t()
sdf2_all_pair_mask[idx[0], idx[1]] = True
sdf2_all_pair_mask[idx[1], idx[0]] = True
assert sdf2_all_pair_mask.sum().cpu() == data.sdf_node_2hop_mask.sum().cpu() * data.sdf_node_2hop_mask.sum().cpu()

In [290]:
 ## Remove Df itself
sdf2_all_pair_mask[data.aug_pos_edge_index1[:, data.df_aug_mask][0], data.aug_pos_edge_index1[:, data.df_aug_mask][1]] = False
sdf2_all_pair_mask[data.aug_pos_edge_index1[:, data.df_aug_mask][1], data.aug_pos_edge_index1[:, data.df_aug_mask][0]] = False
##sdf1_all_pair_mask contain

In [291]:
 ## Lower triangular mask
idx = torch.tril_indices(data.num_nodes, data.num_nodes, -1)
lower_mask = torch.zeros(data.num_nodes, data.num_nodes, dtype=torch.bool)
lower_mask[idx[0], idx[1]] = True

In [292]:
## The final mask is the intersection
sdf1_all_pair_without_df_mask = sdf1_all_pair_mask & lower_mask
sdf2_all_pair_without_df_mask = sdf2_all_pair_mask & lower_mask

In [293]:
data

Data(x=[3327, 3703], val_pos_edge_index=[2, 455], test_pos_edge_index=[2, 910], train_pos_edge_index=[2, 6374], train_neg_adj_mask=[3327, 3327], val_neg_edge_index=[2, 455], test_neg_edge_index=[2, 910], aug_pos_edge_index=[2, 6396], sdf_mask=[6396], sdf_node_1hop_mask=[3327], sdf_node_2hop_mask=[3327], aug_pos_edge_index1=[2, 6396], edge_index1=[2, 6396], df_aug_mask=[6396], dr_aug_mask=[6396], sdf_node_1hop_mask_non_df_mask=[3327], sdf_node_2hop_mask_non_df_mask=[3327])

In [294]:
non_df_node_mask = torch.ones(data.x.shape[0], dtype=torch.bool, device=data.x.device)
non_df_node_mask[data.aug_pos_edge_index1[:,data.df_aug_mask].flatten().unique()] = False

data.sdf_node_1hop_mask_non_df_mask = data.sdf_node_1hop_mask.to(device) & non_df_node_mask
data.sdf_node_2hop_mask_non_df_mask = data.sdf_node_2hop_mask.to(device) & non_df_node_mask

In [295]:
loss_fct = nn.MSELoss()

In [311]:
z_ori= model_ori(data.x, data.aug_pos_edge_index1[:, data.sdf_mask],return_all_emb=False)
logits_ori=(z_ori @ z_ori.t())
logits_ori=logits_ori.detach()

In [312]:
neg_size = data.df_aug_mask.sum()
neg_edge = negative_sampling(
    edge_index=data.aug_pos_edge_index1,
    num_nodes=data.num_nodes,
    num_neg_samples=neg_size)

In [313]:
neg_size

tensor(200, device='cuda:0')

In [314]:
def eval_val(model,weight):
    model.eval()
    weight.detach()
    perfs = []
    for prefix in ["val", "test"]:
        pos_edge_index = data[f"{prefix}_pos_edge_index"]
        neg_edge_index = data[f"{prefix}_neg_edge_index"]
        with torch.no_grad():
            z = model(data.x, data.aug_pos_edge_index1,weight)
            link_logits = model.decode(z, pos_edge_index, neg_edge_index)
        link_probs = link_logits.sigmoid()
        link_labels = get_link_labels(pos_edge_index, neg_edge_index)
        auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())
        perfs.append(auc)
    return perfs

In [315]:
def eval_forget(model,weight):
    model.eval()
    weight.detach()
    neg_edge_index=data.aug_pos_edge_index1[:,data.df_aug_mask]
    pos_edge_index=data.aug_pos_edge_index1[:,same[torch.randperm(same.shape[0])[:edge_to_delete]]]
    with torch.no_grad():
        z = model(data.x, data.aug_pos_edge_index1,weight)
        link_logits = model.decode(z, pos_edge_index, neg_edge_index)
    link_probs = link_logits.sigmoid()
    link_labels = get_link_labels(pos_edge_index, neg_edge_index)
    auc = roc_auc_score(link_labels.cpu(), link_probs.cpu())
    return auc
    

In [316]:
ckpt = {
            'model_state': model_pru.state_dict(),
            'optimizer_state': optimizer_pru.state_dict(),
        }
best_forget=1
best_auc=0
select="forget"

In [317]:
data.aug_pos_edge_index1.shape

torch.Size([2, 6396])

In [318]:
data.sdf_mask.shape

torch.Size([6396])

In [305]:
z_ori.requires_grad

False

In [None]:

##real
batch_size=256
weight=0.0000001
#weight=0
l_num=0
max_edge=20
max_per_iter=3
add_edge=0
edge_add=[]
for epoch in range(5000):
    batch_idx=target_node[torch.randperm(len(target_node))[0:batch_size]]
    #aug.train()
    for i in range(1):
        weight_aug1=adj_sampled_aug[full_pos_edge_index[0],full_pos_edge_index[1]]
        weight_aug=weight_ori+m*weight_aug1*800
        z_aug = model_aug(data.x, data.aug_pos_edge_index,weight_aug,return_all_emb=False)
        #print(z_aug.requires_grad)
        pos_sim_padded = []
        neg_sim_batch=[]
        mask_padded = []
        max_num_pos=5
        #batch_idx=torch.randperm(N)[0:batch_size]
        #print(batch_idx)
        #batch_idx=target_node[torch.randperm(len(target_node))[0:batch_size]]
        for idx in batch_idx:
            #print(idx)
            z_pos_aug=z_aug[pos_node[idx],:]
            z_neg_aug=z_aug[neg_node[idx],:]
            pos_sim=torch.matmul(z_pos_aug,z_aug[idx].unsqueeze(1)).T
            #print("pos_sim",pos_sim.shape)
            neg_sim=torch.matmul(z_neg_aug,z_aug[idx].unsqueeze(1)).T
            #print("neg_sim",neg_sim.shape)
            neg_sim_batch.append(neg_sim)
            padding_size = max_num_pos - pos_sim.size(1)
            #print(F.pad(pos_sim, (0, padding_size), value=0.0))
            pos_sim_padded.append(F.pad(pos_sim, (0, padding_size), value=0.0))  # Pad with 0 for missing positives
            #print("append",len(pos_sim_padded))
            mask_padded.append(F.pad(torch.ones_like(pos_sim), (0, padding_size), value=0.0))  # Mask for real values
            # Stack the padded similarities and masks
        pos_sim_padded = torch.stack(pos_sim_padded,dim=0).squeeze(dim=1)  # [batch_size, max_num_positives]
        neg_sim_batch=torch.stack(neg_sim_batch,dim=0).squeeze(dim=1)
        mask_padded = torch.stack(mask_padded,dim=0).squeeze(dim=1)  # [batch_size, max_num_positives]
        #z_aug = model_aug(data.x, data.aug_pos_edge_index,weight_aug,return_all_emb=False)
        # Concatenate negative and padded positive similarities to form logits
        logits = torch.cat([neg_sim_batch, pos_sim_padded], dim=1)  # [batch_size, num_negatives + max_num_positives]

        # Create labels: 0 for negative pairs, 1 for real positive pairs (mask out padded entries)
        labels = torch.cat([torch.zeros_like(neg_sim_batch), mask_padded], dim=1)  # [batch_size, num_negatives + max_num_positives]
        link_probs = logits.sigmoid()
        # Compute Binary Cross-Entropy loss without reduction to apply mask
        loss = loss_fn(link_probs, labels)

        # Mask the padded positions by multiplying the loss by the mask (for positive part)
        mask = torch.cat([torch.ones_like(neg_sim_batch), mask_padded], dim=1)  # Mask for ignoring padded logits
        #print("z_aug",z_aug.requires_grad)
        loss = loss * mask  # Apply mask to ignore padded positions
        for name, param in model_aug.named_parameters():
            param.requires_grad=False
        l1_loss=torch.abs(weight_aug[re_edge_num:]).sum()
        loss_total=(loss.sum()/mask.sum())+l1_loss*weight
        
        #optimizer_aug.zero_grad()
        loss_total.backward(retain_graph=True)
        #print("loss_total",loss_total)
        optimizer_aug.step()
        optimizer_aug.zero_grad()
        z_aug_detach=z_aug.detach()
        """
        for name, param in aug.named_parameters():
            param.requires_grad=False
        print("z_aug1",z_aug.requires_grad)
        """
        adj_sampled_aug,adj_logits_aug = aug(z_aug_detach)
        #adj_sampled_aug,adj_logits_aug = aug(z_ori.cuda())
        #l=weight_aug[re_edge_num:]>0.9
        l=(m*weight_aug1*800)>0.9
        l_num=l.sum()

In [306]:
for epoch in range(3000):
    model_pru.train()
    adj_sampled,adj_logits = mask(z_ori.cuda())
    weight_mask=adj_sampled[data.aug_pos_edge_index1[0],data.aug_pos_edge_index1[1]]
    weight=torch.ones(data.aug_pos_edge_index1.shape[1]).to(device)-weight_mask*600
    z = model_pru(data.x, data.aug_pos_edge_index1,weight)
    neg_size = data.df_aug_mask.sum()
    neg_edge_index = negative_sampling(
    edge_index=data.aug_pos_edge_index1,
        num_nodes=data.num_nodes,
        num_neg_samples=neg_size)
    df_logits = model_pru.decode(z, data.aug_pos_edge_index1[:, data.df_aug_mask], neg_edge_index)
    loss_size=torch.sum(weight_mask)
    loss_r = loss_fct(df_logits[:neg_size], df_logits[neg_size:])
    if sdf2_all_pair_without_df_mask.sum() != 0:
        logits_sdf = (z @ z.t())[sdf2_all_pair_without_df_mask].sigmoid()
        loss_l = loss_fct(logits_sdf, logits_ori[sdf2_all_pair_without_df_mask].sigmoid())
    alpha = 0.5
    beta=3
    loss = alpha * loss_r + (1 - alpha) * loss_l+beta*loss_size
    loss.backward(retain_graph=True)
    # torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
    optimizer_pru.step()
    #optimizer_pru.zero_grad()
    optimizer_mask.step()
    #optimizer_mask.zero_grad()
    
    step_log = {
        'Epoch': epoch,
        'train_loss': loss.item(),
        'loss_r': loss_r.item(),
        'loss_l': loss_l.item(),
        'loss_size': loss_size.item(),
    }
    for name, param in model_ori.named_parameters():
            param.requires_grad=False
    if(epoch%100==0):
        print(step_log)
    if(epoch%100==0):
        p=eval_val(model_pru,weight)
        val_perf, tmp_test_perf = p
        print("val",val_perf,"test",tmp_test_perf)
        auc_forget=eval_forget(model_pru,weight)
        print("forget",auc_forget)
        #print("best_forget",best_forget)
        print("weight_min",weight.min())
        if(select=="forget"):
            if(auc_forget<best_forget):
                best_forget=auc_forget
                ckpt = {
                'model_state': model_pru.state_dict(),
                'optimizer_state': optimizer_pru.state_dict(),
            }
                torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_delete.pt'))
                print("savebyforget")
        else:
            if(tmp_test_perf>best_auc):
                tmp_test_perf>best_auc
                best_auc=tmp_test_perf
                ckpt = {
                'model_state': model_pru.state_dict(),
                'optimizer_state': optimizer_pru.state_dict(),
            }
                #torch.save(ckpt, os.path.join(args.checkpoint_dir, 'model_delete.pt'))
                print("savebyauc")
    #wandb.log(step_log)


{'Epoch': 0, 'train_loss': 6.032325744628906, 'loss_r': 0.3226485848426819, 'loss_l': 0.03424762934446335, 'loss_size': 1.922452449798584}
val 0.763069677575172 test 0.734039971017993
forget 0.6751999999999999
weight_min tensor(0.4590, device='cuda:0', grad_fn=<MinBackward1>)
savebyforget
{'Epoch': 100, 'train_loss': 5.903014659881592, 'loss_r': 0.16103118658065796, 'loss_l': 0.03415975347161293, 'loss_size': 1.9224525690078735}
val 0.7567612607173047 test 0.7279289940828402
forget 0.7150500000000001
weight_min tensor(0.4589, device='cuda:0', grad_fn=<MinBackward1>)
{'Epoch': 200, 'train_loss': 6.069447994232178, 'loss_r': 0.3690885901451111, 'loss_l': 0.034097906202077866, 'loss_size': 1.922452449798584}
val 0.7576138147566719 test 0.7282345127400072
forget 0.73625
weight_min tensor(0.4590, device='cuda:0', grad_fn=<MinBackward1>)
{'Epoch': 300, 'train_loss': 5.93104887008667, 'loss_r': 0.19609114527702332, 'loss_l': 0.0340939462184906, 'loss_size': 1.922452449798584}
val 0.7580557903

KeyboardInterrupt: 

In [309]:
for name, param in mask.named_parameters():
            print(name,param.requires_grad)

Aaug.gcn_mean.0.weight True
Aaug.gcn_mean.0.bias True
Aaug.gcn_mean.1.weight True
Aaug.gcn_mean.1.bias True
Aaug.gcn_mean.3.weight True
Aaug.gcn_mean.3.bias True


In [308]:
weight.shape

torch.Size([6396])

In [233]:
for name, param in model_ori.named_parameters():
    print(name,param.requires_grad)

conv1.bias True
conv1.lin.weight True
conv2.bias True
conv2.lin.weight True


In [176]:
weight.min()

tensor(0.9997, device='cuda:0', grad_fn=<MinBackward1>)