# Example of Metric Learning in Embedded Space

In [1]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml
import logging
logging.basicConfig(level=logging.ERROR)  

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import torch.nn.functional as F
import frnn
from torch_scatter import scatter_add, scatter_mean, scatter_max

sys.path.append("../../../")

from LightningModules.Embedding.utils import build_edges, build_knn, graph_intersection
from LightningModules.Embedding.Models.layerless_embedding import LayerlessEmbedding
from LightningModules.Embedding.embedding_base import EmbeddingBase
from LightningModules.GNN.utils import make_mlp

device = "cuda" if torch.cuda.is_available() else "cpu"

## Roadmap

- Check that dataset has 1GeV cut
- Run baseline training of embedding: check eff/pur
- Build undirected GravMetric model
- Test run!
- Build and run undirected half-twin: check eff/pur
- Check definition of directed truth
- Build and run directed half-twin: check eff/pur
- If directed works, build and test directed GravMetric

In [2]:
# Load config file
with open("gravmetric.yaml", "r") as f:
    hparams = yaml.load(f)

  hparams = yaml.load(f)


## The Dataset

In [None]:
model = LayerlessEmbedding(hparams)
model.setup(stage="fit")

In [5]:
sample = model.trainset[0]

In [6]:
sample

Data(x=[8968, 3], cell_data=[8968, 11], pid=[8968], event_file='/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/full_events_v4/event000013127', hid=[8968], pt=[8968], primary=[8968], nhits=[8968], modulewise_true_edges=[2, 7791], signal_true_edges=[2, 6901])

## Baseline

In [3]:
model = LayerlessEmbedding(hparams)

In [None]:
logger = WandbLogger(
    project=hparams["project"], group="InitialTest", save_dir=hparams["artifacts"]
)
trainer = Trainer(
    gpus=1,
    max_epochs=hparams["max_epochs"],
    logger=logger,
)
trainer.fit(model)

## GravMetric Base

