# Memory-Efficient Text DataLoader

This notebook demonstrates how to:
1. Unroll and pre-compute numericalized text data
2. Store it efficiently to disk using numpy memmap
3. Create a custom DataLoader that streams from disk
4. Configure number of workers to reduce memory usage

This approach fixes memory leaks in the standard LMDataLoader by avoiding:
- ReindexCollection caching issues
- Repeated reshuffling and chunk recreation
- High memory usage with many workers

In [None]:
from fastai.basics import *
from fastai.callback.all import *
from fastai.text.all import *
import numpy as np
import json
from pathlib import Path

## 1. Download Wikipedia Tiny Dataset

In [None]:
path = untar_data(URLs.WIKITEXT_TINY)
path.ls()

In [None]:
# Load the data
df_train = pd.read_csv(path/'train.csv', header=None)
df_valid = pd.read_csv(path/'test.csv', header=None)
df_all = pd.concat([df_train, df_valid])
print(f"Train samples: {len(df_train)}, Valid samples: {len(df_valid)}, Total: {len(df_all)}")

In [None]:
df_train.head(2)

## 2. Tokenize and Numericalize Text

First, we'll process the text through the standard fastai pipeline to create vocabulary and numericalized tokens.

In [None]:
# Create splits
splits = [list(range_of(df_train)), list(range(len(df_train), len(df_all)))]
print(f"Train indices: 0-{splits[0][-1]}, Valid indices: {splits[1][0]}-{splits[1][-1]}")

In [None]:
# Create the transforms pipeline
# Note: We use Tokenizer.from_df for dataframe input
tok = Tokenizer.from_df(0)  # Column 0 contains text
num = Numericalize()

In [None]:
# First, tokenize all texts to build vocabulary
# This is the standard approach using Datasets
tfms = [attrgetter("text"), tok, num]
dsets = Datasets(df_all, [tfms], splits=splits)

In [None]:
# Get vocabulary from the numericalize transform
vocab = num.vocab
print(f"Vocabulary size: {len(vocab)}")
print(f"First 20 tokens: {vocab[:20]}")

## 3. Unroll Dataset to Disk Storage

Now we'll extract all numericalized tokens and store them efficiently using numpy memmap.

In [None]:
class DiskLMDataset:
    """
    Stores numericalized text data on disk using numpy memmap.
    
    This avoids memory leaks by:
    1. Pre-computing all tokens once
    2. Storing as contiguous array on disk
    3. Memory-mapping for efficient access without loading all data
    """
    
    def __init__(self, cache_dir: Path):
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)
        
        self.data_path = self.cache_dir / 'tokens.npy'
        self.meta_path = self.cache_dir / 'meta.json'
        self.vocab_path = self.cache_dir / 'vocab.pkl'
        
        self._data = None
        self._meta = None
        self._vocab = None
        
    @classmethod
    def from_datasets(cls, dsets, cache_dir: Path, split_idx: int = 0, 
                      force_rebuild: bool = False, vocab=None):
        """
        Create a DiskLMDataset from a fastai Datasets object.
        
        Args:
            dsets: fastai Datasets with numericalized text
            cache_dir: Directory to store cached data
            split_idx: Which split to use (0=train, 1=valid)
            force_rebuild: If True, rebuild even if cache exists
            vocab: Vocabulary list to save (pass from Numericalize transform)
        """
        obj = cls(cache_dir)
        
        # Check if cache exists
        if not force_rebuild and obj.data_path.exists() and obj.meta_path.exists():
            print(f"Loading from cache: {cache_dir}")
            obj._load_cache()
            return obj
        
        print(f"Building cache in: {cache_dir}")
        
        # Get the specific split
        split_dset = dsets.subset(split_idx)
        
        # Collect all numericalized tokens
        all_tokens = []
        doc_lengths = []
        
        print("Extracting tokens...")
        for i in progress_bar(range(len(split_dset))):
            item = split_dset[i]
            # Handle tuple (x,) or just x
            tokens = item[0] if isinstance(item, tuple) else item
            tokens = tokens.numpy() if hasattr(tokens, 'numpy') else np.array(tokens)
            all_tokens.append(tokens)
            doc_lengths.append(len(tokens))
        
        # Concatenate all tokens
        print("Concatenating tokens...")
        all_tokens = np.concatenate(all_tokens).astype(np.int32)
        total_tokens = len(all_tokens)
        
        print(f"Total tokens: {total_tokens:,}")
        
        # Save to disk
        print("Saving to disk...")
        np.save(obj.data_path, all_tokens)
        
        # Save metadata
        meta = {
            'total_tokens': total_tokens,
            'doc_lengths': doc_lengths,
            'num_docs': len(doc_lengths),
        }
        with open(obj.meta_path, 'w') as f:
            json.dump(meta, f)
        
        # Save vocabulary if provided
        if vocab is not None:
            save_pickle(obj.vocab_path, vocab)
        
        obj._load_cache()
        print(f"Cache built successfully!")
        return obj
    
    def _load_cache(self):
        """Load cached data using memory mapping."""
        self._data = np.load(self.data_path, mmap_mode='r')
        with open(self.meta_path, 'r') as f:
            self._meta = json.load(f)
        if self.vocab_path.exists():
            self._vocab = load_pickle(self.vocab_path)
    
    @property
    def data(self):
        if self._data is None:
            self._load_cache()
        return self._data
    
    @property
    def meta(self):
        if self._meta is None:
            self._load_cache()
        return self._meta
    
    @property
    def vocab(self):
        if self._vocab is None:
            self._load_cache()
        return self._vocab
    
    @property
    def total_tokens(self):
        return self.meta['total_tokens']
    
    def __len__(self):
        return self.total_tokens
    
    def __getitem__(self, idx):
        """Get token(s) at index. Supports slicing."""
        return self.data[idx]

