-
Notifications
You must be signed in to change notification settings - Fork 657
/
base_distance.py
97 lines (83 loc) · 3.43 KB
/
base_distance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
from ..utils.module_with_records import ModuleWithRecords
class BaseDistance(ModuleWithRecords):
def __init__(
self, normalize_embeddings=True, p=2, power=1, is_inverted=False, **kwargs
):
super().__init__(**kwargs)
self.normalize_embeddings = normalize_embeddings
self.p = p
self.power = power
self.is_inverted = is_inverted
self.add_to_recordable_attributes(list_of_names=["p", "power"], is_stat=False)
self.add_to_recordable_attributes(
list_of_names=[
"initial_avg_query_norm",
"initial_avg_ref_norm",
"final_avg_query_norm",
"final_avg_ref_norm",
],
is_stat=True,
)
def forward(self, query_emb, ref_emb=None):
self.reset_stats()
self.check_shapes(query_emb, ref_emb)
query_emb_normalized = self.maybe_normalize(query_emb)
if ref_emb is None:
ref_emb = query_emb
ref_emb_normalized = query_emb_normalized
else:
ref_emb_normalized = self.maybe_normalize(ref_emb)
self.set_default_stats(
query_emb, ref_emb, query_emb_normalized, ref_emb_normalized
)
mat = self.compute_mat(query_emb_normalized, ref_emb_normalized)
if self.power != 1:
mat = mat**self.power
assert mat.size() == torch.Size((query_emb.size(0), ref_emb.size(0)))
return mat
def compute_mat(self, query_emb, ref_emb):
raise NotImplementedError
def pairwise_distance(self, query_emb, ref_emb):
raise NotImplementedError
def smallest_dist(self, *args, **kwargs):
if self.is_inverted:
return torch.max(*args, **kwargs)
return torch.min(*args, **kwargs)
def largest_dist(self, *args, **kwargs):
if self.is_inverted:
return torch.min(*args, **kwargs)
return torch.max(*args, **kwargs)
# This measures the margin between x and y
def margin(self, x, y):
if self.is_inverted:
return y - x
return x - y
def normalize(self, embeddings, dim=1, **kwargs):
return torch.nn.functional.normalize(embeddings, p=self.p, dim=dim, **kwargs)
def maybe_normalize(self, embeddings, dim=1, **kwargs):
if self.normalize_embeddings:
return self.normalize(embeddings, dim=dim, **kwargs)
return embeddings
def get_norm(self, embeddings, dim=1, **kwargs):
return torch.norm(embeddings, p=self.p, dim=dim, **kwargs)
def set_default_stats(
self, query_emb, ref_emb, query_emb_normalized, ref_emb_normalized
):
if self.collect_stats:
with torch.no_grad():
self.initial_avg_query_norm = torch.mean(
self.get_norm(query_emb)
).item()
self.initial_avg_ref_norm = torch.mean(self.get_norm(ref_emb)).item()
self.final_avg_query_norm = torch.mean(
self.get_norm(query_emb_normalized)
).item()
self.final_avg_ref_norm = torch.mean(
self.get_norm(ref_emb_normalized)
).item()
def check_shapes(self, query_emb, ref_emb):
if query_emb.ndim != 2 or (ref_emb is not None and ref_emb.ndim != 2):
raise ValueError(
"embeddings must be a 2D tensor of shape (batch_size, embedding_size)"
)