Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate mock vision backbone into model #441

Open
wants to merge 14 commits into
base: mm-dev
Choose a base branch
from
101 changes: 101 additions & 0 deletions configs/mm-tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
run_name: mm-tiny
seed: 6198
dry_run: false

wandb: null

model:
d_model: 1024
n_heads: 16
n_layers: 16
mlp_ratio: 8
weight_tying: false
alibi: false
rope: true
vision_backbone:
name: linear
patch_width: 16
patch_height: 16
frozen: false
flash_attention: true
attention_dropout: 0.0
attention_layer_norm: false
multi_query_attention: false
include_bias: false
block_type: sequential
layer_norm_type: default
layer_norm_with_affine: false
bias_for_layer_norm: false
attention_layer_norm_with_affine: false
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
max_sequence_length: 2048
vocab_size: 50280
embedding_size: 50304
eos_token_id: 0
pad_token_id: 1
init_device: meta
init_fn: mitchell

compile: null
# fullgraph: false

optimizer:
name: adamw
learning_rate: 2e-4
weight_decay: 0.1
betas:
- 0.9
- 0.95
metrics_log_interval: 10

scheduler:
name: linear_with_warmup
t_warmup: 200
alpha_f: 0.001

tokenizer:
identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json
truncate_direction: right

save_folder: runs/${run_name}
remote_save_folder: null
save_overwrite: true
# Sharded checkpoints (best for restarts)
save_interval: 1000
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: null # getting errors on LUMI right now
save_num_unsharded_checkpoints_to_keep: -1

load_path: null

max_duration: 10e9
global_train_batch_size: 512
device_train_microbatch_size: 2
time_limit: null

precision: amp_bf16

fsdp:
wrapping_strategy: by_block_and_size
precision: mixed

max_grad_norm: 1.0
max_grad_norm_ratio: null

speed_monitor:
window_size: 20

eval_interval: ${save_interval}
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators: []

data:
pad_direction: right
num_workers: 0
drop_last: true
multi_modal: true
paths: []
20 changes: 20 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"CompilerConfig",
"LayerNormType",
"InitFnType",
"VisionBackboneType",
"VisionBackboneConfig",
"ModelConfig",
"OptimizerType",
"OptimizerConfig",
Expand Down Expand Up @@ -220,6 +222,18 @@ class InitFnType(StrEnum):
"""


class VisionBackboneType(StrEnum):
linear = "linear"


@dataclass
class VisionBackboneConfig(BaseConfig):
name: VisionBackboneType = VisionBackboneType.linear
patch_width: int = 16
patch_height: int = 16
frozen: bool = False


@dataclass
class ModelConfig(BaseConfig):
"""
Expand Down Expand Up @@ -297,6 +311,11 @@ class ModelConfig(BaseConfig):
apply RoPE at the precision of the input.
"""

vision_backbone: Optional[VisionBackboneConfig] = None
"""
Vision backbone settings for multi-modal models.
"""

flash_attention: bool = False
"""
If ``True``, use ``FlashAttention``.
Expand Down Expand Up @@ -524,6 +543,7 @@ class DataConfig(BaseConfig):
prefetch_factor: Optional[int] = None
persistent_workers: bool = False
timeout: int = 0
multi_modal: bool = False


class EvaluatorType(StrEnum):
Expand Down
37 changes: 30 additions & 7 deletions olmo/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,16 @@
from .collator import DataCollator
from .iterable_dataset import IterableDataset
from .memmap_dataset import MemMapDataset
from .multi_modal_iterable_dataset import MultiModalIterableDataset

__all__ = ["MemMapDataset", "DataCollator", "IterableDataset", "build_eval_dataloader", "build_train_dataloader"]
__all__ = [
"MemMapDataset",
"DataCollator",
"IterableDataset",
"MultiModalIterableDataset",
"build_eval_dataloader",
"build_train_dataloader",
]


def build_memmap_dataset(
Expand Down Expand Up @@ -83,7 +91,6 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
collator = DataCollator(
pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id
)
dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False)
work_dir = Path(train_config.save_folder) / "train_data"
if get_global_rank() == 0:
if work_dir.is_dir() and not train_config.save_overwrite:
Expand All @@ -92,16 +99,32 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader:
)
else:
work_dir.mkdir(exist_ok=True, parents=True)
barrier()
return DataLoader(
IterableDataset(
dataset, # type: ignore

if train_config.data.multi_modal:
# TODO: this block will need to change a little when we integrate the actual
# vision dataset, instead of the mock one.
assert train_config.model.vision_backbone is not None
dataset = MultiModalIterableDataset(
pad_token_id=train_config.model.pad_token_id,
max_sequence_length=train_config.model.max_sequence_length,
vocab_size=train_config.model.vocab_size,
patch_width=train_config.model.vision_backbone.patch_width,
patch_height=train_config.model.vision_backbone.patch_height,
)
else:
dataset = IterableDataset(
build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False), # type: ignore
train_config.global_train_batch_size,
seed=train_config.seed + (train_config.epoch or 0),
shuffle=True,
drop_last=train_config.data.drop_last,
work_dir=work_dir,
),
)

