In [None]:
# HDI-v5 - AE-enabeld ==> Using MNIST as input and device if GPU avaialable
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

############################################################ Target Values
#Define the target correlations
target_correlations = [0.9, 0.7, 0.5, 0.3, 0.1]

# Define target dependence values (adjust as needed)
target_dependence = [0.85, 0.65, 0.45, 0.25, 0.15]  # Example values

# Under the Gaussian distribution, the correlation coefficient and mutual information have a one-to-one mapping:
def rho_to_mi(rho, dim):
    result = -dim / 2 * np.log(1 - rho ** 2)
    return result

def mi_to_rho(mi, dim):
    result = np.sqrt(1 - np.exp(-2 * mi / dim))
    return result

# Calculate target mutual information values from target correlations
target_MIs = []
for rho in target_correlations:
    mi = rho_to_mi(rho, dim=1)  # Assuming 1-dimensional features
    target_MIs.append(round(mi, 2))
############################################################# Dataset


# Define the transform
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5,), (0.5,))  # Normalize to range [-1, 1]
])

# Load the MNIST dataset
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)

# Create DataLoaders for batching
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

############################################################# CUDA Setup
# Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

######################################################################
# ############################################################# Losses
# Define the correlation loss
def correlation_loss(tensor, target_correlations):
    level_1 = tensor[0]  # Level 1
    # print('tensor shape = ', tensor.shape)
    loss = 0.0
    correlations = []
    for i, target_corr in enumerate(target_correlations):
        # print('i= ',i, 'target_corr= ',target_corr )
        level_i = tensor[i + 1]  # Other levels
        level_1_flat = level_1.flatten(start_dim=0, end_dim=1)
        level_i_flat = level_i.flatten(start_dim=0, end_dim=1)
        # print('level_1_flat  shape= ',level_1_flat.shape)
        # print('level_i_flat  shape= ',level_i_flat.shape)

        # Compute Pearson correlation
        cov = torch.mean((level_1_flat - level_1_flat.mean()) * (level_i_flat - level_i_flat.mean()))
        std_1 = level_1_flat.std()
        std_i = level_i_flat.std()
        corr = cov / (std_1 * std_i + 1e-8)
        correlations.append(corr.item())  # Store correlation
        loss += (corr - target_corr) ** 2

    return loss, correlations

# Define the independence loss
def independence_loss(tensor):
    loss = 0.0
    levels, batch_size, features = tensor.shape
    avg_correlations = []
    for level in tensor:
        level_corrs = []
        for i in range(features):
            for j in range(i + 1, features):
                col_i = level[:, i]
                col_j = level[:, j]
                cov = torch.mean((col_i - col_i.mean()) * (col_j - col_j.mean()))
                std_i = col_i.std()
                std_j = col_j.std()
                corr = cov / (std_i * std_j + 1e-8)
                level_corrs.append(corr.item())  # Store correlation
                loss += corr ** 2  # Penalize non-zero correlations
        avg_correlations.append(sum(level_corrs) / len(level_corrs))  # Average correlation for this level
    return loss, avg_correlations

# Define the mutual information loss
def mutual_information_loss(tensor, target_MIs):
    level_1 = tensor[0].detach().cpu().numpy()  # Level 1
    loss = 0.0
    mutual_infos = []
    lambda_reg = 0.01  # Regularization parameter

    for i, target_corr in enumerate(target_MIs):
        level_i = tensor[i + 1].detach().cpu().numpy()  # Other levels

        # Flatten the tensors for mutual information calculation
        level_1_flat = level_1.reshape(-1)
        level_i_flat = level_i.reshape(-1)

        # Compute mutual information using sklearn
        mi = mutual_info_regression(level_1_flat.reshape(-1, 1), level_i_flat, random_state=42)
        mi_value = mi[0] / np.log(2)  # Normalize MI to the range [0, 1]

        mutual_infos.append(mi_value)  # Store normalized mutual information
        weight = 1.0 if mi_value < target_corr else 0.5  # Dynamic weighting
        loss += weight * (mi_value - target_corr) ** 2  # Penalize deviation from target

    # Add regularization
    loss += lambda_reg * np.sum(np.square(mutual_infos))  # L2 regularization

    return loss, mutual_infos

