# Test refactor of Graph Intersection

In [1]:
%load_ext autoreload
%autoreload 2

# System imports
import os
import sys
import yaml

import torch
import numpy as np
import scipy as sp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

##  Definition

In [4]:
def graph_intersection_v1(
    pred_graph, truth_graph, using_weights=False, weights_bidir=None
):

    array_size = max(pred_graph.max().item(), truth_graph.max().item()) + 1

    l1 = pred_graph.cpu().numpy() if torch.is_tensor(pred_graph) else pred_graph
    l2 = truth_graph.cpu().numpy() if torch.is_tensor(truth_graph) else truth_graph
    e_1 = sp.sparse.coo_matrix(
        (np.ones(l1.shape[1]), l1), shape=(array_size, array_size)
    ).tocsr()
    e_2 = sp.sparse.coo_matrix(
        (np.ones(l2.shape[1]), l2), shape=(array_size, array_size)
    ).tocsr()
    del l1

    e_intersection = e_1.multiply(e_2) - ((e_1 - e_2) > 0)
    del e_1
    del e_2

    if using_weights:
        weights_list = weights_bidir.cpu().numpy()
        weights_sparse = sp.sparse.coo_matrix(
            (weights_list, l2), shape=(array_size, array_size)
        ).tocsr()
        del weights_list
        del l2
        new_weights = weights_sparse[e_intersection.astype("bool")]
        del weights_sparse
        new_weights = torch.from_numpy(np.array(new_weights)[0])

    e_intersection = e_intersection.tocoo()
    new_pred_graph = torch.from_numpy(
        np.vstack([e_intersection.row, e_intersection.col])
    ).long().to(pred_graph.device)
    y = torch.from_numpy(e_intersection.data > 0).to(pred_graph.device)
    del e_intersection

    if using_weights:
        return new_pred_graph, y, new_weights
    else:
        return new_pred_graph, y

In [5]:
from torch_geometric.utils import to_scipy_sparse_matrix

def graph_intersection_v2(
    pred_graph, truth_graph, using_weights=False, weights_bidir=None
):

    array_size = max(pred_graph.max().item(), truth_graph.max().item()) + 1

    e_1 = to_scipy_sparse_matrix(pred_graph).tocsr()
    e_2 = to_scipy_sparse_matrix(truth_graph).tocsr()

    e_intersection = e_1.multiply(e_2) - ((e_1 - e_2) > 0)

    e_intersection = e_intersection.tocoo()
    new_pred_graph = torch.from_numpy(
        np.vstack([e_intersection.row, e_intersection.col])
    ).long().to(pred_graph.device)
    y = torch.from_numpy(e_intersection.data > 0).to(pred_graph.device)

    return new_pred_graph, y

In [102]:
def graph_intersection_v3(pred_graph, truth_graph, is_unique=True):
    if not is_unique:
        raise NotImplementedError("Not implemented for non-unique graphs")
        pred_graph, truth_graph = torch.unique(pred_graph, dim=1), torch.unique(truth_graph, dim=1)

    unique_edges, inverse, counts = torch.unique(torch.cat([pred_graph, truth_graph], dim=1), dim=1, sorted=False, return_inverse=True, return_counts=True)

    pred_counts = counts[inverse[:pred_graph.shape[1]]]

    y = pred_counts > 1

    return pred_graph, y

In [38]:
pred = torch.randint(0, 100000, (2, 1000000)).to(device)
truth = torch.randint(0, 100000, (2, 100000)).to(device)

In [39]:
pred_out_1, y_1 = graph_intersection_v1(pred, truth)

In [40]:
pred_out_2, y_2 = graph_intersection_v2(pred, truth)

In [43]:
y_3 = graph_intersection_v3(pred, truth)

In [44]:
y_1.sum(), y_2.sum(), y_3.sum()

(tensor(4), tensor(4), tensor(100000))

In [107]:
# Check that pred_out_2 is the same as pred_out_1 and y_2 is the same as y_1
print((pred_out_1 == pred_out_2).all(), (y_1 == y_2).all())
print((pred_out_1 == pred_out_3).all(), (y_1 == y_3).all())

tensor(True) tensor(True)


RuntimeError: The size of tensor a (999959) must match the size of tensor b (1000000) at non-singleton dimension 1

## Torch Sparse Tensor approach

In [61]:
from torch.utils.benchmark import Timer

In [157]:
# time the two functions
timer = Timer(
    stmt="graph_intersection_v1(pred, truth)",
    globals=globals(),
    label="v1",
    sub_label="",
)
timer.blocked_autorange(min_run_time=1)

<torch.utils.benchmark.utils.common.Measurement object at 0x150f2506fe50>
v1
  Median: 116.65 ms
  IQR:    11.93 ms (105.32 to 117.26)
  9 measurements, 1 runs per measurement, 1 thread
           This could indicate system fluctuation.

In [158]:
timer = Timer(
    stmt="graph_intersection_v2(pred, truth)",
    globals=globals(),
    label="v2",
    sub_label="",
)

timer.blocked_autorange(min_run_time=1)

<torch.utils.benchmark.utils.common.Measurement object at 0x150f24f86e20>
v2
  Median: 109.46 ms
  IQR:    5.62 ms (105.66 to 111.28)
  10 measurements, 1 runs per measurement, 1 thread

