In [1]:
import torch 
import yaml, os
import matplotlib, seaborn as sns
from torch_geometric.data import HeteroData
from pytorch_lightning import LightningModule
from itertools import combinations_with_replacement, product
from typing import Dict, Optional
from torch import Tensor
from torch_geometric.typing import Adj, EdgeType, NodeType
import pdb
%matplotlib inline

CONFIGFILE = "hetero_reg3_lev3.yaml"
with open(CONFIGFILE, 'r') as f: config = yaml.load(f, Loader=yaml.FullLoader)
# data_dir = config['input_dir'] + '/train'
# data = torch.load(os.path.join(data_dir, os.listdir(data_dir)[0]))

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from utils import load_dataset, LargeDataset, background_cut_event, make_mlp
import torch_geometric.transforms as T

In [3]:
def get_region(model):
    return 'volume_' + '_'.join([str(i) for i in model['volume_ids']])

class LargeHeteroDataset(LargeDataset):
    def __init__(self, root, subdir, num_events, hparams, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, subdir, num_events, hparams, transform, pre_transform, pre_filter)

    # def get_region(self, model):
    #     return 'region_' + '_'.join([str(i) for i in model['volume_ids']])

    def get(self, idx):
        
        event = torch.load(self.input_paths[idx], map_location=torch.device('cpu'))

        # Process event with pt cuts
        if self.hparams["pt_background_cut"] > 0:
            event = background_cut_event(event, self.hparams["pt_background_cut"], self.hparams["pt_signal_cut"])
        
        # Ensure PID definition is correct
        event.y_pid = (event.pid[event.edge_index[0]] == event.pid[event.edge_index[1]]) & event.pid[event.edge_index[0]].bool()
        event.pid_signal = torch.isin(event.edge_index, event.signal_true_edges).all(0) & event.y_pid

        # create new hit map
        models = config['model_ids']
        map = torch.ones_like(event.hid) * -1
        for model in self.hparams['model_ids']:
            volume_id = model['volume_ids']
            homo_ids = event.hid[ torch.isin( event.volume_id, torch.tensor(volume_id) ) ]
            # print(homo_ids)
            # map = torch.ones((torch.max(homo_ids)+1,), dtype=torch.long) * -1
            # print(map)
            map[homo_ids] = torch.arange(homo_ids.shape[0])
            # print(map)
        

        data = HeteroData()
        for _, model in enumerate(self.hparams['model_ids']):
            region = get_region(model)
            mask = torch.isin( event.volume_id, torch.tensor(model['volume_ids']) )
            for attr in ['x', 'cell_data', 'pid', 'hid', 'pt', 'primary', 'nhits', 'modules', 'volume_id']:
                data[region][attr] = event[attr][mask]
            data[region]['mask'] = mask
        
        for model1, model2 in product(self.hparams['model_ids'], self.hparams['model_ids']):
            # ids = torch.tensor([model1['volume_ids'], model2['volume_ids']])
            id0, id1 = torch.tensor([model1['volume_ids']]), torch.tensor([model2['volume_ids']])
            region1, region2 = get_region(model1), get_region(model2)
            mask1 = torch.isin(event.volume_id[event.edge_index[0]], id0)
            mask2 = torch.isin(event.volume_id[event.edge_index[1]], id1)
            mask = mask1 * mask2 #+ torch.isin(event.volume_id[event.edge_index[0]], id2) * torch.isin(event.volume_id[event.edge_index[1]],id1)
            edge_index = event.edge_index.T[mask].T
            edge_index = map[edge_index]
            data[region1, 'connected_to', region2].edge_index = edge_index
            data[region1, 'connected_to', region2].y = event.y[mask]
            data[region1, 'connected_to', region2].y_pid = event.y_pid[mask]
            for truth_edge in ['modulewise_true_edges', 'signal_true_edges']:
                mask = torch.isin(event.volume_id[event[truth_edge][0]], id0) * torch.isin(event.volume_id[event[truth_edge][1]], id1) #+ torch.isin(event.volume_id[event[truth_edge][0]], id2) * torch.isin(event.volume_id[event[truth_edge][1]], id1)
                data[region1, 'connected_to', region2][truth_edge] = event[truth_edge].T[mask].T
        return data, self.input_paths[idx]


