In [1]:
import sys
from pathlib import Path

# Find project root by looking for .git or requirements.txt
current = Path.cwd()
while not any((current / marker).exists() for marker in ['.git', 'requirements.txt']):
    if current.parent == current:
        raise FileNotFoundError("Project root not found")
    current = current.parent

sys.path.append(str(current))
print(f"Added project root: {current}")

Added project root: /notebooks


In [2]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import random
import torch

# Set seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)
random.seed(42)

# Add project root to path for imports
# if not '..' in sys.path:
#     sys.path.append('..')

# Utils

In [3]:
def visualize_mixing_matrix(mixing_matrix, title="Ground Truth Mixing Matrix"):
    """Visualize the relationship between factors and latents."""
    plt.figure(figsize=(10, 8))
    sns.heatmap(mixing_matrix, annot=True, fmt=".2f", cmap="viridis",
                xticklabels=[f"z{i}" for i in range(mixing_matrix.shape[1])],
                yticklabels=[f"f{i}" for i in range(mixing_matrix.shape[0])])
    plt.title(title)
    plt.xlabel("Latent Dimensions")
    plt.ylabel("Ground Truth Factors")
    plt.tight_layout()
    plt.show()

def visualize_metric_results(results, metric_name="DCId"):
    """Visualize metric results across different disentanglement levels."""
    plt.figure(figsize=(12, 6))
    
    # Extract relevant metrics
    x = [r['level'] for r in results]
    metrics = {k: [r[k] for r in results] for k in results[0].keys() if k != 'level'}
    
    # Plot each metric
    for name, values in metrics.items():
        plt.plot(x, values, 'o-', label=name)
        
    plt.xlabel('Disentanglement Level of Generated Data')
    plt.ylabel('Metric Value')
    plt.title(f'{metric_name} Metrics vs True Disentanglement Level')
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

# Artificial disentanglement data

In [4]:
def generate_synthetic_data(factor_sizes, num_latents, disentanglement_level, noise_level):
    """
    Generate synthetic data with controlled disentanglement properties.
    All possible combinations of factor values will be generated.
    
    Parameters:
    - factor_sizes: List where each element specifies the number of discrete values for that factor
                   (length of list determines number of factors)
    - num_latents: Number of latent dimensions
    - disentanglement_level: How disentangled the representation should be (0-1)
    - noise_level: Amount of noise to add
    
    Returns:
    - gt_factors: Ground truth factors [num_samples, num_factors]
    - latent_reps: Latent representations [num_samples, num_latents]
    - mixing_matrix: Matrix showing how factors map to latents
    """
    import itertools
    
    # Determine number of factors from the length of factor_sizes
    num_factors = len(factor_sizes)
    
    # Calculate total number of samples based on all combinations
    num_samples = np.prod(factor_sizes)
    
    # Generate evenly spaced values for each factor
    factor_values = []
    for i in range(num_factors):
        # Generate values from 0 to 1 with factor_sizes[i] discrete values
        factor_values.append(np.linspace(0, 1, factor_sizes[i]))
    
    # Generate all combinations of factor values
    all_combinations = list(itertools.product(*factor_values))
    
    # Create ground truth factors from all combinations
    gt_factors = np.array(all_combinations)
    
    # Create a mixing matrix that defines how factors map to latents
    # For perfect disentanglement, this would be an identity matrix (with zeros for any extra dimensions)
    mixing_matrix = np.zeros((num_factors, num_latents))
    
    # Set up the mixing based on disentanglement level
    for i in range(num_factors):
        # Primary dimension gets most of the weight based on disentanglement level
        primary_dim = i % num_latents  # In case we have more factors than latents
        mixing_matrix[i, primary_dim] = disentanglement_level
        
        # Distribute remaining weight across other dimensions
        remaining = 1.0 - disentanglement_level
        other_dims = [j for j in range(num_latents) if j != primary_dim]
        if other_dims:  # Check if there are other dimensions
            for j in other_dims:
                mixing_matrix[i, j] = remaining / len(other_dims)
    
    # Generate latent representations based on the mixing matrix
    latent_reps = np.dot(gt_factors, mixing_matrix)
    
    # Add some noise to make it more realistic
    latent_reps += np.random.normal(0, noise_level, latent_reps.shape)
    
    return gt_factors, latent_reps, mixing_matrix

In [5]:
factor_sizes = [10, 10, 10, 8, 5, 15]  # Number of discrete values for each factor
num_latents = 10  # Number of latent dimensions
disentanglement_level = 0.5  # Disentanglement level (0-1)
noise_level = 0.02  # Noise level to add to the latent representations
total_samples = np.prod(factor_sizes)  # Total number of samples based on all combinations

print("Total samples:", total_samples)

# Generate synthetic disentangled data
gt_factors, latent_reps, mixing_matrix = generate_synthetic_data(
    factor_sizes, num_latents, disentanglement_level, noise_level
)

print("gt factors shape:", gt_factors.shape, "latent reps shape:", latent_reps.shape, "mixing matrix shape:", mixing_matrix.shape)

Total samples: 600000
gt factors shape: (600000, 6) latent reps shape: (600000, 10) mixing matrix shape: (6, 10)


# DCI_d

Testing the Disentanglement, Completeness, and Informativeness (DCI) metric with synthetic data.

In [6]:
from metrics.dci_d import DCId


train_ratio_dci_d = 0.8

dci_d_num_train = int(train_ratio_dci_d * total_samples) 
dci_d_num_test = total_samples - dci_d_num_train

In [7]:
# Convert numpy arrays to torch tensors

# Initialize the DCId metric
dci_metric = DCId(num_train=dci_d_num_train, num_test=dci_d_num_test, backend='sklearn', num_workers=4)

# Compute the metric
dci_results = dci_metric(latent_reps, gt_factors)

for key, value in dci_results.items():
    print(f"{key}: {value}")


                                     

KeyboardInterrupt: 

# MIG

In [8]:
from metrics.mig import MIG

# Set the parameters for MIG calculation
mig_num_bins = 20
mig_num_workers = 8

# Initialize the MIG metric
mig_metric = MIG(num_bins=mig_num_bins, num_workers=mig_num_workers, mi_method='pyitlib', entropy_method='pyitlib')

# We can reuse the same tensors we used for DCI_d
# Compute the metric
mig_result = mig_metric(latent_reps, gt_factors)

print(f"MIG Score: {mig_result:.4f}") # TODO check the correctness of this metric

                                                                               

MIG Score: 0.5271


# Modularity_d

In [None]:
from metrics.modularity_d import Modularityd

# Set the parameters for Modularityd calculation
modularity_d_num_bins = 20
modularity_d_num_workers = 4

# Initialize the Modularityd metric
modularity_d_metric = Modularityd(num_bins=modularity_d_num_bins, num_workers=modularity_d_num_workers, mi_method='numpy', device='cpu')

# Compute the metric
modularity_d_result = modularity_d_metric(latent_reps, gt_factors)
print(f"Modularityd Score: {modularity_d_result:.4f}")


TypeError: BaseMetric.__init__() got an unexpected keyword argument 'device'