In [None]:
# HDI-v3
# 14030902

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from sklearn.feature_selection import mutual_info_regression
import numpy as np

# Define the target correlations
target_correlations = [0.9, 0.7, 0.5, 0.3, 0.1]
# Define target dependence values (adjust as needed)


target_dependence = [0.85, 0.65, 0.45, 0.25, 0.15]  # Example values
# target_dependence = [0.5, 0.5, 0.5, 0.5, 0.5]  # Example values

# Under the Gaussian distribution, the correlation coefficient and mutual information have one-to-one mapping:
def rho_to_mi(rho, dim):
    result = -dim / 2 * np.log(1 - rho **2)
    return result


def mi_to_rho(mi, dim):
    result = np.sqrt(1 - np.exp(-2 * mi / dim))
    return result


##
target_MIs=[]
for rho in target_correlations:
    mi = rho_to_mi(rho,dim=1)
    target_MIs.append(round(mi, 2))

# print('target_MIs= ',target_MIs) # target_MIs=  [0.83, 0.34, 0.14, 0.05, 0.01]

# Define the correlation loss (no changes here)
def correlation_loss(tensor, target_correlations):
    level_1 = tensor[0]  # Level 1
    loss = 0.0
    correlations = []
    for i, target_corr in enumerate(target_correlations):
        level_i = tensor[i + 1]  # Other levels
        level_1_flat = level_1.flatten(start_dim=0, end_dim=1)
        level_i_flat = level_i.flatten(start_dim=0, end_dim=1)

        # Compute Pearson correlation
        cov = torch.mean((level_1_flat - level_1_flat.mean()) * (level_i_flat - level_i_flat.mean()))
        std_1 = level_1_flat.std()
        std_i = level_i_flat.std()
        corr = cov / (std_1 * std_i + 1e-8)
        correlations.append(corr.item())  # Store correlation
        loss += (corr - target_corr) ** 2
    return loss, correlations

# Define the independence loss (no changes here)
def independence_loss(tensor):
    loss = 0.0
    levels, batch_size, features = tensor.shape
    avg_correlations = []
    for level in tensor:
        level_corrs = []
        for i in range(features):
            for j in range(i + 1, features):
                col_i = level[:, i]
                col_j = level[:, j]
                cov = torch.mean((col_i - col_i.mean()) * (col_j - col_j.mean()))
                std_i = col_i.std()
                std_j = col_j.std()
                corr = cov / (std_i * std_j + 1e-8)
                level_corrs.append(corr.item())  # Store correlation
                loss += corr ** 2  # Penalize non-zero correlations
        avg_correlations.append(sum(level_corrs) / len(level_corrs))  # Average correlation for this level
    return loss, avg_correlations

# Define the mutual information loss (no changes here)
def mutual_information_loss(tensor, target_MIs):
    level_1 = tensor[0].detach().cpu().numpy()  # Level 1
    loss = 0.0
    mutual_infos = []
    lambda_reg = 0.01  # Regularization parameter

    for i, target_corr in enumerate(target_MIs):
        level_i = tensor[i + 1].detach().cpu().numpy()  # Other levels

        # Flatten the tensors for mutual information calculation
        level_1_flat = level_1.reshape(-1)
        level_i_flat = level_i.reshape(-1)

        # Compute mutual information using sklearn
        mi = mutual_info_regression(level_1_flat.reshape(-1, 1), level_i_flat, random_state=42)
        mi_value = mi[0] / np.log(2)  # Normalize MI to the range [0, 1]

        mutual_infos.append(mi_value)  # Store normalized mutual information
        weight = 1.0 if mi_value < target_corr else 0.5  # Dynamic weighting
        loss += weight * (mi_value - target_corr) ** 2  # Penalize deviation from target

    # Add regularization
    loss += lambda_reg * np.sum(np.square(mutual_infos))  # L2 regularization

    return loss, mutual_infos

