1. **Per-batch tensor footprint**

   $$
   \begin{aligned}
   \text{bytes}_{\text{wave}} &= \text{chunk\_samples} \times 4 \\
   \text{bytes}_{\text{token}} &= \text{max\_token\_length} \times 4 \\
   \text{batch\_tensor\_MB} &= \frac{\text{batch\_size} \times \left(\text{bytes}_{\text{wave}} + \text{bytes}_{\text{token}}\right)}{1024^2}
   \end{aligned}
   $$

2. **Inflight batches**

   $$
   \begin{aligned}
   \text{inflight} &= \text{prefetch\_factor} \times \min(\text{num\_workers}, \text{batch\_size}) \\
   \text{BatchTensors}_{\text{total}} &= \text{batch\_tensor\_MB} \times \left(1 + \text{inflight}\right)
   \end{aligned}
   $$

3. **Temperature-weighted shard sizes**

   $$
   \begin{aligned}
   w_i &=
     \begin{cases}
       \text{count}_i^{1 / \text{temperature}} & \text{if temperature} \neq \text{None} \\
       \text{count}_i & \text{otherwise}
     \end{cases} \\
   \tilde{w}_i &= \frac{w_i}{\sum_j w_j} \\
   \text{avg\_shard\_MB} &= \sum_i \tilde{w}_i \cdot \text{mean\_shard\_size}_i
   \end{aligned}
   $$

4. **Temperature-weighted MIDI sizes**

   $$
   \text{avg\_midi\_MB} = \sum_i \tilde{w}_i \cdot \text{mean\_noteseq\_size}_i
   $$

5. **Cache footprint per worker**

   $$
   \begin{aligned}
   \text{ShardCache}_{\text{per}} &= \text{shard\_cache\_size} \times \text{avg\_shard\_MB} \\
   \text{MidiCache}_{\text{per}} &= \text{max\_midi\_cache\_size} \times \text{avg\_midi\_MB}
   \end{aligned}
   $$

6. **Cache total across workers**

   $$
   \text{Cache}_{\text{total}} = \text{num\_workers} \times \left(\text{ShardCache}_{\text{per}} + \text{MidiCache}_{\text{per}}\right)
   $$

7. **Peak RAM estimate**

   $$
   \text{PeakRAM}_{\text{total}} \approx \text{BatchTensors}_{\text{total}} + \text{Cache}_{\text{total}}
   $$


In [1]:
from __future__ import annotations

from pathlib import Path
import random
import statistics
from typing import Dict, Iterable, Tuple

import note_seq  # pip install note-seq
import pandas as pd

from configs import load_project_config


