In [100]:
from utils import process
import scipy.sparse as sp
import numpy as np
import math
import torch
from torch import nn
from torch.nn import functional as F

In [101]:


class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, heads=1, dropout=0.6, temp=0.5):
        super(GATLayer, self).__init__()
        assert out_features % heads == 0
        self.out_features = out_features
        self.in_features = in_features
        self.heads = heads
        self.dropout = dropout

        # Weight matrices
        self.W = nn.Parameter(torch.zeros(size=(in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain=1.414)

        # Attention coefficients
        self.a = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain=1.414)

        self.a_mask = nn.Parameter(torch.zeros(size=(2*out_features, 1)))
        nn.init.xavier_uniform_(self.a_mask.data, gain=1.414)

        self.leakyrelu = nn.LeakyReLU(0.2)
        self.temp = temp
    def forward(self, x, adj):
        Wh = torch.mm(x, self.W)  # Linear transformation
        bsz = Wh.size(0)

        # Only consider edges that actually exist (i.e., where adj is nonzero)
        edges_id = adj.nonzero(as_tuple=False)

        # Extract features for the source and target nodes of each edge
        Wh1 = Wh[edges_id[:, 0], :]  # Source node features for each edge
        Wh2 = Wh[edges_id[:, 1], :]  # Target node features for each edge
        
        # Concatenate features from source and target nodes
        e_feat = torch.cat([Wh1, Wh2], dim=1)

        # Apply the shared attention mechanism to every edge
        e = self.leakyrelu(torch.matmul(e_feat, self.a).squeeze(1))
        attention = torch.zeros(bsz, bsz).to(x.device)
        attention[edges_id[:, 0], edges_id[:, 1]] = e

        # Apply mask
        e_mask = self.leakyrelu(torch.matmul(e_feat, self.a_mask).squeeze(1))
        e_mask = torch.sigmoid(e_mask / self.temp)
        mask = torch.zeros(bsz, bsz).to(x.device)
        mask[edges_id[:, 0], edges_id[:, 1]] = e_mask

        attention = attention * mask
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, self.dropout, training=self.training)

        # Apply attention to node features
        h_prime = torch.matmul(attention, Wh)

        if self.heads > 1:
            # If multiple heads, split and concat
            h_prime = h_prime.view(bsz, self.heads, self.out_features // self.heads)
            h_prime = torch.mean(h_prime, dim=1)

        return h_prime

    def get_mask(self, x, adj, threshold=0.5):
        Wh = torch.mm(x, self.W)  # Linear transformation
        bsz = Wh.size(0)

        # Only consider edges that actually exist (i.e., where adj is nonzero)
        edges_id = adj.nonzero(as_tuple=False)

        # Extract features for the source and target nodes of each edge
        Wh1 = Wh[edges_id[:, 0], :]  # Source node features for each edge
        Wh2 = Wh[edges_id[:, 1], :]  # Target node features for each edge
        
        # Concatenate features from source and target nodes
        e_feat = torch.cat([Wh1, Wh2], dim=1)

        # Apply mask
        e_mask = self.leakyrelu(torch.matmul(e_feat, self.a_mask).squeeze(1))
        e_mask = torch.sigmoid(e_mask / self.temp)
        mask = torch.zeros(bsz, bsz).to(x.device)
        mask[edges_id[:, 0], edges_id[:, 1]] = e_mask
        
        # Convert mask to binary using threshold
        binary_mask = (mask > threshold).float()

        return binary_mask

In [102]:
class node_encoder(nn.Module):
    def __init__(self, feat_dim=512, hidden_dim=128, reparam_dim=64, latent_dim=32):
        super(node_encoder, self).__init__()
        self.feat_encode = nn.Linear(feat_dim, hidden_dim)
        self.neighbor_encode = GATLayer(feat_dim, hidden_dim)
        self.latent_encode = nn.Linear(hidden_dim*2, reparam_dim)
        self.mean = nn.Linear(reparam_dim, latent_dim)
        self.log_var = nn.Linear(reparam_dim, latent_dim)

    def reparameterize(self, mean, log_var):
        eps = torch.randn_like(log_var)
        z = mean + eps * torch.exp(log_var * 0.5)
        return z
    
    def forward(self, x, normalized_adj, pos_emb):
        feat = F.relu(self.feat_encode(x))
        neighbor_feat = F.relu(self.neighbor_encode(x + pos_emb, normalized_adj))
        feat = torch.cat([feat, neighbor_feat], dim=1)
        feat = F.relu(self.latent_encode(feat))
        mean = self.mean(feat)
        log_var = self.log_var(feat)
        z = self.reparameterize(mean, log_var)
        return z, mean, log_var

class node_decoder(nn.Module):
    def __init__(self, feat_dim=512, hidden_dim=256, reparam_dim=128, latent_dim=64, seq_len=2708):
        super(node_decoder, self).__init__()
        self.latent_decode = nn.Linear(latent_dim, reparam_dim)
        self.reparam_decode = nn.Linear(reparam_dim, hidden_dim*2)
        self.feat_decode = nn.Linear(hidden_dim, feat_dim)
        self.neighbor_decode = nn.Linear(hidden_dim, seq_len)

    def forward(self, z, temp=0.5):
        z = F.relu(self.latent_decode(z))
        z = F.relu(self.reparam_decode(z))
        # split z into two parts
        z = torch.chunk(z, 2, dim=-1)
        feat = z[0]
        neighbor_feat = z[1]
        feat = self.feat_decode(feat)
        neighbor_feat = self.neighbor_decode(neighbor_feat)
        feat = torch.sigmoid(feat)
        neighbor_map = torch.sigmoid(neighbor_feat/temp)
        # make neighbor_feat sharper
        return feat, neighbor_map



class node_vae(nn.Module):
    def __init__(self, feat_dim=512, hidden_dim=256, reparam_dim=128, latent_dim=64, seq_len=2708):
        super(node_vae, self).__init__()
        self.encoder = node_encoder(feat_dim, hidden_dim, reparam_dim, latent_dim)
        self.decoder = node_decoder(feat_dim, hidden_dim, reparam_dim, latent_dim)

    def forward(self, feat, neighbor_feat, pos_emb):
        z, mean, log_var = self.encoder(feat, neighbor_feat, pos_emb)
        feat, neighbor_map = self.decoder(z)
        return feat, neighbor_map, mean, log_var


def create_positional_embeddings(seq_len, emb_dim):
    """Create positional embeddings."""
    # Initialize the matrix with zeros
    position = torch.arange(seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))

    # Calculate positional encodings
    positional_embeddings = torch.zeros(seq_len, emb_dim)
    positional_embeddings[:, 0::2] = torch.sin(position * div_term)
    positional_embeddings[:, 1::2] = torch.cos(position * div_term)

    return positional_embeddings


