# Other Useful Functions

## Includes

In [None]:
# mass includes
import math
import torch as t
import numpy as np
from numpy.random import randint, uniform
from torch.nn.functional import softplus

## Random operations

In [None]:
# random cropping
def randCrop(in_img, crop_size):
    # generate random coordinates
    if isinstance(in_img, list):
        _, hei, wid = in_img[0].shape
    else:
        _, hei, wid = in_img.shape
    min_dim = min(crop_size, hei - 1, wid - 1)
    crop_y = randint(hei - min_dim)
    crop_x = randint(wid - min_dim)

    # crop to given size
    out_img = in_img[:, crop_y:crop_y + crop_size,
                     crop_x:crop_x + crop_size].copy()

    return out_img


# random horizontal flipping
def randHorFlip(in_img):
    out_img = []
    if uniform() > 0.5:
        out_img = np.fliplr(in_img).copy()
    else:
        out_img = in_img

    return out_img


# random vertical flipping
def randVerFlip(in_img):
    out_img = []
    if uniform() > 0.5:
        out_img = np.flipud(in_img).copy()

    else:
        out_img = in_img

    return out_img

## Custom functions

In [None]:
# code modified from https://github.com/rdevon/DIM
def log_sum_exp(x, axis=None):
    x_max = t.max(x, axis)[0]
    y = t.log((t.exp(x - x_max)).sum(axis)) + x_max

    return y


def get_positive_expectation(p_samples, measure, average=True):
    log_2 = math.log(2.0)

    if measure == 'GAN':
        Ep = -softplus(-p_samples)
    elif measure == 'JSD':
        Ep = log_2 - softplus(-p_samples)  # Note JSD will be shifted
    elif measure == 'X2':
        Ep = p_samples**2
    elif measure == 'KL':
        Ep = p_samples
    elif measure == 'RKL':
        Ep = -t.exp(-p_samples)
    elif measure == 'DV':
        Ep = p_samples
    elif measure == 'H2':
        Ep = 1. - t.exp(-p_samples)
    elif measure == 'W1':
        Ep = p_samples

    if average:
        return Ep.mean()
    else:
        return Ep


def get_negative_expectation(q_samples, measure, average=True):
    log_2 = math.log(2.0)

    if measure == 'GAN':
        Eq = softplus(-q_samples) + q_samples
    elif measure == 'JSD':
        Eq = softplus(
            -q_samples) + q_samples - log_2  # Note JSD will be shifted
    elif measure == 'X2':
        Eq = -0.5 * ((t.sqrt(q_samples**2) + 1.0)**2)
    elif measure == 'KL':
        Eq = t.exp(q_samples - 1.)
    elif measure == 'RKL':
        Eq = q_samples - 1.
    elif measure == 'DV':
        Eq = log_sum_exp(q_samples, 0) - math.log(q_samples.size(0))
    elif measure == 'H2':
        Eq = t.exp(q_samples) - 1.0
    elif measure == 'W1':
        Eq = q_samples

    if average:
        return Eq.mean()
    else:
        return Eq


# f-divergence distance between positive and negative joint distributions
def fenchel_dual_loss(l, m, measure=None):
    assert measure in ['GAN', 'JSD', 'X2', 'KL', 'RKL', 'DV', 'H2',
                       'W1'], print('Invalid measure: %s' % measure)

    N, units, n_locals = l.size()

    # First we make the input tensors the right shape.
    l = l.view(N, units, n_locals)
    l = l.permute(0, 2, 1)
    l = l.reshape(-1, units)

    m = m.view(N, units, 1)
    m = m.permute(0, 2, 1)
    m = m.reshape(-1, units)

    # Outer product, we want a N x N x n_local x 1 tensor.
    u = t.mm(m, l.t())
    u = u.reshape(N, 1, N, n_locals).permute(0, 2, 3, 1)

    # Since we have a big tensor with both positive and negative samples, we need to mask.
    mask = t.eye(N).to(l.device)
    n_mask = 1 - mask

    # Compute the positive and negative score. Average the spatial locations.
    E_pos = get_positive_expectation(u, measure, average=False).mean(2).mean(2)
    E_neg = get_negative_expectation(u, measure, average=False).mean(2).mean(2)

    # Mask positive and negative terms for positive and negative parts of loss
    E_pos = (E_pos * mask).sum() / mask.sum()
    E_neg = (E_neg * n_mask).sum() / n_mask.sum()
    loss = E_neg - E_pos

    return loss


# loss function for deep infomax
class DIMLoss(t.nn.Module):

    def __init__(self, measure='JSD', scale=1.0, l2_penalty=0.0):
        super(DIMLoss, self).__init__()
        self.measure = measure
        self.l2_penalty = l2_penalty
        self.scale = scale

    def forward(self, loc_feats, glb_feats):
        loss = self.scale * fenchel_dual_loss(
            loc_feats, glb_feats, measure=self.measure)

        if self.l2_penalty > 0.:
            loss = loss + self.l2_penalty * (glb_feats**2).sum(1).mean()

        return loss