In [23]:
import torch 

# make random matrix with dim 1000x162
x = torch.rand(1000, 162).to(torch.float32).to('cuda').requires_grad_()
x

tensor([[0.3849, 0.2128, 0.3324,  ..., 0.6313, 0.0952, 0.8892],
        [0.2804, 0.3638, 0.2067,  ..., 0.3791, 0.6019, 0.0328],
        [0.3549, 0.0113, 0.6764,  ..., 0.0233, 0.1334, 0.6252],
        ...,
        [0.5554, 0.5693, 0.6836,  ..., 0.3386, 0.4260, 0.5500],
        [0.6957, 0.6013, 0.7092,  ..., 0.3077, 0.5732, 0.1856],
        [0.0157, 0.8280, 0.3015,  ..., 0.6044, 0.2103, 0.9776]],
       device='cuda:0', requires_grad=True)

In [35]:
%%timeit 
e = torch.mm(x,x.T.detach())
e = e.sum()
e.backward()

249 µs ± 33.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [38]:
%%timeit 
e = torch.mm(x,x.T.detach())
e = e.sum()
e.backward()

347 µs ± 4.25 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [33]:
e = torch.mm(x,x.T)
e = e.sum()
e.backward()

In [34]:
x.grad

tensor([[1001.9821,  997.3248, 1029.8417,  ...,  992.7962,  996.6891,
          985.8806],
        [1001.9821,  997.3248, 1029.8417,  ...,  992.7962,  996.6891,
          985.8806],
        [1001.9821,  997.3248, 1029.8417,  ...,  992.7962,  996.6891,
          985.8806],
        ...,
        [1001.9821,  997.3248, 1029.8417,  ...,  992.7962,  996.6891,
          985.8806],
        [1001.9821,  997.3248, 1029.8417,  ...,  992.7962,  996.6891,
          985.8806],
        [1001.9821,  997.3248, 1029.8417,  ...,  992.7962,  996.6891,
          985.8806]], device='cuda:0')

In [31]:
x.grad

tensor([[500.9910, 498.6624, 514.9208,  ..., 496.3981, 498.3446, 492.9403],
        [500.9910, 498.6624, 514.9208,  ..., 496.3981, 498.3446, 492.9403],
        [500.9910, 498.6624, 514.9208,  ..., 496.3981, 498.3446, 492.9403],
        ...,
        [500.9910, 498.6624, 514.9208,  ..., 496.3981, 498.3446, 492.9403],
        [500.9910, 498.6624, 514.9208,  ..., 496.3981, 498.3446, 492.9403],
        [500.9910, 498.6624, 514.9208,  ..., 496.3981, 498.3446, 492.9403]],
       device='cuda:0')

In [32]:
x.grad.data.zero_()

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0')

In [14]:
e = torch.mm(x,x.T)

In [18]:
%%timeit 
e = torch.mm(x,x.T)
e = e.sum()
e.backward()

351 µs ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
# triangular matrix covering 

In [198]:
import torch 
import torch.nn.functional as F
from torch_geometric.utils import dense_to_sparse

