# Diversity Condition Violated Correct Model

This is a notebook that performs the experiment when the diversity condition is violated and we select a correct model to adjust for that.

- $p(\tilde{z}|z) = \delta(u - u') \frac{e^{\kappa}\tilde{z}^{T}z}{\int_{W(z)}e^{\kappa z^{T}z'}dz'}$ - model conditional density
- $q_{h}(\tilde{z}|z) = \delta(u-u') \frac{e^{\frac{1}{\tau} h(\tilde{z})^{T}h(z)}}{\int_{W(z)}e^{\frac{1}{\tau}h(z)^{T}h(z')}dz'}$ - Model conditional density

The loss function in this setup will be the following:

$\mathcal{L} = \mathbb{E}_{(z, \tilde{z}) \sim p(z, \tilde{z}), \{z^{-}_{i}\} \sim U(\mathcal{W}(z)) }[- \log(\frac{e^{\frac{1}{\tau}h(\tilde{z})^{T}h(z)}}{\Sigma_{i=1}^{M}e^{\frac{1}{\tau} h(\tilde{z})^{T}h(z^{-}_{i})}})]$


In [130]:
%load_ext autoreload
%autoreload 2

import torch

N, M, d, d_fix = 100, 100, 5, 2

Z = torch.randn(N, d)
Z = Z / Z.norm(dim=1, keepdim=True)

# z_fixed = Z[:, :d_fix]
# z_var = Z[:, d_fix:]

# radii = torch.sqrt(1 - z_fixed.norm(dim=1) ** 2) # Compute subspace radii for each negative sample

# neg_samples = torch.randn(N, M, d - d_fix)
# neg_samples = neg_samples / neg_samples.norm(dim=2, keepdim=True)  # Normalize to 1

# neg_samples = neg_samples * radii.view(-1, 1, 1)
# neg_samples = torch.cat((z_fixed.unsqueeze(1).expand(-1, M, -1), neg_samples), dim=2)

# Sampling negative samples in the case where the model of the latent space has violated diversity condition, but our model adjusts for that
def sample_negative_samples(Z, M):
    z_fixed = Z[:, :d_fix]

    radii = torch.sqrt(1 - (z_fixed ** 2).sum(dim=1) )
    neg_samples = torch.randn(N, M, d - d_fix)
    neg_samples = neg_samples / (neg_samples.norm(dim=2, keepdim=True) + 0.1)
    neg_samples = neg_samples * radii.view(-1, 1, 1)

    return torch.cat((z_fixed.unsqueeze(1).expand(-1, M, -1), neg_samples), dim=2)


def compute_loss(Z, Z_pos, M, tau = 0.1):
    Z_n = sample_negative_samples(Z, M)
    neg = torch.logsumexp((Z.unsqueeze(1) * Z_n).sum(dim=2) / tau, dim=1).mean()
    pos = - (Z * Z_pos).sum(dim=1).mean() / tau

    return neg + pos, neg, pos


# Normalize both inputs
Z_input = torch.randn(N, d)
Z_input = Z_input / Z_input.norm(dim=1, keepdim=True)

Z_pos_input = torch.randn(N, d)
Z_pos_input = Z_pos_input / Z_pos_input.norm(dim=1, keepdim=True)

print(compute_loss(Z_input, Z_pos_input, 1000, tau=0.2))





The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
(tensor(10.1420), tensor(9.9930), tensor(0.1490))
