# Multimodal Dataset & Collator Checks


This notebook exercises the ImageNet-based multimodal dataset and collator setup used in training.
Each cell mirrors the same pathways the training script relies on (paths from the default config).


In [None]:
import os
import random
from pprint import pprint

import torch

from src.multimodal.multimodal_training_config import MultimodalTrainingConfig
from src.multimodal.multimodal_training import load_multimodal_dataset
from src.datasets.imagenet.imagenet_dataset import ImageNetDataset, MultimodalCollator
from src.utils import create_transforms
from transformers import AutoTokenizer, AutoImageProcessor
from torch.utils.data import DataLoader, Subset


In [None]:
# Load default multimodal config (assumes repository paths are valid on this machine)
config = MultimodalTrainingConfig.from_params({})
pprint({
    'mapping_path': config.mapping_path,
    'extra_mapping_path': config.extra_mapping_path,
    'image_root': config.image_root,
    'vision_model_name': config.vision_model_name,
    'language_model_name': config.language_model_name,
    'num_vision_tokens': config.num_vision_tokens,
    'train_transforms': config.train_transforms,
    'val_transforms': config.val_transforms,
})


In [None]:
# Sanity-check expected filesystem inputs
assert os.path.exists(config.mapping_path), f'Missing mapping CSV at {config.mapping_path}'
if config.extra_mapping_path:
    assert os.path.exists(config.extra_mapping_path), f'Missing extra mapping CSV at {config.extra_mapping_path}'
assert os.path.exists(config.image_root), f'Missing image root at {config.image_root}'
print('All required paths located.')


In [None]:
# Load the raw ImageNet mapping as a dataset (no train/val split)
full_dataset = ImageNetDataset(
    config.mapping_path,
    config.image_root,
    transform=None,
    return_synset=True,
)
print(f'Total rows: {len(full_dataset)} | Unique classes: {full_dataset.num_classes}')
print('First three entries:', full_dataset.dataset[:3])
print('First five class labels:', full_dataset.unique_labels[:5])


In [None]:
# Inspect an arbitrary sample
sample_idx = random.randint(0, len(full_dataset) - 1)
sample_image, sample_label = full_dataset[sample_idx]
print(f'Random sample index: {sample_idx} | Class name: {sample_label}')
print('Image type:', type(sample_image))
if hasattr(sample_image, 'size'):
    print('Image size:', sample_image.size)


In [None]:
# Reuse the training helper to produce the stratified train/val subsets
train_dataset, val_dataset = load_multimodal_dataset(config)
print(f'Train subset: {len(train_dataset)} rows | Val subset: {len(val_dataset)} rows')
print('Train subset retains class metadata?', hasattr(train_dataset, 'unique_labels'))
print('Example labels:', train_dataset.unique_labels[:5])


In [None]:
# Assemble tokenizer & image processor for the collator
tokenizer = AutoTokenizer.from_pretrained(
    config.language_model_name,
    use_fast=config.use_fast_tokenizer,
)
image_processor = AutoImageProcessor.from_pretrained(config.vision_model_name)
collator = MultimodalCollator(
    image_processor=image_processor,
    tokenizer=tokenizer,
    num_vision_tokens=config.num_vision_tokens,
    prompt_template=config.prompt_template,
    all_class_names=train_dataset.unique_labels,
)
print('Tokenizer vocab size:', tokenizer.vocab_size)
print('Vision processor size:', image_processor.size if hasattr(image_processor, 'size') else 'n/a')


In [None]:
# Build a small inspection DataLoader
inspection_indices = list(range(min(4, len(train_dataset))))
inspection_subset = Subset(train_dataset, inspection_indices)
inspection_loader = DataLoader(inspection_subset, batch_size=2, shuffle=False, collate_fn=collator)
batch = next(iter(inspection_loader))
print('Batch keys:', batch.keys())
for key, value in batch.items():
    if isinstance(value, torch.Tensor):
        print(f"{key}: shape={tuple(value.shape)} dtype={value.dtype}")
    else:
        print(f"{key}: type={type(value)}")


In [None]:
# Validate that masking matches expectations
labels = batch['labels']
vision_token_block = labels[:, :config.num_vision_tokens]
assert torch.all(vision_token_block == -100), 'Vision positions should be masked with -100'
answer_mask = labels != -100
answer_counts = answer_mask.sum(dim=1)
print('Answer tokens per example:', answer_counts.tolist())
assert torch.all(answer_counts > 0), 'Each prompt should contain supervised answer tokens'
print('Label masking checks passed.')


In [None]:
# Decode the textual portion (excluding vision token padding) for a qualitative check
text_only_ids = batch['input_ids'][:, config.num_vision_tokens:]
decoded_texts = tokenizer.batch_decode(text_only_ids, skip_special_tokens=False)
for idx, text in enumerate(decoded_texts):
    print(f'Example {idx}: {text}')