class GraphStructureLearner(torch.nn.Module):
    def __init__(self, input_channels, num_heads, attention_threshold_epsilon, is_heterogeneous=False, metric_type='weighted_cosine'):
        super().__init__()
        self.num_heads = num_heads
        self.attention_threshold_epsilon = attention_threshold_epsilon
        # weighted cosine
        # (h, 1, dim)
        if metric_type == 'weighted_cosine':
            self.weight_tensor = torch.nn.Parameter(torch.Tensor(num_heads, input_channels)).unsqueeze(1)
            self.weight_tensor = torch.nn.Parameter(torch.nn.init.xavier_uniform_(self.weight_tensor))
            # initialize weights
            
            self.forward = self.forward_weighted_cosine
            
    
    # non heterogeneous case
    def forward_weighted_cosine(self, X, batch, has_converged)->int:
        #  batch is the graph number for each node (the graph the node belongs to)
        
        X = X.unsqueeze(0) 
        
        # hadamard product by broadcasting
        # (h, 1, dim) * (1, n, dim) -> (h, n, dim)
        X = torch.multiply(self.weight_tensor, X)
        
        # get the minibatch indices (star1, end1), (start2, end2), ... for each graph in the batch
        changed = (batch.roll(-1) - batch).roll(1)
        changed[0] += batch.max()
        start = torch.cat([torch.tensor([0]).to(changed.device),torch.where(changed)[0], torch.tensor([batch.shape[-1]]).to(changed.device)])
        end= start.roll(-1)
        indices = torch.stack([start, end], dim=1)[:-1]

        # compute attention for minibatch, save minibatch graphs together in one sparse tensor
        edge_indices, edge_weights = [], []
        
        
        Xnorm = F.normalize(X, p=2, dim=-1) # all values 0 precision if not converted to float64
        
        last_index = 0
        not_converged = torch.argwhere(~has_converged).squeeze(1)
        for i, (start,end) in enumerate(indices):
            if i in not_converged:
                E = Xnorm[:,start:end,:]
                attention = torch.bmm(E, E.transpose(1,2))
                attention = torch.mean(attention, dim=0)
                # remove negative/small values under threshold epsilon
                # attention = torch.where(attention > self.attention_threshold_epsilon, attention, torch.tensor(0.0))
                mask = attention > self.attention_threshold_epsilon
                # attention = attention[~torch.argwhere(mask)] = 0  # not needed, we only select the others for backprop anyways
                
                edge_index = torch.argwhere(mask).T #  indices of non-zero values
                edge_weight = attention.masked_select(mask)
                edge_index += last_index
                edge_indices.append(edge_index)
                edge_weights.append(edge_weight)
                
            last_index += (end - start)  # or simpler ...
            

        attention_edge_indices = torch.cat(edge_indices, dim=1)
        attention_edge_weights = torch.cat(edge_weights, dim=0).unsqueeze(1)
        
        return attention_edge_indices, attention_edge_weights
    
    def forward(self, X, batch, has_converged)->int:
        return self.forward_weighted_cosine(X, batch, has_converged)
    
 

In [201]:
x = torch.rand(8000, 128).to(torch.float32).to('cuda')

# make 12 batches so 0,0,0 .... 1,... 11,11,11
batch = torch.arange(8000).to('cuda') // 666
# false foe each batch
has_converged = torch.zeros(12).to('cuda').bool()


In [202]:
structure_learner = GraphStructureLearner(128, 4, 0.1).cuda()
x1 = x.clone()
x1.requires_grad_(True)

tensor([[0.0986, 0.3072, 0.7209,  ..., 0.6020, 0.9548, 0.5166],
        [0.8276, 0.5639, 0.9150,  ..., 0.4976, 0.5375, 0.1376],
        [0.0347, 0.9868, 0.3028,  ..., 0.4488, 0.0330, 0.6578],
        ...,
        [0.2962, 0.8004, 0.6789,  ..., 0.8754, 0.9087, 0.9244],
        [0.1825, 0.0325, 0.8129,  ..., 0.8835, 0.2323, 0.0985],
        [0.4844, 0.7780, 0.6184,  ..., 0.9413, 0.7000, 0.5545]],
       device='cuda:0', requires_grad=True)

In [143]:
x1.grad

In [203]:
%%timeit 
torch.cuda.synchronize()
_,e = structure_learner(x1, batch, has_converged)
e = e.sum()
e.backward()
torch.cuda.synchronize()

21.2 ms ± 93 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [167]:
import torch 
import torch.nn.functional as F
from torch_geometric.utils import dense_to_sparse
import time 