############################################################################
# Define the distance correlation function
def distance_correlation(x, y):
    """
    Compute the distance correlation between two tensors x and y.
    """
    # Center the data
    x = x - x.mean()
    y = y - y.mean()

    # Compute pairwise distances
    a = torch.cdist(x.unsqueeze(0), x.unsqueeze(0), p=2).squeeze()
    b = torch.cdist(y.unsqueeze(0), y.unsqueeze(0), p=2).squeeze()

    # Double centering
    A = a - a.mean(dim=0) - a.mean(dim=1).unsqueeze(1) + a.mean()
    B = b - b.mean(dim=0) - b.mean(dim=1).unsqueeze(1) + b.mean()

    # Compute distance covariance, variance, and correlation
    dcov = torch.sqrt((A * B).mean())
    dvar_x = torch.sqrt((A * A).mean())
    dvar_y = torch.sqrt((B * B).mean())

    # Return distance correlation
    return dcov / (torch.sqrt(dvar_x * dvar_y) + 1e-8)


# Define the adjacent level dependence loss using distance correlation
# Modified adjacent_level_dependence_loss function
def adjacent_level_dependence_loss(tensor, target_dependence):
    """
    Ensure that dependence between adjacent levels is close to target values.
    """
    levels, batch_size, features = tensor.shape
    loss = 0.0
    dependence_values = []

    for i in range(levels - 1):
        level_i = tensor[i].reshape(-1, features)
        level_next = tensor[i + 1].reshape(-1, features)

        dcorr = distance_correlation(level_i, level_next)
        dependence_values.append(dcorr.item())

        # Compare to target dependence
        if i < len(target_dependence):  # Ensure we have a target value
            loss += (dcorr - target_dependence[i])**2

    return loss, dependence_values

import torch
import numpy as np
from sklearn.cross_decomposition import CCA

def hsic_loss(tensor, sigma=None):
    """
    Calculates the HSIC loss between adjacent levels of a tensor.

    Args:
        tensor: Input tensor with shape (levels, batch_size, features).
        sigma: Kernel width (float or None for automatic estimation).

    Returns:
        HSIC loss value (torch.Tensor).
    """
    levels, batch_size, features = tensor.shape
    loss = 0.0
    hsic_values = []

    for i in range(levels - 1):
        level_i = tensor[i].reshape(batch_size, features)
        level_next = tensor[i + 1].reshape(batch_size, features)

        if sigma is None:
            # Heuristic for sigma based on median distance
            dist_matrix = torch.cdist(level_i, level_i, p=2)
            sigma_i = torch.median(dist_matrix) / np.sqrt(2 * np.log(batch_size))

            dist_matrix = torch.cdist(level_next, level_next, p=2)
            sigma_next = torch.median(dist_matrix) / np.sqrt(2 * np.log(batch_size))

            sigma = (sigma_i + sigma_next) / 2

        level_hsic = hsic_loss_level(level_i, level_next, sigma) # Call the level-wise HSIC function
        loss += level_hsic  # Accumulate loss for each adjacent pair
        hsic_values.append(level_hsic.item())

    return loss, hsic_values

def hsic_loss_level(x, y, sigma):
    """
    Calculates the HSIC (Hilbert-Schmidt Independence Criterion) loss between two tensors at a specific level.

    Args:
        x: First tensor (torch.Tensor).
        y: Second tensor (torch.Tensor).
        sigma: Kernel width (float).

    Returns:
        HSIC value (torch.Tensor).
    """
    n = x.shape[0]

    # Gaussian kernel
    def gaussian_kernel(z, sigma):
        pairwise_dists = torch.cdist(z, z, p=2)
        return torch.exp(-pairwise_dists**2 / (2 * sigma**2))

    Kx = gaussian_kernel(x, sigma)
    Ky = gaussian_kernel(y, sigma)

    H = torch.eye(n) - (1.0 / n) * torch.ones((n, n), device=x.device)  # Centering matrix
    HSIC = (1.0 / (n**2)) * torch.trace(Kx @ H @ Ky @ H)
    return -HSIC  # Negative HSIC to maximize dependence



def kcca_loss(tensor, n_components=1):
    """
    Calculates the KCCA loss between adjacent levels of a tensor.

    Args:
        tensor: Input tensor with shape (levels, batch_size, features).
        n_components: Number of canonical components to consider.

    Returns:
        KCCA loss value (torch.Tensor).
    """
    levels, batch_size, features = tensor.shape
    loss = 0.0
    kcca_values = []

    for i in range(levels - 1):
        level_i = tensor[i].reshape(batch_size, features).detach().cpu().numpy()
        level_next = tensor[i + 1].reshape(batch_size, features).detach().cpu().numpy()

        cca = CCA(n_components=n_components)
        cca.fit(level_i, level_next)
        x_c, y_c = cca.transform(level_i, level_next)

        correlations = []
        for j in range(n_components):
            corr = np.corrcoef(x_c[:, j], y_c[:, j])[0, 1]
            correlations.append(corr)

        level_kcca = -torch.tensor(sum(correlations)) # Negative sum to maximize correlations
        loss += level_kcca
        kcca_values.append(level_kcca.item())

    return loss, kcca_values





