In [1]:
from itertools import product
from models.generator import SEMamba
import torch
from utils.util import load_config
import os
import librosa
from models.stfts import mag_phase_stft

device = torch.device('cuda:1')
config = '/disk4/chocho/SEMamba/exp/VCTK-400/dep4_h64_tf4_ds16_dc4_ex4/config.yaml'
checkpoint_file = '/disk4/chocho/SEMamba/exp/VCTK-400/dep4_h64_tf4_ds16_dc4_ex4/g_00093000.pth'
cfg = load_config(config)
n_fft, hop_size, win_size = cfg['stft_cfg']['n_fft'], cfg['stft_cfg']['hop_size'], cfg['stft_cfg']['win_size']
compress_factor = cfg['model_cfg']['compress_factor']
sampling_rate = cfg['stft_cfg']['sampling_rate']
model = SEMamba(cfg).to(device)
state_dict = torch.load(checkpoint_file, map_location=device)
model.load_state_dict(state_dict['generator'])
model.eval()
output_folder = '/disk4/chocho/SEMamba/_202505/encoder-mamba-decoder'
os.makedirs(output_folder, exist_ok=True)

In [2]:
with torch.no_grad():
    for clean_or_noisy in  ["clean","noisy"]:
        input_folder = f'/disk4/chocho/SEMamba/_test_feature_map_{clean_or_noisy}'
        for i, fname in enumerate(os.listdir( input_folder )):
            print(input_folder, fname)
            noisy_wav, _ = librosa.load(os.path.join( input_folder, fname ), sr=sampling_rate)
            noisy_wav = torch.FloatTensor(noisy_wav).to(device)
            # exit()
            norm_factor = torch.sqrt(len(noisy_wav) / torch.sum(noisy_wav ** 2.0)).to(device)
            noisy_wav = (noisy_wav * norm_factor).unsqueeze(0)
            noisy_amp, noisy_pha, noisy_com = mag_phase_stft(noisy_wav, n_fft, hop_size, win_size, compress_factor)
            model.get_feature_map(output_folder, fname, noisy_amp, noisy_pha, clean_or_noisy=clean_or_noisy)

/disk4/chocho/SEMamba/_test_feature_map_clean p226_018.wav


/disk4/chocho/SEMamba/_test_feature_map_clean p227_376.wav
/disk4/chocho/SEMamba/_test_feature_map_clean D4_754.wav
/disk4/chocho/SEMamba/_test_feature_map_clean p226_016.wav
/disk4/chocho/SEMamba/_test_feature_map_clean p230_073.wav
/disk4/chocho/SEMamba/_test_feature_map_clean p287_417.wav
/disk4/chocho/SEMamba/_test_feature_map_noisy p226_018.wav
/disk4/chocho/SEMamba/_test_feature_map_noisy p227_376.wav
/disk4/chocho/SEMamba/_test_feature_map_noisy D4_754.wav
/disk4/chocho/SEMamba/_test_feature_map_noisy p226_016.wav
/disk4/chocho/SEMamba/_test_feature_map_noisy p230_073.wav
/disk4/chocho/SEMamba/_test_feature_map_noisy p287_417.wav


In [3]:
for name, module in model.named_children():
    print(f"{name}")

dense_encoder
TSMamba
mask_decoder
phase_decoder


In [5]:
feature_encoder = model.dense_encoder
mask_decoder = model.mask_decoder
phase_decoder = model.phase_decoder

In [18]:
model.TSMamba[0]

TFMambaBlock(
  (time_mamba): MambaBlock(
    (forward_blocks): ModuleList(
      (0): Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=64, out_features=512, bias=False)
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
          (act): SiLU()
          (x_proj): Linear(in_features=256, out_features=36, bias=False)
          (dt_proj): Linear(in_features=4, out_features=256, bias=True)
          (out_proj): Linear(in_features=256, out_features=64, bias=False)
        )
        (norm): RMSNorm()
      )
    )
    (backward_blocks): ModuleList(
      (0): Block(
        (mixer): Mamba(
          (in_proj): Linear(in_features=64, out_features=512, bias=False)
          (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
          (act): SiLU()
          (x_proj): Linear(in_features=256, out_features=36, bias=False)
          (dt_proj): Linear(in_features=4, out_features=256, bias=True)
      

In [9]:
time_mamba0 = model.TSMamba[0].time_mamba
time_mamba0

MambaBlock(
  (forward_blocks): ModuleList(
    (0): Block(
      (mixer): Mamba(
        (in_proj): Linear(in_features=64, out_features=512, bias=False)
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
        (act): SiLU()
        (x_proj): Linear(in_features=256, out_features=36, bias=False)
        (dt_proj): Linear(in_features=4, out_features=256, bias=True)
        (out_proj): Linear(in_features=256, out_features=64, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (backward_blocks): ModuleList(
    (0): Block(
      (mixer): Mamba(
        (in_proj): Linear(in_features=64, out_features=512, bias=False)
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
        (act): SiLU()
        (x_proj): Linear(in_features=256, out_features=36, bias=False)
        (dt_proj): Linear(in_features=4, out_features=256, bias=True)
        (out_proj): Linear(in_features=256, out_features=64, bias=False)
     

In [23]:
from c66 import pp, pps
pps(time_mamba0(torch.randn([10, 286, 64]).to("cuda:1")))
pps(model.TSMamba[0].tlinear(torch.randn([10, 128, 286]).to("cuda:1")))

time_mamba0(torch.randn([10, 286, 64]).to('cuda:1'))'s shape: torch.Size([10, 286, 128])
model.TSMamba[0].tlinear(torch.randn([10, 128, 286]).to('cuda:1'))'s shape: torch.Size([10, 64, 286])


In [10]:
freq_mamba0 = model.TSMamba[0].freq_mamba
freq_mamba0

MambaBlock(
  (forward_blocks): ModuleList(
    (0): Block(
      (mixer): Mamba(
        (in_proj): Linear(in_features=64, out_features=512, bias=False)
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
        (act): SiLU()
        (x_proj): Linear(in_features=256, out_features=36, bias=False)
        (dt_proj): Linear(in_features=4, out_features=256, bias=True)
        (out_proj): Linear(in_features=256, out_features=64, bias=False)
      )
      (norm): RMSNorm()
    )
  )
  (backward_blocks): ModuleList(
    (0): Block(
      (mixer): Mamba(
        (in_proj): Linear(in_features=64, out_features=512, bias=False)
        (conv1d): Conv1d(256, 256, kernel_size=(4,), stride=(1,), padding=(3,), groups=256)
        (act): SiLU()
        (x_proj): Linear(in_features=256, out_features=36, bias=False)
        (dt_proj): Linear(in_features=4, out_features=256, bias=True)
        (out_proj): Linear(in_features=256, out_features=64, bias=False)
     