In [1]:
!pip install einops



In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torchvision.models import vit_b_16
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from einops import rearrange
import os
from datetime import datetime

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [3]:
# Hyperparameters
lambda_rep = 1.0  # Weight for JEPA predictive loss
lambda_flow = 0.5  # Weight for optical flow loss
lambda_smooth = 0.1  # Weight for flow smoothness loss
batch_size = 8
num_epochs = 50
learning_rate = 1e-4
max_displacement = 4  # For correlation layer

In [4]:
# TensorBoard setup
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = f"runs/jepa_flow_hybrid_{timestamp}"
writer = SummaryWriter(log_dir)

In [5]:
# Custom Dataset for video frames (dummy implementation)
class VideoFrameDataset(Dataset):
    def __init__(self, num_samples=1000, frame_size=224):
        self.num_samples = num_samples
        self.frame_size = frame_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # Generate random adjacent frames with some motion
        x_t = torch.rand(3, self.frame_size, self.frame_size)

        # Create x_t1 by applying a simple transformation to x_t
        # In a real dataset, this would be actual video frames
        affine_mat = torch.eye(2, 3)
        affine_mat[0, 2] = torch.rand(1) * 0.1  # random x translation
        affine_mat[1, 2] = torch.rand(1) * 0.1  # random y translation

        grid = F.affine_grid(affine_mat.unsqueeze(0), x_t.unsqueeze(0).shape)
        x_t1 = F.grid_sample(x_t.unsqueeze(0), grid)
        x_t1 = x_t1.squeeze(0)

        return x_t, x_t1

In [6]:
# Correlation Layer for optical flow
class CorrelationLayer(nn.Module):
    def __init__(self, max_displacement=4):
        super().__init__()
        self.max_displacement = max_displacement
        self.kernel_size = 2 * max_displacement + 1

    def forward(self, f_t, f_t1):
        b, c, h, w = f_t.size()
        corr = torch.zeros(b, self.kernel_size * self.kernel_size, h, w).to(f_t.device)

        # Pad feature maps for displacement
        padding = self.max_displacement
        f_t1_padded = F.pad(f_t1, (padding, padding, padding, padding))

        # Compute correlation for each displacement
        for i in range(-self.max_displacement, self.max_displacement + 1):
            for j in range(-self.max_displacement, self.max_displacement + 1):
                # Extract shifted version of f_t1
                shifted_f_t1 = f_t1_padded[
                    :, :, padding + i : padding + i + h, padding + j : padding + j + w
                ]

                # Compute correlation
                idx = (i + self.max_displacement) * self.kernel_size + (
                    j + self.max_displacement
                )
                corr[:, idx, :, :] = (f_t * shifted_f_t1).sum(dim=1)

        return corr


