In [1]:
import tensorflow as tf
from tensorflow.core.util import event_pb2
import pandas as pd
import numpy as np
from pathlib import Path

def read_tensorboard_logs(log_dir):
    """
    Read TensorBoard event files and extract metrics.
    
    Parameters:
    - log_dir: Path to directory containing TensorBoard event files
    
    Returns:
    - metrics_df: DataFrame with columns [step, epoch, metric_name, value]
    """
    log_path = Path(log_dir)
    event_files = list(log_path.glob("events.out.tfevents.*"))
    
    if not event_files:
        raise ValueError(f"No TensorBoard event files found in {log_dir}")
    
    all_metrics = []
    
    for event_file in event_files:
        for record in tf.data.TFRecordDataset(str(event_file)):
            event = event_pb2.Event.FromString(record.numpy())
            
            # Extract scalar values
            for value in event.summary.value:
                all_metrics.append({
                    'step': event.step,
                    'wall_time': event.wall_time,
                    'metric_name': value.tag,
                    'value': value.simple_value
                })
    
    metrics_df = pd.DataFrame(all_metrics)
    return metrics_df

def calculate_saturation_point(metrics_df, metric_name='train/loss', threshold_pct=0.1):
    """
    Calculate the first step where the model reaches within threshold_pct of final loss.
    
    Parameters:
    - metrics_df: DataFrame from read_tensorboard_logs
    - metric_name: Name of the metric to analyze (e.g., 'train/loss')
    - threshold_pct: Percentage threshold (default 0.1 for 10%)
    
    Returns:
    - saturation_dict: Dictionary with saturation metrics
    """
    # Filter for the specific metric
    metric_data = metrics_df[metrics_df['metric_name'] == metric_name].copy()
    metric_data = metric_data.sort_values('step')
    
    if len(metric_data) == 0:
        raise ValueError(f"Metric '{metric_name}' not found in data")
    
    # Get final loss value (average of last 10% of training)
    n_final = max(1, len(metric_data) // 10)
    final_loss = metric_data['value'].tail(n_final).mean()
    
    # Get initial loss
    initial_loss = metric_data['value'].iloc[0]
    
    # Calculate threshold value
    threshold_value = final_loss * (1 + threshold_pct)
    
    # Find first step where loss is within threshold of final value
    within_threshold = metric_data[metric_data['value'] <= threshold_value]
    
    if len(within_threshold) == 0:
        saturation_step = None
        saturation_epoch = None
    else:
        saturation_step = within_threshold['step'].iloc[0]
        saturation_epoch = saturation_step  # Adjust if you track epochs separately
    
    return {
        'saturation_step': saturation_step,
        'initial_loss': initial_loss,
        'final_loss': final_loss,
        'threshold_value': threshold_value,
        'total_steps': metric_data['step'].max(),
        'saturation_pct': (saturation_step / metric_data['step'].max() * 100) if saturation_step else None
    }

def extract_all_metrics(log_dir, output_csv=None):
    """
    Extract all metrics and calculate saturation for train/val loss.
    
    Parameters:
    - log_dir: Path to TensorBoard log directory
    - output_csv: Optional path to save metrics CSV
    
    Returns:
    - metrics_df: DataFrame with all metrics
    - saturation_metrics: Dictionary with saturation analysis
    """
    # Read metrics
    metrics_df = read_tensorboard_logs(log_dir)
    
    # Save to CSV if requested
    if output_csv:
        metrics_df.to_csv(output_csv, index=False)
        print(f"Metrics saved to {output_csv}")
    
    # Calculate saturation for different metrics
    saturation_metrics = {}
    
    for metric_name in ['train_loss', 'val_loss_step']:
        try:
            saturation = calculate_saturation_point(metrics_df, metric_name)
            saturation_metrics[metric_name] = saturation
            print(f"\n{metric_name}:")
            print(f"  Saturation step: {saturation['saturation_step']}")
            print(f"  Saturation at: {saturation['saturation_pct']:.1f}% of training")
            print(f"  Initial loss: {saturation['initial_loss']:.4f}")
            print(f"  Final loss: {saturation['final_loss']:.4f}")
        except ValueError:
            continue
    
    return metrics_df, saturation_metrics

# Usage example
log_dir = "/gpfs/home/asun/jin_lab/perturbench/debug/out_L6_CT_CTX_holdout/version_0"
metrics_df, saturation_metrics = extract_all_metrics(
    log_dir, 
    output_csv="training_metrics.csv"
)

# Access specific saturation point
#train_saturation_step = saturation_metrics['train/loss']['saturation_step']
#print(f"\nTraining speed saturation: {train_saturation_step} steps")

2025-10-02 13:59:39.921703: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


ValueError: No TensorBoard event files found in /gpfs/home/asun/jin_lab/perturbench/debug/out_L6_CT_CTX_holdout/version_0

In [4]:
metrics_df["metric_name"].value_counts()

metric_name
val_loss_step                   560
epoch                           278
train_loss                      268
lr-Adam                          10
lr-Adam-momentum                 10
val_loss_epoch                   10
hp_metric                         2
_hparams_/experiment              2
_hparams_/session_start_info      2
_hparams_/session_end_info        2
rmse_average                      1
rmse_rank_average                 1
cosine_logfc                      1
cosine_rank_logfc                 1
Name: count, dtype: int64

In [3]:
import pandas as pd
import numpy as np
from pathlib import Path

def read_csv_metrics(csv_path):
    """
    Read metrics from CSV file (e.g., from wandb logger).
    
    Parameters:
    - csv_path: Path to CSV file with metrics
    
    Returns:
    - metrics_df: DataFrame with metrics
    """
    df = pd.read_csv(csv_path)
    return df

def calculate_saturation_point(metrics_df, metric_name='train_loss', threshold_pct=0.1):
    """
    Calculate the first step where the model reaches within threshold_pct of final loss.
    
    Parameters:
    - metrics_df: DataFrame with columns including 'step' and the metric column
    - metric_name: Name of the metric column to analyze (e.g., 'train_loss')
    - threshold_pct: Percentage threshold (default 0.1 for 10%)
    
    Returns:
    - saturation_dict: Dictionary with saturation metrics
    """
    # Filter for rows where the metric exists (not NaN)
    metric_data = metrics_df[['step', metric_name]].dropna().copy()
    metric_data = metric_data.sort_values('step')
    
    if len(metric_data) == 0:
        raise ValueError(f"Metric '{metric_name}' not found or has no valid data")
    
    # Get final loss value (average of last 10% of training)
    n_final = max(1, len(metric_data) // 10)
    final_loss = metric_data[metric_name].tail(n_final).mean()
    
    # Get initial loss
    initial_loss = metric_data[metric_name].iloc[0]
    
    # Calculate threshold value
    threshold_value = final_loss * (1 + threshold_pct)
    
    # Find first step where loss is within threshold of final value
    within_threshold = metric_data[metric_data[metric_name] <= threshold_value]
    
    if len(within_threshold) == 0:
        saturation_step = None
        saturation_pct = None
    else:
        saturation_step = int(within_threshold['step'].iloc[0])
        total_steps = int(metric_data['step'].max())
        saturation_pct = (saturation_step / total_steps * 100)
    
    return {
        'saturation_step': saturation_step,
        'initial_loss': float(initial_loss),
        'final_loss': float(final_loss),
        'threshold_value': float(threshold_value),
        'total_steps': int(metric_data['step'].max()),
        'saturation_pct': saturation_pct
    }

def extract_all_metrics(csv_path, output_summary=None):
    """
    Extract all metrics and calculate saturation for train/val loss.
    
    Parameters:
    - csv_path: Path to CSV file with metrics
    - output_summary: Optional path to save saturation summary JSON
    
    Returns:
    - metrics_df: DataFrame with all metrics
    - saturation_metrics: Dictionary with saturation analysis
    """
    # Read metrics
    metrics_df = read_csv_metrics(csv_path)
    
    print(f"Available metrics: {list(metrics_df.columns)}")
    print(f"Total rows: {len(metrics_df)}")
    
    # Calculate saturation for different metrics
    saturation_metrics = {}
    
    # Look for loss metrics in the columns
    loss_columns = [col for col in metrics_df.columns if 'loss' in col.lower() and col != 'step']
    
    for metric_name in loss_columns:
        try:
            saturation = calculate_saturation_point(metrics_df, metric_name)
            saturation_metrics[metric_name] = saturation
            print(f"\n{metric_name}:")
            print(f"  Saturation step: {saturation['saturation_step']}")
            if saturation['saturation_pct']:
                print(f"  Saturation at: {saturation['saturation_pct']:.1f}% of training")
            print(f"  Initial loss: {saturation['initial_loss']:.4f}")
            print(f"  Final loss: {saturation['final_loss']:.4f}")
        except (ValueError, KeyError) as e:
            print(f"  Could not calculate saturation for {metric_name}: {e}")
    
    # Save saturation summary if requested
    if output_summary:
        import json
        with open(output_summary, 'w') as f:
            json.dump(saturation_metrics, f, indent=2)
        print(f"\nSaturation summary saved to {output_summary}")
    
    return metrics_df, saturation_metrics

# Usage example
csv_path = "/gpfs/home/asun/jin_lab/perturbench/debug/out_L6_CT_CTX_holdout/version_0/metrics.csv"
metrics_df, saturation_metrics = extract_all_metrics(
    csv_path, 
    output_summary="saturation_summary.json"
)

# Access specific saturation point
if 'train_loss' in saturation_metrics:
    train_saturation_step = saturation_metrics['train_loss']['saturation_step']
    print(f"\nTraining speed saturation: {train_saturation_step} steps")

Available metrics: ['batch_time_avg', 'batch_time_cv_percent', 'batch_time_max', 'batch_time_max_min_ratio', 'batch_time_min', 'batches_per_second', 'decoder_loss', 'epoch', 'step', 'train_loss', 'val/decoder_loss', 'val_loss']
Total rows: 1200

decoder_loss:
  Saturation step: 649
  Saturation at: 1.6% of training
  Initial loss: 8.2925
  Final loss: 1.5523

train_loss:
  Saturation step: 1499
  Saturation at: 3.7% of training
  Initial loss: 8.1391
  Final loss: 0.5946

val/decoder_loss:
  Saturation step: 5999
  Saturation at: 15.0% of training
  Initial loss: 4.2559
  Final loss: 0.8933

val_loss:
  Saturation step: 7399
  Saturation at: 18.5% of training
  Initial loss: 2.4406
  Final loss: 0.2517

Saturation summary saved to saturation_summary.json

Training speed saturation: 1499 steps