In [159]:
timer = Timer(
    stmt="graph_intersection_v3(pred, truth, is_unique=False)",
    globals=globals(),
    label="v3",
    sub_label="",
)

timer.blocked_autorange(min_run_time=1)

<torch.utils.benchmark.utils.common.Measurement object at 0x150f25080580>
v3
  Median: 4.51 ms
  IQR:    0.02 ms (4.50 to 4.52)
  21 measurements, 10 runs per measurement, 1 thread

In [66]:
pred = torch.Tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7]])
truth = torch.Tensor([[0, 1, 2, 3], [2, 1, 3, 4]])

In [67]:
y, truth_to_pred = graph_intersection_v3(pred, truth, return_truth_to_pred=True)

In [68]:
y

tensor([False, False,  True,  True, False, False, False])

In [69]:
truth_to_pred

tensor([-1, -1,  2,  3])

In [21]:
truth

tensor([[0., 1., 2., 3.],
        [2., 1., 3., 4.]])

In [22]:
pred

tensor([[0., 1., 2., 3.],
        [1., 2., 3., 4.]])

In [8]:
y

tensor([False, False,  True,  True])

In [92]:
y_3 = graph_intersection_v3(pred, truth)

torch.Size([1999816])
torch.Size([1000000])
torch.Size([1999816])
torch.Size([1000000])


In [93]:
y_3

tensor([False, False, False,  ..., False, False, False])

In [54]:
unique_edges, inverse, _ = torch.unique(torch.cat([pred, truth], dim=1), dim=1, sorted=False, return_inverse=True, return_counts=True)
# pred_counts = counts[inverse[:pred.shape[1]]]
# truth_counts = counts[inverse[pred.shape[1]:]]
# y_pred = pred_counts > 1
# y_truth = truth_counts > 1

In [29]:
all_edges = torch.cat([pred, truth], dim=1)

#### So, we have several concepts
- All edges: This is the concat of [pred, truth] edges
- Unique edges: This is the unique edges of [pred, truth] edges
- Inverse: This is the map from all_edges to unique_edges
- Inverse to pred: This is the map from unique_edges to pred_edges
- Inverse to truth: This is the map from unique_edges to truth_edges

To get the truth to pred map, we map from the truth part of inverse (i.e. inverse[num_pred:]) then use the inverse-to-pred map

In [55]:
inverse_pred_map = torch.ones(unique_edges.shape[1], dtype=torch.long) * -1
inverse_pred_map[inverse[:pred.shape[1]]] = torch.arange(pred.shape[1])

inverse_truth_map = torch.ones(unique_edges.shape[1], dtype=torch.long) * -1
inverse_truth_map[inverse[pred.shape[1]:]] = torch.arange(truth.shape[1])

In [65]:
def graph_intersection_v3(input_pred_graph, input_truth_graph,return_y_pred=True, return_y_truth=False, return_pred_to_truth=False, return_truth_to_pred=False):

    unique_edges, inverse = torch.unique(torch.cat([input_pred_graph, input_truth_graph], dim=1), dim=1, sorted=False, return_inverse=True, return_counts=False)

    inverse_pred_map = torch.ones(unique_edges.shape[1], dtype=torch.long) * -1
    inverse_pred_map[inverse[:input_pred_graph.shape[1]]] = torch.arange(input_pred_graph.shape[1])
    
    inverse_truth_map = torch.ones(unique_edges.shape[1], dtype=torch.long) * -1
    inverse_truth_map[inverse[input_pred_graph.shape[1]:]] = torch.arange(input_truth_graph.shape[1])

    pred_to_truth = inverse_truth_map[inverse][:input_pred_graph.shape[1]]
    truth_to_pred = inverse_pred_map[inverse][input_pred_graph.shape[1]:]

    return_tensors = []

    if return_y_pred:
        y_pred = pred_to_truth >= 0
        return_tensors.append(y_pred)

    if return_y_truth:
        y_truth = truth_to_pred >= 0
        return_tensors.append(y_truth)

    if return_pred_to_truth:        
        return_tensors.append(pred_to_truth)

    if return_truth_to_pred:
        return_tensors.append(truth_to_pred)

    return return_tensors if len(return_tensors) > 1 else return_tensors[0]

In [56]:
inverse_pred_map

tensor([ 0, -1, -1,  1,  2,  3,  4,  5,  6])

In [57]:
inverse_truth_map

tensor([-1,  0,  1, -1,  2,  3, -1, -1, -1])

In [63]:
inverse

tensor([0, 3, 4, 5, 6, 7, 8, 1, 2, 4, 5])

In [61]:
truth

tensor([[0., 1., 2., 3.],
        [2., 1., 3., 4.]])

In [62]:
pred

tensor([[0., 1., 2., 3., 4., 5., 6.],
        [1., 2., 3., 4., 5., 6., 7.]])

In [58]:
truth_to_pred = inverse_pred_map[inverse][pred.shape[1]:]
pred_to_truth = inverse_truth_map[inverse][:pred.shape[1]]

In [59]:
truth_to_pred

tensor([-1, -1,  2,  3])

In [60]:
pred_to_truth >= 0

tensor([False, False,  True,  True, False, False, False])