In [3]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
import itertools

In [4]:
def js_divergence(feature_map_1, feature_map_2, reduction='batchmean'):

    """Jensen-Shannon divergence"""

    # Convert feature maps to probability distributions
    p = F.softmax(feature_map_1, dim=1)
    q = F.softmax(feature_map_2, dim=1)

    # Compute the average distribution
    m = 0.5 * (p + q)

    # Compute KL divergence for both directions
    kl_pm = F.kl_div(torch.log(m + 1e-10), p, reduction=reduction)
    kl_qm = F.kl_div(torch.log(m + 1e-10), q, reduction=reduction)

    # Compute JS divergence
    js_div = 0.5 * (kl_pm + kl_qm)
    return js_div

In [11]:
# Dummy feature maps
feature_map_1 = torch.randn(25, 64, 5, 5)  # [batch, channels, height, width]
feature_map_2 = torch.randn(25, 64, 5, 5)
# Compute JS divergence
js_div = js_divergence(feature_map_1, feature_map_2)
print("Jensen-Shannon divergence:", js_div)

Jensen-Shannon divergence: tensor(4.9071)
