In [97]:
from torch import Tensor, tensor, empty, cat, arange, int64
from torch_sparse import SparseTensor
from typing import Optional

from torch import unique

class graphTorch():
    
    def __init__(self):
        self.node_identifiers = empty(0, dtype = int64)
        self.node_data = {}
        self.edge_data = {}
        self._counter = 0
        
    def get_node_IDs(self):
        return self.node_identifiers
    
    def get_relative_node_indices(self, node_ids: Tensor):
        identifiers = self.node_identifiers.unsqueeze(dim = 1)
        mask = identifiers == torch.tensor(node_ids).unsqueeze(dim = 0)
        relative_indices = torch.arange(self.node_identifiers.shape[0]).unsqueeze(dim = 1) * mask
        return relative_indices.sum(dim = 0)    

    def get_node_data(self, node_type: str):
        assert (node_type in self.node_data), "the given key does not exist"
        return self.node_data[node_type]
    
    def get_edges(self, edge_type: str):
        assert (edge_type in self.edge_data), "the given key does not exist"
        return self.edge_data[edge_type]
                
    def add_nodes(self, nodes_data: dict):
        new_node_identifiers = arange(self._counter, self._counter + len(nodes_data[next(iter(nodes_data))]))
        self.node_identifiers = cat((self.node_identifiers, new_node_identifiers), dim = 0)
        self._counter += len(new_node_identifiers)
        for node_data_type, data in nodes_data.items():
            if node_data_type in self.node_data:
                self.node_data[node_data_type] = cat((self.node_data[node_data_type], data), dim=0)
            else:
                if self.node_identifiers.shape[0] != 0:
                    assert (data.shape[0] == self.node_identifiers.shape[0]), "number of nodes must be equal"
                self.node_data[node_data_type] = data  
    
    def add_edges(self, edge_type: str, U: Tensor, V: Tensor, directed: bool, values: Optional[Tensor]):
        if directed == False:
            U = cat((U,V), dim = 0)
            V = cat((V,U), dim = 0)
        N = self.node_identifiers.shape[0]
        if edge_type not in self.edge_data:
            self.edge_data[edge_type] = edges()
            self._assign_processed_edges(edge_type = edge_type, row = U, col = V, N = N, v = values)
        else:
            if (self.edge_data[edge_type].values == None) and (self.edge_data[edge_type].U != 0):
                raise "the given edge_type is weighted"
            row = cat((self.edge_data[edge_type].U, U), dim = 0)
            col = cat((self.edge_data[edge_type].V, V), dim = 0)
            if values is not None:
                values = cat((self.edge_data[edge_type].values, values), dim = 0)
            self._assign_processed_edges(edge_type, row = U, col = V, N = N, v = values)
    
    def _assign_processed_edges(self,edge_type:str, row: Tensor, col: Tensor, N: int, v: Optional[Tensor]):
        adj = SparseTensor(row=row, col=col, sparse_sizes=(N, N), value = v).coalesce()
        self.edge_data[edge_type].U = adj.storage.row()
        self.edge_data[edge_type].V = adj.storage.col()
        if v is not None:
            self.edge_data[edge_type].values = adj.storage.value() 
            
    def delete_edges(self, edge_type, U, V):
        E = cat((self.edge_data[edge_type].U.unsqueeze(0), self.edge_data[edge_type].V.unsqueeze(0)),dim=0).T
        Erem = cat((U.unsqueeze(0), V.unsqueeze(0)), dim=0).T
        mask = E.unsqueeze(1) == Erem
        mask = mask.all(-1)
        non_repeat_mask = ~mask.any(-1)
        self.edge_data[edge_type].U = self.edge_data[edge_type].U[non_repeat_mask]
        self.edge_data[edge_type].V = self.edge_data[edge_type].V[non_repeat_mask]
        if self.edge_data[edge_type].values is not None:
            self.edge_data[edge_type].values = self.edge_data[edge_type].values[non_repeat_mask]
    
    #def delete_nodes(self, node_type)
        
        
class edges:
    def __init__(self):
        self.U = empty(0)
        self.V = empty(0)
        self.values = None
        

dg = graphTorch()

from torch import rand
dg.add_nodes({"w": rand(5,6,6), "s": rand(5,4,4)})
#dg.add_nodes({"w": torch.rand(2,6,6), "s": torch.rand(2,4,4)})
dg.add_edges(edge_type = "w", U = tensor([1,2,1,4,2,0]), V = tensor([2,2,2,4,3,1]), values = tensor([2,3,9,10,0,2]), directed = True)

dg.delete_edges("w", U = tensor([4,1,2]), V = tensor([4,2,2]))
print(dg.edge_data["w"].U, dg.edge_data["w"].V, dg.edge_data["w"].values)

tensor([0, 2]) tensor([1, 3]) tensor([2, 0])


In [77]:
import torch 
a = torch.as_tensor([[0,0],[0,1],[0,2],[1,3],[1,4],[2,1],[2,4]])
b = torch.as_tensor([[0,1],[0,2],[1,4],[2,4]])
print(a.shape, b.shape)

# Expand a to (7, 1, 2) to broadcast to all b
a_exp = a.unsqueeze(1)

# c: (7, 4, 2) 
c = a_exp == b
# Since we want to know that all components of the vector are equal, we reduce over the last fim
# c: (7, 4)
c = c.all(-1)
print(c)
# Out: Each row i compares the ith element of a against all elements in b
# Therefore, if all row is false means that the a element is not present in b

non_repeat_mask = ~c.any(-1)

# Apply the mask to a
print(a[non_repeat_mask])

torch.Size([7, 2]) torch.Size([4, 2])
tensor([[False, False, False, False],
        [ True, False, False, False],
        [False,  True, False, False],
        [False, False, False, False],
        [False, False,  True, False],
        [False, False, False, False],
        [False, False, False,  True]])
tensor([[0, 0],
        [1, 3],
        [2, 1]])
