In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2
import numpy as np
%matplotlib inline

In [None]:
video_path = "/content/training_video.mp4"
cap = cv2.VideoCapture(video_path)
video_data = []

while cap.isOpened():
    ret, frame = cap.read()
    if not ret:
        break
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
    video_data.append(gray_frame)

cap.release()
video_data = np.array(video_data)
print("Video data shape:", video_data.shape)

In [None]:
def extract_patches(frame, patch_size=16):
    """
    frame: 2D NumPy array of shape (H, W), e.g. (128, 128).
    patch_size: size of each square patch.

    Returns:
      A NumPy array of shape (num_patches, patch_dim),
      where num_patches = (H/patch_size)*(W/patch_size),
      and patch_dim = patch_size*patch_size.
    """
    H, W = frame.shape
    patches = []
    for i in range(0, H, patch_size):
        for j in range(0, W, patch_size):
            patch = frame[i:i+patch_size, j:j+patch_size]
            patches.append(patch.reshape(-1))  # flatten 16×16 -> 256
    # shape: (64, 256) if H=128, W=128, patch_size=16
    return np.stack(patches, axis=0)


def build_dataset_patches(video_data, context_length=3, patch_size=16):
    """
    video_data: NumPy array of shape (T, H, W), e.g. (11072, 128, 128).
    context_length: number of previous frames' patches to include as input.
    patch_size: dimension of each square patch.

    Returns:
        X: shape (T, context_length*64, 256),
           i.e. for each time t, we have N frames × 64 patches/frame = N*64 patches,
           each patch is 256-d if patch_size=16.
        Y: shape (T, 64, 256),
           i.e. the patches for the "current" frame t.
    """
    T, H, W = video_data.shape

    # For 128x128 with 16x16 patches, you get 64 patches per frame.
    # Each patch is 16x16=256 floats when flattened.
    num_patches = (H // patch_size) * (W // patch_size)  # e.g. 64
    patch_dim = patch_size * patch_size                  # e.g. 256

    # We'll store inputs and targets for each frame
    inputs = []
    targets = []

    # Initialize context with zero patches for the first 'context_length' frames
    zero_patches = np.zeros((num_patches, patch_dim), dtype=np.float32)
    context = [zero_patches] * context_length

    for t in range(T):
        # Extract patches for the current frame
        curr_patches = extract_patches(video_data[t], patch_size=patch_size).astype(np.float32)

        # Concatenate the patches from the previous N frames along the patch dimension
        # shape: (N*64, 256)
        x_t = np.concatenate(context, axis=0)  # context_length*64 rows, each 256-d
        y_t = curr_patches  # shape (64, 256) if predicting the "current" frame

        inputs.append(x_t)
        targets.append(y_t)

        # Shift context: drop the oldest frame's patches, add the current frame's patches
        context = context[1:] + [curr_patches]

    # Convert to PyTorch tensors
    X = torch.tensor(np.stack(inputs, axis=0))   # shape (T, N*64, 256)
    Y = torch.tensor(np.stack(targets, axis=0))  # shape (T, 64, 256)

    print("X shape:", X.shape, "Y shape:", Y.shape)
    return X, Y


X_all, Y_all = build_dataset_patches(video_data, context_length=3, patch_size=16)
# e.g. X_all shape: (10, 3*64, 256) = (10, 192, 256)
#      Y_all shape: (10, 64, 256)

# Then split into train/dev/test
n1 = int(0.8 * len(X_all))
n2 = int(0.9 * len(X_all))
Xtr, Ytr = X_all[:n1], Y_all[:n1]
Xdev, Ydev = X_all[n1:n2], Y_all[n1:n2]
Xte, Yte = X_all[n2:], Y_all[n2:]

print("Train shapes:", Xtr.shape, Ytr.shape)
print("Dev shapes:", Xdev.shape, Ydev.shape)
print("Test shapes:", Xte.shape, Yte.shape)

In [None]:
# ---------------------
# Hyperparameters
# ---------------------
g = torch.Generator().manual_seed(2147483647)  # for reproducibility

context_length = 3           # number of past frames
n_patches = 64               # e.g., 128x128 frames split into 16x16 patches => 64 patches
patch_dim = 256              # each patch is 16x16 => 256 pixels
embedding_dim = 64           # dimension for patch embedding
hidden_dim = 400             # dimension of hidden layer
output_dim = n_patches * patch_dim  # 64 * 256 = 16384 (predict next frame's patches)

# ---------------------
# Parameter Shapes
# ---------------------
# C: (patch_dim, embedding_dim) => 256 -> 64 projection for each patch
C = torch.randn((patch_dim, embedding_dim), generator=g) * 0.01

# MLP input: context_length * n_patches * embedding_dim
# e.g. 3 frames * 64 patches * 64 embedding = 3*64*64 = 12288
mlp_input_dim = context_length * n_patches * embedding_dim

# MLP output: next frame's 64 patches * 256 pixels/patch
# => 16384 floats
mlp_output_dim = output_dim

# W1, b1 define the hidden layer
W1 = torch.randn((mlp_input_dim, hidden_dim), generator=g) * 0.01
b1 = torch.zeros(hidden_dim)

# W2, b2 define the output layer
W2 = torch.randn((hidden_dim, mlp_output_dim), generator=g) * 0.01
b2 = torch.zeros(mlp_output_dim)

# Put all parameters in a list for easy zero_grad / updates
parameters = [C, W1, b1, W2, b2]

In [None]:
sum(p.nelement() for p in parameters) # number of parameters in total

In [None]:
for p in parameters:
  p.requires_grad = True

In [None]:
'''
X shape: torch.Size([11072, 192, 256]) Y shape: torch.Size([11072, 64, 256])
Train shapes: torch.Size([8857, 192, 256]) torch.Size([8857, 64, 256])
Dev shapes: torch.Size([1107, 192, 256]) torch.Size([1107, 64, 256])
Test shapes: torch.Size([1108, 192, 256]) torch.Size([1108, 64, 256])
C: 256x64
W: 12288x400
'''

In [None]:
def forward(X_batch):
    """
    X_batch: (batch_size, 192, 256)
        where 192 = context_length * 64 patches,
              256 = flattened patch size.
    Returns:
        logits: (batch_size, 16384)
        This is the predicted next frame (64 patches of 256 floats each).
    """
    # 1) Embed each patch: shape => (batch_size, 192, 64)
    emb = X_batch @ C  # C is (256, 64)

    # 2) Flatten embedded patches: shape => (batch_size, 192*64) = (batch_size, 12288)
    emb_flat = emb.reshape(emb.shape[0], -1)

    # 3) Hidden layer
    h = torch.tanh(emb_flat @ W1 + b1)  # shape => (batch_size, 200)

    # 4) Output layer -> next frame's pixels
    logits = h @ W2 + b2  # shape => (batch_size, 16384)

    return logits

num_epochs = 10000
batch_size = 32
lr = 0.1

for epoch in range(num_epochs):
    # Sample a random minibatch
    ix = torch.randint(0, Xtr.shape[0], (batch_size,))
    X_batch = Xtr[ix]  # shape (batch_size, 192, 256)
    Y_batch = Ytr[ix]  # shape (batch_size, 64, 256)

    logits = forward(X_batch)   # shape (batch_size, 16384)

    Y_batch_flat = Y_batch.reshape(Y_batch.shape[0], -1)

    loss = F.mse_loss(logits, Y_batch_flat)

    for p in parameters:
        p.grad = None
    loss.backward()

    for p in parameters:
        p.data -= lr * p.grad

    if epoch % 10 == 0:
        print(f"Epoch {epoch}, loss = {loss.item():.4f}")


In [None]:
# training loss
logits = forward(Xtr)
loss = F.mse_loss(logits, Ytr.reshape(Ytr.shape[0], -1))
loss

In [None]:
# validation loss
logits = forward(Xdev)
loss = F.mse_loss(logits, Ydev.reshape(Ytr.shape[0], -1))
loss

In [None]:
# test loss
logits = forward(Xte)
loss = F.mse_loss(logits, Yte.reshape(Ytr.shape[0], -1))
loss

In [None]:
def unpatchify(patches, patch_size=16, frame_size=128):
    """
    patches: shape (64, 256) => 64 patches, each 16x16=256 pixels
    Returns: A (128,128) 2D array
    """
    # (8 x 8) grid of patches, each (16 x 16)
    out_frame = np.zeros((frame_size, frame_size), dtype=np.float32)
    idx = 0
    for i in range(0, frame_size, patch_size):
        for j in range(0, frame_size, patch_size):
            patch = patches[idx].reshape(patch_size, patch_size)
            out_frame[i:i+patch_size, j:j+patch_size] = patch
            idx += 1
    return out_frame

def predict_next_frame(context_frames,
                       C, W1, b1, W2, b2,
                       patch_size=16, embed_dim=64, hidden_dim=200):
    """
    context_frames: shape (N, 64, 256) in numpy or torch (the last N frames, each 64 patches).
    Returns: predicted_patches, shape (64, 256), as a torch tensor
    """
    # 1) Merge the N frames into one shape: (N*64, 256)
    #    Then add batch dimension (1, N*64, 256)
    ctx = torch.tensor(context_frames, dtype=torch.float32)
    ctx = ctx.view(1, -1, 256)  # shape (1, N*64, 256)

    # 2) Embedding: (1, N*64, 256) @ (256, embed_dim) => (1, N*64, embed_dim)
    emb = ctx @ C

    # 3) Flatten: => (1, N*64*embed_dim)
    emb_flat = emb.reshape(1, -1)

    # 4) Hidden layer
    h = torch.tanh(emb_flat @ W1 + b1)

    # 5) Output layer => (1, 64*256)
    out = h @ W2 + b2

    # 6) Reshape => (64, 256)
    predicted_patches = out.view(64, 256)
    return predicted_patches

def sample_frames(initial_frames,
                  C, W1, b1, W2, b2,
                  num_generate=20,
                  context_length=3):
    """
    initial_frames: list (or np array) of at least `context_length` frames,
                    each shape (64, 256) in patch form (like your Y or dataset).
    Returns: list of predicted frames in patch form.
    """
    generated = []

    # Start with the last `context_length` frames from initial_frames as context
    context_buffer = list(initial_frames[-context_length:])

    for i in range(num_generate):
        # 1) Predict next frame from context_buffer
        predicted_patches = predict_next_frame(
            context_buffer, C, W1, b1, W2, b2
        )

        # 2) Convert to numpy (if needed) and store
        pred_np = predicted_patches.detach().cpu().numpy()
        generated.append(pred_np)

        # 3) Update context: drop oldest, append new
        context_buffer = context_buffer[1:] + [pred_np]

    return generated

# Example usage:
# Suppose you have some list of frames_in_patches (each (64,256)) from your dataset
# as a "seed" context.
seed_frames = [Ytr[0].numpy(), Ytr[1].numpy(), Ytr[2].numpy()]

# Generate frames
generated_frames = sample_frames(
    seed_frames, C, W1, b1, W2, b2,
    num_generate=200, context_length=3
)