# Rank Dynamics Exploration

This notebook is for interactive analysis of rank decay dynamics from Weights & Biases logs.

In [None]:
import wandb
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from architectural_and_learning_on_loss_landscape.src.utils.zeroth_order_features import compute_rank_decay_dynamics

# --- Configuration ---
RUN_PATH = "dordorhome/drifting_experiment_continuous_input_deformation/z9s2p5z2"  # <-- REPLACE with your entity/project/run_id
RANK_TYPE = "effective_rank"  # or 'approximate_rank', 'numerical_rank', etc.
NUM_LAYERS = 5  # <-- REPLACE with the number of layers you tracked
ANALYSIS_MODE = 'difference' # 'difference' or 'ratio'

## Fetch Data from W&B

In [None]:
api = wandb.Api()
run = api.run(RUN_PATH)

# Define the metric keys to fetch
rank_metric_prefix = "ranks/layer_"
rank_keys = [f"{rank_metric_prefix}{i}/{RANK_TYPE}" for i in range(NUM_LAYERS)]

# Fetch history - this returns a pandas DataFrame
history = run.history(keys=rank_keys, pandas=True)

print(f"Successfully fetched {len(history)} steps from run: {run.name}")
history.head()

## Compute Derived Dynamics Metrics

In [None]:
dynamics_results = []

for index, row in history.iterrows():
    # Extract ranks for the current step, ensuring correct order
    ranks_at_step = [row[key] for key in rank_keys if key in row]
    
    if len(ranks_at_step) == NUM_LAYERS:
        dynamics = compute_rank_decay_dynamics(ranks_at_step, mode=ANALYSIS_MODE)
        dynamics['_step'] = row['_step']
        dynamics_results.append(dynamics)

dynamics_df = pd.DataFrame(dynamics_results)
dynamics_df.head()

## Plot the Results

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(12, 18), sharex=True)
fig.suptitle(f'Rank Decay Dynamics ({ANALYSIS_MODE.capitalize()} Mode) for Run: {run.name}', fontsize=16)

# Plot Gini Coefficient
axes[0].plot(dynamics_df['_step'], dynamics_df['rank_drop_gini'], label='Gini Coefficient')
axes[0].set_title('Rank Drop Gini Coefficient vs. Training Step')
axes[0].set_ylabel('Gini Coefficient')
axes[0].grid(True, linestyle='--')

# Plot Rank Decay Centroid
axes[1].plot(dynamics_df['_step'], dynamics_df['rank_decay_centroid'], label='Rank Decay Centroid')
axes[1].set_title('Rank Decay Centroid vs. Training Step')
axes[1].set_ylabel('Centroid (Layer Index)')
axes[1].grid(True, linestyle='--')

# Plot Normalized AURC
axes[2].plot(dynamics_df['_step'], dynamics_df['normalized_aurc'], label='Normalized AURC')
axes[2].set_title('Normalized AURC vs. Training Step')
axes[2].set_ylabel('Normalized Area')
axes[2].set_xlabel('Training Step')
axes[2].grid(True, linestyle='--')

plt.tight_layout(rect=[0, 0, 1, 0.96])
plt.show()