# Fed-MVKM Advanced Visualization and Diagnostic Tools

This notebook provides advanced visualization and diagnostic tools for analyzing the Fed-MVKM algorithm's behavior on the DHA dataset.

**Author:** Kristina P. Sinaga  
**Date:** May 2024  
**Version:** 1.0

## 1. Setup and Imports

First, let's import necessary libraries and set up our visualization environment.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import scipy.io as sio
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
from typing import List, Dict, Tuple
import networkx as nx
from scipy.stats import entropy
from scipy.spatial.distance import pdist, squareform

# Import our Fed-MVKM implementation
from mvkm_ed import FedMVKMED, FedMVKMEDConfig
from mvkm_ed.utils import MVKMEDDataProcessor, MVKMEDMetrics, MVKMEDVisualizer

# Set style for static plots
plt.style.use('seaborn')
sns.set_palette("husl")

# Configure notebook for plotly
import plotly.io as pio
pio.templates.default = "plotly_white"

## 2. Advanced Visualization Classes

Let's create specialized classes for different types of visualizations.

In [None]:
class ClusterVisualization:
    """Advanced cluster visualization tools."""
    
    def __init__(self, views: List[np.ndarray], labels: np.ndarray, 
                 predictions: np.ndarray):
        self.views = views
        self.labels = labels
        self.predictions = predictions
        self.n_views = len(views)
        
    def plot_cluster_embeddings(self, method='tsne', interactive=True):
        """Plot cluster embeddings using dimensionality reduction."""
        fig = make_subplots(rows=1, cols=self.n_views, 
                           subplot_titles=[f'View {i+1}' for i in range(self.n_views)])
        
        for i, view in enumerate(self.views):
            # Dimensionality reduction
            if method.lower() == 'tsne':
                embedding = TSNE(n_components=2, random_state=42).fit_transform(view)
            else:
                embedding = PCA(n_components=2, random_state=42).fit_transform(view)
            
            # Create DataFrame for plotting
            df = pd.DataFrame({
                'x': embedding[:, 0],
                'y': embedding[:, 1],
                'True Label': self.labels,
                'Predicted': self.predictions
            })
            
            if interactive:
                # Plotly interactive plot
                scatter1 = go.Scatter(
                    x=df['x'], y=df['y'],
                    mode='markers',
                    marker=dict(color=df['True Label'], 
                              showscale=True,
                              colorscale='Viridis'),
                    name='True Labels',
                    showlegend=i==0
                )
                scatter2 = go.Scatter(
                    x=df['x'], y=df['y'],
                    mode='markers',
                    marker=dict(color=df['Predicted'],
                              showscale=True,
                              colorscale='Plasma'),
                    name='Predictions',
                    visible='legendonly',
                    showlegend=i==0
                )
                fig.add_trace(scatter1, row=1, col=i+1)
                fig.add_trace(scatter2, row=1, col=i+1)
            else:
                # Matplotlib static plot
                plt.figure(figsize=(6, 6))
                plt.scatter(df['x'], df['y'], c=df['True Label'], 
                          cmap='viridis', alpha=0.6)
                plt.title(f'View {i+1} Clusters')
                plt.colorbar(label='Cluster')
                plt.show()
        
        if interactive:
            fig.update_layout(height=500, width=250*self.n_views,
                            title=f'Cluster Embeddings ({method.upper()})')
            fig.show()
            
    def plot_view_relationships(self):
        """Visualize relationships between views using correlation network."""
        n_views = len(self.views)
        correlations = np.zeros((n_views, n_views))
        
        # Compute correlations between views
        for i in range(n_views):
            for j in range(n_views):
                if i != j:
                    corr = np.corrcoef(
                        self.views[i].reshape(-1),
                        self.views[j].reshape(-1)
                    )[0, 1]
                    correlations[i, j] = abs(corr)
        
        # Create network graph
        G = nx.Graph()
        for i in range(n_views):
            G.add_node(f'View {i+1}')
        
        for i in range(n_views):
            for j in range(i+1, n_views):
                if correlations[i, j] > 0.1:  # Threshold for visibility
                    G.add_edge(f'View {i+1}', f'View {j+1}', 
                             weight=correlations[i, j])
        
        plt.figure(figsize=(8, 8))
        pos = nx.spring_layout(G)
        
        # Draw nodes
        nx.draw_networkx_nodes(G, pos, node_color='lightblue', 
                             node_size=1000, alpha=0.7)
        
        # Draw edges with varying thickness
        edges = G.edges()
        weights = [G[u][v]['weight'] * 5 for u, v in edges]
        nx.draw_networkx_edges(G, pos, width=weights, alpha=0.5)
        
        # Add labels
        nx.draw_networkx_labels(G, pos)
        
        plt.title('View Relationship Network\n(Edge thickness indicates correlation strength)')
        plt.axis('off')
        plt.tight_layout()
        plt.show()