In [None]:
# Create disk datasets for train and valid
cache_base = path / 'lm_cache'

train_disk = DiskLMDataset.from_datasets(dsets, cache_base / 'train', split_idx=0, vocab=vocab)
valid_disk = DiskLMDataset.from_datasets(dsets, cache_base / 'valid', split_idx=1, vocab=vocab)

In [None]:
print(f"Train tokens: {train_disk.total_tokens:,}")
print(f"Valid tokens: {valid_disk.total_tokens:,}")
print(f"Vocab size: {len(train_disk.vocab)}")

In [None]:
# Verify the data
print("First 20 tokens:", train_disk[:20])
print("Decoded:", [train_disk.vocab[i] for i in train_disk[:20]])

## 4. Create Memory-Efficient DataLoader

Now we create a custom DataLoader that:
1. Reads from the disk-cached data
2. Uses minimal memory
3. Supports configurable number of workers

In [None]:
class DiskLMDataLoaderDataset(torch.utils.data.Dataset):
    """
    PyTorch Dataset that reads from DiskLMDataset.
    
    Implements the same chunking logic as LMDataLoader but reads from disk.
    """
    
    def __init__(self, disk_data: DiskLMDataset, bs: int = 64, seq_len: int = 72):
        self.disk_data = disk_data
        self.bs = bs
        self.seq_len = seq_len
        
        # Calculate dimensions (same logic as LMDataLoader)
        total = disk_data.total_tokens - 1  # -1 for labels
        self.corpus = round_multiple(total, bs, round_down=True)
        self.bl = self.corpus // bs  # batch length
        self.n_batches = self.bl // seq_len + int(self.bl % seq_len != 0)
        self.last_len = self.bl - (self.n_batches - 1) * seq_len
        
        # For shuffling, we create a permutation of batch indices
        self._batch_order = None
        
    def __len__(self):
        return self.n_batches * self.bs
    
    def shuffle(self):
        """Create new random batch ordering for this epoch."""
        # We shuffle at the batch level, not token level
        # This maintains coherent sequences within batches
        self._batch_order = torch.randperm(self.n_batches).numpy()
    
    def __getitem__(self, seq):
        """
        Get a single training example (input, target).
        
        seq: linear index into all sequences
        """
        if seq >= len(self):
            raise IndexError(f"Index {seq} out of range for dataset of size {len(self)}")
        
        # Determine which position in the stream this is
        batch_idx = seq // self.bs  # which batch (0 to n_batches-1)
        stream_idx = seq % self.bs   # which stream within the batch (0 to bs-1)
        
        # Apply batch shuffling if set
        if self._batch_order is not None:
            batch_idx = self._batch_order[batch_idx]
        
        # Calculate sequence length for this batch
        sl = self.last_len if batch_idx == self.n_batches - 1 else self.seq_len
        
        # Calculate start position in the token stream
        # Each stream is bl tokens long, starts at stream_idx * bl
        # Within the stream, we're at batch_idx * seq_len
        start = stream_idx * self.bl + batch_idx * self.seq_len
        
        # Get tokens (sl+1 to have input and target)
        tokens = self.disk_data[start:start + sl + 1].copy()
        tokens = torch.from_numpy(tokens).long()
        
        x = LMTensorText(tokens[:-1])
        y = tokens[1:]
        
        return x, y

