In [126]:
import torch
import random
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt
import importnb
import numpy as np

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [185]:
class IT_calculator():
    def __init__(self):
        self.alpha = 1.01

    def kernel(self, x, y, sigma=0.1):
        res = np.exp(-np.linalg.norm(x-y)**2/(2*sigma**2))
        return res

    def Corr(self, X):
        n = X.shape[0]  # n is the number of data
        A = torch.zeros([n,n]).to(device)
        d = X[0].numel()
        sigma = 5 * n ** (-1/(4+d))                 
        # sigma = 5 *d* n ** (-1/(4+d))             # Silverman's rule of Thumb  
        model = nn.DataParallel(nn.Module())
        for i in range(n):
            for j in range(n):
                A[i][j] = self.kernel(X[i], X[j], sigma)

        # Normalization
        D = A.diag()
        for i in range(n):
            for j in range(n):
                A[i][j] = 1/n * A[i][j] / torch.sqrt((D[i]*D[j]))
        return A

    def Entropy(self, X):
        eigvals = torch.linalg.eigvalsh(self.Corr(X))
        eigvals = torch.pow(torch.abs(eigvals), self.alpha)
        res = 1/(1 - self.alpha)*torch.log(eigvals.sum())/np.log(2)
        return res
    
    def JointE(self, X, Y):
        _ = torch.mul(self.Corr(X), self.Corr(Y))
        tr = torch.trace(_).item()**3
        # tr = torch.trace(_).item()
        _ = _/tr
        res = self.Entropy(_)
        return res

    def MI(self, X, Y):
        res = self.Entropy(X) + self.Entropy(Y) - self.JointE(X,Y)
        return res



In [144]:
# class SimpleCNN(nn.Module):
#     "Build a simple 3 layer CNN model"
#     def __init__(self):
#         super(SimpleCNN, self).__init__()
#         self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
#         self.relu1 = nn.ReLU()
#         self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
#         self.relu2 = nn.ReLU()
#         self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
#         self.fc1 = nn.Linear(7 * 7 * 32, 64)
#         self.relu3 = nn.ReLU()
#         self.fc2 = nn.Linear(64, 10)

#     def forward(self, x):
#         x = self.conv1(x)
#         x = self.relu1(x)
#         x = self.pool1(x)
#         x = self.conv2(x)
#         x = self.relu2(x)
#         x = self.pool2(x)
#         x = x.view(x.size(0), -1)
#         x = self.fc1(x)
#         x = self.relu3(x)
#         x = self.fc2(x)
#         return x

In [145]:
# model = SimpleCNN()
# model.load_state_dict(torch.load('SimpleCNN.pth'))

# transform = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.5,), (0.5,))
# ])
# trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# indices = random.sample(range(len(trainset)), 50)
# # Create a new DataLoader with the selected indices
# subset_trainloader = torch.utils.data.DataLoader(trainset, batch_size=1, shuffle=False, sampler=torch.utils.data.SubsetRandomSampler(indices))
# outputs = torch.empty(0)
# inputs = torch.empty(0)
# with torch.no_grad():
#     for images, labels in subset_trainloader:
#         inputs = torch.cat((inputs, images), dim=0)
#         intermediate_output = model.conv1(images)  # Output of the first convolutional layer
#         # intermediate_output = model.relu1(intermediate_output)  # Apply ReLU activation
#         # intermediate_output = model.pool1(intermediate_output)  # Apply max pooling
#         # intermediate_output = model.conv2(intermediate_output) 
#         outputs = torch.cat((outputs, intermediate_output), dim=0)

# IT = IT_calculator()
# JE = IT.JointE(inputs, outputs)

In [None]:
# E_inputs = IT.Entropy(inputs)
# print('E_inputs:', E_inputs)
# E_outputs = IT.Entropy(outputs)
# print('E_outputs:', E_outputs)

In [None]:
# IT = IT_calculator()
# def JointE(X, Y):
#     # print(IT.Corr(X))
#     # print(IT.Corr(Y))
#     # print('')
#     _ = torch.mul(IT.Corr(X), IT.Corr(Y))
#     # print(_)
#     tr = torch.trace(_).item()**3
#     # print(tr)
#     _ = _/tr
#     # print(_)
#     res = IT.Entropy(_)
#     return res
# print(JointE(inputs, outputs))