In [103]:
dataset = 'cora'
adj, features, labels, idx_train, idx_val, idx_test = process.load_data(dataset)
print(adj.shape)
adj_with_self_loops = adj + sp.eye(adj.shape[0])
normalized_adj = process.normalize_adj(adj_with_self_loops)
normalized_adj = torch.FloatTensor(normalized_adj.todense())
embeds = np.load('/home/local/ASUAD/ywan1053/graph_diffusion/generate_graph_embbedding/gcl_embeddings/cora/all_data/all_embs.npy')
embeds = torch.FloatTensor(embeds)

data_min = embeds.min()
data_max = embeds.max()
embeds_normalized = (embeds - data_min) / (data_max - data_min)

(2708, 2708)


In [104]:
seq_len = embeds_normalized.shape[0]  # Length of your sequence
emb_dim = embeds_normalized.shape[1]  # Embedding dimensions

positional_embeddings = create_positional_embeddings(seq_len, emb_dim)

In [105]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
embeds_normalized = embeds_normalized.to(device)
positional_embeddings = positional_embeddings.to(device)
vae_model = node_vae(feat_dim=emb_dim, hidden_dim=256, reparam_dim=128, latent_dim=64, seq_len=2708).to(device)
adj_with_self_loop = torch.FloatTensor(adj_with_self_loops.todense()).to(device)
normalized_adj = normalized_adj.to(device)

In [106]:
optimizer = torch.optim.Adam(vae_model.parameters(), lr=0.0001, weight_decay=1e-5)


In [107]:
edge_mask = adj_with_self_loop

for epoch in range(500):
    vae_model.train()
    optimizer.zero_grad()
    decoded_feat, neighbor_map, mean, log_var = vae_model(embeds_normalized, normalized_adj, positional_embeddings)
    
    l2_loss = F.mse_loss(decoded_feat, embeds_normalized, reduction='mean')
    edge_loss = F.mse_loss(torch.matmul(decoded_feat, decoded_feat.T), adj_with_self_loop*edge_mask, reduction='mean')
    bce_loss = F.binary_cross_entropy(neighbor_map, adj_with_self_loop, reduction='mean')

    factor = 0.99999
    # You may set the factor for BCE loss based on your previous experiments
    loss = factor * l2_loss + (1-factor) * edge_loss + bce_loss
    loss.backward()
    optimizer.step()

    print('Epoch: {}, l2_loss: {}, edge_loss: {}, bce_loss: {}'.format(epoch, l2_loss.item(), edge_loss.item(), bce_loss.item()))
    with torch.no_grad():
        edge_mask = vae_model.encoder.neighbor_encode.get_mask(embeds_normalized, normalized_adj)

Epoch: 0, l2_loss: 0.040118537843227386, edge_loss: 16546.31640625, bce_loss: 0.7005627751350403
Epoch: 1, l2_loss: 0.03979618847370148, edge_loss: 16436.994140625, bce_loss: 0.6977888941764832
Epoch: 2, l2_loss: 0.03946947678923607, edge_loss: 16325.162109375, bce_loss: 0.694924533367157
Epoch: 3, l2_loss: 0.03917240723967552, edge_loss: 16224.6181640625, bce_loss: 0.6923016905784607
Epoch: 4, l2_loss: 0.038868498057127, edge_loss: 16118.7509765625, bce_loss: 0.6896629333496094
Epoch: 5, l2_loss: 0.03857100382447243, edge_loss: 16014.8857421875, bce_loss: 0.6870497465133667
Epoch: 6, l2_loss: 0.03829038515686989, edge_loss: 15918.1513671875, bce_loss: 0.684459388256073
Epoch: 7, l2_loss: 0.0379783920943737, edge_loss: 15811.037109375, bce_loss: 0.6818532943725586
Epoch: 8, l2_loss: 0.037694696336984634, edge_loss: 15707.703125, bce_loss: 0.6791806817054749
Epoch: 9, l2_loss: 0.03740587458014488, edge_loss: 15604.8828125, bce_loss: 0.6764006018638611
Epoch: 10, l2_loss: 0.0371074154973

In [108]:
    # loss = F.binary_cross_entropy(A_pred, binary_A, reduction='mean')
