In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

只在 visual encoder + MLP/linear 得到的 [T, N_patch, D_llm] image_features (T, N, D) 上做处理

[T, C, H, W] - [T, N_patch, D_vision] - [T, N_patch, D_llm]

- slow 通路：



- fast 通路：



- slow_fast 拼接：




In [2]:
from dataclasses import dataclass

@dataclass
class SlowFastConfig:
    # slow
    slow_num_frames: int = 8
    slow_spatial_pool: str = "1d_max"   # "1d_max" | "1d_avg" | "none"

    # fast
    fast_spatial_size: int = 4          # e.g. 4 → 4x4

    assume_square_patches: bool = True

In [3]:
import torch
import torch.nn as nn
from einops import rearrange
import math

class SlowFastConfig:
    def __init__(self, slow_num_frames=8, slow_spatial_pool="1d_max", fast_output_size=(4,4)):
        self.slow_num_frames = slow_num_frames
        self.slow_spatial_pool = slow_spatial_pool
        self.fast_output_size = fast_output_size

class TokenSlowFastAggregator(nn.Module):
    def __init__(self, cfg: SlowFastConfig, num_patches_per_side: int):
        super().__init__()
        self.cfg = cfg
        self.H = num_patches_per_side # ViT 每帧的 patch 网格大小 H*H

    def forward(self, x, debug=False):
        T, N, D = x.shape
        H = self.H
        assert N == H*H, f"[SlowFast] N={N} not equal to ViT grid H^2={H*H}"

        info = {}

        # Slow
        slow_T = min(T, self.cfg.slow_num_frames)
        slow_idx = torch.linspace(0, T-1, slow_T).long()
        slow_x = x[slow_idx]  # (T_s, N, D)

        # if self.cfg.slow_spatial_pool == "1d_max":
        #     # 将每两个 token 做 max pool -> token 数量减半
        #     slow_x = slow_x.reshape(slow_T, -1, 2, D).max(dim=2).values  # (T_s, N//2, D)
        # elif self.cfg.slow_spatial_pool == "1d_avg":
        #     # 将每两个 token 做平均池化
        #     slow_x = slow_x.reshape(slow_T, -1, 2, D).mean(dim=2)

        # slow_tokens = slow_x.reshape(1, -1, D)  # flatten 所有帧 token -> shape=(1, T_s*N_slow, D)

        slow_x = rearrange(slow_x, 't (h w) d -> t d h w', h=H, w=H)

        if self.cfg.slow_spatial_pool == "2x2_max":
            pool = nn.MaxPool2d(kernel_size=2, stride=2)
            slow_x = pool(slow_x)  # (T_s, D, H/2, H/2)
        elif self.cfg.slow_spatial_pool == "2x2_avg":
            pool = nn.AvgPool2d(kernel_size=2, stride=2)
            slow_x = pool(slow_x)

        _, _, H_slow, W_slow = slow_x.shape
        slow_tokens = rearrange(slow_x, 't d h w -> t (h w) d')
        slow_tokens = slow_tokens.reshape(1, slow_T*H_slow*W_slow, D)

        # Fast
        fast_x = x  # (T, N, D)
        # reshape -> D, T, H, H
        fast_x = rearrange(fast_x, 't (h w) d -> d t h w', h=H, w=H)
        pool2 = nn.AdaptiveAvgPool2d(self.cfg.fast_output_size) # 自适应平均池化到 fast_output_size=(4,4)
        fast_x = pool2(fast_x)  # (D, T, 4, 4)

        D, T, Hf, Wf = fast_x.shape
        fast_x = rearrange(fast_x, 'd t h w -> t h w d') # 对应之前的 rearrange， (T, 4, 4, D)
        fast_tokens = fast_x.reshape(1, T*Hf*Wf, D) # flatten (1, T*4*4, D)

        out = torch.cat([slow_tokens, fast_tokens], dim=1)

        if debug:
            info["input_shape"] = (T, N, D)
            info["slow_idx"] = slow_idx.tolist()
            info["slow_tokens_shape"] = slow_tokens.shape
            info["fast_tokens_shape"] = fast_tokens.shape
            info["output_shape"] = out.shape
            return out, info

        return out


In [5]:
def slowfast_debug_demo():
    torch.manual_seed(0)

    T, H, D =64, 14, 512  # ViT: 14x14 patches
    N = H*H
    x = torch.randn(T, N, D)

    cfg = SlowFastConfig(
        slow_num_frames=32,
        slow_spatial_pool="1d_max",
        fast_output_size=(4,4)
    )

    model = TokenSlowFastAggregator(cfg, num_patches_per_side=H)
    y, info = model.forward(x, debug=True)

    total_input_tokens = T * N
    total_output_tokens = info['output_shape'][1]
    ratio = total_output_tokens / total_input_tokens

    print(f"Input tokens:        {info['input_shape']}   (T, N, D)")
    print(f"Slow frame indices:  {info['slow_idx']}")
    print(f"Slow tokens shape:   {info['slow_tokens_shape']}")
    print(f"Fast tokens shape:   {info['fast_tokens_shape']}")
    print("---------------------------------------")
    print(f"Fused output shape:  {info['output_shape']}")
    print(f"ratio:   {ratio:.2f}")


if __name__ == "__main__":
    slowfast_debug_demo()


Input tokens:        (64, 196, 512)   (T, N, D)
Slow frame indices:  [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 63]
Slow tokens shape:   torch.Size([1, 6272, 512])
Fast tokens shape:   torch.Size([1, 1024, 512])
---------------------------------------
Fused output shape:  torch.Size([1, 7296, 512])
ratio:   0.58