In [4]:
dataset = LargeHeteroDataset(root=config['input_dir'], subdir='train', num_events=1, hparams=config)
data, input_dir = dataset.get(0)
homo_data = torch.load(input_dir)
# undirected_data = T.ToUndirected()(data)

In [5]:
data

HeteroData(
  [1mvolume_0_1[0m={
    x=[248396, 9],
    cell_data=[248396, 11],
    pid=[248396],
    hid=[248396],
    pt=[248396],
    primary=[248396],
    nhits=[248396],
    modules=[248396],
    volume_id=[248396],
    mask=[326917]
  },
  [1mvolume_2[0m={
    x=[40883, 9],
    cell_data=[40883, 11],
    pid=[40883],
    hid=[40883],
    pt=[40883],
    primary=[40883],
    nhits=[40883],
    modules=[40883],
    volume_id=[40883],
    mask=[326917]
  },
  [1mvolume_3[0m={
    x=[37638, 9],
    cell_data=[37638, 11],
    pid=[37638],
    hid=[37638],
    pt=[37638],
    primary=[37638],
    nhits=[37638],
    modules=[37638],
    volume_id=[37638],
    mask=[326917]
  },
  [1m(volume_0_1, connected_to, volume_0_1)[0m={
    edge_index=[2, 292516],
    y=[292516],
    y_pid=[292516],
    modulewise_true_edges=[2, 104922],
    signal_true_edges=[2, 11008]
  },
  [1m(volume_0_1, connected_to, volume_2)[0m={
    edge_index=[2, 22346],
    y=[22346],
    y_pid=[22346],
    mo

In [24]:
from torch_geometric.nn import HeteroConv, HeteroLinear, MLP, GCNConv, MessagePassing, to_hetero
from torch.nn import Module, ModuleDict
from collections import defaultdict

class NodeEncoder(torch.nn.Module):
    def __init__(self, hparams, model) -> None:
        super().__init__()
        self.hparams = hparams

        self.network = make_mlp(
            model['num_features'],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
        )

    def forward(self, x):
        return  self.network( x.float() )
    
class HeteroNodeEncoder(torch.nn.Module):
    def __init__(self, hparams) -> None:
        super().__init__()

        self.hparams = hparams

        self.encoders = torch.nn.ModuleDict()
        for model in self.hparams['model_ids']:
            region = get_region(model)
            self.encoders[region] = NodeEncoder(self.hparams, model)
        
    def forward(self, x_dict):
        for model in self.hparams['model_ids']:
            region = get_region(model)
            print(region)
            x_dict[region] = self.encoders[region](x_dict[region][:, : model['num_features']])
        
        return x_dict

In [7]:
encoder = HeteroNodeEncoder(hparams=config)

encoder(data.x_dict)

{'volume_0_1': tensor([[ 0.4559, -0.6419, -0.2931,  ...,  0.0601,  0.2524,  0.7009],
         [ 0.4751, -0.6559, -0.3046,  ...,  0.0529,  0.2268,  0.7113],
         [ 0.4322, -0.6406, -0.2763,  ...,  0.0861,  0.2159,  0.7096],
         ...,
         [-0.9515, -0.9330, -0.9221,  ..., -0.2887,  0.8334, -0.8976],
         [-0.9532, -0.9333, -0.9220,  ..., -0.2904,  0.8344, -0.8974],
         [-0.9538, -0.9331, -0.9231,  ..., -0.2832,  0.8338, -0.8953]],
        grad_fn=<TanhBackward0>),
 'volume_2': tensor([[ 0.5418,  0.9371,  0.8508,  ..., -0.9120,  0.1569,  0.0486],
         [ 0.5762,  0.9358,  0.8506,  ..., -0.9064,  0.1539,  0.0734],
         [ 0.4515,  0.9366,  0.8540,  ..., -0.9223,  0.1309, -0.0043],
         ...,
         [ 0.3760, -0.8509, -0.4800,  ..., -0.7841,  0.7006, -0.4799],
         [ 0.3826, -0.8437, -0.4772,  ..., -0.7881,  0.7012, -0.4746],
         [ 0.4670, -0.7989, -0.5056,  ..., -0.8094,  0.7284, -0.4121]],
        grad_fn=<TanhBackward0>),
 'volume_3': tensor([[ 0

In [37]:
class EdgeEncoder(torch.nn.Module):
    def __init__(self, hparams) -> None:
        super().__init__()
        self.hparams = hparams

        self.network = make_mlp(
            2 * hparams['hidden'],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
        )

    def forward(self, x, edge_index):
        src, dst = edge_index
        print("Encoding edges")
        if isinstance(x, tuple):
            x1, x2 = x
            x_in = torch.cat([x1[src], x2[dst]], dim=-1)
        else:
            x_in = torch.cat([x[src], x[dst]], dim=-1)

        return  self.network( x_in )

class HeteroEdgeConv(HeteroConv):

    def __init__(self, convs: dict, aggr: str = "sum"):
        super().__init__(convs, aggr)

    def forward(
        self,
        x_dict: dict,
        edge_index_dict: dict,
        *args_dict,
        **kwargs_dict,
    ) -> dict :

        out_dict = {}
        for edge_type, edge_index in edge_index_dict.items():
            print(edge_type)
            src, rel, dst = edge_type

            str_edge_type = '__'.join(edge_type)
            if str_edge_type not in self.convs:
                continue

            args = []
            for value_dict in args_dict:
                if edge_type in value_dict:
                    args.append(value_dict[edge_type])
                elif src == dst and src in value_dict:
                    args.append(value_dict[src])
                elif src in value_dict or dst in value_dict:
                    args.append(
                        (value_dict.get(src, None), value_dict.get(dst, None)))

            kwargs = {}
            for arg, value_dict in kwargs_dict.items():
                arg = arg[:-5]  # `{*}_dict`
                if edge_type in value_dict:
                    kwargs[arg] = value_dict[edge_type]
                elif src == dst and src in value_dict:
                    kwargs[arg] = value_dict[src]
                elif src in value_dict or dst in value_dict:
                    kwargs[arg] = (value_dict.get(src, None),
                                value_dict.get(dst, None))

            conv = self.convs[str_edge_type]
            print(kwargs)

            if src == dst:
                out = conv(x_dict[src], edge_index, *args, **kwargs)
            else:
                out = conv((x_dict[src], x_dict[dst]), edge_index, *args, **kwargs)

            out_dict[edge_type] = out

        return out_dict

class EdgeClassifier(torch.nn.Module):

    def __init__(self, hparams):
        super().__init__()

        self.network = make_mlp(
            3 * hparams["hidden"],
            [hparams["hidden"]] * hparams["nb_edge_layer"] + [1],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_activation=hparams['output_activation'],
            hidden_activation=hparams["hidden_activation"],
        )

    def forward(self, x, edge_index, edge):
        src, dst = edge_index
        if isinstance(x, tuple):
            x1, x2 = x
            classifier_input = torch.cat([x1[src], x2[dst], edge], dim=-1)
        else:
            classifier_input = torch.cat([x[src], x[dst], edge], dim=-1)
        # classifier_input = torch.cat([x[src], x[dst], edge], dim=-1)
        return self.network(classifier_input)

In [32]:
class InteractionMessagePassing(MessagePassing):
    def __init__(self, hparams, aggr: str = "add", flow: str = "source_to_target", node_dim: int = -2, decomposed_layers: int = 1):
        super().__init__(aggr, flow=flow, node_dim=node_dim, decomposed_layers=decomposed_layers)

        self.hparams=hparams

        # The edge network computes new edge features from connected nodes
        self.edge_encoder = make_mlp(
            2 * (hparams["hidden"]),
            [hparams["hidden"]] * hparams["nb_edge_layer"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
        )

        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(
            3 * hparams["hidden"],
            [hparams["hidden"]] * hparams["nb_edge_layer"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
        )

        # The node network computes new node features
        self.node_network = make_mlp(
            2 * hparams["hidden"],
            [hparams["hidden"]] * hparams["nb_node_layer"],
            layer_norm=hparams["layernorm"],
            batch_norm=hparams["batchnorm"],
            output_activation=hparams["output_activation"],
            hidden_activation=hparams["hidden_activation"],
        )

    def message(self, edge):
        return edge

    def aggregate(self, out, edge_index):

        src, dst = edge_index
        return self.aggr_module(out, dst)[dst.unique()]
    
    def update(self, agg_message, x, edge_index):
        src, dst = edge_index
        indices_to_add = torch.arange(agg_message.shape[0])
        print(dst.unique())
        print(x)
        x[dst.unique()] += agg_message
        
        return x

    def edge_update(self, x, edge, edge_index, *args, **kwargs):
        src, dst = edge_index
        if isinstance(x, tuple):
            x_src, x_dst = x[0][src], x[1][dst]
        else:
            x_src, x_dst = x[src], x[dst]
        out = self.edge_network(torch.cat([x_src, x_dst, edge], dim=-1))
        return out

    def forward(self, x, edge_index, edge):

        if isinstance(x, tuple):
            x_src, x_dst = x
        else:
            x_src, x_dst = x, x

        x_dst = self.propagate(edge_index, x=x_dst, edge=edge)

        return x_dst

class InteractionHeteroConv(HeteroConv):
    def __init__(self, convs: Dict[EdgeType, Module], aggr: Optional[str] = "sum"):
        super().__init__(convs, aggr)

    def edge_forward(self,x_dict: Dict[NodeType, Tensor],
        edge_index_dict: Dict[EdgeType, Adj],
        edge_dict,
        *args_dict,
        **kwargs_dict,
    ) -> Dict[NodeType, Tensor]:

        out_dict = {}
        for edge_type, edge_index in edge_index_dict.items():
            print(edge_type)
            src, rel, dst = edge_type

            str_edge_type = '__'.join(edge_type)
            if str_edge_type not in self.convs:
                continue

            args = []
            for value_dict in args_dict:
                if edge_type in value_dict:
                    args.append(value_dict[edge_type])
                elif src == dst and src in value_dict:
                    args.append(value_dict[src])
                elif src in value_dict or dst in value_dict:
                    args.append(
                        (value_dict.get(src, None), value_dict.get(dst, None)))

            kwargs = {}
            for arg, value_dict in kwargs_dict.items():
                arg = arg[:-5]  # `{*}_dict`
                if edge_type in value_dict:
                    kwargs[arg] = value_dict[edge_type]
                elif src == dst and src in value_dict:
                    kwargs[arg] = value_dict[src]
                elif src in value_dict or dst in value_dict:
                    kwargs[arg] = (value_dict.get(src, None),
                                value_dict.get(dst, None))

            conv = self.convs[str_edge_type]
            edge = edge_dict[edge_type]

            out = conv.edge_update((x_dict[src], x_dict[dst]), edge, edge_index, *args, **kwargs)

            # if src == dst:
            #     out = conv.edge_updater(x_dict[src], edge_index, *args, **kwargs)
            # else:
            #     out = conv.edge_update((x_dict[src], x_dict[dst]), edge_index, *args,
            #             **kwargs)

            out_dict[edge_type] = out

        return out_dict


class HeteroGNN(torch.nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        self.node_encoders = HeteroNodeEncoder(self.hparams)
        self.edge_encoders = HeteroEdgeConv({
            (get_region(model0), 'connected_to', get_region(model1)): EdgeEncoder(self.hparams)
            for model0, model1 in product(self.hparams['model_ids'], self.hparams['model_ids'])
        })

        self.convs = torch.nn.ModuleList()

        for _ in range(1):
            conv = InteractionHeteroConv({
                (get_region(model0), 'connected_to', get_region(model1)): InteractionMessagePassing(hparams=self.hparams)
                for model0, model1 in product(self.hparams['model_ids'], self.hparams['model_ids'])
            }, aggr='sum')
            self.convs.append(conv)

        self.edge_classifiers = HeteroEdgeConv({
            (get_region(model0), 'connected_to', get_region(model1)): EdgeClassifier(self.hparams)
            for model0, model1 in combinations_with_replacement(self.hparams['model_ids'], 2)
        })


    def forward(self, x_dict: dict, edge_index_dict: dict):
        x_dict = self.node_encoders(x_dict)  
        edge_dict = self.edge_encoders(x_dict, edge_index_dict) 
        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict, edge_dict=edge_dict)
            edge_dict = conv.edge_forward(x_dict, edge_index_dict, edge_dict)

        out = self.edge_classifiers(x_dict, edge_index_dict, edge_dict=edge_dict)

        return out

In [38]:
gnn = HeteroGNN(config)

# x_dict, edge_dict= gnn(data.x_dict, data.edge_index_dict, data.mask_dict)
out = gnn(data.x_dict, data.edge_index_dict)
print(out)
# for region, x in edge_dict.items():
#     print(region, x.shape)

volume_0_1
volume_2
volume_3
('volume_0_1', 'connected_to', 'volume_0_1')
{}
Encoding edges
('volume_0_1', 'connected_to', 'volume_2')
{}
Encoding edges
('volume_0_1', 'connected_to', 'volume_3')
{}
Encoding edges
('volume_2', 'connected_to', 'volume_0_1')
{}
Encoding edges
('volume_2', 'connected_to', 'volume_2')
{}
Encoding edges
('volume_2', 'connected_to', 'volume_3')
{}
Encoding edges
('volume_3', 'connected_to', 'volume_0_1')
{}
Encoding edges
('volume_3', 'connected_to', 'volume_2')
{}
Encoding edges
('volume_3', 'connected_to', 'volume_3')
{}
Encoding edges
tensor([     0,      1,      7,  ..., 248383, 248392, 248393])
tensor([[-0.7312, -0.8284,  0.6658,  ...,  0.5697, -0.7187,  0.6371],
        [-0.7490, -0.8317,  0.6773,  ...,  0.5903, -0.7170,  0.6186],
        [-0.7228, -0.8311,  0.6851,  ...,  0.5611, -0.7352,  0.6047],
        ...,
        [-0.0284,  0.5865, -0.4600,  ..., -0.0034, -0.4285, -0.6854],
        [-0.0325,  0.5808, -0.4714,  ...,  0.0044, -0.4312, -0.6718],
  

In [36]:
print(out)

{('volume_0_1', 'connected_to', 'volume_0_1'): tensor([[-0.0951],
        [-0.1666],
        [ 0.1289],
        ...,
        [ 0.3235],
        [ 0.2870],
        [ 0.0448]], grad_fn=<AddmmBackward0>), ('volume_0_1', 'connected_to', 'volume_2'): tensor([[ 0.3332],
        [-0.1639],
        [-0.1021],
        ...,
        [ 0.1369],
        [-0.0543],
        [-0.0602]], grad_fn=<AddmmBackward0>), ('volume_0_1', 'connected_to', 'volume_3'): tensor([[ 0.3072],
        [ 0.0244],
        [-0.6064],
        ...,
        [ 0.0584],
        [ 0.1134],
        [ 0.2005]], grad_fn=<AddmmBackward0>), ('volume_2', 'connected_to', 'volume_2'): tensor([[0.4536],
        [0.6040],
        [0.6297],
        ...,
        [0.7279],
        [0.7474],
        [0.7390]], grad_fn=<AddmmBackward0>), ('volume_2', 'connected_to', 'volume_3'): tensor([[0.3646],
        [0.7339],
        [0.6230],
        ...,
        [0.0788],
        [0.3085],
        [0.4047]], grad_fn=<AddmmBackward0>), ('volume_3', 'conn

In [27]:
imp = InteractionMessagePassing(config)

imp.__dict__

imp.message(edge_dict[('volume_0_1', 'connected_to', 'volume_2')])

agg_mess =  imp.aggregate(edge_dict[('volume_0_1', 'connected_to', 'volume_2')], data.edge_index_dict[('volume_0_1', 'connected_to', 'volume_2')])

In [23]:
data.x_dict['volume_2'].shape

torch.Size([44901, 9])

In [33]:
data.edge_index_dict[('volume_0_1', 'connected_to', 'volume_2')][1].unique()

tensor([    3,     9,    10,  ..., 32679, 32682, 32706])

In [19]:
agg_mess.shape

torch.Size([32707, 64])

In [35]:
agg_mess[:4]

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000, 

In [None]:
node_encoder = NodeEncoder(config)
edge_encoder = EdgeEncoder
# het_encoder = to_hetero(encoder, data.metadata())

node_encoder(data['region_2'].x, data['region_2'].mask)[data['region_2'].mask]

In [None]:
het_encoder = to_hetero(node_encoder, metadata=data.metadata())

In [33]:
from torch_geometric.nn.dense.linear import HeteroLinear

encoder = HeteroLinear(in_channels=9, out_channels=64, num_types=3)

In [44]:
encoder(homo_data.x.float(), torch.where(homo_data.volume_id==1, homo_data.volume_id, 0))

tensor([[-0.1403,  0.2397,  0.2695,  ..., -0.0816,  0.3613,  0.1801],
        [-0.1321,  0.2312,  0.2651,  ..., -0.0746,  0.3529,  0.1808],
        [-0.1614,  0.2399,  0.2729,  ..., -0.1022,  0.3619,  0.1883],
        ...,
        [ 0.3059, -0.2536, -1.3117,  ..., -0.7055, -1.2227, -0.5551],
        [ 0.3078, -0.2565, -1.3183,  ..., -0.7088, -1.2038, -0.5388],
        [ 0.3069, -0.2540, -1.3174,  ..., -0.7057, -1.2094, -0.5426]],
       grad_fn=<IndexPutBackward0>)

In [35]:
homo_data

Data(x=[349642, 9], cell_data=[349642, 11], pid=[349642], event_file='/global/cfs/cdirs/m3443/data/ITk-upgrade/processed/full_events_v4/event000014840', hid=[349642], pt=[349642], primary=[349642], nhits=[349642], modules=[349642], modulewise_true_edges=[2, 138249], signal_true_edges=[2, 16298], edge_index=[2, 772775], y=[772775], y_pid=[772775], volume_id=[349642])

In [73]:
models = config['model_ids']
map = torch.ones_like(homo_data.hid) * -1
for model in models:
    volume_id = model['volume_ids']
    homo_ids = homo_data.hid[ torch.isin( homo_data.volume_id, torch.tensor(volume_id) ) ]
    print(homo_ids)
    # map = torch.ones((torch.max(homo_ids)+1,), dtype=torch.long) * -1
    # print(map)
    map[homo_ids] = torch.arange(homo_ids.shape[0])
    print(map)


tensor([     0,      1,      2,  ..., 264913, 264914, 264915])
tensor([ 0,  1,  2,  ..., -1, -1, -1])
tensor([285245, 285246, 285247,  ..., 328489, 328490, 328491])
tensor([ 0,  1,  2,  ..., -1, -1, -1])
tensor([264916, 264917, 264918,  ..., 349639, 349640, 349641])
tensor([    0,     1,     2,  ..., 41476, 41477, 41478])


In [75]:
uid, counts = map.unique(return_counts=True)
counts.unique()

tensor([1, 2, 3])

In [59]:
map = torch.zeros(torch.max(homo_ids))
map

torch.Size([349641])

In [64]:
homo_data.hid

tensor([     0,      1,      2,  ..., 349639, 349640, 349641])

In [23]:
import torch_geometric


torch_geometric.__version__

'2.0.4'