In [None]:
import sys, os
project_dir = os.path.split(os.getcwd())[0]
if project_dir not in sys.path:
    sys.path.append(project_dir)

import torch
from torch import Tensor, nn
from IPDL import MatrixEstimator, ClassificationInformationPlane, AutoEncoderInformationPlane
from IPDL.optim import AligmentOptimizer, SilvermanOptimizer

import torchvision
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, ToTensor, Lambda
from torch.nn.functional import one_hot
from tqdm import tqdm

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
print(device)

# RBF Kernel

In [None]:
from IPDL import TensorKernel

def RBF(x: Tensor, sigma: float) -> Tensor:
    '''
        Tensor Based Radial Basis Function (RBF) Kernel

        @param x: Tensor shape (n, features) or (batch, n, features)
        @param sigma
    '''
    assert x.ndim > 1 and x.ndim < 4, "The dimension of X must be 2 or 3"
    pairwise_difference = (torch.unsqueeze(x,x.ndim-1) - torch.unsqueeze(x,x.ndim-2))**2
    distance = torch.sum(pairwise_difference, dim=x.ndim)
    return torch.exp(-distance / (2*(sigma**2)) )

In [None]:
sigma = 5
a = torch.rand(4, 32, 128).to(device)

In [None]:
rbf_result = RBF(a, sigma)
A = rbf_result / rbf_result.size(-1)

In [None]:
print("Test RBF")
for i in range(len(a)):
    _result = TensorKernel.RBF(a[i], sigma)
    print(torch.all(torch.isclose(_result, rbf_result[i])))

print("Test A")
len_x = a.size(1)
for i in range(len(a)):
    _result = TensorKernel.RBF(a[i], sigma) 
    _A = _result / len_x
    print(torch.all(torch.isclose(_A, A[i])))

# Entropy

In [None]:
# sigma = 5
# a = torch.rand(32, 128)

In [None]:
from IPDL import MatrixBasedRenyisEntropy as MRE

def entropy(A: Tensor) -> float:
    eigval, _ = torch.linalg.eigh(A)        
    epsilon = 1e-8
    eigval = eigval.abs() + epsilon 
    return -torch.sum(eigval*(torch.log2(eigval)), dim=eigval.ndim-1)

In [None]:
entropy_result = entropy(A)
_entropy_result = []
for i in range(len(a)):
    _entropy_result.append(MRE.entropy(A[i]))

print(torch.isclose(torch.hstack(_entropy_result), entropy_result))

In [None]:
print(torch.hstack(_entropy_result))
print(entropy_result)

torch.hstack(_entropy_result) == entropy_result

# Joint entropy

In [None]:
sigma = 5
a = torch.rand(32, 128).to(device)
b = torch.rand(4, 32, 128).to(device)

rbf_result = RBF(a, sigma)
A = rbf_result / rbf_result.size(-1)

rbf_result = RBF(b, sigma)
B =  rbf_result / rbf_result.size(-1)

In [None]:
def jointEntropy(Kx: Tensor, *args: Tensor) -> float:
    '''
        Parameters
        ----------
            Kx: Tensor

            args: More tensors!!!
    '''
    A = Kx.clone()
    for val in args:
        A = A * val
    
    A = A/A.trace() if A.ndim == 2 else A/(torch.sum(A.diagonal(offset=0, dim1=-1, dim2=-2), dim=1).reshape(-1,1,1))
    return entropy(A)

In [None]:
from IPDL import MatrixBasedRenyisEntropy as MRE

je_result = jointEntropy(A, B)
_je_result = []
for i in range(len(B)):
    _je_result.append(MRE.jointEntropy(A, B[i]))

print(torch.isclose(torch.hstack(_je_result), je_result))
print(torch.stack(_je_result) == je_result)

In [None]:
je_result2 = jointEntropy(B, A)
print(je_result)
print(je_result2)
print(je_result == je_result2)

# Mutual Information

In [None]:
from IPDL import MatrixBasedRenyisEntropy as MRE

def mutualInformation(Ax: Tensor, Ay: Tensor) -> float:
    entropy_Ax = entropy(Ax)
    entropy_Ay = entropy(Ay)
    joint_entropy = jointEntropy(Ax, Ay)
    return (entropy_Ax + entropy_Ay - joint_entropy)

In [None]:
sigma = 5
a = torch.rand(12, 128).to(device)
b = torch.rand(4, 12, 128).to(device)

rbf_result = RBF(a, sigma)
A = rbf_result / rbf_result.size(-1)

rbf_result = RBF(b, sigma)
B =  rbf_result / rbf_result.size(-1)

In [None]:
entropy_Ax = entropy(A)
entropy_Ay = entropy(B)
joint_entropy = jointEntropy(A, B)

print(entropy_Ax.shape)
print(entropy_Ay.shape)
print(joint_entropy.shape)

In [None]:
entropy_Ax + entropy_Ay - joint_entropy

In [None]:
mi_estimation = mutualInformation(A, B)

_mi_estimation = []
for i in range(len(B)):
    _mi_estimation.append(MRE.mutualInformation(A, B[i]))

print(torch.isclose(torch.stack(_mi_estimation), mi_estimation, rtol=1e-5))

In [None]:
print(torch.stack(_mi_estimation))
print(mi_estimation)

In [None]:
print(A.shape)
print(B.shape)

In [None]:
print(jointEntropy(A, B))
print(jointEntropy(B, A))
mi_estimation2 = mutualInformation(B, A)

mi_estimation == mi_estimation2


In [None]:
mi_estimation2

In [None]:
print(torch.stack(_mi_estimation))
print(mi_estimation)