Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

integrate mock vision backbone into model #441

Open
wants to merge 14 commits into
base: mm-dev
Choose a base branch
from
19 changes: 19 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
"CompilerConfig",
"LayerNormType",
"InitFnType",
"VisionBackboneType",
"VisionBackboneConfig",
"ModelConfig",
"OptimizerType",
"OptimizerConfig",
Expand Down Expand Up @@ -225,6 +227,18 @@ class InitFnType(StrEnum):
"""


class VisionBackboneType(StrEnum):
linear = "linear"


@dataclass
class VisionBackboneConfig(BaseConfig):
name: VisionBackboneType = VisionBackboneType.linear
patch_width: int = 16
patch_height: int = 16
frozen: bool = False


@dataclass
class ModelConfig(BaseConfig):
"""
Expand Down Expand Up @@ -297,6 +311,11 @@ class ModelConfig(BaseConfig):
apply RoPE at the precision of the input.
"""

vision_backbone: Optional[VisionBackboneConfig] = None
"""
Vision backbone settings for multi-modal models.
"""

flash_attention: bool = False
"""
If ``True``, use ``FlashAttention``.
Expand Down
94 changes: 79 additions & 15 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
FSDPWrapStrategy,
LayerNormType,
ModelConfig,
VisionBackboneType,
)
from .exceptions import OlmoConfigurationError
from .initialization import ModuleType, init_weights
Expand Down Expand Up @@ -997,6 +998,44 @@ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointin
block.set_activation_checkpointing(strategy)


class OlmoVisionBackbone(nn.Module):
def __init__(self, config: ModelConfig):
super().__init__()
self.config = config

@classmethod
def build(cls, config: ModelConfig) -> OlmoVisionBackbone:
v_cfg = config.vision_backbone
assert v_cfg is not None
if v_cfg.name == VisionBackboneType.linear:
return OlmoLinearVisionBackbone(config)
else:
raise NotImplementedError(v_cfg.name)

def reset_parameters(self):
pass


class OlmoLinearVisionBackbone(OlmoVisionBackbone):
def __init__(self, config: ModelConfig):
super().__init__(config)
v_cfg = self.config.vision_backbone
assert v_cfg is not None
self.ff = nn.Linear(
v_cfg.patch_width * v_cfg.patch_height * 3, self.config.d_model, device=self.config.init_device
)
if v_cfg.frozen:
for param in self.ff.parameters():
param.requires_grad = False

def forward(self, image_patches: torch.Tensor) -> torch.Tensor:
batch_size, num_patches, *_ = image_patches.shape
# Reshape image patches from (batch_size, num_patches, patch_width, patch_height, 3)
# to (batch_size, num_patches, patch_width * patch_height * 3)
image_patches = image_patches.view(batch_size, num_patches, -1)
return self.ff(image_patches)


class Olmo(nn.Module):
def __init__(self, config: ModelConfig, init_params: bool = True):
super().__init__()
Expand Down Expand Up @@ -1056,6 +1095,7 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
self.transformer.update(
{"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
)

if not config.weight_tying:
self.transformer.update(
{
Expand All @@ -1067,6 +1107,11 @@ def __init__(self, config: ModelConfig, init_params: bool = True):
)
}
)

self.vision_backbone: Optional[OlmoVisionBackbone] = None
if config.vision_backbone is not None:
self.vision_backbone = OlmoVisionBackbone.build(config)

# When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
if init_params and self.config.init_device != "meta":
self.reset_parameters()
Expand Down Expand Up @@ -1113,6 +1158,10 @@ def reset_parameters(self):
if hasattr(self.transformer, "ff_out"):
init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore

# Vision backbone.
if self.vision_backbone is not None:
self.vision_backbone.reset_parameters()

