In [23]:
from utils import process
import scipy.sparse as sp
import numpy as np
import torch
import math
import torch.nn as nn
import torch.nn.functional as F

class node_encoder(nn.Module):
    def __init__(self, feat_dim=512, hidden_dim=256, reparam_dim=128, latent_dim=64):
        super(node_encoder, self).__init__()
        self.feat_encode = nn.Linear(feat_dim, hidden_dim)
        self.neighbor_encode = nn.Linear(feat_dim, hidden_dim)
        self.latent_encode = nn.Linear(hidden_dim*2, reparam_dim)
        self.mean = nn.Linear(reparam_dim, latent_dim)
        self.log_var = nn.Linear(reparam_dim, latent_dim)

    def reparameterize(self, mean, log_var):
        eps = torch.randn_like(log_var)
        z = mean + eps * torch.exp(log_var * 0.5)
        return z
    
    def forward(self, feat, neighbor_feat):
        feat = F.relu(self.feat_encode(feat))
        neighbor_feat = F.relu(self.neighbor_encode(neighbor_feat))
        feat = torch.cat([feat, neighbor_feat], dim=1)
        feat = F.relu(self.latent_encode(feat))
        mean = self.mean(feat)
        log_var = self.log_var(feat)
        z = self.reparameterize(mean, log_var)
        return z, mean, log_var

class node_decoder(nn.Module):
    def __init__(self, feat_dim=512, hidden_dim=256, reparam_dim=128, latent_dim=64, seq_len=2708):
        super(node_decoder, self).__init__()
        self.latent_decode = nn.Linear(latent_dim, reparam_dim)
        self.reparam_decode = nn.Linear(reparam_dim, hidden_dim*2)
        self.feat_decode = nn.Linear(hidden_dim, feat_dim)
        self.neighbor_decode = nn.Linear(hidden_dim, seq_len)

    def forward(self, z, temp=0.5):
        z = F.relu(self.latent_decode(z))
        z = F.relu(self.reparam_decode(z))
        # split z into two parts
        z = torch.chunk(z, 2, dim=-1)
        feat = z[0]
        neighbor_feat = z[1]
        feat = self.feat_decode(feat)
        neighbor_feat = self.neighbor_decode(neighbor_feat)
        feat = torch.sigmoid(feat)
        neighbor_map = torch.sigmoid(neighbor_feat/temp)
        # make neighbor_feat sharper
        return feat, neighbor_map

class node_vae(nn.Module):
    def __init__(self, feat_dim=512, hidden_dim=256, reparam_dim=128, latent_dim=64, seq_len=2708):
        super(node_vae, self).__init__()
        self.encoder = node_encoder(feat_dim, hidden_dim, reparam_dim, latent_dim)
        self.decoder = node_decoder(feat_dim, hidden_dim, reparam_dim, latent_dim)

    def forward(self, feat, neighbor_feat):
        z, mean, log_var = self.encoder(feat, neighbor_feat)
        feat, neighbor_map = self.decoder(z)
        return feat, neighbor_map, mean, log_var


def create_positional_embeddings(seq_len, emb_dim):
    """Create positional embeddings."""
    # Initialize the matrix with zeros
    position = torch.arange(seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, emb_dim, 2) * -(math.log(10000.0) / emb_dim))

    # Calculate positional encodings
    positional_embeddings = torch.zeros(seq_len, emb_dim)
    positional_embeddings[:, 0::2] = torch.sin(position * div_term)
    positional_embeddings[:, 1::2] = torch.cos(position * div_term)

    return positional_embeddings

In [24]:
dataset = 'cora'
adj, features, labels, idx_train, idx_val, idx_test = process.load_data(dataset)

norm_adj = process.normalize_adj(adj + sp.eye(adj.shape[0]))
norm_adj = torch.FloatTensor(norm_adj.todense())
embeds = np.load('/home/local/ASUAD/ywan1053/graph_diffusion/generate_graph_embbedding/gcl_embeddings/cora/all_data/all_embs.npy')
embeds = torch.FloatTensor(embeds)

data_min = embeds.min()
data_max = embeds.max()
embeds_normalized = (embeds - data_min) / (data_max - data_min)

In [25]:
seq_len = embeds_normalized.shape[0]  # Length of your sequence
emb_dim = embeds_normalized.shape[1]  # Embedding dimensions

positional_embeddings = create_positional_embeddings(seq_len, emb_dim)
all_neighbor_feats = torch.matmul(norm_adj, embeds_normalized+positional_embeddings)

In [26]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
embeds_normalized = embeds_normalized.to(device)
all_neighbor_feats = all_neighbor_feats.to(device)
vae_model = node_vae(feat_dim=emb_dim, hidden_dim=256, reparam_dim=128, latent_dim=64, seq_len=2708).to(device)
neighbor_map_gt = adj + sp.eye(adj.shape[0])
neighbor_map_gt = torch.FloatTensor(neighbor_map_gt.todense()).to(device)
print(neighbor_map_gt.shape)
print(neighbor_map_gt)


torch.Size([2708, 2708])
tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 0., 0.],
        [0., 1., 1.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 1., 0., 0.],
        [0., 0., 0.,  ..., 0., 1., 1.],
        [0., 0., 0.,  ..., 0., 1., 1.]], device='cuda:0')


In [27]:
print(embeds_normalized.shape, all_neighbor_feats.shape)
print(embeds_normalized.max(), embeds_normalized.min())
print(all_neighbor_feats.max(), all_neighbor_feats.min())

reconstructed_feat, neighbor_map, _, _ = vae_model(embeds_normalized, all_neighbor_feats)

print(reconstructed_feat.shape, neighbor_map.shape)
print(reconstructed_feat.max(), reconstructed_feat.min())
print(neighbor_map.max(), neighbor_map.min())

# binary cross entropy loss between reconstructed neighbor map and ground truth neighbor map
bce_loss = F.binary_cross_entropy(neighbor_map, neighbor_map_gt, reduction='mean')
# l2 loss between reconstructed node features and ground truth node features
l2_loss = F.mse_loss(reconstructed_feat, embeds_normalized, reduction='mean')

torch.Size([2708, 512]) torch.Size([2708, 512])
tensor(1., device='cuda:0') tensor(0., device='cuda:0')
tensor(8.2051, device='cuda:0') tensor(-1.0343, device='cuda:0')
torch.Size([2708, 512]) torch.Size([2708, 512])
tensor(0.6285, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.3816, device='cuda:0', grad_fn=<MinBackward1>)
tensor(0.7435, device='cuda:0', grad_fn=<MaxBackward1>) tensor(0.2333, device='cuda:0', grad_fn=<MinBackward1>)
