# Other Useful Functions

## Includes

In [None]:
# mass includes
import math
import torch as t
import numpy as np
from numpy.random import randint, uniform
from torch.nn.functional import softplus

## Random operations

In [None]:
# random cropping
def randCrop(in_img, crop_size):
    # generate random coordinates
    _, hei, wid = in_img.shape
    min_dim = min(crop_size, hei - 1, wid - 1)
    crop_y = randint(hei - min_dim)
    crop_x = randint(wid - min_dim)

    # crop to given size
    out_img = in_img[:, crop_y:crop_y + crop_size,
                     crop_x:crop_x + crop_size].copy()

    return out_img


# random flipping
def randFlip(in_img):
    out_img = in_img.copy()

    # horizontal
    if uniform() > 0.5:
        out_img = np.flip(out_img, axis=-1)

    # vertical
    if uniform() > 0.5:
        out_img = np.flip(out_img, axis=-2)

    return out_img

## Custom functions

In [None]:
# f-divergence distance between positive and negative joint distributions
def calcDIMLoss(loc_feats, glb_feats, loc_masks):
    b, n, n_feats, n_locs = loc_feats.size()

    # make the input tensors the right shape
    loc_feats = loc_feats.permute(0, 1, 3, 2)
    glb_feats = glb_feats.permute(0, 1, 3, 2)
    loc_masks = loc_masks.permute(0, 1, 3, 2)

    # compute the positive scores
    err_pos_sum = 0.0
    err_pos_cnt = 1e-6
    for b_idx in range(b):
        b_loc_feats = loc_feats[b_idx, :, :, :].view(n * n_locs, n_feats)
        b_glb_feats = glb_feats[b_idx, :, :, :].view(-1, n_feats)
        b_loc_masks = loc_masks[b_idx, :, :, :].view(-1, 1)

        # outer product scoring
        u = t.mm(b_loc_feats, b_glb_feats.t())

        # loss for positive samples (JSD)
        err_pos = (math.log(2.0) - softplus(-u)) * b_loc_masks

        # add to total loss
        err_pos_sum = err_pos_sum + t.sum(err_pos)
        err_pos_cnt = err_pos_cnt + t.sum(b_loc_masks) * n

    # compute negative scores
    err_neg_sum = 0.0
    err_neg_cnt = 1e-6
    for n_idx in range(n):
        b_loc_feats = loc_feats[:, n_idx, :, :].reshape(b * n_locs, n_feats)
        b_glb_feats = glb_feats[:, n_idx, :, :].reshape(-1, n_feats)
        b_loc_masks = loc_masks[:, n_idx, :, :].reshape(-1, 1)

        # outer product scoring
        u = t.mm(b_loc_feats, b_glb_feats.t())

        # loss for negative samples (JSD)
        err_neg = (softplus(-u) + u - math.log(2.0)) * b_loc_masks

        # remove self MI
        neg_mask = 1 - t.eye(b).to(err_neg.device)
        err_neg = err_neg.reshape(b, n_locs, b, 1).permute(0, 2, 1, 3)
        err_neg = err_neg.sum(2).sum(2) * neg_mask

        # add to total loss
        err_neg_sum = err_neg_sum + t.sum(err_neg)
        err_neg_cnt = err_neg_cnt + t.sum(b_loc_masks) * (b - 1)

    # final loss
    out_loss = err_neg_sum / err_neg_cnt - err_pos_sum / err_pos_cnt

    return out_loss