# LRU Cache - Look-Ahead

In [8]:
!uv pip install -e ../
# !uv pip install datasets

[2mUsing Python 3.11.14 environment at: /Users/ariff.a/Documents/GitHub/StreamingDataset/.venv[0m
[2K[2mResolved [1m22 packages[0m [2min 46ms[0m[0m                                         [0m
[2K   [36m[1mBuilding[0m[39m chinidataset[2m @ file:///Users/ariff.a/Documents/GitHub/StreamingDatas[0m
[2K[1A   [36m[1mBuilding[0m[39m chinidataset[2m @ file:///Users/ariff.a/Documents/GitHub/StreamingDatas[0m
[2K[1A   [36m[1mBuilding[0m[39m chinidataset[2m @ file:///Users/ariff.a/Documents/GitHub/StreamingDatas[0m
[2K[1A   [36m[1mBuilding[0m[39m chinidataset[2m @ file:///Users/ariff.a/Documents/GitHub/StreamingDatas[0m
[2K[1A   [36m[1mBuilding[0m[39m chinidataset[2m @ file:///Users/ariff.a/Documents/GitHub/StreamingDatas[0m
[2K[1A   [36m[1mBuilding[0m[39m chinidataset[2m @ file:///Users/ariff.a/Documents/GitHub/StreamingDatas[0m
[2K[1A      [32m[1mBuilt[0m[39m chinidataset[2m @ file:///Users/ariff.a/Documents/GitHub/StreamingDatas[

In [2]:
## 1. Load Wikipedia & build word vocabulary

import time
import numpy as np
from collections import Counter
from pathlib import Path
from tqdm import tqdm

from datasets import load_dataset

PARQUET_URL = "hf://datasets/wikimedia/wikipedia/20231101.en/train-00000-of-00041.parquet"

print("Loading Wikipedia EN shard 0 ...")
wiki = load_dataset("parquet", data_files=PARQUET_URL, split="train")
print(f"  {len(wiki):,} articles")

print("Building word vocabulary ...")
counter = Counter()
for row in tqdm(wiki, desc="  counting words"):
    counter.update(row["text"].split())

UNK_ID = 0
VOCAB = {"<unk>": UNK_ID}
for i, (w, _) in enumerate(counter.most_common(49_999), 1):
    VOCAB[w] = i
print(f"  {len(VOCAB):,} words in vocab")

def tokenize(text):
    return np.array([VOCAB.get(w, UNK_ID) for w in text.split()], dtype=np.uint32)

  from .autonotebook import tqdm as notebook_tqdm


Loading Wikipedia EN shard 0 ...
  156,289 articles
Building word vocabulary ...


  counting words: 100%|██████████| 156289/156289 [00:14<00:00, 10909.57it/s]


  50,000 words in vocab


In [3]:
## 2. Write to ChiniDataset shards

from chinidataset import ParquetWriter

DATA_DIR = Path("./demo_look_ahead_data")

columns = {"input_ids": "uint32[]", "labels": "uint32[]"}

with ParquetWriter(out=str(DATA_DIR), columns=columns, exist_ok=True) as w:
    for row in tqdm(wiki, desc="Writing"):
        ids = tokenize(row["text"])
        w.write({"input_ids": ids, "labels": ids})

n_shards = len(list(DATA_DIR.glob("shard.*")))
print(f"Done! {len(wiki):,} samples across {n_shards} shards")

Directory /Users/ariff.a/Documents/GitHub/StreamingDataset/examples/demo_look_ahead_data exists; removing contents.
Writing: 100%|██████████| 156289/156289 [00:14<00:00, 11106.75it/s]

Done! 156,289 samples across 14 shards





In [7]:
## 3. Benchmark: `look_ahead=0` vs `look_ahead=2`
from chinidataset import StreamingDataset
import gc

REPEATS = 4  # 1 warmup + 3 measured
results = {}

for la in [0, 2, 4]:
    times = []
    for run in range(REPEATS):
        ds = StreamingDataset(
            local=str(DATA_DIR),
            look_ahead=la,
            max_open_shards=max(8, la + 2),
            shuffle=False,
        )

        t0 = time.perf_counter()
        count = 0
        for sample in ds:
            _ = sample["input_ids"]
            count += 1
        elapsed = time.perf_counter() - t0

        label = "warmup" if run == 0 else f"run {run}"
        print(f"  look_ahead={la}  {label}: {count:,} samples in {elapsed:.3f}s "
              f"({count / elapsed:,.0f} samples/s)")
        times.append(elapsed)

        del ds
        gc.collect()

    measured = times[1:]  # skip warmup
    results[la] = {
        "avg": np.mean(measured),
        "best": min(measured),
        "samples": count,
    }
    print()

# Summary
n = results[0]["samples"]
print("=" * 60)
print(f"  {'look_ahead':>12s}  {'Avg time':>10s}  {'Avg samp/s':>12s}  {'Speedup':>8s}")
print(f"  {'-' * 50}")
for la in [0, 2, 4]:
    r = results[la]
    sp = results[0]["avg"] / r["avg"]
    print(f"  {la:>12d}  {r['avg']:>9.3f}s  {n / r['avg']:>11,.0f}/s  {sp:>7.2f}x")
print("=" * 60)


  look_ahead=0  warmup: 156,289 samples in 1.104s (141,517 samples/s)
  look_ahead=0  run 1: 156,289 samples in 0.811s (192,691 samples/s)
  look_ahead=0  run 2: 156,289 samples in 0.749s (208,748 samples/s)
  look_ahead=0  run 3: 156,289 samples in 0.753s (207,629 samples/s)

  look_ahead=2  warmup: 156,289 samples in 0.566s (276,081 samples/s)
  look_ahead=2  run 1: 156,289 samples in 0.567s (275,476 samples/s)
  look_ahead=2  run 2: 156,289 samples in 0.560s (278,867 samples/s)
  look_ahead=2  run 3: 156,289 samples in 0.588s (265,857 samples/s)

  look_ahead=4  warmup: 156,289 samples in 0.564s (277,271 samples/s)
  look_ahead=4  run 1: 156,289 samples in 0.589s (265,534 samples/s)
  look_ahead=4  run 2: 156,289 samples in 0.593s (263,701 samples/s)
  look_ahead=4  run 3: 156,289 samples in 0.669s (233,614 samples/s)

    look_ahead    Avg time    Avg samp/s   Speedup
  --------------------------------------------------
             0      0.771s      202,752/s     1.00x
          

In [8]:
## 4. Inspect LRU cache behavior

ds = StreamingDataset(
    local=str(DATA_DIR),
    look_ahead=0,          # disable look-ahead so we can see pure LRU
    max_open_shards=3,     # small cache to trigger evictions
    shuffle=False,
)

samples_per_shard = ds._samples_per_shard

# Access first sample of each shard
offset = 0
for shard_idx in range(ds.num_shards):
    _ = ds[offset]
    cached = list(ds._readers.keys())
    print(f"Accessed shard {shard_idx:>2d} -> cache: {cached}  (size={len(cached)})")
    offset += samples_per_shard[shard_idx]

print(f"\nmax_open_shards=3, so only 3 most recent shards stay in memory.")
print(f"Oldest shard is evicted each time a new one is loaded.")

Accessed shard  0 -> cache: [0]  (size=1)
Accessed shard  1 -> cache: [0, 1]  (size=2)
Accessed shard  2 -> cache: [0, 1, 2]  (size=3)
Accessed shard  3 -> cache: [1, 2, 3]  (size=3)
Accessed shard  4 -> cache: [2, 3, 4]  (size=3)
Accessed shard  5 -> cache: [3, 4, 5]  (size=3)
Accessed shard  6 -> cache: [4, 5, 6]  (size=3)
Accessed shard  7 -> cache: [5, 6, 7]  (size=3)
Accessed shard  8 -> cache: [6, 7, 8]  (size=3)
Accessed shard  9 -> cache: [7, 8, 9]  (size=3)
Accessed shard 10 -> cache: [8, 9, 10]  (size=3)
Accessed shard 11 -> cache: [9, 10, 11]  (size=3)
Accessed shard 12 -> cache: [10, 11, 12]  (size=3)
Accessed shard 13 -> cache: [11, 12, 13]  (size=3)

max_open_shards=3, so only 3 most recent shards stay in memory.
Oldest shard is evicted each time a new one is loaded.


In [9]:
## 5. LRU touch: re-access prevents eviction

ds = StreamingDataset(
    local=str(DATA_DIR),
    look_ahead=0,
    max_open_shards=3,
    shuffle=False,
)

sps = ds._samples_per_shard

# Load shards 0, 1, 2
_ = ds[0]                       # shard 0
_ = ds[sps[0]]                  # shard 1
_ = ds[sps[0] + sps[1]]         # shard 2
print(f"After loading 0,1,2:  cache = {list(ds._readers.keys())}")

# Touch shard 0 again (moves to most recent)
_ = ds[5]                       # shard 0 again
print(f"After touching 0:     cache = {list(ds._readers.keys())}")

# Load shard 3 — shard 1 gets evicted (not 0!)
_ = ds[sps[0] + sps[1] + sps[2]]  # shard 3
print(f"After loading 3:      cache = {list(ds._readers.keys())}")
print(f"\nShard 1 was evicted, not shard 0, because we touched 0 -- LRU in action!")

After loading 0,1,2:  cache = [0, 1, 2]
After touching 0:     cache = [1, 2, 0]
After loading 3:      cache = [2, 0, 3]

Shard 1 was evicted, not shard 0, because we touched 0 -- LRU in action!


In [11]:
import shutil
shutil.rmtree(DATA_DIR, ignore_errors=True)
print("Cleaned up demo data.")

Cleaned up demo data.
