In [1]:
from itertools import product
from models.generator import SEMamba
import torch
from utils.util import load_config
import os
import librosa
from models.stfts import mag_phase_stft, mag_phase_istft
from models.pcs400 import cal_pcs

device = torch.device('cuda:1')
config = '/disk4/chocho/SEMamba/exp/VCTK/dep3_h32_tf4_ds32_dc3_ex4/config.yaml'
checkpoint_file = '/disk4/chocho/SEMamba/exp/VCTK/dep3_h32_tf4_ds32_dc3_ex4/g_00025000.pth'
# config = '/disk4/chocho/SEMamba/exp/VCTK-400/dep4_h64_tf4_ds16_dc4_ex4/config.yaml'
# checkpoint_file = "/disk4/chocho/SEMamba/exp/VCTK-400/dep4_h64_tf4_ds16_dc4_ex4/g_00093000.pth"
cfg = load_config(config)
n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size']
compress_factor = cfg['model_cfg']['compress_factor']
sampling_rate = cfg['stft_cfg']['sampling_rate']
model = SEMamba(cfg).to(device)
state_dict = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(state_dict['generator'])
model.eval()
# output_folder = '/disk4/chocho/SEMamba/_202505/encoder-mamba-decoder'
# os.makedirs(output_folder, exist_ok=True)

