In [None]:
# HDI_v2
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.feature_selection import mutual_info_regression
import numpy as np

# Define the target correlations
target_correlations = [0.9, 0.7, 0.5, 0.3, 0.1]

def rho_to_mi(rho, dim):
    result = -dim / 2 * np.log(1 - rho **2)
    return result

##
target_MIs=[]
for rho in target_correlations:
    mi = rho_to_mi(rho,dim=1)
    target_MIs.append(round(mi, 2))

# print('target_MIs= ',target_MIs) # target_MIs=  [0.83, 0.34, 0.14, 0.05, 0.01]


# Define the correlation loss
def correlation_loss(tensor, target_correlations):
    level_1 = tensor[0]  # Level 1
    loss = 0.0
    correlations = []
    for i, target_corr in enumerate(target_correlations):
        level_i = tensor[i + 1]  # Other levels
        level_1_flat = level_1.flatten(start_dim=0, end_dim=1)
        level_i_flat = level_i.flatten(start_dim=0, end_dim=1)

        # Compute Pearson correlation
        cov = torch.mean((level_1_flat - level_1_flat.mean()) * (level_i_flat - level_i_flat.mean()))
        std_1 = level_1_flat.std()
        std_i = level_i_flat.std()
        corr = cov / (std_1 * std_i + 1e-8)
        correlations.append(corr.item())  # Store correlation
        loss += (corr - target_corr) ** 2
    return loss, correlations

# Define the independence loss
def independence_loss(tensor):
    loss = 0.0
    levels, batch_size, features = tensor.shape
    avg_correlations = []
    for level in tensor:
        level_corrs = []
        for i in range(features):
            for j in range(i + 1, features):
                col_i = level[:, i]
                col_j = level[:, j]
                cov = torch.mean((col_i - col_i.mean()) * (col_j - col_j.mean()))
                std_i = col_i.std()
                std_j = col_j.std()
                corr = cov / (std_i * std_j + 1e-8)
                level_corrs.append(corr.item())  # Store correlation
                loss += corr ** 2  # Penalize non-zero correlations
        avg_correlations.append(sum(level_corrs) / len(level_corrs))  # Average correlation for this level
    return loss, avg_correlations

# Define the mutual information loss
# def mutual_information_loss(tensor, target_MIs):
#     level_1 = tensor[0].detach().cpu().numpy()  # Level 1
#     loss = 0.0
#     mutual_infos = []
#     for i, target_corr in enumerate(target_MIs):
#         level_i = tensor[i + 1].detach().cpu().numpy()  # Other levels

#         # Flatten the tensors for mutual information calculation
#         level_1_flat = level_1.reshape(-1)
#         level_i_flat = level_i.reshape(-1)

#         # Compute mutual information using sklearn
#         mi = mutual_info_regression(level_1_flat.reshape(-1, 1), level_i_flat, random_state=42)
#         mi_value = mi[0] / np.log(2)  # Normalize MI to the range [0, 1]

#         mutual_infos.append(mi_value)  # Store normalized mutual information
#         loss += (mi_value - target_corr) ** 2  # Penalize deviation from target
#     return loss, mutual_infos

def mutual_information_loss(tensor, target_MIs):
    level_1 = tensor[0].detach().cpu().numpy()  # Level 1
    loss = 0.0
    mutual_infos = []
    lambda_reg = 0.01  # Regularization parameter

    for i, target_corr in enumerate(target_MIs):
        level_i = tensor[i + 1].detach().cpu().numpy()  # Other levels

        # Flatten the tensors for mutual information calculation
        level_1_flat = level_1.reshape(-1)
        level_i_flat = level_i.reshape(-1)

        # Compute mutual information using sklearn
        mi = mutual_info_regression(level_1_flat.reshape(-1, 1), level_i_flat, random_state=42)
        mi_value = mi[0] / np.log(2)  # Normalize MI to the range [0, 1]

        mutual_infos.append(mi_value)  # Store normalized mutual information
        weight = 1.0 if mi_value < target_corr else 0.5  # Dynamic weighting
        loss += weight * (mi_value - target_corr) ** 2  # Penalize deviation from target

    # Add regularization
    loss += lambda_reg * np.sum(np.square(mutual_infos))  # L2 regularization

    return loss, mutual_infos


# Compute mutual information between all pairs of levels
def mutual_information_between_levels(tensor):
    levels, batch_size, features = tensor.shape
    mi_matrix = np.zeros((levels, levels))  # MI matrix to store MI for all pairs of levels

    # Compute MI for each pair of levels
    for i in range(levels):
        for j in range(i + 1, levels):  # Only compute for i < j to avoid redundancy
            level_i = tensor[i].detach().cpu().numpy().reshape(-1)  # Flatten level i
            level_j = tensor[j].detach().cpu().numpy().reshape(-1)  # Flatten level j

            # Compute mutual information using sklearn
            mi = mutual_info_regression(level_i.reshape(-1, 1), level_j, random_state=42)
            mi_value = mi[0] / np.log(2)  # Normalize MI to the range [0, 1]

            # Store MI value in the matrix
            mi_matrix[i, j] = mi_value
            mi_matrix[j, i] = mi_value  # Symmetric matrix

    return mi_matrix