# Flow Head with correlation
class FlowHead(nn.Module):
    def __init__(self, input_channels, hidden_channels=256):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, hidden_channels, 3, padding=1)
        self.conv2 = nn.Conv2d(hidden_channels, hidden_channels // 2, 3, padding=1)
        self.conv3 = nn.Conv2d(
            hidden_channels // 2, 2, 3, padding=1
        )  # Output flow field (u, v)

    def forward(self, corr_volume):
        x = F.relu(self.conv1(corr_volume))
        x = F.relu(self.conv2(x))
        flow = self.conv3(x)
        return flow

In [7]:
# Spatial JEPA Predictor
class SpatialPredictor(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=512):
        super().__init__()
        self.conv1 = nn.Conv2d(input_dim, hidden_dim, 1)  # 1x1 convolution
        self.conv2 = nn.Conv2d(hidden_dim, output_dim, 1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.conv2(x)
        return x

In [8]:
# EMA Teacher Network
class EMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

        # Register shadow parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                new_average = (
                    1.0 - self.decay
                ) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

In [9]:
# Differentiable warping function
def warp(x, flow):
    # x: [B, C, H, W]
    # flow: [B, 2, H, W] (dx, dy)
    B, C, H, W = x.size()

    # Create grid
    grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
    grid = torch.stack((grid_x, grid_y), 2).float().to(x.device)
    grid = grid.unsqueeze(0).repeat(B, 1, 1, 1)  # [B, H, W, 2]

    # Add flow to grid
    new_grid = grid + flow.permute(0, 2, 3, 1)

    # Normalize grid to [-1, 1]
    new_grid[:, :, :, 0] = 2.0 * new_grid[:, :, :, 0] / max(W - 1, 1) - 1.0
    new_grid[:, :, :, 1] = 2.0 * new_grid[:, :, :, 1] / max(H - 1, 1) - 1.0

    # Sample using grid_sample
    warped = F.grid_sample(x, new_grid, align_corners=True)
    return warped


# Smoothness loss for optical flow
def smoothness_loss(flow):
    # Calculate gradients of flow
    dx = torch.abs(flow[:, :, :, :-1] - flow[:, :, :, 1:])
    dy = torch.abs(flow[:, :, :-1, :] - flow[:, :, 1:, :])

    # Sum of all gradients
    loss = dx.sum() + dy.sum()
    return loss

In [10]:
# Main model
class JEPAFlowHybrid(nn.Module):
    def __init__(self):
        super().__init__()

        # Backbone - ViT with spatial output
        self.backbone = vit_b_16(pretrained=True)

        # Modify ViT to output spatial features
        self.backbone.heads = nn.Identity()  # Remove classification head

        # JEPA predictor
        self.jepa_predictor = SpatialPredictor(768, 768)

        # Correlation layer
        self.correlation = CorrelationLayer(max_displacement=max_displacement)

        # Flow head
        self.flow_head = FlowHead(max_displacement * 2 + 1)

        # EMA teacher
        self.teacher = EMA(self.backbone)

    def forward(self, x_t, x_t1):
        # Extract features
        f_t = self.backbone(x_t)  # [B, 768, H/16, W/16] for ViT-B/16
        f_t1 = self.backbone(x_t1)

        # JEPA prediction
        f_t1_pred = self.jepa_predictor(f_t)

        # Flow estimation
        corr_volume = self.correlation(f_t, f_t1)
        flow = self.flow_head(corr_volume)

        return f_t1_pred, flow, f_t1

In [11]:
# Initialize model
model = JEPAFlowHybrid().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Dataset and DataLoader
dataset = VideoFrameDataset(num_samples=1000)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Training loop
global_step = 0
for epoch in range(num_epochs):
    epoch_rep_loss = 0.0
    epoch_flow_loss = 0.0
    epoch_total_loss = 0.0

    for batch_idx, (x_t, x_t1) in enumerate(dataloader):
        # Move data to device
        x_t = x_t.to(device)
        x_t1 = x_t1.to(device)

        # Forward pass
        f_t1_pred, flow, f_t1 = model(x_t, x_t1)

        # JEPA loss
        with torch.no_grad():
            f_t1_target = model.teacher.model(x_t1)
        rep_loss = F.mse_loss(f_t1_pred, f_t1_target)

        # Flow loss
        x_t1_warped = warp(x_t, flow)
        photo_loss = F.l1_loss(x_t1_warped, x_t1)
        smooth_loss = smoothness_loss(flow)
        flow_loss = photo_loss + lambda_smooth * smooth_loss

        # Total loss
        total_loss = lambda_rep * rep_loss + lambda_flow * flow_loss

        # Backward pass and optimize
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # Update EMA teacher
        model.teacher.update()

        # Log losses
        writer.add_scalar("Loss/JEPA", rep_loss.item(), global_step)
        writer.add_scalar("Loss/Flow", flow_loss.item(), global_step)
        writer.add_scalar("Loss/Total", total_loss.item(), global_step)
        writer.add_scalar("Loss/Photometric", photo_loss.item(), global_step)
        writer.add_scalar("Loss/Smoothness", smooth_loss.item(), global_step)

        # Update epoch statistics
        epoch_rep_loss += rep_loss.item()
        epoch_flow_loss += flow_loss.item()
        epoch_total_loss += total_loss.item()
        global_step += 1

        if batch_idx % 10 == 0:
            print(
                f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}/{len(dataloader)}], "
                f"JEPA Loss: {rep_loss.item():.4f}, Flow Loss: {flow_loss.item():.4f}"
            )

    # Log epoch averages
    avg_rep_loss = epoch_rep_loss / len(dataloader)
    avg_flow_loss = epoch_flow_loss / len(dataloader)
    avg_total_loss = epoch_total_loss / len(dataloader)

    writer.add_scalar("Epoch/JEPA", avg_rep_loss, epoch)
    writer.add_scalar("Epoch/Flow", avg_flow_loss, epoch)
    writer.add_scalar("Epoch/Total", avg_total_loss, epoch)

    print(
        f"Epoch [{epoch+1}/{num_epochs}], "
        f"Avg JEPA Loss: {avg_rep_loss:.4f}, Avg Flow Loss: {avg_flow_loss:.4f}, "
        f"Avg Total Loss: {avg_total_loss:.4f}"
    )

# Close TensorBoard writer
writer.close()
print("Training completed. Run 'tensorboard --logdir=runs' to view the loss plots.")



Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to /home/jefferyfan/.cache/torch/hub/checkpoints/vit_b_16-c867db91.pth


100%|██████████| 330M/330M [00:34<00:00, 10.1MB/s] 


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [8, 768]