In [None]:
# Fed-MVKM Tutorial with DHA Dataset

<VSCode.Cell language="markdown">
# Federated Multi-View K-Means Clustering Tutorial
## Using the DHA (Depth-RGB-Hand Action) Dataset

This notebook demonstrates how to use the Fed-MVKM algorithm for federated multi-view clustering on the DHA dataset, which contains hand action data captured using both RGB and depth sensors.

**Authors:** Kristina P. Sinaga  
**Date:** May 2024  
**Version:** 1.0

### Overview:
1. Data Loading and Preprocessing
2. Parameter Setup
3. Client Data Distribution
4. Model Training
5. Results Analysis and Visualization

### Required Files:
- Depth_DHA.mat: Depth sensor data
- RGB_DHA.mat: RGB camera data
- label_DHA.mat: Ground truth labels
- clients_MVDHA.mat: Multi-view data distributed across clients
- clients_labelset_DHA.mat: Label sets for each client
</VSCode.Cell>

<VSCode.Cell language="markdown">
## 1. Setup and Imports

First, let's import all necessary libraries and set up our environment.
</VSCode.Cell>

<VSCode.Cell language="python">
import numpy as np
import torch
import scipy.io as sio
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler

# Import our Fed-MVKM implementation
from mvkm_ed import FedMVKMED, FedMVKMEDConfig
from mvkm_ed.utils import MVKMEDDataProcessor, MVKMEDMetrics, MVKMEDVisualizer

# Set random seed for reproducibility
np.random.seed(42)
torch.manual_seed(42)

# Enable CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
</VSCode.Cell>

<VSCode.Cell language="markdown">
## 2. Data Loading and Preprocessing

We'll now load the DHA dataset from the MAT files and preprocess it using our utility functions.
</VSCode.Cell>

<VSCode.Cell language="python">
def load_mat_data(data_dir: Path):
    """Load MAT files containing DHA dataset."""
    # Load raw data
    rgb_data = sio.loadmat(data_dir / 'RGB_DHA.mat')['RGB_DHA']
    depth_data = sio.loadmat(data_dir / 'Depth_DHA.mat')['Depth_DHA']
    labels = sio.loadmat(data_dir / 'label_DHA.mat')['label_DHA'].ravel()
    
    print(f"RGB data shape: {rgb_data.shape}")
    print(f"Depth data shape: {depth_data.shape}")
    print(f"Number of samples: {len(labels)}")
    print(f"Number of unique classes: {len(np.unique(labels))}")
    
    return rgb_data, depth_data, labels

# Set data directory
data_dir = Path("../data")  # Adjust this path as needed

# Load and preprocess data
rgb_data, depth_data, labels = load_mat_data(data_dir)

# Preprocess views
processor = MVKMEDDataProcessor()
views = processor.preprocess_views(
    [rgb_data, depth_data],
    scale=True,
    normalize=True
)

print("\nAfter preprocessing:")
for i, view in enumerate(views):
    print(f"View {i+1} shape: {view.shape}")
</VSCode.Cell>

<VSCode.Cell language="markdown">
## 3. Client Data Distribution

Now we'll split the data across clients to simulate a federated learning scenario.
</VSCode.Cell>

<VSCode.Cell language="python">
def create_client_partitions(views, labels, n_clients=2, balanced=True):
    """Split data into client partitions."""
    n_samples = views[0].shape[0]
    
    if balanced:
        # Create balanced partitions
        indices = np.random.permutation(n_samples)
        client_size = n_samples // n_clients
    else:
        # Create unbalanced partitions (more realistic scenario)
        ratios = np.random.dirichlet(np.ones(n_clients))
        client_sizes = (ratios * n_samples).astype(int)
        client_sizes[-1] = n_samples - client_sizes[:-1].sum()  # Ensure total = n_samples
        indices = np.random.permutation(n_samples)
        
    client_data = {}
    client_labels = {}
    start_idx = 0
    
    for i in range(n_clients):
        if balanced:
            end_idx = start_idx + client_size if i < n_clients - 1 else n_samples
        else:
            end_idx = start_idx + client_sizes[i]
            
        client_indices = indices[start_idx:end_idx]
        client_data[i] = [view[client_indices] for view in views]
        client_labels[i] = labels[client_indices]
        
        start_idx = end_idx
        
        print(f"Client {i+1}:")
        print(f"  Number of samples: {len(client_indices)}")
        print(f"  Classes present: {np.unique(client_labels[i])}")
        print()
    
    return client_data, client_labels