In [3]:
class MultiEmbeddingBase(EmbeddingBase):
    def __init__(self, hparams):
        super().__init__(hparams)
        """
        Initialise the Lightning Module that can scan over different embedding training regimes
        """
        self.save_hyperparameters(hparams)

    def build_training_set(self, batch, embedding, r_train=None, knn=None):

        # Instantiate empty prediction edge list
        training_edges = torch.empty([2, 0], dtype=torch.int64, device=self.device)

        query_indices, query = self.get_query_points(batch, embedding)
        
        # Append Hard Negative Mining (hnm) with KNN graph
        if "hnm" in self.hparams["regime"]:
            training_edges = self.append_hnm_pairs(training_edges, query, query_indices, embedding, r_train, knn)

        # Append random edges pairs (rp) for stability
        if "rp" in self.hparams["regime"]:
            training_edges = self.append_random_pairs(training_edges, query_indices, embedding)

        # Instantiate bidirectional truth (since KNN prediction will be bidirectional)
        e_bidir = torch.cat(
            [batch.signal_true_edges, batch.signal_true_edges.flip(0)], axis=-1
        )

        # Calculate truth from intersection between Prediction graph and Truth graph
        training_edges, y = self.get_truth(batch, training_edges, e_bidir)
        new_weights = y.to(self.device)

        # Append all positive examples and their truth and weighting
        training_edges, y, new_weights = self.get_true_pairs(
            training_edges, y, new_weights, e_bidir
        )

        return training_edges, y

    def get_loss(self, hinge, d, margin=None, weight=1):

        if margin is None:
            margin = self.hparams["margin"]**2

        negative_loss = torch.nn.functional.hinge_embedding_loss(
            d[hinge == -1],
            hinge[hinge == -1],
            margin=margin,
            reduction="mean",
        )

        positive_loss = torch.nn.functional.hinge_embedding_loss(
            d[hinge == 1],
            hinge[hinge == 1],
            margin=margin,
            reduction="mean",
        )

        loss = negative_loss +  weight * positive_loss

        return loss

    def training_step(self, batch, batch_idx):

        """
        Args:
            batch (``list``, required): A list of ``torch.tensor`` objects
            batch (``int``, required): The index of the batch

        Returns:
            ``torch.tensor`` The loss function as a tensor
        """

        logging.info(f"Memory at train start: {torch.cuda.max_memory_allocated() / 1024**3} Gb")

        # Forward pass of model, handling whether Cell Information (ci) is included
        input_data = self.get_input_data(batch)

        # Embed hits
        topo, spatial = self(input_data)

        # Build training set
        e_spatial, y_spatial = self.build_training_set(batch, spatial)
        e_topo, y_topo = self.build_training_set(batch, topo, r_train=self.hparams["topo_margin"])

        # Loss functions
        spatial_hinge, spatial_d = self.get_hinge_distance(spatial, e_spatial, y_spatial)
        topo_hinge, topo_d = self.get_hinge_distance(topo, e_topo, y_topo)

        spatial_loss = self.get_loss(spatial_hinge, spatial_d, self.hparams["margin"]**2, self.hparams["weight"])
        topo_loss = self.get_loss(topo_hinge, topo_d, self.hparams["topo_margin"]**2)
        loss = spatial_loss + topo_loss      

        self.log("train_loss", loss)

        return loss

    def shared_evaluation(self, batch, batch_idx, knn_radius, knn_num, log=False):

        input_data = self.get_input_data(batch)
        topo, spatial = self(input_data)

        e_bidir = torch.cat(
            [batch.signal_true_edges, batch.signal_true_edges.flip(0)], axis=-1
        )

        # Build whole KNN graph
        e_spatial = build_edges(
            spatial, spatial, indices=None, r_max=knn_radius, k_max=knn_num
        )

        e_spatial, y_cluster = self.get_truth(batch, e_spatial, e_bidir)

        hinge, d = self.get_hinge_distance(
            spatial, e_spatial.to(self.device), y_cluster
        )

        loss = torch.nn.functional.hinge_embedding_loss(
            d, hinge, margin=self.hparams["margin"]**2, reduction="mean"
        )

        cluster_true = e_bidir.shape[1]
        cluster_true_positive = y_cluster.sum()
        cluster_positive = len(e_spatial[0])

        eff = cluster_true_positive / cluster_true
        pur = cluster_true_positive / cluster_positive
        if "module_veto" in self.hparams["regime"]:
            module_veto_pur = cluster_true_positive / (batch.modules[e_spatial[0]] != batch.modules[e_spatial[1]]).sum()
        else:
            module_veto_pur = 0
        
        if log:
            current_lr = self.optimizers().param_groups[0]["lr"]
            self.log_dict(
                {"val_loss": loss, "eff": eff, "pur": pur, "module_veto_pur": module_veto_pur, "current_lr": current_lr}
            )
        logging.info("Efficiency: {}".format(eff))
        logging.info("Purity: {}".format(pur))
        logging.info(batch.event_file)

        return {
            "loss": loss,
            "distances": d,
            "preds": e_spatial,
            "truth": y_cluster,
            "truth_graph": e_bidir,
        }

## Undirected GravMetric

In [4]:
class UndirectedGravMetric(MultiEmbeddingBase):
    def __init__(self, hparams):
        print("UndirectedGravMetric")
        super().__init__(hparams)
        """
        An implementation of the GravMetric architecture: The most naive version,
        such that input maps are undirected, and there is only topo-space.

        Behaviour is:
        1. topo = map_0(x)
        2. Get neighbourhoods as edge list
        3. Weight start_node topo by edge potential
        4. Scatter_mean weighted start_node topo at end_nodes
        5. Pass [weighted_topo, end_nodes] to map_1(x)
        """

        # Construct the MLP architecture
        self.map_0 = make_mlp(
            hparams["spatial_channels"] + hparams["cell_channels"],
            [hparams["emb_hidden"]] * hparams["nb_layer"] + [hparams["emb_dim"]],
            hidden_activation=hparams["activation"],
            output_activation=None,
            layer_norm=True,
        )

        if hparams["feature_hidden"] > 0:
            self.feature_mlp = make_mlp(
                hparams["spatial_channels"] + hparams["cell_channels"],
                [hparams["feature_hidden"]] * hparams["nb_layer"],
                hidden_activation=hparams["activation"],
                output_activation=None,
                layer_norm=True,
            )
            feature_size = hparams["feature_hidden"]
        else:
            feature_size = hparams["emb_dim"]

        self.map_1 = make_mlp(
            hparams["spatial_channels"] + hparams["cell_channels"] + feature_size,
            [hparams["emb_hidden"]] * hparams["nb_layer"] + [hparams["emb_dim"]],
            hidden_activation=hparams["activation"],
            output_activation=None,
            layer_norm=True,
        )

        self.save_hyperparameters()

    def forward(self, x):

        # 1. topo = map_0(x)
        topo = F.normalize(self.map_0(x))

        # 2. Get neighbourhoods as edge list
        # topo_edges = build_edges(topo, topo, r_max=self.hparams["topo_margin"], k_max=self.hparams["topo_k"])
        topo_edges = build_knn(topo, self.hparams["topo_k"])

        # 3. Weight start_node topo by edge potential
        edge_potentials = self.get_potential(topo, topo_edges)
        if self.hparams["feature_hidden"] > 0:
            features = self.feature_mlp(x)
            weighted_features = features[topo_edges[0]] * edge_potentials.unsqueeze(-1)
        else:
            weighted_features = topo[topo_edges[0]] * edge_potentials.unsqueeze(-1)
            
        # 4. Scatter_mean weighted start_node topo at end_nodes
        mean_neighborhood = scatter_mean(weighted_features, topo_edges[1], dim=0, dim_size=topo.shape[0])

        # 5. Pass [weighted_topo, end_nodes] to map_1(x)
        return topo, F.normalize(self.map_1(torch.cat([x, mean_neighborhood], dim=-1)))

    def get_potential(self, x, edges):

        d_sq = ((x[edges[0]] - x[edges[1]])**2).sum(dim=-1)

        potential = (torch.exp(1 - d_sq / self.hparams["topo_margin"]**2) - 1) / (np.exp(1) - 1)

        return potential        

