In [66]:
"""Goal: create data structures / loops for xent calculation"""
import torch
d=torch.load('tmp_matching_loss.pt')

pred_mask_logits = d['pred_mask_logits']
instances = d['instances']
n_masks_per_roi = d['n_masks_per_roi']

In [90]:
n_images = len(instances)
assert n_images == 1

mask_side_len = pred_mask_logits.size(2)
total_num_masks = pred_mask_logits.size(0)

instances_per_image = instances[0]

gt_1 = instances_per_image.gt_masks.crop_and_resize(
         instances_per_image.proposal_boxes.tensor, mask_side_len).to(device=pred_mask_logits.device)
gt_2 = instances_per_image.gt_second_best_masks.crop_and_resize(
         instances_per_image.proposal_boxes.tensor, mask_side_len).to(device=pred_mask_logits.device) 

gt_2.shape

torch.Size([34, 28, 28])

In [103]:
gt_pairs = [torch.stack([g1, g2]) for g1, g2 in zip(gt_1, gt_2)]

torch.Size([28, 28])

In [144]:
indices = torch.arange(total_num_masks)
gt_classes = instances_per_image.gt_classes.to(dtype=torch.int64)

pred_mask_pairs = [[p[i::n_masks_per_roi, :, :][gt_class] for i in range(n_masks_per_roi)] for p, gt_class in zip(pred_mask_logits, gt_classes)]

len(pred_mask_pairs)
len(pred_mask_pairs[0])
pred_mask_pairs[0][0].shape, pred_mask_pairs[0][1].shape

(torch.Size([28, 28]), torch.Size([28, 28]))

In [147]:
from torch.nn import functional as F
instances_per_image[idx:(idx+1)]


for idx, (gt_pair, pred_pair) in enumerate(zip(gt_pairs, pred_mask_pairs)):
    xent_losses = torch.zeros((len(gt_pair), len(pred_pair)))
    for i, pred in enumerate(pred_pair):
        for j, gt in enumerate(gt_pair):
            xent_losses[i, j] = maskwise_mask_loss = F.binary_cross_entropy_with_logits(
                pred, gt.to(dtype=torch.float32), reduction='mean')

print(xent_losses)

tensor([[0.0671, 4.0934],
        [0.6836, 0.7042]], grad_fn=<CopySlices>)


In [152]:
import torch
from scipy import optimize
import numpy as np
from torch.nn import functional as F


def solve_matching_problem(cost_tensor: torch.Tensor):
    """
    Returns matching assignment, sorted by row index.
    """
    if torch is not None:
        assert type(cost_tensor) is np.ndarray or torch.is_tensor(cost_tensor)
    else:
        assert type(cost_tensor) is np.ndarray
    cost_tensor_for_assignment = cost_tensor.detach() if cost_tensor.requires_grad else cost_tensor
    row_ind, col_ind = optimize.linear_sum_assignment(cost_tensor_for_assignment)
    ind_idxs_sorted_by_row = np.argsort(row_ind)
    col_ind = [col_ind[idx] for idx in ind_idxs_sorted_by_row]
    return col_ind

match_cols = solve_matching_problem(xent_losses)
print(sum(xent_losses[[0, 1], match_cols]))
print(sum(xent_losses[[0, 1], [1 - m for m in match_cols]]))
print(xent_losses)

tensor(0.7714, grad_fn=<AddBackward0>)
tensor(4.7770, grad_fn=<AddBackward0>)
tensor([[0.0671, 4.0934],
        [0.6836, 0.7042]], grad_fn=<CopySlices>)


In [50]:
# """Goal: get pairwise xent"""
# import torch
# d=torch.load('tmp_pairwise_xent.pt')
# logit_sets, gt_sets, instances = d['logit_sets'], d['gt_sets'], d['instances'] # length of each is # images
# gt_sets = [
#     [[i.gt_masks[idx:(idx+1)], i.gt_second_best_masks[idx:(idx+1)]] for idx in range(len(i))]
#     for i in instances]
# gt_matching_sets = [
#     [[i.gt_masks[idx], i.gt_second_best_masks[idx]] for idx in range(len(i))]
#     for i in instances]


# logit_matching_sets = list(zip(*logit_sets))

# for z in (logit_matching_sets, gt_matching_sets): # (z1, z2):
#     el = z
#     descr = ''
#     tab = 1
#     while not torch.is_tensor(el) and tab < 6:
#         if isinstance(el, tuple) or isinstance(el, list):
#             descr += "\t".join('' for _ in range(tab)) + f"{type(el).__name__} of length {len(el)}.  Each element is:\n"
#         elif hasattr(el, '__len__'):
#             descr += "\t".join('' for _ in range(tab)) + f"{type(el).__name__} of length {len(el)}.  Each element is:\n"
#         else:
#             break
#         old_el = el
#         el = el[0]
#         tab += 1
#     if torch.is_tensor(el):
#         descr += "\t".join('' for _ in range(tab)) + f"{type(el).__name__} of shape {el.shape}."
#     else:
#         descr += "\t".join('' for _ in range(tab)) + f"{type(el).__name__}"

#     print(descr)

# #     I have (logit_sets):
# #         List[Tensor(34,80,28,28), Tensor(34,80,28,28)]
# #     I need (zip(*logit_sets)):
# #           List[Tuple(Tensor(80,28,28), Tensor(80,28,28)), ...x34]
# #     I have (gt_sets):
# #         List[[PolygonMasks(34)], [PolygonMasks(34)]]
# #     I need:
# #         List[ [Tuple(PolygonMasks(1), PolygonMasks(1)), ...x34] x#_images]
# instances_per_image = instances[0]
# gt_matching_sets = [
#     [[i.gt_masks[idx], i.gt_second_best_masks[idx]] for idx in range(len(i))]
#     for i in instances]
# gt_matching_sets_per_image = gt_matching_sets[0]
# for gt_tuple in gt_matching_sets_per_image:
#     print(len(gt_tuple), type(gt_tuple[0]), len(gt_tuple[0]))

# pred_mask_logit_matching_sets_per_image = list(zip(*logit_sets))    
# for s in pred_mask_logit_matching_sets_per_image:
#     print(len(s), type(s[0]), len(s[0]))


# pred_mask_logits = logit_sets[0]
# gt_masks_raw = gt_sets[0]

# cls_agnostic_mask = pred_mask_logits.size(1) == 1
# total_num_masks = pred_mask_logits.size(0)
# mask_side_len = pred_mask_logits.size(2)
# assert pred_mask_logits.size(2) == pred_mask_logits.size(3), "Mask prediction must be square!"

# gt_classes = []
# gt_masks = []
# for instances_per_image, gt_masks_per_image in zip(instances, gt_masks_raw):
#     if len(instances_per_image) == 0:
#         continue
#     if not cls_agnostic_mask:
#         gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64)
#         gt_classes.append(gt_classes_per_image)

#     gt_masks_per_image = gt_masks_per_image.crop_and_resize(
#         instances_per_image.proposal_boxes.tensor, mask_side_len).to(device=pred_mask_logits.device)
#     # A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len
#     gt_masks.append(gt_masks_per_image)