# Define the model
class CorrelationModel(nn.Module):
    def __init__(self, levels, batch_size, features):
        super(CorrelationModel, self).__init__()
        self.levels = levels
        self.batch_size = batch_size
        self.features = features
        self.transform = nn.Parameter(torch.randn(levels, batch_size, features))

    def forward(self):
        return self.transform

# Hyperparameters
levels = 6
batch_size = 256
features = 5
learning_rate = 0.0025
num_epochs = 500

# Initialize the model, optimizer, and loss function
model = CorrelationModel(levels, batch_size, features)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Track metrics for plotting
# Track metrics for plotting
correlation_history = []  # To track correlations between levels
independence_history = []  # To track average independence correlations
mutual_information_history = []  # To track mutual information between levels

# Track MI between all pairs of levels
mi_between_levels_history = []  # To store MI matrices for all epochs


# Initialize lists to track losses
corr_loss_history = []  # To track correlation loss
indep_loss_history = []  # To track independence loss
total_loss_history = []  # To track total loss
mi_loss_history = []


for epoch in range(num_epochs):
    optimizer.zero_grad()
    output = model()  # Forward pass

    # Compute the losses
    corr_loss, correlations = correlation_loss(output, target_correlations)
    indep_loss, avg_independence_corrs = independence_loss(output)
    mi_loss, mutual_infos = mutual_information_loss(output, target_MIs)

    # Compute MI between all pairs of levels
    mi_matrix = mutual_information_between_levels(output)
    mi_between_levels_history.append(mi_matrix)  # Store MI matrix for this epoch

    # Combine losses
    # total_loss = corr_loss + 0.1 * indep_loss + 0.1 * mi_loss
    total_loss = corr_loss + 0.1 * indep_loss +  mi_loss

    # Backward pass and optimization
    total_loss.backward()
    optimizer.step()

    # Append metrics to history lists
    correlation_history.append(correlations)  # Append correlations for this epoch
    independence_history.append(avg_independence_corrs)  # Append independence correlations for this epoch
    mutual_information_history.append(mutual_infos)  # Append mutual information for this epoch


    # Track losses
    corr_loss_history.append(corr_loss.item())
    indep_loss_history.append(indep_loss.item())
    total_loss_history.append(total_loss.item())
    mi_loss_history.append(mi_loss.item())

    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"\nEpoch [{epoch + 1}/{num_epochs}], Correlation Loss: {corr_loss.item():.4f}, Independence Loss: {indep_loss.item():.4f}, Mutual Information Loss: {mi_loss:.4f}, Total Loss: {total_loss.item():.4f}")
        # print(f"Correlations between Level 1 and other levels: {correlations}")
        # print(f"Mutual Information between Level 1 and other levels: {mutual_infos}")
        # print(f"Average independence correlations (within levels): {avg_independence_corrs}")
        # print(f"Mutual Information Matrix (Levels):\n{mi_matrix}")



##########################
fig, axs = plt.subplots(1, 4, figsize=(18, 6))

# Plot Correlation Loss
axs[0].plot(corr_loss_history, label='MI Loss', color='blue', linewidth=4)
axs[0].set_title("Correlation Loss Over Epochs")
axs[0].set_xlabel("Epochs")
axs[0].set_ylabel("Loss")
axs[0].legend()
axs[0].grid()

# Plot Independence Loss
axs[1].plot(indep_loss_history, label='Independence Loss', color='orange', linewidth=4)
axs[1].set_title("Independence Loss Over Epochs")
axs[1].set_xlabel("Epochs")
axs[1].set_ylabel("Loss")
axs[1].legend()
axs[1].grid()

# Plot Total Loss
axs[2].plot(total_loss_history, label='Total Loss', color='green', linewidth=4)
axs[2].set_title("Total Loss Over Epochs")
axs[2].set_xlabel("Epochs")
axs[2].set_ylabel("Loss")
axs[2].legend()
axs[2].grid()

# Plot MI Loss
axs[3].plot(mi_loss_history, label='Total Loss', color='green', linewidth=4)
axs[3].set_title("MI Loss Over Epochs")
axs[3].set_xlabel("Epochs")
axs[3].set_ylabel("Loss")
axs[3].legend()
axs[3].grid()

#################################################
# Convert tracked metrics to tensors for easier plotting
correlation_history = torch.tensor(correlation_history)  # Shape: (num_epochs, len(target_correlations))
independence_history = torch.tensor(independence_history)  # Shape: (num_epochs, levels)
mutual_information_history = torch.tensor(mutual_information_history)  # Shape: (num_epochs, len(target_correlations))


