In [None]:
from utils.data_loader import SRDataset
from torch.utils.data import DataLoader
import time

root = "data/preprocessed"

def benchmark(num_workers):
    dataset = SRDataset(root, split="train")
    loader = DataLoader(dataset, batch_size=64, num_workers=num_workers, pin_memory=True)
    
    start = time.time()
    for i, batch in enumerate(loader):
        if i == 20:  # juste 20 batches pour tester
            break
    duration = time.time() - start
    return duration

for w in [4, 8, 12, 16]:
    t = benchmark(w)
    print(f"num_workers={w} → {t:.3f} sec")


In [None]:
import time
import numpy as np
from torch.utils.data import DataLoader
from utils.data_loader import SRDataset  # ou SRDatasetAug
import torch

root = "data/preprocessed"

def benchmark(num_workers, repeats=20, batches_to_load=30):
    """
    Benchmark du DataLoader :
    - repeats : nombre de répétitions pour la moyenne
    - batches_to_load : nombre de batches à lire pour mesurer
    """
    times = []

    for r in range(repeats):
        dataset = SRDataset(root=root, split="train")
        loader = DataLoader(
            dataset,
            batch_size=128,
            num_workers=num_workers,
            pin_memory=True
        )

        # forcer le système à "purger" un peu le cache
        torch.cuda.empty_cache()

        start = time.time()
        
        for i, batch in enumerate(loader):
            if i >= batches_to_load:
                break
        
        duration = time.time() - start
        times.append(duration)

    return np.mean(times), np.std(times)


# ---- Valeurs à tester ----
workers_list = [8, 12]

print("===== Benchmark num_workers =====")
for w in workers_list:
    mean, std = benchmark(w)
    print(f"num_workers={w:2d} → mean={mean:.3f} sec | std={std:.3f}")
