# Visualizing Training Loss with `plot_loss`

This notebook demonstrates how to use the `plot_loss` utility to visualize training loss curves
after running SFT, OSFT, or LoRA training. This notebook covers:

1. Basic usage - plotting loss from a single training run
2. Comparing multiple training runs (even across different algorithms)
3. Using EMA (Exponential Moving Average) smoothing
4. Customizing the output and handling different metrics files

**Supported algorithms:** SFT, OSFT, and LoRA (automatically detects the metrics format)

## Setup

In [None]:
from training_hub import plot_loss

## Overview

The `plot_loss` function reads metrics from checkpoint directories and creates
visualizations of the loss curve over training steps.

### How It Works

1. **Auto-detection**: The function automatically finds metrics files in your checkpoint directory
2. **Format support**: Handles both JSONL (SFT/OSFT) and trainer_state.json (LoRA) formats
3. **Metric extraction**: Reads loss values using common keys like `avg_loss`, `loss`, `train_loss`
4. **Visualization**: Creates a matplotlib plot and saves it as a PNG file

### Supported Formats

| Algorithm | Backend | Metrics Format | Location |
|-----------|---------|----------------|----------|
| SFT | instructlab-training | JSONL | `ckpt_dir/*.jsonl` |
| OSFT | mini-trainer | JSONL | `ckpt_dir/*.jsonl` |
| LoRA | Unsloth/TRL | trainer_state.json | `ckpt_dir/checkpoint-*/trainer_state.json` |

### Function Signature

```python
def plot_loss(
    ckpt_output_dirs: str | list[str],  # Checkpoint directory or list of directories
    *,
    metrics_file: str | None = None,     # Metrics filename (auto-detected if None)
    output_path: str | None = None,      # Output path (default: loss_plot.png in ckpt dir)
    labels: list[str] | None = None,     # Labels for legend (auto-generated if None)
    ema: bool = False,                   # Enable EMA smoothing overlay
    ema_span: int = 30,                  # EMA span for smoothing
    metric_keys: list[str] | None = None,  # Custom metric keys to search for
    show: bool = False,                  # Display plot interactively
) -> str:                                # Returns path to saved plot
```

## Basic Usage

After running training with `sft()`, `osft()`, or `lora_sft()`, simply pass the checkpoint directory to `plot_loss`:

In [None]:
# Example: Plot loss from a training run
# (Replace with your actual checkpoint directory)

ckpt_dir = "/path/to/your/checkpoints"

# This will:
# 1. Auto-detect the metrics JSONL file in the directory
# 2. Extract loss values from the file
# 3. Save the plot to loss_plot.png in the checkpoint directory

# plot_path = plot_loss(ckpt_dir)
# print(f"Plot saved to: {plot_path}")

### Typical Workflow

Here's a complete example showing training followed by visualization:

In [None]:
# from training_hub import osft, plot_loss
#
# # Run OSFT training
# osft(
#     model_path="Qwen/Qwen2.5-7B-Instruct",
#     data_path="/path/to/training_data.jsonl",
#     ckpt_output_dir="./my_experiment",
#     unfreeze_rank_ratio=0.25,
#     effective_batch_size=128,
#     max_tokens_per_gpu=10000,
#     max_seq_len=4096,
#     learning_rate=5e-6,
#     num_epochs=3,
# )
#
# # After training completes, visualize the loss
# plot_loss("./my_experiment")

## Comparing Multiple Training Runs

One of the most useful features is comparing loss curves from different experiments.
Pass a list of checkpoint directories to overlay multiple runs.

This works across different algorithms - you can compare SFT, OSFT, and LoRA runs
in the same plot:

In [None]:
# Compare different experiments (same algorithm)
experiments = [
    "/path/to/experiment_lr_1e-5",
    "/path/to/experiment_lr_5e-6",
    "/path/to/experiment_lr_1e-6",
]

# With custom labels for the legend
# plot_loss(
#     experiments,
#     labels=["lr=1e-5", "lr=5e-6", "lr=1e-6"],
#     output_path="./learning_rate_comparison.png"
# )

# Compare across different algorithms
# plot_loss(
#     ["./sft_outputs", "./osft_outputs", "./lora_outputs"],
#     labels=["SFT", "OSFT", "LoRA"],
#     output_path="./algorithm_comparison.png"
# )