In [None]:
class MemoryEfficientLMDataLoader:
    """
    Memory-efficient Language Model DataLoader.
    
    Key differences from standard LMDataLoader:
    1. Reads from disk-cached data (no in-memory caching issues)
    2. Configurable num_workers (default 0 to avoid multiprocessing memory)
    3. Optional batch-level shuffling
    """
    
    def __init__(self, 
                 disk_data: DiskLMDataset,
                 bs: int = 64,
                 seq_len: int = 72,
                 num_workers: int = 0,
                 shuffle: bool = True,
                 drop_last: bool = True,
                 pin_memory: bool = False):
        
        self.disk_data = disk_data
        self.bs = bs
        self.seq_len = seq_len
        self.num_workers = num_workers
        self.shuffle = shuffle
        self.drop_last = drop_last
        self.pin_memory = pin_memory
        
        # Create underlying dataset
        self.dataset = DiskLMDataLoaderDataset(disk_data, bs=bs, seq_len=seq_len)
        
        # Store vocab for decode
        self.vocab = disk_data.vocab
    
    def __len__(self):
        return self.dataset.n_batches
    
    def __iter__(self):
        """Iterate over batches."""
        if self.shuffle:
            self.dataset.shuffle()
        
        # Create DataLoader
        dl = torch.utils.data.DataLoader(
            self.dataset,
            batch_size=self.bs,
            shuffle=False,  # We handle shuffling ourselves
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=self.drop_last,
            collate_fn=self._collate
        )
        
        for batch in dl:
            yield batch
    
    def _collate(self, batch):
        """Collate a batch of (x, y) pairs."""
        xs, ys = zip(*batch)
        # Stack into batch tensors
        x = torch.stack([x for x in xs])
        y = torch.stack([y for y in ys])
        return x, y
    
    def one_batch(self):
        """Get one batch for testing."""
        return next(iter(self))
    
    def decode(self, tokens):
        """Decode tokens to text."""
        if isinstance(tokens, torch.Tensor):
            tokens = tokens.cpu().numpy()
        return ' '.join([self.vocab[t] for t in tokens])
    
    def show_batch(self, max_n: int = 3):
        """Display a few examples from one batch."""
        x, y = self.one_batch()
        for i in range(min(max_n, len(x))):
            print(f"--- Example {i+1} ---")
            print(f"Input:  {self.decode(x[i][:50])}...")
            print(f"Target: {self.decode(y[i][:50])}...")
            print()

In [None]:
# Create memory-efficient dataloaders
bs, sl = 64, 72

train_dl = MemoryEfficientLMDataLoader(
    train_disk, 
    bs=bs, 
    seq_len=sl, 
    num_workers=0,  # Use 0 workers to minimize memory
    shuffle=True
)

valid_dl = MemoryEfficientLMDataLoader(
    valid_disk, 
    bs=bs, 
    seq_len=sl, 
    num_workers=0,
    shuffle=False
)

In [None]:
print(f"Train batches: {len(train_dl)}")
print(f"Valid batches: {len(valid_dl)}")

In [None]:
# Show a batch
train_dl.show_batch(max_n=2)

In [None]:
# Verify batch shapes
x, y = train_dl.one_batch()
print(f"Input shape: {x.shape}")
print(f"Target shape: {y.shape}")

## 5. Integration with fastai Learner

Now let's create a wrapper that makes this work with fastai's training loop.

In [None]:
class DiskLMDataLoaders(DataLoaders):
    """
    DataLoaders wrapper for memory-efficient LM training.
    
    Works with fastai's Learner.
    """
    
    def __init__(self, train_dl, valid_dl, vocab, device=None):
        self.loaders = [train_dl, valid_dl]
        self._vocab = vocab
        self.device = default_device() if device is None else device
    
    @property
    def train(self):
        return self.loaders[0]
    
    @property
    def valid(self):
        return self.loaders[1]
    
    @property
    def vocab(self):
        return self._vocab
    
    def __iter__(self):
        return iter(self.loaders)
    
    def __len__(self):
        return len(self.loaders)
    
    def __getitem__(self, i):
        return self.loaders[i]
    
    @classmethod
    def from_cache(cls, cache_dir: Path, bs: int = 64, seq_len: int = 72, 
                   num_workers: int = 0, device=None):
        """
        Create DataLoaders from cached disk data.
        
        Args:
            cache_dir: Base cache directory (should have 'train' and 'valid' subdirs)
            bs: Batch size
            seq_len: Sequence length
            num_workers: Number of dataloader workers (0 recommended for low memory)
            device: PyTorch device
        """
        cache_dir = Path(cache_dir)
        
        # Load disk datasets
        train_disk = DiskLMDataset(cache_dir / 'train')
        valid_disk = DiskLMDataset(cache_dir / 'valid')
        
        # Create dataloaders
        train_dl = MemoryEfficientLMDataLoader(
            train_disk, bs=bs, seq_len=seq_len, 
            num_workers=num_workers, shuffle=True
        )
        valid_dl = MemoryEfficientLMDataLoader(
            valid_disk, bs=bs, seq_len=seq_len,
            num_workers=num_workers, shuffle=False
        )
        
        return cls(train_dl, valid_dl, train_disk.vocab, device=device)