## 3. Federation Analysis Tools

Tools for analyzing the federated learning process and client behavior.

In [None]:
class FederationAnalyzer:
    """Tools for analyzing federated learning behavior."""
    
    def __init__(self, model: FedMVKMED, client_data: Dict, client_labels: Dict):
        self.model = model
        self.client_data = client_data
        self.client_labels = client_labels
        self.history = model.history
        
    def plot_client_convergence_3d(self):
        """Create 3D visualization of client convergence trajectories."""
        fig = go.Figure()
        
        for client_id, objectives in self.history['client_objectives'].items():
            fig.add_trace(go.Scatter3d(
                x=np.arange(len(objectives)),
                y=[client_id] * len(objectives),
                z=objectives,
                mode='lines+markers',
                name=f'Client {client_id}',
                line=dict(width=4),
                marker=dict(size=4)
            ))
        
        fig.update_layout(
            title='3D Client Convergence Trajectories',
            scene=dict(
                xaxis_title='Iteration',
                yaxis_title='Client ID',
                zaxis_title='Objective Value',
                camera=dict(
                    up=dict(x=0, y=0, z=1),
                    center=dict(x=0, y=0, z=0),
                    eye=dict(x=1.5, y=1.5, z=1.5)
                )
            ),
            width=800,
            height=800
        )
        
        fig.show()
        
    def analyze_client_diversity(self):
        """Analyze and visualize diversity across clients."""
        n_clients = len(self.client_data)
        
        # Compute statistics for each client
        stats = []
        for client_id in self.client_data:
            client_views = self.client_data[client_id]
            client_labels = self.client_labels[client_id]
            
            # Basic statistics
            n_samples = len(client_labels)
            n_classes = len(np.unique(client_labels))
            class_entropy = entropy(
                np.unique(client_labels, return_counts=True)[1] / n_samples
            )
            
            # View statistics
            view_stats = []
            for view in client_views:
                mean_norm = np.linalg.norm(view.mean(axis=0))
                std_norm = np.linalg.norm(view.std(axis=0))
                view_stats.extend([mean_norm, std_norm])
            
            stats.append([client_id, n_samples, n_classes, class_entropy] + view_stats)
        
        # Create visualization
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=(
                'Sample Distribution',
                'Class Distribution',
                'Feature Statistics',
                'Class Entropy'
            )
        )
        
        # Sample distribution
        fig.add_trace(
            go.Bar(x=[f'Client {s[0]}' for s in stats],
                  y=[s[1] for s in stats],
                  name='Samples'),
            row=1, col=1
        )
        
        # Class distribution
        fig.add_trace(
            go.Bar(x=[f'Client {s[0]}' for s in stats],
                  y=[s[2] for s in stats],
                  name='Classes'),
            row=1, col=2
        )
        
        # Feature statistics
        fig.add_trace(
            go.Box(y=[s[4:] for s in stats],
                  name='View Stats',
                  boxpoints='all'),
            row=2, col=1
        )
        
        # Class entropy
        fig.add_trace(
            go.Scatter(x=[f'Client {s[0]}' for s in stats],
                      y=[s[3] for s in stats],
                      mode='lines+markers',
                      name='Entropy'),
            row=2, col=2
        )
        
        fig.update_layout(height=800, width=1000,
                         showlegend=False,
                         title_text="Client Diversity Analysis")
        fig.show()
        
    def plot_model_evolution(self):
        """Visualize the evolution of model parameters over iterations."""
        if 'center_updates' not in self.history:
            print("No center updates found in history")
            return
            
        center_updates = self.history['center_updates']
        n_iterations = len(center_updates)
        n_views = len(center_updates[0])
        
        fig = make_subplots(
            rows=n_views, cols=1,
            subplot_titles=[f'View {i+1} Center Evolution' for i in range(n_views)],
            vertical_spacing=0.1
        )
        
        for view_idx in range(n_views):
            # Get center trajectories for this view
            centers = np.array([update[view_idx] for update in center_updates])
            
            # Create heatmap of center movement
            fig.add_trace(
                go.Heatmap(
                    z=centers.reshape(n_iterations, -1),
                    colorscale='Viridis',
                    showscale=True,
                    name=f'View {view_idx+1}'
                ),
                row=view_idx+1, col=1
            )
        
        fig.update_layout(
            height=300*n_views,
            width=800,
            title_text="Model Parameter Evolution"
        )
        fig.show()

