# Matching visual relationships

In [1]:
import torch
from torch_scatter import scatter_mean
from torch_geometric.data import Data, Batch
from torchvision.ops.boxes import box_iou

torch.manual_seed(42);

C_objects = 4
C_predicates = 3
H = 480
W = 640

In [2]:
def random_boxes(num_boxes, min_size=20):
    WH = torch.tensor([W, H]).float()
    x1y1 = torch.rand(size=(num_boxes, 2)) * (WH - min_size)
    x2y2 = x1y1 + min_size + (torch.rand(size=(num_boxes, 2)) * (WH - x1y1 - min_size))
    return torch.cat((x1y1, x2y2), dim=1)

def ground_truth(num_nodes, num_relations):
    return Data(
        num_nodes=num_nodes,
        n_nodes=num_nodes,
        n_edges=num_relations,
        
        object_classes=torch.randint(C_objects, size=(num_nodes,)),
        object_boxes=random_boxes(num_nodes),
        
        predicate_classes=torch.randint(C_predicates, size=(num_relations,)),
        relation_indexes=torch.randint(num_nodes, size=(2, num_relations)),
    )

t1 = ground_truth(3, 4)
t2 = ground_truth(5, 8)
targets = Batch.from_data_list([t1, t2])

print('g', 'so_idx', 'spo_class', sep='\t')
print('-------------------------')
for graph_idx, subj_idx, subj_class, predicate_class, obj_idx, obj_class in zip(
    targets.batch[targets.relation_indexes[0]].numpy(),
    targets.relation_indexes[0].numpy(),
    targets.object_classes[targets.relation_indexes[0]].numpy(),
    targets.predicate_classes.numpy(),
    targets.relation_indexes[1].numpy(),
    targets.object_classes[targets.relation_indexes[1]].numpy(),
):
    print(graph_idx, (subj_idx, obj_idx), (subj_class, predicate_class, obj_class), sep='\t')

g	so_idx	spo_class
-------------------------
0	(1, 2)	(3, 1, 0)
0	(1, 2)	(3, 2, 0)
0	(2, 1)	(0, 1, 3)
0	(1, 2)	(3, 0, 0)
1	(4, 4)	(3, 1, 3)
1	(3, 4)	(1, 0, 3)
1	(3, 5)	(1, 1, 3)
1	(3, 7)	(1, 1, 0)
1	(3, 7)	(1, 1, 0)
1	(4, 6)	(3, 2, 0)
1	(6, 6)	(0, 2, 0)
1	(3, 7)	(1, 0, 0)


## Predicate prediction

The first requirement for a match is that subject, object and predicate classes match:
- `subject_class`
- `predicate_class`
- `object_class`

Also, the boxes need to match. In this case, `predictions` and `targets` share the same object instances, 
so it's enough to match their indexes:
- `graph_index`
- `subject_index`
- `object_index`

Since the box indexes were offset when batching, `graph_index` is implicitly included in `subject_index` and `object_index`.

In [3]:
def prediction_using_gt_objects(num_relations, target: Data):
    return Data(
        num_nodes=target.num_nodes,
        n_nodes=target.n_nodes,
        n_edges=num_relations,
        
        object_classes=target.object_classes,
        object_boxes=target.object_boxes,
        
        predicate_classes=torch.randint(C_predicates, size=(num_relations,)),
        relation_indexes=torch.randint(target.num_nodes, size=(2, num_relations)),
        relation_scores=torch.rand(size=(num_relations,)).sort(descending=True)[0],
    )

p1 = prediction_using_gt_objects(10, target=t1)
p2 = prediction_using_gt_objects(10, target=t2)
predictions = Batch.from_data_list([p1, p2])

print('g', 'so_idx', 'spo_class', ' score', sep='\t')
print('--------------------------------------')
for graph_idx, subj_idx, subj_class, predicate_class, obj_idx, obj_class, rel_score in zip(
        predictions.batch[predictions.relation_indexes[0]].numpy(),
        predictions.relation_indexes[0].numpy(),
        predictions.object_classes[predictions.relation_indexes[0]].numpy(),
        predictions.predicate_classes.numpy(),
        predictions.relation_indexes[1].numpy(),
        predictions.object_classes[predictions.relation_indexes[1]].numpy(),
        predictions.relation_scores.numpy(),
):
    print(graph_idx, (subj_idx, obj_idx), (subj_class, predicate_class, obj_class), f'{rel_score:>6.1%}', sep='\t')