class GraphStructureLearnerOptimtime(torch.nn.Module):
    def __init__(self, input_channels, num_heads, attention_threshold_epsilon, is_heterogeneous=False, metric_type='weighted_cosine'):
        super().__init__()
        self.num_heads = num_heads
        self.attention_threshold_epsilon = attention_threshold_epsilon
        # weighted cosine
        # (h, 1, dim)
        if metric_type == 'weighted_cosine':
            self.weight_tensor = torch.nn.Parameter(torch.Tensor(num_heads, input_channels)).unsqueeze(1)
            self.weight_tensor = torch.nn.Parameter(torch.nn.init.xavier_uniform_(self.weight_tensor))
            # initialize weights
            
            self.forward = self.forward_weighted_cosine
            
    
    # non heterogeneous case
    def forward_weighted_cosine(self, X, batch, has_converged)->int:
        #  batch is the graph number for each node (the graph the node belongs to)
        
        torch.cuda.synchronize()
        time1 = time.time()
        X = X.unsqueeze(0) 
        
        # hadamard product by broadcasting
        # (h, 1, dim) * (1, n, dim) -> (h, n, dim)
        X = torch.multiply(self.weight_tensor, X)
        
        # get the minibatch indices (star1, end1), (start2, end2), ... for each graph in the batch
        changed = (batch.roll(-1) - batch).roll(1)
        changed[0] += batch.max()
        start = torch.cat([torch.tensor([0]).to(changed.device),torch.where(changed)[0], torch.tensor([batch.shape[-1]]).to(changed.device)])
        end= start.roll(-1)
        indices = torch.stack([start, end], dim=1)[:-1]
        torch.cuda.synchronize()
        time2 = time.time()
        # compute attention for minibatch, save minibatch graphs together in one sparse tensor
        edge_indices, edge_weights = [], []
        
        
        Xnorm = F.normalize(X, p=2, dim=-1) # all values 0 precision if not converted to float64
        torch.cuda.synchronize()
        time3 = time.time()
        # Xnorm_detached = Xnorm.detach()
        last_index = 0
        not_converged = torch.argwhere(~has_converged).squeeze(1)
        for i, (start,end) in enumerate(indices):
            if i in not_converged:
                E = Xnorm[:,start:end,:]
                # Et = Xnorm_detached[:,start:end,:]
                attention = torch.bmm(E, E.transpose(1,2).detach())
                attention = torch.mean(attention, dim=0)
                # remove negative/small values under threshold epsilon
                # attention = torch.where(attention > self.attention_threshold_epsilon, attention, torch.tensor(0.0))
                mask = attention > self.attention_threshold_epsilon
                # attention = attention[~torch.argwhere(mask)] = 0  # not needed, we only select the others for backprop anyways
                
                edge_index = torch.argwhere(mask).T #  indices of non-zero values
                edge_weight = attention.masked_select(mask)
                edge_index += last_index
                edge_indices.append(edge_index)
                edge_weights.append(edge_weight)
                
            last_index += (end - start)  # or simpler ...
        
        torch.cuda.synchronize()
        timea = time.time()
        for i, (start,end) in enumerate(indices):
            if i in not_converged:
                E = Xnorm[:,start:end,:]
                # Et = Xnorm_detached[:,start:end,:]
                attention = torch.bmm(E, E.transpose(1,2).detach())
        torch.cuda.synchronize()
        timeb = time.time()
        weights1 = []
        weights2 = []
        for i, (start,end) in enumerate(indices):
            if i in not_converged:
                mask = attention > self.attention_threshold_epsilon
                edge_index = torch.argwhere(mask).T #  indices of non-zero values
                edge_weight = attention.masked_select(mask)
                # edge_index += last_index
                # weights1.append(edge_index)
                # weights2.append(edge_weight)
        torch.cuda.synchronize()
        timec = time.time()
            
        torch.cuda.synchronize()
        time4 = time.time()
        attention_edge_indices = torch.cat(edge_indices, dim=1)
        attention_edge_weights = torch.cat(edge_weights, dim=0).unsqueeze(1)
        
        torch.cuda.synchronize()
        time5 = time.time()
        
        return attention_edge_indices, attention_edge_weights, time2-time1, time3-time2, time4-time3, time5-time4, timeb-timea, timec-timeb
    
    def forward(self, X, batch, has_converged)->int:
        return self.forward_weighted_cosine(X, batch, has_converged)
    
 

In [None]:
#

In [204]:
import torch 
import torch.nn.functional as F
from torch_geometric.utils import dense_to_sparse
import time 