def estimate_peak_loader_ram(
    *,
    manifest_path: str | Path | None = None,
    batch_size: int = 16,
    num_workers: int = 8,
    prefetch_factor: int = 2,
    shard_cache_size: int = 32,
    max_midi_cache_size: int = 2048,
    split: str | Iterable[str] | None = None,
    temperature: float | None = None,
    shard_samples_per_dataset: int = 50,
    midi_samples_per_dataset: int = 25,
) -> Dict[str, float]:
    """
    Estimate peak RAM for the MT3 dataloader (tensors + shard/MIDI caches).

    Returns a dict with all intermediate values (in MiB) and the final estimate.
    """
    cfg = load_project_config()
    manifest_path = Path(manifest_path or cfg["paths"]["cache"]["chunk_manifest"])
    manifest = pd.read_parquet(manifest_path)
    if split is not None:
        allowed = {split} if isinstance(split, str) else set(split)
        manifest = manifest[manifest["split"].isin(allowed)].reset_index(drop=True)
    if manifest.empty:
        raise ValueError("Filtered manifest is empty; cannot estimate RAM.")

    # --- 1) Per-batch tensor footprint --------------------------------------
    chunk_samples = cfg["audio"]["features"]["chunk_samples"]
    max_tokens = cfg["symbolic"]["tokenizer"].get("max_token_length", 1024)
    bytes_wave = chunk_samples * 4  # float32
    bytes_token = max_tokens * 4    # int32
    batch_tensor_mb = batch_size * (bytes_wave + bytes_token) / 1024**2
    print(f"Estimated per-batch tensor footprint: {batch_tensor_mb:.2f} MiB")   

    # --- 2) Inflight batches -------------------------------------------------
    inflight = prefetch_factor * min(num_workers, batch_size)
    tensors_total = batch_tensor_mb * (1 + inflight)
    print(f"Estimated total tensor footprint (including inflight={inflight} batches): {tensors_total:.2f} MiB")

    # --- 3) Dataset weights (temperature-aware) ------------------------------
    counts = manifest["dataset"].value_counts().to_dict()
    if temperature:
        weights_raw = {ds: count ** (1.0 / temperature) for ds, count in counts.items()}
    else:
        weights_raw = counts
    weight_norm = sum(weights_raw.values())
    weights = {ds: val / weight_norm for ds, val in weights_raw.items()}
    print(f"Dataset weights (temperature={temperature}): {weights}")
    # --- helpers -------------------------------------------------------------
    def sample_mean_file_size(paths: pd.Series, k: int) -> float:
        paths = paths.dropna()
        if paths.empty:
            return 0.0
        sample = paths.sample(min(k, len(paths)), random_state=0)
        sizes = []
        for raw in sample:
            p = Path(raw)
            if p.is_file():
                sizes.append(p.stat().st_size / 1024**2)
        return statistics.mean(sizes) if sizes else 0.0

    def sample_mean_noteseq_size(paths: pd.Series, k: int) -> float:
        paths = paths.dropna()
        if paths.empty:
            return 0.0
        sample = paths.sample(min(k, len(paths)), random_state=0)
        sizes = []
        for raw in sample:
            p = Path(raw)
            if not p.is_file():
                continue
            try:
                ns = note_seq.midi_file_to_note_sequence(str(p))
                sizes.append(len(ns.SerializeToString()) / 1024**2)
            except Exception:
                continue
        return statistics.mean(sizes) if sizes else 0.0

    # --- 3/4) Temperature-weighted averages ---------------------------------
    avg_shard_mb = 0.0
    avg_midi_mb = 0.0
    for dataset, weight in weights.items():
        subset = manifest[manifest["dataset"] == dataset]
        shard_mean = sample_mean_file_size(
            subset.loc[subset["chunk_storage"] == "per_track", "chunk_shard_path"],
            shard_samples_per_dataset,
        )
        midi_mean = sample_mean_noteseq_size(
            subset["midi_path"],
            midi_samples_per_dataset,
        )
        avg_shard_mb += weight * shard_mean
        avg_midi_mb += weight * midi_mean
    print(f"Estimated average shard size: {avg_shard_mb:.2f} MiB")
    print(f"Estimated average MIDI size: {avg_midi_mb:.2f} MiB")

    # --- 5) Cache footprint per worker --------------------------------------
    shard_cache_per = shard_cache_size * avg_shard_mb
    midi_cache_per = max(0, max_midi_cache_size or 0) * avg_midi_mb
    print(f"Estimated shard cache per worker: {shard_cache_per:.2f} MiB")
    print(f"Estimated MIDI cache per worker: {midi_cache_per:.2f} MiB")

    # --- 6) Cache total across workers --------------------------------------
    cache_total = num_workers * (shard_cache_per + midi_cache_per)
    print(f"Estimated total cache footprint across workers: {cache_total:.2f} MiB")
    # --- 7) Final estimate ---------------------------------------------------
    peak_ram = tensors_total + cache_total
    print(f"Estimated peak dataloader RAM: {peak_ram:.2f} MiB")

    return {
        "batch_tensor_mb": batch_tensor_mb,
        "inflight_batches": inflight,
        "BatchTensors_total": tensors_total,
        "avg_shard_mb": avg_shard_mb,
        "avg_midi_mb": avg_midi_mb,
        "ShardCache_per_worker": shard_cache_per,
        "MidiCache_per_worker": midi_cache_per,
        "Cache_total": cache_total,
        "PeakRAM_total": peak_ram,
    }




In [18]:
#params for testing
prefetch_factor = 8
batch_size = 256
num_workers = 32
shard_cache_size = 128
max_midi_cache_size = 2048
temperature = 10/3
estimate_peak_loader_ram(
    batch_size=batch_size,
    num_workers=num_workers,
    prefetch_factor=prefetch_factor,
    shard_cache_size=shard_cache_size,
    max_midi_cache_size=max_midi_cache_size,
    temperature=temperature,
)

Estimated per-batch tensor footprint: 33.00 MiB
Estimated total tensor footprint (including inflight=256 batches): 8481.00 MiB
Dataset weights (temperature=3.3333333333333335): {'slakh_stem': 0.48384983409481314, 'maestro': 0.2780256627908883, 'slakh_full_mix': 0.23812450311429859}
Estimated average shard size: 14.16 MiB
Estimated average MIDI size: 0.16 MiB
Estimated shard cache per worker: 1812.89 MiB
Estimated MIDI cache per worker: 330.07 MiB
Estimated total cache footprint across workers: 68574.65 MiB
Estimated peak dataloader RAM: 77055.65 MiB


{'batch_tensor_mb': 33.0,
 'inflight_batches': 256,
 'BatchTensors_total': 8481.0,
 'avg_shard_mb': 14.163182199859687,
 'avg_midi_mb': 0.1611671600624968,
 'ShardCache_per_worker': 1812.88732158204,
 'MidiCache_per_worker': 330.0703438079934,
 'Cache_total': 68574.64529248107,
 'PeakRAM_total': 77055.64529248107}