# Create client partitions
n_clients = 2  # Number of clients
client_data, client_labels = create_client_partitions(
    views, 
    labels, 
    n_clients=n_clients,
    balanced=False  # Use unbalanced partitions for more realistic scenario
)
</VSCode.Cell>

<VSCode.Cell language="markdown">
## 4. Model Configuration and Training

Let's set up the Fed-MVKM model with appropriate parameters and train it on our distributed data.
</VSCode.Cell>

<VSCode.Cell language="python">
# Configure the federated model
config = FedMVKMEDConfig(
    cluster_num=len(np.unique(labels)),  # Number of clusters
    points_view=len(views),              # Number of views (RGB and Depth)
    alpha=15.0,                          # View weight control
    beta=1.0,                            # Initial distance parameter
    gamma=0.04,                          # Model update rate
    max_iterations=10,                    # Maximum federation rounds
    convergence_threshold=1e-4,          # Convergence criterion
    random_state=42,                     # For reproducibility
    verbose=True                         # Show progress
)

# Initialize and train the model
model = FedMVKMED(config)
print("Starting federated training...")
model = model.fit(client_data)
print("Training complete!")
</VSCode.Cell>

<VSCode.Cell language="markdown">
## 5. Results Analysis

Now let's analyze the clustering results and visualize the performance.
</VSCode.Cell>

<VSCode.Cell language="python">
# Get predictions for all clients
predictions = model.predict(client_data)

# Combine all predictions and true labels
all_predictions = []
all_true_labels = []

for client_id in predictions:
    all_predictions.append(predictions[client_id])
    all_true_labels.append(client_labels[client_id])

all_predictions = np.concatenate(all_predictions)
all_true_labels = np.concatenate(all_true_labels)

# Compute clustering metrics
metrics = MVKMEDMetrics.compute_metrics(
    views, 
    all_predictions,
    all_true_labels
)

print("\nClustering Performance Metrics:")
for metric, value in metrics.items():
    print(f"{metric:15s}: {value:.4f}")
</VSCode.Cell>

<VSCode.Cell language="markdown">
## 6. Visualization

Let's create some visualizations to understand the model's behavior during training.
</VSCode.Cell>

<VSCode.Cell language="python">
# Create visualizer
visualizer = MVKMEDVisualizer(model)

# Set up the plotting area
plt.figure(figsize=(15, 5))

# Plot 1: Convergence Analysis
plt.subplot(1, 2, 1)
visualizer.plot_convergence()

# Plot 2: View Weight Evolution
plt.subplot(1, 2, 2)
visualizer.plot_view_weights()

plt.tight_layout()
plt.show()

# Plot client-specific objectives
plt.figure(figsize=(10, 6))
for client_id, objectives in model.history['client_objectives'].items():
    plt.plot(objectives, label=f'Client {client_id+1}')
plt.xlabel('Iteration')
plt.ylabel('Objective Value')
plt.title('Client-specific Convergence')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()
</VSCode.Cell>

<VSCode.Cell language="markdown">
## 7. Save Results

Finally, let's save our trained model and results for future use.
</VSCode.Cell>

<VSCode.Cell language="python">
from mvkm_ed.utils import MVKMEDPersistence
import json
from datetime import datetime

# Save the model
save_dir = Path("results")
save_dir.mkdir(exist_ok=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
model_path = save_dir / f"fed_mvkm_model_{timestamp}.pkl"
results_path = save_dir / f"fed_mvkm_results_{timestamp}.json"

# Save model
MVKMEDPersistence.save_model(model, str(model_path))

# Save metrics and parameters
results = {
    "metrics": metrics,
    "parameters": {
        "n_clients": n_clients,
        "n_clusters": config.cluster_num,
        "n_views": config.points_view,
        "alpha": config.alpha,
        "gamma": config.gamma,
    },
    "timestamp": timestamp
}

with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Model saved to: {model_path}")
print(f"Results saved to: {results_path}")
</VSCode.Cell>

<VSCode.Cell language="markdown">
## Conclusion

This notebook demonstrated the complete workflow of using the Fed-MVKM algorithm on the DHA dataset:
1. Data loading and preprocessing
2. Client data distribution
3. Model configuration and training
4. Results analysis and visualization
5. Model persistence

The algorithm successfully performed federated multi-view clustering while preserving data privacy across clients.
</VSCode.Cell>