diff --git a/CHANGELOG_DEV.md b/CHANGELOG_DEV.md index 36a41c9c9..32c54ed41 100644 --- a/CHANGELOG_DEV.md +++ b/CHANGELOG_DEV.md @@ -85,3 +85,19 @@ This PR mainly addresses the warmstart of model training, e.g., after GPU crashe **Breaking Changes** * the settings part of the configs have been completely refactored + + +## PR #263 CoCa model updates + +This PR adds updates to the CoCa model: + + +**General Changes** +* add AudioTransformer model +* update the VisionTransformer model for video +* add the MultimodalWebDataset dataset for loading audio-text, image-text and video-text in the webdataset format +* add a multi-loss function for specifying a weighted-sum of different losses +* update the CoCa model to include encoders for video and audio + +**Breaking Changes** +* the LLMDataLoader now contains a Pytorch Dataloader object instead of inheriting from it. diff --git a/config_files/training/config_coca_img_aud_vid_dataset.yaml b/config_files/training/config_coca_img_aud_vid_dataset.yaml new file mode 100644 index 000000000..4003e0e3d --- /dev/null +++ b/config_files/training/config_coca_img_aud_vid_dataset.yaml @@ -0,0 +1,566 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 2 + checkpointing_interval_in_steps: 2 + evaluation_interval_in_steps: 2 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 10 + sequence_length: 256 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_num_steps + config: + num_steps: ${settings.training_target.num_target_steps} + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_samples + config: + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_samples: ${settings.coca_example_settings.train_num_samples} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + local_num_seen_batches: 0 + last_step: -1 + coca_example_settings: + train_num_samples: 64 + val_num_samples: 32 + +tokenizer: + component_key: tokenizer + variant_key: pretrained_hf_tokenizer + config: + pretrained_model_name_or_path: openai/clip-vit-base-patch32 + padding: true + max_length: ${settings.step_profile.sequence_length} + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - images + - audio + - audio_len + - video + - ${settings.referencing_keys.sample_key} + target_keys: [] + text_sample_key: ${settings.referencing_keys.sample_key} + text_target_key: ${settings.referencing_keys.target_key} + +train_audio_transform: + component_key: transform + variant_key: audio_transform + config: + is_training: True + block_size_audio_encoder: ${model_raw.config.audio_encoder_config.block_size} + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + +train_image_transform: + component_key: transform + variant_key: image_transform + config: + is_training: True + input_size: ${model_raw.config.image_encoder_config.img_size} + +train_video_transform: + component_key: transform + variant_key: video_transform + config: + is_training: True + hflip: 0.5 + color_jitter: [0.5, 0.5, 0.5, 0.5] + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + input_size: ${model_raw.config.video_encoder_config.img_size} + num_frames: ${model_raw.config.video_encoder_config.num_video_frames} + +val_audio_transform: + component_key: transform + variant_key: audio_transform + config: + is_training: False + block_size_audio_encoder: ${model_raw.config.audio_encoder_config.block_size} + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + +val_image_transform: + component_key: transform + variant_key: image_transform + config: + is_training: False + input_size: ${model_raw.config.image_encoder_config.img_size} + +val_video_transform: + component_key: transform + variant_key: video_transform + config: + is_training: False + hflip: 0.0 + color_jitter: [0.0, 0.0, 0.0, 0.0] + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + input_size: ${model_raw.config.video_encoder_config.img_size} + num_frames: ${model_raw.config.video_encoder_config.num_video_frames} + +text_transform: + component_key: transform + variant_key: text_transform + config: + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +train_video_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "youcook2/training/000000.tar" + is_audio_video: ${model_raw.config.is_audio_video} + modality_key_mapping: + TEXT: ["json", "input_ids"] + VIDEO: ["mp4", "video"] + modality_transforms: + VIDEO: + instance_key: train_video_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 100_000 + +val_video_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "youcook2/training/000000.tar" + is_audio_video: ${model_raw.config.is_audio_video} + modality_key_mapping: + TEXT: ["json", "input_ids"] + VIDEO: ["mp4", "video"] + modality_transforms: + VIDEO: + instance_key: val_video_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + +train_audio_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "commonvoice/commonvoice_17_dev_wav_000001.tar" + modality_key_mapping: + TEXT: ["transcript.txt", "input_ids"] # source and target keys + AUDIO: ["wav", "audio"] + modality_transforms: + AUDIO: + instance_key: train_audio_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 30000 + +val_audio_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "commonvoice/commonvoice_17_dev_wav_000001.tar" + modality_key_mapping: + TEXT: ["transcript.txt", "input_ids"] # source and target keys + AUDIO: ["wav", "audio"] + modality_transforms: + AUDIO: + instance_key: val_audio_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + +train_coco_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "coco_captions/data/train/000000.tar" + modality_key_mapping: + TEXT: ["json_text0", "input_ids"] + IMAGE: ["jpg", "images"] + modality_transforms: + IMAGE: + instance_key: train_image_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + +val_coco_dataset_builder: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: "coco_captions/data/train/000000.tar" + modality_key_mapping: + TEXT: ["json_text0", "input_ids"] + IMAGE: ["jpg", "images"] + modality_transforms: + IMAGE: + instance_key: val_image_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + + +train_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: train_audio_dataset_builder + pass_type: BY_REFERENCE + - instance_key: train_coco_dataset_builder + pass_type: BY_REFERENCE + - instance_key: train_video_builder + pass_type: BY_REFERENCE + mixing_ratios: [0.5, 0.4, 0.1] + batch_size: ${settings.step_profile.local_train_micro_batch_size} + shardshuffle: 100 + repeat: false + resample: false + shuffle_buffer: 10_000 + +val_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: val_audio_dataset_builder + pass_type: BY_REFERENCE + - instance_key: val_coco_dataset_builder + pass_type: BY_REFERENCE + - instance_key: val_video_builder + pass_type: BY_REFERENCE + mixing_ratios: [0.5, 0.4, 0.1] + batch_size: ${settings.step_profile.local_train_micro_batch_size} + shardshuffle: 1000 + repeat: true + resample: true + shuffle_buffer: 10_000 + +train_dataloader: + component_key: data_loader + variant_key: web_dataloader + config: + num_workers: 8 + pin_memory: true + drop_last: true + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.step_profile.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +val_dataloader: + component_key: data_loader + variant_key: web_dataloader + config: + num_workers: 8 + pin_memory: true + drop_last: false + dataloader_tag: "val" + dataset: + instance_key: val_dataset + pass_type: BY_REFERENCE + batch_size: ${settings.step_profile.local_train_micro_batch_size} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: val_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: fsdp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +captioning_loss: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${model_raw.config.prediction_key} + +contrastive_loss_audio: + component_key: loss + variant_key: clip_loss + config: + prediction_keys: + - ${model_raw.config.audio_cls_prediction_key} + - ${model_raw.config.audio_text_cls_prediction_key} + logit_scale_key: ${model_raw.config.logit_scale_prediction_key} + tag: contrastive_loss_audio + +contrastive_loss_image: + component_key: loss + variant_key: clip_loss + config: + prediction_keys: + - ${model_raw.config.image_cls_prediction_key} + - ${model_raw.config.image_text_cls_prediction_key} + logit_scale_key: ${model_raw.config.logit_scale_prediction_key} + tag: contrastive_loss_image + +contrastive_loss_video: + component_key: loss + variant_key: clip_loss + config: + prediction_keys: + - ${model_raw.config.video_cls_prediction_key} + - ${model_raw.config.video_text_cls_prediction_key} + logit_scale_key: ${model_raw.config.logit_scale_prediction_key} + tag: contrastive_loss_image + +loss_fn: + component_key: loss + variant_key: multiple_functions_loss + config: + losses: + - instance_key: captioning_loss + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_audio + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_image + pass_type: BY_REFERENCE + - instance_key: contrastive_loss_video + pass_type: BY_REFERENCE + corrsp_weights: + - 2.0 + - 1.0 + - 1.0 + - 1.0 + +wrapped_model: + component_key: model + variant_key: fsdp_wrapped + config: + model: + instance_key: model + pass_type: BY_REFERENCE + sync_module_states: true + mixed_precision_settings: FP_16 + sharding_strategy: HYBRID_SHARD + block_names: [TransformerBlock, VisionTransformerBlock, ConformerBlock] + +model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: coca + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: 4 # text_decoder_config.n_layer_text + text_decoder_config.n_layer_multimodal_text + +model_raw: + component_key: model + variant_key: coca + config: + prediction_key: ${settings.referencing_keys.prediction_key} + audio_embd_prediction_key: audio_embeddings + image_embd_prediction_key: image_embeddings + video_embd_prediction_key: video_embeddings + text_embd_prediction_key: text_embeddings + image_cls_prediction_key: image_cls + image_text_cls_prediction_key: image_text_cls + audio_cls_prediction_key: audio_cls + audio_text_cls_prediction_key: audio_text_cls + video_cls_prediction_key: video_cls + video_text_cls_prediction_key: video_text_cls + text_cls_prediction_key: text_cls + modality_keys: ${collate_fn.config.sample_keys} + is_audio_video: false + individual_datasets: true + logit_scale_prediction_key: logit_scale + audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 2_000 + n_mels: 128 + n_embd: 768 + n_heads: 8 + n_conformer_blocks: 2 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 + image_encoder_config: + sample_key: images + prediction_key: image_embeddings + img_size: 256 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 2 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 16 # 18 in the original coca + patch_stride: 16 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 256 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 2 + attention_config: + attention_engine_type: default_attention + n_head: 12 + n_embd: 768 + dropout: 0.0 + patch_size: 18 # 18 in the original coca + patch_stride: 18 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 + text_decoder_config: + sample_key: ${settings.referencing_keys.sample_key} + prediction_key: ${model_raw.config.prediction_key} + block_size: 512 + vocab_size: 50304 # 64k in the original coca + n_layer_text: 2 # update model_initializer num_layers + n_layer_multimodal_text: 2 # update model_initializer num_layers + attention_config: + attention_engine_type: default_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 + n_pool_head: 12 + n_queries: 256 + bias_attn_pool: False + epsilon_attn_pool: 1e-5 + +scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 8e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.02 + anneal_strategy: linear # COCA uses linear decay + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 8e-4 + betas: [0.9, 0.999] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, norm, parameter] + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp + config: + wrapped_model: + instance_key: wrapped_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_coca + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} diff --git a/config_files/training/config_example_coca.yaml b/config_files/training/config_example_coca.yaml deleted file mode 100644 index 570f9e5ae..000000000 --- a/config_files/training/config_example_coca.yaml +++ /dev/null @@ -1,317 +0,0 @@ -settings: - experiment_id: ${modalities_env:experiment_id} - config_file_path: ${modalities_env:config_file_path} - referencing_keys: - sample_key: input_ids - target_key: target_ids - prediction_key: logits - cuda_env: - local_rank: ${cuda_env:LOCAL_RANK} - global_rank: ${cuda_env:RANK} - world_size: ${cuda_env:WORLD_SIZE} - paths: - checkpoint_saving_path: data/checkpoints - train_dataset_path: ./data/lorem_ipsum.pbin - intervals: - training_log_interval_in_steps: 2 - checkpointing_interval_in_steps: 2 - evaluation_interval_in_steps: 2 - consistency_enforcement: - enforce_tokens_per_step_consistency: true - enforce_last_step_logged: false - enforce_last_step_evaluated: false - enforce_last_step_checkpointed: false - step_profile: - gradient_accumulation_steps: 1 - local_train_micro_batch_size: 1 - sequence_length: 256 - training_target: - num_target_tokens: - component_key: number_conversion - variant_key: num_tokens_from_num_steps - config: - num_steps: ${settings.training_target.num_target_steps} - num_ranks: ${settings.cuda_env.world_size} - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} - sequence_length: ${settings.step_profile.sequence_length} - gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} - num_target_steps: # for the batch progress subscriber - component_key: number_conversion - variant_key: num_steps_from_num_samples - config: - num_ranks: ${settings.cuda_env.world_size} - local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} - global_num_samples: ${settings.coca_example_settings.train_num_samples} - gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} - training_progress: - global_num_seen_tokens: 0 - num_seen_steps: 0 - local_num_seen_batches: 0 - last_step: -1 - coca_example_settings: - train_num_samples: 64 - val_num_samples: 32 - -collate_fn: - component_key: collate_fn - variant_key: coca_collator - config: - sample_keys: - - images - - ${settings.referencing_keys.sample_key} - target_keys: [] - text_sample_key: ${settings.referencing_keys.sample_key} - text_target_key: ${settings.referencing_keys.target_key} - -train_dataset: - component_key: dataset - variant_key: dummy_dataset - config: - num_samples: ${settings.coca_example_settings.train_num_samples} - sample_definition: - - sample_key: images - sample_shape: [3, 224, 224] - sample_type: float - - sample_key: input_ids - sample_shape: [1024] - sample_type: int - -val_dataset: - component_key: dataset - variant_key: dummy_dataset - config: - num_samples: ${settings.coca_example_settings.val_num_samples} - sample_definition: - - sample_key: images - sample_shape: [3, 224, 224] - sample_type: float - - sample_key: input_ids - sample_shape: [1024] - sample_type: int - -train_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - dataloader_tag: train - skip_num_batches: ${settings.training_progress.local_num_seen_batches} - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.step_profile.local_train_micro_batch_size} - drop_last: true - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: true - drop_last: true - seed: 42 - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -val_dataloader: - component_key: data_loader - variant_key: default - config: - num_workers: 2 - pin_memory: true - dataloader_tag: val - dataset: - instance_key: val_dataset - pass_type: BY_REFERENCE - batch_sampler: - component_key: batch_sampler - variant_key: default - config: - batch_size: ${settings.step_profile.local_train_micro_batch_size} - drop_last: true - - sampler: - component_key: sampler - variant_key: distributed_sampler - config: - rank: ${settings.cuda_env.global_rank} - num_replicas: ${settings.cuda_env.world_size} - shuffle: false - drop_last: true - dataset: - instance_key: train_dataset - pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn - pass_type: BY_REFERENCE - -eval_dataloaders: - - instance_key: val_dataloader - pass_type: BY_REFERENCE - -checkpoint_saving: - component_key: checkpoint_saving - variant_key: default - config: - checkpoint_saving_strategy: - component_key: checkpoint_saving_strategy - variant_key: save_k_most_recent_checkpoints_strategy - config: - k: -1 # -1 to save all checkpoints - checkpoint_saving_execution: - component_key: checkpoint_saving_execution - variant_key: fsdp - config: - checkpoint_path: ${settings.paths.checkpoint_saving_path} - global_rank: ${settings.cuda_env.global_rank} - experiment_id: ${settings.experiment_id} - -loss_fn: - component_key: loss - variant_key: clm_cross_entropy_loss - config: - target_key: ${settings.referencing_keys.target_key} - prediction_key: ${settings.referencing_keys.prediction_key} - -wrapped_model: - component_key: model - variant_key: fsdp_wrapped - config: - model: - instance_key: model - pass_type: BY_REFERENCE - sync_module_states: true - mixed_precision_settings: FP_16 - sharding_strategy: FULL_SHARD - block_names: [TransformerBlock, VisionTransformerBlock] - -model: - component_key: model - variant_key: model_initialized - config: - model: - instance_key: model_raw - pass_type: BY_REFERENCE - model_initializer: - component_key: model_initialization - variant_key: composed - config: - model_type: coca - weight_init_type: plain - mean: 0.0 - std: 0.02 - -model_raw: - component_key: model - variant_key: coca - config: - prediction_key: logits - vision_embd_prediction_key: vision_embeddings - text_embd_prediction_key: text_embeddings - vision_cls_prediction_key: vision_cls - text_cls_prediction_key: text_cls - vision_encoder_config: - sample_key: images - prediction_key: vision_embeddings - img_size: 224 - n_classes: Null # Disable vision transformer head - n_layer: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - n_embd: 768 - dropout: 0.0 - patch_size: 16 - patch_stride: 16 - n_img_channels: 3 - add_cls_token: False - bias: True - text_decoder_config: - sample_key: ${settings.referencing_keys.sample_key} - prediction_key: ${loss_fn.config.prediction_key} - block_size: 1024 - vocab_size: 50304 - n_layer_text: 12 - n_layer_multimodal_text: 12 - attention_config: - attention_engine_type: default_attention - n_head: 12 - ffn_hidden: 2048 - n_embd: 768 - dropout: 0.0 - bias: true - activation: swiglu - epsilon: 1e-5 - n_pool_head: 8 - n_vision_queries: 256 - bias_attn_pool: False - epsilon_attn_pool: 1e-5 - -scheduler: - component_key: scheduler - variant_key: onecycle_lr - config: - optimizer: - instance_key: optimizer - pass_type: BY_REFERENCE - max_lr: 6e-4 - div_factor: 10 - final_div_factor: 1 - total_steps: ${settings.training_target.num_target_steps} - pct_start: 0.01 - anneal_strategy: cos - last_epoch: ${settings.training_progress.last_step} - -optimizer: - component_key: optimizer - variant_key: adam_w - config: - lr: 0.0001 - betas: [0.9, 0.95] - eps: 1e-8 - weight_decay: 1e-1 - weight_decay_groups_excluded: [] - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - -gradient_clipper: - component_key: gradient_clipper - variant_key: fsdp_logging_only - config: - wrapped_model: - instance_key: wrapped_model - pass_type: BY_REFERENCE - norm_type: P2_NORM - -progress_subscriber: - component_key: progress_subscriber - variant_key: rich - config: - global_rank: ${settings.cuda_env.global_rank} - num_seen_steps: ${settings.training_progress.num_seen_steps} - num_target_steps: ${settings.training_target.num_target_steps} - train_dataloader_tag: ${train_dataloader.config.dataloader_tag} - eval_dataloaders: - instance_key: eval_dataloaders - pass_type: BY_REFERENCE - -evaluation_subscriber: - component_key: results_subscriber - variant_key: wandb - config: - global_rank: ${settings.cuda_env.global_rank} - project: modalities - mode: OFFLINE - experiment_id: ${settings.experiment_id} - directory: wandb_storage - config_file_path: ${settings.config_file_path} \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index b81a30f5e..79b3884ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ requires-python = ">=3.10,<3.12" description = "Modalities, a PyTorch-native framework for distributed and reproducible foundation model training." readme = "README.md" dependencies = [ + "torchaudio", "numpy<2.0", "torch~=2.4.1", "packaging", @@ -12,6 +13,7 @@ dependencies = [ "pyyaml", "transformers", "datasets", + "decord", "protobuf", "SentencePiece", "rich", @@ -23,7 +25,10 @@ dependencies = [ "class_resolver", "wandb", "einops>=0.7.0", - "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` + "webdataset>=0.2.86", + "timm>=0.9.16", + "pyav", + "flash-attn", # install this directly via `pip install flash-attn --no-build-isolation` ] [project.urls] @@ -84,4 +89,4 @@ exclude_also = [ ignore_errors = true [tool.coverage.html] -directory = "coverage_html_report" \ No newline at end of file +directory = "coverage_html_report" diff --git a/src/modalities/batch.py b/src/modalities/batch.py index 19aa673e2..cd32ed807 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -50,8 +50,8 @@ def device(self) -> torch.device: return self.samples[key].device def __len__(self) -> int: - key = list(self.samples.keys())[0] - return self.samples[key].shape[self.batch_dim] + lengths = [self.samples[key].shape[self.batch_dim] for key in self.samples.keys()] + return max(lengths) @dataclass @@ -89,8 +89,8 @@ def get_targets(self, key: str) -> torch.Tensor: return self.targets[key] def __len__(self) -> int: - key = list(self.predictions.keys())[0] - return self.predictions[key].shape[self.batch_dim] + lengths = [self.predictions[key].shape[self.batch_dim] for key in self.predictions.keys()] + return max(lengths) @dataclass diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 80cce3b98..0a6bf9eea 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -16,9 +16,9 @@ PydanticCheckpointSavingExecutionIFType, PydanticCheckpointSavingStrategyIFType, PydanticCollateFnIFType, + PydanticDataLoaderIFType, PydanticDatasetIFType, PydanticFSDPModuleType, - PydanticLLMDataLoaderIFType, PydanticModelInitializationIFType, PydanticOptimizerIFType, PydanticPytorchDeviceType, @@ -67,6 +67,26 @@ class CLMCrossEntropyLossConfig(BaseModel): prediction_key: str +class NCELossConfig(BaseModel): + prediction_key1: str + prediction_key2: str + is_asymmetric: bool = True + temperature: float = 1.0 + tag: str = "NCELoss" + + +class ClipLossConfig(BaseModel): + logit_scale_key: str + prediction_keys: list[str] + local_loss: bool = True + tag: str = "ClipLoss" + + +class MultipleFunctionsLossConfig(BaseModel): + losses: list + corrsp_weights: list + + # Checkpointing class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): k: PositiveInt @@ -312,8 +332,18 @@ class LLMDataLoaderConfig(BaseModel): fixed_num_batches: Optional[int] = None +class WebDataLoaderConfig(BaseModel): + dataloader_tag: str + dataset: PydanticDatasetIFType + batch_size: int + collate_fn: PydanticCollateFnIFType + num_workers: Annotated[int, Field(strict=True, ge=0)] + pin_memory: bool + drop_last: bool + + class RepeatingDataLoaderConfig(BaseModel): - dataloader: PydanticLLMDataLoaderIFType + dataloader: PydanticDataLoaderIFType reshuffle_after_epoch: Optional[bool] = False num_epochs: Annotated[int, Field(strict=True, ge=1)] @@ -322,8 +352,16 @@ class DummyProgressSubscriberConfig(BaseModel): pass +class SimpleProgressSubscriberConfig(BaseModel): + eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list) + train_dataloader_tag: str + num_seen_steps: Annotated[int, Field(strict=True, ge=0)] + num_target_steps: Annotated[int, Field(strict=True, gt=0)] + global_rank: Annotated[int, Field(strict=True, ge=0)] + + class RichProgressSubscriberConfig(BaseModel): - eval_dataloaders: Optional[list[PydanticLLMDataLoaderIFType]] = Field(default_factory=list) + eval_dataloaders: Optional[list[PydanticDataLoaderIFType]] = Field(default_factory=list) train_dataloader_tag: str num_seen_steps: Annotated[int, Field(strict=True, ge=0)] num_target_steps: Annotated[int, Field(strict=True, gt=0)] diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index bd203c9ca..6a8b7c411 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -6,9 +6,9 @@ from modalities.config.pydanctic_if_types import ( PydanticCheckpointSavingIFType, + PydanticDataLoaderIFType, PydanticDatasetIFType, PydanticGradientClipperIFType, - PydanticLLMDataLoaderIFType, PydanticLossIFType, PydanticLRSchedulerIFType, PydanticMessageSubscriberIFType, @@ -170,8 +170,8 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel scheduler: PydanticLRSchedulerIFType loss_fn: PydanticLossIFType train_dataset: PydanticDatasetIFType - train_dataloader: PydanticLLMDataLoaderIFType - eval_dataloaders: list[PydanticLLMDataLoaderIFType] + train_dataloader: PydanticDataLoaderIFType + eval_dataloaders: list[PydanticDataLoaderIFType] progress_subscriber: PydanticMessageSubscriberIFType evaluation_subscriber: PydanticMessageSubscriberIFType checkpoint_saving: PydanticCheckpointSavingIFType diff --git a/src/modalities/config/pydanctic_if_types.py b/src/modalities/config/pydanctic_if_types.py index 3761eb8df..25d879737 100644 --- a/src/modalities/config/pydanctic_if_types.py +++ b/src/modalities/config/pydanctic_if_types.py @@ -13,7 +13,7 @@ from modalities.checkpointing.checkpoint_loading import CheckpointLoadingIF from modalities.checkpointing.checkpoint_saving import CheckpointSaving, CheckpointSavingExecutionABC from modalities.checkpointing.checkpoint_saving_strategies import CheckpointSavingStrategyIF -from modalities.dataloader.dataloader import LLMDataLoader +from modalities.dataloader.dataloader import DataLoaderIF from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss @@ -56,7 +56,7 @@ def __get_pydantic_core_schema__( PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)] PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)] PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)] -PydanticLLMDataLoaderIFType = Annotated[LLMDataLoader, PydanticThirdPartyTypeIF(LLMDataLoader)] +PydanticDataLoaderIFType = Annotated[DataLoaderIF, PydanticThirdPartyTypeIF(DataLoaderIF)] PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)] PydanticLRSchedulerIFType = Annotated[LRScheduler, PydanticThirdPartyTypeIF(LRScheduler)] PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)] diff --git a/src/modalities/dataloader/dataloader.py b/src/modalities/dataloader/dataloader.py index f06139627..48d4b1d6a 100644 --- a/src/modalities/dataloader/dataloader.py +++ b/src/modalities/dataloader/dataloader.py @@ -1,5 +1,7 @@ +import multiprocessing from typing import Iterable, Optional +import webdataset as wd from torch.utils.data import Dataset, DistributedSampler, Sampler from torch.utils.data.dataloader import DataLoader, _collate_fn_t, _worker_init_fn_t @@ -11,7 +13,11 @@ from modalities.dataloader.samplers import ResumableBatchSampler -class LLMDataLoader(DataLoader[T_co]): +class DataLoaderIF: + pass + + +class LLMDataLoader(DataLoaderIF): """LLMDataLoader is a custom DataLoader class that extends the PyTorch DataLoader class.""" def __init__( @@ -62,7 +68,9 @@ def __init__( None """ assert batch_sampler is not None and batch_size == 1 - super().__init__( + self._dataloader_tag = dataloader_tag + self._batch_size = batch_sampler.batch_size + self._torch_dataloader = DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, # shuffling must be implemented on a dataset level @@ -81,9 +89,6 @@ def __init__( pin_memory_device=pin_memory_device, ) - self._dataloader_tag = dataloader_tag - self._batch_size = batch_sampler.batch_size - @property def dataloader_tag(self) -> str: """ @@ -125,6 +130,47 @@ def batch_size(self, value: int): """ self._batch_size = value + def __len__(self): + return self._torch_dataloader.__len__() + + def __iter__(self): + return self._torch_dataloader.__iter__() + + @property + def dataset(self) -> Dataset[T_co]: + return self._torch_dataloader.dataset + + @property + def batch_sampler(self) -> ResumableBatchSampler: + return self._torch_dataloader.batch_sampler + + @property + def sampler(self) -> Sampler | Iterable | None: + return self._torch_dataloader.sampler + + @property + def collate_fn(self) -> _collate_fn_t: + return self._torch_dataloader.collate_fn + + @property + def multiprocessing_context(self) -> str | multiprocessing.context.BaseContext: + return self._torch_dataloader.multiprocessing_context + + @multiprocessing_context.setter + def multiprocessing_context(self, multiprocessing_context): + self._torch_dataloader.multiprocessing_context = multiprocessing_context + + @property + def _auto_collation(self): + return self._torch_dataloader._auto_collation + + @property + def _index_sampler(self): + return self._torch_dataloader._index_sampler + + def check_worker_number_rationality(self): + return self._torch_dataloader.check_worker_number_rationality() + @property def fast_forward_batch_id(self) -> int: """ @@ -133,15 +179,15 @@ def fast_forward_batch_id(self) -> int: Returns: int: fast forward batch ID """ - return self.batch_sampler.start_index + return self._torch_dataloader.batch_sampler.start_index -class RepeatingDataLoader(LLMDataLoader[T_co]): +class RepeatingDataLoader(LLMDataLoader): """ RepeatingDataLoader is a custom DataLoader class that repeats the given dataloader for the specified number of epochs.""" - def __init__(self, dataloader: LLMDataLoader[T_co], num_epochs: int, reshuffle_after_epoch: bool = False): + def __init__(self, dataloader: LLMDataLoader, num_epochs: int, reshuffle_after_epoch: bool = False): """ Initializes a RepeatingDataLoader object that repeats the given dataloader for the specified number of epochs. This is especially useful for DataLoader types that we wish to automatically restart upon completion. @@ -245,3 +291,68 @@ def __len__(self) -> int: int: The total number of steps. """ return self.num_epochs * len(self.dataloader) + + +class WebDataLoader(DataLoaderIF): + """WebDataLoader is a custom DataLoader class that wraps the webdataset.WebLoader class.""" + + def __init__( + self, + dataloader_tag: str, + dataset: Dataset[T_co], + batch_size: Optional[int] = 1, + num_workers: int = 0, + collate_fn: Optional[_collate_fn_t] = None, + pin_memory: bool = False, + drop_last: bool = False, + ): + """Initializes WebDataLoader, which is a wrapper for webdataset.WebLoader. + + Args: + dataloader_tag (str): The tag for the dataloader. + dataset (Dataset[T_co]): The dataset to load the data from. + batch_size (Optional[int], optional): The batch size. Defaults to 1. + num_workers (int, optional): The number of worker processes to use for data loading. Defaults to 0. + collate_fn (Optional[_collate_fn_t], optional): The function used to collate the data samples. + Defaults to None. + pin_memory (bool, optional): Flag indicating whether to pin the memory. Defaults to False. + drop_last (bool, optional): Flag indicating whether to drop the last incomplete batch. Defaults to False. + """ + self.num_batches = len(dataset) // batch_size + int(not drop_last) + dataset = dataset.batched(batch_size, collation_fn=collate_fn) + self.webloader = wd.WebLoader(dataset=dataset, batch_size=None, num_workers=num_workers, pin_memory=pin_memory) + self.webloader = self.webloader.with_epoch(self.num_batches) + self.dataloader_tag = dataloader_tag + self.batch_size = batch_size + + def __len__(self): + return self.num_batches + + def __iter__(self): + return iter(self.webloader) + + @property + def batch_size(self) -> int: + return self._batch_size + + @batch_size.setter + def batch_size(self, value: int): + self._batch_size = value + + @property + def fast_forward_sample_id(self) -> int: + """The sample id until which we fast-forward, as specified in the ResumableBatchSampler. + + Returns: + int: fast forward sample id + """ + return 0 # self.batch_size * self.batch_sampler.start_index + + @property + def fast_forward_batch_id(self) -> int: + """The batch id until which we fast-forward, as specified in the ResumableBatchSampler. + + Returns: + int: fast forward batch id + """ + return 0 # self.batch_sampler.start_index diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index e327f11ea..fe21ab679 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -3,7 +3,8 @@ from torch.utils.data import BatchSampler from torch.utils.data.dataset import Dataset -from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader +from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader, WebDataLoader +from modalities.dataloader.dataset import MultimodalWebDataset from modalities.dataloader.samplers import ResumableBatchSampler from modalities.exceptions import ConfigError @@ -89,3 +90,39 @@ def get_repeating_dataloader( """ dataloader = RepeatingDataLoader(dataloader, num_epochs, reshuffle_after_epoch) return dataloader + + @staticmethod + def get_web_dataloader( + dataloader_tag: str, + dataset: MultimodalWebDataset, + batch_size: int, + collate_fn: Callable, + num_workers: int, + pin_memory: bool, + drop_last: bool, + ) -> WebDataLoader: + """ + Returns a WebDataLoader object for a MultimodalWebDataset + + Args: + dataloader_tag (str): Tag for the dataloader + dataset (Dataset): The MultimodalWebDataset to be used + batch_size (int): batch size per device + collate_fn (Callable): Callable for shaping the batch + num_workers (int): Number of workers for the dataloader + pin_memory (bool): Flag indicating whether to pin memory + drop_last (bool): Flag indicating whether to drop the last non-full batch + + Returns: + WebDataLoader: A WebDataLoader object + """ + dataloader = WebDataLoader( + dataloader_tag=dataloader_tag, + dataset=dataset, + batch_size=batch_size, + collate_fn=collate_fn, + num_workers=num_workers, + pin_memory=pin_memory, + drop_last=drop_last, + ) + return dataloader diff --git a/src/modalities/dataloader/dataset.py b/src/modalities/dataloader/dataset.py index 67c3585a1..f2ef94d93 100644 --- a/src/modalities/dataloader/dataset.py +++ b/src/modalities/dataloader/dataset.py @@ -1,20 +1,37 @@ from __future__ import annotations +import io +import random +import re from enum import Enum from pathlib import Path -from typing import Optional +from typing import Annotated, Any, Optional +import decord import jq import numpy as np -from pydantic import BaseModel +import PIL +import torch +import torchaudio +import webdataset as wds +from pydantic import BaseModel, Field +from timm.data import create_transform +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from torch.utils.data import IterableDataset from torch.utils.data.dataset import Dataset as TorchdataSet +from torchvision.transforms import v2 as transforms from tqdm import tqdm from transformers import BatchEncoding +from modalities.config.config import PydanticTokenizerIFType +from modalities.config.lookup_enum import LookupEnum +from modalities.config.pydanctic_if_types import PydanticThirdPartyTypeIF +from modalities.dataloader.create_packed_data import EmbeddedStreamData +from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper +from modalities.util import flatten_dict -from ..dataloader.large_file_lines_reader import LargeFileLinesReader -from .create_packed_data import EmbeddedStreamData +decord.bridge.set_bridge("torch") class Dataset(TorchdataSet): @@ -48,6 +65,7 @@ class DummySampleDataType(str, Enum): FLOAT = "float" INT = "int" + CONSTANT = "const" class DummySampleConfig(BaseModel): @@ -99,6 +117,8 @@ def __init__(self, num_samples: int, sample_definition: tuple[DummySampleConfig] self.num_samples = num_samples self.sample_definition = sample_definition + self.VISION = 1 + def __len__(self) -> int: """ Returns the length of the dataset. @@ -131,6 +151,8 @@ def _create_random_sample(self) -> dict: data = np.random.randn(*s.sample_shape) elif s.sample_type == DummySampleDataType.INT: data = np.random.randint(low=0, high=512, size=s.sample_shape) + elif s.sample_type == DummySampleDataType.CONSTANT: + data = self.VISION else: raise NotImplementedError(f"DummyDataset does not support type { s.sample_type}") sample[s.sample_key] = data @@ -368,3 +390,773 @@ def _generate_packing_index(self) -> list[tuple[int, int]]: curr_offset = segment_offset curr_len = segment_len return index + + +class ModalityEnum(LookupEnum): + TEXT = "text" + IMAGE = "image" + VIDEO = "video" + AUDIO = "audio" + + +class TransformConfig(BaseModel): + pass + + +class Transform: + pass + + +PydanticTransformIFType = Annotated[Transform, PydanticThirdPartyTypeIF(Transform)] + + +class ImageTransformConfig(TransformConfig): + input_size: int | tuple[int, int] | tuple[int, int, int] = 224 + is_training: bool = False + no_aug: bool = False + train_crop_mode: Optional[str] = None + scale: Optional[tuple[float, float]] = None + ratio: Optional[tuple[float, float]] = None + hflip: float = 0.5 + vflip: float = 0.0 + color_jitter: float | tuple[float, ...] = 0.4 + color_jitter_prob: Optional[float] = None + grayscale_prob: float = 0.0 + gaussian_blur_prob: float = 0.0 + auto_augment: Optional[str] = None + interpolation: str = "bilinear" + mean: tuple[float, ...] = IMAGENET_DEFAULT_MEAN + std: tuple[float, ...] = IMAGENET_DEFAULT_STD + re_prob: float = 0.0 + re_mode: str = "const" + re_count: int = 1 + re_num_splits: int = 0 + crop_pct: Optional[float] = None + crop_mode: Optional[str] = None + crop_border_pixels: Optional[int] = None + tf_preprocessing: bool = False + use_prefetcher: bool = False + separate: bool = False + + +class ImageTransform(Transform): + """ImageTransform class.""" + + def __init__(self, **kwargs) -> None: + """ + Initializes a Transform object for image transformations. + + The following argument descriptions are duplicated from: + https://github.com/huggingface/pytorch-image-models/blob/main/timm/data/transforms_factory.py + + Args: + input_size (int, tuple[int,int], tuple[int, int, int]: + Target input size (channels, height, width) tuple or size scalar. + is_training (bool): Return training (random) transforms. + no_aug (bool): Disable augmentation for training (useful for debug). + train_crop_mode (Optional[str]): Training random crop mode ('rrc', 'rkrc', 'rkrr'). + scale (Optional[tuple[float, float]]) : Random resize scale range (crop area, < 1.0 => zoom in). + ratio (Optional[tuple[float, float]]): Random aspect ratio range + (crop ratio for RRC, ratio adjustment factor for RKR). + hflip (float): Horizontal flip probability. + vflip (float): Vertical flip probability. + color_jitter (float | tuple[float, ...]): Random color jitter component factors + (brightness, contrast, saturation, hue). + Scalar is applied as (scalar,) * 3 (no hue). + color_jitter_prob (Optional[float]): Apply color jitter with this + probability if not None (for SimlCLR-like aug). + grayscale_prob (float): Probability of converting image to grayscale (for SimCLR-like aug). + gaussian_blur_prob (float): Probability of applying gaussian blur (for SimCLR-like aug). + auto_augment (Optional[str]): Auto augment configuration string (see auto_augment.py). + interpolation (str): Image interpolation mode. + mean (tuple[float, ...]): Image normalization mean. + std (tuple[float, ...]): Image normalization standard deviation. + re_prob (float): Random erasing probability. + re_mode (str): Random erasing fill mode. + re_count (int): Number of random erasing regions. + re_num_splits (int): Control split of random erasing across batch size. + crop_pct (Optional[float]): Inference crop percentage (output size / resize size). + crop_mode (Optional[str]): Inference crop mode. + One of ['squash', 'border', 'center']. Defaults to 'center' when None. + crop_border_pixels (Optional[int]): Inference crop border of + specified # pixels around edge of original image. + tf_preprocessing (bool): Use TF 1.0 inference preprocessing for testing model ports + use_prefetcher (bool): Pre-fetcher enabled. Do not convert image to tensor or normalize. + """ + + self._timm_image_transform = create_transform(**kwargs) + + def __call__(self, image: PIL.Image.Image) -> torch.Tensor: + return self._timm_image_transform(image) + + +class TextTransformConfig(TransformConfig): + tokenizer: PydanticTokenizerIFType + max_length: int = 77 + padding: str = "max_length" + truncation: bool = True + return_attention_mask: bool = True + + +class TextTransform(Transform): + def __init__( + self, + tokenizer: TokenizerWrapper, + max_length: int = 77, + padding: str = "max_length", + truncation: bool = True, + return_attention_mask: bool = True, + ) -> None: + """ + Args: + tokenizer (TokenizerWrapper): text tokenizer + max_length (int): maximum number of tokens. Default 77 + padding (str): padding strategy. Default "max_length" + truncation (bool): Flag which determines whether to apply truncation. Default True. + return_attention_mask (bool): Flag which determines whether the attention mask is returned. Default True. + + Returns: + None + """ + self.tokenizer = tokenizer + self.max_length = max_length + self.padding = padding + self.truncation = truncation + self.return_attention_mask = return_attention_mask + + def __call__(self, text: str) -> BatchEncoding: + batch_encoding: BatchEncoding = self.tokenizer.tokenizer( + text, + max_length=self.max_length, + padding=self.padding, + truncation=self.truncation, + return_attention_mask=self.return_attention_mask, + ) + return batch_encoding + + +class AudioTransformConfig(TransformConfig): + """ + Configuration class for the audio transformation module. + + This class defines various parameters that control the behavior of the AudioTransform. + These parameters include whether the module is in training mode, the number of mel-frequency bands, + lengths for frequency and time domain masking during training, and the target block size for audio encoding. + + Attributes: + is_training (bool): Whether the module is in training mode. Defaults to False. + n_mels (int): Number of mel-frequency bands. Defaults to 128. + freq_domain_mask_length (int): Length of frequency masking during training. Defaults to 30. + time_domain_mask_length (int): Length of time masking during training. Defaults to 100. + block_size_audio_encoder (int): Maximum allowed input length to the audio encoder. + """ + + is_training: bool = False + n_mels: int = 128 + freq_domain_mask_length: int = 30 + time_domain_mask_length: int = 100 + block_size_audio_encoder: int + + +class AudioTransform(Transform): + """ + An audio transformation module that processes raw audio into mel-spectrogram features. + + This module includes steps such as feature extraction, frequency and time domain masking during training, + padding to match a fixed block size, and returns the processed features along with their length. + """ + + def __init__( + self, + block_size_audio_encoder: int, + is_training: bool = False, + n_mels: int = 128, + freq_domain_mask_length: int = 30, + time_domain_mask_length: int = 100, + ) -> None: + """ + Initializes the AudioTransform class. + + Args: + block_size_audio_encoder (int): Maximum allowed input length to the audio encoder. + is_training (bool, optional): Whether the module is in training mode. Defaults to False. + n_mels (int, optional): Number of mel-frequency bands. Defaults to 128. + freq_domain_mask_length (int, optional): Length of frequency masking. Defaults to 30. + time_domain_mask_length (int, optional): Length of time masking. Defaults to 100. + + Returns: + tuple[torch.Tensor, int]: A tuple containing the processed audio features and their length. + """ + self.block_size_audio_encoder = block_size_audio_encoder + self.is_training = is_training + self.n_mels = n_mels + self.freq_domain_mask_length = freq_domain_mask_length + self.time_domain_mask_length = time_domain_mask_length + + def __call__(self, raw_audio: tuple[torch.Tensor, int]) -> tuple[torch.Tensor, int]: + """ + Processes the input raw audio into mel-spectrogram features. + + Args: + raw_audio (tuple[torch.Tensor, int]): A tuple containing the raw audio tensor and its sample rate. + + Returns: + tuple[torch.Tensor, int]: A tuple containing the processed audio features and their length. + """ + + SUB_SAMPLING_FACTOR = 4 # reduce the number of features (i.e., time frames) + + self.extract_features = torchaudio.transforms.MelSpectrogram(n_mels=self.n_mels) + + if self.is_training: + self.masking = torch.nn.Sequential( + torchaudio.transforms.FrequencyMasking(freq_mask_param=self.freq_domain_mask_length), + torchaudio.transforms.TimeMasking(time_mask_param=self.time_domain_mask_length), + ) + + log_mel_spec = torch.clamp(self.extract_features(raw_audio[0]), 1e-10).log10().squeeze(0) + log_mel_spec = self.masking(log_mel_spec) if self.is_training else log_mel_spec + feats_len = log_mel_spec.shape[-1] // SUB_SAMPLING_FACTOR + + assert feats_len * SUB_SAMPLING_FACTOR <= SUB_SAMPLING_FACTOR * self.block_size_audio_encoder + log_mel_spec = torch.nn.functional.pad( + log_mel_spec, (0, SUB_SAMPLING_FACTOR * self.block_size_audio_encoder - log_mel_spec.shape[-1]) + ).transpose(0, 1) + return log_mel_spec, feats_len + + +class TemporalCrop: + """ + This module crops a video along the temporal dimension + """ + + def __init__( + self, + num_frames: int, + is_training: bool = False, + ) -> None: + """ + Initializes the TemporalCrop class + + Args: + num_frames (int): The length of the clip to be cropped + is_training (bool, optional): Whether the module is in training mode. Defaults to False. + + Returns: + None + """ + self.num_frames = num_frames + self.is_training = is_training + + def __call__(self, video: torch.Tensor) -> torch.Tensor: + """ + Crops the video to a length of `num_frames`. If in training mode, the start of the crop is chosen randomly. + + Args: + video (torch.Tensor): the video to be cropped with dimensions T x H x W x C + + Returns: + cropped video (torch.Tensor): the cropped video with dimensions num_frames x C x H x W + """ + total_frames = len(video) + if self.is_training: + start = random.randint(0, total_frames - self.num_frames) + else: + start = 0 + return video[start : start + self.num_frames].permute(0, 3, 1, 2) # F C H W + + +class VideoTransformConfig(TransformConfig): + input_size: int | tuple[int, int] | tuple[int, int, int] = 224 + is_training: bool = False + num_frames: int = 16 + hflip: float = 0.0 + color_jitter: list[float] = [0.0, 0.0, 0.0, 0.0] + mean: list[float] = IMAGENET_DEFAULT_MEAN + std: list[float] = IMAGENET_DEFAULT_STD + + +class VideoTransform(Transform): + """ + A video transformation module that performs spatial and temporal transformations. + """ + + def __init__( + self, + input_size: int | tuple[int, int] | tuple[int, int, int] = 224, + is_training: bool = False, + num_frames: int = 16, + hflip: float = 0.0, + color_jitter: list[float] = [0.0, 0.0, 0.0, 0.0], + mean: list[float] = IMAGENET_DEFAULT_MEAN, + std: list[float] = IMAGENET_DEFAULT_STD, + ) -> None: + """ + Initializes the VideoTransform class + + Args: + input_size (int | tuple[int, int] | tuple[int, int, int] ): target spatial size of video frames. + is_training (bool, optional): Whether the module is in training mode. Defaults to False. + When not in training mode, resize and center crop is used instead of RandomResizedCrop, + no horizontal flip nor color jitter is performed, and the temporal crop is deterministic. + num_frames (int): target number of frames in the transformed video. Defaults to 16. + hflip (float): probability of performing horizontal flip on the frames. Defaults to 0.0. + color_jitter (list[float]): Random color jitter component factors + (brightness, contrast, saturation, hue). + Defaults to 0.0 for all components. + mean (list[float]): Image normalization mean. Defaults to IMAGENET defaults. + std (list[float]): Image normalization standard deviation. Defaults to IMAGENET defaults. + + + Returns: + None + """ + if is_training: + self.spatial_transform = transforms.Compose( + [ + transforms.RandomResizedCrop(input_size, antialias=True), + transforms.RandomHorizontalFlip(p=hflip), + transforms.ColorJitter( + brightness=color_jitter[0], + contrast=color_jitter[1], + saturation=color_jitter[2], + hue=color_jitter[3], + ), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + else: + self.spatial_transform = transforms.Compose( + [ + transforms.Resize(input_size, antialias=True), + transforms.CenterCrop(input_size), + transforms.ConvertImageDtype(torch.float), + transforms.Normalize(mean=mean, std=std), + ] + ) + self.temporal_transform = TemporalCrop(num_frames=num_frames, is_training=is_training) + + def __call__(self, video: tuple[torch.Tensor, torch.Tensor | None, int]) -> torch.Tensor: + """ + Performs spatial and temporal transformations on the input video + + Args: + video (tuple[torch.Tensor, torch.Tensor, ]): the first element is the video + to be transformed T x H' x W' x C. + The second and third elements are ignored (optional audio, audio sample rate). + + Returns: + transformed video (torch.Tensor): with dimensions num_frames x C x H x W + """ + video = video[0] + video = self.temporal_transform(video) + return self.spatial_transform(video) + + +def decord_video(key: str, data: bytes) -> None | tuple[torch.Tensor, Optional[torch.Tensor], int]: + """ + Based on the torch_video decoder in webdataset + https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py#L394 + + Decode a video file using Decord and optionally extract audio. + + This function decodes a video file from the provided data. + It first checks if the file extension is one of the supported formats. + If an audio stream exists, it extracts the audio with a mean across channels (if there are multiple). + It then uses Decord to decode uniformly sampled frames from the video. + + Args: + key (str): The key or identifier for the video data. + data (bytes): The binary data of the video file. + + Returns: + tuple: A tuple containing the decoded video frames, audio tensor (if available), and audio sample rate. + If no audio stream exists, the audio tensor will be None. + """ + extension = re.sub(r".*[.]", "", key) + if extension not in "mp4 ogv mjpeg avi mov h264 mpg webm wmv".split(): + return None + + audio = None + audio_sample_rate = -1 + stream = torchaudio.io.StreamReader(data) + for idx in range(stream.num_src_streams): + if stream.get_src_stream_info(idx).media_type == "audio": + audio, audio_sample_rate = torchaudio.load(data) + if audio.shape[0] > 1: # more than one audio channel + audio = torch.mean(audio, dim=0, keepdim=True) + break + + file_obj = io.BytesIO(data) + vr = decord.VideoReader(file_obj) + clip_num_frames = 64 + # sample clip_num_frames uniformly from the full video + frame_ids = torch.linspace(0, len(vr) - 1, clip_num_frames, dtype=torch.int64) + frames = vr.get_batch(frame_ids.tolist()) # T x H x W x C + + return (frames, audio, audio_sample_rate) + + +def torch_audio(key: str, data: bytes) -> None | tuple[torch.Tensor, int]: + """ + Based on the torch_audio decoder in webdataset + https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py#L418 + + Decode an audio file using torchaudio. + + This function decodes an audio file from the provided data. + It first checks if the file extension is one of the supported formats. + If there are multiple channels in the audio file, it averages them to produce a mono audio tensor. + + Args: + key (str): The key or identifier for the audio data. + data (bytes): The binary data of the audio file. + + Returns: + tuple: A tuple containing the decoded audio tensor and its sample rate. If the file extension is not supported, + the function will return None. + """ + + extension = re.sub(r".*[.]", "", key) + valid_extensions = "mp4 ogv mjpeg avi mov h264 mpg webm wmv flac mp3 sox wav m4a ogg wma".split() + if extension not in valid_extensions: + return None + + audio, sample_rate = torchaudio.load(data) + if audio.shape[0] > 1: # more than one channel + audio = torch.mean(audio, dim=0, keepdim=True) + return (audio, sample_rate) + + +def fixed_ratio_round_robin(*sources, samples_per_batch: list[int]): + """ + Iterator over a list of iterators. + Samples from each source iterator are selected in a round-robin fashion, with a fixed number + of samples from each iterator for a given batch, as defined by `samples_per_batch` + + + Args: + sources (list[iterator]): An arbitrary number of source iterators + samples_per_batch (list[int]): Number of samples from each source iterator + which should be present in one batch + + Yields: + sample: a sample from one of the iterators + """ + + sources = list(sources) + remaining_samples_in_batch = samples_per_batch.copy() + i = 0 + while len(sources) > 0: + try: + sample = next(sources[i]) + remaining_samples_in_batch[i] -= 1 + + # reset + if sum(remaining_samples_in_batch) == 0: + remaining_samples_in_batch = samples_per_batch.copy() + + # go to next source which has some remaining samples + i = (i + 1) % len(sources) + while remaining_samples_in_batch[i] == 0: + i = (i + 1) % len(sources) + yield sample + except StopIteration: + # stop if any modality runs out of samples + break + + +class FixedRatioRoundRobinMix(IterableDataset): + def __init__( + self, + datasets: list[wds.WebDataset], + mixing_ratios: list[float], + batch_size: int, + ) -> None: + """An iterator for a list of datasets. + Samples are yielded in a round robin manner + with a fixed ratio of samples per dataset. There is no random sampling, so the number of + samples per modality is guaranteed to be fixed per batch. + + Args: + datasets (list[WebDataset]): a list of WebDatasets to be iterated over + mixing_ratios (list[float]): the ratio of samples from each dataset that should be present in a batch + batch_size (int): size of batch containing samples from all datasets in the specified ratio + + Returns: + None + """ + self.datasets = datasets + self.samples_per_batch = [int(batch_size * ratio) for ratio in mixing_ratios] + # ensure ratio sums up to 1.0 + self.samples_per_batch[0] += batch_size - sum(self.samples_per_batch) + + def __iter__(self): + """ + Returns: + an iterator over the source datasets + """ + sources = [iter(d) for d in self.datasets] + return fixed_ratio_round_robin(*sources, samples_per_batch=self.samples_per_batch) + + +class MultimodalWebDatasetBuilderConfig(BaseModel): + urls: list[str] | str + modality_key_mapping: dict[ModalityEnum, tuple[str, str]] + modality_transforms: dict[ModalityEnum, PydanticTransformIFType] + is_audio_video: Optional[bool] = False + num_samples: Annotated[int, Field(ge=1)] + + +class MultimodalWebDatasetBuilder: + def __init__( + self, + urls: list[str] | str, + modality_key_mapping: dict[str, tuple[str, str]], + modality_transforms: dict[str, Transform], + is_audio_video: bool, + num_samples: int, + ) -> None: + """A multimodal dataset instance for the WebDataset. + + Args: + urls (list[str] or str): A webdataset url. For example: "/data/path/{00000..00012}.tar". + modality_key_mapping (dict[str, tuple[str, str]]): Mapping from dataset keys to keys + expected by the forward pass of the model. + For example: {ModalityEnum.IMAGE: ("jpg", "image"), ModalityEnum.TEXT: ("text", "caption")}} + modality_transforms (dict[str, Transform]): The transforms for each modality as a dictionary. + is_audio_video (bool): Whether the dataset is a video dataset which contains audio + num_samples (int): The number of samples for each modality combination. + + Returns: + None + """ + self.urls = urls + self.is_audio_video = is_audio_video + self.modality_key_mapping = modality_key_mapping + self.modality_transforms = modality_transforms + # transforms should be specified for all modality_key mappings, + # but we can also specify more transforms than necessary + # so modality_key_mappings should be a subset of modality_transforms + assert set(self.modality_key_mapping.keys()).issubset(self.modality_transforms.keys()) + self.modalities = list(self.modality_key_mapping.keys()) + self.num_samples = num_samples + self.web_dataset = None + + # Mapping between modality and the decode "function" + self.modality_to_decode_fn = { + ModalityEnum.TEXT: None, + ModalityEnum.IMAGE: "pil", + ModalityEnum.VIDEO: decord_video, + ModalityEnum.AUDIO: wds.torch_audio, + } + + self.additional_extracted_keys = [] + if ModalityEnum.TEXT in self.modality_transforms: + self.additional_extracted_keys.append("attention_mask") + + if ModalityEnum.AUDIO in self.modality_transforms or ModalityEnum.VIDEO in self.modality_transforms: + self.additional_extracted_keys.append("audio_len") + + # Mapping between modality and transform + self.modality_to_transform_fn = { + ModalityEnum.TEXT: self._transform_text, + ModalityEnum.IMAGE: self._transform_image, + ModalityEnum.VIDEO: self._transform_video, + ModalityEnum.AUDIO: self._transform_audio, + } + + def prepare( + self, shardshuffle: int = 100, resample: bool = True, repeat: bool = False, shuffle_buffer: int = 10_000 + ) -> None: + """ + Prepares a WebDataset object as a pipeline that includes shuffling, decoding data, and transformations + + Args: + shardshuffle (int): Number of shards that should be used for shuffling. Defaults to 100. + resample (bool): Instead of iterating in order sample random shards. + This has the issue that the model will see sample multiple times but is significantly more + efficient. Defaults to True. + repeat (bool): Repeat the dataset. Defaults to False. + shuffle_buffer (Optional[int]): Number of samples that should be used for shuffling. Defaults to 10_000. + + Returns: + None + + """ + self.web_dataset = wds.WebDataset( + urls=self.urls, + nodesplitter=self.dummy_nodesplitter if not resample else None, + shardshuffle=shardshuffle, + repeat=repeat, + handler=wds.ignore_and_continue, + resampled=resample, + ) + + # Apply shuffling to samples + if shuffle_buffer is not None and shuffle_buffer > 0: + self.web_dataset.append(wds.filters.shuffle(shuffle_buffer)) + + # Flatten the json structure for convenience + self.web_dataset.append(wds.filters.decode(partial=True)) # Decode json byte string + self.web_dataset.append(wds.filters.map(self._flatten_sample)) + + # Load the actual data + for modality_key in self.modalities: + decode_fn = self.modality_to_decode_fn[modality_key] + if decode_fn is None: + continue + self.web_dataset.append(wds.filters.decode(decode_fn, partial=True)) + + # Transform the data + for modality_key in self.modalities: + transform_fn = self.modality_to_transform_fn[modality_key] + self.web_dataset.append(wds.filters.map(transform_fn)) + + self.web_dataset.append(wds.filters.map(self._select_keys)) + + def _transform_text(self, sample: dict[str, Any]) -> dict[str, Any]: + source_key, target_key = self.modality_key_mapping[ModalityEnum.TEXT] + transform: TextTransform = self.modality_transforms[ModalityEnum.TEXT] + batch_encoding: BatchEncoding = transform(sample[source_key]) + del sample[source_key] + sample[target_key] = batch_encoding.input_ids + sample["attention_mask"] = batch_encoding.attention_mask + return sample + + def _transform_image(self, sample: dict[str, Any]) -> dict[str, Any]: + source_key, target_key = self.modality_key_mapping[ModalityEnum.IMAGE] + transform: TextTransform = self.modality_transforms[ModalityEnum.IMAGE] + sample[target_key] = transform(sample[source_key]) + del sample[source_key] + return sample + + def _transform_video(self, sample: dict[str, Any]) -> dict[str, Any]: + source_key, target_key = self.modality_key_mapping[ModalityEnum.VIDEO] + transform: VideoTransform = self.modality_transforms[ModalityEnum.VIDEO] + sample[target_key] = transform(sample[source_key]) + # if the video contains audio + if sample[source_key][1] is not None and ModalityEnum.AUDIO in self.modality_transforms and self.is_audio_video: + transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] + sample["audio"], sample["audio_len"] = transform((sample[source_key][1], sample[source_key][2])) + if "audio" not in self.additional_extracted_keys: + self.additional_extracted_keys.append("audio") + del sample[source_key] + return sample + + def _transform_audio(self, sample: dict[str, Any]) -> dict[str, Any]: + source_key, target_key = self.modality_key_mapping[ModalityEnum.AUDIO] + transform: AudioTransform = self.modality_transforms[ModalityEnum.AUDIO] + sample[target_key], sample["audio_len"] = transform(sample[source_key]) + del sample[source_key] + return sample + + def _flatten_sample(self, sample: dict[str, Any]) -> dict[str, Any]: + return flatten_dict(sample) + + def _select_keys(self, sample: dict[str, Any]) -> dict[str, Any]: + # only select the required keys from the sample + # i.e. the keys specified in modality_key_mapping + # and the additional_extracted_keys + select_keys = self.additional_extracted_keys + [v[1] for v in self.modality_key_mapping.values()] + new_sample = {} + for k, v in sample.items(): + if k not in select_keys: + continue + new_sample[k] = v + return new_sample + + @staticmethod + def dummy_nodesplitter(src, group=None): + # This node splitter is not actually splitting the data over the nodes + # but keeps the complete dataset on each node. + # This is required so that each node has the same amount of data. + # In the case of 25 shards and 16 ranks for example 7 ranks are + # without data in the second iteration. This will cause a crash once all_gather is called. + # This is only relevant for validation. + yield from src + + +PydanticMultimodalWebDatasetBuilderIFType = Annotated[ + MultimodalWebDatasetBuilder, PydanticThirdPartyTypeIF(MultimodalWebDatasetBuilder) +] + + +class MultimodalWebDatasetConfig(BaseModel): + builders: list[PydanticMultimodalWebDatasetBuilderIFType] + batch_size: Optional[int] = None + mixing_ratios: Optional[list[float]] = None + shardshuffle: int = 100 + repeat: bool = False + resample: bool = True + shuffle_buffer: Optional[int] = 10_000 + + +class MultimodalWebDataset(wds.DataPipeline, wds.compat.FluidInterface): + def __init__( + self, + builders: list[MultimodalWebDatasetBuilder], + batch_size: int = None, + mixing_ratios: Optional[list[float]] = None, + shardshuffle: int = 100, + repeat: bool = False, + resample: bool = True, + shuffle_buffer: Optional[int] = 10_000, + ) -> None: + """WebDataset for loading and combining multimodal datasets. + + Args: + builders: WebDatasetBuilder instances. + batch_size (int): batch size per device + mixing_ratios (Optinal[list[float]]): Mixing ratios of the different modality combinations. + For example: [0.3, 0.7] + shardshuffle (int): Number of shards that should be used for shuffling. Defaults to 100. + repeat (bool): Repeat the dataset. Defaults to False. + resample (bool): Instead of iterating in order sample random shards. + This has the issue that the model will see sample multiple times but is significantly more + efficient. Defaults to True. + shuffle_buffer (Optional[int]): Number of samples that should be used for shuffling. Defaults to 10_000. + + Raises: + NotImplementedError: if multiple builders are specified and at least one builder contains a + video dataset which contains audio + ValueError: if multiple builders are specified and batch size is None + + Returns: + None + """ + super().__init__() + self.builders = builders + + for builder in self.builders: + if builder.is_audio_video and len(self.builders) > 1: + raise NotImplementedError( + "It is not yet possible to include a video-audio dataset with other types of modalities" + ) + + # Build datasets + [ + b.prepare(shardshuffle=shardshuffle, resample=resample, repeat=repeat, shuffle_buffer=shuffle_buffer) + for b in self.builders + ] + + # Setup mixing ratios + self.mixing_ratios = mixing_ratios + if self.mixing_ratios is None: + uniform_ratio = 1 / len(self.builders) + self.mixing_ratios = [uniform_ratio for _ in self.builders] + assert len(self.mixing_ratios) == len(self.builders) + + if len(self.builders) > 1: + if batch_size is None: + raise ValueError("batch_size cannot be None if multiple builders are used") + datasets = [] + for b in self.builders: + datasets.append(b.web_dataset) + dataset = FixedRatioRoundRobinMix(datasets, self.mixing_ratios, batch_size) # Apply mixing at sample level + self.pipeline.append(dataset) + else: + self.pipeline.extend(self.builders[0].web_dataset.pipeline) + + self.with_length(sum([b.num_samples for b in self.builders])) diff --git a/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py index 0c5e7f072..f54a3f782 100644 --- a/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/progress_subscriber.py @@ -14,10 +14,51 @@ class DummyProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): def consume_message(self, message: Message[ProgressUpdate]): pass - def consume_dict(self, mesasge_dict: dict[str, Any]): + def consume_dict(self, message_dict: dict[str, Any]): pass +class SimpleProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): + """A subscriber object for the RichProgress observable.""" + + def __init__( + self, + train_split_num_samples: dict[str, int], + eval_splits_num_samples: dict[str, int], + ) -> None: + self.train_split_num_samples = train_split_num_samples + self.eval_splits_num_samples = eval_splits_num_samples + + def consume_message(self, message: Message[ProgressUpdate]): + if not isinstance(message.payload, ProgressUpdate): + return + + batch_progress = message.payload + completed_samples = 0 + total_samples = 0 + + [batch_progress.dataloader_tag] + + prefix = "" + if message.payload.experiment_status == ExperimentStatus.TRAIN: + prefix = "Train" + completed_samples = batch_progress.num_steps_done + total_samples = self.train_split_num_samples[batch_progress.dataloader_tag] + + elif message.payload.experiment_status == ExperimentStatus.EVALUATION: + prefix = "Evaluation" + completed_samples = batch_progress.num_steps_done + total_samples = self.eval_splits_num_samples[batch_progress.dataloader_tag] + + print( + f"{prefix}[{batch_progress.dataloader_tag}] " + f"[{completed_samples}/{total_samples} ({completed_samples*100/total_samples:.01f}%)]" + ) + + def consume_dict(self, mesasge_dict: dict[str, Any]): + raise NotImplementedError + + class RichProgressSubscriber(MessageSubscriberIF[ProgressUpdate]): """A subscriber object for the RichProgress observable.""" diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index 8086f9a15..e44054913 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -105,6 +105,15 @@ def consume_message(self, message: Message[EvaluationResultBatch]): wandb.log(data=throughput_metrics, step=eval_result.num_train_steps_done) - # wandb.log({"tokens_loss": wandb.plot.scatter("num_tokens", "loss", title="Tokens vs Loss")}) - # wandb.log({"steps_loss": wandb.plot.scatter("steps_loss", "loss", title="Steps vs Loss")}) - # wandb.log({"samples_loss": wandb.plot.scatter("samples_loss", "loss", title="Samples vs Loss")}) + num_samples = eval_result.num_train_steps_done + group_content = [f"Train [{num_samples}]:"] + + losses = [f"{k}: {v}" for k, v in losses.items()] + metrics = [f"{k}: {v}" for k, v in metrics.items()] + + if losses: + group_content.append(" ".join(losses)) + if metrics: + group_content.append(" ".join(metrics)) + + print(" ".join(group_content)) diff --git a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py index c5ae3c4f1..7f9d3b576 100644 --- a/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py +++ b/src/modalities/logging_broker/subscriber_impl/subscriber_factory.py @@ -7,6 +7,7 @@ from modalities.logging_broker.subscriber_impl.progress_subscriber import ( DummyProgressSubscriber, RichProgressSubscriber, + SimpleProgressSubscriber, ) from modalities.logging_broker.subscriber_impl.results_subscriber import ( DummyResultSubscriber, @@ -38,6 +39,24 @@ def get_rich_progress_subscriber( subscriber = ProgressSubscriberFactory.get_dummy_progress_subscriber() return subscriber + @staticmethod + def get_simple_progress_subscriber( + eval_dataloaders: list[LLMDataLoader], + train_dataloader_tag: str, + num_seen_steps: int, + num_target_steps: int, + global_rank: int, + ) -> SimpleProgressSubscriber: + if global_rank == 0: + train_split_num_samples = {train_dataloader_tag: (num_target_steps)} + + eval_splits_num_samples = {dataloader.dataloader_tag: len(dataloader) for dataloader in eval_dataloaders} + + subscriber = SimpleProgressSubscriber(train_split_num_samples, eval_splits_num_samples) + else: + subscriber = ProgressSubscriberFactory.get_dummy_progress_subscriber() + return subscriber + @staticmethod def get_dummy_progress_subscriber() -> DummyProgressSubscriber: return DummyProgressSubscriber() diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 54d8de36b..63e5c3cd1 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import torch +import torch.distributed as dist +import torch.nn.functional as F from torch.nn import CrossEntropyLoss from modalities.batch import InferenceResultBatch @@ -23,6 +25,65 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: raise NotImplementedError +class MultipleFunctionsLoss(Loss): + """Loss objects of this type use more + than one loss function and weights corresponding + to the losses to compute total loss. + """ + + def __init__( + self, + losses: list[Loss], + corrsp_weights: list[float], + tag: str = "MultipleFunctionsLoss", + ) -> None: + """MultipleFunctionsLoss Constructor + + Args: + losses (list): Initialized losses. This list should contain more than one loss. + corrsp_weights (list): Weights to be multiplied to each loss while summing up. + + Returns: + None + """ + super().__init__(tag) + + if len(losses) <= 1: + raise ValueError("Number of losses used should be more than 1.") + + self.groups = [(loss_func, weight) for loss_func, weight in zip(losses, corrsp_weights, strict=True)] + + self.cumulated_individual_losses = None + # variable storing each loss, + # summed over local batches, + # separately. + + self.reset_cumulated_individual_losses() + + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: + device = forward_batch.predictions[list(forward_batch.predictions.keys())[0]].device + total_loss = torch.tensor(0, dtype=torch.float, device=device) + for ind, (loss_func, weight) in enumerate(self.groups): + loss = loss_func(forward_batch) + self.cumulated_individual_losses[ind] += loss + total_loss += weight * loss + return total_loss + + def reset_cumulated_individual_losses( + self, + ) -> None: + """Initializes and resets the variable + accumulating each loss separately. + + Called first when the class is initialized, and then + after every logging step in trainer.py. + """ + if torch.cuda.is_available(): + self.cumulated_individual_losses = torch.zeros(len(self.groups)).to(torch.device("cuda")) + else: + self.cumulated_individual_losses = torch.zeros(len(self.groups)).to("cpu") + + class CLMCrossEntropyLoss(Loss): def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEntropyLoss"): super().__init__(tag) @@ -33,6 +94,11 @@ def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEnt def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: labels = forward_batch.get_targets(self.target_key) + + if "attention_mask" in forward_batch.targets: + attention_mask = forward_batch.get_targets("attention_mask") + labels[attention_mask == 0] = -100 + lm_logits = forward_batch.get_predictions(self.prediction_key) # move labels to correct device to enable model parallelism @@ -122,3 +188,87 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: contiguous_embedding1, contiguous_embedding2, embedding1.device, self.is_asymmetric, self.temperature ) return loss + + +class ClipLoss(Loss): + def __init__( + self, + logit_scale_key: str, + prediction_keys: list[str], + local_loss: bool, + tag: str = "ClipLoss", + ): + """ + CLIP Loss (Source: https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/loss.py) + + Args: + logit_scale_key (str): Value of a learnable logit scale parameter. + prediction_keys (list[str]): Keys to access embeddings. + tag (str, optional): Defaults to "ClipLoss". + """ + super().__init__(tag) + self.logit_scale_key = logit_scale_key + self.prediction_keys = prediction_keys + self.local_loss = local_loss + + if not (2 <= len(prediction_keys) <= 3): + raise ValueError("ClipLoss requires either 2 or 3 prediction keys.") + + def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: + """ + Args: + forward_batch (InferenceResultBatch): data batch. + + Returns: + torch.Tensor: loss tensor. + """ + logit_scale = forward_batch.get_predictions(self.logit_scale_key) + + embeddings = [forward_batch.get_predictions(key).contiguous() for key in self.prediction_keys] + device = embeddings[0].device + + # Gather all embeddings from each rank + world_size = dist.get_world_size() + rank = dist.get_rank() + + gathered_embeddings = [[torch.zeros_like(embedding) for _ in range(world_size)] for embedding in embeddings] + + for gathered_embedding, embedding in zip(gathered_embeddings, embeddings): + dist.all_gather(gathered_embedding, embedding) + + # Make sure we have gradients for the "local" embeddings + if not self.local_loss: + for gathered_embedding, embedding in zip(gathered_embeddings, embeddings): + gathered_embedding[rank] = embedding + + # Combine embeddings + gathered_embeddings = [torch.cat(gathered_embedding, dim=0) for gathered_embedding in gathered_embeddings] + + # Calculate logits + logits_per_embeddings = [] + for i, embedding in enumerate(embeddings): + for j, gathered_embedding in enumerate(gathered_embeddings): + if i != j: + if self.local_loss: + logits = logit_scale * embedding @ gathered_embedding.T + else: + logits = logit_scale * gathered_embeddings[i] @ gathered_embeddings[j].T + logits_per_embeddings.append(logits) + + # Build gt labels for diagonal + num_logits = logits_per_embeddings[0].shape[0] + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if world_size > 1 and self.local_loss: + labels = labels + num_logits * rank + + # Calculate loss + losses = None + for logits in logits_per_embeddings: + if losses is None: + losses = F.cross_entropy(logits, labels) + else: + losses += F.cross_entropy(logits, labels) + + clip_loss = losses.mean() + + return clip_loss diff --git a/src/modalities/models/audio_transformer/__init__.py b/src/modalities/models/audio_transformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/models/audio_transformer/audio_transformer_model.py b/src/modalities/models/audio_transformer/audio_transformer_model.py new file mode 100644 index 000000000..66184f9f8 --- /dev/null +++ b/src/modalities/models/audio_transformer/audio_transformer_model.py @@ -0,0 +1,350 @@ +from typing import Annotated + +import torch +from pydantic import BaseModel, Field +from torch import nn + +from modalities.nn.attention import AttentionConfig, AttentionType, MultiHeadAttention +from modalities.nn.mlp import MLP + + +class AudioTransformerConfig(BaseModel): + """ + Configuration for an audio transformer model using conformer blocks. + + This configuration class defines all necessary parameters to instantiate and configure an `AudioTransformer` model. + + Attributes: + sample_key (str): The key in the input dictionary that contains the audio samples. + prediction_key (str): The key under which the model's output will be stored in the output dictionary. + block_size (int): The size of each block for positional embeddings. Must be a positive integer. + n_mels (int): The number of mel-frequency bands used for input audio feature extraction. + Must be a positive integer. + n_embd (int): The embedding dimension used throughout the model. Must be a positive integer. + n_heads (int): The number of attention heads in the conformer blocks. Must be a positive integer. + n_conformer_blocks (int): The number of conformer blocks to include in the transformer model. + Must be a positive integer. + attention_config (AttentionConfig): Configuration object for attention mechanisms. + pointwise_conv_kernel_size (int): Kernel size for the pointwise convolutional layers in conformer blocks. + Must be a positive integer. + depthwise_conv_kernel_size (int): Kernel size for the depthwise convolutional layers in conformer blocks. + Must be a positive integer. + ffmodule_dropout (float, optional): Dropout rate for feed-forward modules in conformer blocks. + Must be a float less than 1.0. Default is 0.1. + attn_dropout (float, optional): Dropout rate for attention mechanisms. Must be a float less than 1.0. + Default is 0.1. + convmodule_dropout (float, optional): Dropout rate for depthwise convolutional layers in conformer blocks. + Must be a float less than 1.0. Default is 0.1. + """ + + sample_key: str + prediction_key: str + block_size: Annotated[int, Field(ge=1)] + n_mels: Annotated[int, Field(ge=1)] + n_embd: Annotated[int, Field(ge=1)] + n_heads: Annotated[int, Field(ge=1)] + n_conformer_blocks: Annotated[int, Field(ge=1)] + attention_config: AttentionConfig + pointwise_conv_kernel_size: Annotated[int, Field(ge=1)] + depthwise_conv_kernel_size: Annotated[int, Field(ge=1)] + ffmodule_dropout: Annotated[float, Field(lt=1.0)] = 0.1 + attn_dropout: Annotated[float, Field(lt=1.0)] = 0.1 + convmodule_dropout: Annotated[float, Field(lt=1.0)] = 0.1 + + +class ConvolutionModule(nn.Module): + """ + A convolutional module designed to process sequences using a series of layers including LayerNorm, + pointwise convolutions, GLU activation, depthwise convolution, batch normalization, SiLU (Swish) activation, + and a final pointwise convolution. + """ + + def __init__( + self, + n_embd: int, + pointwise_conv_kernel_size: int, + depthwise_conv_kernel_size: int, + dropout: float, + ): + """ + Initializes the ConvolutionModule class. + + Args: + n_embd (int): The number of embedding dimensions. Must be a positive integer. + pointwise_conv_kernel_size (int): The kernel size for both the first and second pointwise convolutions. + depthwise_conv_kernel_size (int): The kernel size for the depthwise convolution. + dropout (float): Dropout rate applied after each layer. Must be a float between 0 and 1. + """ + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.pointwise_conv_1 = nn.Conv1d( + n_embd, + 2 * n_embd, + pointwise_conv_kernel_size, + padding="same", + ) + self.glu = nn.GLU(dim=1) + self.depthwise_conv = nn.Conv1d( + n_embd, + n_embd, + kernel_size=depthwise_conv_kernel_size, + groups=n_embd, + padding="same", + ) + self.batch_norm = nn.BatchNorm1d( + n_embd, + ) + self.swish = nn.SiLU() + self.pointwise_conv_2 = nn.Conv1d( + n_embd, + n_embd, + pointwise_conv_kernel_size, + padding="same", + ) + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass through the convolutional module. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, D), where B is the batch size, + T is the number of time steps, and D is the embedding dimension. + + Returns: + torch.Tensor: Output tensor of shape (B, T, D). + """ + if x.shape[1] == 1: + raise ValueError("The time dimension of the input to the convolution module cannot be 1!") + + x = self.ln_1(x) + x = x.transpose(1, 2) + x = self.glu(self.pointwise_conv_1(x)) + x = self.swish(self.batch_norm(self.depthwise_conv(x))) + x = self.pointwise_conv_2(x) + return self.dropout(x.transpose(1, 2)) + + +class ConformerBlock(nn.Module): + """ + This block combines self-attention, feed-forward modules, and depthwise convolutional layers to provide + efficient processing of sequential data. + """ + + def __init__( + self, + n_embd: int, + n_heads: int, + attention_config: AttentionConfig, + pointwise_conv_kernel_size: int, + depthwise_conv_kernel_size: int, + ffmodule_dropout: float, + attn_dropout: float, + convmodule_dropout: float, + ) -> None: + """Initializes the ConformerBlock class. + + Args: + n_embd (int): The number of expected features in the input. + n_heads (int): Number of parallel attention heads. + attention_config (AttentionConfig): Configuration for the attention mechanism, typically a dictionary or \ + class instance. + pointwise_conv_kernel_size (int): Kernel size of the depthwise convolutional layer. + depthwise_conv_kernel_size (int): The kernel size for the depthwise convolutional module. + ffmodule_dropout (float): Dropout rate for feed-forward modules. + attn_dropout (float): Dropout rate for attention mechanism. + convmodule_dropout (float): Dropout rate for the convolutional module. + """ + super().__init__() + + self.ln_1 = nn.LayerNorm(n_embd) + self.entry_ffmodule = MLP( + in_features=n_embd, + act_fn=nn.SiLU, + dropout=ffmodule_dropout, + ) + self.ln_mhsa = nn.LayerNorm(n_embd) + self.attn = MultiHeadAttention( + attention_config=attention_config, + attention_type=AttentionType.NON_CAUSAL_SELF_ATTENTION, + n_embd=n_embd, + n_head=n_heads, + dropout=attn_dropout, + ) + self.convmodule = ConvolutionModule( + n_embd, + pointwise_conv_kernel_size, + depthwise_conv_kernel_size, + convmodule_dropout, + ) + self.ln_2 = nn.LayerNorm( + n_embd, + ) + self.exit_ffmodule = MLP( + in_features=n_embd, + act_fn=nn.SiLU, + dropout=ffmodule_dropout, + ) + self.exit_ln = nn.LayerNorm( + n_embd, + ) + + def forward( + self, + x: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + """ + Forward pass through the conformer block. + + Args: + x (torch.Tensor): Input tensor of shape (B, T, D), where B is the batch size, + T is the number of time steps, and D is the embedding dimension. + mask (torch.Tensor): Attention mask of shape (N, 1, L) or (N, L, L), where N is the batch size, + L is the sequence length. If not provided, no attention mask will be used. + + Returns: + torch.Tensor: Output tensor of shape (B, T, D). + """ + x = self.ln_1(x) + x = x + 0.5 * self.entry_ffmodule(x) + x = x + self.attn(self.ln_mhsa(x), mask=mask) + x = x + self.convmodule(x) + x = self.ln_2(x) + x = x + 0.5 * self.exit_ffmodule(x) + return self.exit_ln(x) + + +class AudioTransformer(nn.Module): + """An audio transformer model using conformer blocks for processing audio data and generating predictions. + + This model includes convolutional layers, subsampling, positional embeddings, + and multiple conformer blocks for feature extraction and processing.""" + + def __init__( + self, + *, + sample_key: str, + prediction_key: str, + block_size: int, + n_mels: int, + n_embd: int, + n_heads: int, + n_conformer_blocks: int, + attention_config: AttentionConfig, + pointwise_conv_kernel_size: int, + depthwise_conv_kernel_size: int, + ffmodule_dropout: float = 0.1, + attn_dropout: float = 0.1, + convmodule_dropout: float = 0.1, + ): + """ + Initializes the AudioTransformer model. + + Args: + sample_key (str): The key in the input dictionary that contains the audio samples. + prediction_key (str): The key under which the model's output will be stored in the output dictionary. + block_size (int): The size of each block for positional embeddings. + n_mels (int): The number of mel-frequency bands used for input audio feature extraction. + n_embd (int): The embedding dimension used throughout the model. + n_heads (int): The number of attention heads in the conformer blocks. + n_conformer_blocks (int): The number of conformer blocks to include in the transformer model. + attention_config (AttentionConfig): Configuration object for attention mechanisms. + pointwise_conv_kernel_size (int): Kernel size for the pointwise convolutional layers in conformer blocks. + depthwise_conv_kernel_size (int): Kernel size for the depthwise convolutional layers in conformer blocks. + ffmodule_dropout (float): Dropout rate for feed-forward modules in conformer blocks. Default is 0.1. + attn_dropout (float): Dropout rate for attention mechanisms. Default is 0.1. + convmodule_dropout (float): Dropout rate for depthwise convolutional layers in conformer blocks. + Default is 0.1. + """ + super().__init__() + self.sample_key = sample_key + self.prediction_key = prediction_key + self.block_size = block_size + + self.project = nn.Conv1d(in_channels=n_mels, out_channels=n_embd, kernel_size=3, padding="same") + self.subsampler = nn.Sequential( + nn.Conv1d( + in_channels=n_embd, + out_channels=n_embd, + kernel_size=2, + stride=2, + ), + nn.Conv1d( + in_channels=n_embd, + out_channels=n_embd, + kernel_size=2, + stride=2, + ), + ) + self.post_subsampler_linear = nn.Sequential( + nn.Linear(n_embd, n_embd), + nn.Dropout(0.1), + ) + + self.positional_embeddings = nn.Embedding(self.block_size, n_embd) + self.conformer_blocks = nn.ModuleList( + [ + ConformerBlock( + n_embd, + n_heads, + attention_config, + pointwise_conv_kernel_size, + depthwise_conv_kernel_size, + ffmodule_dropout, + attn_dropout, + convmodule_dropout, + ) + for _ in range(n_conformer_blocks) + ] + ) + + def forward( + self, + inputs: dict[str, tuple[torch.Tensor, torch.Tensor]], + ) -> dict[str, tuple[torch.Tensor, torch.Tensor]]: + """ + Forward pass of the AudioTransformer model. + + Args: + inputs (dict[str, tuple[torch.Tensor, torch.Tensor]]): A dictionary containing the input tensors. + It must include the key specified by `sample_key`. + + Returns: + dict[str, tuple[torch.Tensor, torch.Tensor]]: A dictionary with a single key specified by `prediction_key`,\ + containing the model's output. + """ + x = inputs[self.sample_key] # x.shape: B, T, D + attn_key_mask = self._get_attn_key_mask(inputs["audio_len"]) + # x.shape: B, T, D + x = self.project(x.transpose(1, 2)) # x.shape: B, D, T + x = self.subsampler(x) # x.shape: B, D, T/4 + x = x.transpose(1, 2) + x = self.post_subsampler_linear(x) + x = x + self.positional_embeddings.weight + + for block in self.conformer_blocks: + x = block(x, attn_key_mask) + return {self.prediction_key: x} + + def _get_attn_key_mask( + self, + lengths: torch.Tensor, + ) -> torch.Tensor: + # Generates an attention key mask based on input sequence lengths. + stack = [] + for length in lengths: + ones = torch.ones(length, self.block_size) + ones[1:, length:] = 0 + stack.append(ones) + return ( + torch.nn.utils.rnn.pad_sequence( + stack + [torch.zeros(self.block_size, self.block_size)], + batch_first=True, + ) + .transpose(1, 2)[:-1] + .unsqueeze_(1) + ).to(lengths.device) diff --git a/src/modalities/models/coca/coca_model.py b/src/modalities/models/coca/coca_model.py index e299d768c..20bbaceba 100644 --- a/src/modalities/models/coca/coca_model.py +++ b/src/modalities/models/coca/coca_model.py @@ -1,10 +1,13 @@ -from typing import Annotated +from typing import Annotated, Optional +import numpy as np import torch +import torch.nn.functional as F from einops import repeat from pydantic import BaseModel, Field from torch import nn +from modalities.models.audio_transformer.audio_transformer_model import AudioTransformer, AudioTransformerConfig from modalities.models.coca.attention_pooling import AttentionPooling from modalities.models.coca.multi_modal_decoder import MultiModalTextDecoder from modalities.models.coca.text_decoder import TextDecoder @@ -57,30 +60,59 @@ class CoCaConfig(BaseModel): Args: prediction_key (str): The key for the predictions. - vision_embd_prediction_key (str): The key for the vision embeddings. text_embd_prediction_key (str): The key for the text embeddings. - vision_cls_prediction_key (str): The key for the vision cls token. - text_cls_prediction_key (str): The key for the text cls token. - vision_encoder_config (VisionTransformerConfig): Configuration for the vision encoder. + logit_scale_prediction_key (str): The key for the logit scale + text_cls_prediction_key (Optional[str]): The key for the text cls token. + audio_embd_prediction_key (Optional[str]): The key for audio embeddings + image_embd_prediction_key (Optional[str]): The key for image embeddings + video_embd_prediction_key (Optional[str]): The key for video embeddings + audio_cls_prediction_key (Optional[str]): Th key for the audio cls token + audio_text_cls_prediction_key (Optional[str]): Th key for the text cls token associated with the audio samples + image_cls_prediction_key (Optional[str]): Th key for the image cls token + image_text_cls_prediction_key (Optional[str]): Th key for the text cls token associated with the image samples + video_cls_prediction_key (Optional[str]): Th key for the video cls token + video_text_cls_prediction_key (Optional[str]): Th key for the text cls token associated with the video samples + modality_keys (list[str]): sample keys in the input associated with the input modalities + individual_datasets (Optional[bool]): flag indicating whether + there are separate datasets for different modalities + is_audio_video (Optional[bool]): flag indicating whether the video samples contain audio + audio_encoder_config (Optional[AudioTransformerConfig]): config for the audio encoder. Defaults to None. + image_encoder_config (Optional[VisionTransformerConfig]): config for the image encoder. Defaults to None + video_encoder_config (Optional[VisionTransformerConfig]): config for the video encoder. Defaults to None text_decoder_config (TextDecoderConfig): Configuration for the text decoder. n_pool_head (int): Number of attention heads for pooling. - n_vision_queries (int): Number of vision queries. + n_queries (int): Number of queries for attention pooling. bias_attn_pool (bool): Flag indicating whether to use bias in attention pooling. epsilon_attn_pool (float): Epsilon value for attention pooling. + seed (Optional[int]): The random seed. Defaults to None """ prediction_key: str = "logits" - vision_embd_prediction_key: str # same key as vision encoder text_embd_prediction_key: str - vision_cls_prediction_key: str - text_cls_prediction_key: str - vision_encoder_config: VisionTransformerConfig + logit_scale_prediction_key: str + text_cls_prediction_key: Optional[str] = None + audio_embd_prediction_key: Optional[str] = None + image_embd_prediction_key: Optional[str] = None + video_embd_prediction_key: Optional[str] = None + audio_cls_prediction_key: Optional[str] = None + audio_text_cls_prediction_key: Optional[str] = None + image_cls_prediction_key: Optional[str] = None + image_text_cls_prediction_key: Optional[str] = None + video_cls_prediction_key: Optional[str] = None + video_text_cls_prediction_key: Optional[str] = None + modality_keys: list[str] + individual_datasets: Optional[bool] = False + is_audio_video: Optional[bool] = False + audio_encoder_config: Optional[AudioTransformerConfig] = None + image_encoder_config: Optional[VisionTransformerConfig] = None + video_encoder_config: Optional[VisionTransformerConfig] = None text_decoder_config: TextDecoderConfig n_pool_head: Annotated[int, Field(ge=1)] - n_vision_queries: Annotated[int, Field(ge=1)] + n_queries: Optional[Annotated[int, Field(ge=1)]] bias_attn_pool: bool epsilon_attn_pool: Annotated[float, Field(ge=0.0)] + seed: Optional[int] = None class CoCa(NNModel): @@ -97,45 +129,147 @@ class CoCa(NNModel): def __init__( self, prediction_key: str, - vision_cls_prediction_key: str, - text_cls_prediction_key: str, - vision_embd_prediction_key: str, text_embd_prediction_key: str, - n_vision_queries: int, + logit_scale_prediction_key: str, + text_cls_prediction_key: Optional[str], + audio_embd_prediction_key: Optional[str], + image_embd_prediction_key: Optional[str], + video_embd_prediction_key: Optional[str], + audio_cls_prediction_key: Optional[str], + audio_text_cls_prediction_key: Optional[str], + image_cls_prediction_key: Optional[str], + image_text_cls_prediction_key: Optional[str], + video_cls_prediction_key: Optional[str], + video_text_cls_prediction_key: Optional[str], + modality_keys: list[str], + individual_datasets: Optional[bool], + is_audio_video: Optional[bool], + audio_encoder_config: Optional[AudioTransformerConfig], + image_encoder_config: Optional[VisionTransformerConfig], + video_encoder_config: Optional[VisionTransformerConfig], + text_decoder_config: TextDecoderConfig, n_pool_head: int, + n_queries: Optional[int], bias_attn_pool: bool, epsilon_attn_pool: float, - vision_encoder_config: VisionTransformerConfig, - text_decoder_config: TextDecoderConfig, + seed: int = None, ) -> None: """ Initializes the CocaModel object. Args: prediction_key (str): The key for the predictions. - vision_cls_prediction_key (str): The key for the vision cls token. - text_cls_prediction_key (str): The key for the text cls token. - vision_embd_prediction_key (str): The key for the vision embeddings. text_embd_prediction_key (str): The key for the text embeddings. - - n_vision_queries (int): The number of vision queries. - n_pool_head (int): The number of pool heads. + logit_scale_prediction_key (str): The key for the logit scale + text_cls_prediction_key (Optional[str]): The key for the text cls token. + audio_embd_prediction_key (Optional[str]): The key for audio embeddings + image_embd_prediction_key (Optional[str]): The key for image embeddings + video_embd_prediction_key (Optional[str]): The key for video embeddings + audio_cls_prediction_key (Optional[str]): Th key for the audio cls token + audio_text_cls_prediction_key (Optional[str]): Th key for the text cls token + associated with the audio samples + image_cls_prediction_key (Optional[str]): Th key for the image cls token + image_text_cls_prediction_key (Optional[str]): Th key for the text cls token + associated with the image samples + video_cls_prediction_key (Optional[str]): Th key for the video cls token + video_text_cls_prediction_key (Optional[str]): Th key for the text cls token + associated with the video samples + modality_keys (list[str]): sample keys in the input associated with the input modalities + individual_datasets (Optional[bool]): flag indicating whether there are separate datasets + for different modalities + is_audio_video (Optional[bool]): flag indicating whether the video samples contain audio + audio_encoder_config (Optional[AudioTransformerConfig]): config for the audio encoder. Defaults to None. + image_encoder_config (Optional[VisionTransformerConfig]): config for the image encoder. Defaults to None + video_encoder_config (Optional[VisionTransformerConfig]): config for the video encoder. Defaults to None + text_decoder_config (TextDecoderConfig): Configuration for the text decoder. + n_pool_head (int): Number of attention heads for pooling. + n_queries (int): Number of queries for attention pooling. bias_attn_pool (bool): Flag indicating whether to use bias in attention pooling. - epsilon_attn_pool (float): The epsilon value for attention pooling. - vision_encoder_config (VisionTransformerConfig): The configuration for the vision encoder. - text_decoder_config (TextDecoderConfig): The configuration for the text decoder. + epsilon_attn_pool (float): Epsilon value for attention pooling. + seed (Optional[int]): The random seed. Defaults to None + + Raises: + ValueError: if none of the modality encoders are defined + ValueError: if using individual dataset and none of the text cls tokens + corresponding to the modalities is defined + ValueError: if training on a single dataset and text_cls_prediction_key is not defined Returns: None """ - super().__init__() + weight_decay_groups = { + "linear": [r"attention", r"\.attn", r"\.cross_attn", r"\.post_subsampler", r"_ffmodule", r"mlp"], + "conv": [r"embedding_fn\.conv", r"project", r"\.subsampler", r"pointwise_conv", r"depthwise_conv"], + "embedding": [r"wte", r"wpe", r"positional_embedding", r"time_embd"], + "norm": [r"norm", r"norm_latents", r"\.ln_", r"\.batch_norm", r"exit_ln"], + "parameter": [r"_queries", r"logit_scale", r"\.latents", r"cls_token"], + } + super().__init__(weight_decay_groups=weight_decay_groups, seed=seed) + + if individual_datasets: + if ( + not audio_text_cls_prediction_key + and not image_text_cls_prediction_key + and not video_text_cls_prediction_key + ): + raise ValueError("All text_cls_prediction_keys cannot be None") + else: + if not text_cls_prediction_key: + raise ValueError("text_cls_prediction key cannot be None") + if not audio_encoder_config and not image_encoder_config and not video_encoder_config: + raise ValueError("Atleast one modality encoder config should be specified") + self.prediction_key = prediction_key - self.vision_cls_prediction_key = vision_cls_prediction_key - self.text_cls_prediction_key = text_cls_prediction_key - self.vision_embd_prediction_key = vision_embd_prediction_key self.text_embd_prediction_key = text_embd_prediction_key + self.logit_scale_prediction_key = logit_scale_prediction_key + self.text_cls_prediction_key = text_cls_prediction_key + + self.audio_embd_prediction_key = audio_embd_prediction_key + self.image_embd_prediction_key = image_embd_prediction_key + self.video_embd_prediction_key = video_embd_prediction_key + self.audio_cls_prediction_key = audio_cls_prediction_key + self.audio_text_cls_prediction_key = audio_text_cls_prediction_key + self.image_cls_prediction_key = image_cls_prediction_key + self.image_text_cls_prediction_key = image_text_cls_prediction_key + self.video_cls_prediction_key = video_cls_prediction_key + self.video_text_cls_prediction_key = video_text_cls_prediction_key + + self.modality_keys = modality_keys + self.individual_datasets = individual_datasets + self.is_audio_video = is_audio_video + + self.n_pool_head = n_pool_head + self.bias_attn_pool = bias_attn_pool + self.epsilon_attn_pool = epsilon_attn_pool + self.text_decoder_config = text_decoder_config + + self.image_sample_key = None + if image_encoder_config is not None: + self.image_sample_key = image_encoder_config.sample_key + self.image_encoder, self.image_queries, self.image_attn_pool = self._init_modality( + VisionTransformer, + image_encoder_config, + n_queries, + ) + + self.video_sample_key = None + if video_encoder_config is not None: + self.video_sample_key = video_encoder_config.sample_key + self.video_encoder, self.video_queries, self.video_attn_pool = self._init_modality( + VisionTransformer, + video_encoder_config, + n_queries, + ) + + self.audio_sample_key = None + if audio_encoder_config is not None: + self.audio_sample_key = audio_encoder_config.sample_key + self.audio_encoder, self.audio_queries, self.audio_attn_pool = self._init_modality( + AudioTransformer, + audio_encoder_config, + n_queries, + ) - self.vision_encoder = VisionTransformer(**dict(vision_encoder_config)) self.text_decoder = TextDecoder( sample_key=text_decoder_config.sample_key, prediction_key=text_embd_prediction_key, @@ -160,6 +294,7 @@ def __init__( n_head=text_decoder_config.n_head, n_embd=text_decoder_config.n_embd, ffn_hidden=text_decoder_config.ffn_hidden, + is_audio_video=self.is_audio_video, dropout=text_decoder_config.dropout, bias=text_decoder_config.bias, attention_config=text_decoder_config.attention_config, @@ -171,80 +306,148 @@ def __init__( self.multimodal_decoder.lm_head.weight ) # https://paperswithcode.com/method/weight-tying - # vision_queries: 256 queries for multimodal cross attention and 1 as vision cls token for contrastive learning - self.vision_queries = nn.Parameter(torch.randn(n_vision_queries + 1, vision_encoder_config.n_embd)) - self.attn_pool = AttentionPooling( - n_embd=vision_encoder_config.n_embd, - n_head=n_pool_head, - bias=bias_attn_pool, - epsilon=epsilon_attn_pool, - attention_config=text_decoder_config.attention_config, + # Logit scale for contrastive loss + self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + + def _init_modality( + self, encoder_class: type, encoder_config: VisionTransformerConfig | AudioTransformerConfig, n_queries: int + ) -> tuple[VisionTransformer | AudioTransformer, nn.Parameter, AttentionPooling]: + # initialize modality encoder, returns a tuple containing the encoder, queries and attention pooling layer + encoder = encoder_class(**dict(encoder_config)) + queries = nn.Parameter(torch.randn(n_queries + 1, encoder_config.n_embd)) + attn_pool = AttentionPooling( + n_embd=encoder_config.n_embd, + n_head=self.n_pool_head, + bias=self.bias_attn_pool, + epsilon=self.epsilon_attn_pool, + attention_config=self.text_decoder_config.attention_config, ) + return encoder, queries, attn_pool def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Forward pass of the CoCa model. Args: - inputs (dict[str, torch.Tensor]): Input dictionary containing the tensors. + inputs (dict[str, torch.Tensor]): Input dictionary containing the text and modality samples + In case of multiple modalities, the 'input_ids' key contain the token ids for + the text corresponding to all the modalities stacked together. Thus the length (batch size) + of 'input_ids' will be equal to the sum of the lengths of the individual modality + samples. Returns: - dict[str, torch.Tensor]: Output dictionary. + dict[str, torch.Tensor]: Output dictionary containing + - cls token(s) for the modality or modalities + - text cls token(s) corresponding to the modality sample(s) + - logits from the text decoder + - logit_scale """ - vision_embd, vision_cls_token = self._forward_encode_vision(inputs) - text_embd, text_cls_token = self._forward_encode_text(inputs) - logits = self._forward_decode(text_embd, vision_embd) - return { - self.prediction_key: logits, - self.vision_cls_prediction_key: vision_cls_token, - self.text_cls_prediction_key: text_cls_token, - } + output = {} - def _forward_encode_vision(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: - """ - Encodes the input image using the vision encoder. + # encode modalities + image_embd = audio_embd = video_embd = None + if self.image_sample_key: + image_embd, image_cls_token = self._forward_encode_image(inputs) + output[self.image_cls_prediction_key] = image_cls_token - Args: - inputs (dict[str, torch.Tensor]): Dictionary containing vision inputs. + if self.audio_sample_key: + audio_embd, audio_cls_token = self._forward_encode_audio(inputs) + output[self.audio_cls_prediction_key] = audio_cls_token - Returns: - tuple[torch.Tensor, torch.Tensor]: Tuple containing encoded vision embeddings and classification token. - """ - vision_embd = self.vision_encoder(inputs)[self.vision_embd_prediction_key] - queries = repeat(self.vision_queries, "n d -> b n d", b=vision_embd.shape[0]) - vision_embd = self.attn_pool(queries, context=vision_embd) - vision_embd, vision_cls_token = vision_embd[:, :-1, :], vision_embd[:, -1:, :] - return vision_embd, vision_cls_token + if self.video_sample_key: + video_embd, video_cls_token = self._forward_encode_video(inputs) + output[self.video_cls_prediction_key] = video_cls_token - def _forward_encode_text(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: - """ - Encodes the input text using the text decoder. + # encode text + text_embd, text_cls_token = self._forward_encode_text(inputs) - Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. + # decode modality + text + if self.individual_datasets: # multiple modalities (from different datasets) + start = 0 + modality_logits = [] + # this ensures that we select the text input_ids corresponding to each modality_key in the order + # they are stacked by the collator + for modality_key in self.modality_keys: + if modality_key == "images" and image_embd is not None: + image_text_cls_token = text_cls_token[start : start + len(image_embd)] + image_text_embd = text_embd[start : start + len(image_embd)] + image_logits = self._forward_decode(image_text_embd, image_embd) + output.update({self.image_text_cls_prediction_key: image_text_cls_token}) + modality_logits.append(image_logits) + start = start + len(image_embd) + if modality_key == "audio" and audio_embd is not None: + audio_text_cls_token = text_cls_token[start : start + len(audio_embd)] + audio_text_embd = text_embd[start : start + len(audio_embd)] + audio_logits = self._forward_decode(audio_text_embd, audio_embd) + output.update({self.audio_text_cls_prediction_key: audio_text_cls_token}) + modality_logits.append(audio_logits) + start = start + len(audio_embd) + if modality_key == "video" and video_embd is not None: + video_text_cls_token = text_cls_token[start : start + len(video_embd)] + video_text_embd = text_embd[start : start + len(video_embd)] + video_logits = self._forward_decode(video_text_embd, video_embd) + output.update({self.video_text_cls_prediction_key: video_text_cls_token}) + modality_logits.append(video_logits) + start = start + len(video_embd) + logits = torch.cat(modality_logits) + elif audio_embd is not None and video_embd is not None: # video dataset that contains audio + modality_embd = {"audio": audio_embd, "video": video_embd} + logits = self._forward_decode(text_embd, modality_embd) + output.update({self.text_cls_prediction_key: text_cls_token}) + else: # single modality + output.update({self.text_cls_prediction_key: text_cls_token}) + if image_embd is not None: + logits = self._forward_decode(text_embd, image_embd) + elif audio_embd is not None: + logits = self._forward_decode(text_embd, audio_embd) + elif video_embd is not None: + logits = self._forward_decode(text_embd, video_embd) + + output.update( + { + self.prediction_key: logits, + self.logit_scale_prediction_key: self.logit_scale.exp(), + } + ) + return output + + def _forward_encode_image(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + # returns a tuple containing the image embeddings and cls token + image_embd = self.image_encoder(inputs)[self.image_embd_prediction_key] + queries = repeat(self.image_queries, "n d -> b n d", b=image_embd.shape[0]) + image_embd = self.image_attn_pool(queries, context=image_embd) + image_embd, image_cls_token = image_embd[:, :-1, :], F.normalize(image_embd[:, -1, :], dim=-1) + return image_embd, image_cls_token + + def _forward_encode_video(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + # returns a tuple containing the video embeddings and cls token + video_embd = self.video_encoder(inputs)[self.video_embd_prediction_key] + queries = repeat(self.video_queries, "n d -> b n d", b=video_embd.shape[0]) + video_embd = self.video_attn_pool(queries, context=video_embd) + video_embd, video_cls_token = video_embd[:, :-1, :], F.normalize(video_embd[:, -1, :], dim=-1) + return video_embd, video_cls_token + + def _forward_encode_audio(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + # returns a tuple containing the audio embeddings and cls token + audio_embd = self.audio_encoder(inputs)[self.audio_embd_prediction_key] + queries = repeat(self.audio_queries, "n d -> b n d", b=audio_embd.shape[0]) + audio_embd = self.audio_attn_pool(queries, context=audio_embd) + audio_embd, audio_cls_token = audio_embd[:, :-1, :], F.normalize(audio_embd[:, -1, :], dim=-1) + return audio_embd, audio_cls_token - Returns: - tuple[torch.Tensor, torch.Tensor]: A tuple containing the encoded text tensor - and the classification token tensor. - """ + def _forward_encode_text(self, inputs: dict[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + # returns a tuple containing the encoded text tensor and the cls token text_embd = self.text_decoder(inputs)[self.text_embd_prediction_key] - text_embd, text_cls_token = text_embd[:, :-1, :], text_embd[:, -1:, :] + text_embd, text_cls_token = text_embd[:, :-1, :], F.normalize(text_embd[:, -1, :], dim=-1) return text_embd, text_cls_token - def _forward_decode(self, text_embd: torch.Tensor, vision_embd: torch.Tensor) -> torch.Tensor: - """ - Perform forward decoding using the given text and vision embeddings. - - Args: - text_embd (torch.Tensor): The text embeddings. - vision_embd (torch.Tensor): The vision embeddings. - - Returns: - torch.Tensor: The logits obtained from the multimodal decoder. - """ + def _forward_decode( + self, text_embd: torch.Tensor, modality_embd: list[torch.Tensor] | torch.Tensor + ) -> torch.Tensor: + # forward decode given the text and modality embedding(s) decoder_inputs = { self.text_embd_prediction_key: text_embd, - "context": vision_embd, + "context": modality_embd, } decoder_outputs = self.multimodal_decoder(decoder_inputs) logits = decoder_outputs[self.multimodal_decoder.prediction_key] diff --git a/src/modalities/models/coca/collator.py b/src/modalities/models/coca/collator.py index 437db1ece..c476c6831 100644 --- a/src/modalities/models/coca/collator.py +++ b/src/modalities/models/coca/collator.py @@ -71,14 +71,55 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: Raises: None. """ - samples = { - sample_key: torch.stack([torch.tensor(d[sample_key]) for d in batch]) for sample_key in self.sample_keys - } - targets = { - target_key: torch.stack([torch.tensor(d[target_key]) for d in batch]) for target_key in self.target_keys - } - + # only keys related to the other modalities (e.g. images, audio, video) + modality_keys = [key for key in self.sample_keys if key not in ["audio_len", self.text_sample_key]] + + samples = {sample_key: [] for sample_key in self.sample_keys if sample_key != self.text_sample_key} + text_samples = {sample_key: [] for sample_key in modality_keys} + attention_masks = {sample_key: [] for sample_key in modality_keys} + # gather samples by modality + for sample in batch: + text_sample_added = False # make sure text is only added once per sample + for sample_key in self.sample_keys: + if sample_key in sample: + if sample_key in samples: + samples[sample_key].append(self._prepare_sample(sample[sample_key])) + if "attention_mask" in sample and sample_key in attention_masks and not text_sample_added: + attention_masks[sample_key].append(self._prepare_sample(sample["attention_mask"])) + if sample_key in text_samples and not text_sample_added: + text_samples[sample_key].append(self._prepare_sample(sample[self.text_sample_key])) + text_sample_added = True + # remove keys with no samples + for sample_key in modality_keys: + if len(text_samples[sample_key]) == 0: + del text_samples[sample_key] + if len(attention_masks[sample_key]) == 0: + del attention_masks[sample_key] + # stack samples by modality + for sample_key in self.sample_keys: + if sample_key in samples: + samples[sample_key] = torch.stack(samples[sample_key]) + if sample_key in text_samples: + text_samples[sample_key] = torch.stack(text_samples[sample_key]) + if sample_key in attention_masks: + attention_masks[sample_key] = torch.stack(attention_masks[sample_key]) + # stack input_ids and attention masks for all modalities + samples[self.text_sample_key] = torch.cat([text_samples[sample_key] for sample_key in text_samples]) + samples["attention_mask"] = torch.cat([attention_masks[sample_key] for sample_key in attention_masks]) + + targets = {} # Create target for text input targets[self.text_target_key] = samples[self.text_sample_key][:, 1:].clone().detach() - samples[self.text_sample_key] = samples[self.text_sample_key][:, :-1].clone().detach() + samples[self.text_sample_key] = samples[self.text_sample_key][:, :-1] + + if "attention_mask" in batch[0]: + targets["attention_mask"] = samples["attention_mask"][:, 1:].clone().detach() + samples["attention_mask"] = samples["attention_mask"][:, :-1] + return DatasetBatch(targets=targets, samples=samples) + + @staticmethod + def _prepare_sample(x): + if isinstance(x, torch.Tensor): + return x + return torch.tensor(x) diff --git a/src/modalities/models/coca/multi_modal_decoder.py b/src/modalities/models/coca/multi_modal_decoder.py index 6c6165233..aca2cea21 100644 --- a/src/modalities/models/coca/multi_modal_decoder.py +++ b/src/modalities/models/coca/multi_modal_decoder.py @@ -22,6 +22,7 @@ def __init__( dropout: float, ffn_hidden: int, with_context: bool, + is_audio_video: bool, attention_type: AttentionType, attention_config: AttentionConfig = None, add_extra_mlp: bool = False, @@ -38,6 +39,8 @@ def __init__( dropout (float): The dropout rate. ffn_hidden (int): The number of hidden units in the feed-forward network. with_context (bool): Flag indicating whether to include context in the decoder. + is_audio_video (bool): Flag indicating whether an additional cross attention block is required for + data that consists of both audio and video from the same source. attention_type (AttentionType): The type of attention mechanism to use. attention_config (AttentionConfig, optional): The configuration for the attention mechanism. Defaults to None. @@ -45,6 +48,7 @@ def __init__( """ super().__init__() self.with_context = with_context + self.is_audio_video = is_audio_video self.add_extra_mlp = add_extra_mlp if activation == ActivationType.GELU: @@ -75,13 +79,22 @@ def __init__( self.ln_4 = nn.LayerNorm(normalized_shape=n_embd, bias=bias, eps=epsilon) self.mlp_2 = mlp() - def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor: + if self.is_audio_video: + self.cross_attn2 = MultiHeadAttention( + n_embd=n_embd, + n_head=n_head, + bias=bias, + attention_config=attention_config, + attention_type=AttentionType.CROSS_ATTENTION, + ) + + def forward(self, x: torch.Tensor, context: list[torch.Tensor] | torch.Tensor | None = None) -> torch.Tensor: """ Forward pass of the TransformerBlock module. Args: x (torch.Tensor): Input tensor. - context (torch.Tensor, optional): Context tensor. Defaults to None. + context (list[torch.Tensor] | torch.Tensor, optional): Context tensor. Defaults to None. Returns: torch.Tensor: Output tensor. @@ -90,8 +103,13 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor if not self.with_context or self.add_extra_mlp: x = x + self.mlp(self.ln_2(x)) if self.with_context: - x = x + self.cross_attn(self.ln_3(x), context=context) - x = x + self.mlp_2(self.ln_4(x)) + if isinstance(context, dict): + x = self.ln_3(x) + x = x + self.cross_attn(x, context=context["audio"]) + self.cross_attn2(x, context=context["video"]) + x = x + self.mlp_2(self.ln_4(x)) + else: + x = x + self.cross_attn(self.ln_3(x), context=context) + x = x + self.mlp_2(self.ln_4(x)) return x @@ -108,6 +126,7 @@ def __init__( n_head: int, n_embd: int, ffn_hidden: int, + is_audio_video: bool, dropout: float, bias: bool, activation: ActivationType, @@ -126,6 +145,8 @@ def __init__( n_head (int): The number of attention heads. n_embd (int): The dimension of the embeddings. ffn_hidden (int): The size of the feed-forward network hidden layer. + is_audio_video (bool): Flag indicating whether an additional cross attention block is required for + data that consists of both audio and video from the same source. dropout (float): The dropout rate. bias (bool): Flag indicating whether to include bias terms. activation (ActivationType): The activation function to use. @@ -153,6 +174,7 @@ def __init__( dropout=dropout, ffn_hidden=ffn_hidden, with_context=True, + is_audio_video=is_audio_video, attention_type=AttentionType.CAUSAL_SELF_ATTENTION, attention_config=attention_config, add_extra_mlp=False, diff --git a/src/modalities/models/coca/text_decoder.py b/src/modalities/models/coca/text_decoder.py index 39204d18b..e6b15c7ff 100644 --- a/src/modalities/models/coca/text_decoder.py +++ b/src/modalities/models/coca/text_decoder.py @@ -66,6 +66,7 @@ def __init__( dropout=dropout, ffn_hidden=ffn_hidden, with_context=False, + is_audio_video=False, attention_type=AttentionType.CAUSAL_SELF_ATTENTION, attention_config=attention_config, ) diff --git a/src/modalities/models/vision_transformer/vision_transformer_model.py b/src/modalities/models/vision_transformer/vision_transformer_model.py index 0b504bd39..3345a28a4 100644 --- a/src/modalities/models/vision_transformer/vision_transformer_model.py +++ b/src/modalities/models/vision_transformer/vision_transformer_model.py @@ -24,12 +24,15 @@ class VisionTransformerConfig(BaseModel): attention_config (AttentionConfig, optional): The configuration for the attention mechanism. Defaults to None. n_head (int): The number of attention heads. Defaults to 8. n_embd (int): The dimensionality of the embedding. Defaults to 768. + ffn_hidden (int): The number of hidden units in the feed-forward network. Defaults to 3072. dropout (float): The dropout rate. Defaults to 0.0. patch_size (int): The size of the image patches. Defaults to 16. patch_stride (int): The stride of the image patches. Defaults to 16. n_img_channels (int): The number of image channels. Defaults to 3. add_cls_token (bool): Flag indicating whether to add a classification token. Defaults to True. bias (bool): Flag indicating whether to include bias terms. Defaults to True. + num_video_frames (int): the number of video frames in case of video input + n_latents: the number of latent queries used for the Perceiver block in case of video input. Defaults to 64. """ sample_key: str @@ -40,12 +43,15 @@ class VisionTransformerConfig(BaseModel): attention_config: AttentionConfig = None n_head: Annotated[int, Field(ge=1)] = 8 n_embd: Annotated[int, Field(ge=1)] = 768 + ffn_hidden: Annotated[int, Field(ge=1)] = 3072 dropout: Annotated[float, Field(ge=0.0)] = 0.0 patch_size: Annotated[int, Field(ge=1)] = 16 patch_stride: Annotated[int, Field(ge=1)] = 16 n_img_channels: Annotated[int, Field(ge=1)] = 3 add_cls_token: bool = True bias: bool = True + num_video_frames: Annotated[int, Field(ge=0)] = 1 + n_latents: Annotated[int, Field(ge=1)] = 64 class ImagePatchEmbedding(nn.Module): @@ -108,6 +114,56 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class VideoPatchEmbedding(nn.Module): + def __init__( + self, + n_img_channels: int = 3, + n_embd: int = 768, + patch_size: int = 16, + patch_stride: int = 16, + ) -> None: + """ + Initializes a VideoPatchEmbedding object. + + + Args: + n_img_channels (int): Number of image channels. Defaults to 3. + n_embd (int): Number of embedding dimensions. Defaults to 768. + patch_size (int): Patch size for convolutional layer. Defaults to 16. + patch_stride (int): Patch stride for convolutional layer. Defaults to 16. + + Returns: + None + """ + super().__init__() + self.input_rearrange = Rearrange("b T c h w -> b c T h w") + self.conv = nn.Conv3d( + in_channels=n_img_channels, + out_channels=n_embd, + kernel_size=(1, patch_size, patch_size), + stride=(1, patch_size, patch_stride), + ) # TODO: check the 3D conv again + + # See https://github.com/arogozhnikov/einops/wiki/Using-torch.compile-with-einops + self.rearrange = Rearrange("b c T h w -> b T (h w) c") # TODO: this might change when implementing dataloader + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the VideoPatchEmbedding. + + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + x = self.input_rearrange(x) + x = self.conv(x) + x = self.rearrange(x) + return x # [b T S D] + + class VisionTransformerBlock(nn.Module): """VisionTransformerBlock class.""" @@ -161,6 +217,69 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +class PerceiverTransformerBlock(nn.Module): + """Perceiver Resampler + + This is a transformer based architecture that performs cross and self attention to compress and embed video + or other high-dimensional inputs. + paper: 'Flamingo: a Visual Language Model for Few-Shot Learning' + Link: https://github.com/mlfoundations/open_flamingo + """ + + def __init__( + self, + n_embd: int = 768, + n_head: int = 8, + ffn_hidden: int = 3072, + bias: bool = True, + dropout: float = 0.0, + attention_config: AttentionConfig = None, + ) -> None: + """ + Initializes a PerceiverTransformerBlock object. + + Args: + n_embd (int, optional): The dimensionality of the embedding layer. Defaults to 768. + n_head (int, optional): The number of attention heads. Defaults to 8. + ffn_hidden (int, optional): The number of hidden units in the feed-forward network. Defaults to 3072. + bias (bool, optional): Flag indicating whether to include bias terms. Defaults to True. + dropout (float, optional): The dropout rate. Defaults to 0.0. + attention_config (AttentionConfig, optional): The configuration for the attention mechanism. + Defaults to None. + + Returns: + None + """ + super().__init__() + self.norm_latents = nn.LayerNorm(n_embd) + self.norm = nn.LayerNorm(n_embd) + self.attention = MultiHeadAttention( + n_embd=n_embd, + n_head=n_head, + attention_config=attention_config, + attention_type=AttentionType.CROSS_ATTENTION, + ) + self.mlp = MLP(in_features=n_embd, hidden_features=ffn_hidden, bias=bias, dropout=dropout) + + def forward(self, x: torch.Tensor, latents: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the PerceiverTransformerBlock module. + + Args: + x (torch.Tensor): Input tensor. + latents (torch.Tensor): input latent array tensor + + Returns: + torch.Tensor: Output tensor. + """ + latents = self.norm_latents(latents) + x = self.norm(x) + context = torch.cat((x, latents), dim=-2) # video features and the latent together + latents = latents + self.attention(latents, context=context) + latents = latents + self.mlp(latents) + return latents + + class VisionTransformer(nn.Module): """ VisionTransformer class. @@ -170,6 +289,8 @@ class VisionTransformer(nn.Module): Paper: `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` Link: https://arxiv.org/abs/2010.11929 + + This architecture is extended to encode videos using a Perceiver transformer model """ def __init__( @@ -189,6 +310,8 @@ def __init__( n_img_channels: int = 3, add_cls_token: bool = True, bias: bool = True, + num_video_frames: int = 1, # 1: Image, >1: Video + n_latents: int = 64, ) -> None: """ Initializes the VisionTransformer object. @@ -209,22 +332,44 @@ def __init__( n_img_channels (int, optional): The number of image channels. Defaults to 3. add_cls_token (bool, optional): Flag indicating whether to add a classification token. Defaults to True. bias (bool, optional): Flag indicating whether to include bias terms. Defaults to True. + num_video_frames (int): Number of frames. Defaults to 1. + n_latents (int, optional): Size of latent array. Defaults to 64. Returns: None """ super().__init__() self.sample_key = sample_key + self.has_cls_token = add_cls_token self.prediction_key = prediction_key self.img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) self.block_size = self._calculate_block_size(self.img_size, patch_size, patch_stride, add_cls_token) - - self.embedding_fn = ImagePatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride, add_cls_token) - self.positional_embedding_fn = nn.Embedding(num_embeddings=self.block_size, embedding_dim=n_embd) self.dropout = nn.Dropout(dropout) + + self.head = None + if n_classes is not None: + self.norm = nn.LayerNorm(n_embd) + self.head = nn.Linear(in_features=n_embd, out_features=n_classes, bias=bias) + + self.vision_input = "Image" + if num_video_frames > 1: # video data + self.vision_input = "Video" + self.embedding_fn = VideoPatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride) # [b T S D] + self.time_embd = nn.Parameter(torch.randn(num_video_frames, 1, n_embd)) # [T,1,d] + if add_cls_token: + n_latents += 1 # to count for a video level cls token + self.block_size -= 1 + self.latents = nn.Parameter(torch.randn(n_latents, n_embd)) # [R,d] + self.rearrange = Rearrange("b T S D -> b (T S) D") + else: + self.embedding_fn = ImagePatchEmbedding(n_img_channels, n_embd, patch_size, patch_stride, add_cls_token) + + self.positional_embedding_fn = nn.Embedding(num_embeddings=self.block_size, embedding_dim=n_embd) # [S D] + block_classes = {"Video": PerceiverTransformerBlock, "Image": VisionTransformerBlock} + self.blocks = nn.ModuleList( [ - VisionTransformerBlock( + block_classes[self.vision_input]( n_embd=n_embd, n_head=n_head, ffn_hidden=ffn_hidden, @@ -236,11 +381,6 @@ def __init__( ] ) - self.head = None - if n_classes is not None: - self.norm = nn.LayerNorm(n_embd) - self.head = nn.Linear(in_features=n_embd, out_features=n_classes, bias=bias) - def forward_images(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for processing images using the VisionTransformer module. @@ -257,6 +397,31 @@ def forward_images(self, x: torch.Tensor) -> torch.Tensor: x = block(x) return x + def forward_videos(self, x: torch.Tensor) -> torch.Tensor: + """Encode video data into a shorter sequence of tokens + + Args: + x (torch.Tensor): images from multiple video frames + shape (b c T h w) + b: batch size + T: temporal dim + h,w: spatial dims (S=h*w) + c: embedding dim (D) + + Returns: + torch.Tensor: latents + shape (b R D) R << T*S + """ + x = self.embedding_fn(x) # [b T S D] + b, T = x.shape[:2] + x = self.dropout(x + self.positional_embedding_fn.weight) + x = self.dropout(x + self.time_embd.repeat(b, 1, 1, 1)) + x = self.rearrange(x) # [b T*S D] + latents = self.latents.repeat(b, 1, 1) # [b,R,d] with R< dict[str, torch.Tensor]: """ Forward pass of the VisionTransformer module. @@ -269,9 +434,12 @@ def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ x = inputs[self.sample_key] - x = self.forward_images(x) + if self.vision_input == "Video": + x = self.forward_videos(x) + else: + x = self.forward_images(x) if self.head: - if self.embedding_fn.cls_token is not None: + if self.has_cls_token: x = x[:, 0] else: x = x.mean(dim=1) diff --git a/src/modalities/nn/attention.py b/src/modalities/nn/attention.py index 789602a1b..35ac8f45b 100644 --- a/src/modalities/nn/attention.py +++ b/src/modalities/nn/attention.py @@ -59,25 +59,42 @@ def __init__( ) self.resid_dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() - def forward(self, x: Tensor, context: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, context: Optional[Tensor] = None, mask: Tensor = None) -> Tensor: context = context if self.use_cross_attention else x B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) q, k, v = self._forward_input_projection(x, context=context) if self.use_flash: - y = F.scaled_dot_product_attention( - query=q, - key=k, - value=v, - attn_mask=None, - dropout_p=self.dropout if self.training else 0, - is_causal=self.is_causal, + y = ( + self._flash_with_mask(query=q, key=k, value=v, mask=mask) + if mask is not None + else self._flash_without_mask(query=q, key=k, value=v) ) else: - y = self._forward_attention(query=q, key=k, value=v) + y = self._forward_attention(query=q, key=k, value=v, mask=mask) y = y.transpose(1, 2).contiguous().view(B, T, C) y = self.resid_dropout(self.c_proj(y)) return y + def _flash_with_mask(self, query: Tensor, key: Tensor, value: Tensor, mask: Tensor) -> Tensor: + return F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=(mask == 0).logical_not(), + dropout_p=self.dropout if self.training else 0, + is_causal=self.is_causal, + ) + + def _flash_without_mask(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + return F.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=None, + dropout_p=self.dropout if self.training else 0, + is_causal=self.is_causal, + ) + def _forward_input_projection(self, x: Tensor, context: Tensor) -> tuple[Tensor, Tensor, Tensor]: B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) _, Tc, Cc = context.shape # batch size, context length, context embedding dimensionality @@ -88,11 +105,13 @@ def _forward_input_projection(self, x: Tensor, context: Tensor) -> tuple[Tensor, v = self.wv(context).view(B, Tc, self.n_head, Cc // self.n_head).transpose(1, 2) return q, k, v - def _forward_attention(self, query: Tensor, key: Tensor, value: Tensor) -> Tensor: + def _forward_attention(self, query: Tensor, key: Tensor, value: Tensor, mask: Tensor) -> Tensor: att = (query @ key.transpose(-2, -1)) * (1.0 / math.sqrt(key.size(-1))) if self.is_causal: T = query.size(2) att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float("-inf")) + if mask is not None: + att = att.masked_fill(mask == 0, float("-inf")) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) return att @ value diff --git a/src/modalities/nn/model_initialization/parameter_name_filters.py b/src/modalities/nn/model_initialization/parameter_name_filters.py index ff4edede0..c15ac99bd 100644 --- a/src/modalities/nn/model_initialization/parameter_name_filters.py +++ b/src/modalities/nn/model_initialization/parameter_name_filters.py @@ -67,10 +67,27 @@ class RegexFilter(BaseModel): }, SupportWeightInitModels.COCA: { # we reject all bias and weight parameters belonging to norms + # optional .weight so that we include nn.Parameters WeightInitTypes.PLAIN: RegexFilter( - weights=[r"^(?!.*norm)(?!.*ln_).*\.weight$"], biases=[r"^(?!.*norm)(?!.*ln_).*\.bias$"] + weights=[r"^(?!.*norm)(?!.*ln)(?!.*batch_norm).*(.weight)?$"], + biases=[r"^(?!.*norm)(?!.*ln)(?!.*batch_norm).*.bias$"], + ), + # scaled init for residual layers: + # https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf (pp 4) + WeightInitTypes.SCALED: RegexFilter( + weights=[ + r"transformer\.h\.\d+\.attn\.c_proj\.weight", + ] + ), + WeightInitTypes.SCALED_EMBED: RegexFilter( + weights=[ + # embedding weights + r"\.wte\.weight", + r"\.wpe\.weight", + r"positional_embeddings\.weight", + r"positional_embedding_fn\.weight", + r"time_embd$", + ] ), - WeightInitTypes.SCALED: RegexFilter(weights=[], biases=[]), - WeightInitTypes.SCALED_EMBED: RegexFilter(weights=[], biases=[]), }, } diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 5e4ef2378..6f07036c3 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -22,6 +22,7 @@ CheckpointedModelConfig, CheckpointedOptimizerConfig, CheckpointSavingConfig, + ClipLossConfig, CLMCrossEntropyLossConfig, ConstantLRSchedulerConfig, CosineAnnealingLRSchedulerConfig, @@ -35,6 +36,8 @@ GPT2LLMCollateFnConfig, LLMDataLoaderConfig, MemMapDatasetConfig, + MultipleFunctionsLossConfig, + NCELossConfig, OneCycleLRSchedulerConfig, PackedMemMapDatasetContinuousConfig, PackedMemMapDatasetMegatronConfig, @@ -45,19 +48,35 @@ RichResultSubscriberConfig, SaveEveryKStepsCheckpointingStrategyConfig, SaveKMostRecentCheckpointsStrategyConfig, + SimpleProgressSubscriberConfig, StepLRSchedulerConfig, TorchCheckpointLoadingConfig, WandBEvaluationResultSubscriberConfig, + WebDataLoaderConfig, WeightInitializedModelConfig, ) from modalities.dataloader.dataloader_factory import DataloaderFactory -from modalities.dataloader.dataset import DummyDatasetConfig +from modalities.dataloader.dataset import ( + AudioTransform, + AudioTransformConfig, + DummyDatasetConfig, + ImageTransform, + ImageTransformConfig, + MultimodalWebDataset, + MultimodalWebDatasetBuilder, + MultimodalWebDatasetBuilderConfig, + MultimodalWebDatasetConfig, + TextTransform, + TextTransformConfig, + VideoTransform, + VideoTransformConfig, +) from modalities.dataloader.dataset_factory import DatasetFactory from modalities.logging_broker.subscriber_impl.subscriber_factory import ( ProgressSubscriberFactory, ResultsSubscriberFactory, ) -from modalities.loss_functions import CLMCrossEntropyLoss +from modalities.loss_functions import ClipLoss, CLMCrossEntropyLoss, MultipleFunctionsLoss, NCELoss from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig @@ -142,6 +161,9 @@ class ComponentEntity: ), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + ComponentEntity("loss", "nce_loss", NCELoss, NCELossConfig), + ComponentEntity("loss", "clip_loss", ClipLoss, ClipLossConfig), + ComponentEntity("loss", "multiple_functions_loss", MultipleFunctionsLoss, MultipleFunctionsLossConfig), # optmizers ComponentEntity("optimizer", "adam", OptimizerFactory.get_adam, AdamOptimizerConfig), ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), @@ -175,6 +197,13 @@ class ComponentEntity: PackedMemMapDatasetMegatronConfig, ), ComponentEntity("dataset", "dummy_dataset", DatasetFactory.get_dummy_dataset, DummyDatasetConfig), + ComponentEntity("dataset", "web_dataset", MultimodalWebDataset, MultimodalWebDatasetConfig), + ComponentEntity("dataset", "web_dataset_builder", MultimodalWebDatasetBuilder, MultimodalWebDatasetBuilderConfig), + # Data transforms & augmentations + ComponentEntity("transform", "text_transform", TextTransform, TextTransformConfig), + ComponentEntity("transform", "image_transform", ImageTransform, ImageTransformConfig), + ComponentEntity("transform", "audio_transform", AudioTransform, AudioTransformConfig), + ComponentEntity("transform", "video_transform", VideoTransform, VideoTransformConfig), # samplers ComponentEntity("sampler", "distributed_sampler", DistributedSampler, DistributedSamplerConfig), # batch samplers @@ -184,6 +213,7 @@ class ComponentEntity: ComponentEntity("collate_fn", "coca_collator", CoCaCollatorFn, CoCaCollateFnConfig), # data loaders ComponentEntity("data_loader", "default", DataloaderFactory.get_dataloader, LLMDataLoaderConfig), + ComponentEntity("data_loader", "web_dataloader", DataloaderFactory.get_web_dataloader, WebDataLoaderConfig), ComponentEntity( "data_loader", "repeating_data_loader", DataloaderFactory.get_repeating_dataloader, RepeatingDataLoaderConfig ), @@ -214,6 +244,12 @@ class ComponentEntity: ProgressSubscriberFactory.get_dummy_progress_subscriber, DummyProgressSubscriberConfig, ), + ComponentEntity( + "progress_subscriber", + "simple", + ProgressSubscriberFactory.get_simple_progress_subscriber, + SimpleProgressSubscriberConfig, + ), ComponentEntity( "progress_subscriber", "rich", diff --git a/src/modalities/running_env/cuda_env.py b/src/modalities/running_env/cuda_env.py index 611ff51a8..6d60b849c 100644 --- a/src/modalities/running_env/cuda_env.py +++ b/src/modalities/running_env/cuda_env.py @@ -21,7 +21,7 @@ def __init__( """ self.process_group_backend = process_group_backend # TODO we might want to set this from outside via the config - self.local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.local_rank = int(os.environ["LOCAL_RANK"]) def __enter__(self) -> "CudaEnv": """Sets the CUDA environment for distributed training. diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index 195d050a9..c1f542b99 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -12,7 +12,7 @@ from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate from modalities.logging_broker.publisher import MessagePublisher -from modalities.loss_functions import Loss +from modalities.loss_functions import Loss, MultipleFunctionsLoss from modalities.models.model import model_predict_batch from modalities.running_env.fsdp.reducer import Reducer from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -259,6 +259,25 @@ def train( "train loss last": ResultItem(train_loss_last_batch, decimal_places=2), } + # If there are multiple loss functions being used, + # this block computes and logs all the individual + # losses, averaged over the global batch size. + if isinstance(loss_fun, MultipleFunctionsLoss): + global_batch_size = Reducer.reduce( + tensor=cumulated_losses[-1], operation=dist.ReduceOp.SUM, post_processing_fun=None + ) + reduced_individual_losses = Reducer.reduce( + tensor=loss_fun.cumulated_individual_losses, + operation=dist.ReduceOp.SUM, + post_processing_fun=lambda t: torch.stack( + [t[ind] / global_batch_size for ind in range(len(t))] + ), + ) + for ind, (loss, _) in enumerate(loss_fun.groups): + losses[f"train {loss.tag} avg"] = ResultItem(reduced_individual_losses[ind], decimal_places=2) + + loss_fun.reset_cumulated_individual_losses() + consumed_tokens = torch.Tensor([training_progress.num_seen_tokens_total]) metrics = { "consumed tokens": ResultItem(consumed_tokens, 0), diff --git a/src/modalities/util.py b/src/modalities/util.py index 1bbe3ff4a..bda55b873 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -181,6 +181,28 @@ def get_all_reduced_value( return value +def flatten_dict(d, parent_key="", sep="_"): + """ + Flatten a nested dictionary. + + Args: + d: The dictionary to flatten. + parent_key: The base key to use for concatenation. + sep: The separator to use between concatenated keys. + + Return: + A flattened dictionary with concatenated keys. + """ + items = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(flatten_dict(v, new_key, sep=sep).items()) + else: + items.append((new_key, v)) + return dict(items) + + def get_module_class_from_name(module: torch.nn.Module, name: str) -> Type[torch.nn.Module] | None: """From Accelerate source code (https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/utils/dataclasses.py#L1902) diff --git a/tests/dataloader/distributed/test_distributed_dataloader.py b/tests/dataloader/distributed/test_distributed_dataloader.py index 0038d04a6..0d2b0b098 100644 --- a/tests/dataloader/distributed/test_distributed_dataloader.py +++ b/tests/dataloader/distributed/test_distributed_dataloader.py @@ -9,7 +9,7 @@ from modalities.__main__ import Main from modalities.config.config import ProcessGroupBackendType -from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType +from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType from modalities.running_env.cuda_env import CudaEnv from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig @@ -18,7 +18,7 @@ class DataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif( diff --git a/tests/dataloader/distributed/test_distributed_repeating_dataloader.py b/tests/dataloader/distributed/test_distributed_repeating_dataloader.py index 7f40cc974..418793a43 100644 --- a/tests/dataloader/distributed/test_distributed_repeating_dataloader.py +++ b/tests/dataloader/distributed/test_distributed_repeating_dataloader.py @@ -8,7 +8,7 @@ from pydantic import BaseModel from modalities.__main__ import Main -from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType +from modalities.config.config import ProcessGroupBackendType, PydanticDataLoaderIFType from modalities.running_env.cuda_env import CudaEnv from tests.dataloader.dummy_sequential_dataset import TestDataset, TestDatasetConfig @@ -17,7 +17,7 @@ class DataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif( diff --git a/tests/dataloader/test_dataloader.py b/tests/dataloader/test_dataloader.py index 65139ce6a..9d6f171a9 100644 --- a/tests/dataloader/test_dataloader.py +++ b/tests/dataloader/test_dataloader.py @@ -10,7 +10,7 @@ from modalities.config.component_factory import ComponentFactory from modalities.config.config import load_app_config_dict -from modalities.config.pydanctic_if_types import PydanticLLMDataLoaderIFType +from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType from modalities.dataloader.dataloader import LLMDataLoader, RepeatingDataLoader from modalities.dataloader.dataset import Dataset from modalities.dataloader.samplers import ResumableBatchSampler @@ -49,7 +49,7 @@ def test_dataloader_from_config(dummy_config: dict): dummy_config["train_dataloader"]["config"]["skip_num_batches"] = start_index class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType registry = Registry(COMPONENTS) component_factory = ComponentFactory(registry=registry) @@ -167,7 +167,7 @@ def test_repeating_dataloader_with_shuffling(): def test_skipped_and_distributed_dataloader_from_config(): class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType skip_num_batches: int root_dir = Path(__file__).parents[0] @@ -244,7 +244,7 @@ class DataloaderTestModel(BaseModel): ) def test_dataloader_with_fixed_num_batches(global_rank): class DataloaderTestModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType fixed_num_batches: int class IdentityCollateFn(CollateFnIF): diff --git a/tests/dataloader/test_webdataset.py b/tests/dataloader/test_webdataset.py new file mode 100644 index 000000000..3ed7c7762 --- /dev/null +++ b/tests/dataloader/test_webdataset.py @@ -0,0 +1,140 @@ +import io +import tarfile +from pathlib import Path + +import numpy as np +import pytest +import torch +import torchaudio +import webdataset as wds +from pydantic import BaseModel + +from modalities.__main__ import load_app_config_dict +from modalities.config.component_factory import ComponentFactory +from modalities.config.pydanctic_if_types import PydanticDataLoaderIFType +from modalities.registry.components import COMPONENTS +from modalities.registry.registry import Registry +from tests.conftest import _ROOT_DIR + + +def create_image_sample(): + img = np.random.randint(0, 255, size=(224, 224, 3)).astype(np.uint8) + img = wds.writer.imageencoder(img, format="JPG") + text = {"text0": "this is an image caption %d" % np.random.randint(10)} + return img, text + + +@pytest.fixture(scope="session") +def image_tar_path(tmp_path_factory): + data_path = str(tmp_path_factory.mktemp("data") / "images.tar") + dataset_sink = wds.TarWriter(data_path) + # 10 image samples + for idx in range(10): + img, text = create_image_sample() + dataset_sink.write( + { + "__key__": "%02d" % idx, + "jpg": img, + "json": text, + } + ) + dataset_sink.close() + return data_path + + +def create_audio_sample(): + sample_rate = 16000 + audio = torch.from_numpy(np.random.uniform(-1, 1, sample_rate)).unsqueeze(0) + audio_buf = io.BytesIO() + torchaudio.save(audio_buf, audio, sample_rate, format="wav") + audio_buf.seek(0) + text = "this is an audio caption %d" % np.random.randint(10) + text_f = io.BytesIO() + text_f.write(text.encode("utf-8")) + text_f.seek(0) + return audio_buf, text_f + + +@pytest.fixture(scope="session") +def audio_tar_path(tmp_path_factory): + data_path = str(tmp_path_factory.mktemp("data") / "audio.tar") + with tarfile.open(data_path, "w") as fp: + # 25 audio samples + for idx in range(25): + key = "%02d" % idx + wav, text = create_audio_sample() + info = tarfile.TarInfo(key + ".wav") + info.size = wav.getbuffer().nbytes + fp.addfile(info, wav) + info = tarfile.TarInfo(key + ".transcript.txt") + info.size = text.getbuffer().nbytes + fp.addfile(info, text) + return data_path + + +@pytest.mark.parametrize( + "mixing_ratios,resample,batch_size", + [ + ([0.9, 0.1], False, 10), # we run out of image samples after the second batch + ([0.9, 0.1], True, 10), # since we resample, there are enough samples for >2 batches + ([0.7, 0.3], False, 20), # the first batch won't have 0.7*20 samples + ([0.3, 0.6], False, 10), # ratios don't add up to 1 + ([0.8, 0.2], True, 100), + ], +) +def test_web_dataloader(image_tar_path, audio_tar_path, mixing_ratios, resample, batch_size): + class DataloaderTestModel(BaseModel): + train_dataloader: PydanticDataLoaderIFType + + config_file_path = _ROOT_DIR / Path("tests/dataloader/yaml_configs/web_dataloader.yaml") + config_dict = load_app_config_dict(config_file_path=config_file_path) + config_dict["image_dataset"]["config"]["urls"] = image_tar_path + config_dict["audio_dataset"]["config"]["urls"] = audio_tar_path + config_dict["train_dataset"]["config"]["mixing_ratios"] = mixing_ratios + config_dict["train_dataset"]["config"]["resample"] = resample + config_dict["train_dataset"]["config"]["batch_size"] = batch_size + config_dict["train_dataloader"]["config"]["batch_size"] = batch_size + registry = Registry(COMPONENTS) + component_factory = ComponentFactory(registry=registry) + components = component_factory.build_components(config_dict=config_dict, components_model_type=DataloaderTestModel) + + expected_images = int(mixing_ratios[0] * batch_size) + expected_audio = int(mixing_ratios[1] * batch_size) + # if ratios don't add up to 1, extra samples are added to first modality + remaining = batch_size - (expected_audio + expected_images) + expected_images += remaining + + loader = iter(components.train_dataloader) + + # image, audio + total_samples = [10, 25] + seen_samples = [0, 0] + + for idx in range(5): + batch_expected_images = expected_images + batch_expected_audio = expected_audio + try: + batch = next(loader) + except StopIteration: + break + + if not resample: + # if resample is False, the last batch may have less + # samples than expected if one of the modalities + # runs out of samples + if total_samples[0] - seen_samples[0] < expected_images: + expected_images - (total_samples[0] - seen_samples[0]) + batch_expected_images = total_samples[0] - seen_samples[0] + if total_samples[1] - seen_samples[1] < expected_audio: + expected_audio - (total_samples[1] - seen_samples[1]) + batch_expected_audio = total_samples[1] - seen_samples[1] + + assert batch.samples["images"].shape[0] == batch_expected_images + seen_samples[0] += batch.samples["images"].shape[0] + assert batch.samples["audio"].shape[0] == batch_expected_audio + seen_samples[1] += batch.samples["audio"].shape[0] + assert batch.samples["input_ids"].shape[0] == batch_expected_audio + batch_expected_images + for idx in range(2): + # reset if the complete dataset has been seen already + if seen_samples[idx] == total_samples[idx]: + seen_samples[idx] = 0 diff --git a/tests/dataloader/yaml_configs/web_dataloader.yaml b/tests/dataloader/yaml_configs/web_dataloader.yaml new file mode 100644 index 000000000..843d3d801 --- /dev/null +++ b/tests/dataloader/yaml_configs/web_dataloader.yaml @@ -0,0 +1,111 @@ +tokenizer: + component_key: tokenizer + variant_key: pretrained_hf_tokenizer + config: + pretrained_model_name_or_path: openai/clip-vit-base-patch32 + padding: true + max_length: 50 + +train_image_transform: + component_key: transform + variant_key: image_transform + config: + is_training: True + input_size: 224 + +train_audio_transform: + component_key: transform + variant_key: audio_transform + config: + is_training: True + block_size_audio_encoder: 500 + freq_domain_mask_length: 30 + time_domain_mask_length: 100 + +text_transform: + component_key: transform + variant_key: text_transform + config: + tokenizer: + instance_key: tokenizer + pass_type: BY_REFERENCE + +collate_fn: + component_key: collate_fn + variant_key: coca_collator + config: + sample_keys: + - images + - audio + - audio_len + - input_ids + target_keys: [] + text_sample_key: input_ids + text_target_key: logits + +image_dataset: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: None + modality_key_mapping: + TEXT: ["json_text0", "input_ids"] + IMAGE: ["jpg", "images"] + modality_transforms: + IMAGE: + instance_key: train_image_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + +audio_dataset: + component_key: dataset + variant_key: web_dataset_builder + config: + urls: None + modality_key_mapping: + TEXT: ["transcript.txt", "input_ids"] # source and target keys + AUDIO: ["wav", "audio"] + modality_transforms: + AUDIO: + instance_key: train_audio_transform + pass_type: BY_REFERENCE + TEXT: + instance_key: text_transform + pass_type: BY_REFERENCE + num_samples: 10 + + +train_dataset: + component_key: dataset + variant_key: web_dataset + config: + builders: + - instance_key: image_dataset + pass_type: BY_REFERENCE + - instance_key: audio_dataset + pass_type: BY_REFERENCE + mixing_ratios: [0.9, 0.1] + batch_size: 10 + shardshuffle: 100 + repeat: false + resample: false + shuffle_buffer: 10_000 + +train_dataloader: + component_key: data_loader + variant_key: web_dataloader + config: + num_workers: 0 + pin_memory: true + drop_last: true + dataloader_tag: "train" + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_size: 10 + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE diff --git a/tests/end2end_tests/test_fsdp_warmstart.py b/tests/end2end_tests/test_fsdp_warmstart.py index 3261eb4b4..dac0b402c 100644 --- a/tests/end2end_tests/test_fsdp_warmstart.py +++ b/tests/end2end_tests/test_fsdp_warmstart.py @@ -11,7 +11,7 @@ from modalities.__main__ import Main, load_app_config_dict from modalities.batch import EvaluationResultBatch -from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType +from modalities.config.config import ProcessGroupBackendType, PydanticDataLoaderIFType from modalities.config.instantiation_models import TrainingComponentsInstantiationModel from modalities.dataloader.dataloader import LLMDataLoader from modalities.logging_broker.messages import Message @@ -46,7 +46,7 @@ class SaveAllResultSubscriberConfig(BaseModel): class TrainDataloaderInstantiationModel(BaseModel): - train_dataloader: PydanticLLMDataLoaderIFType + train_dataloader: PydanticDataLoaderIFType @pytest.mark.skipif( diff --git a/tests/models/audio_transformer/__init__.py b/tests/models/audio_transformer/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/models/audio_transformer/test_audio_transformer_model.py b/tests/models/audio_transformer/test_audio_transformer_model.py new file mode 100644 index 000000000..09d3d672a --- /dev/null +++ b/tests/models/audio_transformer/test_audio_transformer_model.py @@ -0,0 +1,122 @@ +import pytest +import torch + +from modalities.models.audio_transformer.audio_transformer_model import ( + AudioTransformer, + ConformerBlock, + ConvolutionModule, +) +from modalities.nn.attention import AttentionConfig + + +@pytest.fixture +def params() -> dict: + return { + "sample_key": "audio", + "prediction_key": "audio_embeddings", + "block_size": 5, + "n_mels": 1, + "n_conformer_blocks": 1, + "n_embd": 1, + "n_heads": 1, + "attention_config": AttentionConfig(attention_engine_type="pytorch_flash_attention"), + "pointwise_conv_kernel_size": 1, + "depthwise_conv_kernel_size": 1, + "dropout": 0.1, + } + + +@pytest.fixture +def audio_transformer_model(params) -> AudioTransformer: + return AudioTransformer( + sample_key=params["sample_key"], + prediction_key=params["prediction_key"], + block_size=params["block_size"], + n_mels=params["n_mels"], + n_conformer_blocks=params["n_conformer_blocks"], + n_embd=params["n_embd"], + n_heads=params["n_heads"], + attention_config=params["attention_config"], + pointwise_conv_kernel_size=params["pointwise_conv_kernel_size"], + depthwise_conv_kernel_size=params["depthwise_conv_kernel_size"], + ffmodule_dropout=params["dropout"], + attn_dropout=params["dropout"], + convmodule_dropout=params["dropout"], + ) + + +@pytest.fixture +def invalid_forward_input() -> torch.Tensor: + return torch.randn((1, 1, 256)) + + +@pytest.fixture +def forward_input() -> dict[str, torch.Tensor]: + return {"x": torch.randn((1, 2, 1)), "mask": torch.ones((1, 2))} + + +def test_convolution_module_forward_return_shape( + params, + forward_input, +): + convolution = ConvolutionModule( + params["n_embd"], + params["pointwise_conv_kernel_size"], + params["depthwise_conv_kernel_size"], + params["dropout"], + ) + + out = convolution(forward_input["x"]) + + assert out.shape == (1, 2, 1) + + +def test_convolution_module_forward_raise( + params, + invalid_forward_input, +): + convolution = ConvolutionModule( + params["n_embd"], + params["pointwise_conv_kernel_size"], + params["depthwise_conv_kernel_size"], + params["dropout"], + ) + + with pytest.raises(ValueError, match="The time dimension of the input to the convolution module cannot be 1!"): + convolution(invalid_forward_input) + + +def test_conformer_forward(params, forward_input): + conformer = ConformerBlock( + params["n_embd"], + params["n_heads"], + params["attention_config"], + params["pointwise_conv_kernel_size"], + params["depthwise_conv_kernel_size"], + params["dropout"], + params["dropout"], + params["dropout"], + ) + + conformer(forward_input["x"], forward_input["mask"]) + + +def test_audio_transformer__get_attn_key_mask(audio_transformer_model): + lengths = torch.tensor([3]) + + CORRECT_MASK = torch.Tensor( + [ + [ + [ + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [1, 0, 0, 0, 0], + [1, 0, 0, 0, 0], + ] + ] + ] + ) + + CREATED_MASK = audio_transformer_model._get_attn_key_mask(lengths) + assert torch.equal(CORRECT_MASK, CREATED_MASK) diff --git a/tests/models/coca/coca_config_aud_vid.yaml b/tests/models/coca/coca_config_aud_vid.yaml new file mode 100644 index 000000000..d6b14d305 --- /dev/null +++ b/tests/models/coca/coca_config_aud_vid.yaml @@ -0,0 +1,68 @@ +prediction_key: logits +audio_embd_prediction_key: audio_embeddings +video_embd_prediction_key: video_embeddings +text_embd_prediction_key: text_embeddings +audio_cls_prediction_key: audio_cls +audio_text_cls_prediction_key: audio_text_cls +video_cls_prediction_key: video_cls +video_text_cls_prediction_key: video_text_cls +text_cls_prediction_key: text_cls +modality_keys: + - images + - audio + - audio_len + - video + - input_ids +is_audio_video: true +individual_datasets: false +logit_scale_prediction_key: logit_scale +audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 500 + n_mels: 128 + n_embd: 768 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 +video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 224 + n_classes: Null + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 +n_pool_head: 12 +n_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/coca_config_audio.yaml b/tests/models/coca/coca_config_audio.yaml new file mode 100644 index 000000000..0b28b8b5e --- /dev/null +++ b/tests/models/coca/coca_config_audio.yaml @@ -0,0 +1,44 @@ +prediction_key: logits +audio_embd_prediction_key: audio_embeddings +text_embd_prediction_key: text_embeddings +audio_cls_prediction_key: audio_cls +text_cls_prediction_key: text_cls +modality_keys: + - audio + - audio_len + - input_ids +is_audio_video: false +individual_datasets: false +logit_scale_prediction_key: logit_scale +audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 500 + n_mels: 128 + n_embd: 768 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 +n_pool_head: 8 +n_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/coca_config.yaml b/tests/models/coca/coca_config_image.yaml similarity index 72% rename from tests/models/coca/coca_config.yaml rename to tests/models/coca/coca_config_image.yaml index 2cc2e9195..21d6318f4 100644 --- a/tests/models/coca/coca_config.yaml +++ b/tests/models/coca/coca_config_image.yaml @@ -1,11 +1,17 @@ prediction_key: logits -vision_embd_prediction_key: vision_embeddings +image_embd_prediction_key: image_embeddings text_embd_prediction_key: text_embeddings -vision_cls_prediction_key: vision_cls +image_cls_prediction_key: image_cls text_cls_prediction_key: text_cls -vision_encoder_config: +modality_keys: + - images + - input_ids +is_audio_video: false +individual_datasets: false +logit_scale_prediction_key: logit_scale +image_encoder_config: sample_key: images - prediction_key: vision_embeddings + prediction_key: image_embeddings img_size: 224 n_classes: Null # Disable vision transformer head n_layer: 6 @@ -36,6 +42,6 @@ text_decoder_config: activation: swiglu epsilon: 1e-5 n_pool_head: 8 -n_vision_queries: 256 +n_queries: 256 bias_attn_pool: False -epsilon_attn_pool: 1e-5 \ No newline at end of file +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/coca_config_img_aud_vid.yaml b/tests/models/coca/coca_config_img_aud_vid.yaml new file mode 100644 index 000000000..bcb2b9d51 --- /dev/null +++ b/tests/models/coca/coca_config_img_aud_vid.yaml @@ -0,0 +1,87 @@ +prediction_key: logits +audio_embd_prediction_key: audio_embeddings +image_embd_prediction_key: image_embeddings +video_embd_prediction_key: video_embeddings +text_embd_prediction_key: text_embeddings +image_cls_prediction_key: image_cls +image_text_cls_prediction_key: image_text_cls +audio_cls_prediction_key: audio_cls +audio_text_cls_prediction_key: audio_text_cls +video_cls_prediction_key: video_cls +video_text_cls_prediction_key: video_text_cls +text_cls_prediction_key: text_cls +modality_keys: + - images + - audio + - audio_len + - video + - input_ids +is_audio_video: false +individual_datasets: true +logit_scale_prediction_key: logit_scale +audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 500 + n_mels: 128 + n_embd: 768 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 +image_encoder_config: + sample_key: images + prediction_key: image_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True +video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 224 + n_classes: Null + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 +n_pool_head: 12 +n_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/coca_config_video.yaml b/tests/models/coca/coca_config_video.yaml new file mode 100644 index 000000000..aa2b45576 --- /dev/null +++ b/tests/models/coca/coca_config_video.yaml @@ -0,0 +1,49 @@ +prediction_key: logits +video_embd_prediction_key: video_embeddings +text_embd_prediction_key: text_embeddings +video_cls_prediction_key: video_cls +text_cls_prediction_key: text_cls +modality_keys: + - video + - input_ids +is_audio_video: false +individual_datasets: false +logit_scale_prediction_key: logit_scale +video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 224 + n_classes: Null # Disable vision transformer head + n_layer: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 + patch_stride: 16 + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 +text_decoder_config: + sample_key: input_ids + prediction_key: text_embeddings + block_size: 1024 + vocab_size: 50304 + n_layer_text: 6 + n_layer_multimodal_text: 6 + attention_config: + attention_engine_type: pytorch_flash_attention + n_head: 12 + ffn_hidden: 2048 + n_embd: 768 + dropout: 0.0 + bias: true + activation: swiglu + epsilon: 1e-5 +n_pool_head: 8 +n_queries: 256 +bias_attn_pool: False +epsilon_attn_pool: 1e-5 diff --git a/tests/models/coca/test_coca.py b/tests/models/coca/test_coca.py index b5eb3f683..5d2d523b8 100644 --- a/tests/models/coca/test_coca.py +++ b/tests/models/coca/test_coca.py @@ -10,38 +10,140 @@ from modalities.running_env.cuda_env import CudaEnv from tests.conftest import _ROOT_DIR +# shared config +N_EMBD = 768 -def test_coca(): +# text_decoder_config +TEXT_DECODER_VOCAB_SIZE = 50_304 +TEXT_DECODER_BLOCK_SIZE = 1_024 + +# vision_transformer_config +N_IMAGE_CLASSES = 1_000 +IMG_SIZE = 224 +N_IMG_CHANNELS = 3 +N_FRAMES = 16 + +# audio_transformer_config +AUDIO_BLOCK_SIZE = 500 +N_MELS = 128 +SUB_SAMPLING_FACTOR = 4 + +BATCH_SIZE = 2 + + +def dummy_image_sample(): + input_image = torch.randn(BATCH_SIZE, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE, TEXT_DECODER_BLOCK_SIZE)) + return dict( + images=input_image, + input_ids=input_text, + ) + + +def dummy_video_sample(): + input_video = torch.randn(BATCH_SIZE, N_FRAMES, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE, TEXT_DECODER_BLOCK_SIZE)) + return dict( + video=input_video, + input_ids=input_text, + ) + + +def dummy_audio_sample(): + audio_features = torch.randn(BATCH_SIZE, AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE, TEXT_DECODER_BLOCK_SIZE)) + return dict( + audio=audio_features, + audio_len=audio_len, + input_ids=input_text, + ) + + +def dummy_img_aud_vid_sample(): + # separate image, audio, and video datasets + input_image = torch.randn(BATCH_SIZE, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + audio_features = torch.randn(BATCH_SIZE, AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) + input_video = torch.randn(BATCH_SIZE, N_FRAMES, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE * 3, TEXT_DECODER_BLOCK_SIZE)) + return dict( + images=input_image, + audio=audio_features, + audio_len=audio_len, + video=input_video, + input_ids=input_text, + ) + + +def dummy_aud_vid_sample(): + # single video dataset which contains audio + audio_features = torch.randn(BATCH_SIZE, AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) + input_video = torch.randn(BATCH_SIZE, N_FRAMES, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (BATCH_SIZE, TEXT_DECODER_BLOCK_SIZE)) + return dict( + audio=audio_features, + audio_len=audio_len, + video=input_video, + input_ids=input_text, + ) + + +@pytest.mark.parametrize( + "yaml,dummy_sample", + [ + ("tests/models/coca/coca_config_image.yaml", dummy_image_sample()), + ("tests/models/coca/coca_config_audio.yaml", dummy_audio_sample()), + ("tests/models/coca/coca_config_video.yaml", dummy_video_sample()), + ("tests/models/coca/coca_config_img_aud_vid.yaml", dummy_img_aud_vid_sample()), + ("tests/models/coca/coca_config_aud_vid.yaml", dummy_aud_vid_sample()), + ], +) +def test_coca(yaml, dummy_sample): # Create model - config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config.yaml") + config_file_path = _ROOT_DIR / Path(yaml) config_dict = load_app_config_dict(config_file_path=config_file_path) coca_config = CoCaConfig.model_validate(config_dict) model = CoCa(**dict(coca_config)) - # Create dummy inputs - dummy_input_image = torch.randn(1, 3, 224, 224) - dummy_input_text = torch.randint( - 0, coca_config.text_decoder_config.vocab_size, (1, coca_config.text_decoder_config.block_size) - ) - dummy_input = dict(images=dummy_input_image, input_ids=dummy_input_text) - # Create optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Run one training step optimizer.zero_grad() - out = model(dummy_input) + out = model(dummy_sample) loss = out["logits"].sum() loss.backward() optimizer.step() # Test outputs - assert "logits" in out - assert "vision_cls" in out - assert "text_cls" in out - assert out["logits"].shape == (1, 1024, 50304) - assert out["vision_cls"].shape == (1, 1, 768) - assert out["text_cls"].shape == (1, 1, 768) + text_output_batch_size = 0 + if coca_config.audio_encoder_config: + assert "audio_cls" in out + assert out["audio_cls"].shape == (BATCH_SIZE, N_EMBD) + if coca_config.individual_datasets: + assert out["audio_text_cls"].shape == (BATCH_SIZE, N_EMBD) + if not coca_config.is_audio_video: + text_output_batch_size += BATCH_SIZE + if coca_config.image_encoder_config: + assert "image_cls" in out + assert out["image_cls"].shape == (BATCH_SIZE, N_EMBD) + if coca_config.individual_datasets: + assert out["image_text_cls"].shape == (BATCH_SIZE, N_EMBD) + text_output_batch_size += BATCH_SIZE + if coca_config.video_encoder_config: + assert "video_cls" in out + assert out["video_cls"].shape == (BATCH_SIZE, N_EMBD) + if coca_config.individual_datasets: + assert out["video_text_cls"].shape == (BATCH_SIZE, N_EMBD) + text_output_batch_size += BATCH_SIZE + if not coca_config.individual_datasets: + assert out["text_cls"].shape == (BATCH_SIZE, N_EMBD) + assert out["logits"].shape == (text_output_batch_size, TEXT_DECODER_BLOCK_SIZE, TEXT_DECODER_VOCAB_SIZE) + assert "logit_scale" in out @pytest.mark.skip( diff --git a/tests/models/coca/test_collator.py b/tests/models/coca/test_collator.py new file mode 100644 index 000000000..af342c9d6 --- /dev/null +++ b/tests/models/coca/test_collator.py @@ -0,0 +1,142 @@ +import pytest +import torch + +from modalities.models.coca.collator import CoCaCollatorFn + +# shared config +N_EMBD = 768 + +# text_decoder_config +TEXT_DECODER_VOCAB_SIZE = 50_304 +TEXT_DECODER_BLOCK_SIZE = 100 + +# vision_transformer_config +N_IMAGE_CLASSES = 1_000 +IMG_SIZE = 224 +N_IMG_CHANNELS = 3 +N_FRAMES = 16 + +# audio_transformer_config +AUDIO_BLOCK_SIZE = 500 +N_MELS = 128 +SUB_SAMPLING_FACTOR = 4 + + +def dummy_image_sample(): + input_image = torch.randn(N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (TEXT_DECODER_BLOCK_SIZE,)) + attn_mask = torch.randint(0, 2, (TEXT_DECODER_BLOCK_SIZE,)) + return dict( + images=input_image, + input_ids=input_text, + attention_mask=attn_mask, + ) + + +def dummy_video_sample(): + input_video = torch.randn(N_FRAMES, N_IMG_CHANNELS, IMG_SIZE, IMG_SIZE) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (TEXT_DECODER_BLOCK_SIZE,)) + attn_mask = torch.randint(0, 2, (TEXT_DECODER_BLOCK_SIZE,)) + return dict( + video=input_video, + input_ids=input_text, + attention_mask=attn_mask, + ) + + +def dummy_audio_sample(): + audio_features = torch.randn(AUDIO_BLOCK_SIZE * SUB_SAMPLING_FACTOR, N_MELS) + audio_len = torch.tensor([N_IMAGE_CLASSES / SUB_SAMPLING_FACTOR]).type(torch.int16) + input_text = torch.randint(0, TEXT_DECODER_VOCAB_SIZE, (TEXT_DECODER_BLOCK_SIZE,)) + attn_mask = torch.randint(0, 2, (TEXT_DECODER_BLOCK_SIZE,)) + return dict( + audio=audio_features, + audio_len=audio_len, + input_ids=input_text, + attention_mask=attn_mask, + ) + + +@pytest.mark.parametrize( + "modality_sequence", + [ + ("iiiii"), + ("aaaaa"), + ("vvvvv"), + ("iiaav"), + ("iaiav"), + ("iviaa"), + ("iaiavaivaiiiiaaaviaa"), + ], +) +def test_collator(modality_sequence): + sample_keys = ["input_ids"] + target_keys = [] + text_sample_key = "input_ids" + text_target_key = "target_ids" + + num_image = modality_sequence.count("i") + num_audio = modality_sequence.count("a") + num_video = modality_sequence.count("v") + + # sample_keys in the order: images, audio, video + if num_image: + sample_keys.append("images") + if num_audio: + sample_keys.append("audio") + sample_keys.append("audio_len") + if num_video: + sample_keys.append("video") + + # create samples + image_samples = [] + for idx in range(num_image): + image_samples.append(dummy_image_sample()) + audio_samples = [] + for idx in range(num_audio): + audio_samples.append(dummy_audio_sample()) + + video_samples = [] + for idx in range(num_video): + video_samples.append(dummy_video_sample()) + + modality_samples = {"images": image_samples, "audio": audio_samples, "video": video_samples} + + collate_fn = CoCaCollatorFn(sample_keys, target_keys, text_sample_key, text_target_key) + + batch = [] + image_idx = 0 + video_idx = 0 + audio_idx = 0 + # create the batch according to the specified modality sequence + for ch in modality_sequence: + if ch == "i": + batch.append(image_samples[image_idx]) + image_idx += 1 + if ch == "a": + batch.append(audio_samples[audio_idx]) + audio_idx += 1 + if ch == "v": + batch.append(video_samples[video_idx]) + video_idx += 1 + + dataset_batch = collate_fn(batch) + + batch_idx = 0 + + # regardless of the order of the modality sequence, + # the batch (esp. input_ids and target_ids) should be in the same order as sample_keys + # i.e. batch.samples['input_ids'] = [*image input_ids, *audio_input_ids, *video_input_ids] + for modality_key in sample_keys: + if modality_key in ["audio_len", "input_ids"]: + continue + if modality_key in dataset_batch.samples: + for modality_idx, gt_sample in enumerate(modality_samples[modality_key]): + assert torch.equal(gt_sample[modality_key], dataset_batch.samples[modality_key][modality_idx]) + assert torch.equal(gt_sample["input_ids"][:-1], dataset_batch.samples[text_sample_key][batch_idx]) + assert torch.equal(gt_sample["input_ids"][1:], dataset_batch.targets[text_target_key][batch_idx]) + assert torch.equal(gt_sample["attention_mask"][:-1], dataset_batch.samples["attention_mask"][batch_idx]) + assert torch.equal(gt_sample["attention_mask"][1:], dataset_batch.targets["attention_mask"][batch_idx]) + if modality_key == "audio": + assert torch.equal(gt_sample["audio_len"], dataset_batch.samples["audio_len"][modality_idx]) + batch_idx += 1 diff --git a/tests/models/vision_transformer/test_vision_transformer.py b/tests/models/vision_transformer/test_vision_transformer.py index 24b03921a..be72a6de1 100644 --- a/tests/models/vision_transformer/test_vision_transformer.py +++ b/tests/models/vision_transformer/test_vision_transformer.py @@ -8,16 +8,33 @@ from tests.conftest import _ROOT_DIR -def test_vision_transformer(): +@pytest.mark.parametrize( + "input,sample_key,n_classes,num_video_frames,add_cls_token,output", + [ + (torch.randn(1, 3, 224, 224), "images", 1000, 1, True, (1, 1000)), + (torch.randn(1, 3, 224, 224), "images", None, 1, True, (1, 197, 768)), + (torch.randn(1, 3, 224, 224), "images", None, 1, False, (1, 196, 768)), + (torch.randn(1, 3, 224, 224), "images", 1000, 1, False, (1, 1000)), + (torch.randn(1, 16, 3, 224, 224), "videos", 1000, 16, True, (1, 1000)), + (torch.randn(1, 16, 3, 224, 224), "videos", None, 16, True, (1, 65, 768)), + (torch.randn(1, 16, 3, 224, 224), "videos", None, 16, False, (1, 64, 768)), + (torch.randn(1, 16, 3, 224, 224), "videos", 1000, 16, False, (1, 1000)), + ], +) +def test_vision_transformer(input, sample_key, n_classes, num_video_frames, add_cls_token, output): # Create model config_file_path = _ROOT_DIR / Path("tests/models/vision_transformer/vision_transformer_config.yaml") config_dict = load_app_config_dict(config_file_path=config_file_path) config = VisionTransformerConfig.model_validate(config_dict) + config.sample_key = sample_key + config.n_classes = n_classes + config.num_video_frames = num_video_frames + config.add_cls_token = add_cls_token + model = VisionTransformer(**dict(config)) # Create dummy inputs - dummy_input_image = torch.randn(1, 3, 224, 224) - dummy_input = dict(images=dummy_input_image) + dummy_input = {sample_key: input} # Create optimizer optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) @@ -31,7 +48,7 @@ def test_vision_transformer(): # Test outputs assert "logits" in out - assert out["logits"].shape == (1, 1000) + assert out["logits"].shape == output @pytest.mark.parametrize( diff --git a/tests/models/vision_transformer/vision_transformer_config.yaml b/tests/models/vision_transformer/vision_transformer_config.yaml index d6657c5c1..507719791 100644 --- a/tests/models/vision_transformer/vision_transformer_config.yaml +++ b/tests/models/vision_transformer/vision_transformer_config.yaml @@ -11,3 +11,5 @@ patch_stride: 16 n_img_channels: 3 add_cls_token: True bias: True +num_video_frames: 1 +n_latents: 64 diff --git a/tests/test_initialization.py b/tests/test_initialization.py index a169ade87..8e38ba113 100644 --- a/tests/test_initialization.py +++ b/tests/test_initialization.py @@ -82,7 +82,7 @@ def _load_coca(initialization_type: str, std: float | str) -> FSDP: coca_wrapped_model = ModelFactory.get_fsdp_wrapped_model( coca_model, sync_module_states=True, - block_names=["TransformerBlock", "VisionTransformerBlock"], + block_names=["TransformerBlock", "VisionTransformerBlock", "ConformerBlock"], mixed_precision_settings=MixedPrecisionSettings.FP_16, sharding_strategy=ShardingStrategy.NO_SHARD, ) @@ -111,9 +111,23 @@ def _load_model(model_name: str, initialization: str = "plain", std: float | str "other": [], } MAPPING_COCA = { - "embedding": [], # TODO + "embedding": [ + r"wte\.weight$", + r"wpe\.weight$", + r"positional_embeddings\.weight$", + "positional_embedding_fn.weight$", + "time_embd$", + ], "weight-projection": [r"c_proj\.weight$"], # TODO - "weight-norm": [r"norm[12]\.weight$", r"ln_[1234f]\.weight$"], # TODO + "weight-norm": [ + r"norm[12]?\.weight$", + r"norm_latents\.weight$", + r"ln_[1234f]\.weight$", + r"ln_mhsa.weight", + r"batch_norm.*weight$", + r"exit_ln.weight$", + r"attention_norm.weight$", + ], "weight-normal": [r"\.weight$"], "other": [r"conv", r".*(? dict[str, Optional[torch.T GPT2_WEIGHT_NORMAL = GPT2_ALL - GPT2_WEIGHT_PROJECTION - GPT2_EMBEDDING - GPT2_WEIGHT_NORM - GPT2_BIAS # 40107264 COCA_NLAYERS = 6 + 6 # text + multimodal -COCA_ALL = 184502784 -COCA_EMBEDDING = 0 # TODO -COCA_WEIGHT_PROJECTION = 14745600 -COCA_WEIGHT_NORM = 34560 -COCA_BIAS = 191232 -COCA_OTHER = 198912 -COCA_WEIGHT_NORMAL = 169332480 +COCA_ALL = 277424641 +COCA_EMBEDDING = 40118016 +COCA_WEIGHT_PROJECTION = 21233664 +COCA_WEIGHT_NORM = 768 * 79 +COCA_BIAS = 292608 +COCA_OTHER = 657409 +COCA_WEIGHT_NORMAL = 215062272 NR_PARAMETERS = { "gpt2": { diff --git a/tests/test_loss_functions.py b/tests/test_loss_functions.py index 8825f15c3..eccd05765 100644 --- a/tests/test_loss_functions.py +++ b/tests/test_loss_functions.py @@ -2,7 +2,7 @@ import torch from modalities.batch import InferenceResultBatch -from modalities.loss_functions import NCELoss, nce_loss +from modalities.loss_functions import ClipLoss, CLMCrossEntropyLoss, MultipleFunctionsLoss, NCELoss, nce_loss @pytest.fixture @@ -36,3 +36,101 @@ def test_nce_loss_correctness(embedding1, embedding2): bidirectional_loss = nce_loss(embedding1, embedding2, device="cpu", is_asymmetric=False, temperature=1.0) assert unidirectional_loss == pytest.approx(1.1300, 0.0001) assert bidirectional_loss == pytest.approx(2.2577, 0.0001) + + +@pytest.fixture +def clm_cross_entropy_loss_object() -> CLMCrossEntropyLoss: + return CLMCrossEntropyLoss(target_key="target_ids", prediction_key="logits") + + +@pytest.fixture +def clip_loss_object() -> ClipLoss: + return ClipLoss( + logit_scale_key="logit_scale", + prediction_keys=["image_cls", "image_text_cls"], + local_loss=False, + ) + + +@pytest.fixture +def clip_loss_forward_batch() -> InferenceResultBatch: + # BATCH SIZE, LENGTH OF SEQUENCE, EMBEDDING SIZE + predictions = { + "image_cls": torch.Tensor([[1, 2, 3], [4, 5, 6]]).to("cuda"), + "image_text_cls": torch.Tensor([[7, 8, 9], [10, 11, 12]]).to("cuda"), + "logit_scale": 0.07, + } + return InferenceResultBatch(targets={}, predictions=predictions) + + +@pytest.fixture +def setup_distributed(monkeypatch): + import torch.distributed as dist + + monkeypatch.setenv("RANK", "0") + monkeypatch.setenv("LOCAL_RANK", "0") + monkeypatch.setenv("WORLD_SIZE", "1") + monkeypatch.setenv("MASTER_ADDR", "localhost") + monkeypatch.setenv("MASTER_PORT", "9948") + + dist.init_process_group(backend="nccl") + yield + dist.destroy_process_group() + + +def test_clip_loss(clip_loss_object, clip_loss_forward_batch, setup_distributed): + loss_fn = clip_loss_object + loss_fn(clip_loss_forward_batch) + + +@pytest.fixture +def multiple_functions_loss_object_with_two_losses( + clm_cross_entropy_loss_object, clip_loss_object +) -> MultipleFunctionsLoss: + return MultipleFunctionsLoss( + [clm_cross_entropy_loss_object, clip_loss_object], + corrsp_weights=[1.0, 1.0], + ) + + +def test_multiple_functions_loss_initialized_with_single_loss( + clm_cross_entropy_loss_object, +): + with pytest.raises(ValueError, match="Number of losses used should be more than 1."): + MultipleFunctionsLoss([clm_cross_entropy_loss_object], corrsp_weights=[1.0]) + + +def test_multiple_functions_loss_reset_cumulated_individual_losses( + multiple_functions_loss_object_with_two_losses, +): + loss = multiple_functions_loss_object_with_two_losses + num_losses = len(loss.groups) + loss.cumulated_individual_losses = torch.randn(num_losses) + loss.reset_cumulated_individual_losses() + + assert torch.equal( + loss.cumulated_individual_losses, torch.zeros(num_losses, device=loss.cumulated_individual_losses.device) + ) + + +@pytest.fixture +def multiple_functions_loss_forward_batch() -> InferenceResultBatch: + targets = {"target_ids": torch.Tensor([[1, 2, 1], [1, 1, 2]])} + predictions = { + "image_cls": torch.Tensor([[1, 2, 3], [4, 5, 6]]).to("cuda"), + "image_text_cls": torch.Tensor([[7, 8, 9], [10, 11, 12]]).to("cuda"), + "logit_scale": 0.07, + "logits": torch.Tensor( + [[[0.1, 0.2, 0.7], [0.3, 0.2, 0.5], [0.0, 0.3, 0.7]], [[0.1, 0.2, 0.7], [0.3, 0.2, 0.5], [0.0, 0.3, 0.7]]] + ), + } + + return InferenceResultBatch(targets=targets, predictions=predictions) + + +def test_multiple_functions_loss( + multiple_functions_loss_object_with_two_losses, + multiple_functions_loss_forward_batch, + setup_distributed, +): + multiple_functions_loss_object_with_two_losses(multiple_functions_loss_forward_batch) diff --git a/tests/test_optimizer_factory.py b/tests/test_optimizer_factory.py index 4f273ad01..46ac8ee72 100644 --- a/tests/test_optimizer_factory.py +++ b/tests/test_optimizer_factory.py @@ -55,14 +55,14 @@ def _load_gpt2() -> FSDP: def _load_coca() -> FSDP: - config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config.yaml") + config_file_path = _ROOT_DIR / Path("tests/models/coca/coca_config_img_aud_vid.yaml") config_dict = load_app_config_dict(config_file_path=config_file_path) coca_config = CoCaConfig.model_validate(config_dict) coca_model = CoCa(**dict(coca_config)) coca_wrapped_model = ModelFactory.get_fsdp_wrapped_model( coca_model, sync_module_states=True, - block_names=["TransformerBlock", "VisionTransformerBlock"], + block_names=["TransformerBlock", "VisionTransformerBlock", "ConformerBlock"], mixed_precision_settings=MixedPrecisionSettings.FP_16, sharding_strategy=ShardingStrategy.NO_SHARD, ) @@ -73,7 +73,16 @@ def _load_coca() -> FSDP: GPT2_LINEAR = 66130944 GPT2_EMBEDDING = 768 * (50304 + 2048) # n_embd * (vocab_size + sequence_length) GPT2_LAYERNORM = 768 * 50 # n_embd * num_layer_norms -COCA_ALL = 184502784 + +COCA_LINEAR = 227321088 +COCA_CONV = 9226752 +# (n_embd * vocab_size) + +# (n_embd * (text_block_size + img_block_size + vid_block_size + aud_block_size + num_frames) +COCA_EMBEDDING = (768 * 50304) + (768 * ((1024 + 1) + 196 + 196 + 500 + 16)) +COCA_NORM = 768 * 152 # n_embd * norm layers +# 3 * (n_queries + 1) n_embd + logit_scale + (1(cls_token) * n_embd) + (n_latents * n_embd) +COCA_PARAMETER = (3 * 257 * 768) + 1 + (768) + (64 * 768) +COCA_ALL = COCA_LINEAR + COCA_CONV + COCA_EMBEDDING + COCA_NORM + COCA_PARAMETER @pytest.mark.skipif( @@ -92,6 +101,16 @@ def _load_coca() -> FSDP: ("gpt2", 1e-1, ["non-existing-group"], False, None, None), ("coca", 0, [], True, 0, COCA_ALL), ("coca", 1e-1, [], True, COCA_ALL, 0), + ("coca", 1e-1, ["embedding"], True, COCA_ALL - COCA_EMBEDDING, COCA_EMBEDDING), + ("coca", 1e-1, ["embedding", "norm"], True, COCA_ALL - COCA_EMBEDDING - COCA_NORM, COCA_EMBEDDING + COCA_NORM), + ( + "coca", + 1e-1, + ["embedding", "norm", "parameter"], + True, + COCA_LINEAR + COCA_CONV, + COCA_EMBEDDING + COCA_NORM + COCA_PARAMETER, + ), ("coca", 1e-1, ["non-existing-group"], False, None, None), ], ) diff --git a/tests/test_yaml_configs/coca_config_initialization.yaml b/tests/test_yaml_configs/coca_config_initialization.yaml index bda3fb253..42547001c 100644 --- a/tests/test_yaml_configs/coca_config_initialization.yaml +++ b/tests/test_yaml_configs/coca_config_initialization.yaml @@ -19,13 +19,39 @@ model_raw: variant_key: coca config: prediction_key: logits - vision_embd_prediction_key: vision_embeddings + audio_embd_prediction_key: audio_embeddings + image_embd_prediction_key: image_embeddings + video_embd_prediction_key: video_embeddings text_embd_prediction_key: text_embeddings - vision_cls_prediction_key: vision_cls + image_cls_prediction_key: image_cls + image_text_cls_prediction_key: image_text_cls + audio_cls_prediction_key: audio_cls + audio_text_cls_prediction_key: audio_text_cls + video_cls_prediction_key: video_cls + video_text_cls_prediction_key: video_text_cls text_cls_prediction_key: text_cls - vision_encoder_config: + modality_keys: + - audio + - images + - video + is_audio_video: false + individual_datasets: true + logit_scale_prediction_key: logit_scale + audio_encoder_config: + sample_key: audio + prediction_key: audio_embeddings + block_size: 500 + n_mels: 128 + n_embd: 768 + n_heads: 4 + n_conformer_blocks: 3 + attention_config: + attention_engine_type: default_attention + pointwise_conv_kernel_size: 1 + depthwise_conv_kernel_size: 31 + image_encoder_config: sample_key: images - prediction_key: vision_embeddings + prediction_key: image_embeddings img_size: 224 n_classes: Null # Disable vision transformer head n_layer: 6 @@ -39,6 +65,24 @@ model_raw: n_img_channels: 3 add_cls_token: False bias: True + video_encoder_config: + sample_key: video + prediction_key: video_embeddings + img_size: 224 # 288 in the original coca + n_classes: Null # Disable vision transformer head + n_layer: 6 + attention_config: + attention_engine_type: default_attention + n_head: 8 + n_embd: 768 + dropout: 0.0 + patch_size: 16 # 18 in the original coca + patch_stride: 16 # 18 in the original coca + n_img_channels: 3 + add_cls_token: False + bias: True + num_video_frames: 16 + n_latents: 64 text_decoder_config: sample_key: input_ids prediction_key: logits @@ -55,7 +99,7 @@ model_raw: bias: true activation: swiglu epsilon: 1e-5 - n_pool_head: 8 - n_vision_queries: 256 + n_pool_head: 12 + n_queries: 256 bias_attn_pool: False - epsilon_attn_pool: 1e-5 \ No newline at end of file + epsilon_attn_pool: 1e-5