In [11]:
from pathlib import Path
import time
import psutil

from data.datasets.loader import build_chunk_dataloader
from configs import load_project_config


def profile_dataloader(prefetch_factor: int, shard_cache_size: int,max_midi_cache_size: int, num_workers: int, batch_size: int, *,
                       max_batches: int = 64, split: str = "train", temperature: float = 10/3, log_dir: Path | str = "dataloader_checkpoints") -> Path:
    """Build the chunk dataloader with the provided knobs, print diagnostics, and
    dump RAM/time checkpoints to `log_dir`.

    Returns the path to the written checkpoint file.
    """
    cfg = load_project_config()
    manifest_path = Path(cfg["paths"]["cache"]["chunk_manifest"])
    if not manifest_path.exists():
        raise FileNotFoundError(f"Chunk manifest not found at {manifest_path}")

    def available_ram_mib() -> float:
        return psutil.virtual_memory().available / 1024**2

    loader_kwargs = dict(
        manifest_path=manifest_path,
        batch_size=batch_size,
        feature_type="waveform",
        load_tokens=True,
        max_examples_per_mix=4,
        temperature=temperature,
        num_workers=num_workers,
        pin_memory=True,
        compute_log_mel_in_collate=False,
        collate_device="cpu",
        seed=0,
        shard_cache_size=shard_cache_size,
        prefetch_factor=prefetch_factor,
        max_midi_cache_size=max_midi_cache_size,
        split=split,
    )

    t_build = time.perf_counter()
    dataloader = build_chunk_dataloader(**loader_kwargs)
    print(f"Built dataloader with {loader_kwargs} in {time.perf_counter() - t_build:.2f}s")

    base_ram = available_ram_mib()
    batch_times: list[float] = []
    batch_ram: list[float] = []

    t0 = time.perf_counter()
    for idx, batch in enumerate(dataloader):
        load_time = time.perf_counter() - t0
        delta_ram = base_ram - available_ram_mib()
        batch_times.append(load_time)
        batch_ram.append(delta_ram)

        print(f"[Batch {idx}] load={load_time:.2f}s, tokens={batch['tokens'].shape}, "
              f"waveform={batch['waveform'].shape}, ΔRAM={delta_ram:.2f} MiB")
        if idx + 1 >= max_batches:
            break
        t0 = time.perf_counter()

    out_dir = Path(log_dir)
    out_dir.mkdir(exist_ok=True)
    out_path = out_dir / f"dataloader_checkpoints_bs{batch_size}_nw{num_workers}_scs{shard_cache_size}_pr{prefetch_factor}.txt"
    with out_path.open("w") as fp:
        fp.write(f"Loader kwargs: {loader_kwargs}\n")
        for i, (ram, t) in enumerate(zip(batch_ram, batch_times)):
            fp.write(f"Batch {i}: Available RAM: {ram:.2f} MiB, Load time: {t:.2f} seconds\n")
        fp.write(f"Peak memory usage during batch loading: {max(batch_ram, default=0):.2f} MiB\n")
        fp.write(f"Min batch load time: {min(batch_times, default=0):.2f} seconds\n")
        fp.write(f"Max batch load time: {max(batch_times, default=0):.2f} seconds\n")
        fp.write(f"Avg batch load time: {(sum(batch_times)/len(batch_times)):.2f} seconds\n")
    print(f"Checkpoint saved to {out_path}")
    return out_path


In [25]:
#params for testing
prefetch_factor = 16
batch_size = 256
num_workers = 32
shard_cache_size = 128
max_midi_cache_size = 2048
temperature = 10/3
max_batches = 128
profile_dataloader(
    prefetch_factor=prefetch_factor,
    shard_cache_size=shard_cache_size,
    max_midi_cache_size=max_midi_cache_size,
    num_workers=num_workers,
    batch_size=batch_size,
    temperature=temperature,
    max_batches=max_batches,
)

Built dataloader with {'manifest_path': PosixPath('cache/chunk_manifest.parquet'), 'batch_size': 256, 'feature_type': 'waveform', 'load_tokens': True, 'max_examples_per_mix': 4, 'temperature': 3.3333333333333335, 'num_workers': 32, 'pin_memory': True, 'compute_log_mel_in_collate': False, 'collate_device': 'cpu', 'seed': 0, 'shard_cache_size': 128, 'prefetch_factor': 16, 'max_midi_cache_size': 2048, 'split': 'train'} in 12.12s
[Batch 0] load=45.71s, tokens=torch.Size([256, 1024]), waveform=torch.Size([256, 32768]), ΔRAM=47719.47 MiB
[Batch 1] load=2.31s, tokens=torch.Size([256, 1024]), waveform=torch.Size([256, 32768]), ΔRAM=47898.15 MiB
[Batch 2] load=0.00s, tokens=torch.Size([256, 1024]), waveform=torch.Size([256, 32768]), ΔRAM=47898.08 MiB
[Batch 3] load=0.41s, tokens=torch.Size([256, 1024]), waveform=torch.Size([256, 32768]), ΔRAM=47838.77 MiB
[Batch 4] load=0.52s, tokens=torch.Size([256, 1024]), waveform=torch.Size([256, 32768]), ΔRAM=47809.16 MiB
[Batch 5] load=0.00s, tokens=torch