SEMamba(
  (dense_encoder): DenseEncoder(
    (dense_conv_1): Sequential(
      (0): Conv2d(2, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): PReLU(num_parameters=32)
    )
    (dense_block): DenseBlock(
      (dense_block): ModuleList(
        (0): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (2): PReLU(num_parameters=32)
        )
        (1): Sequential(
          (0): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(2, 1), dilation=(2, 1))
          (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (2): PReLU(num_parameters=32)
        )
        (2): Sequential(
          (0): Conv2d(96, 32, kernel_size=(3, 3), stride=(1, 1), padding=(4, 1), dilation=(4, 1))
          (1): 

In [2]:
feature_encoder = model.dense_encoder
mask_decoder = model.mask_decoder
phase_decoder = model.phase_decoder
TSMamba = model.TSMamba

In [None]:
from c66 import pp, pps
from einops import rearrange
import matplotlib.pyplot as plt
import os

with torch.no_grad():
    # for clean_or_noisy in  ["clean","noisy"]:
    for clean_or_noisy in  ["noisy"]:
        
        input_folder = f'/disk4/chocho/SEMamba/_test_feature_map_{clean_or_noisy}'
        # input_folder = "_test_noisy"
        for i, fname in enumerate(os.listdir( input_folder )):
            feature_maps = []
            labels = []
            print(input_folder, fname)
            noisy_wav, _ = librosa.load(os.path.join( input_folder, fname ), sr=sampling_rate)
            noisy_wav = torch.FloatTensor(noisy_wav).to(device)
            
            norm_factor = torch.sqrt(len(noisy_wav) / torch.sum(noisy_wav ** 2.0)).to(device)
            # noisy_wav: [len_wav,] = [L,]
            # norm_factor: [1,]
            # !!!
            
            noisy_wav = (noisy_wav * norm_factor).unsqueeze(0)
            # noisy_wav: [1, len_wav] = [1, L]

            noisy_mag, noisy_pha, noisy_com = mag_phase_stft(noisy_wav, n_fft, hop_size, win_size, compress_factor)
            # noisy_mag, noisy_pha: [1, F, T] = [1, n_fft//2, len_wav//hop_size]
            
            # feature_encoder
            # Reshape inputs
            noisy_mag = rearrange(noisy_mag, 'b f t -> b t f').unsqueeze(1)  # [B F T] -> [B, 1, T, F]
            noisy_pha = rearrange(noisy_pha, 'b f t -> b t f').unsqueeze(1)  # [B F T] -> [B, 1, T, F]
            # 第一 row 畫出 noisy_pha, noisy_pha
            # x 軸是 T, y 軸是 F
            feature_maps.append(noisy_mag[0, 0].detach().cpu().numpy())
            labels.append("Noisy Magnitude (after reshape)")
            feature_maps.append(noisy_pha[0, 0].detach().cpu().numpy())  # [T, F]
            labels.append("Noisy Phase (after reshape)")
            
            # Concatenate magnitude and phase inputs
            x = torch.cat((noisy_mag, noisy_pha), dim=1)  # [B, 2, T, F]

            # Feature Encoder
            x = feature_encoder(x)
            # [B, 2, T, F] -> [B, h, T, F//2]
            # 第二 row 畫出 x 的前三張圖
            # x 軸是 T, y 軸是 F//2
            for i in range(3):
                feature_maps.append(x[0, i].detach().cpu().numpy())  # [T, F]
                labels.append(f"Dense Encoder Output (Channel {i})")
            
            # TF-Mamba
            # TSMamba is a instance of TFMambaBlock
            for idx, block in enumerate(TSMamba):
            # for block in TSMamba:
                b, c, t, f = x.size()
                # b, c, t, f = [1, h, T, F]
                
                # 新的 row 畫出 x 前 2 張圖
                # x 軸是 T, y 軸是 F
                for i in range(2):  # Channels 0, 1
                    feature_maps.append(x.view(b,c,t,f)[0, i].detach().cpu().numpy())  # [T, F]
                    labels.append(f"TSMamba Block {idx} Input (Channel {i})")
                
                x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
                
                x = block.tlinear( block.time_mamba(x).permute(0,2,1) ).permute(0,2,1) + x
                # [F, T, h] -> [F, T, h]
                # 延續剛剛的 row 畫出 x reshape 成 [1, h, T, F] 的前 2 張圖
                # x 軸是 T, y 軸是 F
                # 先將 x 恢復到 [b, f, t, c]
                x_reshaped = x.view(b, f, t, c)
                # 再轉換到 [b, c, t, f]
                x_reshaped = x_reshaped.permute(0, 3, 2, 1)  # [b, c, t, f]
                for i in range(2):  # Channels 0, 1
                    feature_maps.append(x_reshaped[0, i].detach().cpu().numpy())  # [T, F]
                    labels.append(f"Time Mamba {idx} Output (Channel {i})")
                
                x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
                
                x = block.flinear( block.freq_mamba(x).permute(0,2,1) ).permute(0,2,1) + x
                # [T, F, h] -> [T, F, h]
                # 延續剛剛的 row 畫出 x reshape 成 [1, h, T, F] 的前 2 張圖
                # x 軸是 T, y 軸是 F
                # 先將 x 恢復到 [b, t, f, c]
                x_reshaped = x.view(b, t, f, c)
                # 再轉換到 [b, c, t, f]
                x_reshaped = x_reshaped.permute(0, 3, 1, 2)  # [b, c, t, f]
                for i in range(2):  # Channels 0, 1
                    feature_maps.append(x_reshaped[0, i].detach().cpu().numpy())  # [T, F]
                    labels.append(f"Freq Mamba {idx} Output (Channel {i})")
                
                x = x.view(b, t, f, c).permute(0, 3, 1, 2)

            # 新的 row 畫出 x
            for i in range(2):  # Channels 0, 1
                feature_maps.append(x.view(b,c,t,f)[0, i].detach().cpu().numpy())  # [T, F]
                labels.append(f"TSMamba Block {len(TSMamba)-1} Output (Channel {i})")
            
            # 接著 row 畫出 mask_decoder(x), phase_decoder(x) 這 2 張圖
            # x 軸是 T, y 軸是 F
            feature_maps.append(mask_decoder(x)[0, 0].detach().cpu().numpy())  # [T, F]
            labels.append("mask_decoder(x)")
            feature_maps.append(phase_decoder(x)[0, 0].detach().cpu().numpy())  # [T, F]
            labels.append("phase_decoder(x)")
            feature_maps.append((mask_decoder(x) * noisy_mag)[0, 0].detach().cpu().numpy())  # [T, F]
            labels.append("denoised mag")
            
            # Mag, Pha Decoder
            denoised_mag = rearrange(mask_decoder(x) * noisy_mag, 'b c t f -> b f t c').squeeze(-1)
            # [1, 1, T, F] * [1, 1, T, F] = [1, 1, T, F] -> [1, F, T, 1] -> [1, F, T]
            
            denoised_pha = rearrange(phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)
            # [1, 1, T, F] -> [1, F, T, 1] -> [1, F, T]
            # !!!
            
            audio_g = mag_phase_istft(denoised_mag, denoised_pha, n_fft, hop_size, win_size, compress_factor)
            # [1, ~L]
            # !!!
            
            audio_g = audio_g / norm_factor
            # !!!
            
            audio_g = cal_pcs(audio_g.squeeze().cpu().numpy())
            # [~L,]
            # !!!
            
            pps(audio_g)
            
            # 動態計算 row 數
            # Row 1: Noisy Magnitude, Noisy Phase (2 張圖)
            # Row 2: Dense Encoder Outputs (3 張圖)
            # Row 3 到 Row 2+len(TSMamba): 每個 TSMamba block (6 張圖)
            # Row 3+len(TSMamba): Decoder Outputs (2 張圖)
            total_rows = 3 + len(TSMamba)  # 2 (初始 rows) + len(TSMamba) + 1 (decoder row)

            # 每個 row 最多 6 列（TSMamba block 需要 6 張圖）
            plt.figure(figsize=(18, 3 * total_rows))  # 動態調整圖表高度

            # Row 1: Noisy Magnitude, Noisy Phase (2 張圖)
            plt.subplot(total_rows, 6, 1)
            plt.imshow(feature_maps[0].T, aspect='auto', origin='lower', cmap='viridis')
            plt.colorbar(label='Amplitude')
            plt.title(labels[0], fontsize=8)
            plt.xlabel('Time Frame')
            plt.ylabel('Frequency Bin')

            plt.subplot(total_rows, 6, 2)
            plt.imshow(feature_maps[1].T, aspect='auto', origin='lower', cmap='viridis')
            plt.colorbar(label='Amplitude')
            plt.title(labels[1], fontsize=8)
            plt.xlabel('Time Frame')
            plt.ylabel('Frequency Bin')

            # Row 2: Dense Encoder Outputs (3 張圖)
            for i in range(3):
                plt.subplot(total_rows, 6, 7 + i)  # 第二行從第 7 個位置開始
                plt.imshow(feature_maps[2 + i].T, aspect='auto', origin='lower', cmap='viridis')
                plt.colorbar(label='Amplitude')
                plt.title(labels[2 + i], fontsize=8)
                plt.xlabel('Time Frame')
                plt.ylabel('Frequency Bin')

            # Row 3 到 Row 2+len(TSMamba): 每個 TSMamba block (6 張圖)
            start_idx = 5  # feature_maps 索引從 5 開始
            for block_idx in range(len(TSMamba)):
                row_start = 13 + block_idx * 6  # 每個 TSMamba block 從新 row 開始
                # TSMamba block 輸入 (2 張圖)
                for i in range(2):
                    plt.subplot(total_rows, 6, row_start + i)
                    plt.imshow(feature_maps[start_idx + i].T, aspect='auto', origin='lower', cmap='viridis')
                    plt.colorbar(label='Amplitude')
                    plt.title(labels[start_idx + i], fontsize=8)
                    plt.xlabel('Time Frame')
                    plt.ylabel('Frequency Bin')
                start_idx += 2

                # Time Mamba 輸出 (2 張圖)
                for i in range(2):
                    plt.subplot(total_rows, 6, row_start + 2 + i)
                    plt.imshow(feature_maps[start_idx + i].T, aspect='auto', origin='lower', cmap='viridis')
                    plt.colorbar(label='Amplitude')
                    plt.title(labels[start_idx + i], fontsize=8)
                    plt.xlabel('Time Frame')
                    plt.ylabel('Frequency Bin')
                start_idx += 2

                # Freq Mamba 輸出 (2 張圖)
                for i in range(2):
                    plt.subplot(total_rows, 6, row_start + 4 + i)
                    plt.imshow(feature_maps[start_idx + i].T, aspect='auto', origin='lower', cmap='viridis', vmin=-4, vmax=4)
                    plt.colorbar(label='Amplitude')
                    plt.title(labels[start_idx + i], fontsize=8)
                    plt.xlabel('Time Frame')
                    plt.ylabel('Frequency Bin')
                start_idx += 2

            # 最後一個 Row: Decoder Outputs (2 張圖)
            row_start = 13 + len(TSMamba) * 6
            
            for i in range(5):
            
                plt.subplot(total_rows, 6, row_start+i)
                plt.imshow(feature_maps[start_idx + i].T, aspect='auto', origin='lower', cmap='viridis')
                plt.colorbar(label='Amplitude')
                plt.title(labels[start_idx + i], fontsize=8)
                plt.xlabel('Time Frame')
                plt.ylabel('Frequency Bin')

            # plt.subplot(total_rows, 6, row_start + 1)
            # plt.imshow(feature_maps[start_idx + 1].T, aspect='auto', origin='lower', cmap='viridis')
            # plt.colorbar(label='Amplitude')
            # plt.title(labels[start_idx + 1], fontsize=8)
            # plt.xlabel('Time Frame')
            # plt.ylabel('Frequency Bin')

            # 設定整體標題並調整佈局
            plt.suptitle(f"Feature Maps for {fname}", fontsize=14)
            plt.tight_layout(rect=[0, 0, 1, 0.96])  # 避免標題重疊

            # 保存圖表
            output_folder = f'/disk4/chocho/SEMamba/feature_maps_output_{clean_or_noisy}'
            os.makedirs(output_folder, exist_ok=True)
            plt.savefig(os.path.join(output_folder, f'feature_maps_{fname.split(".")[0]}_{clean_or_noisy}.png'), bbox_inches='tight')
            plt.close()

/disk4/chocho/SEMamba/_test_feature_map_noisy p226_018.wav
audio_g's shape: (94000,)
/disk4/chocho/SEMamba/_test_feature_map_noisy p227_376.wav
audio_g's shape: (53500,)
/disk4/chocho/SEMamba/_test_feature_map_noisy D4_754.wav
audio_g's shape: (214000,)
/disk4/chocho/SEMamba/_test_feature_map_noisy p226_016.wav
audio_g's shape: (124000,)
/disk4/chocho/SEMamba/_test_feature_map_noisy p230_073.wav
audio_g's shape: (54300,)
/disk4/chocho/SEMamba/_test_feature_map_noisy p287_417.wav
audio_g's shape: (41300,)


: 

In [None]:
from calflops import calculate_flops
with torch.no_grad():
    # 使用 calflops 計算 FLOPs，將 args 改為列表
        flops, macs, params = calculate_flops(
            model=model,
            args=[noisy_mag, noisy_pha],  # 使用列表而非元組
            print_results=True  # 顯示逐層結果
        )
        print(f"Total FLOPs for {fname}: {flops}")
        print(f"Total Params: {params}")
        print(f"Total MACs: {macs}")

In [None]:
flops, macs, params = calculate_flops(
    model=model,
    args=[noisy_mag, noisy_pha],  # 使用列表而非元組
    print_results=True  # 顯示逐層結果
)
print(f"Total FLOPs for {fname}: {flops}")
print(f"Total Params: {params}")
print(f"Total MACs: {macs}")