class GraphStructureLearnerOptim(torch.nn.Module):
    def __init__(self, input_channels, num_heads, attention_threshold_epsilon, is_heterogeneous=False, metric_type='weighted_cosine'):
        super().__init__()
        self.num_heads = num_heads
        self.attention_threshold_epsilon = attention_threshold_epsilon
        # weighted cosine
        # (h, 1, dim)
        if metric_type == 'weighted_cosine':
            self.weight_tensor = torch.nn.Parameter(torch.Tensor(num_heads, input_channels)).unsqueeze(1)
            self.weight_tensor = torch.nn.Parameter(torch.nn.init.xavier_uniform_(self.weight_tensor))
            # initialize weights
            
            self.forward = self.forward_weighted_cosine
            
    
    # non heterogeneous case
    def forward_weighted_cosine(self, X, batch, has_converged)->int:
        #  batch is the graph number for each node (the graph the node belongs to)
        
        X = X.unsqueeze(0) 
        
        # hadamard product by broadcasting
        # (h, 1, dim) * (1, n, dim) -> (h, n, dim)
        X = torch.multiply(self.weight_tensor, X)
        
        # get the minibatch indices (star1, end1), (start2, end2), ... for each graph in the batch
        changed = (batch.roll(-1) - batch).roll(1)
        changed[0] += batch.max()
        start = torch.cat([torch.tensor([0]).to(changed.device),torch.where(changed)[0], torch.tensor([batch.shape[-1]]).to(changed.device)])
        end= start.roll(-1)
        indices = torch.stack([start, end], dim=1)[:-1]
        # compute attention for minibatch, save minibatch graphs together in one sparse tensor
        edge_indices, edge_weights = [], []
        
        
        Xnorm = F.normalize(X, p=2, dim=-1) # all values 0 precision if not converted to float64
        # Xnorm_detached = Xnorm.detach()
        last_index = 0
        not_converged = torch.argwhere(~has_converged).squeeze(1)
        for i, (start,end) in enumerate(indices):
            if i in not_converged:
                E = Xnorm[:,start:end,:]
                # Et = Xnorm_detached[:,start:end,:]
                attention = torch.bmm(E, E.transpose(1,2).detach())
                attention = torch.mean(attention, dim=0)
                # remove negative/small values under threshold epsilon
                # attention = torch.where(attention > self.attention_threshold_epsilon, attention, torch.tensor(0.0))
                mask = attention > self.attention_threshold_epsilon
                # attention = attention[~torch.argwhere(mask)] = 0  # not needed, we only select the others for backprop anyways
                
                edge_index = torch.argwhere(mask).T #  indices of non-zero values
                edge_weight = attention.masked_select(mask)
                edge_index += last_index
                edge_indices.append(edge_index)
                edge_weights.append(edge_weight)
                
            last_index += (end - start)  # or simpler ...
            
        attention_edge_indices = torch.cat(edge_indices, dim=1)
        attention_edge_weights = torch.cat(edge_weights, dim=0).unsqueeze(1)
        
        return attention_edge_indices, attention_edge_weights
    
    def forward(self, X, batch, has_converged)->int:
        return self.forward_weighted_cosine(X, batch, has_converged)
    
 

In [213]:
import torch 
import torch.nn.functional as F
from torch_geometric.utils import dense_to_sparse
import time 

class GraphStructureLearnerOptim2(torch.nn.Module):
    def __init__(self, input_channels, num_heads, attention_threshold_epsilon, is_heterogeneous=False, metric_type='weighted_cosine'):
        super().__init__()
        self.num_heads = num_heads
        self.attention_threshold_epsilon = attention_threshold_epsilon
        # weighted cosine
        # (h, 1, dim)
        if metric_type == 'weighted_cosine':
            self.weight_tensor = torch.nn.Parameter(torch.Tensor(num_heads, input_channels)).unsqueeze(1)
            self.weight_tensor = torch.nn.Parameter(torch.nn.init.xavier_uniform_(self.weight_tensor))
            # initialize weights
            
            self.forward = self.forward_weighted_cosine
            
    
    # non heterogeneous case
    def forward_weighted_cosine(self, X, batch, has_converged)->int:
        #  batch is the graph number for each node (the graph the node belongs to)
        
        X = X.unsqueeze(0) 
        
        # hadamard product by broadcasting
        # (h, 1, dim) * (1, n, dim) -> (h, n, dim)
        X = torch.multiply(self.weight_tensor, X)
        
        # get the minibatch indices (star1, end1), (start2, end2), ... for each graph in the batch
        Xnorm = F.normalize(X, p=2, dim=-1) # all values 0 precision if not converted to float64
        attention = torch.bmm(Xnorm, Xnorm.transpose(1,2).detach())
        attention = torch.mean(attention, dim=0)
        mask = attention > self.attention_threshold_epsilon
        attention_edge_indices = torch.argwhere(mask).T #  indices of non-zero values
        attention_edge_weights = attention.masked_select(mask)
        
        return attention_edge_indices, attention_edge_weights
    
    def forward(self, X, batch, has_converged)->int:
        return self.forward_weighted_cosine(X, batch, has_converged)
    
 