In [None]:
# Create DataLoaders
dls = DiskLMDataLoaders.from_cache(
    path / 'lm_cache',
    bs=64,
    seq_len=72,
    num_workers=0  # Zero workers for minimal memory
)

In [None]:
print(f"Vocab size: {len(dls.vocab)}")
print(f"Train batches: {len(dls.train)}")
print(f"Valid batches: {len(dls.valid)}")

## 6. Create and Train a Language Model

In [None]:
# Create a simple training loop compatible with our DataLoaders
import gc

def train_epoch(model, dl, loss_func, opt, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    n_batches = 0
    
    for x, y in progress_bar(dl):
        x, y = x.to(device), y.to(device)
        
        opt.zero_grad()
        out = model(x)[0]  # AWD-LSTM returns (output, hidden, dropped_output)
        loss = loss_func(out.view(-1, out.shape[-1]), y.view(-1))
        loss.backward()
        opt.step()
        
        total_loss += loss.item()
        n_batches += 1
        
        # Explicit garbage collection to help with memory
        del x, y, out, loss
    
    gc.collect()
    torch.cuda.empty_cache() if torch.cuda.is_available() else None
    
    return total_loss / n_batches

@torch.no_grad()
def validate(model, dl, loss_func, device):
    """Validate the model."""
    model.eval()
    total_loss = 0
    n_batches = 0
    
    for x, y in dl:
        x, y = x.to(device), y.to(device)
        out = model(x)[0]
        loss = loss_func(out.view(-1, out.shape[-1]), y.view(-1))
        total_loss += loss.item()
        n_batches += 1
    
    return total_loss / n_batches

In [None]:
# Create model
config = awd_lstm_lm_config.copy()
config.update({'input_p': 0.6, 'output_p': 0.4, 'weight_p': 0.5, 'embed_p': 0.1, 'hidden_p': 0.2})

model = get_language_model(AWD_LSTM, len(dls.vocab), config=config)
device = default_device()
model = model.to(device)

print(f"Model on: {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Create optimizer and loss
opt = Adam(model.parameters(), lr=5e-3, wd=0.1)
loss_func = CrossEntropyLossFlat()

In [None]:
# Train for a few epochs
n_epochs = 2

for epoch in range(n_epochs):
    train_loss = train_epoch(model, dls.train, loss_func, opt, device)
    valid_loss = validate(model, dls.valid, loss_func, device)
    
    print(f"Epoch {epoch+1}: train_loss={train_loss:.4f}, valid_loss={valid_loss:.4f}, perplexity={np.exp(valid_loss):.2f}")

## 7. Memory Usage Comparison

Let's compare memory usage between standard and disk-based approaches.

In [None]:
import psutil
import os

def get_memory_mb():
    """Get current process memory in MB."""
    process = psutil.Process(os.getpid())
    return process.memory_info().rss / 1024 / 1024

print(f"Current memory usage: {get_memory_mb():.1f} MB")

In [None]:
# Test iteration without memory leaks
gc.collect()
mem_before = get_memory_mb()

# Iterate through multiple epochs
for epoch in range(3):
    for batch_idx, (x, y) in enumerate(dls.train):
        if batch_idx >= 10:  # Just test a few batches
            break
        del x, y
    gc.collect()
    print(f"Epoch {epoch+1}: Memory = {get_memory_mb():.1f} MB")

mem_after = get_memory_mb()
print(f"\nMemory change: {mem_after - mem_before:.1f} MB")

## 8. Summary

This notebook demonstrates a memory-efficient approach for language model training:

### Key Components:
1. **DiskLMDataset**: Stores numericalized tokens on disk using numpy memmap
2. **DiskLMDataLoaderDataset**: PyTorch Dataset that reads from disk cache
3. **MemoryEfficientLMDataLoader**: Custom DataLoader with configurable workers
4. **DiskLMDataLoaders**: fastai-compatible DataLoaders wrapper

### Benefits:
- **Reduced memory**: Data stays on disk, only loaded when needed
- **No memory leaks**: Avoids ReindexCollection caching issues
- **Configurable workers**: Use `num_workers=0` for minimal memory
- **Fast second load**: Pre-computed cache loads instantly
- **Large dataset support**: Can handle datasets larger than RAM

### Usage:
```python
# First time: Build cache (pass vocab from Numericalize transform)
train_disk = DiskLMDataset.from_datasets(dsets, cache_dir / 'train', split_idx=0, vocab=num.vocab)

# Second time: Load from cache (fast!)
dls = DiskLMDataLoaders.from_cache(cache_dir, bs=64, seq_len=72, num_workers=0)
```