g	so_idx	spo_class	 score
--------------------------------------
0	(1, 1)	(3, 0, 3)	100.0%
0	(0, 1)	(2, 1, 3)	 65.4%
0	(2, 2)	(0, 2, 0)	 59.4%
0	(1, 2)	(3, 2, 0)	 33.4%
0	(1, 1)	(3, 0, 3)	 22.5%
0	(2, 0)	(0, 1, 2)	 18.2%
0	(0, 0)	(2, 2, 2)	 17.2%
0	(0, 2)	(2, 0, 0)	  7.6%
0	(1, 1)	(3, 1, 3)	  6.2%
0	(2, 0)	(0, 2, 2)	  3.4%
1	(3, 4)	(1, 2, 3)	 95.5%
1	(5, 6)	(3, 1, 0)	 95.5%
1	(7, 3)	(0, 0, 1)	 90.4%
1	(4, 4)	(3, 0, 3)	 75.8%
1	(5, 7)	(3, 2, 0)	 62.6%
1	(6, 3)	(0, 1, 1)	 44.5%
1	(4, 3)	(3, 2, 1)	 28.5%
1	(3, 7)	(1, 2, 0)	 13.3%
1	(4, 3)	(3, 1, 1)	 12.6%
1	(6, 5)	(0, 0, 3)	 10.4%


Use some broadcasting trick to compare all 5 fields across all `NxM` pairs at once.
To have a match, all 5 fields must match.

In [4]:
# [E_t, 5]
gt_matrix = torch.stack([
    # subject_idx, object_idx
    targets.batch[targets.relation_indexes[0]],
    targets.batch[targets.relation_indexes[1]],
    # subject_class, predicate_class, object_class
    targets.object_classes[targets.relation_indexes[0]],
    targets.predicate_classes,
    targets.object_classes[targets.relation_indexes[1]],
], dim=1)

# [E_p, 5]
pred_matrix = torch.stack([
    # subject_idx, object_idx
    predictions.batch[predictions.relation_indexes[0]],
    predictions.batch[predictions.relation_indexes[1]],
    # subject_class, predicate_class, object_class
    predictions.object_classes[predictions.relation_indexes[0]],
    predictions.predicate_classes,
    predictions.object_classes[predictions.relation_indexes[1]],
], dim=1)

# Block matrix [E_p, E_t]
matches = (gt_matrix[None, :, :] == pred_matrix[:, None, :]).all(dim=2)

`matches`
```
                  Target relations

            E_t_1 = 6      E_t_2 = 9
          +-----------+-----------------+  
          |x          |                 |  0
P         |           |                 |
r         |    x      |                 |
e         |           |    Different    |
d   E_p_1 |           |      graph      |
i     =   |  x        |                 |
c    10   |           |    NO MATCH     |
t         |           |                 |
e         |      x    |                 |
d         |           |                 |  9
          +-----------+-----------------+ 
r         |           |      x          | 10
e         |           |  x              |
l         |           |              x  |
a         | Different |                 |
t   E_p_2 |   graph   |            x    |
i     =   |           |                 |
o    10   | NO MATCH  |    x            |
n         |           |                 |
s         |           |                 |
          |           |                 | 19
          +-----------+-----------------+
```       

`offset`
```
          +-----------+-----------------+
          |0 ....... 0|10 ........... 10|
          +-----------+-----------------+
```

`matches.any(0)`
```
          +-----------+-----------------+
          |x x x x - -|- x x x - - x x -|
          +-----------+-----------------+
```

`matches.argmax(0) - offset`
```
          +-----------+-----------------+
          |0 5 2 8 - -|- 1 6 0 - - 4 2 -|
          +-----------+-----------------+
```

Were ground-truth relations retrieved among the predicted relations? 

If so, at what index in the list of predictions for that graph? (rank is zero-based)

