In [68]:
import torch
import torch.nn as nn


#Euclidean distance
def pairwise_euclidean_distances(x, dim=-1):
    dist = torch.cdist(x,x)**2
    return dist, x

def pairwise_poincare_distances(x, dim=-1):
    x_norm = (x**2).sum(dim,keepdim=True)
    x_norm = (x_norm.sqrt()-1).relu() + 1 
    x = x/(x_norm*(1+1e-2))
    x_norm = (x**2).sum(dim,keepdim=True)
    
    pq = torch.cdist(x,x)**2
    dist = torch.arccosh(1e-6+1+2*pq/((1-x_norm)*(1-x_norm.transpose(-1,-2))))**2
    return dist, x

def sparse_eye(size):
    """
    Returns the identity matrix as a sparse matrix
    """
    indices = torch.arange(0, size).long().unsqueeze(0).expand(2, size)
    values = torch.tensor(1.0).float().expand(size)
    cls = getattr(torch.sparse, values.type().split(".")[-1])
    return cls(indices, values, torch.Size([size, size])) 



class DGM_d(nn.Module):


    def __init__(self, embed_f, k=5, distance=pairwise_euclidean_distances, sparse=True):
        super(DGM_d, self).__init__()
        
        self.sparse=sparse
        
        self.temperature = nn.Parameter(torch.tensor(1. if distance=="hyperbolic" else 4.).float())
        self.embed_f = embed_f
        self.centroid=None
        self.scale=None
        self.k = k
        
        self.debug=False
        if distance == 'euclidean':
            self.distance = pairwise_euclidean_distances
        else:
            self.distance = pairwise_poincare_distances
        
    def forward(self, x, A, batch = None, not_used=None, fixedges=None):
        x = self.embed_f(x,A)  

        if self.training:

            # if fixedges is not None:                
            #     return x, fixedges, torch.zeros(fixedges.shape[0],fixedges.shape[-1]//self.k,self.k,dtype=torch.float,device=x.device)

            D, _x = self.distance(x)
            if batch is None:
            #sampling here
                edges_hat, logprobs = self.sample_without_replacement(D)
            elif batch is not None: 
                edges_hat, logprobs = self.sample_without_replacement_batch(D, batch)
                
        else:
            with torch.no_grad():
                # if fixedges is not None:                
                #     return x, fixedges, torch.zeros(fixedges.shape[0],fixedges.shape[-1]//self.k,self.k,dtype=torch.float,device=x.device)
                D, _x = self.distance(x)

                #sampling here

                if batch is None:
                #sampling here
                    edges_hat, logprobs = self.sample_without_replacement(D)
                elif batch is not None: 
                    edges_hat, logprobs = self.sample_without_replacement_batch(D, batch)

              
        if self.debug:
            self.D = D
            self.edges_hat=edges_hat
            self.logprobs=logprobs
            self.x=x

        return x, edges_hat, logprobs
    

    def sample_without_replacement(self, logits):

        # logits is a matrix tot_num_nodes_batch x tot_num_nodes_batch with pairwise distances between nodes. 

        b,n,_ = logits.shape
