In [18]:
import os, sys
sys.path.append('../../../Pipelines/TrackML_Example_Full')
sys.path.append('..')
from LightningModules.Embedding.multi_embedding_base import EmbeddingBase
from LightningModules.Embedding.utils import make_mlp
import torch.nn.functional as F
import torch

import yaml

In [4]:
class MultimapEmbedding(EmbeddingBase):
    def __init__(self, hparams):
        super().__init__(hparams)

        in_channels = hparams['spatial_channels'] + hparams['cell_channels']
        self.head_network = make_mlp(in_channels, [hparams["emb_hidden"]] * hparams["nb_layer"] + [hparams["emb_dim"]],
            hidden_activation=hparams["activation"],
            output_activation=None,
            layer_norm=True,
        )

        self.tail_network = make_mlp(in_channels, [hparams["emb_hidden"]] * hparams["nb_layer"] + [hparams["emb_dim"]],
            hidden_activation=hparams["activation"],
            output_activation=None,
            layer_norm=True,
        )

        self.save_hyperparameters()
    
    def forward(self, x):
        head_out = self.head_network(x)
        tail_out = self.tail_network(x)

        if "norm" in self.hparams["regime"]:
            return F.normalize(head_out), F.normalize(tail_out)
        else:
            return head_out, tail_out


In [50]:
config = '../pipeline_config_local.yaml'

with open(config, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

common_configs = config['common_configs']
metric_learning_config = config['metric_learning_configs']
model = MultimapEmbedding(metric_learning_config)
model.setup(stage='fit')
model.to('cuda')

MultimapEmbedding(
  (head_network): Sequential(
    (0): Linear(in_features=12, out_features=1024, bias=True)
    (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (2): Tanh()
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (5): Tanh()
    (6): Linear(in_features=1024, out_features=1024, bias=True)
    (7): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (8): Tanh()
    (9): Linear(in_features=1024, out_features=1024, bias=True)
    (10): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (11): Tanh()
    (12): Linear(in_features=1024, out_features=12, bias=True)
  )
  (tail_network): Sequential(
    (0): Linear(in_features=12, out_features=1024, bias=True)
    (1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (2): Tanh()
    (3): Linear(in_features=1024, out_features=1024, bias=True)
    (4): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    (5):

In [51]:
input_data = model.get_input_data( model.trainset[0] ).to('cuda')

In [52]:
with torch.no_grad():
    head_latent, tail_latent = model(input_data)

In [55]:
e_spatial = torch.empty([2, 0], dtype=torch.int64, device=model.device)
query_indices, query = model.get_query_points(batch, head_latent)

e_spatial = model.append_hnm_pairs(e_spatial, query, query_indices, head_latent)



In [57]:
e_spatial.shape

torch.Size([2, 563794])

In [25]:
indices.shape, query.shape

(torch.Size([12053]), torch.Size([12053, 12]))

In [26]:
model.hparams['points_per_batch']

100000

In [35]:
batch = model.trainset[0]

indices = batch.signal_true_edges.unique()

indices, indices[torch.randperm(len(indices))][: 10000]

(tensor([     2,      9,     17,  ..., 103232, 103234, 103239]),
 tensor([36744, 42619, 74811,  ...,  3775, 11978, 61769]))

In [37]:
indices[torch.randperm(len(indices))][: 100000].shape

torch.Size([12053])

In [38]:
batch

Data(x=[103241, 3], pid=[103241], modules=[103241], event_file='datasets/full_data/21045', hid=[103241], pt=[103241], weights=[76942], modulewise_true_edges=[2, 76942], cell_data=[103241, 9], signal_true_edges=[2, 10963])