In [5]:
# matches.argmax(dim=0) will return the last index if no True value is found.
# We can use matches.any(dim=0) to ignore those cases.
# Also, we must account for the row offset in the matches matrix.
gt_retrieved = matches.any(dim=0)

offset = predictions.n_edges.cumsum(dim=0).repeat_interleave(targets.n_edges) - predictions.n_edges[0]
gt_retrieved_rank = matches.int().argmax(dim=0) - offset

gt_relation_to_graph_assignment = targets.batch[targets.relation_indexes[0]]

print('g  retrieved  rank')
print('------------------')
for g, ret, rank in zip(gt_relation_to_graph_assignment.numpy(), gt_retrieved.numpy(), gt_retrieved_rank.numpy()):
    print(f'{g}  {"👍" if ret else "❌":^10} {rank if ret else "":>2}')

g  retrieved  rank
------------------
0      ❌        
0      👍       3
0      ❌        
0      ❌        
1      ❌        
1      ❌        
1      ❌        
1      ❌        
1      ❌        
1      👍       4
1      ❌        
1      ❌        


In [6]:
recall_per_graph = scatter_mean(gt_retrieved.float(), gt_relation_to_graph_assignment, dim=0, dim_size=targets.num_graphs)

print('g  recall')
print('---------')
for g, rec in enumerate(recall_per_graph.numpy()):
    print(f'{g}  {rec:>6.1%}')

g  recall
---------
0   25.0%
1   12.5%


## Phrase detection

The first requirement for a match is that subject, object and predicate classes match:
- `subject_class`
- `predicate_class`
- `object_class`

Also, the union of subject and object boxes need to match with IoU > .5.
- `graph_index`
- `iou_union > .5`

In [7]:
def noisy_boxes(boxes, H, W, scale=15):
    """Adds gaussian noise to the pixel coordinates"""
    res = boxes + scale * torch.randn_like(boxes)
    res[:, [0, 2]] = res[:, [0, 2]].clamp(min=0, max=W)
    res[:, [1, 3]] = res[:, [1, 3]].clamp(min=0, max=H)
    return res

def noisy_classes(classes, C, p=.2):
    """With probability p, assigns a uniformly random class, otherwise keeps the original"""
    return torch.where(
        torch.rand_like(classes, dtype=torch.float) < p, 
        torch.randint_like(classes, C), 
        classes
    )

def noisy_prediction(target: Data, topk: int):
    # For every object box, simulate 2 noisy detections
    object_boxes = noisy_boxes(target.object_boxes.repeat(2, 1), H, W)
    object_classes = noisy_classes(target.object_classes.repeat(2), C_objects)
    
    
    num_relations = C_predicates * len(object_boxes) * (len(object_boxes) - 1)
    predicate_classes=torch.randint(C_predicates, size=(num_relations,))
    relation_indexes=torch.randint(len(object_boxes), size=(2, num_relations))
    relation_scores=torch.rand(size=(len(predicate_classes),))
    relation_scores, topk_index = torch.topk(relation_scores, k=topk, dim=0)
    
    return Data(
        num_nodes=len(object_boxes),
        n_nodes=len(object_boxes),
        n_edges=len(relation_scores),

        object_classes=object_classes,
        object_boxes=object_boxes,

        predicate_classes=predicate_classes[topk_index],
        relation_indexes=relation_indexes[:, topk_index],
        relation_scores=relation_scores,
    )

p1 = noisy_prediction(t1, topk=20)
p2 = noisy_prediction(t2, topk=20)
predictions = Batch.from_data_list([p1, p2])

Matching subject, object and predicate classes same as before, but now we include `graph_idx` to distinghuish graphs.

In [8]:
# [E_p, 4]
pred_matrix = torch.stack([
    # graph_idx
    predictions.batch[predictions.relation_indexes[0]],
    # subject_class, predicate_class, object_class
    predictions.object_classes[predictions.relation_indexes[0]],
    predictions.predicate_classes,
    predictions.object_classes[predictions.relation_indexes[1]],
], dim=1)

