# Day 2

## Dataloader
Desiderata:
* Work with any dataset implementing __len__ and __getitem(i)
* Deterministic shuffling. Reproducible via seed and per-epoch control, no data leakage
* Batching controls: batch_size, drop_last
* Collation: can stack NumPy arrays, recursibely handle tuples/lists/dicts.
* Length semantics: len(dataloader) returns number of batches
* Memory-friendly: no data copying
* Composability: ? 
* Per-epoch reseeding
* Clear errors: shape/length mismatches fail loudly
* Samplers, prefetch...

In [8]:
from dataloader import Dataloader, DatasetProtocol
import numpy as np
from typing import List, Dict, Tuple

In [9]:
class ArrayDataset(DatasetProtocol):
    """Simple dataset that wraps a numpy array."""

    def __init__(self, X: np.ndarray, y: np.ndarray): 
        if len(X) != len(y):
            raise ValueError("X and y must have the same length.")
        self.X = X
        self.y = y

    def __len__(self) -> int: return len(self.X)
    def __getitem__(self, i) -> Tuple[np.ndarray, np.ndarray]:
        return (self.X[i], self.y[i])
    
X = np.arange(20).reshape(10, 2).astype(np.float32)
y = (np.sum(X, axis=1) > 10).astype(np.int64)  # Binary classification

ds = ArrayDataset(X, y)
loader = Dataloader(ds, batch_size=4, shuffle=True, drop_last=False, seed=42)

for epoch in range(2):
    loader.set_epoch(epoch)
    print(f"\n[Dataloader] Epoch {epoch}")
    for batch in loader: 
        xb, yb = batch
        print(f"[Dataloader]\n[Batch] X: \n {xb} \n[Batch] y:\n {yb} \n")




[Dataloader] Epoch 0
[Dataloader]
[Batch] X: 
 [[12. 13.]
 [ 2.  3.]
 [ 0.  1.]
 [ 8.  9.]] 
[Batch] y:
 [1 0 0 1] 

[Dataloader]
[Batch] X: 
 [[10. 11.]
 [ 6.  7.]
 [14. 15.]
 [18. 19.]] 
[Batch] y:
 [1 1 1 1] 

[Dataloader]
[Batch] X: 
 [[16. 17.]
 [ 4.  5.]] 
[Batch] y:
 [1 0] 


[Dataloader] Epoch 1
[Dataloader]
[Batch] X: 
 [[12. 13.]
 [18. 19.]
 [ 4.  5.]
 [10. 11.]] 
[Batch] y:
 [1 1 0 1] 

[Dataloader]
[Batch] X: 
 [[16. 17.]
 [ 6.  7.]
 [14. 15.]
 [ 0.  1.]] 
[Batch] y:
 [1 1 1 0] 

[Dataloader]
[Batch] X: 
 [[8. 9.]
 [2. 3.]] 
[Batch] y:
 [1 0] 