In [205]:
structure_learner2 = GraphStructureLearnerOptim(128, 4, 0.1).cuda()
x2 = x.clone()
x2.requires_grad_(True)

tensor([[0.0986, 0.3072, 0.7209,  ..., 0.6020, 0.9548, 0.5166],
        [0.8276, 0.5639, 0.9150,  ..., 0.4976, 0.5375, 0.1376],
        [0.0347, 0.9868, 0.3028,  ..., 0.4488, 0.0330, 0.6578],
        ...,
        [0.2962, 0.8004, 0.6789,  ..., 0.8754, 0.9087, 0.9244],
        [0.1825, 0.0325, 0.8129,  ..., 0.8835, 0.2323, 0.0985],
        [0.4844, 0.7780, 0.6184,  ..., 0.9413, 0.7000, 0.5545]],
       device='cuda:0', requires_grad=True)

In [206]:

%%timeit
torch.cuda.synchronize()
indices,e = structure_learner2(x2, batch, has_converged)
e = e.sum()
e.backward()

19.4 ms ± 55.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [214]:
structure_learner3 = GraphStructureLearnerOptim2(128, 4, 0.1).cuda()
x3 = x.clone()
x3.requires_grad_(True)

tensor([[0.0986, 0.3072, 0.7209,  ..., 0.6020, 0.9548, 0.5166],
        [0.8276, 0.5639, 0.9150,  ..., 0.4976, 0.5375, 0.1376],
        [0.0347, 0.9868, 0.3028,  ..., 0.4488, 0.0330, 0.6578],
        ...,
        [0.2962, 0.8004, 0.6789,  ..., 0.8754, 0.9087, 0.9244],
        [0.1825, 0.0325, 0.8129,  ..., 0.8835, 0.2323, 0.0985],
        [0.4844, 0.7780, 0.6184,  ..., 0.9413, 0.7000, 0.5545]],
       device='cuda:0', requires_grad=True)

In [217]:
batch

torch.Size([8000])

In [221]:

