Find best functions

#### Install and import Python libraries

In [56]:
import os
import sys
import time

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm

%load_ext autoreload
%autoreload 2
%matplotlib inline

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [57]:
sys.path.append(os.path.join(reid_root_dir, "src"))

reid_root_dir = ".."
root_dir = ".."
# sys.path.append(os.path.join(root_dir, 'src'))

In [None]:
import torch
from scipy.optimize import linear_sum_assignment as linear_assignment
from torch.nn import functional as F

import motmetrics as mm
from market import metrics, utils

# Load helper code
from market.datamanager import ImageDataManager
from market.models import build_model
from mot.data.data_obj_detect import MOT16ObjDetect
from mot.data.data_track import MOT16Sequences
from mot.models.object_detector import FRCNN_FPN
from mot.tracker.base import Tracker
from mot.utils import (
    evaluate_mot_accums,
    get_mot_accum,
    obj_detect_transforms,
    plot_sequence,
)
from mot.eval import evaluate_obj_detect
mm.lap.default_solver = "lap"

## Setup

In [59]:
seed = 12345

torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True

# Check different euclidian dist func

In [18]:
def euclidean_squared_distance_v1(input1, input2):
    """Computes euclidean squared distance.
    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.
    Returns:
        torch.Tensor: distance matrix.
    """
    distmat = torch.cdist(input1, input2, p=2.0)
    return distmat ** 2


def euclidean_squared_distance_v2(input1, input2):
    """Computes euclidean squared distance.
    Args:
        input1 (torch.Tensor): 2-D feature matrix.
        input2 (torch.Tensor): 2-D feature matrix.
    Returns:
        torch.Tensor: distance matrix.
    """
    # each cell  (a-b)^2 = a^2-2ab+b^2
    m, n = input1.size(0), input2.size(0)
    mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
    mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
    distmat = mat1 + mat2
    distmat.addmm_(input1, input2.t(), beta=1, alpha=-2)
    return distmat

## explain euclidean_squared_distance_v1

In [26]:
input1 = torch.randn((3, 100))

In [27]:
print(input1.norm())
print(input1.norm(dim=1).shape)
print(input1.norm(dim=1)[:, None].shape)
print(input1.shape)
input1_norm = input1 / input1.norm(dim=1)[:, None]
input1_norm.shape

tensor(16.3962)
torch.Size([3])
torch.Size([3, 1])
torch.Size([3, 100])


torch.Size([3, 100])

### torch.mm

In [29]:
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)

tensor([[-0.5445,  1.0862,  1.4717],
        [-0.0646,  0.2895,  0.2801]])

### compare torch.mm and  torch.matmul

In [31]:
input1 = torch.randn((12, 100))
input2 = torch.randn((33, 100))
print(input1.shape, input2.shape)
res_mm = torch.mm(input1, input2.t())
res_matmul = torch.matmul(input1, input2.t())
print("res_mm.shape", res_mm.shape)
print("res_matmul", res_matmul.shape)

torch.Size([12, 100]) torch.Size([33, 100])
res_mm.shape torch.Size([12, 33])
res_matmul torch.Size([12, 33])


In [34]:
assert (res_matmul == res_mm).all(), "res_matmul!=res_mm"

## compare euclidean_squared_distance_v2 and euclidean_squared_distance_v1

In [35]:
input1 = torch.randn((3, 5))
input2 = torch.randn((4, 5))

res_v1 = euclidean_squared_distance_v1(input1, input2)
res_v2 = euclidean_squared_distance_v2(input1, input2)
assert torch.allclose(res_v1, res_v2), "results are not equal"
assert res_v1.shape == res_v2.shape, "shapes are not equal"
print("shape ", res_v1.shape)

shape  torch.Size([3, 4])


## timeit both funcs 

In [36]:
input1 = torch.randn((1000, 5120))
input2 = torch.randn((3000, 5120))

In [37]:
%timeit euclidean_squared_distance_v1(input1, input2)

42.6 ms ± 3.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [38]:
%timeit euclidean_squared_distance_v2(input1, input2)

34.7 ms ± 3.47 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


# select with mask

###  1d_tensor

In [39]:
tensor = torch.rand(7)
mask = torch.randint(3, (7,))
print("tensor", tensor)
print("mask", mask)

tensor tensor([0.5632, 0.3211, 0.0368, 0.4203, 0.7256, 0.4745, 0.0292])
mask tensor([2, 1, 0, 0, 0, 0, 2])


In [40]:
class_id = 1
mask_binary = mask == class_id
print("tensor", tensor)
print("mask_similar", mask_binary)
print("select only vals with mask = 1", torch.masked_select(tensor, mask_binary))
print("select only vals with mask = 1", tensor[mask_binary])
print("select only vals with mask = 1", tensor[mask == 1])

tensor tensor([0.5632, 0.3211, 0.0368, 0.4203, 0.7256, 0.4745, 0.0292])
mask_similar tensor([False,  True, False, False, False, False, False])
select only vals with mask = 1 tensor([0.3211])
select only vals with mask = 1 tensor([0.3211])
select only vals with mask = 1 tensor([0.3211])


### 2d 

In [41]:
def neg_pos_pairs(distance_matrix, targets):
    n = distance_matrix.size(0)
    distance_positive_pairs, distance_negative_pairs = [], []
    for i in range(n):
        row_dist = distance_matrix[i]
        cur_class = targets[i]
        mask_similar = targets == cur_class
        pos_dists = torch.masked_select(row_dist, mask_similar)
        neg_dists = torch.masked_select(row_dist, ~mask_similar)
        hard_neg_dist = min(neg_dists)
        hard_pos_dist = max(pos_dists)
        distance_positive_pairs.append(hard_pos_dist)
        distance_negative_pairs.append(hard_neg_dist)
    distance_positive_pairs = torch.stack(distance_positive_pairs)
    distance_negative_pairs = torch.stack(distance_negative_pairs)

    return distance_positive_pairs, distance_negative_pairs