## Directed GravMetric

In [3]:
# TODO
# MODEL
# Need multimap M_1a, M_1b, M_2a, M_2b
# Aggregate start and end separately
# Concatenate and pass through MLP

# BASE
# Need multimap M_1a, M_1b, M_2a, M_2b as output

class GravMetricBase(EmbeddingBase):
    def __init__(self, hparams):
        super().__init__(hparams)
        """
        Initialise the Lightning Module that can scan over different embedding training regimes
        """
        self.save_hyperparameters(hparams)

    def get_hinge_distance(self, spatial1, spatial2, e_spatial, y_cluster):

        hinge = y_cluster.float().to(self.device)
        hinge[hinge == 0] = -1

        reference = spatial1[e_spatial[0]]
        neighbors = spatial2[e_spatial[1]]
        d = torch.sum((reference - neighbors) ** 2, dim=-1)

        return hinge, d

    def build_training_set(self, batch, embedding_a, embedding_b, r_train=None, knn=None):

        # Instantiate empty prediction edge list
        training_edges = torch.empty([2, 0], dtype=torch.int64, device=self.device)

        query_indices, query = self.get_query_points(batch, embedding_a)
        
        # Append Hard Negative Mining (hnm) with KNN graph
        if "hnm" in self.hparams["regime"]:
            training_edges = self.append_hnm_pairs(training_edges, query, query_indices, embedding_b, r_train, knn)

        # Append random edges pairs (rp) for stability
        if "rp" in self.hparams["regime"]:
            training_edges = self.append_random_pairs(training_edges, query_indices, embedding_b)

        # Instantiate bidirectional truth (since KNN prediction will be bidirectional)
        e_bidir = batch.signal_true_edges

        # Calculate truth from intersection between Prediction graph and Truth graph
        training_edges, y = self.get_truth(batch, training_edges, e_bidir)
        new_weights = y.to(self.device)

        # Append all positive examples and their truth and weighting
        training_edges, y, new_weights = self.get_true_pairs(
            training_edges, y, new_weights, e_bidir
        )

        return training_edges, y

    def get_loss(self, hinge, d, margin=None, weight=1):

        if margin is None:
            margin = self.hparams["margin"]**2

        negative_loss = torch.nn.functional.hinge_embedding_loss(
            d[hinge == -1],
            hinge[hinge == -1],
            margin=margin,
            reduction="mean",
        )

        positive_loss = torch.nn.functional.hinge_embedding_loss(
            d[hinge == 1],
            hinge[hinge == 1],
            margin=margin,
            reduction="mean",
        )

        loss = negative_loss +  weight * positive_loss

        return loss

    def training_step(self, batch, batch_idx):

        """
        Args:
            batch (``list``, required): A list of ``torch.tensor`` objects
            batch (``int``, required): The index of the batch

        Returns:
            ``torch.tensor`` The loss function as a tensor
        """

        logging.info(f"Memory at train start: {torch.cuda.max_memory_allocated() / 1024**3} Gb")

        # Forward pass of model, handling whether Cell Information (ci) is included
        input_data = self.get_input_data(batch)

        # Embed hits
        topo_a, topo_b, output_a, output_b = self(input_data)

        # Build training set
        e_spatial, y_spatial = self.build_training_set(batch, output_a, output_b)
        e_topo, y_topo = self.build_training_set(batch, topo_a, topo_b, r_train=self.hparams["topo_margin"])

        # Loss functions
        spatial_hinge, spatial_d = self.get_hinge_distance(output_a, output_b, e_spatial, y_spatial)
        topo_hinge, topo_d = self.get_hinge_distance(topo_a, topo_b, e_topo, y_topo)

        spatial_loss = self.get_loss(spatial_hinge, spatial_d, self.hparams["margin"]**2, self.hparams["weight"])
        topo_loss = self.get_loss(topo_hinge, topo_d, self.hparams["topo_margin"]**2)
        loss = spatial_loss + topo_loss      

        self.log("train_loss", loss)

        return loss

    def shared_evaluation(self, batch, batch_idx, knn_radius, knn_num, log=False):

        input_data = self.get_input_data(batch)
        _, _, output_a, output_b = self(input_data)

        e_bidir = batch.signal_true_edges

        # Build whole KNN graph
        e_spatial = build_edges(
            output_a, output_b, indices=None, r_max=knn_radius, k_max=knn_num
        )

        e_spatial, y_cluster = self.get_truth(batch, e_spatial, e_bidir)

        hinge, d = self.get_hinge_distance(
            output_a, output_b, e_spatial.to(self.device), y_cluster
        )

        loss = torch.nn.functional.hinge_embedding_loss(
            d, hinge, margin=self.hparams["margin"]**2, reduction="mean"
        )

        cluster_true = e_bidir.shape[1]
        cluster_true_positive = y_cluster.sum()
        cluster_positive = len(e_spatial[0])

        eff = cluster_true_positive / cluster_true
        pur = cluster_true_positive / cluster_positive
        if "module_veto" in self.hparams["regime"]:
            module_veto_pur = cluster_true_positive / (batch.modules[e_spatial[0]] != batch.modules[e_spatial[1]]).sum()
        else:
            module_veto_pur = 0
        
        if log:
            current_lr = self.optimizers().param_groups[0]["lr"]
            self.log_dict(
                {"val_loss": loss, "eff": eff, "pur": pur, "module_veto_pur": module_veto_pur, "current_lr": current_lr}
            )
        logging.info("Efficiency: {}".format(eff))
        logging.info("Purity: {}".format(pur))
        logging.info(batch.event_file)

        return {
            "loss": loss,
            "distances": d,
            "preds": e_spatial,
            "truth": y_cluster,
            "truth_graph": e_bidir,
        }
        
