# Advanced Visualizations for Fed-MVKM Analysis

This notebook provides sophisticated visualization tools for analyzing Fed-MVKM algorithm's behavior using interactive and 3D visualizations.

**Author:** Kristina P. Sinaga  
**Date:** May 2024  
**Version:** 1.0

## 1. Setup and Imports

Let's import required libraries with a focus on advanced visualization packages.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
from mpl_toolkits.mplot3d import Axes3D
import networkx as nx
from scipy.stats import gaussian_kde
from sklearn.manifold import TSNE, MDS
from sklearn.decomposition import PCA
import umap
import pandas as pd
from typing import List, Dict, Tuple, Optional
import colorcet as cc  # For better color maps
import plotly.io as pio

# Set default theme for plotly
pio.templates.default = "plotly_white"

# Import Fed-MVKM related modules
from mvkm_ed import FedMVKMED, FedMVKMEDConfig
from mvkm_ed.utils import MVKMEDDataProcessor, MVKMEDMetrics

# Configure matplotlib style
plt.style.use('seaborn')
sns.set_palette(cc.glasbey_dark)

## 2. Enhanced Cluster Visualization Tools

In [None]:
class EnhancedClusterVisualizer:
    """Advanced cluster visualization with multiple dimensionality reduction methods."""
    
    def __init__(self, views: List[np.ndarray], labels: np.ndarray, predictions: np.ndarray):
        self.views = views
        self.labels = labels
        self.predictions = predictions
        self.n_views = len(views)
        self.n_clusters = len(np.unique(predictions))
        
    def plot_cluster_comparison(self, method='tsne', dims=2):
        """Compare true labels vs predictions using various embedding methods."""
        methods = {
            'tsne': TSNE,
            'umap': umap.UMAP,
            'pca': PCA,
            'mds': MDS
        }
        
        reducer = methods[method.lower()](
            n_components=dims,
            random_state=42
        )
        
        fig = make_subplots(
            rows=self.n_views,
            cols=2,
            subplot_titles=[
                f'View {i+1} - True Labels vs Predictions'
                for i in range(self.n_views)
            ],
            specs=[[{'type': 'scene' if dims == 3 else None}]*2]*self.n_views
        )
        
        for view_idx, view in enumerate(self.views):
            # Compute embedding
            embedding = reducer.fit_transform(view)
            
            # Create scatter plots
            if dims == 3:
                self._add_3d_scatter(fig, embedding, self.labels, 
                                   view_idx+1, 1, 'True Labels')
                self._add_3d_scatter(fig, embedding, self.predictions, 
                                   view_idx+1, 2, 'Predictions')
            else:
                self._add_2d_scatter(fig, embedding, self.labels, 
                                   view_idx+1, 1, 'True Labels')
                self._add_2d_scatter(fig, embedding, self.predictions, 
                                   view_idx+1, 2, 'Predictions')
        
        fig.update_layout(
            height=400*self.n_views,
            width=1200,
            title_text=f"Cluster Comparison using {method.upper()}"
        )
        fig.show()
        
    def _add_2d_scatter(self, fig, embedding, labels, row, col, title):
        """Add 2D scatter plot to figure."""
        scatter = go.Scatter(
            x=embedding[:, 0],
            y=embedding[:, 1],
            mode='markers',
            marker=dict(
                size=8,
                color=labels,
                colorscale='Viridis',
                showscale=True,
                colorbar=dict(title='Cluster')
            ),
            text=[f'Cluster {l}' for l in labels],
            name=title
        )
        fig.add_trace(scatter, row=row, col=col)
        fig.update_xaxes(title_text="Dimension 1", row=row, col=col)
        fig.update_yaxes(title_text="Dimension 2", row=row, col=col)
        
    def _add_3d_scatter(self, fig, embedding, labels, row, col, title):
        """Add 3D scatter plot to figure."""
        scatter = go.Scatter3d(
            x=embedding[:, 0],
            y=embedding[:, 1],
            z=embedding[:, 2],
            mode='markers',
            marker=dict(
                size=4,
                color=labels,
                colorscale='Viridis',
                showscale=True,
                colorbar=dict(title='Cluster')
            ),
            text=[f'Cluster {l}' for l in labels],
            name=title
        )
        fig.add_trace(scatter, row=row, col=col)

## 3. Federation Process Visualization

In [None]:
class FederationVisualizer:
    """Visualize the federation process and client interactions."""
    
    def __init__(self, model: FedMVKMED, client_data: Dict):
        self.model = model
        self.client_data = client_data
        self.history = model.history
        
    def plot_federation_topology(self):
        """Visualize federation topology and client relationships."""
        G = nx.Graph()
        
        # Add nodes for clients and server
        G.add_node('Server', node_type='server')
        for client_id in self.client_data.keys():
            G.add_node(f'Client {client_id}', node_type='client')
            G.add_edge('Server', f'Client {client_id}')
        
        # Create positions
        pos = nx.spring_layout(G)
        
        # Create figure
        plt.figure(figsize=(12, 8))
        
        # Draw nodes
        nx.draw_networkx_nodes(G, pos, 
                             node_color=['red' if n == 'Server' else 'blue' 
                                       for n in G.nodes()],
                             node_size=[2000 if n == 'Server' else 1000 
                                      for n in G.nodes()],
                             alpha=0.7)
        
        # Draw edges with arrow to show communication direction
        nx.draw_networkx_edges(G, pos, edge_color='gray', 
                             arrows=True, arrowsize=20)
        
        # Add labels
        nx.draw_networkx_labels(G, pos)
        
        plt.title("Federation Topology and Communication Flow")
        plt.axis('off')
        plt.show()
        
    def animate_model_evolution(self):
        """Create animated visualization of model evolution."""
        if 'center_updates' not in self.history:
            print("No center updates found in history")
            return
            
        # Create animation frames
        frames = []
        for i, centers in enumerate(self.history['center_updates']):
            for view_idx, view_centers in enumerate(centers):
                frame = go.Frame(
                    data=[go.Scatter(
                        x=view_centers[:, 0],
                        y=view_centers[:, 1],
                        mode='markers+text',
                        marker=dict(size=15, symbol='star'),
                        text=[f'C{j}' for j in range(len(view_centers))],
                        name=f'View {view_idx+1}'
                    )],
                    name=f'Frame {i}'
                )
                frames.append(frame)
        
        # Create figure
        fig = go.Figure(
            frames=frames,
            layout=go.Layout(
                title="Model Evolution Animation",
                updatemenus=[{
                    'type': 'buttons',
                    'showactive': False,
                    'buttons': [{
                        'label': 'Play',
                        'method': 'animate',
                        'args': [None, {'frame': {'duration': 500}}]
                    }]
                }]
            )
        )
        
        fig.show()