#####################################################################################################
# Define the model (no changes here)
class CorrelationModel(nn.Module):
    def __init__(self, levels, batch_size, features):
        super(CorrelationModel, self).__init__()
        self.levels = levels
        self.batch_size = batch_size
        self.features = features
        # MRH:
        self.transform = nn.Parameter(torch.randn(levels, batch_size, features))
        #self.transform = nn.Parameter(torch.randn(levels, batch_size, features) * 0.1)

    def forward(self):
        return self.transform

# Hyperparameters
levels = 6
batch_size = 256
features = 5
learning_rate = 0.0025
num_epochs = 500

# Initialize the model, optimizer, and loss function
model = CorrelationModel(levels, batch_size, features)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Track metrics for plotting
correlation_history = []  # To track correlations between levels
independence_history = []  # To track average independence correlations
mutual_information_history = []  # To track mutual information between levels
adjacent_dependence_history = []  # To track adjacent level dependence

# Initialize lists to track losses
corr_loss_history = []  # To track correlation loss
indep_loss_history = []  # To track independence loss
total_loss_history = []  # To track total loss
mi_loss_history = []
adj_loss_history=[]

for epoch in range(num_epochs):
    optimizer.zero_grad()
    output = model()  # Forward pass

    # Normalize the output at each level
    normalized_output = []
    for level_output in output:
        mean = level_output.mean(dim=[0,1], keepdim=True) # Calculate mean and std across batch and features
        std = level_output.std(dim=[0,1], keepdim=True)
        normalized_level = (level_output - mean) / (std + 1e-8) # Normalize and handle potential zero std
        normalized_output.append(normalized_level)
    normalized_output = torch.stack(normalized_output) # Stack back into a single tensor

    # Compute the losses
    corr_loss, correlations = correlation_loss(normalized_output, target_correlations)
    indep_loss, avg_independence_corrs = independence_loss(normalized_output)
    mi_loss, mutual_infos = mutual_information_loss(normalized_output, target_correlations)
    #adj_loss, dependence_values = adjacent_level_dependence_loss(normalized_output)
    adj_loss, dependence_values = adjacent_level_dependence_loss(normalized_output, target_dependence) # Pass normalized output

    #
    # hsic, hsic_values = hsic_loss(normalized_output)
    # kcca, kcca_values = kcca_loss(normalized_output)

    # Combine losses
    #MRH:
    # total_loss = corr_loss + 0.1 * indep_loss + 0.1 * mi_loss + 0.1 * adj_loss
    total_loss = corr_loss + 0.1 * indep_loss +  mi_loss + adj_loss
    #total_loss = corr_loss + 0.1 * indep_loss + 0.1 * mi_loss -0.1* kcca # Example using HSIC

    # Backward pass and optimization
    total_loss.backward()
    optimizer.step()

    # Append metrics to history lists
    correlation_history.append(correlations)
    independence_history.append(avg_independence_corrs)
    mutual_information_history.append(mutual_infos)
    adjacent_dependence_history.append(dependence_values)
    #adjacent_dependence_history.append(kcca_values)


    # Track losses
    corr_loss_history.append(corr_loss.item())
    indep_loss_history.append(indep_loss.item())
    total_loss_history.append(total_loss.item())
    mi_loss_history.append(mi_loss.item())
    adj_loss_history.append(adj_loss.item())

    # Print progress
    if (epoch + 1) % 10 == 0:
        print(f"\nEpoch [{epoch + 1}/{num_epochs}], Correlation Loss: {corr_loss.item():.4f}, Independence Loss: {indep_loss.item():.4f}, Mutual Information Loss: {mi_loss:.4f}, Adjacent Dependence Loss: {adj_loss:.4f}, Total Loss: {total_loss.item():.4f}")
        # print(f"Correlations between Level 1 and other levels: {correlations}")
        # print(f"Mutual Information between Level 1 and other levels: {mutual_infos}")
        # print(f"Average independence correlations (within levels): {avg_independence_corrs}")
        # print(f"Dependence between adjacent levels: {dependence_values}")

