# Example of Metric Learning in Embedded Space

In [1]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml
import logging
logging.basicConfig(level=logging.ERROR)  

# External imports
import matplotlib.pyplot as plt
import scipy as sp
from sklearn.decomposition import PCA
from sklearn.metrics import auc
import numpy as np
import pandas as pd
import torch
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from pytorch_lightning import Trainer
import torch.nn.functional as F
import frnn
from torch_scatter import scatter_add, scatter_mean, scatter_max

sys.path.append("../../../")

from LightningModules.SuperEmbedding.Models.undirected_embedding import UndirectedEmbedding
from LightningModules.SuperEmbedding.Models.gravmetric2 import GravMetric

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
num_edges = 10
num_heads = 3
hidden_dim = 4
hidden = torch.arange(num_edges * num_heads * hidden_dim).reshape(num_edges, num_heads, hidden_dim).float()

In [4]:
hidden.shape

torch.Size([10, 3, 4])

In [3]:
hidden[0]

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

In [17]:
hidden[1]

tensor([[12., 13., 14., 15.],
        [16., 17., 18., 19.],
        [20., 21., 22., 23.]])

In [5]:
hidden.reshape(num_edges, num_heads * hidden_dim).shape

torch.Size([10, 12])

In [26]:
d = torch.ones(num_edges, num_heads, 1)

In [27]:
weighted = hidden * d

In [28]:
weighted[0]

tensor([[ 0.,  1.,  2.,  3.],
        [ 4.,  5.,  6.,  7.],
        [ 8.,  9., 10., 11.]])

In [35]:
hidden_dim = 32
emb_dim = 12
num_heads = 4
num_nodes = 100
num_edges = 200

# Generate random data
hidden = torch.randn(num_nodes, num_heads, hidden_dim).to(device)
emb = torch.randn(num_nodes, num_heads, emb_dim).to(device)
edge_index = torch.randint(0, num_nodes, (2, num_edges)).to(device)

In [36]:
d = torch.sum((emb[edge_index][0] - emb[edge_index][1])**2, dim=-1) # euclidean distance

In [37]:
d.shape

torch.Size([200, 4])

In [39]:
weighted = hidden[edge_index[0]] * d.unsqueeze(-1)

In [42]:
weighted.shape

torch.Size([200, 4, 32])

In [None]:
scatter_add(hidden_edge_features, end, dim=0, dim_size=hidden_features.shape[0]),

In [43]:
weighted_sum = scatter_add(weighted, edge_index[1], dim=0, dim_size=num_nodes)

In [45]:
weighted_sum.shape

torch.Size([100, 4, 32])

In [46]:
from torch import nn

def make_multi_mlp(
    input_size,
    sizes,
    heads,
    hidden_activation="ReLU",
    output_activation="ReLU",
    layer_norm=False,
    batch_norm=False,
):
    """Construct an MLP with specified fully-connected layers."""
    hidden_activation = getattr(nn, hidden_activation)
    if output_activation is not None:
        output_activation = getattr(nn, output_activation)
    layers = []
    n_layers = len(sizes)
    sizes = [input_size] + sizes
    # Hidden layers
    for i in range(n_layers - 1):
        layers.append(nn.Linear(sizes[i], sizes[i + 1]))
        if layer_norm:
            layers.append(nn.LayerNorm(sizes[i + 1], elementwise_affine=False))
        if batch_norm:
            layers.append(nn.BatchNorm1d(sizes[i + 1], track_running_stats=False, affine=False))
        layers.append(hidden_activation())
    # Final layer
    layers.append(nn.Linear(sizes[-2], sizes[-1]))
    if output_activation is not None:
        if layer_norm:
            layers.append(nn.LayerNorm(sizes[-1], elementwise_affine=False))
        if batch_norm:
            layers.append(nn.BatchNorm1d(sizes[-1], track_running_stats=False, affine=False))
        layers.append(output_activation())
    return nn.Sequential(*layers)

In [47]:
make_multi_mlp(32, [16, 8], 1)

Sequential(
  (0): Linear(in_features=32, out_features=16, bias=True)
  (1): ReLU()
  (2): Linear(in_features=16, out_features=8, bias=True)
  (3): ReLU()
)

In [49]:
nn.Linear([3, 32], [3, 16])

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