# Define the distance correlation function
def distance_correlation(x, y):
    x = x - x.mean()
    y = y - y.mean()

    # Compute pairwise distances
    a = torch.cdist(x.unsqueeze(0), x.unsqueeze(0), p=2).squeeze()
    b = torch.cdist(y.unsqueeze(0), y.unsqueeze(0), p=2).squeeze()

    # Double centering
    A = a - a.mean(dim=0) - a.mean(dim=1).unsqueeze(1) + a.mean()
    B = b - b.mean(dim=0) - b.mean(dim=1).unsqueeze(1) + b.mean()

    # Compute distance covariance, variance, and correlation
    dcov = torch.sqrt((A * B).mean())
    dvar_x = torch.sqrt((A * A).mean())
    dvar_y = torch.sqrt((B * B).mean())

    return dcov / (torch.sqrt(dvar_x * dvar_y) + 1e-8)

# Define the adjacent level dependence loss
def adjacent_level_dependence_loss(tensor, target_dependence):
    levels, batch_size, features = tensor.shape
    loss = 0.0
    dependence_values = []

    for i in range(levels-1 ):
        # print('\n i= ',i)
        level_i = tensor[i].reshape(-1, features)
        level_next = tensor[i + 1].reshape(-1, features)

        dcorr = distance_correlation(level_i, level_next)
        dependence_values.append(dcorr.item())
        # print('dependence_values= ',dependence_values)

        # Compare to target dependence
        if i < len(target_dependence):  # Ensure we have a target value
            loss += (dcorr - target_dependence[i])**2
    return loss, dependence_values
########################################################################## Models
# Define the hierarchical bottleneck model
class HDI_model(nn.Module):
    def __init__(self, levels, features, dropout_rate=0.1, activation=nn.ReLU):
        """
        Enriched Correlation Model for Hierarchical Disentangled Information (HDI).

        Args:
            levels (int): Number of hierarchical levels.
            features (int): Number of features per level.
            dropout_rate (float): Dropout rate for regularization.
            activation (nn.Module): Activation function to use (default: ReLU).
        """
        super(HDI_model, self).__init__()
        self.levels = levels
        self.features = features
        self.activation = activation()

        # Define linear layers for each level (except the first one)
        self.linear_layers = nn.ModuleList([
            nn.Linear(features, features) for _ in range(levels - 1)
        ])

        # Define normalization layers for each level
        self.norm_layers = nn.ModuleList([
            nn.LayerNorm(features) for _ in range(levels - 1)
        ])

        # Dropout for regularization
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, encoder_output):
        """
        Forward pass through the hierarchical model.
        Args:
            encoder_output (torch.Tensor): Input tensor from the encoder (shape: [batch_size, features]).
        Returns:
            torch.Tensor: Stacked tensor of hierarchical levels (shape: [levels, batch_size, features]).
        """
        levels = [encoder_output]  # First level is the encoder output
        for i, (layer, norm) in enumerate(zip(self.linear_layers, self.norm_layers)):
            # Linear transformation
            transformed = layer(levels[-1])
            # Add residual connection
            residual = transformed + levels[-1]
            # Apply normalization
            normalized = norm(residual)
            # Apply activation function
            activated = self.activation(normalized)
            # Apply dropout
            dropped_out = self.dropout(activated)
            # Append to levels
            levels.append(dropped_out)
        #print('mrh= ',torch.stack(levels).shape) # (levels, batch_size, BN_dim)
        # Stack levels into a tensor
        return torch.stack(levels)


class SimpleHDI(nn.Module):
    def __init__(self, levels, features):
        """
        Simple bottleneck that repeats the same level across all levels.
        Args:
            levels (int): Number of hierarchical levels.
            features (int): Number of features per level.
        """
        super(SimpleHDI, self).__init__()
        self.levels = levels
        self.features = features

    def forward(self, x):
        """
        Forward pass through the simple bottleneck.
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, features).
        Returns:
            torch.Tensor: Bottleneck tensor of shape (levels, batch_size, features),
                          where all levels are identical.
        """
        # Repeat the input tensor across the levels dimension
        return x.unsqueeze(0).repeat(self.levels, 1, 1)




########################################################################