barrier()

return DataLoader(
dataset,
batch_size=train_config.device_train_batch_size,
drop_last=train_config.data.drop_last,
collate_fn=collator,
Expand Down
28 changes: 28 additions & 0 deletions olmo/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@ def from_train_config(cls, config: TrainConfig) -> DataCollator:
def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Dict[str, Any]:
assert items
max_len = max((len(x["input_ids"] if isinstance(x, dict) else x) for x in items))
max_images = 0
if items and isinstance(items[0], dict) and "image_offsets" in items[0]:
max_images = max((len(x["image_offsets"]) for x in items)) # type: ignore
all_input_ids = []
all_attention_mask = []
all_attention_bias = []
all_label_mask = []
all_image_patches = []
all_image_offsets = []
all_indices = []
all_metadata = []
for x in items:
Expand Down Expand Up @@ -92,6 +97,25 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
)
)

# Image patches and offsets.
image_offsets = x.get("image_offsets") if isinstance(x, dict) else None
if image_offsets is not None:
pad_shape = (0, max_images - len(image_offsets))
image_patches = x["image_patches"] # type: ignore
image_patches = F.pad(
image_patches.to(dtype=torch.float),
(0, 0, 0, 0, 0, 0) + pad_shape,
value=0.0,
)
all_image_patches.append(image_patches)
all_image_offsets.append(
F.pad(
image_offsets.to(dtype=torch.int32),
pad_shape,
value=-1,
)
)

# Indices.
index = x.get("index") if isinstance(x, dict) else None
if index is not None:
Expand All @@ -109,6 +133,10 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di
out["attention_bias"] = torch.stack(all_attention_bias)
if all_label_mask:
out["label_mask"] = torch.stack(all_label_mask)
if all_image_patches:
out["image_patches"] = torch.stack(all_image_patches)
if all_image_offsets:
out["image_offsets"] = torch.stack(all_image_offsets)
if all_indices:
out["index"] = torch.stack(all_indices)
if all_metadata:
Expand Down
60 changes: 60 additions & 0 deletions olmo/data/multi_modal_iterable_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import random
from typing import Any, Dict, Iterator, Optional

import torch
import torch.utils.data

from ..torch_util import get_fs_local_rank, get_global_rank, get_world_size

__all__ = ["MultiModalIterableDataset"]


# TODO: at the moment this class is a mock dataset, it just generates random data.
# But this is where we should integrate the actual vision dataset when we get there.


class MultiModalIterableDataset(torch.utils.data.IterableDataset[Dict[str, Any]]):
def __init__(
self,
*,
pad_token_id: int,
max_sequence_length: int,
vocab_size: int,
patch_width: int,
patch_height: int,
world_size: Optional[int] = None,
rank: Optional[int] = None,
fs_local_rank: Optional[int] = None,
):
self.pad_token_id = pad_token_id
self.max_sequence_length = max_sequence_length
self.vocab_size = vocab_size
self.patch_width = patch_width
self.patch_height = patch_height
self.rank = rank if rank is not None else get_global_rank()
self.fs_local_rank = fs_local_rank if fs_local_rank is not None else get_fs_local_rank()
self.world_size = world_size if world_size is not None else get_world_size()

def __iter__(self) -> Iterator[Dict[str, Any]]:
index = self.rank
while True:
# Generate mock input IDs.
input_ids = torch.randint(0, self.vocab_size, (self.max_sequence_length,))
# Make sure there are no padding tokens so far.
input_ids.masked_fill_(input_ids == self.pad_token_id, self.pad_token_id + 1)
# Determine where to place image patches.
image_offsets = torch.tensor(
sorted(random.sample(range(self.max_sequence_length), random.randint(1, 5)))
)
# Mask out patch location in input IDs.
input_ids.index_fill_(0, image_offsets, self.pad_token_id)
# Generate mock image patches.
image_patches = torch.rand(len(image_offsets), self.patch_width, self.patch_height, 3)
yield {
"index": index,
"input_ids": input_ids,
"label_mask": input_ids != self.pad_token_id,
"image_offsets": image_offsets,
"image_patches": image_patches,
}
index += self.world_size
Loading
Loading