From aea19b7072fd3817454014683a2a480638785a4b Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 8 Feb 2024 15:35:44 -0800 Subject: [PATCH 1/7] integrate mock vision backbone into model --- olmo/config.py | 19 ++++++++++ olmo/model.py | 94 ++++++++++++++++++++++++++++++++++++++++++-------- 2 files changed, 98 insertions(+), 15 deletions(-) diff --git a/olmo/config.py b/olmo/config.py index c0f26b08b..bc0f033dd 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -33,6 +33,8 @@ "CompilerConfig", "LayerNormType", "InitFnType", + "VisionBackboneType", + "VisionBackboneConfig", "ModelConfig", "OptimizerType", "OptimizerConfig", @@ -225,6 +227,18 @@ class InitFnType(StrEnum): """ +class VisionBackboneType(StrEnum): + linear = "linear" + + +@dataclass +class VisionBackboneConfig(BaseConfig): + name: VisionBackboneType = VisionBackboneType.linear + patch_width: int = 16 + patch_height: int = 16 + frozen: bool = False + + @dataclass class ModelConfig(BaseConfig): """ @@ -297,6 +311,11 @@ class ModelConfig(BaseConfig): apply RoPE at the precision of the input. """ + vision_backbone: Optional[VisionBackboneConfig] = None + """ + Vision backbone settings for multi-modal models. + """ + flash_attention: bool = False """ If ``True``, use ``FlashAttention``. diff --git a/olmo/model.py b/olmo/model.py index cc621a37b..5731c7be0 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -41,6 +41,7 @@ FSDPWrapStrategy, LayerNormType, ModelConfig, + VisionBackboneType, ) from .exceptions import OlmoConfigurationError from .initialization import ModuleType, init_weights @@ -997,6 +998,44 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin block.set_activation_checkpointing(strategy) +class OlmoVisionBackbone(nn.Module): + def __init__(self, config: ModelConfig): + super().__init__() + self.config = config + + @classmethod + def build(cls, config: ModelConfig) -> OlmoVisionBackbone: + v_cfg = config.vision_backbone + assert v_cfg is not None + if v_cfg.name == VisionBackboneType.linear: + return OlmoLinearVisionBackbone(config) + else: + raise NotImplementedError(v_cfg.name) + + def reset_parameters(self): + pass + + +class OlmoLinearVisionBackbone(OlmoVisionBackbone): + def __init__(self, config: ModelConfig): + super().__init__(config) + v_cfg = self.config.vision_backbone + assert v_cfg is not None + self.ff = nn.Linear( + v_cfg.patch_width * v_cfg.patch_height * 3, self.config.d_model, device=self.config.init_device + ) + if v_cfg.frozen: + for param in self.ff.parameters(): + param.requires_grad = False + + def forward(self, image_patches: torch.Tensor) -> torch.Tensor: + batch_size, num_patches, *_ = image_patches.shape + # Reshape image patches from (batch_size, num_patches, patch_width, patch_height, 3) + # to (batch_size, num_patches, patch_width * patch_height * 3) + image_patches = image_patches.view(batch_size, num_patches, -1) + return self.ff(image_patches) + + class Olmo(nn.Module): def __init__(self, config: ModelConfig, init_params: bool = True): super().__init__() @@ -1056,6 +1095,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True): self.transformer.update( {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)} ) + if not config.weight_tying: self.transformer.update( { @@ -1067,6 +1107,11 @@ def __init__(self, config: ModelConfig, init_params: bool = True): ) } ) + + self.vision_backbone: Optional[OlmoVisionBackbone] = None + if config.vision_backbone is not None: + self.vision_backbone = OlmoVisionBackbone.build(config) + # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights. if init_params and self.config.init_device != "meta": self.reset_parameters() @@ -1113,6 +1158,10 @@ def reset_parameters(self): if hasattr(self.transformer, "ff_out"): init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore + # Vision backbone. + if self.vision_backbone is not None: + self.vision_backbone.reset_parameters() + # Let the blocks handle themselves. if self.config.block_group_size == 1: for block in self.transformer.blocks: @@ -1140,6 +1189,8 @@ def forward( attention_mask: Optional[torch.Tensor] = None, attention_bias: Optional[torch.Tensor] = None, past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None, + image_patches: Optional[torch.Tensor] = None, + image_offsets: Optional[torch.Tensor] = None, use_cache: bool = False, last_logits_only: bool = False, ) -> OlmoOutput: @@ -1167,6 +1218,10 @@ def forward( :param past_key_values: Pre-computed keys and values for each attention block. Can be used to speed up sequential decoding. The `input_ids` which have their past given to this model should not be passed as `input_ids` as they have already been computed. + :param image_patches: For multi-modal models, image patch inputs of shape + `(num_patches, patch_width, patch_height, 3)`. + :param image_offsets: For mulit-modal models, specifies where in the input IDs the embedded image + patches should go. Shape `(num_patches,)`. :param use_cache: If `True`, return key and value tensors for each block. :param last_logits_only: If `True`, only compute the logits for the last token of each sequence. This can speed up decoding when you only care about the next token. @@ -1180,10 +1235,25 @@ def forward( else: past_length = past_key_values[0][0].size(-2) + img_emb: Optional[torch.Tensor] = None + if image_patches is not None: + # Get image patch embeddings. + assert self.vision_backbone is not None + assert image_offsets is not None + # shape: (batch_size, num_patches, d_model) + img_emb = self.vision_backbone(image_patches) + # Get embeddings of input. # shape: (batch_size, seq_len, d_model) x = self.transformer.wte(input_ids) # type: ignore + if img_emb is not None: + # Inject image patch embeddings into input embeddings. + assert image_offsets is not None + image_offsets_mask = image_offsets > 0 + batch_idx = torch.arange(0, batch_size).repeat_interleave(image_offsets_mask.sum(dim=-1)) + x.index_put_((batch_idx, image_offsets[image_offsets_mask]), img_emb[image_offsets_mask]) + if not (self.config.alibi or self.config.rope): # Get positional embeddings. # shape: (1, seq_len) @@ -1309,13 +1379,13 @@ def forward( def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None): if wrap_strategy is None: return None - if wrap_strategy == FSDPWrapStrategy.by_block: + elif wrap_strategy == FSDPWrapStrategy.by_block: def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel if recurse: return True # always recurse for simplicity - return isinstance(module, OlmoBlock) + return isinstance(module, (OlmoVisionBackbone, OlmoBlock)) return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_and_size: @@ -1323,11 +1393,8 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel if recurse: - # Determine if we should recurse. - return not isinstance(module, OlmoBlock) - else: - # Determine if we should wrap. - return isinstance(module, (OlmoBlock, nn.Linear, nn.Embedding)) + return True # always recurse for simplicity + return isinstance(module, (OlmoVisionBackbone, OlmoBlock, nn.Linear, nn.Embedding)) return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_group: @@ -1340,7 +1407,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel if recurse: return True # always recurse for simplicity - return isinstance(module, OlmoBlockGroup) + return isinstance(module, (OlmoVisionBackbone, OlmoBlockGroup)) return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size: @@ -1352,17 +1419,14 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel if recurse: - # Determine if we should recurse. - return not isinstance(module, OlmoBlockGroup) - else: - # Determine if we should wrap. - return isinstance(module, (OlmoBlockGroup, nn.Linear, nn.Embedding)) + return True # always recurse for simplicity + return isinstance(module, (OlmoVisionBackbone, OlmoBlockGroup, nn.Linear, nn.Embedding)) return fsdp_wrap_fn elif wrap_strategy == FSDPWrapStrategy.size_based: from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy - return size_based_auto_wrap_policy + return partial(size_based_auto_wrap_policy, force_leaf_modules={OlmoVisionBackbone}) elif wrap_strategy in { FSDPWrapStrategy.one_in_two, FSDPWrapStrategy.one_in_three, @@ -1380,7 +1444,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0): del nonwrapped_numel if recurse: return True # always recurse for simplicity - return isinstance(module, OlmoBlock) and module.layer_id % c == 0 + return isinstance(module, (OlmoVisionBackbone, OlmoBlock)) and module.layer_id % c == 0 return fsdp_wrap_fn else: From ca4673e73b8322c783750f267c7292b029e96e2e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 8 Feb 2024 15:55:33 -0800 Subject: [PATCH 2/7] Pass image inputs to model within trainer --- olmo/train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/olmo/train.py b/olmo/train.py index f459ad88d..fdacee5cb 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -536,6 +536,8 @@ def model_forward( input_ids=batch["input_ids"], attention_mask=batch.get("attention_mask"), attention_bias=batch.get("attention_bias"), + image_patches=batch.get("image_patches"), + image_offsets=batch.get("image_offsets"), ).logits logits_for_loss = logits[..., :-1, :].contiguous() # shape: (batch_size * seq_len, vocab_size) From 472c386ba38b5f3f6fe4bf4ded826799f8bd0ddf Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 12 Feb 2024 14:22:29 -0800 Subject: [PATCH 3/7] Update data collator for image fields --- olmo/data/collator.py | 28 ++++++++++++++++++++++++++++ tests/data/collator_test.py | 22 ++++++++++++++++++++++ 2 files changed, 50 insertions(+) diff --git a/olmo/data/collator.py b/olmo/data/collator.py index d86a0b9af..0345b4976 100644 --- a/olmo/data/collator.py +++ b/olmo/data/collator.py @@ -23,10 +23,15 @@ def from_train_config(cls, config: TrainConfig) -> DataCollator: def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Dict[str, Any]: assert items max_len = max((len(x["input_ids"] if isinstance(x, dict) else x) for x in items)) + max_images = 0 + if items and isinstance(items[0], dict) and "image_offsets" in items[0]: + max_images = max((len(x["image_offsets"]) for x in items)) # type: ignore all_input_ids = [] all_attention_mask = [] all_attention_bias = [] all_label_mask = [] + all_image_patches = [] + all_image_offsets = [] all_indices = [] all_metadata = [] for x in items: @@ -92,6 +97,25 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di ) ) + # Image patches and offsets. + image_offsets = x.get("image_offsets") if isinstance(x, dict) else None + if image_offsets is not None: + pad_shape = (0, max_images - len(image_offsets)) + image_patches = x["image_patches"] # type: ignore + image_patches = F.pad( + image_patches.to(dtype=torch.float), + (0, 0, 0, 0, 0, 0) + pad_shape, + value=0.0, + ) + all_image_patches.append(image_patches) + all_image_offsets.append( + F.pad( + image_offsets.to(dtype=torch.int32), + pad_shape, + value=-1, + ) + ) + # Indices. index = x.get("index") if isinstance(x, dict) else None if index is not None: @@ -109,6 +133,10 @@ def __call__(self, items: Union[List[Dict[str, Any]], List[torch.Tensor]]) -> Di out["attention_bias"] = torch.stack(all_attention_bias) if all_label_mask: out["label_mask"] = torch.stack(all_label_mask) + if all_image_patches: + out["image_patches"] = torch.stack(all_image_patches) + if all_image_offsets: + out["image_offsets"] = torch.stack(all_image_offsets) if all_indices: out["index"] = torch.stack(all_indices) if all_metadata: diff --git a/tests/data/collator_test.py b/tests/data/collator_test.py index e94451313..ecc670a41 100644 --- a/tests/data/collator_test.py +++ b/tests/data/collator_test.py @@ -129,3 +129,25 @@ def test_collate_with_label_mask(train_config, pad_direction): [[True, False, True, True], [False, True, True, False]], ) ).all() + + +def test_collate_with_images(): + collator = DataCollator(pad_direction=PaddingDirection.right, pad_token_id=0) + patch_size = 5 # width and height + + inputs = [ + { + "input_ids": torch.tensor([1, 2, 3, 0, 4]), + "image_offsets": torch.tensor([3]), + "image_patches": torch.rand(1, patch_size, patch_size, 3), + }, + { + "input_ids": torch.tensor([4, 0, 0, 1, 2, 3]), + "image_offsets": torch.tensor([1, 2]), + "image_patches": torch.rand(2, patch_size, patch_size, 3), + }, + ] + batch = collator(inputs) # type: ignore + assert batch["image_offsets"].shape == (2, 2) + assert (batch["image_offsets"] == torch.tensor([[3, -1], [1, 2]])).all() + assert batch["image_patches"].shape == (2, 2, patch_size, patch_size, 3) From b81a99ebbce15e449a4f6fcc2a4635420913368f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 12 Feb 2024 15:06:33 -0800 Subject: [PATCH 4/7] Add mock multi-modal dataset --- olmo/config.py | 1 + olmo/data/__init__.py | 35 +++++++++++--- olmo/data/multi_modal_iterable_dataset.py | 56 +++++++++++++++++++++++ 3 files changed, 85 insertions(+), 7 deletions(-) create mode 100644 olmo/data/multi_modal_iterable_dataset.py diff --git a/olmo/config.py b/olmo/config.py index bc0f033dd..ae9019274 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -543,6 +543,7 @@ class DataConfig(BaseConfig): prefetch_factor: Optional[int] = None persistent_workers: bool = False timeout: int = 0 + multi_modal: bool = False class EvaluatorType(StrEnum): diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index 52421b57a..1e50547a1 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -10,8 +10,16 @@ from .collator import DataCollator from .iterable_dataset import IterableDataset from .memmap_dataset import MemMapDataset +from .multi_modal_iterable_dataset import MultiModalIterableDataset -__all__ = ["MemMapDataset", "DataCollator", "IterableDataset", "build_eval_dataloader", "build_train_dataloader"] +__all__ = [ + "MemMapDataset", + "DataCollator", + "IterableDataset", + "MultiModalIterableDataset", + "build_eval_dataloader", + "build_train_dataloader", +] def build_memmap_dataset( @@ -83,7 +91,6 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader: collator = DataCollator( pad_direction=train_config.data.pad_direction, pad_token_id=train_config.model.pad_token_id ) - dataset = build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False) work_dir = Path(train_config.save_folder) / "train_data" if get_global_rank() == 0: if work_dir.is_dir() and not train_config.save_overwrite: @@ -92,16 +99,30 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader: ) else: work_dir.mkdir(exist_ok=True, parents=True) - barrier() - return DataLoader( - IterableDataset( - dataset, # type: ignore + + if train_config.data.multi_modal: + assert train_config.model.vision_backbone is not None + dataset = MultiModalIterableDataset( + pad_token_id=train_config.model.pad_token_id, + max_sequence_length=train_config.model.max_sequence_length, + vocab_size=train_config.model.vocab_size, + patch_width=train_config.model.vision_backbone.patch_width, + patch_height=train_config.model.vision_backbone.patch_height, + ) + else: + dataset = IterableDataset( + build_memmap_dataset(train_config, train_config.data, include_instance_metadata=False), # type: ignore train_config.global_train_batch_size, seed=train_config.seed + (train_config.epoch or 0), shuffle=True, drop_last=train_config.data.drop_last, work_dir=work_dir, - ), + ) + + barrier() + + return DataLoader( + dataset, batch_size=train_config.device_train_batch_size, drop_last=train_config.data.drop_last, collate_fn=collator, diff --git a/olmo/data/multi_modal_iterable_dataset.py b/olmo/data/multi_modal_iterable_dataset.py new file mode 100644 index 000000000..f1787fd47 --- /dev/null +++ b/olmo/data/multi_modal_iterable_dataset.py @@ -0,0 +1,56 @@ +import random +from typing import Any, Dict, Iterator, Optional + +import torch +import torch.utils.data + +from ..torch_util import get_fs_local_rank, get_global_rank, get_world_size + +__all__ = ["MultiModalIterableDataset"] + + +class MultiModalIterableDataset(torch.utils.data.IterableDataset[Dict[str, Any]]): + def __init__( + self, + *, + pad_token_id: int, + max_sequence_length: int, + vocab_size: int, + patch_width: int, + patch_height: int, + world_size: Optional[int] = None, + rank: Optional[int] = None, + fs_local_rank: Optional[int] = None, + ): + self.pad_token_id = pad_token_id + self.max_sequence_length = max_sequence_length + self.vocab_size = vocab_size + self.patch_width = patch_width + self.patch_height = patch_height + self.rank = rank if rank is not None else get_global_rank() + self.fs_local_rank = fs_local_rank if fs_local_rank is not None else get_fs_local_rank() + self.world_size = world_size if world_size is not None else get_world_size() + + def __iter__(self) -> Iterator[Dict[str, Any]]: + index = self.rank + while True: + # Generate mock input IDs. + input_ids = torch.randint(0, self.vocab_size, (self.max_sequence_length,)) + # Make sure there are no padding tokens so far. + input_ids.masked_fill_(input_ids == self.pad_token_id, self.pad_token_id + 1) + # Determine where to place image patches. + image_offsets = torch.tensor( + sorted(random.sample(range(self.max_sequence_length), random.randint(1, 5))) + ) + # Mask out patch location in input IDs. + input_ids.index_fill_(0, image_offsets, self.pad_token_id) + # Generate mock image patches. + image_patches = torch.rand(len(image_offsets), self.patch_width, self.patch_height, 3) + yield { + "index": index, + "input_ids": input_ids, + "label_mask": input_ids != self.pad_token_id, + "image_offsets": image_offsets, + "image_patches": image_patches, + } + index += self.world_size From 25b9f3caa6c3dcd82461fe952ac92dd1b7665e37 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Mon, 12 Feb 2024 15:06:46 -0800 Subject: [PATCH 5/7] Add config for testing --- configs/mm-tiny.yaml | 101 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 configs/mm-tiny.yaml diff --git a/configs/mm-tiny.yaml b/configs/mm-tiny.yaml new file mode 100644 index 000000000..15c3107b1 --- /dev/null +++ b/configs/mm-tiny.yaml @@ -0,0 +1,101 @@ +run_name: mm-tiny +seed: 6198 +dry_run: false + +wandb: null + +model: + d_model: 1024 + n_heads: 16 + n_layers: 16 + mlp_ratio: 8 + weight_tying: false + alibi: false + rope: true + vision_backbone: + name: linear + patch_width: 16 + patch_height: 16 + frozen: false + flash_attention: true + attention_dropout: 0.0 + attention_layer_norm: false + multi_query_attention: false + include_bias: false + block_type: sequential + layer_norm_type: default + layer_norm_with_affine: false + bias_for_layer_norm: false + attention_layer_norm_with_affine: false + activation_type: swiglu + residual_dropout: 0.0 + embedding_dropout: 0.0 + max_sequence_length: 2048 + vocab_size: 50280 + embedding_size: 50304 + eos_token_id: 0 + pad_token_id: 1 + init_device: meta + init_fn: mitchell + +compile: null + # fullgraph: false + +optimizer: + name: adamw + learning_rate: 2e-4 + weight_decay: 0.1 + betas: + - 0.9 + - 0.95 + metrics_log_interval: 10 + +scheduler: + name: linear_with_warmup + t_warmup: 200 + alpha_f: 0.001 + +tokenizer: + identifier: tokenizers/allenai_eleuther-ai-gpt-neox-20b-pii-special.json + truncate_direction: right + +save_folder: runs/${run_name} +remote_save_folder: null +save_overwrite: true +# Sharded checkpoints (best for restarts) +save_interval: 1000 +save_num_checkpoints_to_keep: -1 +# Unsharded checkpoints (for final storage) +save_interval_unsharded: null # getting errors on LUMI right now +save_num_unsharded_checkpoints_to_keep: -1 + +load_path: null + +max_duration: 10e9 +global_train_batch_size: 512 +device_train_microbatch_size: 2 +time_limit: null + +precision: amp_bf16 + +fsdp: + wrapping_strategy: by_block_and_size + precision: mixed + +max_grad_norm: 1.0 +max_grad_norm_ratio: null + +speed_monitor: + window_size: 20 + +eval_interval: ${save_interval} +eval_subset_num_batches: -1 +device_eval_batch_size: ${device_train_microbatch_size} +evaluators: [] + +data: + pad_direction: right + num_workers: 0 + drop_last: true + multi_modal: true + paths: [] From 8976f36ea161697c71fb9daaa3e0abe5bf2a9940 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 13 Feb 2024 12:27:05 -0800 Subject: [PATCH 6/7] fix dtype --- olmo/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/olmo/model.py b/olmo/model.py index c717f438f..91d684cc0 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -1261,7 +1261,9 @@ def forward( # Inject image patch embeddings into input embeddings. assert image_offsets is not None image_offsets_mask = image_offsets > 0 - batch_idx = torch.arange(0, batch_size).repeat_interleave(image_offsets_mask.sum(dim=-1)) + batch_idx = torch.arange(0, batch_size, device=x.device).repeat_interleave( + image_offsets_mask.sum(dim=-1) + ) x.index_put_((batch_idx, image_offsets[image_offsets_mask]), img_emb[image_offsets_mask]) if not (self.config.alibi or self.config.rope): From e39737305dc8216681dd7ec9117237e13484774e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Tue, 13 Feb 2024 12:34:04 -0800 Subject: [PATCH 7/7] Add some comments --- olmo/data/__init__.py | 2 ++ olmo/data/multi_modal_iterable_dataset.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/olmo/data/__init__.py b/olmo/data/__init__.py index 1e50547a1..4050d0e40 100644 --- a/olmo/data/__init__.py +++ b/olmo/data/__init__.py @@ -101,6 +101,8 @@ def build_train_dataloader(train_config: TrainConfig) -> DataLoader: work_dir.mkdir(exist_ok=True, parents=True) if train_config.data.multi_modal: + # TODO: this block will need to change a little when we integrate the actual + # vision dataset, instead of the mock one. assert train_config.model.vision_backbone is not None dataset = MultiModalIterableDataset( pad_token_id=train_config.model.pad_token_id, diff --git a/olmo/data/multi_modal_iterable_dataset.py b/olmo/data/multi_modal_iterable_dataset.py index f1787fd47..9b4663e82 100644 --- a/olmo/data/multi_modal_iterable_dataset.py +++ b/olmo/data/multi_modal_iterable_dataset.py @@ -9,6 +9,10 @@ __all__ = ["MultiModalIterableDataset"] +# TODO: at the moment this class is a mock dataset, it just generates random data. +# But this is where we should integrate the actual vision dataset when we get there. + + class MultiModalIterableDataset(torch.utils.data.IterableDataset[Dict[str, Any]]): def __init__( self,