# Autoencoder model
# class Autoencoder(nn.Module):
#     def __init__(self, input_dim, hidden_dim, BN_dim ):
#         super(Autoencoder, self).__init__()
#         # Encoder
#         self.encoder = nn.Sequential(
#             nn.Flatten(),
#             nn.Linear(28 * 28, hidden_dim[0]),
#             nn.ReLU(),
#             nn.Linear(hidden_dim[0], hidden_dim[1]),
#             nn.ReLU(),
#             nn.Linear(hidden_dim[1], latent_dim)
#         )
#         # Decoder
#         self.bottleneck=SimpleBottleneck(levels,BN_dim)  # Hierarchical bottleneck
#         self.decoder = nn.Sequential(
#             nn.Linear(latent_dim, hidden_dim[1]),
#             nn.ReLU(),
#             nn.Linear(hidden_dim[1], hidden_dim[0]),
#             nn.ReLU(),
#             nn.Linear(hidden_dim[0], input_dim),
#             nn.Tanh()  # Output normalized to [-1, 1]
#         )

#     def forward(self, x):
#         z = self.encoder(x)
#         HDI_levels = self.bottleneck(z)
#         #print('HDI_levels= ', HDI_levels.shape) # (levels,batch_size,128)
#         # x_reconstructed = self.decoder(z)
#         x_reconstructed = self.decoder(HDI_levels)
#         return x_reconstructed

#########

import torch
import torch.nn as nn
import torch.nn.functional as F

class Autoencoder2(nn.Module):
    def __init__(self, input_dim, hidden_dim, BN_dim, levels):
        super(Autoencoder2, self).__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, hidden_dim[0]),
            nn.ReLU(),
            nn.Linear(hidden_dim[0], hidden_dim[1]),
            nn.ReLU(),
            nn.Linear(hidden_dim[1], BN_dim)  # Output is (batch_size, BN_dim)
        )

        # Bottleneck (HDI layer)
        # self.bottleneck = SimpleHDI(levels, BN_dim)  # or HDI2(levels, BN_dim)
        self.bottleneck = HDI_model(levels, BN_dim)  # or HDI2(levels, BN_dim)

        # Attention mechanism to extract one (batch_size, BN_dim) from (levels, batch_size, BN_dim)
        self.attention = nn.Sequential(
            nn.Linear(BN_dim, 1),  # Compute attention score for each level
            nn.Softmax(dim=0)      # Normalize scores across levels
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(BN_dim, hidden_dim[1]),
            nn.ReLU(),
            nn.Linear(hidden_dim[1], hidden_dim[0]),
            nn.ReLU(),
            nn.Linear(hidden_dim[0], input_dim),
            nn.Tanh()  # Output normalized to [-1, 1]
        )

    def forward(self, x, method="attention", fixed_level=0, manual_weights=None):
        """
        Forward pass with a flag to choose the method for extracting (batch_size, BN_dim).

        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, input_dim).
            method (str): Method to extract (batch_size, BN_dim) from (levels, batch_size, BN_dim).
                          Options: "attention", "fixed", "weighted".
            fixed_level (int): Index of the level to select if method="fixed".
            manual_weights (torch.Tensor): Tensor of shape (levels,) containing weights for levels if method="weighted".

        Returns:
            torch.Tensor: Reconstructed input of shape (batch_size, input_dim).
        """
        # Encode input
        z = self.encoder(x)  # Shape: (batch_size, BN_dim)

        # Pass through HDI bottleneck
        HDI_levels = self.bottleneck(z)  # Shape: (levels, batch_size, BN_dim)

        if method == "attention":
            # Compute attention scores
            attention_scores = self.attention(HDI_levels)  # Shape: (levels, batch_size, 1)
            attention_scores = attention_scores.squeeze(-1)  # Shape: (levels, batch_size)

            # Weighted sum of levels using attention scores
            attention_scores = attention_scores.unsqueeze(-1)  # Shape: (levels, batch_size, 1)
            weighted_output = (HDI_levels * attention_scores).sum(dim=0)  # Shape: (batch_size, BN_dim)

        elif method == "fixed":
            # Select a specific level (e.g., first or last)
            weighted_output = HDI_levels[fixed_level]  # Shape: (batch_size, BN_dim)

        elif method == "weighted":
            # Use manually provided weights to compute a weighted sum
            if manual_weights is None:
                raise ValueError("manual_weights must be provided when method='weighted'")
            if manual_weights.shape[0] != HDI_levels.shape[0]:
                raise ValueError("manual_weights must have the same length as the number of levels")

            # Normalize weights to sum to 1
            manual_weights = manual_weights / manual_weights.sum()  # Shape: (levels,)
            manual_weights = manual_weights.unsqueeze(-1).unsqueeze(-1)  # Shape: (levels, 1, 1)

            # Weighted sum of levels
            weighted_output = (HDI_levels * manual_weights).sum(dim=0)  # Shape: (batch_size, BN_dim)

        else:
            raise ValueError(f"Invalid method: {method}. Choose from 'attention', 'fixed', or 'weighted'.")

        # Decode the weighted output
        x_reconstructed = self.decoder(weighted_output)  # Shape: (batch_size, input_dim)
        return x_reconstructed, HDI_levels