# [E_t, 4]
gt_matrix = torch.stack([
    # graph_idx
    targets.batch[targets.relation_indexes[0]],
    # subject_class, predicate_class, object_class
    targets.object_classes[targets.relation_indexes[0]],
    targets.predicate_classes,
    targets.object_classes[targets.relation_indexes[1]],
], dim=1)

# Block matrix [E_p, E_t]
matches_class = (gt_matrix[None, :, :] == pred_matrix[:, None, :]).all(dim=2)

Compute the union box between corresponding boxes:
```
0------------------------------------>     0------------------------------------>
|                                    x     |                                    x
|   x1, y1 ------------+                   |   u1, u1 ---------------------+
|     |                |                   |     |                         |
|     |     x1, y1 ----|--------+      =>  |     |                         |
|     | a     |        |        |          |     |                         |
|     +------------- x2, y2     |          |     |                         |
|             |             b   |          |     |                         |
|             +-------------- x2, y2       |     +---------------------- u2, u2
V y                                        V y
```

In [9]:
def matched_boxlist_union(a, b):
    # Boxes are represented as [N, 4] tensors,
    # where the 4 corrdinates are in order (x1, y1, x2, y2)
    assert a.shape[0] == b.shape[0]
    assert a.shape[1] == b.shape[1] == 4
    
    union_top_left = torch.min(a[:, :2], b[:, :2])  # N x (x1, y1)
    union_bottom_right = torch.max(a[:, 2:], b[:, 2:])  # N x (x2, y2)
    boxes_union = torch.cat((union_top_left, union_bottom_right), dim=1)

    return boxes_union

# [E_p, 4]
pred_union_boxes = matched_boxlist_union(
    predictions.object_boxes[predictions.relation_indexes[0]],  # subj
    predictions.object_boxes[predictions.relation_indexes[1]],  # obj
)

# [E_t, 4]
gt_union_boxes = matched_boxlist_union(
    targets.object_boxes[targets.relation_indexes[0]],  # subj
    targets.object_boxes[targets.relation_indexes[1]],  # obj
)

# Full matrix [E_p, E_t]
iou_union = box_iou(pred_union_boxes, gt_union_boxes)
matches_iou_union = iou_union > .5

# Block matrix [E_p, E_t]
matches = matches_class & matches_iou_union

In [10]:
print('gt', 'pred', 'spo_class', '  t/p union box (iou)', sep='\t')
print('='*57)
for t in range(targets.n_edges.sum()):
    print(
        t,
        '',
        (
            targets.object_classes[targets.relation_indexes[0, t]].item(),
            targets.predicate_classes[t].item(),
            targets.object_classes[targets.relation_indexes[1, t]].item()
        ),
        f'{gt_union_boxes[t].int().numpy()}',
        sep='\t'
    )
    
    for p in matches[: ,t].nonzero().flatten().tolist():
        print(
            '>',
            f'{p:3d}',
            (
                predictions.object_classes[predictions.relation_indexes[0, p]].item(),
                predictions.predicate_classes[p].item(),
                predictions.object_classes[predictions.relation_indexes[1, p]].item()
            ),
            f'{pred_union_boxes[t].int().numpy()} '
            f'({iou_union[p, t]:.1%})',
            sep='\t'
        )
    print('-'*57)

gt	pred	spo_class	  t/p union box (iou)
0		(3, 1, 0)	[372 118 584 472]
---------------------------------------------------------
1		(3, 2, 0)	[372 118 584 472]
>	  5	(3, 2, 0)	[387 109 591 480] (86.3%)
---------------------------------------------------------
2		(0, 1, 3)	[372 118 584 472]
---------------------------------------------------------
3		(3, 0, 0)	[372 118 584 472]
>	  4	(3, 0, 0)	[344 121 539 430] (81.2%)
>	 18	(3, 0, 0)	[344 121 539 430] (83.3%)
---------------------------------------------------------
4		(3, 1, 3)	[209 372 553 416]
---------------------------------------------------------
5		(1, 0, 3)	[209 268 619 469]
---------------------------------------------------------
6		(1, 1, 3)	[358 268 619 469]
>	 32	(1, 1, 3)	[469 421 591 480] (57.2%)
---------------------------------------------------------
7		(1, 1, 0)	[393 167 619 469]
---------------------------------------------------------
8		(1, 1, 0)	[393 167 619 469]
-------------------------------------------------