%%timeit
torch.cuda.synchronize()
for i in range(16):
    indices,e = structure_learner3(x3[i*(8000//16):(i+1)*(8000//16)], batch, has_converged)
    e.sum().backward(retain_graph=True)

15.8 ms ± 576 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [123]:
x2.grad

tensor([[1141.2169, -172.5217,  442.3118,  ..., -308.8889, 1508.8729,
           36.1408],
        [-187.2372,  197.1149,  408.1441,  ...,  492.9110, 1770.6268,
          281.8169],
        [-414.2870,  676.9843, -158.0955,  ...,  -75.8605,  895.7821,
         -110.5858],
        ...,
        [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
            0.0000],
        [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
            0.0000],
        [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
            0.0000]], device='cuda:0')

In [124]:
x1.grad

tensor([[ 398.6671, -235.8796, 2067.4553,  ..., -547.0215, 1166.7192,
           70.0552],
        [ -40.9576,  283.5767, 1944.9091,  ...,  847.5580, 1355.4150,
          609.4676],
        [-220.3365,  978.3788, -837.9225,  ..., -147.2248,  681.3812,
         -265.3523],
        ...,
        [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
            0.0000],
        [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
            0.0000],
        [   0.0000,    0.0000,    0.0000,  ...,    0.0000,    0.0000,
            0.0000]], device='cuda:0')

In [171]:
times1, times2,times3,times4,times5,timesa,timesb = [],[],[],[],[],[],[]
for i in range(1000):
    structure_learner2 = GraphStructureLearnerOptimtime(128, 4, 0.1).cuda()
    x2 = x.clone()
    x2.requires_grad_(True)

    indices,e, time1, time2, time3, time4, timea, timeb= structure_learner2(x2, batch, has_converged)
    times1.append(time1)
    times2.append(time2)
    times3.append(time3)
    times4.append(time4)
    timesa.append(timea)
    timesb.append(timeb)
    
    e = e.sum()
    torch.cuda.synchronize()
    time5 = time.time()
    e.backward()
    torch.cuda.synchronize()
    time5 = time.time() - time5
    times5.append(time5)

print(sum(times1)/len(times1), sum(times2)/len(times2), sum(times3)/len(times3), sum(times4)/len(times4), sum(times5)/len(times5), sum(timesa)/len(timesa), sum(timesb)/len(timesb))

KeyboardInterrupt: 

In [None]:
0.0022389500141143798 0.01698826265335083
0.002236340284347534 0.000594428300857544

0.0003767564296722412 0.00025538945198059083 0.014358247756958008 0.0010652749538421632 0.011078863620758057 0.002239511966705322 0.005343229055404663


In [153]:
0.0003784432411193848 0.00025416707992553713 0.007009508609771728 0.0010624141693115234 0.011068058490753175

SyntaxError: invalid syntax (3830012367.py, line 1)

In [182]:
# make random float tensor 100x100 with weights from -1 to 1 
xa = torch.rand(100, 100).to(torch.float32).to('cuda') * 2 - 1
xa.requires_grad_(True)

tensor([[-0.1505, -0.7693,  0.9857,  ..., -0.1066,  0.4547, -0.6688],
        [ 0.3987,  0.9820,  0.3226,  ...,  0.3678, -0.7398, -0.3567],
        [ 0.5150,  0.9157,  0.0548,  ...,  0.1412,  0.7922,  0.7795],
        ...,
        [ 0.9898, -0.9740, -0.5937,  ..., -0.4817, -0.0546,  0.2182],
        [-0.4179, -0.8088, -0.6399,  ..., -0.0083,  0.6907, -0.8968],
        [ 0.0456,  0.9002, -0.1455,  ...,  0.6903,  0.3910, -0.1203]],
       device='cuda:0', requires_grad=True)

In [184]:
xa = torch.rand(100, 100).to(torch.float32).to('cuda') * 2 - 1
xa.requires_grad_(True)

tensor([[ 0.9222,  0.1429,  0.3722,  ...,  0.0598, -0.6948,  0.3111],
        [ 0.1324, -0.4627,  0.9854,  ..., -0.1995,  0.1942,  0.0766],
        [ 0.4638, -0.3174, -0.8356,  ..., -0.0672,  0.0100, -0.2991],
        ...,
        [ 0.2514,  0.8340, -0.6442,  ..., -0.9268,  0.3634, -0.3279],
        [ 0.7734,  0.5188,  0.6181,  ...,  0.2777, -0.4139, -0.0737],
        [ 0.3979, -0.5447, -0.4936,  ...,  0.1326, -0.1847,  0.2321]],
       device='cuda:0', requires_grad=True)

In [188]:

mask = xa > 0.1
edge_index = torch.argwhere(mask).T #  indices of non-zero values
edge_weight = xa.masked_select(mask)

In [196]:
%%timeit
torch.cuda.synchronize()
mask = xa > 0.1
edge_index = torch.nonzero(mask).T
edge_weight = xa[mask]

138 µs ± 961 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [192]:
edge_index

tensor([[ 0,  0,  0,  ..., 99, 99, 99],
        [ 0,  1,  2,  ..., 96, 97, 99]], device='cuda:0')

In [189]:
edge_index

tensor([[ 0,  0,  0,  ..., 99, 99, 99],
        [ 0,  1,  2,  ..., 96, 97, 99]], device='cuda:0')

In [193]:
edge_weight

tensor([0.9222, 0.1429, 0.3722,  ..., 0.4655, 0.1326, 0.2321], device='cuda:0',
       grad_fn=<IndexBackward0>)

In [190]:
edge_weight

tensor([0.9222, 0.1429, 0.3722,  ..., 0.4655, 0.1326, 0.2321], device='cuda:0',
       grad_fn=<MaskedSelectBackward0>)