In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import random
import pytorch_wavelets.dwt.lowlevel as lowlevel
import pywt
from rich.console import Console
from torch.utils.data import WeightedRandomSampler
import cv2

In [2]:
class DWTForward(nn.Module):
    """
    Applies a single-level 2D Discrete Wavelet Transform using Haar wavelets.
    Returns:
        LL: Low-frequency component
        HF: High-frequency components concatenated along channel dimension
    """
    def __init__(self, wave="haar", mode="zero"):
        super().__init__()
        if isinstance(wave, str):
            wave = pywt.Wavelet(wave)

        # Wavelet filter coefficients
        h0_col, h1_col = wave.dec_lo, wave.dec_hi
        h0_row, h1_row = h0_col, h1_col

        # Prepare filter banks (pytorch_wavelets)
        filts = lowlevel.prep_filt_afb2d(h0_col, h1_col, h0_row, h1_row)

        self.register_buffer("h0_col", filts[0])
        self.register_buffer("h1_col", filts[1])
        self.register_buffer("h0_row", filts[2])
        self.register_buffer("h1_row", filts[3])

        self.mode = lowlevel.mode_to_int(mode)

    def forward(self, x):
        # ll = low frequency      shape -> [B, C, H/2, W/2]
        # high = tensor of [LH, HL, HH] concatenated -> [B, 3*C, H/2, W/2]
        ll, high = lowlevel.AFB2D.apply(
            x, self.h0_col, self.h1_col, self.h0_row, self.h1_row, self.mode
        )
        return ll, high


In [3]:
class MLFFE_Frame(nn.Module):
    """
    Extracts a 512-dimensional frequency feature vector for a single frame.
    Consists of:
        - Initial CNN embedding
        - 3-level hierarchical Haar DWT
        - Concatenation of all levels (LL + HF)
        - Final classifier -> 512-dim feature
    """
    def __init__(self, embed_ch=64):
        super().__init__()

        # Shallow image embedding
        self.conv = nn.Conv2d(3, embed_ch, kernel_size=3, padding=1)

        # DWT operator (Haar)
        self.dwt = DWTForward(wave="haar")

        # Each level is pooled to 7×7 
        self.pool = nn.AdaptiveAvgPool2d((7, 7))

        # (LL + LH + HL + HH) * 3 levels = 12 × C channels
        total_channels = embed_ch * 12

        # Final linear mapper -> 512 dims
        self.classifier = nn.Sequential(
            nn.BatchNorm2d(total_channels),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.LayerNorm(total_channels * 3 * 3),
            nn.Dropout(0.3),
            nn.PReLU(),
            nn.Linear(total_channels * 3 * 3, 512),
        )

    def forward(self, img):
        """
        img: [B, 3, 224, 224]
        return: [B, 512]
        """
        x = self.conv(img)
        b, c, h, w = x.shape

        # ---- Level 1 ----
        LL1, HF1 = self.dwt(x)
        HF1 = HF1.view(b, 3 * c, h // 2, w // 2)
        X1 = self.pool(torch.cat([LL1, HF1], dim=1))

        # ---- Level 2 ----
        LL2, HF2 = self.dwt(LL1)
        HF2 = HF2.view(b, 3 * c, h // 4, w // 4)
        X2 = self.pool(torch.cat([LL2, HF2], dim=1))

        # ---- Level 3 ----
        LL3, HF3 = self.dwt(LL2)
        HF3 = HF3.view(b, 3 * c, h // 8, w // 8)
        X3 = self.pool(torch.cat([LL3, HF3], dim=1))

        # Combine all 3 levels
        X = torch.cat([X1, X2, X3], dim=1)

        feat = self.classifier(X)
        return feat   # shape: [B, 512]

In [4]:
class VideoMLFFE(nn.Module):
    """
    Applies ML-FFE on all frames of a video.
    Input:  [B, T, 3, 224, 224]
    Output: [B, T, 512]      512-D feature per frame (for temporal models)
    """
    def __init__(self, embed_ch=64):
        super().__init__()
        self.frame_model = MLFFE_Frame(embed_ch)

    def forward(self, videos):
        """
        videos: [Batch, Frames, 3, 224, 224]
        returns:
            frame_feats: [B, T, 512]
        """
        B, T, C, H, W = videos.shape

        # Vectorized processing: flatten batch & time
        frames = videos.view(B * T, C, H, W)    # [B*T, 3, 224, 224]

        # Extract per-frame features
        feats = self.frame_model(frames)        # [B*T, 512]

        # Restore temporal structure
        frame_feats = feats.view(B, T, 512)

        return frame_feats

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = VideoMLFFE(embed_ch=64).to(device)

B, T = 1, 64
dummy = torch.randn(B, T, 3, 224, 224).to(device)

out = model(dummy)
print(out.shape)

RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [None]:
from torchinfo import summary
from rich.console import Console

def pretty_summary(model):
    console = Console()
    model_summary = summary(
        model,
        col_names=("num_params", "kernel_size"),
        depth=5,
        verbose=0
    )
    console.print(model_summary)


device = "cuda" if torch.cuda.is_available() else "cpu"
model = MLFFE_Frame(embed_ch=64).to(device)

pretty_summary(model)


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(list(model.parameters()) + list(classifier.parameters()), lr=1e-4)