# Let the blocks handle themselves.
if self.config.block_group_size == 1:
for block in self.transformer.blocks:
Expand Down Expand Up @@ -1140,6 +1189,8 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
attention_bias: Optional[torch.Tensor] = None,
past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
image_patches: Optional[torch.Tensor] = None,
image_offsets: Optional[torch.Tensor] = None,
use_cache: bool = False,
last_logits_only: bool = False,
) -> OlmoOutput:
Expand Down Expand Up @@ -1167,6 +1218,10 @@ def forward(
:param past_key_values: Pre-computed keys and values for each attention block.
Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
:param image_patches: For multi-modal models, image patch inputs of shape
`(num_patches, patch_width, patch_height, 3)`.
:param image_offsets: For mulit-modal models, specifies where in the input IDs the embedded image
patches should go. Shape `(num_patches,)`.
:param use_cache: If `True`, return key and value tensors for each block.
:param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
This can speed up decoding when you only care about the next token.
Expand All @@ -1180,10 +1235,25 @@ def forward(
else:
past_length = past_key_values[0][0].size(-2)

img_emb: Optional[torch.Tensor] = None
if image_patches is not None:
# Get image patch embeddings.
assert self.vision_backbone is not None
assert image_offsets is not None
# shape: (batch_size, num_patches, d_model)
img_emb = self.vision_backbone(image_patches)

# Get embeddings of input.
# shape: (batch_size, seq_len, d_model)
x = self.transformer.wte(input_ids) # type: ignore

if img_emb is not None:
# Inject image patch embeddings into input embeddings.
assert image_offsets is not None
image_offsets_mask = image_offsets > 0
batch_idx = torch.arange(0, batch_size).repeat_interleave(image_offsets_mask.sum(dim=-1))
x.index_put_((batch_idx, image_offsets[image_offsets_mask]), img_emb[image_offsets_mask])
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This took some thinking, but you can validate it with this little example:

import torch

B = 2
S = 8
D = 16
P = 3  # num patches (max across instances)
x = torch.zeros(B, S, D)

img_emb = torch.rand(B, P, D)

# use -1 for padding
image_offsets = torch.tensor([[1, 5, 6], [3, -1, -1]])
assert image_offsets.shape == (B, P)

image_offsets_mask = image_offsets > 0
batch_idx = torch.arange(0, B).repeat_interleave(image_offsets_mask.sum(dim=-1))
x.index_put_((batch_idx, image_offsets[image_offsets_mask]), img_emb[image_offsets_mask])


if not (self.config.alibi or self.config.rope):
# Get positional embeddings.
# shape: (1, seq_len)
Expand Down Expand Up @@ -1309,25 +1379,22 @@ def forward(
def get_fsdp_wrap_policy(self, wrap_strategy: Optional[FSDPWrapStrategy] = None):
if wrap_strategy is None:
return None
if wrap_strategy == FSDPWrapStrategy.by_block:
elif wrap_strategy == FSDPWrapStrategy.by_block:

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlock)
return isinstance(module, (OlmoVisionBackbone, OlmoBlock))

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_and_size:

def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
# Determine if we should recurse.
return not isinstance(module, OlmoBlock)
else:
# Determine if we should wrap.
return isinstance(module, (OlmoBlock, nn.Linear, nn.Embedding))
return True # always recurse for simplicity
return isinstance(module, (OlmoVisionBackbone, OlmoBlock, nn.Linear, nn.Embedding))

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group:
Expand All @@ -1340,7 +1407,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlockGroup)
return isinstance(module, (OlmoVisionBackbone, OlmoBlockGroup))

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.by_block_group_and_size:
Expand All @@ -1352,17 +1419,14 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
# Determine if we should recurse.
return not isinstance(module, OlmoBlockGroup)
else:
# Determine if we should wrap.
return isinstance(module, (OlmoBlockGroup, nn.Linear, nn.Embedding))
return True # always recurse for simplicity
return isinstance(module, (OlmoVisionBackbone, OlmoBlockGroup, nn.Linear, nn.Embedding))

return fsdp_wrap_fn
elif wrap_strategy == FSDPWrapStrategy.size_based:
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

return size_based_auto_wrap_policy
return partial(size_based_auto_wrap_policy, force_leaf_modules={OlmoVisionBackbone})
elif wrap_strategy in {
FSDPWrapStrategy.one_in_two,
FSDPWrapStrategy.one_in_three,
Expand All @@ -1380,7 +1444,7 @@ def fsdp_wrap_fn(module, recurse: bool = True, nonwrapped_numel: int = 0):
del nonwrapped_numel
if recurse:
return True # always recurse for simplicity
return isinstance(module, OlmoBlock) and module.layer_id % c == 0
return isinstance(module, (OlmoVisionBackbone, OlmoBlock)) and module.layer_id % c == 0

return fsdp_wrap_fn
else:
Expand Down