class DirectedGravMetric(GravMetricBase):
    def __init__(self, hparams):
        print("DirectedGravMetric")
        super().__init__(hparams)
        """
        An implementation of the GravMetric architecture: The most naive version,
        such that input maps are undirected, and there is only topo-space.

        Behaviour is:
        1. topo_a, topo_b  = map_0a(x), map_0b(x)
        2. Get neighbourhoods as edge list (from a->b)
        3. Weight all nodes topo_a and topo_b by edge potential
        4. Scatter_mean weighted start_node topo_a at end_nodes, and weight end_node topo_b at start_nodes
        5. Pass [mean_neighbors_a, topo_a, topo_b, mean_neighbors_b] to map_1a, map_1b
        """

        # Construct the MLP architecture
        self.map_0a, self.map_0b = [make_mlp(
            hparams["spatial_channels"] + hparams["cell_channels"],
            [hparams["emb_hidden"]] * hparams["nb_layer"] + [hparams["emb_dim"]],
            hidden_activation=hparams["activation"],
            output_activation=None,
            layer_norm=True,
        ) for _ in range(2)]

        if hparams["feature_hidden"] > 0:
            self.feature_mlp_a, self.feature_mlp_b = [make_mlp(
                hparams["spatial_channels"] + hparams["cell_channels"],
                [hparams["feature_hidden"]] * hparams["nb_layer"],
                hidden_activation=hparams["activation"],
                output_activation=None,
                layer_norm=True,
            ) for _ in range(2)]
            feature_size = 2 * hparams["feature_hidden"] + 2 * hparams["emb_dim"] + hparams["spatial_channels"] + hparams["cell_channels"]
        else:
            feature_size = 2 * hparams["emb_dim"] + hparams["spatial_channels"] + hparams["cell_channels"]

        self.map_1a, self.map_1b = [make_mlp(
            feature_size,
            [hparams["emb_hidden"]] * hparams["nb_layer"] + [hparams["emb_dim"]],
            hidden_activation=hparams["activation"],
            output_activation=None,
            layer_norm=True,
        ) for _ in range(2)]

        self.save_hyperparameters()

    def forward(self, x):

        # 1. Embed and build topo graph
        topo_a, topo_b = F.normalize(self.map_0a(x)), F.normalize(self.map_0b(x))
        topo_edges = build_knn(topo_a, topo_b, self.hparams["topo_k"])

        # 2. Get gravity potentials and weight features as attention mechanism
        edge_potentials = self.get_potential(topo_a, topo_b, topo_edges)
        if self.hparams["feature_hidden"] > 0:
            features_in = self.feature_mlp_a(x)
            features_out = self.feature_mlp_b(x)
            weighted_features_in = torch.cat([features_in, topo_a], dim=-1)[topo_edges[0]] * edge_potentials.unsqueeze(-1)
            weighted_features_out = torch.cat([features_out, topo_b], dim=-1)[topo_edges[1]] * edge_potentials.unsqueeze(-1)
        else:
            weighted_features_in = topo_a[topo_edges[0]] * edge_potentials.unsqueeze(-1)
            weighted_features_out = topo_b[topo_edges[1]] * edge_potentials.unsqueeze(-1)
            
        # 3. Scatter_mean weighted start_node topo at end_nodes
        mean_neighborhood_in = scatter_mean(weighted_features_in, topo_edges[1], dim=0, dim_size=topo_a.shape[0])
        mean_neighborhood_out = scatter_mean(weighted_features_out, topo_edges[0], dim=0, dim_size=topo_b.shape[0])

        # 4. Pass [mean_neighbors_a, topo_a, topo_b, mean_neighbors_b] to map_1a, map_1b
        out_a = F.normalize(self.map_1a(torch.cat([mean_neighborhood_in, x, mean_neighborhood_out], dim=-1)))
        out_b = F.normalize(self.map_1b(torch.cat([mean_neighborhood_in, x, mean_neighborhood_out], dim=-1)))

        return topo_a, topo_b, out_a, out_b

    def get_potential(self, x_a, x_b, edges):

        d_sq = ((x_a[edges[0]] - x_b[edges[1]])**2).sum(dim=-1)

        potential = (torch.exp(1 - d_sq / self.hparams["topo_margin"]**2) - 1) / (np.exp(1) - 1)

        return potential        