In [11]:
# matches.argmax(dim=0) will return the last index if no True value is found.
# We can use matches.any(dim=0) to ignore those cases.
# Also, we must account for the row offset in the matches matrix.
gt_retrieved = matches.any(dim=0)

offset = predictions.n_edges.cumsum(dim=0).repeat_interleave(targets.n_edges) - predictions.n_edges[0]
gt_retrieved_rank = matches.int().argmax(dim=0) - offset

gt_relation_to_graph_assignment = targets.batch[targets.relation_indexes[0]]

print('g  retrieved  rank')
print('------------------')
for g, ret, rank in zip(gt_relation_to_graph_assignment.numpy(), gt_retrieved.numpy(), gt_retrieved_rank.numpy()):
    print(f'{g}  {"👍" if ret else "❌":^10} {rank if ret else "":>2}')

g  retrieved  rank
------------------
0      ❌        
0      👍       5
0      ❌        
0      👍      18
1      ❌        
1      ❌        
1      👍      12
1      ❌        
1      ❌        
1      ❌        
1      ❌        
1      👍      17


In [12]:
recall_per_graph = scatter_mean(gt_retrieved.float(), gt_relation_to_graph_assignment, dim=0, dim_size=targets.num_graphs)

print('g  recall')
print('---------')
for g, rec in enumerate(recall_per_graph.numpy()):
    print(f'{g}  {rec:>6.1%}')

g  recall
---------
0   50.0%
1   25.0%


## Relationship detection

The first requirement for a match is that subject, object and predicate classes match:
- `subject_class`
- `predicate_class`
- `object_class`

Also, the predicted subject/object boxes need to match with the ground-truth subject/object boxes with IoU > .5.
- `iou_subject > .5`
- `iou_object > .5`

In [13]:
p1 = noisy_prediction(t1, topk=50)
p2 = noisy_prediction(t2, topk=50)
predictions = Batch.from_data_list([p1, p2])

In [14]:
# [E_p, 4]
pred_matrix = torch.stack([
    # graph_idx
    predictions.batch[predictions.relation_indexes[0]],
    # subject_class, predicate_class, object_class
    predictions.object_classes[predictions.relation_indexes[0]],
    predictions.predicate_classes,
    predictions.object_classes[predictions.relation_indexes[1]],
], dim=1)

# [E_t, 4]
gt_matrix = torch.stack([
    # graph_idx
    targets.batch[targets.relation_indexes[0]],
    # subject_class, predicate_class, object_class
    targets.object_classes[targets.relation_indexes[0]],
    targets.predicate_classes,
    targets.object_classes[targets.relation_indexes[1]],
], dim=1)

# Block matrix [E_p, E_t]
matches_class = (gt_matrix[None, :, :] == pred_matrix[:, None, :]).all(dim=2)

# Two full matrices [E_p, E_t]
iou_subject = box_iou(
    predictions.object_boxes[predictions.relation_indexes[0]], 
    targets.object_boxes[targets.relation_indexes[0]]
)
iou_object = box_iou(
    predictions.object_boxes[predictions.relation_indexes[1]], 
    targets.object_boxes[targets.relation_indexes[1]]
)

# Block matrix [E_p, E_t]
matches = matches_class & (iou_subject > .5) & (iou_object > .5)

In [15]:
print('gt', 'pred', 'spo_class', 't/p subject box (iou)    ', 't/p object box (iou)', sep='\t')
print('='*95)
for t in range(targets.n_edges.sum()):
    print(
        t,
        '',
        (
            targets.object_classes[targets.relation_indexes[0, t]].item(),
            targets.predicate_classes[t].item(),
            targets.object_classes[targets.relation_indexes[1, t]].item()
        ),
        f'{targets.object_boxes[targets.relation_indexes[0, t]].int().numpy()}        ',
        f'{targets.object_boxes[targets.relation_indexes[1, t]].int().numpy()}        ',
        sep='\t'
    )
    
    for p in matches[: ,t].nonzero().flatten().tolist():
        print(
            '>',
            f'{p:3d}',
            (
                predictions.object_classes[predictions.relation_indexes[0, p]].item(),
                predictions.predicate_classes[p].item(),
                predictions.object_classes[predictions.relation_indexes[1, p]].item()
            ),
            f'{predictions.object_boxes[predictions.relation_indexes[0, p]].int().numpy()} '
            f'({iou_subject[p, t]:.1%})',
            f'{predictions.object_boxes[predictions.relation_indexes[1, p]].int().numpy()}',
            f'({iou_object[p, t]:.1%})',
            sep='\t'
        )
    print('-'*95)

