In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
import Norms as Norms

In [2]:
##########################################################################
# 影像轉token序列 (B, C, H, W) to (B, HW, C)
def to_3d(x):
    """Reshape from (B, C, H, W) to (B, HW, C)"""
    return rearrange(x, 'b c h w -> b (h w) c')

##########################################################################
# token序列轉影像 (B, HW, C) to (B, C, H, W)
def to_4d(x, h, w):
    """Reshape from (B, HW, C) to (B, C, H, W)"""
    return rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)

##########################################################################
# MLP (flaoat16)
class MLP(nn.Module):

    def __init__(self, dim, dropout=0.1, bias=True):
        super().__init__()
        self.c_fc    = nn.Linear(dim, 4*dim, bias=bias, dtype=torch.float16)
        self.c_proj  = nn.Linear(4*dim, dim, bias=bias, dtype=torch.float16)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x.to(torch.float16)
        with torch.cuda.amp.autocast():  # ✅ AMP 自動管理精度
            x = self.c_fc(x)
            x = F.gelu(x)
            x = self.c_proj(x)
            x = self.dropout(x)
        return x

In [None]:
"""Token statistics transformer: linear-time attention via variational rate reduction"""
##########################################################################
# ToST（Token Statistics Transformer） 版本的自注意力，取代傳統的 QK 相似性計算
class CausalSelfAttention_TSSA(nn.Module):

    def __init__(self, dim, num_heads = 8, block_size = 1024, dropout = 0.1, bias=False , dtype=torch.float16):
        super().__init__()
        
        # query, key, value projections
        self.c_attn = nn.Linear(dim, dim, bias=bias, dtype=dtype)
        # output projection
        self.c_proj = nn.Linear(dim, dim, bias=bias, dtype=dtype)
        # regularization
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        self.n_head = num_heads
        self.dim = dim
        self.dropout = dropout
        self.block_size = block_size
        self.temp = nn.Parameter(torch.ones((self.n_head, 1), dtype = dtype))
        self.denom_bias = nn.Parameter(torch.zeros((self.n_head, block_size, 1), dtype = dtype))
        
    def forward(self, x):
        """
        x: (B, N, C) - token 序列
        return: (B, N, C) - 經過 TSSA 處理的 token 序列
        """
        x = x.to(torch.float16) # 確保計算在 float16 上執行
        B, N, C = x.shape # batch size, sequence length, embedding dimensionality (dim)

        with torch.cuda.amp.autocast():  # ✅ AMP 自動管理精度
            # calculate query, key, values for all heads in batch and move head forward to be the batch dim
            w = self.c_attn(x).view(B, N, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
            w_sq = w ** 2
            denom = (torch.cumsum(w_sq,dim=-2)).clamp_min(torch.finfo(torch.float16).eps) # cumulative sum
            w_normed = (w_sq / denom) + self.denom_bias[:,:N,:]
        
            # calculate attention weights
            tmp = torch.sum(w_normed, dim=-1)* self.temp
            Pi = F.softmax(tmp, dim=1) # B, nh, T
        
            # calculate attention
            dots = torch.cumsum(w_sq * Pi.unsqueeze(-1), dim=-2) / (Pi.cumsum(dim=-1) + torch.finfo(torch.float16).eps).unsqueeze(-1)
            attn = 1. / (1 + dots)
            attn = self.attn_dropout(attn)
        
            # apply attention weights and combine heads
            y = - torch.mul(w.mul(Pi.unsqueeze(-1)), attn)
            y = y.transpose(1, 2).contiguous().view(B, N, C) # re-assemble all head outputs side by side
            y = self.resid_dropout(self.c_proj(y))
            
        return y

##########################################################################
# ToST（Token Statistics Transformer）塊
class ToSTBlock(nn.Module):

    def __init__(self, dim = 1024, norm_type='WithBias'):
        super().__init__()
        self.ln_1 = Norms.Norm(dim, norm_type) # LayerNorm
        self.attn = CausalSelfAttention_TSSA(dim) # TSSA
        
        self.ln_2 = Norms.Norm(dim, norm_type) # LayerNorm
        self.mlp = MLP(dim)
        eta = torch.finfo(torch.float16).eps
        self.gamma1 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
        self.gamma2 = nn.Parameter(eta * torch.ones(dim), requires_grad=True)
    def forward(self, x):
        """
        x: (B, C, H, W) - 影像特徵圖
        return: (B, C, H, W) - 經過 ToST 處理的影像特徵圖
        """
        _, _, H, W = x.shape
        
        x = x + self.gamma1.view(1, -1, 1, 1) *to_4d(self.attn(self.ln_1(to_3d(x))), H, W)
        x = x + self.gamma2.view(1, -1, 1, 1) *to_4d(self.mlp(self.ln_2(to_3d(x))), H, W)
        return x

In [23]:
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler

# 設定測試參數
B = 2       # batch size
C = 1024    # 通道數 (與 dim 對應)
H = 32      # 高度
W = 32      # 寬度
dtype = torch.float16  # 減少內存佔用

# 建立 ToSTBlock
tost_block = ToSTBlock(dim=C)  # ✅ 放到 GPU，確保 float16

# 創建隨機輸入影像特徵 (B, C, H, W)
x = torch.randn(B, C, H, W, dtype=dtype)  # ✅ 確保輸入數據是 float16

# 設定 AMP（混合精度）
scaler = GradScaler(enabled=True)  # ✅ 允許 AMP

# 使用 AMP 進行前向運算
with autocast(dtype=torch.float16):
    y = tost_block(x)

# 測試 1: 檢查輸出形狀是否正確
assert y.shape == x.shape, f"ToSTBlock 輸出形狀錯誤！預期 {x.shape}，但得到 {y.shape}"

# 測試 2: 檢查是否有 NaN 或 Inf
assert not torch.isnan(y).any(), "ToSTBlock 輸出包含 NaN！"
assert not torch.isinf(y).any(), "ToSTBlock 輸出包含 Inf！"

# 測試 3: 反向傳播測試
optimizer = torch.optim.Adam(tost_block.parameters(), lr=1e-3)  # ✅ 建立優化器
optimizer.zero_grad()  # 清空梯度

with autocast(dtype=torch.float16):  # ✅ AMP 運行
    loss = y.mean()  # 假設損失函數是均值
scaler.scale(loss).backward()  # ✅ 使用 AMP 反向傳播
scaler.step(optimizer)  # ✅ AMP 更新權重
scaler.update()  # ✅ AMP 調整 scale

print("✅ ToSTBlock 測試通過，一切正常！")


  scaler = GradScaler(enabled=True)  # ✅ 允許 AMP
  with autocast(dtype=torch.float16):
  with torch.cuda.amp.autocast():  # ✅ AMP 自動管理精度
  with torch.cuda.amp.autocast():  # ✅ AMP 自動管理精度


輸入形狀： torch.Size([2, 1024, 32, 32])
to3d torch.Size([2, 1024, 1024])
NORM torch.Size([2, 1024, 1024])
CausalSelfAttention_TSSA torch.Size([2, 1024, 1024])
to4d torch.Size([2, 1024, 32, 32])
to3d torch.Size([2, 1024, 1024])
MLP torch.Size([2, 1024, 1024])
to4d torch.Size([2, 1024, 32, 32])


  with autocast(dtype=torch.float16):  # ✅ AMP 運行


✅ ToSTBlock 測試通過，一切正常！


In [5]:
num_heads = 100
dtype = torch.float16
position_bias = torch.randn((1, num_heads, 1, 1), dtype = dtype) * 0.01

In [6]:
print(position_bias.shape)  # torch.Size([1, 100, 1, 1])
print(position_bias.dtype)  # torch.float16

torch.Size([1, 100, 1, 1])
torch.float16


In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import Norms as Norms

In [2]:
##########################################################################
# dilated dense residual block (DDRB)
class DDRB(nn.Module):
    """
    Dilated Dense Residual Block 
    Usage:
        self.ddrb = DDRB(in_channels=32, mid_channels=32, kernel=3, stride=1, d=[1, 2, 5], bias=False)
    """
    def __init__(self,
                 in_channels=32,
                 mid_channels=32,
                 kernel=3,
                 stride=1,
                 d=[1, 2, 5],
                 bias=False):
        super(DDRB, self).__init__()
        self.convD1 = nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel, stride, padding=d[0], dilation=d[0], bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_channels, mid_channels, kernel, stride, padding=d[0], dilation=d[0], bias=bias)
            ) # dilation=1
        self.convD2 = nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel, stride, padding=d[1], dilation=d[1], bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_channels, mid_channels, kernel, stride, padding=d[1], dilation=d[1], bias=bias)
            ) # dilation=2
        self.convD3 = nn.Sequential(
                nn.Conv2d(in_channels, mid_channels, kernel, stride, padding=d[2], dilation=d[2], bias=bias),
                nn.ReLU(inplace=True),
                nn.Conv2d(mid_channels, mid_channels, kernel, stride, padding=d[2], dilation=d[2], bias=bias)
            ) # dilation=5
            
    def forward(self, x):
        """
        Args:
            x: input feature map
        Returns:
            enhanced feature map
        Usage:
            enhanced_feature = DDRB(input_feature)
        """
        with torch.cuda.amp.autocast():
            x1 = self.convD1(x)
            x2 = self.convD2(x+x1)
            x3 = self.convD3(x+x1+x2)
       
        return x + x1 + x2 + x3
    