## 4. Performance Diagnostic Tools

Tools for detailed performance analysis and debugging.

In [None]:
class PerformanceDiagnostics:
    """Advanced diagnostic tools for Fed-MVKM performance analysis."""
    
    def __init__(self, model: FedMVKMED, views: List[np.ndarray], 
                 predictions: np.ndarray, true_labels: np.ndarray):
        self.model = model
        self.views = views
        self.predictions = predictions
        self.true_labels = true_labels
        
    def analyze_cluster_quality(self):
        """Analyze cluster quality using multiple metrics."""
        from sklearn.metrics import silhouette_samples, calinski_harabasz_score
        
        results = {}
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=(
                'Silhouette Scores Distribution',
                'Cluster Sizes',
                'Intra-cluster Distances',
                'Inter-cluster Distances'
            )
        )
        
        for view_idx, view in enumerate(self.views):
            # Silhouette scores
            sil_scores = silhouette_samples(view, self.predictions)
            fig.add_trace(
                go.Violin(
                    y=sil_scores,
                    name=f'View {view_idx+1}',
                    box_visible=True,
                    meanline_visible=True
                ),
                row=1, col=1
            )
            
            # Cluster sizes
            unique, counts = np.unique(self.predictions, return_counts=True)
            fig.add_trace(
                go.Bar(
                    x=[f'Cluster {i}' for i in unique],
                    y=counts,
                    name=f'View {view_idx+1}'
                ),
                row=1, col=2
            )
            
            # Intra-cluster distances
            intra_dist = []
            for cluster in unique:
                mask = self.predictions == cluster
                if np.sum(mask) > 1:
                    cluster_points = view[mask]
                    distances = pdist(cluster_points)
                    intra_dist.extend(distances)
            
            fig.add_trace(
                go.Box(
                    y=intra_dist,
                    name=f'View {view_idx+1}',
                    boxpoints='outliers'
                ),
                row=2, col=1
            )
            
            # Inter-cluster distances
            centers = []
            for cluster in unique:
                mask = self.predictions == cluster
                centers.append(view[mask].mean(axis=0))
            center_dist = pdist(centers)
            
            fig.add_trace(
                go.Box(
                    y=center_dist,
                    name=f'View {view_idx+1}',
                    boxpoints='all'
                ),
                row=2, col=2
            )
            
            # Store numerical results
            results[f'view_{view_idx+1}'] = {
                'silhouette_mean': np.mean(sil_scores),
                'calinski_harabasz': calinski_harabasz_score(view, self.predictions),
                'cluster_sizes': dict(zip(unique, counts)),
                'intra_cluster_dist_mean': np.mean(intra_dist),
                'inter_cluster_dist_mean': np.mean(center_dist)
            }
        
        fig.update_layout(height=800, width=1000,
                         title_text="Cluster Quality Analysis")
        fig.show()
        
        return results
    
    def analyze_view_importance(self):
        """Analyze the importance and contribution of each view."""
        if not hasattr(self.model, 'history') or 'view_weights' not in self.model.history:
            print("No view weight history available")
            return
            
        view_weights = np.array(self.model.history['view_weights'])
        n_iterations, n_views = view_weights.shape
        
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=(
                'View Weight Evolution',
                'Final Weight Distribution',
                'Weight Stability',
                'View Correlation'
            )
        )
        
        # View weight evolution
        for i in range(n_views):
            fig.add_trace(
                go.Scatter(
                    y=view_weights[:, i],
                    mode='lines',
                    name=f'View {i+1}'
                ),
                row=1, col=1
            )
        
        # Final weight distribution
        fig.add_trace(
            go.Pie(
                labels=[f'View {i+1}' for i in range(n_views)],
                values=view_weights[-1],
                hole=0.3
            ),
            row=1, col=2
        )
        
        # Weight stability
        weight_std = view_weights.std(axis=0)
        fig.add_trace(
            go.Bar(
                x=[f'View {i+1}' for i in range(n_views)],
                y=weight_std,
                name='Weight Stability'
            ),
            row=2, col=1
        )
        
        # View correlation
        corr_matrix = np.corrcoef([view.reshape(-1) for view in self.views])
        fig.add_trace(
            go.Heatmap(
                z=corr_matrix,
                x=[f'View {i+1}' for i in range(n_views)],
                y=[f'View {i+1}' for i in range(n_views)],
                colorscale='RdBu'
            ),
            row=2, col=2
        )
        
        fig.update_layout(height=800, width=1000,
                         title_text="View Importance Analysis")
        fig.show()