gt	pred	spo_class	t/p subject box (iou)    	t/p object box (iou)
0		(3, 1, 0)	[372 118 539 435]        	[492 432 584 472]        
-----------------------------------------------------------------------------------------------
1		(3, 2, 0)	[372 118 539 435]        	[492 432 584 472]        
>	 38	(3, 2, 0)	[365 138 503 434] (70.2%)	[508 433 603 480]	(58.3%)
-----------------------------------------------------------------------------------------------
2		(0, 1, 3)	[492 432 584 472]        	[372 118 539 435]        
>	  7	(0, 1, 3)	[508 433 603 480] (58.3%)	[365 138 503 434]	(70.2%)
>	 36	(0, 1, 3)	[508 433 603 480] (58.3%)	[365 138 503 434]	(70.2%)
-----------------------------------------------------------------------------------------------
3		(3, 0, 0)	[372 118 539 435]        	[492 432 584 472]        
>	  5	(3, 0, 0)	[365 138 503 434] (70.2%)	[508 433 603 480]	(58.3%)
-----------------------------------------------------------------------------------------------
4		(3, 1, 3)	[209 3

In [16]:
# matches.argmax(dim=0) will return the last index if no True value is found.
# We can use matches.any(dim=0) to ignore those cases.
# Also, we must account for the row offset in the matches matrix.

# [E_t]
gt_retrieved = matches.any(dim=0)

offset = predictions.n_edges.cumsum(dim=0).repeat_interleave(targets.n_edges) - predictions.n_edges[0]
gt_retrieved_rank = matches.int().argmax(dim=0) - offset

gt_relation_to_graph_assignment = targets.batch[targets.relation_indexes[0]]

print('g  retrieved  rank')
print('------------------')
for g, ret, rank in zip(gt_relation_to_graph_assignment.numpy(), gt_retrieved.numpy(), gt_retrieved_rank.numpy()):
    print(f'{g}  {"👍" if ret else "❌":^10} {rank if ret else "":>2}')

g  retrieved  rank
------------------
0      ❌        
0      👍      38
0      👍      36
0      👍       5
1      👍      12
1      👍      28
1      👍      46
1      ❌        
1      ❌        
1      👍      49
1      ❌        
1      ❌        


In [17]:
recall_per_graph = scatter_mean(gt_retrieved.float(), gt_relation_to_graph_assignment, dim=0, dim_size=targets.num_graphs)

print('g  recall')
print('---------')
for g, rec in enumerate(recall_per_graph.numpy()):
    print(f'{g}  {rec:>6.1%}')

g  recall
---------
0   75.0%
1   50.0%


Let's also compute recall at various choices of `k` (using broadcasting).

In [18]:
K = torch.tensor([5, 10, 15, 25, 30, 50])

# [K, E_t]
gt_retrieved_at = (gt_retrieved_rank[None, :] < K[:, None]) & gt_retrieved[None, :]

# [K, num_graphs]
recall_at_per_graph = scatter_mean(gt_retrieved_at.float(), gt_relation_to_graph_assignment, dim=1, dim_size=targets.num_graphs)

print('g', *(f'  R@{k:<2}' for k in K))
print('-' * 43)
for g, rec in enumerate(recall_at_per_graph.unbind(dim=1)):
    print(g, *(f'{r:>6.1%}' for r in rec.numpy()))

g   R@5    R@10   R@15   R@25   R@30   R@50
-------------------------------------------
0   0.0%  25.0%  25.0%  25.0%  25.0%  75.0%
1   0.0%   0.0%  12.5%  12.5%  25.0%  50.0%
