In [None]:
# ============================================
# Ultrasound Preprocessing & DataLoader Pipeline
# ============================================

# STEP 1: Install & Import Required Libraries
!pip install opencv-python albumentations torch torchvision tqdm

import os
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision.transforms as T
import albumentations as A
from tqdm import tqdm

# ============================================
# STEP 2: Configurations
# ============================================

DATA_PATH = "/content/ultrasound_dataset"   # <-- unzip Zenodo data here
FRAME_SIZE = (224, 224)                    # resize resolution
FRAMES_PER_VIDEO = 16                      # number of frames to sample
BATCH_SIZE = 8
VAL_SPLIT = 0.2
TEST_SPLIT = 0.1
SEED = 42
torch.manual_seed(SEED)

# ============================================
# STEP 3: Utility Functions
# ============================================

def sample_frames(video_path, num_frames=FRAMES_PER_VIDEO):
    """Load video and sample N frames uniformly."""
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    indices = np.linspace(0, total_frames-1, num_frames, dtype=np.int32)
    
    frames = []
    for i in range(total_frames):
        ret, frame = cap.read()
        if not ret:
            break
        if i in indices:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = cv2.resize(frame, FRAME_SIZE)
            frames.append(frame)
    cap.release()
    return np.array(frames)

def sample_label_frames(label_video_path, num_frames=FRAMES_PER_VIDEO):
    """Same as sample_frames but for segmentation masks."""
    cap = cv2.VideoCapture(label_video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    indices = np.linspace(0, total_frames-1, num_frames, dtype=np.int32)
    
    frames = []
    for i in range(total_frames):
        ret, frame = cap.read()
        if not ret:
            break
        if i in indices:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            frame = cv2.resize(frame, FRAME_SIZE, interpolation=cv2.INTER_NEAREST)
            frames.append(frame)
    cap.release()
    return np.array(frames)

# ============================================
# STEP 4: Dataset Class
# ============================================

class UltrasoundDataset(Dataset):
    def __init__(self, video_dir, label_dir, transform=None, augment=False):
        self.video_dir = video_dir
        self.label_dir = label_dir
        self.video_files = sorted(os.listdir(video_dir))
        self.label_files = sorted(os.listdir(label_dir))
        self.transform = transform
        self.augment = augment
        
        # Define augmentations (only for training set)
        self.aug = A.Compose([
            A.HorizontalFlip(p=0.5),
            A.RandomBrightnessContrast(p=0.5),
        ])
        
    def __len__(self):
        return len(self.video_files)
    
    def __getitem__(self, idx):
        video_path = os.path.join(self.video_dir, self.video_files[idx])
        label_path = os.path.join(self.label_dir, self.label_files[idx])
        
        frames = sample_frames(video_path)
        labels = sample_label_frames(label_path)
        
        # Apply augmentation frame by frame
        if self.augment:
            aug_frames, aug_labels = [], []
            for f, l in zip(frames, labels):
                augmented = self.aug(image=f, mask=l)
                aug_frames.append(augmented["image"])
                aug_labels.append(augmented["mask"])
            frames, labels = np.array(aug_frames), np.array(aug_labels)
        
        # Normalize
        frames = frames.astype(np.float32) / 255.0
        labels = labels.astype(np.int64)
        
        # To torch tensors
        frames = torch.from_numpy(frames).permute(0, 3, 1, 2)  # (T, C, H, W)
        labels = torch.from_numpy(labels)  # (T, H, W)
        
        return frames, labels

# ============================================
# STEP 5: Dataset Splitting
# ============================================

video_dir = os.path.join(DATA_PATH, "videos")
label_dir = os.path.join(DATA_PATH, "labels")

dataset = UltrasoundDataset(video_dir, label_dir, augment=False)
total_len = len(dataset)
val_len = int(total_len * VAL_SPLIT)
test_len = int(total_len * TEST_SPLIT)
train_len = total_len - val_len - test_len

train_set, val_set, test_set = random_split(dataset, [train_len, val_len, test_len])

# Enable augmentation only for training set
train_set.dataset.augment = True

# ============================================
# STEP 6: DataLoaders
# ============================================

train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_set, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False)

print(f"Train: {len(train_loader)} batches, Val: {len(val_loader)} batches, Test: {len(test_loader)} batches")

# ============================================
# STEP 7: Quick Sanity Check
# ============================================

frames, labels = next(iter(train_loader))
print("Frames shape:", frames.shape)   # [B, T, C, H, W]
print("Labels shape:", labels.shape)   # [B, T, H, W]