# SharedMemoryDataLoader Demo & Benchmark

This notebook demonstrates how to use the SharedMemoryDataLoader and runs a small benchmark to test its performance.


In [ ]:
import time
import torch
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display, clear_output

# Use the new factory pattern!
from data_loader import actvs_loader_from_test_config

In [ ]:
# Much simpler with factory pattern!
# No need to manually configure - the factory handles it all

batch_size = 1000  # Smaller batch size for demo

print("📋 Using test configuration with factory pattern")
print(f"   Batch size: {batch_size}")
print("   Buffer size: 100,000 samples") 
print("   Generation batch size: 2")


In [None]:
# Wait a bit for buffer to fill
time.sleep(10)

print("📦 Loading first batch...")
try:
    # Load first batch
    batch = next(iter(loader))
    
    print(f"✅ Successfully loaded batch!")
    print(f"   Shape: {batch.shape}")
    print(f"   Data type: {batch.dtype}")
    print(f"   Device: {batch.device}")
    print(f"   Memory usage: {batch.numel() * batch.element_size() / 1024 / 1024:.2f} MB")
    
    # Show some sample data
    print(f"\n📊 Sample data (first 5 sequences, first 10 tokens):")
    print(batch[:5, :10])


In [None]:
def run_mini_benchmark(loader, duration=10, batch_size=1000):
    """Run a mini benchmark for the specified duration."""
    print(f"🏃 Running mini benchmark for {duration} seconds...")
    
    # Get first batch to calculate memory per batch
    first_batch = next(iter(loader))
    bytes_per_batch = first_batch.numel() * first_batch.element_size()
    mb_per_batch = bytes_per_batch / (1024 * 1024)
    
    print(f"   Batch size: {batch_size} samples")
    print(f"   Memory per batch: {mb_per_batch:.2f} MB")
    print(f"   Starting benchmark...\n")
    
    start_time = time.time()
    batch_count = 0
    total_bytes = 0
    batch_times = []
    
    try:
        while time.time() - start_time < duration:
            batch_start = time.time()
            batch = next(iter(loader))
            batch_time = time.time() - batch_start
            
            batch_count += 1
            total_bytes += bytes_per_batch
            batch_times.append(batch_time * 1000)  # Convert to ms
            
            # Print progress every few batches
            if batch_count % 3 == 0:
                elapsed = time.time() - start_time
                current_mb_s = (total_bytes / (1024 * 1024)) / elapsed
                print(f"   Batch {batch_count}: {batch_time*1000:.1f}ms, Running avg: {current_mb_s:.1f} MB/s")
    
    except Exception as e:
        print(f"   ⚠️  Benchmark interrupted: {e}")
    
    # Calculate final results
    total_time = time.time() - start_time
    total_mb = total_bytes / (1024 * 1024)
    mb_per_second = total_mb / total_time
    samples_per_second = (batch_count * batch_size) / total_time
    avg_batch_time = np.mean(batch_times)
    
    return {
        'duration': total_time,
        'batches': batch_count,
        'total_samples': batch_count * batch_size,
        'total_mb': total_mb,
        'mb_per_second': mb_per_second,
        'samples_per_second': samples_per_second,
        'avg_batch_time_ms': avg_batch_time,
        'batch_times': batch_times
    }

# Run the benchmark
results = run_mini_benchmark(loader, duration=10, batch_size=batch_size)