## 5. Example Usage

Let's demonstrate how to use these visualization and diagnostic tools.

In [None]:
# Load data and model from previous analysis
# Assuming you have run the main tutorial notebook first
data_dir = Path("../data")

def load_and_preprocess():
    """Load and preprocess the DHA dataset."""
    processor = MVKMEDDataProcessor()
    rgb_data = sio.loadmat(data_dir / 'RGB_DHA.mat')['RGB_DHA']
    depth_data = sio.loadmat(data_dir / 'Depth_DHA.mat')['Depth_DHA']
    labels = sio.loadmat(data_dir / 'label_DHA.mat')['label_DHA'].ravel()
    
    views = processor.preprocess_views([rgb_data, depth_data])
    return views, labels

# Load data
views, true_labels = load_and_preprocess()

# Create visualizers
cluster_viz = ClusterVisualization(views, true_labels, model.predict(client_data))
fed_analyzer = FederationAnalyzer(model, client_data, client_labels)
diagnostics = PerformanceDiagnostics(model, views, all_predictions, all_true_labels)

# Generate visualizations
print("1. Cluster Embeddings:")
cluster_viz.plot_cluster_embeddings(method='tsne', interactive=True)

print("\n2. View Relationships:")
cluster_viz.plot_view_relationships()

print("\n3. Federation Analysis:")
fed_analyzer.plot_client_convergence_3d()
fed_analyzer.analyze_client_diversity()
fed_analyzer.plot_model_evolution()

print("\n4. Performance Diagnostics:")
cluster_quality = diagnostics.analyze_cluster_quality()
diagnostics.analyze_view_importance()

# Print numerical results
print("\nCluster Quality Metrics:")
print(json.dumps(cluster_quality, indent=2))

## Conclusion

This notebook provided advanced visualization and diagnostic tools for:
1. Cluster analysis and embedding visualization
2. Federation behavior analysis
3. Client diversity and contribution analysis
4. Performance diagnostics and quality metrics

These tools help in understanding and improving the Fed-MVKM algorithm's behavior on multi-view datasets.