In [42]:
def neg_pos_pairs2(distance_matrix, targets):
    n = distance_matrix.size(0)
    distance_positive_pairs, distance_negative_pairs = [], []
    for i in range(n):
        row_dist = distance_matrix[i]
        cur_class = targets[i]
        mask_similar = targets == cur_class
        pos_dists = row_dist[mask_similar]
        neg_dists = row_dist[mask_similar == 0]
        hard_neg_dist = min(neg_dists)
        hard_pos_dist = max(pos_dists)
        distance_positive_pairs.append(hard_pos_dist)
        distance_negative_pairs.append(hard_neg_dist)
    distance_positive_pairs = torch.stack(distance_positive_pairs)
    distance_negative_pairs = torch.stack(distance_negative_pairs)

    return distance_positive_pairs, distance_negative_pairs

In [43]:
def neg_pos_pairs3(distance_matrix, targets):
    n = distance_matrix.size(0)
    mask = targets.expand(n, n).eq(targets.expand(n, n).t())
    distance_positive_pairs, distance_negative_pairs = [], []
    for i in range(n):
        distance_positive_pairs.append(dist[i][mask[i]].max().unsqueeze(0))
        distance_negative_pairs.append(dist[i][mask[i] == 0].min().unsqueeze(0))
    distance_positive_pairs = torch.cat(distance_positive_pairs)
    distance_negative_pairs = torch.cat(distance_negative_pairs)
    return distance_positive_pairs, distance_negative_pairs

In [44]:
def neg_pos_pairs4(distance_matrix, targets):
    n = distance_matrix.size(0)
    mask = targets.expand(n, n).eq(targets.expand(n, n).t())
    distance_positive_pairs, distance_negative_pairs = [], []
    for i in range(n):
        row_dist = distance_matrix[i]
        row_mask = mask[i]
        hard_pos_dist = row_dist[row_mask].max().unsqueeze(0)
        hard_neg_dist = row_dist[row_mask == 0].min().unsqueeze(0)
        # pos_dists = torch.masked_select(row_dist, row_mask).max()
        # neg_dists = torch.masked_select(row_dist, ~row_mask).min()
        distance_positive_pairs.append(hard_pos_dist)
        distance_negative_pairs.append(hard_neg_dist)
    distance_positive_pairs = torch.cat(distance_positive_pairs)
    distance_negative_pairs = torch.cat(distance_negative_pairs)
    return distance_positive_pairs, distance_negative_pairs

## expand mask 

In [45]:
n = 7
targets = torch.randint(3, (n,))
targets

tensor([1, 0, 1, 2, 1, 2, 1])

In [46]:
mask = targets.expand(n, n).eq(targets.expand(n, n).t())
mask

tensor([[ True, False,  True, False,  True, False,  True],
        [False,  True, False, False, False, False, False],
        [ True, False,  True, False,  True, False,  True],
        [False, False, False,  True, False,  True, False],
        [ True, False,  True, False,  True, False,  True],
        [False, False, False,  True, False,  True, False],
        [ True, False,  True, False,  True, False,  True]])

In [47]:
input1 = torch.randn((7, 2))

dist = torch.cdist(input1, input1, p=2.0)
targets = torch.randint(3, (dist.size(0),))

dist, targets

(tensor([[0.0000, 1.0912, 2.3678, 2.2721, 3.8014, 1.9302, 2.0710],
         [1.0912, 0.0000, 2.0342, 1.9520, 3.6402, 2.5636, 2.4721],
         [2.3678, 2.0342, 0.0000, 0.0959, 1.6148, 1.9938, 1.5082],
         [2.2721, 1.9520, 0.0959, 0.0000, 1.6911, 1.9319, 1.4580],
         [3.8014, 3.6402, 1.6148, 1.6911, 0.0000, 2.6299, 2.1079],
         [1.9302, 2.5636, 1.9938, 1.9319, 2.6299, 0.0000, 0.5444],
         [2.0710, 2.4721, 1.5082, 1.4580, 2.1079, 0.5444, 0.0000]]),
 tensor([2, 0, 1, 2, 2, 0, 0]))

In [48]:
distance_positive_pairs, distance_negative_pairs = neg_pos_pairs(dist, targets)
distance_positive_pairs2, distance_negative_pairs2 = neg_pos_pairs2(dist, targets)
distance_positive_pairs3, distance_negative_pairs3 = neg_pos_pairs3(dist, targets)
distance_positive_pairs4, distance_negative_pairs4 = neg_pos_pairs4(dist, targets)

In [49]:
assert (
    distance_positive_pairs == distance_positive_pairs3
).all(), "func results are not equal"
assert (
    distance_positive_pairs == distance_positive_pairs2
).all(), "func results are not equal"
assert (
    distance_positive_pairs == distance_positive_pairs4
).all(), "func results are not equal"

### speed

In [50]:
input1 = torch.randn((500, 512))
dist = torch.cdist(input1, input1, p=2.0)
targets = torch.randint(20, (dist.size(0),))

In [51]:
%timeit neg_pos_pairs(dist,targets)

687 ms ± 5.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [52]:
%timeit neg_pos_pairs2(dist,targets)

690 ms ± 6.45 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [53]:
%timeit neg_pos_pairs3(dist, targets)

17.3 ms ± 27.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [54]:
%timeit neg_pos_pairs4(dist,targets)

16.3 ms ± 75.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