############################################################################### Configuration
# Hyperparameters
input_dim=28*28
hidden_dim=[256, 128]
levels = 6
BN_dim = 5 # Bottleneck features dim
learning_rate = 1e-3
num_epochs = 10  # Reduced for demonstration

################################################################## Initialization
# Initialize model, loss, and optimizer



# model = Autoencoder(input_dim, hidden_dim, BN_dim).to(device)
model= Autoencoder2(input_dim, hidden_dim, BN_dim, levels).to(device)
criterion = nn.MSELoss()  # Mean Squared Error for reconstruction
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

#####################################################################

# Track metrics for plotting
correlation_history = []  # To track correlations between levels
independence_history = []  # To track average independence correlations
mutual_information_history = []  # To track mutual information between levels
adjacent_dependence_history = []  # To track adjacent level dependence

# Initialize lists to track losses
corr_loss_history = []  # To track correlation loss
indep_loss_history = []  # To track independence loss
mi_loss_history = []
adj_loss_history=[]
total_loss_history = []  # To track total loss

model.train()
# Training the autoencoder
for epoch in range(num_epochs):
    # train_loss = train(model, train_loader, criterion, optimizer, device)

    train_loss = 0
    for images, _ in train_loader:
        images = images.to(device)
        images = images.view(images.size(0), -1)  # Flatten images

        ############################################### Forward pass
        # MRH: simple
        # output = model(images)

        # MRH: Attention
        output, HDI_levels = model(images, method="attention")
        #print('\nHDI_levels= ',HDI_levels)

        # MRH: Fixed
        # output = model(images, method="fixed", fixed_level=3)  # Select the first level
        #print('\nimages shape= ',images.shape) # torch.Size([256, 784])
        #print('output shape= ',output.shape) # torch.Size([256, 784])

        # MRH: Weighted
        # manual_weights = torch.tensor([0.1, 0.2, 0.3, 0.4]).to(device)  # Example weights for 4 levels
        # output = model(images, method="weighted", manual_weights=manual_weights)

        ############################################# Compute the losses
        reconstruction_loss = criterion(output, images)
        corr_loss, correlations = correlation_loss(HDI_levels, target_correlations)
        indep_loss, avg_independence_corrs = independence_loss(HDI_levels)
        # mi_loss, mutual_infos = mutual_information_loss(HDI_levels target_MIs)
        adj_loss, dependence_values = adjacent_level_dependence_loss(HDI_levels, target_dependence)

        total_loss = reconstruction_loss + corr_loss + 0.1 * indep_loss  + 0.1 * adj_loss

        # Backward pass and optimization
        # total_loss.backward()




        # Backward pass
        optimizer.zero_grad()
        # loss.backward()
        total_loss.backward()
        optimizer.step()


        # Track metrics
        correlation_history.append(correlations)
        independence_history.append(avg_independence_corrs)
        # mutual_information_history.append(mutual_infos)
        adjacent_dependence_history.append(dependence_values)


        # Track losses
        corr_loss_history.append(corr_loss.item())
        indep_loss_history.append(indep_loss.item())
        # mi_loss_history.append(mi_loss.item())
        adj_loss_history.append(adj_loss.item())
        total_loss_history.append(total_loss.item())

        # train_loss += loss.item()
        train_loss += total_loss.item()

    print(f"Epoch [{epoch + 1}/{num_epochs}], "
          f"Reconstruction Loss: {reconstruction_loss.item():.4f}, "
          f"Correlation Loss: {corr_loss.item():.4f}, Independence Loss: {indep_loss:.4f}, "
          f"Adjacent Dependence Loss: {adj_loss:.4f}, "
          f"Total Loss: {total_loss.item():.4f}")


