In [None]:
# Imports and Configuration (litdata backend)
import os, sys, math
import torch
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
from omegaconf import OmegaConf
import litdata as ld

# Add project root to path
sys.path.append('/home/paul/path-fm')

from dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator

# Load and merge config (ssl_default_config + vits14_reg4)
cfg_default = OmegaConf.load('dinov2/configs/ssl_default_config.yaml')
cfg_train = OmegaConf.load('dinov2/configs/train/vits14_reg4.yaml')
cfg = OmegaConf.merge(cfg_default, cfg_train)

# litdata storage options
storage_options = {
    'endpoint_url': os.environ.get('AWS_ENDPOINT_URL'),
    'aws_access_key_id': os.environ.get('AWS_ACCESS_KEY_ID'),
    'aws_secret_access_key': os.environ.get('AWS_SECRET_ACCESS_KEY'),
}
LITDATA_ROOT = 's3://sophont/paul/data/litTCGA'

print('Merged config ready.')
print(' litdata root:', LITDATA_ROOT)
print(' crops: global', cfg.crops.global_crops_size, 'local', cfg.crops.local_crops_size)
print(' patch_size:', cfg.student.patch_size)


In [None]:
# Build augmentation, masking and collate exactly like in training
img_size = cfg.crops.global_crops_size
patch_size = cfg.student.patch_size
n_tokens = (img_size // patch_size) ** 2

mask_generator = MaskingGenerator(
    input_size=(img_size // patch_size, img_size // patch_size),
    max_num_patches=0.5 * img_size // patch_size * img_size // patch_size,
)

data_transform = DataAugmentationDINO(
    cfg.crops.global_crops_scale,
    cfg.crops.local_crops_scale,
    cfg.crops.local_crops_number,
    global_crops_size=cfg.crops.global_crops_size,
    local_crops_size=cfg.crops.local_crops_size,
)

collate_fn = partial(
    collate_data_and_cast,
    mask_ratio_tuple=cfg.ibot.mask_ratio_min_max,
    mask_probability=cfg.ibot.mask_sample_probability,
    n_tokens=n_tokens,
    mask_generator=mask_generator,
    dtype=torch.half,
)

print('n_tokens:', n_tokens)
print('cfg.crops.global_crops_size', cfg.crops.global_crops_size)
print('cfg.crops.local_crops_size', cfg.crops.local_crops_size)
print('mask_ratio_min_max:', list(cfg.ibot.mask_ratio_min_max))
print('mask_probability:', cfg.ibot.mask_sample_probability)


In [None]:
# Create litdata StreamingDataset + DataLoader
def extract_and_transform(item):
    transformed = data_transform(item['image'])
    return (transformed, None)

dataset = ld.StreamingDataset(
    LITDATA_ROOT,
    storage_options=storage_options,
    shuffle=True,
    drop_last=True,
    transform=extract_and_transform,
)

data_loader = ld.StreamingDataLoader(dataset, collate_fn=collate_fn)
batch = next(iter(data_loader))

print('Batch keys:', list(batch.keys()))
for k,v in batch.items():
    if isinstance(v, torch.Tensor):
        print(f' {k}: shape={tuple(v.shape)}, dtype={v.dtype}')

actual_bs = batch['collated_global_crops'].shape[0] // 2
print('Actual batch size (from crops):', actual_bs)


In [None]:
# Visualization helpers and Global Crops
IMAGENET_DEFAULT_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_DEFAULT_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

def unnormalize_img(t):
    x = t.detach().float().cpu().numpy()  # C,H,W
    x = x * IMAGENET_DEFAULT_STD[:, None, None] + IMAGENET_DEFAULT_MEAN[:, None, None]
    x = np.clip(x, 0.0, 1.0)
    return np.transpose(x, (1,2,0))  # H,W,C

def show_image_grid(imgs, ncols=4, title=None, figsize=(14, 6)):
    n = len(imgs)
    nrows = int(math.ceil(n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=figsize)
    axes = np.array(axes).reshape(nrows, ncols) if n>1 else np.array([[axes]])
    for i in range(nrows*ncols):
        r, c = divmod(i, ncols)
        ax = axes[r, c]
        ax.axis('off')
        if i < n:
            ax.imshow(imgs[i])
    if title:
        fig.suptitle(title)
    plt.tight_layout()
    plt.show()

global_crops = batch['collated_global_crops']  # [2*B, C, H, W]
B = global_crops.shape[0] // 2
to_show = []
for i in range(min(B, 4)):
    im1 = unnormalize_img(global_crops[2*i + 0])
    im2 = unnormalize_img(global_crops[2*i + 1])
    to_show.extend([im1, im2])
show_image_grid(to_show, ncols=4, title='Global crops (pairs per sample)')

print('Global crops tensor info:')
print('  Shape:', tuple(global_crops.shape), '(2*B, C, H, W)')
print('  Dtype:', global_crops.dtype)
print(f'  Value range: [{global_crops.min():.3f}, {global_crops.max():.3f}]')
print(f'  Mean: {global_crops.float().mean():.3f}, Std: {global_crops.float().std():.3f}')


In [None]:
# Local Crops
local_crops = batch['collated_local_crops']  # [8*B, C, h, w] by default
n_local = int(cfg.crops.local_crops_number)
first_sample_locals = [unnormalize_img(local_crops[i]) for i in range(n_local)]
show_image_grid(first_sample_locals, ncols=4, title='Local crops (first sample)')

print('Local crops tensor info:')
print('  Shape:', tuple(local_crops.shape), '(8*B, C, h, w)')
print('  Dtype:', local_crops.dtype)
print(f'  Value range: [{local_crops.min():.3f}, {local_crops.max():.3f}]')
print(f'  Mean: {local_crops.float().mean():.3f}, Std: {local_crops.float().std():.3f}')


In [None]:
# iBOT Masks
masks = batch['collated_masks']  # [2*B, N] where N=(img_size//patch_size)^2
grid = img_size // patch_size
n_show = min(masks.shape[0], 4)
fig, axes = plt.subplots(1, n_show, figsize=(3*n_show, 3))
if n_show == 1:
    axes = [axes]
for i in range(n_show):
    m = masks[i].float().reshape(grid, grid).cpu().numpy()
    axes[i].imshow(m, cmap='gray_r', vmin=0, vmax=1)
    axes[i].set_title(f'Mask {i}')
    axes[i].axis('off')
plt.tight_layout()
plt.show()

print('Mask statistics:')
print('  Shape:', tuple(masks.shape), '(2*B, N)')
print('  Dtype:', masks.dtype)
print('  Total patches per image:', n_tokens)
print('  Patches per side:', grid)
for i in range(n_show):
    mask_ratio = masks[i].float().mean().item()
    print(f'  Sample {i+1}: {mask_ratio:.1%} masked ({masks[i].sum():.0f}/{n_tokens} patches)')
