# Pico-V-JEPA

An even smaller V-JEPA implementation.

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.io import read_video
from torch.utils.data import Dataset, DataLoader

# from google.colab import drive
# Mount Google Drive
# drive.mount('/content/drive')

In [None]:
# Define video dataset class
class VideoDataset(Dataset):
    def __init__(self, video_paths, transform=None):
        self.video_paths = video_paths
        self.transform = transform

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        video, audio, info = read_video(video_path, pts_unit="sec")
        video = video.permute(0, 3, 1, 2)  # Reorder dimensions to (T, C, H, W)

        if self.transform:
            video = self.transform(video)

        return video

In [None]:
transform = transforms.Compose(
    [
        transforms.Resize((64, 64)),  # Resize to the desired input size
        transforms.ToTensor(),
        # Other necessary transformations
    ]
)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class VJEPA(nn.Module):
    def __init__(self, input_shape, embedding_dim=256, prediction_horizon=1):
        super(VJEPA, self).__init__()
        self.input_shape = input_shape  # (C, T, H, W)
        self.embedding_dim = embedding_dim
        self.prediction_horizon = prediction_horizon

        # Encoder (3D CNN)
        self.encoder = nn.Sequential(
            nn.Conv3d(
                input_shape[0],
                32,
                kernel_size=(3, 3, 3),
                stride=(1, 2, 2),
                padding=(1, 1, 1),
            ),
            nn.ReLU(),
            nn.Conv3d(
                32, 64, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)
            ),
            nn.ReLU(),
            nn.Conv3d(
                64, 128, kernel_size=(3, 3, 3), stride=(1, 2, 2), padding=(1, 1, 1)
            ),
            nn.ReLU(),
            nn.AdaptiveAvgPool3d((1, 1, 1)),  # Global Average Pooling
            nn.Flatten(),
            nn.Linear(128, embedding_dim),
        )

        # Predictor (simple linear layer - can be replaced with a more complex network)
        self.predictor = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, video_frames, masked_frames):
        """
        Args:
            video_frames (torch.Tensor):  (batch_size, C, T, H, W)
            masked_frames (torch.Tensor): (batch_size, C, T, H, W)  Masked version of the input
        Returns:
            torch.Tensor: loss
        """
        batch_size, _, time_steps, _, _ = video_frames.shape

        encoded_frames = []
        encoded_masked_frames = []

        # Encode original and masked videos
        for t in range(time_steps):
            encoded_frames.append(self.encoder(video_frames[:, :, t : t + 1, :, :]))
            encoded_masked_frames.append(
                self.encoder(masked_frames[:, :, t : t + 1, :, :])
            )

        encoded_frames = torch.stack(encoded_frames, dim=1)  # (B, T, D)
        encoded_masked_frames = torch.stack(encoded_masked_frames, dim=1)  # (B, T, D)

        loss = 0
        for t in range(time_steps - self.prediction_horizon):
            # Predict future representation from masked context
            predicted_embedding = self.predictor(encoded_masked_frames[:, t, :])
            target_embedding = encoded_frames[:, t + self.prediction_horizon, :]
            loss += F.mse_loss(predicted_embedding, target_embedding)

        return loss / (time_steps - self.prediction_horizon)


def create_mask(video_frames, mask_ratio=0.5):
    """
    Creates a spatio-temporal mask for the input video frames.

    Args:
        video_frames (torch.Tensor): (batch_size, C, T, H, W)
        mask_ratio (float):            Fraction of the video to mask (between 0 and 1)

    Returns:
        torch.Tensor: masked_video_frames (batch_size, C, T, H, W)
        torch.Tensor: mask (batch_size, T, H, W)  Bool Tensor, 1 for keep, 0 for mask
    """
    batch_size, channels, time_steps, height, width = video_frames.shape
    mask = torch.ones(
        (batch_size, time_steps, height, width),
        dtype=torch.bool,
        device=video_frames.device,
    )
    num_masked_pixels = int(mask_ratio * time_steps * height * width)

    for b in range(batch_size):
        # Randomly select indices to mask
        rand_indices = torch.randperm(
            time_steps * height * width, device=video_frames.device
        )[:num_masked_pixels]
        for index in rand_indices:
            t = index // (height * width)
            hw_index = index % (height * width)
            h = hw_index // width
            w = hw_index % width
            mask[b, t, h, w] = False

    masked_video_frames = video_frames.clone()
    for b in range(batch_size):
        for t in range(time_steps):
            masked_video_frames[b, :, t, :, :] = masked_video_frames[
                b, :, t, :, :
            ] * mask[b, t, :, :].unsqueeze(0)

    return masked_video_frames, mask

Before running the code

Replace /content/drive/MyDrive/videos with the actual path to your video folder in Google Drive.
Make sure your videos are in a compatible format (e.g., .mp4) and resized to the expected input size (64x64 in this case).
Consider adding more data preprocessing steps (e.g., normalization, data augmentation) as needed.

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define hyperparameters
video_folder = "nano-dataset-k400-5c-10v"  # Replace with your video folder path
batch_size = 2
input_shape = (3, 16, 64, 64)  # (C, T, H, W)
num_epochs = 10
learning_rate = 1e-3

In [None]:
# Create video paths list
video_paths = [
    os.path.join(video_folder, filename)
    for filename in os.listdir(video_folder)
    if filename.endswith(".mp4")
]

# Create dataset and dataloader
transform = transforms.Compose(
    [
        transforms.Resize((input_shape[2], input_shape[3])),  # Resize frames
        transforms.ToTensor(),
    ]
)
dataset = VideoDataset(video_paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
# Create model and optimizer
model = VJEPA(input_shape).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
# Training loop
for epoch in range(num_epochs):
    for batch_idx, video_data in enumerate(dataloader):
        video_data = video_data.to(device)
        masked_video_data, mask = create_mask(video_data, mask_ratio=0.75)
        loss = model(video_data, masked_video_data)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        print(
            f"Epoch [{epoch + 1}/{num_epochs}], Batch [{batch_idx + 1}/{len(dataloader)}], Loss: {loss.item()}"
        )