PosixPath('dataloader_checkpoints/dataloader_checkpoints_bs256_nw32_scs128_pr16.txt')

In [None]:
from pathlib import Path
from configs import load_project_config
from data.datasets.loader import build_chunk_dataloader
import time
import psutil


def get_available_memory_mib():
    """Helper function to return available memory in MiB."""
    return psutil.virtual_memory().available / (1024 ** 2)

print("--- Script Start ---")


cfg = load_project_config()
manifest_path = Path(cfg["paths"]["cache"]["chunk_manifest"])
if not manifest_path.exists():
    raise FileNotFoundError(f"Chunk manifest not found at {manifest_path}")

start_time = time.time()

loader_kwargs = dict(
    manifest_path=manifest_path,
    batch_size=16,
    feature_type="waveform",
    load_tokens=True,
    max_examples_per_mix=4,
    temperature=0.3,
    num_workers=16,
    pin_memory=True,
    compute_log_mel_in_collate=False,
    collate_device="cpu",
    seed=0,
    shard_cache_size=32,
    prefetch_factor=4,
    max_midi_cache_size=2048,
    split="train",
)

dataloader = build_chunk_dataloader(**loader_kwargs) 
print(f"Building dataloader in {time.time() - start_time:.2f} seconds...")
# 1. Capture the initial available RAM
initial_available_ram = get_available_memory_mib()
checkpoint_1 = initial_available_ram - get_available_memory_mib() 

print("\n--- DataLoader Checkpoint ---")
print(f"used RAM after building dataloader: {checkpoint_1:.2f} MiB")

batch_checkpoints_ram = []
batch_checkpoints_time = []

t0 = time.time()
batch_number = 64
for idx, batch in enumerate(dataloader):
    print(f"Batch {idx+1}")
    print("batch keys:", batch.keys())
    print("  waveform:", batch["waveform"].shape)
    print("  tokens:", batch["tokens"].shape)
    # token list sample examples (first 10 tokens of the first sample in the batch)
    print("  token sample (first 10 tokens of first sample):", batch["tokens"][0][:10])
    # token list sample examples (last 10 tokens of the first sample in the batch)
    print("  token sample (last 10 tokens of first sample):", batch["tokens"][0][-10:])
    print("token mask:", batch["token_mask"].shape)
    print("  metadata :", batch["metadata"])
    t1 = time.time() - t0
    print(f"Batch {idx+1} loaded in {t1:.2f} seconds")
    batch_checkpoints_time.append(t1)
    batch_checkpoints_ram.append(initial_available_ram - get_available_memory_mib() )
    t0 = time.time()
    if idx >= batch_number - 1:
        break

total_time = time.time() - start_time
print(f"total time: {total_time:.2f} seconds")
# Save time and ram checkpoints to a file with the loader_kwargs info and the file name should include the batch size and num_workers and shard_cache_size, all the files should be inside a folder called dataloader_checkpoints
output_dir = Path("dataloader_checkpoints")
output_dir.mkdir(exist_ok=True)
output_path = output_dir / f"dataloader_checkpoints_bs{loader_kwargs['batch_size']}_nw{loader_kwargs['num_workers']}_scs{loader_kwargs['shard_cache_size']}_pr{loader_kwargs['prefetch_factor']}.txt"
with output_path.open("w") as f:
    f.write(f"Loader kwargs: {loader_kwargs}\n")
    for i, (ram, t) in enumerate(zip(batch_checkpoints_ram, batch_checkpoints_time)):
        f.write(f"Batch {i}: Available RAM: {ram:.2f} MiB, Load time: {t:.2f} seconds\n")
    # peak memory usage during the batch loading
    peak_memory = max(batch_checkpoints_ram)
    # min, max, avg loading time
    min_time = min(batch_checkpoints_time)
    max_time = max(batch_checkpoints_time)
    avg_time = sum(batch_checkpoints_time) / len(batch_checkpoints_time)
    f.write(f"Peak memory usage during batch loading: {peak_memory:.2f} MiB\n")
    f.write(f"Min batch load time: {min_time:.2f} seconds\n")
    f.write(f"Max batch load time: {max_time:.2f} seconds\n")
    f.write(f"Avg batch load time: {avg_time:.2f} seconds\n")       
print(f"Checkpoints saved to {output_path}")    