In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import time
import os

# Transformaciones CIFAR-10
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465),
                         (0.2023, 0.1994, 0.2010))
])

# Dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

# Dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Simulación de red simple para prueba
import torch.nn as nn
import torch.nn.functional as F

class DummyNet(nn.Module):
    def __init__(self):
        super(DummyNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.fc1 = nn.Linear(16 * 32 * 32, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# Crear modelo
model = DummyNet().to(device)

# Benchmark
batch_sizes = [32, 64, 128, 256]
worker_counts = [0, 2, 4, 8, 12, 16]

results = []

print("Benchmarking combinaciones...\n")

for batch_size in batch_sizes:
    for num_workers in worker_counts:
        dataloader = DataLoader(trainset, batch_size=batch_size, shuffle=True,
                                num_workers=num_workers)

        start = time.time()

        for i, data in enumerate(dataloader):
            if i >= 100:  # limitar a 100 iteraciones para velocidad
                break
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)

        end = time.time()
        duration = end - start
        results.append((batch_size, num_workers, duration))
        print(f"Batch size: {batch_size:3d}, Workers: {num_workers:2d}, Tiempo: {duration:.2f} s")

# Ordenar por rendimiento
results.sort(key=lambda x: x[2])
print("\n🥇 Mejores combinaciones (por menor tiempo):")
for b, w, t in results[:5]:
    print(f"- Batch: {b}, Workers: {w}, Tiempo: {t:.2f}s")


Benchmarking combinaciones...

Batch size:  32, Workers:  0, Tiempo: 0.47 s
Batch size:  32, Workers:  2, Tiempo: 5.47 s
Batch size:  32, Workers:  4, Tiempo: 10.10 s
Batch size:  32, Workers:  8, Tiempo: 19.25 s
Batch size:  32, Workers: 12, Tiempo: 28.27 s
Batch size:  32, Workers: 16, Tiempo: 38.76 s
Batch size:  64, Workers:  0, Tiempo: 0.67 s
Batch size:  64, Workers:  2, Tiempo: 5.55 s
Batch size:  64, Workers:  4, Tiempo: 9.84 s
Batch size:  64, Workers:  8, Tiempo: 19.18 s
Batch size:  64, Workers: 12, Tiempo: 28.79 s
Batch size:  64, Workers: 16, Tiempo: 38.24 s
Batch size: 128, Workers:  0, Tiempo: 1.40 s
Batch size: 128, Workers:  2, Tiempo: 5.72 s
Batch size: 128, Workers:  4, Tiempo: 10.24 s
Batch size: 128, Workers:  8, Tiempo: 19.13 s
Batch size: 128, Workers: 12, Tiempo: 29.45 s
Batch size: 128, Workers: 16, Tiempo: 38.30 s
Batch size: 256, Workers:  0, Tiempo: 2.40 s
Batch size: 256, Workers:  2, Tiempo: 6.36 s
Batch size: 256, Workers:  4, Tiempo: 11.17 s
Batch size: 