#         logits = logits * torch.exp(self.temperature*10)
        logits = logits * torch.exp(torch.clamp(self.temperature,-5,5))
        
        q = torch.rand_like(logits) + 1e-8
        lq = (logits-torch.log(-torch.log(q)))
        logprobs, indices = torch.topk(-lq,self.k)  
    
        rows = torch.arange(n).view(1,n,1).to(logits.device).repeat(b,1,self.k)
        edges = torch.stack((indices.view(b,-1),rows.view(b,-1)),-2)
        
        if self.sparse:
            return (edges+(torch.arange(b).to(logits.device)*n)[:,None,None]).transpose(0,1).reshape(2,-1), logprobs
        return edges, logprobs


    def sample_without_replacement_batch(self, logits, batch):
        """
        This function samples k edges without replacement for each graph in the batch.
        Args:
            logits: Tensor of shape (total_nodes, total_nodes) containing the logits for each pair of nodes.
            batch: Tensor of shape (total_nodes,) indicating which graph each node belongs to.
        Returns:
            edges: Tensor of shape (2, total_edges) containing the sampled edges.
            logprobs: Tensor of shape (batch_size, k) containing the log probabilities of the sampled edges.
        """

        device = logits.device
        unique_graphs = batch.unique(sorted=True)
        edges_list = []
        logprobs_list = []

        logits = logits * torch.exp(torch.clamp(self.temperature, -5, 5))
        q = torch.rand_like(logits) + 1e-8 
        lq = logits - torch.log(-torch.log(q))   # e pure questo 


        for graph_id in unique_graphs:

            mask = (batch == graph_id)
            lq_i = lq[mask][:, mask] 
            num_nodes_i = lq_i.size(0) 

            
            logprobs, indices = torch.topk(lq_i, self.k, largest = False)  # i topk edge per il grafo i-esimo (largest = False per selezionare i più piccoli 
                                                                         # ovvero quelli più simili)


            rows = torch.arange(num_nodes_i, device=device).view(-1, 1).expand_as(indices)
            
            edges = torch.stack((rows.reshape(-1), indices.view(-1))) 
        
            global_indices = mask.nonzero(as_tuple=True)[0]
            
            edges = global_indices[edges]
            # cosa sono global indices
            
            edges_list.append(edges)
            logprobs_list.append(logprobs)
        
        first_elements = torch.cat([edge[0] for edge in edges_list], dim=0)
        second_elements = torch.cat([edge[1] for edge in edges_list], dim=0)

        all_edges = torch.stack((first_elements, second_elements), dim=0)
        
        #all_edges = torch.cat(edges_list, dim=0)
        
        all_logprobs = torch.cat(logprobs_list, dim=0)

        # if self.sparse:
        #     all_edges = all_edges.transpose(0, 1).reshape(2, -1)
        
        return all_edges, all_logprobs



class SimpleEmbedder(nn.Module):
    def forward(self, x, A):
        return x

embed_f = SimpleEmbedder()
model = DGM_d(embed_f=embed_f, k=3, distance='euclidean', sparse=True)

# Synthetic data
batch_size = 4
nodes_per_graph = 10
total_nodes = batch_size * nodes_per_graph

# Random logits matrix (distance matrix) of shape (total_nodes, total_nodes)
logits = torch.rand(total_nodes, total_nodes)

# Batch tensor indicating which graph each node belongs to
batch = torch.tensor([[i] * nodes_per_graph for i in range(0,batch_size)])
batch = batch.flatten()


In [67]:

# Simulate an adjacency matrix A and input node features x
# For this test, A and x are just random tensors (but they are not used in this specific test)
A = torch.rand(total_nodes, total_nodes)
x = torch.rand(total_nodes, 10)  # Assume node features of size 10

# Forward pass
model.eval()  # Set the model to evaluation mode
with torch.no_grad():
    x_out, edges_hat, logprobs = model(x, A, batch=batch)

# # Print the results
# print("Sampled Edges (without replacement):")
# print(edges_hat)

# print("\nLog Probabilities of Sampled Edges:")
# print(logprobs)


10
tensor([[-3.9177e-01,  1.5277e+02,  1.1645e+02,  1.2249e+02,  8.1234e+01,
          9.7897e+01,  1.3260e+02,  1.1001e+02,  1.2138e+02,  1.8217e+02],
        [ 1.5211e+02,  4.3935e-01,  8.2124e+01,  6.7404e+01,  4.7698e+01,
          1.8141e+02,  1.2856e+02,  5.7914e+01,  7.1609e+01,  2.0694e+02],
        [ 1.1629e+02,  8.3232e+01, -4.3970e-01,  7.6239e+01,  6.7829e+01,
          1.3420e+02,  1.6918e+02,  9.1843e+01,  7.2711e+01,  9.8179e+01],
        [ 1.2218e+02,  7.0103e+01,  7.9394e+01, -1.3657e+00,  6.4822e+01,
          1.1010e+02,  5.9941e+01,  9.5854e+01,  7.9720e+01,  9.0749e+01],
        [ 7.8756e+01,  4.8422e+01,  6.8959e+01,  6.5910e+01,  8.6457e-01,
          1.4200e+02,  1.3343e+02,  7.3247e+01,  9.6588e+01,  1.9066e+02],
        [ 9.9821e+01,  1.8120e+02,  1.3423e+02,  1.1184e+02,  1.4544e+02,
          1.9095e+00,  9.0962e+01,  1.1977e+02,  1.2075e+02,  8.9934e+01],
        [ 1.3178e+02,  1.2971e+02,  1.6856e+02,  6.2633e+01,  1.3528e+02,
          9.1127e+01, -1.5554