# IterableDataset — Sharding with num_workers and pin_memory

Demonstrates:
- How to write an `IterableDataset` with correct sharding
- Why `num_workers > 0` without sharding causes duplicate data
- How `pin_memory=True` + `non_blocking=True` fits in

In [None]:
import torch
from torch.utils.data import IterableDataset, DataLoader

## 1. The Dataset

Simulates a streaming dataset (e.g. rows from a CSV or log file).
Each sample is a `(features, label)` pair generated from a simple linear relationship.

In [None]:
class StreamingDataset(IterableDataset):
    """
    Simulates a streaming dataset with 'total_samples' rows.
    Each row: features = random float tensor, label = sum of features.
    """
    def __init__(self, total_samples: int, n_features: int = 4, seed: int = 42):
        super().__init__()
        self.total_samples = total_samples
        self.n_features = n_features
        self.seed = seed

        # Pre-generate all data once (simulates a fixed file on disk)
        rng = torch.Generator().manual_seed(seed)
        self.features = torch.randn(total_samples, n_features, generator=rng)
        self.labels   = self.features.sum(dim=1, keepdim=True)

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        if worker_info is None:
            # ── Single-process DataLoader (num_workers=0) ──────────────────
            # No sharding needed: yield every sample
            start, end = 0, self.total_samples
        else:
            # ── Multi-process DataLoader (num_workers > 0) ─────────────────
            # Divide samples into contiguous blocks, one block per worker
            #
            # Example: 8 samples, 2 workers
            #   worker 0 → indices 0..3
            #   worker 1 → indices 4..7
            total    = self.total_samples
            n_workers = worker_info.num_workers
            wid       = worker_info.id

            # Base chunk size (floor division)
            chunk = total // n_workers
            # Distribute leftover samples to first (total % n_workers) workers
            remainder = total % n_workers

            # Workers with id < remainder get one extra sample
            if wid < remainder:
                start = wid * (chunk + 1)
                end   = start + chunk + 1
            else:
                start = wid * chunk + remainder
                end   = start + chunk

        for idx in range(start, end):
            yield self.features[idx], self.labels[idx]

## 2. Verify Sharding is Correct

Check: union of all worker shards == full dataset, with no overlaps.

In [None]:
TOTAL_SAMPLES = 20
N_FEATURES    = 4
BATCH_SIZE    = 4
NUM_WORKERS   = 3   # intentionally not a divisor of 20 to test remainder logic

dataset = StreamingDataset(total_samples=TOTAL_SAMPLES, n_features=N_FEATURES)

# Use num_workers > 0 with pin_memory
# Note: pin_memory=True has no effect on MPS/CPU but is safe to set
loader = DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,         # pins tensors in CPU RAM for faster GPU transfer
)

all_labels = []
for features, labels in loader:
    all_labels.extend(labels.squeeze().tolist())

print(f"Total samples seen : {len(all_labels)}")
print(f"Unique samples seen: {len(set(round(x, 4) for x in all_labels))}")
print()
if len(all_labels) == TOTAL_SAMPLES:
    print("Sharding OK — every sample seen exactly once")
else:
    print(f"Problem — expected {TOTAL_SAMPLES}, got {len(all_labels)}")

## 3. Show What Goes Wrong Without Sharding

In [None]:
class BrokenStreamingDataset(IterableDataset):
    """Same dataset but __iter__ ignores worker_info — no sharding."""
    def __init__(self, total_samples: int, n_features: int = 4, seed: int = 42):
        super().__init__()
        self.total_samples = total_samples
        rng = torch.Generator().manual_seed(seed)
        self.features = torch.randn(total_samples, n_features, generator=rng)
        self.labels   = self.features.sum(dim=1, keepdim=True)

    def __iter__(self):
        # No get_worker_info() check — every worker streams the full dataset
        for idx in range(self.total_samples):
            yield self.features[idx], self.labels[idx]


broken_loader = DataLoader(
    BrokenStreamingDataset(total_samples=TOTAL_SAMPLES),
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
)

broken_labels = []
for features, labels in broken_loader:
    broken_labels.extend(labels.squeeze().tolist())

print(f"Total samples seen : {len(broken_labels)}")
print(f"Unique samples seen: {len(set(round(x, 4) for x in broken_labels))}")
print()
print(f"Expected {TOTAL_SAMPLES}, got {len(broken_labels)} — "
      f"{len(broken_labels) // TOTAL_SAMPLES}x duplication")

## 4. Full Training Loop

Shows the complete pattern: sharded `IterableDataset` + `pin_memory` + `non_blocking` transfer.

In [None]:
import torch.nn as nn

# ── Device ────────────────────────────────────────────────────────────────────
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")

# ── Model ─────────────────────────────────────────────────────────────────────
model = nn.Sequential(
    nn.Linear(N_FEATURES, 16),
    nn.ReLU(),
    nn.Linear(16, 1)
).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# ── DataLoader (sharded IterableDataset + pin_memory) ─────────────────────────
train_dataset = StreamingDataset(total_samples=1000, n_features=N_FEATURES)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    num_workers=2,
    pin_memory=True,           # pins CPU tensors for fast DMA transfer to GPU
    persistent_workers=True,   # keeps worker processes alive between epochs
)

# ── Training loop ─────────────────────────────────────────────────────────────
EPOCHS = 3

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    n_batches  = 0

    for features, labels in train_loader:
        # non_blocking=True pairs with pin_memory=True:
        # transfer runs on a separate CUDA stream, CPU continues preparing next batch
        features = features.to(device, non_blocking=True)
        labels   = labels.to(device, non_blocking=True)

        optimizer.zero_grad()
        preds = model(features)
        loss  = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        n_batches  += 1

    avg_loss = total_loss / n_batches
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch+1}/{EPOCHS} | loss={avg_loss:.4f} | lr={current_lr:.2e} | batches={n_batches}")

## Summary

```
IterableDataset __iter__
  └── get_worker_info()          # None → single process, else → shard
        ├── worker_info.id       # which worker am I? (0, 1, 2, ...)
        └── worker_info.num_workers

DataLoader
  ├── num_workers=2              # 2 subprocesses load data in parallel
  ├── pin_memory=True            # lock CPU tensors in RAM for fast DMA to GPU
  └── persistent_workers=True   # keep workers alive between epochs (avoids respawn cost)

Training loop
  └── tensor.to(device, non_blocking=True)   # async transfer on separate CUDA stream
```

### Shard correctness rule
```
Union of all shards    = full dataset   (no gaps)
Intersection of shards = empty          (no overlaps)
```