## 4. Advanced Performance Analysis

In [None]:
class PerformanceAnalyzer:
    """Advanced performance analysis and visualization tools."""
    
    def __init__(self, model: FedMVKMED, views: List[np.ndarray], 
                 true_labels: np.ndarray):
        self.model = model
        self.views = views
        self.true_labels = true_labels
        
    def plot_cluster_stability(self):
        """Visualize cluster stability across iterations."""
        if not hasattr(self.model, 'history'):
            print("No history available in model")
            return
            
        assignments_history = self.model.history.get('assignments', [])
        if not assignments_history:
            print("No assignment history available")
            return
            
        # Calculate stability metrics
        stability_matrix = np.zeros((len(assignments_history), 
                                   len(assignments_history)))
        
        for i in range(len(assignments_history)):
            for j in range(len(assignments_history)):
                stability_matrix[i, j] = self._compute_assignment_similarity(
                    assignments_history[i],
                    assignments_history[j]
                )
        
        # Create heatmap
        fig = go.Figure(data=go.Heatmap(
            z=stability_matrix,
            colorscale='RdBu',
            text=np.around(stability_matrix, decimals=2),
            texttemplate="%{text}",
            textfont={"size": 10},
            hoverongaps=False
        ))
        
        fig.update_layout(
            title="Cluster Assignment Stability Matrix",
            xaxis_title="Iteration j",
            yaxis_title="Iteration i",
            width=800,
            height=800
        )
        
        fig.show()
        
    def _compute_assignment_similarity(self, assign1, assign2):
        """Compute similarity between two cluster assignments."""
        from sklearn.metrics import adjusted_rand_score
        return adjusted_rand_score(assign1, assign2)
    
    def plot_view_contribution(self):
        """Analyze and visualize the contribution of each view."""
        view_weights = np.array(self.model.history['view_weights'])
        
        fig = make_subplots(
            rows=2, cols=2,
            subplot_titles=(
                'View Weight Evolution',
                'Final View Importance',
                'View Weight Stability',
                'View Discriminative Power'
            )
        )
        
        # 1. View weight evolution
        for i in range(view_weights.shape[1]):
            fig.add_trace(
                go.Scatter(
                    y=view_weights[:, i],
                    mode='lines',
                    name=f'View {i+1}'
                ),
                row=1, col=1
            )
        
        # 2. Final view importance (pie chart)
        fig.add_trace(
            go.Pie(
                values=view_weights[-1],
                labels=[f'View {i+1}' for i in range(view_weights.shape[1])],
                hole=0.3
            ),
            row=1, col=2
        )
        
        # 3. View weight stability (box plot)
        fig.add_trace(
            go.Box(
                y=view_weights.flatten(),
                x=[f'View {i+1}' for i in range(view_weights.shape[1])] * len(view_weights),
                name='Weight Distribution'
            ),
            row=2, col=1
        )
        
        # 4. View discriminative power
        discriminative_power = []
        for view in self.views:
            from sklearn.metrics import silhouette_score
            score = silhouette_score(view, self.true_labels)
            discriminative_power.append(score)
        
        fig.add_trace(
            go.Bar(
                x=[f'View {i+1}' for i in range(len(self.views))],
                y=discriminative_power,
                name='Discriminative Power'
            ),
            row=2, col=2
        )
        
        fig.update_layout(
            height=800,
            width=1000,
            title_text="Comprehensive View Analysis"
        )
        
        fig.show()

## 5. Example Usage

Let's demonstrate how to use these advanced visualization tools.

In [None]:
# Load and preprocess data (assuming we have the data from previous notebooks)
processor = MVKMEDDataProcessor()
views, labels = load_mat_data(data_dir)
views = processor.preprocess_views(views)

# Initialize visualizers
cluster_viz = EnhancedClusterVisualizer(views, labels, model.predict(client_data))
fed_viz = FederationVisualizer(model, client_data)
perf_analyzer = PerformanceAnalyzer(model, views, labels)

# Generate visualizations
print("1. Enhanced Cluster Visualization:")
cluster_viz.plot_cluster_comparison(method='umap', dims=3)

print("\n2. Federation Process Visualization:")
fed_viz.plot_federation_topology()
fed_viz.animate_model_evolution()

print("\n3. Advanced Performance Analysis:")
perf_analyzer.plot_cluster_stability()
perf_analyzer.plot_view_contribution()

## Conclusion

This notebook provided advanced visualization tools for:
1. Enhanced cluster analysis with multiple dimensionality reduction methods
2. Interactive 3D visualizations of federation process
3. Animated model evolution tracking
4. Comprehensive performance analysis
5. View contribution analysis

These visualizations help in better understanding and analyzing the Fed-MVKM algorithm's behavior and performance.