## Train GravMetric

In [3]:
from LightningModules.SuperEmbedding.Models.gravmetric import DirectedGravMetric

In [4]:
model  = DirectedGravMetric(hparams)

DirectedGravMetric


In [5]:
logger = WandbLogger(
    project=hparams["project"], group="InitialTest", save_dir=hparams["artifacts"]
)

trainer = Trainer(
    gpus=1,
    max_epochs=hparams["max_epochs"],
    logger=logger,
    num_sanity_val_steps=0,
)

trainer.fit(model)

INFO:pytorch_lightning.utilities.distributed:GPU available: True, used: True
INFO:pytorch_lightning.utilities.distributed:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.distributed:IPU available: False, using: 0 IPUs
100%|██████████| 120/120 [00:05<00:00, 20.18it/s]
100%|██████████| 120/120 [00:04<00:00, 28.51it/s]
INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


0.0 Gb


ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmurnanedaniel[0m. Use [1m`wandb login --relogin`[0m to force relogin


INFO:pytorch_lightning.callbacks.model_summary:
  | Name          | Type       | Params
---------------------------------------------
0 | map_0a        | Sequential | 3.2 M 
1 | map_0b        | Sequential | 3.2 M 
2 | feature_mlp_a | Sequential | 3.2 M 
3 | feature_mlp_b | Sequential | 3.2 M 
4 | map_1a        | Sequential | 3.2 M 
5 | map_1b        | Sequential | 3.2 M 
---------------------------------------------
19.1 M    Trainable params
0         Non-trainable params
19.1 M    Total params
76.530    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]



Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


## Testing FRNN for two maps

In [2]:
# Make 2D torch grid
space1 = torch.stack(torch.meshgrid(torch.arange(0, 10, 1), torch.arange(0, 10, 1))).flatten(1).T.float().to(device)
space2 = torch.stack(torch.meshgrid(torch.arange(5, 15, 1), torch.arange(0, 10, 1))).flatten(1).T.float().to(device)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [3]:
import frnn

In [4]:
def build_edges(
    query, database, indices=None, r_max=1.0, k_max=10, return_indices=False
):

    dists, idxs, nn, grid = frnn.frnn_grid_points(
        points1=query.unsqueeze(0),
        points2=database.unsqueeze(0),
        lengths1=None,
        lengths2=None,
        K=k_max,
        r=r_max,
        grid=None,
        return_nn=False,
        return_sorted=True,
    )

    idxs = idxs.squeeze().int()
    ind = torch.Tensor.repeat(
        torch.arange(idxs.shape[0], device=device), (idxs.shape[1], 1), 1
    ).T.int()
    positive_idxs = idxs >= 0
    edge_list = torch.stack([ind[positive_idxs], idxs[positive_idxs]]).long()

    # Reset indices subset to correct global index
    if indices is not None:
        edge_list[0] = indices[edge_list[0]]

    # Remove self-loops
    # edge_list = edge_list[:, edge_list[0] != edge_list[1]]

    if return_indices:
        return edge_list, dists, idxs, ind
    else:
        return edge_list

In [5]:
k_max = 100
r_max = 2.1
dists, idxs, nn, grid = frnn.frnn_grid_points(
        points1=space1.unsqueeze(0),
        points2=space2.unsqueeze(0),
        lengths1=None,
        lengths2=None,
        K=k_max,
        r=r_max,
        grid=None,
        return_nn=False,
        return_sorted=True,
    )

In [6]:
twin_graph = build_edges(space1, space1, r_max=r_max, k_max=k_max)

In [7]:
twin_graph.shape

torch.Size([2, 1104])

In [8]:
twin_graph

tensor([[ 0,  0,  0,  ..., 99, 99, 99],
        [ 0,  1, 10,  ..., 97, 98, 99]], device='cuda:0')

In [48]:
half_graph = build_edges(space1, space2, r_max=r_max, k_max=k_max)

In [49]:
half_graph.shape

torch.Size([2, 600])

In [51]:
half_graph

tensor([[30, 31, 32,  ..., 99, 99, 99],
        [ 0,  1,  2,  ..., 58, 59, 69]], device='cuda:0')

In [53]:
half_graph[:, :100]

tensor([[30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 40, 40, 41, 41, 41, 41, 42,
         42, 42, 42, 43, 43, 43, 43, 44, 44, 44, 44, 45, 45, 45, 45, 46, 46, 46,
         46, 47, 47, 47, 47, 48, 48, 48, 48, 49, 49, 49, 50, 50, 50, 50, 50, 50,
         51, 51, 51, 51, 51, 51, 51, 51, 52, 52, 52, 52, 52, 52, 52, 52, 52, 53,
         53, 53, 53, 53, 53, 53, 53, 53, 54, 54, 54, 54, 54, 54, 54, 54, 54, 55,
         55, 55, 55, 55, 55, 55, 55, 55, 56, 56],
        [ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  0,  1, 10,  0,  1, 11,  2,  1,
          2, 12,  3,  2,  3, 13,  4,  3,  4, 14,  5,  4,  5, 15,  6,  5,  6, 16,
          7,  6,  7, 17,  8,  7,  8, 18,  9,  8,  9, 19,  0,  1, 10, 11,  2, 20,
          0,  1, 10, 11,  2, 12,  3, 21,  0,  1, 11,  2, 12,  3, 13,  4, 22,  1,
          2, 12,  3, 13,  4, 14,  5, 23,  2,  3, 13,  4, 14,  5, 15,  6, 24,  3,
          4, 14,  5, 15,  6, 16,  7, 25,  4,  5]], device='cuda:0')

In [9]:
indices = torch.arange(50, 99, device=device)

In [11]:
half_graph = build_edges(space1[indices], space2, indices = indices, r_max=r_max, k_max=k_max)

In [12]:
half_graph.shape

torch.Size([2, 543])

In [13]:
half_graph

tensor([[50, 50, 50,  ..., 98, 98, 98],
        [ 0,  1, 10,  ..., 58, 59, 68]], device='cuda:0')