##########
model.eval()
with torch.no_grad():
    for images, _ in test_loader:
        images = images.to(device)
        images = images.view(images.size(0), -1)  # Flatten images
        outputs,_ = model(images)
images = images.view(-1, 28, 28).cpu().numpy()
outputs = outputs.view(-1, 28, 28).cpu().numpy()

# Plot original and reconstructed images
n = 10  # Number of images to display
plt.figure(figsize=(20, 4))
for i in range(n):
    # Original images
    plt.subplot(2, n, i + 1)
    plt.imshow(images[i], cmap='gray')
    plt.title("Original")
    plt.axis('off')

    # Reconstructed images
    plt.subplot(2, n, i + 1 + n)
    plt.imshow(outputs[i], cmap='gray')
    plt.title("Reconstructed")
    plt.axis('off')

plt.show()
            # return images, outputs  # Return a batch for visualization

####################################################
fig, axs = plt.subplots(1, 4, figsize=(18, 6))
losses = [
    (corr_loss_history, "Correlation Loss", "blue"),
    (indep_loss_history, "Independence Loss", "orange"),
    (adj_loss_history, "Adjacent Dependence Loss", "magenta"),
    (total_loss_history, "Total Loss", "green"),
]

for i, (loss, title, color) in enumerate(losses):
    axs[i].plot(loss, color=color, linewidth=2)
    axs[i].set_title(title)
    axs[i].set_xlabel("Epochs")
    axs[i].set_ylabel("Loss")
    axs[i].grid()

plt.tight_layout()
plt.show()

# Convert tracked metrics to tensors for easier plotting
correlation_history = torch.tensor(correlation_history)
independence_history = torch.tensor(independence_history)
# mutual_information_history = torch.tensor(mutual_information_history)
adjacent_dependence_history = torch.tensor(adjacent_dependence_history)

# Plot Correlations Between Levels
plt.figure(figsize=(12, 6))
colors = ['r', 'g', 'b', 'orange', 'purple']
for i, color in enumerate(colors):
    plt.plot(correlation_history[:, i], label=f"Level {i + 2} (Target: {target_correlations[i]})", color=color, linewidth=2)
    plt.axhline(y=target_correlations[i], color=color, linestyle='--', linewidth=1.5)  # Target line
plt.title("Correlation Between Level 1 and Other Levels")
plt.xlabel("Epochs")
plt.ylabel("Correlation")
plt.legend()
plt.grid()
plt.show()

# Plot Average Independence Correlations Within Levels
plt.figure(figsize=(12, 6))
for i in range(levels):
    plt.plot(independence_history[:, i], label=f"Level {i + 1}", linewidth=2)
plt.title("Average Independence Correlations (Within Levels)")
plt.xlabel("Epochs")
plt.ylabel("Average Correlation")
plt.legend()
plt.grid()
plt.show()

# Plot Mutual Information Between Levels
# plt.figure(figsize=(12, 6))
# for i, color in enumerate(colors):
#     plt.plot(mutual_information_history[:, i], label=f"Level {i + 2} (Target: {target_MIs[i]})", color=color, linewidth=2)
#     plt.axhline(y=target_MIs[i], color=color, linestyle='--', linewidth=1.5)  # Target line
# plt.title("Mutual Information Between Level 1 and Other Levels")
# plt.xlabel("Epochs")
# plt.ylabel("MI")
# plt.legend()
# plt.grid()
# plt.show()

# Plot Dependence Between Adjacent Levels
plt.figure(figsize=(12, 6))
for i, color in enumerate(colors[:-1]):  # Only `levels - 1` adjacent levels
    dependence_by_level = [dependence[i] for dependence in adjacent_dependence_history]
    plt.plot(dependence_by_level, label=f"Level {i + 1} vs Level {i + 2}", color=color, linewidth=2)
    plt.axhline(y=target_dependence[i], color=color, linestyle='--', linewidth=1.5)  # Target line
plt.title("Dependence Between Adjacent Levels")
plt.xlabel("Epochs")
plt.ylabel("Dependence (Distance Correlation)")
plt.legend()
plt.grid()
plt.show()