In [1]:
import numpy as np
from pathlib import Path
from cached_path import cached_path

from olmo.config import TrainConfig
from olmo.data import build_memmap_dataset, DataCollator, IterableDataset
from torch.utils.data import DataLoader, DistributedSampler

import torch.distributed as dist
dist.init_process_group(backend="gloo", world_size=1, rank=0, store=dist.HashStore())

In [2]:
train_config_path = "configs/mitchish7-s3.yaml"
train_config = TrainConfig.load(train_config_path, ["global_train_batch_size=1024"])
cfg = train_config
# Fill some configuration options.
cfg.model.precision = cfg.precision
cfg.device_train_batch_size = cfg.global_train_batch_size // 1
assert cfg.device_train_batch_size is not None  # for mypy
cfg.device_train_grad_accum = cfg.device_train_batch_size // cfg.device_train_microbatch_size
if cfg.optimizer.no_decay_norm_and_bias is not None:
    log.warning(
        "You set the deprecated config option `no_decay_norm_and_bias`. For compatibility, this"
        "setting will take precedence over all other weight decay configurations. Please change"
        "your config to use `decay_norm_and_bias` and `decay_embeddings` instead."
    )
    cfg.optimizer.decay_norm_and_bias = not cfg.optimizer.no_decay_norm_and_bias
    cfg.optimizer.decay_embeddings = not cfg.optimizer.no_decay_norm_and_bias
    cfg.optimizer.no_decay_norm_and_bias = None  # So nobody uses this by accident.

# overrides because we have only one process here
train_config.data.num_workers = 128
train_config.data.pin_memory = False
train_config.data.prefetch_factor = 256

collator = DataCollator(
    pad_direction=train_config.data.pad_direction,
    pad_token_id=train_config.model.pad_token_id
)
dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False)
seed = train_config.data.seed if train_config.data.seed is not None else train_config.seed
work_dir = Path("./temp-work-dir")
loader = DataLoader(
        IterableDataset(
            dataset,  # type: ignore
            train_config.global_train_batch_size,
            seed=seed + (train_config.epoch or 0),
            shuffle=True,
            drop_last=train_config.data.drop_last,
            work_dir=None,
        ),
        batch_size=train_config.device_train_batch_size,
        drop_last=train_config.data.drop_last,
        collate_fn=collator,
        num_workers=train_config.data.num_workers,
        pin_memory=train_config.data.pin_memory,
        prefetch_factor=None if train_config.data.num_workers == 0 else train_config.data.prefetch_factor,
        persistent_workers=False if train_config.data.num_workers == 0 else train_config.data.persistent_workers,
        timeout=train_config.data.timeout,
    )
batches = iter(loader)

In [3]:
max_step = len(dataset) // train_config.device_train_batch_size

In [4]:
import torch
from tqdm.notebook import trange, tqdm

total_counts = torch.zeros((max_step, cfg.model.vocab_size,), dtype=torch.int32)

for step, batch in tqdm(enumerate(batches), total=max_step):
    if step > max_step:
        break
    uniques, counts = batch['input_ids'].flatten().unique(return_counts=True)
    total_counts[step, uniques] += counts

  0%|          | 0/408923 [00:00<?, ?it/s]

KeyboardInterrupt: 