If you don't provide labels, they will be auto-generated from the directory names.

## EMA Smoothing

Training loss can be noisy. Enable EMA (Exponential Moving Average) smoothing to see
the underlying trend more clearly:

In [None]:
# Enable EMA smoothing with default span (30 steps)
# plot_loss(ckpt_dir, ema=True)

# Customize the smoothing window
# Larger span = smoother curve, but less responsive to changes
# plot_loss(ckpt_dir, ema=True, ema_span=50)

When EMA is enabled, you'll see both:
- **Solid line**: Raw loss values
- **Dashed line**: EMA-smoothed values

This is particularly useful for:
- Identifying overall training trends
- Spotting when training starts to plateau
- Comparing noisy runs more easily

## Customization Options

### Specifying the Metrics File

By default, `plot_loss` auto-detects metrics files by looking for:

**For SFT/OSFT (JSONL format):**
1. `training_log.jsonl`
2. `training_metrics.jsonl`
3. `metrics.jsonl`
4. Any `.jsonl` file (excluding `data.jsonl`)

**For LoRA (HuggingFace Trainer format):**
1. `trainer_state.json` inside `checkpoint-*` subdirectories

If your metrics file has a different name, specify it explicitly:

In [None]:
# Specify a custom metrics filename
# plot_loss(ckpt_dir, metrics_file="my_custom_metrics.jsonl")

### Custom Output Path

By default, plots are saved to `loss_plot.png` in the checkpoint directory.
You can specify a custom path:

In [None]:
# Save to a specific location
# plot_loss(ckpt_dir, output_path="./reports/training_loss.png")

# Save as PDF for publication
# plot_loss(ckpt_dir, output_path="./figures/loss_curve.pdf")

### Custom Metric Keys

The function looks for loss values using these keys (in order):
- `avg_loss`
- `loss`
- `avg_loss_backwards`
- `train_loss`

If your metrics file uses different key names, specify them:

In [None]:
# Use custom metric keys
# plot_loss(ckpt_dir, metric_keys=["training_loss", "batch_loss"])

### Interactive Display

By default, the plot is saved to a file. To also display it interactively
(useful in Jupyter notebooks or for quick inspection):

In [None]:
# Display the plot interactively (in addition to saving)
# plot_loss(ckpt_dir, show=True)

## Parameter Reference

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `ckpt_output_dirs` | `str \| list[str]` | Required | Checkpoint directory or list of directories |
| `metrics_file` | `str \| None` | `None` | Metrics filename (auto-detected if None) |
| `output_path` | `str \| None` | `None` | Output path (default: `loss_plot.png` in first ckpt dir) |
| `labels` | `list[str] \| None` | `None` | Legend labels (auto-generated from dir names if None) |
| `ema` | `bool` | `False` | Enable EMA smoothing overlay |
| `ema_span` | `int` | `30` | EMA span for smoothing |
| `metric_keys` | `list[str] \| None` | `None` | Custom metric keys to search for |
| `show` | `bool` | `False` | Display plot interactively |

## Tips and Best Practices

1. **Compare fairly**: When comparing runs, ensure they used the same dataset and similar configurations

2. **Use EMA for noisy data**: If your loss curve is very noisy, EMA smoothing helps reveal the trend

3. **Organize experiments**: Use descriptive directory names so auto-generated labels are meaningful

4. **Save comparisons**: When comparing runs, save to a dedicated comparison file rather than overwriting individual plots

5. **Check early**: Plot loss periodically during long training runs (checkpoints contain metrics up to that point)

## Troubleshooting

### "No metrics file found"

The function couldn't auto-detect a metrics file. Solutions:
- Verify your checkpoint directory path is correct
- Check that training completed and wrote metrics
- Use `metrics_file` parameter to specify the exact filename

### "No matching metric found"

The metrics file was found but doesn't contain recognized loss keys. Solutions:
- Check your JSONL file format (each line should be valid JSON)
- Use `metric_keys` parameter to specify your custom key names

### "No loss data found"

None of the provided directories contained valid loss data. Solutions:
- Verify at least one directory has a valid metrics file
- Check that the metrics files contain numeric loss values