In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import Norms as Norms

In [2]:
##########################################################################
# 影像轉token序列 (B, C, H, W) to (B, HW, C)
def to_3d(x):
    """Reshape from (B, C, H, W) to (B, HW, C)"""
    return rearrange(x, 'b c h w -> b (h w) c')

##########################################################################
# token序列轉影像 (B, HW, C) to (B, C, H, W)
def to_4d(x, h, w):
    """Reshape from (B, HW, C) to (B, C, H, W)"""
    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

##########################################################################
# MLP (flaoat16)
class MLP(nn.Module):

    def __init__(self, dim, dropout=0.1, bias=True):
        super().__init__()
        self.c_fc    = nn.Linear(dim, 4*dim, bias=bias, dtype=torch.float16)
        self.c_proj  = nn.Linear(4*dim, dim, bias=bias, dtype=torch.float16)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x.to(torch.float16)
        with torch.amp.autocast('cuda'):  # ✅ AMP 自動管理精度
            x = self.c_fc(x)
            x = F.gelu(x)
            x = self.c_proj(x)
            x = self.dropout(x)
        return x

In [None]:
"""Token statistics transformer: linear-time attention via variational rate reduction"""
##########################################################################
# ToST（Token Statistics Transformer） 版本的自注意力，取代傳統的 QK 相似性計算
class CausalSelfAttention_TSSA(nn.Module):

    def __init__(self, dim, num_heads = 8, block_size = 1024, dropout = 0.1, bias=False , dtype=torch.float16):
        super().__init__()
        
        # query, key, value projections
        self.c_attn = nn.Linear(dim, dim, bias=bias, dtype=dtype)
        # output projection
        self.c_proj = nn.Linear(dim, dim, bias=bias, dtype=dtype)
        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.n_head = num_heads
        self.dim = dim
        self.dropout = dropout
        self.block_size = block_size
        self.temp = nn.Parameter(torch.ones((self.n_head, 1), dtype = dtype))
        self.denom_bias = nn.Parameter(torch.zeros((self.n_head, block_size, 1), dtype = dtype))
        
    def forward(self, x):
        """
        x: (B, N, C) - token 序列
        return: (B, N, C) - 經過 TSSA 處理的 token 序列
        """
        x = x.to(torch.float16) # 確保計算在 float16 上執行
        B, N, C = x.shape # batch size, sequence length, embedding dimensionality (dim)

        with torch.amp.autocast('cuda'):  # ✅ AMP 自動管理精度
            # calculate query, key, values for all heads in batch and move head forward to be the batch dim
            w = self.c_attn(x).view(B, N, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
            w_sq = w ** 2
            denom = (torch.cumsum(w_sq,dim=-2)).clamp_min(torch.finfo(torch.float16).eps) # cumulative sum
            w_normed = (w_sq / denom) + self.denom_bias[:,:N,:]
        
            # calculate attention weights
            tmp = torch.sum(w_normed, dim=-1)* self.temp
            Pi = F.softmax(tmp, dim=1) # B, nh, T
        
            # calculate attention
            dots = torch.cumsum(w_sq * Pi.unsqueeze(-1), dim=-2) / (Pi.cumsum(dim=-1) + torch.finfo(torch.float16).eps).unsqueeze(-1)
            attn = 1. / (1 + dots)
            attn = self.attn_dropout(attn)
        
            # apply attention weights and combine heads
            y = - torch.mul(w.mul(Pi.unsqueeze(-1)), attn)
            y = y.transpose(1, 2).contiguous().view(B, N, C) # re-assemble all head outputs side by side
            y = self.resid_dropout(self.c_proj(y))
            
        return y

##########################################################################
# ToST（Token Statistics Transformer）塊
class ToSTBlock(nn.Module):

    def __init__(self, dim = 1024, norm_type='WithBias'):
        super().__init__()
        self.ln_1 = Norms.Norm(dim, norm_type) # LayerNorm
        self.attn = CausalSelfAttention_TSSA(dim) # TSSA
        
        self.ln_2 = Norms.Norm(dim, norm_type) # LayerNorm
        self.mlp = MLP(dim)
        eta = torch.finfo(torch.float16).eps
        self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
        self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
    def forward(self, x):
        """
        x: (B, C, H, W) - 影像特徵圖
        return: (B, C, H, W) - 經過 ToST 處理的影像特徵圖
        """
        _, _, H, W = x.shape
        
        x = x + self.gamma1.view(1, -1, 1, 1) *to_4d(self.attn(self.ln_1(to_3d(x))), H, W)
        x = x + self.gamma2.view(1, -1, 1, 1) *to_4d(self.mlp(self.ln_2(to_3d(x))), H, W)
        return x

In [23]:
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

# 設定測試參數
B = 2       # batch size
C = 1024    # 通道數 (與 dim 對應)
H = 32      # 高度
W = 32      # 寬度
dtype = torch.float16  # 減少內存佔用

# 建立 ToSTBlock
tost_block = ToSTBlock(dim=C)  # ✅ 放到 GPU，確保 float16

# 創建隨機輸入影像特徵 (B, C, H, W)
x = torch.randn(B, C, H, W, dtype=dtype)  # ✅ 確保輸入數據是 float16

# 設定 AMP（混合精度）
scaler = GradScaler(enabled=True)  # ✅ 允許 AMP

# 使用 AMP 進行前向運算
with autocast(dtype=torch.float16):
    y = tost_block(x)

# 測試 1: 檢查輸出形狀是否正確
assert y.shape == x.shape, f"ToSTBlock 輸出形狀錯誤！預期 {x.shape}，但得到 {y.shape}"

# 測試 2: 檢查是否有 NaN 或 Inf
assert not torch.isnan(y).any(), "ToSTBlock 輸出包含 NaN！"
assert not torch.isinf(y).any(), "ToSTBlock 輸出包含 Inf！"

# 測試 3: 反向傳播測試
optimizer = torch.optim.Adam(tost_block.parameters(), lr=1e-3)  # ✅ 建立優化器
optimizer.zero_grad()  # 清空梯度

with autocast(dtype=torch.float16):  # ✅ AMP 運行
    loss = y.mean()  # 假設損失函數是均值
scaler.scale(loss).backward()  # ✅ 使用 AMP 反向傳播
scaler.step(optimizer)  # ✅ AMP 更新權重
scaler.update()  # ✅ AMP 調整 scale

print("✅ ToSTBlock 測試通過，一切正常！")


  scaler = GradScaler(enabled=True)  # ✅ 允許 AMP
  with autocast(dtype=torch.float16):
  with torch.amp.autocast('cuda'):  # ✅ AMP 自動管理精度
  with torch.amp.autocast('cuda'):  # ✅ AMP 自動管理精度


輸入形狀： torch.Size([2, 1024, 32, 32])
to3d torch.Size([2, 1024, 1024])
NORM torch.Size([2, 1024, 1024])
CausalSelfAttention_TSSA torch.Size([2, 1024, 1024])
to4d torch.Size([2, 1024, 32, 32])
to3d torch.Size([2, 1024, 1024])
MLP torch.Size([2, 1024, 1024])
to4d torch.Size([2, 1024, 32, 32])


  with autocast(dtype=torch.float16):  # ✅ AMP 運行


✅ ToSTBlock 測試通過，一切正常！