###
# Import necessary libraries
fig, axs = plt.subplots(1, 5, figsize=(18, 6))

# Plot Correlation Loss
axs[0].plot(corr_loss_history, label='MI Loss', color='blue', linewidth=4)
axs[0].set_title("Correlation Loss Over Epochs")
axs[0].set_xlabel("Epochs")
axs[0].set_ylabel("Loss")
axs[0].legend()
axs[0].grid()

# Plot Independence Loss
axs[1].plot(indep_loss_history, label='Independence Loss', color='orange', linewidth=4)
axs[1].set_title("Independence Loss Over Epochs")
axs[1].set_xlabel("Epochs")
axs[1].set_ylabel("Loss")
axs[1].legend()
axs[1].grid()



# Plot MI Loss
axs[2].plot(mi_loss_history, label='Total Loss', color='red', linewidth=4)
axs[2].set_title("MI Loss Over Epochs")
axs[2].set_xlabel("Epochs")
axs[2].set_ylabel("Loss")
axs[2].legend()
axs[2].grid()

# Plot Adjacent Level Loss
axs[3].plot(adj_loss_history, label='Total Loss', color='magenta', linewidth=4)
axs[3].set_title("Adjacent Level Loss Over Epochs")
axs[3].set_xlabel("Epochs")
axs[3].set_ylabel("Loss")
axs[3].legend()
axs[3].grid()

# Plot Total Loss
axs[-1].plot(total_loss_history, label='Total Loss', color='green', linewidth=4)
axs[-1].set_title("Total Loss Over Epochs")
axs[-1].set_xlabel("Epochs")
axs[-1].set_ylabel("Loss")
axs[-1].legend()
axs[-1].grid()


# #################################################
# Convert tracked metrics to tensors for easier plotting
correlation_history = torch.tensor(correlation_history)  # Shape: (num_epochs, len(target_correlations))
independence_history = torch.tensor(independence_history)  # Shape: (num_epochs, levels)
mutual_information_history = torch.tensor(mutual_information_history)  # Shape: (num_epochs, len(target_correlations))
adjacent_dependence_history=torch.tensor(adjacent_dependence_history)


# # # Plot correlations between levels
# # ##################################
# # # Plot correlations between levels
colors = ['r', 'g', 'b', 'orange', 'purple']  # Corresponding to target correlations

plt.figure(figsize=(12, 6))

# Plot each level with the corresponding color
for i in range(len(target_correlations)):
    plt.plot(correlation_history[:, i],
             label=f"Level {i + 2} (Target: {target_correlations[i]})",
             linewidth=2,
             color=colors[i])  # Use the defined colors

# Draw horizontal lines with corresponding colors
plt.axhline(y=0.9, color='r', linestyle='--', label="Target Correlation (Level 2)", linewidth=2)
plt.axhline(y=0.7, color='g', linestyle='--', label="Target Correlation (Level 3)", linewidth=2)
plt.axhline(y=0.5, color='b', linestyle='--', label="Target Correlation (Level 4)", linewidth=2)
plt.axhline(y=0.3, color='orange', linestyle='--', label="Target Correlation (Level 5)", linewidth=2)
plt.axhline(y=0.1, color='purple', linestyle='--', label="Target Correlation (Level 6)", linewidth=2)

plt.title("Correlation Between Level 1 and Other Levels Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Correlation")
plt.legend()
plt.grid()
plt.show()

# # ##########################################
# # # Plot average independence correlations within levels
plt.figure(figsize=(12, 6))
for i in range(levels):
    plt.plot(independence_history[:, i], label=f"Level {i + 1}", linewidth=2)  # Adjust the linewidth as needed

plt.title("Average Independence Correlations (Within Levels) Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Average Correlation")
plt.legend()
plt.grid()
plt.show()


# # ########################################
# # # Plot mutual information between levels
colors = ['r', 'g', 'b', 'orange', 'purple']  # Corresponding to target correlations
plt.figure(figsize=(12, 6))

# Plot each level with the corresponding color
for i in range(len(target_MIs)):
    plt.plot(mutual_information_history[:, i],
             label=f"Level {i + 2} (Target: {target_MIs[i]})",
             linewidth=2,
             color=colors[i])  # Use the defined colors