# Plot correlations between levels
##################################
# Plot correlations between levels
colors = ['r', 'g', 'b', 'orange', 'purple']  # Corresponding to target correlations

plt.figure(figsize=(12, 6))

# Plot each level with the corresponding color
for i in range(len(target_correlations)):
    plt.plot(correlation_history[:, i],
             label=f"Level {i + 2} (Target: {target_correlations[i]})",
             linewidth=2,
             color=colors[i])  # Use the defined colors

# Draw horizontal lines with corresponding colors
plt.axhline(y=0.9, color='r', linestyle='--', label="Target Correlation (Level 2)", linewidth=2)
plt.axhline(y=0.7, color='g', linestyle='--', label="Target Correlation (Level 3)", linewidth=2)
plt.axhline(y=0.5, color='b', linestyle='--', label="Target Correlation (Level 4)", linewidth=2)
plt.axhline(y=0.3, color='orange', linestyle='--', label="Target Correlation (Level 5)", linewidth=2)
plt.axhline(y=0.1, color='purple', linestyle='--', label="Target Correlation (Level 6)", linewidth=2)

plt.title("Correlation Between Level 1 and Other Levels Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Correlation")
plt.legend()
plt.grid()
plt.show()
##########################################
# Plot average independence correlations within levels
plt.figure(figsize=(12, 6))
for i in range(levels):
    plt.plot(independence_history[:, i], label=f"Level {i + 1}", linewidth=2)  # Adjust the linewidth as needed

plt.title("Average Independence Correlations (Within Levels) Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Average Correlation")
plt.legend()
plt.grid()
plt.show()


########################################
# Plot mutual information between levels
colors = ['r', 'g', 'b', 'orange', 'purple']  # Corresponding to target correlations
plt.figure(figsize=(12, 6))

# Plot each level with the corresponding color
for i in range(len(target_MIs)):
    plt.plot(mutual_information_history[:, i],
             label=f"Level {i + 2} (Target: {target_MIs[i]})",
             linewidth=2,
             color=colors[i])  # Use the defined colors

# Draw horizontal lines with corresponding colors
plt.axhline(y=0.83, color='r', linestyle='--', label="Target MI (Level 2)", linewidth=2)
plt.axhline(y=0.34, color='g', linestyle='--', label="Target MI (Level 3)", linewidth=2)
plt.axhline(y=0.14, color='b', linestyle='--', label="Target MI (Level 4)", linewidth=2)
plt.axhline(y=0.05, color='orange', linestyle='--', label="Target MI (Level 5)", linewidth=2)
plt.axhline(y=0.01, color='purple', linestyle='--', label="Target MI (Level 6)", linewidth=2)

plt.title("Correlation Between Level 1 and Other Levels Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Correlation")
plt.legend()
plt.grid()
plt.show()



##########################################
import seaborn as sns
#
# Plot MI heatmap for the last epoch
plt.figure(figsize=(10, 8))
sns.heatmap(mi_between_levels_history[-1], annot=True, fmt=".2f", cmap="coolwarm", xticklabels=[f"Level {i+1}" for i in range(levels)], yticklabels=[f"Level {i+1}" for i in range(levels)])
plt.title("Mutual Information Between Levels (Last Epoch)")
plt.xlabel("Levels")
plt.ylabel("Levels")
plt.show()

###########################################
# Extract MI for all pairs of levels over epochs
level_pairs = [(i, j) for i in range(levels) for j in range(i + 1, levels)]  # All unique pairs of levels
mi_over_epochs = {pair: [mi_matrix[pair[0], pair[1]] for mi_matrix in mi_between_levels_history] for pair in level_pairs}

# Plot MI for each pair over epochs
plt.figure(figsize=(12, 6))
for pair, mi_values in mi_over_epochs.items():
    plt.plot(mi_values, label=f"Level {pair[0] + 1} vs Level {pair[1] + 1}", linewidth=2)  # Set linewidth to 2
plt.title("Mutual Information Between Level Pairs Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Mutual Information")
plt.legend()
plt.grid()
plt.show()

# Define adjacent level pairs
adjacent_level_pairs = [(i, i + 1) for i in range(levels - 1)]  # e.g., (1, 2), (2, 3), ..., (n-1, n)

# Extract MI for adjacent pairs of levels over epochs
mi_over_epochs_adjacent = {pair: [mi_matrix[pair[0], pair[1]] for mi_matrix in mi_between_levels_history] for pair in adjacent_level_pairs}

# Plot MI for each adjacent pair over epochs
plt.figure(figsize=(12, 6))
for pair, mi_values in mi_over_epochs_adjacent.items():
    plt.plot(mi_values, label=f"Level {pair[0] + 1} vs Level {pair[1] + 1}", linewidth=2)  # Set linewidth to 2
plt.title("Mutual Information Between Adjacent Level Pairs Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Mutual Information")
plt.legend()
plt.grid()
plt.show()