##########################################################################
# enhanced residual pixel-wise attention block (ERPAB)
class ERPAB(nn.Module):
    """ 
    Enhanced Residual Pixel-wise Attention Block 
    Usage:
        self.erpab = ERPAB(in_channels=32, mid_channels=32, kernel=3, stride=1, d=[1, 2, 5], bias=False)
    """
    def __init__(self,
                 in_channels=32,
                 mid_channels=32,
                 kernel=3,
                 stride=1,
                 d=[1, 2, 5],
                 bias=False):
        super(ERPAB, self).__init__()
        
        self.experts = nn.ModuleList([
            nn.Conv2d(in_channels, mid_channels, kernel, stride, padding=d[0], dilation=d[0], bias=bias),  # C32D1
            nn.Conv2d(in_channels, mid_channels, kernel, stride, padding=d[1], dilation=d[1], bias=bias),  # C32D2
            nn.Conv2d(in_channels, mid_channels, kernel, stride, padding=d[2], dilation=d[2], bias=bias),  # C32D5
        ])
        
        self.conv1 = nn.Conv2d(mid_channels*3, mid_channels, kernel_size=3, padding=1, bias=False)
        self.attn_map = nn.Sequential(
            nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(1, in_channels, kernel_size=3, padding=1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        """
        Args:
            x: input feature map
        Returns:
            enhanced feature map
        Usage:
            enhanced_feature = ERPAB(input_feature)
        """
        with torch.cuda.amp.autocast():
            expert_outputs = torch.cat([expert(x) for expert in self.experts], dim=1)
            x1 = F.relu(self.conv1(expert_outputs))
            attn_map = self.attn_map(x1)

        return x + x1 * self.sigmoid(attn_map)

##########################################################################
# cross-stage feature interaction module (CFIM)
class CFIM(nn.Module):
    """
    Cross-Stage Feature Interaction Module
    Usage:
        self.cfim = CFIM(in_channels=32, norm_type = 'DyT' or 'WithBias' or 'BiasFree')
    """
    def __init__(self, in_channels, norm_type='DyT'):
        super(CFIM, self).__init__()
        self.norm1 = Norms.Norm(in_channels, norm_type)
        self.norm2 = Norms.Norm(in_channels, norm_type)
        self.rsconv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
        self.rsconv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
        self.drconv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
        self.drconv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, bias=False)
    
    def forward(self, r_net, dr_net):
        """
        Args:
            rs_net: rain streaks removal network intermediate output
            dr_net: details reconstruction network intermediate output
        Returns:
            to_rs_net: updated rain streaks removal network intermediate output
            to_dr_net: updated details reconstruction network intermediate output
        Usage:
            to_rs_net, to_dr_net = CFIM(r_net, dr_net)
        """     
        with torch.cuda.amp.autocast():
            rs1 = self.rsconv1(self.norm1(r_net))
            dr1 = self.drconv1(self.norm2(dr_net))
            A = torch.matmul(rs1, dr1)
            rs2 = self.rsconv2(rs1)
            dr2 = self.drconv2(dr1)
            rs_side = torch.matmul(A, rs2)
            dr_side = torch.matmul(A, dr2)
            to_rs_net = dr_side + r_net
            to_dr_net = rs_side + dr_net

        return to_rs_net, to_dr_net

In [3]:
def test_modules():
    x = torch.randn(1, 32, 64, 64)  # Example input
    ddrb = DDRB(in_channels=32)
    erpab = ERPAB(in_channels=32)
    cfim = CFIM(in_channels=32)
    
    print("DDRB Output Shape:", ddrb(x).shape)
    print("ERPAB Output Shape:", erpab(x).shape)
    
    x1 = torch.randn(1, 32, 64, 64)
    x2 = torch.randn(1, 32, 64, 64)
    to_rs_net, to_dr_net = cfim(x1, x2)
    print("CFIM Output Shape to Rs Net:", to_rs_net.shape)
    print("CFIM Output Shape to Dr Net:", to_dr_net.shape)

if __name__ == "__main__":
    test_modules()

DDRB Output Shape: torch.Size([1, 32, 64, 64])
ERPAB Output Shape: torch.Size([1, 32, 64, 64])
CFIM Output Shape to Rs Net: torch.Size([1, 32, 64, 64])
CFIM Output Shape to Dr Net: torch.Size([1, 32, 64, 64])


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


In [8]:

##########################################################################
# DPENet_v2 with CFIM
class DPENet_CFIM(nn.Module):
    def __init__(self,
                 in_channels=3,
                 mid_channels=32,
                 kernel=3,
                 stride=1,
                 dilation_list=[1, 2, 5],
                 bias=False):
        super(DPENet_CFIM, self).__init__()

        # Initial feature transformation
        self.inconv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, padding=0, bias=bias)
        self.outconv1 = nn.Conv2d(mid_channels, in_channels, kernel_size=1, padding=0, bias=bias)
        self.inconv2 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, padding=0, bias=bias)
        self.outconv2 = nn.Conv2d(mid_channels, in_channels, kernel_size=1, padding=0, bias=bias)
        self.inconv3 = nn.Conv2d(in_channels, mid_channels, kernel_size=1, padding=0, bias=bias)
        self.outconv3 = nn.Conv2d(mid_channels, in_channels, kernel_size=1, padding=0, bias=bias)

        # Network Modules
        self.ddrb1 = nn.Sequential(*[DDRB(mid_channels, mid_channels, kernel, stride, dilation_list, bias) for _ in range(5)])
        self.ddrb2 = nn.Sequential(*[DDRB(mid_channels, mid_channels, kernel, stride, dilation_list, bias) for _ in range(5)])
        
        # Shared ERPAB instance
        self.erpab1 = ERPAB(mid_channels, mid_channels, kernel, stride, dilation_list, bias)
        self.erpab2 = nn.Sequential(*[ERPAB(mid_channels, mid_channels, kernel, stride, dilation_list, bias) for _ in range(2)])
        
        self.cfim = CFIM(mid_channels)

    def forward(self, x):
        input_ = x
        
        # Stage 1: Initial Rain Streaks Removal
        x = self.inconv1(x)
        rs1 = self.ddrb1(x)
        x = self.outconv1(rs1)
        x_mid = x + input_  # Residual connection
        
        # Stage 2: Initial Detail Reconstruction
        x = self.inconv2(F.relu(x_mid))
        dr1 = self.erpab1(x)
        
        # Cross-stage Feature Interaction
        rs2, _ = self.cfim(rs1, dr1)
        
        # Stage 3: Further Rain Streaks Removal
        x = self.ddrb2(rs2)
        x = self.outconv2(x)
        x_rain_removed = x + x_mid  # Residual connection
        
        # Stage 4: Further Detail Reconstruction
        x = self.inconv3(x_rain_removed)
        dr2 = self.erpab1(x)
        
        # Cross-stage Feature Interaction
        _, dr3 = self.cfim(rs1, dr2)
        
        # Final Detail Enhancement
        x = self.erpab2(dr3)
        x = self.outconv3(x)
        x_final = x + x_rain_removed  # Residual connection
        
        return x_rain_removed, x_final

In [9]:
def test_dpenet_cfim():
    model = DPENet_CFIM()
    test_input = torch.randn(1, 3, 64, 64)  # Batch size = 1, 3 channels, 64x64 image
    output_rain_removed, output_final = model(test_input)
    print("Output shape (rain removed):", output_rain_removed.shape)
    print("Output shape (final reconstruction):", output_final.shape)

if __name__ == "__main__":
    test_dpenet_cfim()

Output shape (rain removed): torch.Size([1, 3, 64, 64])
Output shape (final reconstruction): torch.Size([1, 3, 64, 64])


  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():
  with torch.cuda.amp.autocast():


In [10]:
import torch
print(torch.__version__)


2.6.0+cpu


In [11]:
import sys
print(sys.version)

3.10.16 | packaged by Anaconda, Inc. | (main, Dec 11 2024, 16:19:12) [MSC v.1929 64 bit (AMD64)]