# Draw horizontal lines with corresponding colors
plt.axhline(y=0.83, color='r', linestyle='--', label="Target MI (Level 2)", linewidth=2)
plt.axhline(y=0.34, color='g', linestyle='--', label="Target MI (Level 3)", linewidth=2)
plt.axhline(y=0.14, color='b', linestyle='--', label="Target MI (Level 4)", linewidth=2)
plt.axhline(y=0.05, color='orange', linestyle='--', label="Target MI (Level 5)", linewidth=2)
plt.axhline(y=0.01, color='purple', linestyle='--', label="Target MI (Level 6)", linewidth=2)

plt.title("MI Between Level 1 and Other Levels Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("MI")
plt.legend()
plt.grid()
plt.show()

# ##########################################
# import seaborn as sns
# #
# # Plot MI heatmap for the last epoch
plt.figure(figsize=(10, 8))
sns.heatmap(mi_between_levels_history[-1], annot=True, fmt=".2f", cmap="coolwarm", xticklabels=[f"Level {i+1}" for i in range(levels)], yticklabels=[f"Level {i+1}" for i in range(levels)])
plt.title("Mutual Information Between Levels (Last Epoch)")
plt.xlabel("Levels")
plt.ylabel("Levels")
plt.show()

# ###########################################
# # Extract MI for all pairs of levels over epochs
# level_pairs = [(i, j) for i in range(levels) for j in range(i + 1, levels)]  # All unique pairs of levels
# mi_over_epochs = {pair: [mi_matrix[pair[0], pair[1]] for mi_matrix in mi_between_levels_history] for pair in level_pairs}

# # Plot MI for each pair over epochs
plt.figure(figsize=(12, 6))
for pair, mi_values in mi_over_epochs.items():
    plt.plot(mi_values, label=f"Level {pair[0] + 1} vs Level {pair[1] + 1}", linewidth=2)  # Set linewidth to 2
plt.title("Mutual Information Between Level Pairs Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Mutual Information")
plt.legend()
plt.grid()
plt.show()

# # Define adjacent level pairs
# adjacent_level_pairs = [(i, i + 1) for i in range(levels - 1)]  # e.g., (1, 2), (2, 3), ..., (n-1, n)

# # Extract MI for adjacent pairs of levels over epochs
mi_over_epochs_adjacent = {pair: [mi_matrix[pair[0], pair[1]] for mi_matrix in mi_between_levels_history] for pair in adjacent_level_pairs}

# Plot MI for each adjacent pair over epochs
plt.figure(figsize=(12, 6))
for pair, mi_values in mi_over_epochs_adjacent.items():
    plt.plot(mi_values, label=f"Level {pair[0] + 1} vs Level {pair[1] + 1}", linewidth=2)  # Set linewidth to 2
plt.title("Mutual Information Between Adjacent Level Pairs Over Epochs")
plt.xlabel("Epochs")
plt.ylabel("Mutual Information")
plt.legend()
plt.grid()
plt.show()



#################################################
# Adjacent Level MIs
# Plot dependence between adjacent levels
# Define colors for the lines
colors = ['r', 'g', 'b', 'orange', 'purple']  # Corresponding to levels

# Plot dependence between adjacent levels
plt.figure(figsize=(12, 8))

# Create lists to hold handles and labels for the legend
handles = []
labels = []

# Plot dependence lines
for i in range(levels - 1):
    dependence_by_level = [dependence[i] for dependence in adjacent_dependence_history]
    # Plot dependence lines
    line, = plt.plot(dependence_by_level, linewidth=4, color=colors[i], label=f'Dependence (Level {i + 1} vs Level {i + 2})')
    handles.append(line)  # Store the line handle for the legend
    labels.append(f'Dependence (Level {i + 1} vs Level {i + 2})')  # Store the label

# Plot target lines with corresponding colors
for i in range(levels - 1):
    target_line = plt.axhline(y=target_dependence[i], color=colors[i], linestyle='--', linewidth=2, label=f'Target Dependence (Level {i + 2})')
    handles.append(target_line)  # Store the target line handle for the legend
    labels.append(f'Target Dependence (Level {i + 2})')  # Store the label

plt.title('Dependence Between Adjacent Levels')
plt.xlabel('Epochs')  # You can keep this label or change it as needed
plt.ylabel('Dependence distance_correlation')

# Create a single legend in the top left corner of the plot
plt.legend(handles, labels, loc='upper left')

plt.grid()
plt.tight_layout()  # Adjust layout to make room for the legend
plt.show()