### **Drive - Data Set**

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### **Clone Paper Code**

In [2]:
# Clone the repository
!git lfs install
!git clone https://github.com/dmlguq456/SepReformer.git
%cd SepReformer

# Install dependencies
!pip install -r requirements.txt
# Note: The repo recommends Python 3.10, which is currently the Colab standard.

Git LFS initialized.
Cloning into 'SepReformer'...
remote: Enumerating objects: 673, done.[K
remote: Counting objects: 100% (173/173), done.[K
remote: Compressing objects: 100% (47/47), done.[K
remote: Total 673 (delta 143), reused 126 (delta 126), pack-reused 500 (from 1)[K
Receiving objects: 100% (673/673), 8.42 MiB | 18.08 MiB/s, done.
Resolving deltas: 100% (416/416), done.
/content/SepReformer
Collecting absl-py==2.1.0 (from -r requirements.txt (line 1))
  Downloading absl_py-2.1.0-py3-none-any.whl.metadata (2.3 kB)
Collecting audioread==3.0.1 (from -r requirements.txt (line 2))
  Downloading audioread-3.0.1-py3-none-any.whl.metadata (8.4 kB)
Collecting certifi==2024.8.30 (from -r requirements.txt (line 3))
  Downloading certifi-2024.8.30-py3-none-any.whl.metadata (2.2 kB)
Collecting cffi==1.17.1 (from -r requirements.txt (line 4))
  Downloading cffi-1.17.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting charset-normalizer==3.4.0 (from -

In [3]:
import os
import random

# Configuration
base_path = '/content/drive/MyDrive/◊™◊ï◊ê◊®/speech'
subfolders = ['s1', 's2', 'mix_clean']
scp_output_dir = '/content/SepReformer/data/custom'
os.makedirs(scp_output_dir, exist_ok=True)

# 1. Get list of file names (assuming they are identical across s1, s2, and mix)
# We sort them first to ensure indices match across all subfolders
files = sorted([f for f in os.listdir(os.path.join(base_path, 's1')) if f.endswith('.wav')])

# 2. Shuffle and Split (80/20)
random.seed(42) # For reproducibility
random.shuffle(files)
split_idx = int(len(files) * 0.8)

train_files = files[:split_idx]
val_files = files[split_idx:]

# 3. Generate the SCP files
def create_scp(file_list, set_name):
    for folder in subfolders:
        output_path = os.path.join(scp_output_dir, f"{set_name}_{folder}.scp")
        with open(output_path, 'w') as f:
            for filename in file_list:
                full_path = os.path.join(base_path, folder, filename)
                # ID format: filename (must be unique)
                f.write(f"{filename} {full_path}\n")
        print(f"Created {output_path} with {len(file_list)} files.")

create_scp(train_files, "train")
create_scp(val_files, "val")

Created /content/SepReformer/data/custom/train_s1.scp with 160 files.
Created /content/SepReformer/data/custom/train_s2.scp with 160 files.
Created /content/SepReformer/data/custom/train_mix_clean.scp with 160 files.
Created /content/SepReformer/data/custom/val_s1.scp with 40 files.
Created /content/SepReformer/data/custom/val_s2.scp with 40 files.
Created /content/SepReformer/data/custom/val_mix_clean.scp with 40 files.


In [4]:
!pip install loguru torchinfo ptflops thop mir_eval

Collecting loguru
  Downloading loguru-0.7.3-py3-none-any.whl.metadata (22 kB)
Collecting torchinfo
  Downloading torchinfo-1.8.0-py3-none-any.whl.metadata (21 kB)
Collecting ptflops
  Downloading ptflops-0.7.5-py3-none-any.whl.metadata (9.4 kB)
Collecting thop
  Using cached thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Collecting mir_eval
  Downloading mir_eval-0.8.2-py3-none-any.whl.metadata (3.0 kB)
Downloading loguru-0.7.3-py3-none-any.whl (61 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m61.6/61.6 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchinfo-1.8.0-py3-none-any.whl (23 kB)
Downloading ptflops-0.7.5-py3-none-any.whl (19 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Downloading mir_eval-0.8.2-py3-none-any.whl (102 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚î

In [5]:
config_content = """
project: "[Project] SepReformer"
notes: "Training on custom dataset - Fixed Scheduler"

config:
    dataset:
        max_len: 32000
        sampling_rate: 8000
        scp_dir: "data/custom"
        train:
            mixture: "train_mix_clean.scp"
            spk1: "train_s1.scp"
            spk2: "train_s2.scp"
            dynamic_mixing: true
        valid:
            mixture: "val_mix_clean.scp"
            spk1: "val_s1.scp"
            spk2: "val_s2.scp"
        test:
            mixture: "val_mix_clean.scp"
            spk1: "val_s1.scp"
            spk2: "val_s2.scp"

    dataloader:
        batch_size: 1
        pin_memory: true
        num_workers: 4
        drop_last: true

    model:
        num_stages: 4
        num_spks: 2
        module_audio_enc:
            in_channels: 1
            out_channels: 256
            kernel_size: 16
            stride: 4
            groups: 1
            bias: false
        module_feature_projector:
            num_channels: 256
            in_channels: 256
            out_channels: 128
            kernel_size: 1
            bias: false
        module_separator:
            num_stages: 4
            relative_positional_encoding:
                in_channels: 128
                num_heads: 8
                maxlen: 2000
                embed_v: false
            enc_stage:
                global_blocks:
                    in_channels: 128
                    num_mha_heads: 8
                    dropout_rate: 0.05
                local_blocks:
                    in_channels: 128
                    kernel_size: 65
                    dropout_rate: 0.05
                down_conv_layer:
                    in_channels: 128
                    samp_kernel_size: 5
            spk_split_stage:
                in_channels: 128
                num_spks: 2
            simple_fusion:
                out_channels: 128
            dec_stage:
                num_spks: 2
                global_blocks:
                    in_channels: 128
                    num_mha_heads: 8
                    dropout_rate: 0.05
                local_blocks:
                    in_channels: 128
                    kernel_size: 65
                    dropout_rate: 0.05
                spk_attention:
                    in_channels: 128
                    num_mha_heads: 8
                    dropout_rate: 0.05
        module_output_layer:
            in_channels: 256
            out_channels: 128
            num_spks: 2
        module_audio_dec:
            in_channels: 256
            out_channels: 1
            kernel_size: 16
            stride: 4
            bias: false

    criterion:
        name: ["PIT_SISNR_mag", "PIT_SISNR_time", "PIT_SISNRi", "PIT_SDRi"]
        PIT_SISNR_mag:
            frame_length: 512
            frame_shift: 128
            window: 'hann'
            num_stages: 4
            num_spks: 2
            scale_inv: true
            mel_opt: false
        PIT_SISNR_time:
            num_spks: 2
            scale_inv: true
        PIT_SISNRi:
            num_spks: 2
            scale_inv: true
        PIT_SDRi:
            dump: 0

    optimizer:
        name: ["AdamW"]
        AdamW:
            lr: 1.0e-3
            weight_decay: 1.0e-2

    scheduler:
        # Restored "WarmupConstantSchedule" here
        name: ["ReduceLROnPlateau", "WarmupConstantSchedule"]
        ReduceLROnPlateau:
            mode: "min"
            min_lr: 1.0e-10
            factor: 0.8
            patience: 2
        WarmupConstantSchedule:
            warmup_steps: 100

    check_computations:
        dummy_len: 16000

    engine:
        max_epoch: 90
        gpuid: "0"
        mvn: false
        clip_norm: 5
        start_scheduling: 10
        test_epochs: [50, 90]
"""

with open("models/SepReformer_Base_WSJ0/configs.yaml", "w") as f:
    f.write(config_content)
print("Config file updated with restored Scheduler settings!")
!rm -rf models/SepReformer_Base_WSJ0/log

#_______________________________________________________________________________

import os

# Path to the file
engine_path = '/content/SepReformer/models/SepReformer_Base_WSJ0/engine.py'

# Read the file
with open(engine_path, 'r') as f:
    lines = f.readlines()

# Fix the specific line
new_lines = []
found = False
for line in lines:
    # Look for the save_checkpoint_per_best call
    if 'util_engine.save_checkpoint_per_best' in line:
        # We need to remove the last argument (wandb_run or getattr...)
        # The correct call should end after 'self.checkpoint_path'
        # Let's construct the correct line based on the error message.
        # It takes 7 args: valid_loss_best, val_loss, train_loss, epoch, model, optimizer, path

        indent = line.split("valid_loss_best")[0] # Preserve indentation
        new_line = (
            f"{indent}valid_loss_best = util_engine.save_checkpoint_per_best("
            f"valid_loss_best, valid_loss_src_time, train_loss_src_time, epoch, "
            f"self.model, self.main_optimizer, self.checkpoint_path)\n"
        )
        new_lines.append(new_line)
        found = True
        print("Fixed line to 7 arguments.")
    else:
        new_lines.append(line)

if found:
    with open(engine_path, 'w') as f:
        f.writelines(new_lines)
    print("Successfully patched engine.py!")
else:
    print("Could not find the specific line to patch. Please check file content.")

#_______________________________________________________________________________

import os

# ◊î◊†◊™◊ô◊ë ◊ú◊ß◊ï◊ë◊•
engine_path = '/content/SepReformer/models/SepReformer_Base_WSJ0/engine.py'

# ◊ß◊®◊ô◊ê◊™ ◊î◊ß◊ï◊ë◊•
with open(engine_path, 'r') as f:
    lines = f.readlines()

# ◊™◊ô◊ß◊ï◊ü ◊î◊©◊ï◊®◊î
new_lines = []
found = False
for line in lines:
    # ◊û◊ó◊§◊©◊ô◊ù ◊ê◊™ ◊î◊©◊ï◊®◊î ◊©◊û◊†◊°◊î ◊ú◊®◊©◊ï◊ù ◊ê◊™ ◊î-Learning Rate
    if 'writer_src.add_scalars("Learning Rate"' in line:
        # ◊û◊ó◊ú◊ô◊§◊ô◊ù ◊ê◊™ add_scalars ◊ë-add_scalar (◊ë◊ô◊ó◊ô◊ì)
        new_line = line.replace('add_scalars', 'add_scalar')
        new_lines.append(new_line)
        found = True
        print("Fixed line:", new_line.strip())
    else:
        new_lines.append(line)

if found:
    with open(engine_path, 'w') as f:
        f.writelines(new_lines)
    print("Successfully patched engine.py (add_scalar fix)!")
else:
    print("Could not find the specific line to patch.")


Config file updated with restored Scheduler settings!
Fixed line to 7 arguments.
Successfully patched engine.py!
Fixed line: writer_src.add_scalar("Learning Rate", self.main_optimizer.param_groups[0]['lr'], epoch)
Successfully patched engine.py (add_scalar fix)!


# control **epochs**

In [6]:
import yaml

config_path = 'models/SepReformer_Base_WSJ0/configs.yaml'

# ◊ß◊®◊ô◊ê◊™ ◊î◊ß◊ï◊ë◊•
with open(config_path, 'r') as f:
    # ◊†◊ò◊¢◊ü ◊ê◊™ ◊î◊ß◊ï◊ë◊• ◊õ◊ò◊ß◊°◊ò ◊§◊©◊ï◊ò ◊õ◊ì◊ô ◊ú◊ê ◊ú◊©◊ë◊ï◊® ◊ê◊™ ◊î◊î◊¢◊®◊ï◊™ ◊ï◊î◊û◊ë◊†◊î ◊î◊û◊ô◊ï◊ó◊ì
    content = f.read()

# ◊ë◊ô◊¶◊ï◊¢ ◊î◊ó◊ú◊§◊ï◊™ ◊§◊©◊ï◊ò◊ï◊™ ◊ú◊©◊ô◊†◊ï◊ô ◊û◊°◊§◊® ◊î-Epochs
content = content.replace('max_epoch: 6', 'max_epoch: 90')
content = content.replace('test_epochs: [6]', 'test_epochs: [90]')

# ◊©◊û◊ô◊®◊î ◊û◊ó◊ì◊©
with open(config_path, 'w') as f:
    f.write(content)

print("Updated config: max_epoch = 6, test_epochs = [6]")

Updated config: max_epoch = 6, test_epochs = [6]


In [7]:
import os

# ◊î◊†◊™◊ô◊ë ◊ú◊ß◊ï◊ë◊• ◊î◊ë◊¢◊ô◊ô◊™◊ô
dataset_path = '/content/SepReformer/models/SepReformer_Base_WSJ0/dataset.py'

# ◊ß◊®◊ô◊ê◊™ ◊î◊ß◊ï◊ë◊•
with open(dataset_path, 'r') as f:
    lines = f.readlines()

# ◊™◊ô◊ß◊ï◊ü ◊î◊ú◊ï◊í◊ô◊ß◊î
new_lines = []
found = False
for line in lines:
    # ◊û◊ó◊§◊©◊ô◊ù ◊ê◊™ ◊î◊©◊ï◊®◊î ◊©◊û◊†◊°◊î ◊ú◊§◊®◊ß ◊ê◊™ ◊©◊û◊ï◊™ ◊î◊ß◊ë◊¶◊ô◊ù ◊¢◊ù split('_')
    # ◊î◊©◊ï◊®◊î ◊î◊û◊ß◊ï◊®◊ô◊™ ◊†◊®◊ê◊ô◊™ ◊ë◊¢◊®◊ö ◊õ◊ö:
    # tmp1 = key.split('_')[1][:3] != key_random.split('_')[3][:3]
    if "key.split('_')" in line and "key_random.split('_')" in line:
        # ◊û◊ó◊ú◊ô◊§◊ô◊ù ◊ë◊ú◊ï◊í◊ô◊ß◊î ◊§◊©◊ï◊ò◊î ◊©◊™◊û◊ô◊ì ◊û◊ó◊ñ◊ô◊®◊î True (◊û◊ê◊§◊©◊®◊™ ◊¢◊®◊ë◊ï◊ë)
        # ◊©◊ï◊û◊®◊ô◊ù ◊¢◊ú ◊î◊î◊ñ◊ó◊î (Indentation) ◊î◊û◊ß◊ï◊®◊ô◊™
        indent = line[:line.find("tmp1")]
        new_line = f"{indent}tmp1 = True # Patched for custom filenames\n"
        new_lines.append(new_line)

        # ◊ê◊†◊ó◊†◊ï ◊¶◊®◊ô◊õ◊ô◊ù ◊í◊ù ◊ú◊†◊ò◊®◊ú ◊ê◊™ ◊î◊©◊ï◊®◊ï◊™ ◊î◊ë◊ê◊ï◊™ ◊©◊ë◊ï◊ì◊ß◊ï◊™ ◊ê◊™ tmp2 ◊ï◊ê◊™ ◊î◊™◊†◊ê◊ô ◊î◊û◊ï◊®◊õ◊ë
        # ◊ê◊ë◊ú ◊î◊ì◊®◊ö ◊î◊õ◊ô ◊ß◊ú◊î ◊î◊ô◊ê ◊ú◊™◊™ ◊ú-tmp1 ◊ú◊î◊ô◊ï◊™ True, ◊ï◊ú◊©◊†◊ï◊™ ◊í◊ù ◊ê◊™ tmp2 ◊ê◊ù ◊¶◊®◊ô◊ö.
        # ◊ú◊û◊¢◊©◊î, ◊î◊§◊™◊®◊ï◊ü ◊î◊õ◊ô ◊ô◊¶◊ô◊ë ◊î◊ï◊ê ◊ú◊¢◊ò◊ï◊£ ◊ê◊™ ◊õ◊ú ◊î◊ë◊ú◊ï◊ß ◊ë-try/except, ◊ê◊ë◊ú ◊¢◊®◊ô◊õ◊î ◊õ◊ñ◊ï ◊û◊°◊ï◊ë◊õ◊™ ◊ë◊°◊ß◊®◊ô◊§◊ò.
        # ◊î◊§◊™◊®◊ï◊ü ◊î◊§◊©◊ï◊ò: ◊†◊î◊§◊ï◊ö ◊ê◊™ ◊õ◊ú ◊ë◊ì◊ô◊ß◊ï◊™ ◊î◊ì◊ï◊ë◊®◊ô◊ù ◊ú-True.
        found = True
        print("Fixed tmp1 check.")

    elif "tmp2 =" in line and "key.split" in line:
         indent = line[:line.find("tmp2")]
         new_lines.append(f"{indent}tmp2 = True # Patched\n")
         print("Fixed tmp2 check.")

    else:
        new_lines.append(line)

if found:
    with open(dataset_path, 'w') as f:
        f.writelines(new_lines)
    print("‚úÖ Successfully patched dataset.py to support custom filenames!")
else:
    print("‚ùå Could not find the lines to patch. Please check file content manually.")


import os

file_path = '/content/SepReformer/models/SepReformer_Base_WSJ0/dataset.py'

print(f"üîß Fixing {file_path}...")

with open(file_path, 'r') as f:
    lines = f.readlines()

new_lines = []
fixed = False

for line in lines:
    # ◊ñ◊ô◊î◊ï◊ô ◊î◊©◊ï◊®◊î ◊î◊û◊©◊ï◊ë◊©◊™ (◊©◊û◊õ◊ô◊ú◊î ◊í◊ù tmp2 ◊ï◊í◊ù tmp1 ◊ë◊ô◊ó◊ì)
    if "tmp2 =" in line and "tmp1 =" in line:
        # ◊©◊ï◊û◊®◊ô◊ù ◊¢◊ú ◊î◊î◊ñ◊ó◊î (Indentation) ◊î◊û◊ß◊ï◊®◊ô◊™
        indent = line[:line.find("tmp2")]
        # ◊õ◊ï◊™◊ë◊ô◊ù ◊ê◊™ ◊î◊§◊ß◊ï◊ì◊ï◊™ ◊ë◊©◊ï◊®◊ï◊™ ◊†◊§◊®◊ì◊ï◊™
        new_lines.append(f"{indent}tmp1 = True # Fixed\n")
        new_lines.append(f"{indent}tmp2 = True # Fixed\n")
        fixed = True
        print("‚úÖ Found and fixed the broken line!")

    # ◊ò◊ô◊§◊ï◊ú ◊ë◊û◊ß◊®◊ô◊ù ◊©◊ë◊î◊ù ◊î◊©◊ï◊®◊ï◊™ ◊¢◊ì◊ô◊ô◊ü ◊û◊ß◊ï◊®◊ô◊ï◊™ ◊ê◊ö ◊ú◊ê ◊û◊™◊ê◊ô◊û◊ï◊™ (◊ú◊û◊†◊ô◊¢◊™ ◊ë◊¢◊ô◊ï◊™ ◊¢◊™◊ô◊ì◊ô◊ï◊™)
    elif "tmp1 =" in line and "key.split" in line:
        indent = line[:line.find("tmp1")]
        new_lines.append(f"{indent}tmp1 = True # Patched\n")
        fixed = True
    elif "tmp2 =" in line and "key.split" in line:
        indent = line[:line.find("tmp2")]
        new_lines.append(f"{indent}tmp2 = True # Patched\n")
        fixed = True

    else:
        new_lines.append(line)

if fixed:
    with open(file_path, 'w') as f:
        f.writelines(new_lines)
    print("üöÄ File updated successfully.")
else:
    print("‚ö†Ô∏è No broken lines found. The file might be already fixed or different than expected.")


import os

# ◊î◊†◊™◊ô◊ë ◊ú◊ß◊ï◊ë◊• ◊î◊ë◊¢◊ô◊ô◊™◊ô
dataset_path = '/content/SepReformer/models/SepReformer_Base_WSJ0/dataset.py'

# ◊ß◊®◊ô◊ê◊™ ◊î◊ß◊ï◊ë◊•
with open(dataset_path, 'r') as f:
    lines = f.readlines()

# ◊™◊ô◊ß◊ï◊ü ◊î◊©◊ï◊®◊î
new_lines = []
found = False
for line in lines:
    # ◊û◊ó◊§◊©◊ô◊ù ◊ê◊™ ◊î◊©◊ï◊®◊î ◊©◊û◊ë◊¶◊¢◊™ speed_aug
    if 'self.speed_aug(torch.tensor(samps_tmp))' in line:
        # ◊§◊©◊ï◊ò ◊†◊ì◊ú◊í ◊¢◊ú◊ô◊î ◊ê◊ï ◊†◊î◊§◊ï◊ö ◊ê◊ï◊™◊î ◊ú◊î◊¢◊®◊î.
        # ◊ë◊û◊ß◊ï◊ù ◊ú◊û◊ó◊ï◊ß, ◊†◊©◊ê◊ô◊® ◊ê◊™ samps_tmp ◊õ◊û◊ï ◊©◊î◊ï◊ê.
        # ◊î◊©◊ï◊®◊î ◊î◊û◊ß◊ï◊®◊ô◊™: samps_tmp = np.array(self.speed_aug(torch.tensor(samps_tmp))[0])
        # ◊î◊ó◊ì◊©◊î: # samps_tmp = ... (skipped)
        new_lines.append(f"            # {line.strip()} # PATCHED: Disabled speed_aug\n")
        found = True
        print("Disabled speed_aug line.")
    else:
        new_lines.append(line)

if found:
    with open(dataset_path, 'w') as f:
        f.writelines(new_lines)
    print("‚úÖ Successfully patched dataset.py to remove speed augmentation bug!")
else:
    print("‚ö†Ô∏è Could not find the speed_aug line. It might be already fixed.")

Fixed tmp1 check.
Fixed tmp1 check.
‚úÖ Successfully patched dataset.py to support custom filenames!
üîß Fixing /content/SepReformer/models/SepReformer_Base_WSJ0/dataset.py...
‚úÖ Found and fixed the broken line!
üöÄ File updated successfully.
Disabled speed_aug line.
‚úÖ Successfully patched dataset.py to remove speed augmentation bug!


In [8]:
# ============================================================
# CELL: AttentionSpeakerSplit ‚Äî Drop-in Replacement
# ============================================================
# Run ONCE after the repo is cloned (and before training).
# It will:
#   1. Inject AttentionSpeakerSplit into modules/module.py
#   2. Replace every SpkSplitStage() call with it
#   3. Run a shape + gradient smoke-test
# ============================================================

import os, textwrap
import torch
import torch.nn as nn

# ----------------------------------------------------------
# 1.  Class definition (injected as text into module.py)
# ----------------------------------------------------------
NEW_CLASS = textwrap.dedent("""
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
#  AttentionSpeakerSplit  (drop-in replacement for SpkSplitStage)
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class AttentionSpeakerSplit(nn.Module):
    \"\"\"
    Dynamic, cross-attention-based speaker split.

    Architecture
    ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    ‚Ä¢ num_spks learnable prototype vectors (speaker_queries) act as Queries Q.
    ‚Ä¢ The shared encoder output x provides Keys K and Values V.
    ‚Ä¢ Cross-attention distils a global "speaker context" from the mixture ‚Äî
      each query learns what its assigned speaker sounds like across the corpus.
    ‚Ä¢ That context modulates the input frame-by-frame via a sigmoid gate:
          out[s, t] = x[t] * sigmoid(ctx[s]) + proj(x[t])   for each speaker s
    ‚Ä¢ A per-speaker GLU feed-forward sharpens the gated output.
    ‚Ä¢ Residual connections stabilise training from random initialisation.

    Shapes
    ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    Input  x : (batch, dim, time)
    Output   : (batch, num_spks, dim, time)  ‚Üê feeds simple_fusion unchanged
    \"\"\"

    def __init__(self, in_channels: int, num_spks: int,
                 num_heads: int = 8, dropout: float = 0.05):
        super().__init__()
        assert in_channels % num_heads == 0, (
            f"in_channels ({in_channels}) must be divisible by "
            f"num_heads ({num_heads})"
        )
        self.in_channels = in_channels
        self.num_spks    = num_spks

        # ‚îÄ‚îÄ Learnable speaker prototype queries ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        # (num_spks, in_channels): one row per speaker.
        # These are the only globally shared parameters encoding speaker identity.
        self.speaker_queries = nn.Parameter(
            torch.empty(num_spks, in_channels)
        )
        nn.init.trunc_normal_(self.speaker_queries, std=0.02)

        # ‚îÄ‚îÄ Cross-attention (Q = prototypes, K/V = mixture features) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        self.cross_attn = nn.MultiheadAttention(
            embed_dim   = in_channels,
            num_heads   = num_heads,
            dropout     = dropout,
            batch_first = False,   # expects (seq, batch, dim)
        )
        self.attn_norm = nn.LayerNorm(in_channels)

        # ‚îÄ‚îÄ Per-speaker gated feed-forward (GLU) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        # Sharpens the modulated representation independently per speaker.
        # Linear: dim ‚Üí 2*dim, then GLU halves back to dim.
        self.ff_gate = nn.Sequential(
            nn.LayerNorm(in_channels),
            nn.Linear(in_channels, in_channels * 2, bias=True),
            nn.GLU(dim=-1),
        )
        self.ff_norm = nn.LayerNorm(in_channels)

        # ‚îÄ‚îÄ Residual input projection ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        self.res_proj = nn.Linear(in_channels, in_channels, bias=False)

        # ‚îÄ‚îÄ Output normalisation (GroupNorm matches rest of model) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        self.out_norm = nn.GroupNorm(1, in_channels)

    # ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        \"\"\"
        Parameters
        ----------
        x : (batch, dim, time)

        Returns
        -------
        out : (batch, num_spks, dim, time)
        \"\"\"
        B, D, T = x.shape

        # ‚îÄ‚îÄ Step 1: Cross-Attention to get per-speaker global context ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        # MultiheadAttention: (seq, batch, dim)
        kv = x.permute(2, 0, 1)                            # (T, B, D)

        # Expand speaker_queries to batch: (num_spks, B, D)
        q  = self.speaker_queries.unsqueeze(1).expand(-1, B, -1)

        # Each speaker query attends over all time steps.
        attn_out, _ = self.cross_attn(
            query        = q,
            key          = kv,
            value        = kv,
            need_weights = False,           # skip weight materialisation
        )
        # attn_out: (num_spks, B, D)

        # Prototype queries as residual, then LayerNorm.
        attn_out = self.attn_norm(attn_out + q)             # (num_spks, B, D)

        # ‚îÄ‚îÄ Step 2: Frame-wise sigmoid gate ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        # Expand input over speakers
        x_exp   = x.unsqueeze(1).expand(-1, self.num_spks, -1, -1)
        #         (B, 1, D, T) ‚Üí (B, S, D, T)

        # attn_out: (S, B, D) ‚Üí (B, S, D, 1)  ‚Äî broadcast over time
        spk_ctx = attn_out.permute(1, 0, 2).unsqueeze(-1)  # (B, S, D, 1)

        # Residual baseline: project input, broadcast over speakers
        res = self.res_proj(x.permute(0, 2, 1))            # (B, T, D)
        res = res.permute(0, 2, 1).unsqueeze(1)            # (B, 1, D, T)

        # Gate + residual
        out = x_exp * torch.sigmoid(spk_ctx) + res         # (B, S, D, T)

        # ‚îÄ‚îÄ Step 3: Per-speaker GLU feed-forward ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        # Reshape to (B*S, T, D) for the FF layers
        out_flat = out.reshape(B * self.num_spks, D, T).permute(0, 2, 1)
        #          (B*S, T, D)
        ff_out   = self.ff_gate(out_flat)                   # (B*S, T, D)
        out_flat = self.ff_norm(out_flat + ff_out)          # residual
        out      = out_flat.permute(0, 2, 1).reshape(B, self.num_spks, D, T)

        # ‚îÄ‚îÄ Step 4: GroupNorm over channel dim (per speaker) ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
        out = self.out_norm(
            out.reshape(B * self.num_spks, D, T)
        ).reshape(B, self.num_spks, D, T)

        return out                                          # (B, S, D, T)

""")


# ----------------------------------------------------------
# 2.  Patch helpers
# ----------------------------------------------------------
def patch_module_py(path: str) -> None:
    """Inject class + swap SpkSplitStage calls in modules/module.py."""
    if not os.path.isfile(path):
        raise FileNotFoundError(
            f"module.py not found at: {path}\n"
            "Make sure the repo is cloned before running this cell."
        )
    with open(path, "r") as f:
        src = f.read()

    # Guard
    if "AttentionSpeakerSplit" in src:
        print("‚ö° module.py already contains AttentionSpeakerSplit ‚Äî skipping inject.")
    else:
        marker = "class SpkSplitStage"
        if marker not in src:
            raise RuntimeError(
                "Could not find 'class SpkSplitStage' in module.py.\n"
                "Search for the speaker-split class and insert "
                "AttentionSpeakerSplit above it manually."
            )
        pos = src.index(marker)
        src = src[:pos] + NEW_CLASS + src[pos:]
        print("‚úÖ  AttentionSpeakerSplit class injected into module.py.")

    # Swap instantiation calls
    before = src
    src    = src.replace("SpkSplitStage(", "AttentionSpeakerSplit(")
    n_swapped = before.count("SpkSplitStage(")
    if src != before:
        print(f"‚úÖ  Swapped {n_swapped} SpkSplitStage( ‚Üí AttentionSpeakerSplit( in module.py.")
    else:
        print("  ‚ÑπÔ∏è  No SpkSplitStage( call found in module.py (may be in model.py).")

    with open(path, "w") as f:
        f.write(src)
    print("üíæ  module.py saved.")


def patch_other(path: str) -> None:
    """Swap SpkSplitStage calls in any additional file (model.py etc.)."""
    if not os.path.isfile(path):
        return
    fname = os.path.basename(path)
    with open(path, "r") as f:
        src = f.read()
    if "AttentionSpeakerSplit" in src:
        print(f"‚ö° {fname}: already patched ‚Äî skipping.")
        return
    before = src
    src    = src.replace("SpkSplitStage(", "AttentionSpeakerSplit(")
    if src == before:
        print(f"  ‚ÑπÔ∏è  {fname}: no SpkSplitStage( calls ‚Äî nothing to do.")
        return
    with open(path, "w") as f:
        f.write(src)
    n = before.count("SpkSplitStage(")
    print(f"‚úÖ  {fname}: swapped {n} call(s).")


# ----------------------------------------------------------
# 3.  Run patches
# ----------------------------------------------------------
BASE = "/content/SepReformer/models/SepReformer_Base_WSJ0"

patch_module_py(os.path.join(BASE, "modules", "module.py"))
patch_other(os.path.join(BASE, "model.py"))
patch_other(os.path.join(BASE, "engine.py"))


# ----------------------------------------------------------
# 4.  Smoke-test (runs the class in-process, no full model needed)
# ----------------------------------------------------------
print("\n‚îÄ‚îÄ‚îÄ Shape + Gradient Smoke Test ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ")

exec(compile(NEW_CLASS, "<string>", "exec"), globals())

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"  device: {device}")

# Match configs.yaml: in_channels=128, num_spks=2
BATCH, DIM, TIME, NUM_SPKS = 2, 128, 200, 2

module = AttentionSpeakerSplit(
    in_channels=DIM, num_spks=NUM_SPKS,
    num_heads=8, dropout=0.05
).to(device)

dummy = torch.randn(BATCH, DIM, TIME, device=device)

# Forward shape
with torch.no_grad():
    out = module(dummy)
assert out.shape == (BATCH, NUM_SPKS, DIM, TIME), \
    f"Shape mismatch: expected {(BATCH, NUM_SPKS, DIM, TIME)}, got {tuple(out.shape)}"
print(f"  Input  shape : {tuple(dummy.shape)}")
print(f"  Output shape : {tuple(out.shape)}  ‚úÖ")

# Gradient flow
module(dummy).sum().backward()
assert module.speaker_queries.grad is not None
print(f"  ‚Äñ‚àá speaker_queries‚Äñ = {module.speaker_queries.grad.norm():.4f}  ‚úÖ  gradients flow")

# Speaker outputs should differ from each other
diff = (out[:, 0] - out[:, 1]).abs().mean().item()
print(f"  Speaker output mean |diff| = {diff:.5f}  " \
      f"{'‚úÖ differentiated' if diff > 1e-6 else '‚ö†Ô∏è identical ‚Äî check queries init'}")

n_p = sum(p.numel() for p in module.parameters())
print(f"  Module params  : {n_p:,}  ({n_p/1e6:.3f} M)")

print("""
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
üéâ  Patch complete!

If SepReformer modules were already imported this session,
restart the runtime now. Then run training as normal:

    !python run.py --model SepReformer_Base_WSJ0 --engine-mode train
‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
""")

# ----------------------------------------------------------
# 5.  Post-patch Indentation Fix for module.py
# ----------------------------------------------------------
module_file_path = os.path.join(BASE, "modules", "module.py")

if os.path.exists(module_file_path):
    print("\n‚îÄ‚îÄ‚îÄ Fixing indentation in module.py ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ")
    with open(module_file_path, "r") as f:
        lines = f.readlines()

    fixed_lines = []
    for line in lines:
        stripped_line = line.lstrip()
        # If a line starts with 'class' or 'def' (and is not commented out),
        # ensure it has no leading whitespace. This assumes these are top-level constructs.
        if (stripped_line.startswith("class ") or stripped_line.startswith("def ")) and \
           not stripped_line.startswith("#"):
            fixed_lines.append(stripped_line) # Remove all leading whitespace
        else:
            fixed_lines.append(line) # Keep original indentation

    with open(module_file_path, "w") as f:
        f.writelines(fixed_lines)
    print("‚úÖ  Indentation in module.py adjusted. (Removed leading whitespace from top-level class/def statements)")
else:
    print(f"‚ö†Ô∏è  module.py not found at {module_file_path}, skipping indentation fix.")


‚úÖ  AttentionSpeakerSplit class injected into module.py.
‚úÖ  Swapped 2 SpkSplitStage( ‚Üí AttentionSpeakerSplit( in module.py.
üíæ  module.py saved.
  ‚ÑπÔ∏è  model.py: no SpkSplitStage( calls ‚Äî nothing to do.
  ‚ÑπÔ∏è  engine.py: no SpkSplitStage( calls ‚Äî nothing to do.

‚îÄ‚îÄ‚îÄ Shape + Gradient Smoke Test ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
  device: cuda
  Input  shape : (2, 128, 200)
  Output shape : (2, 2, 128, 200)  ‚úÖ
  ‚Äñ‚àá speaker_queries‚Äñ = 0.0000  ‚úÖ  gradients flow
  Speaker output mean |diff| = 0.15260  ‚úÖ differentiated
  Module params  : 116,736  (0.117 M)

‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
üéâ  Patch complete!

If SepReformer modules were already imported this session,
restart the runtime now. Then run training as normal:

    !python run.py --model SepReformer_Base_WSJ

### **Train**

In [9]:
import os
import textwrap
import torch

# 0. Set working directory
%cd /content/SepReformer

# 1. Revert module.py to clean state
# This ensures we don't have broken indentation from previous attempts
print("üîÑ Reverting files to clean state...")
!git checkout models/SepReformer_Base_WSJ0/modules/module.py
!git checkout models/SepReformer_Base_WSJ0/model.py
!git checkout models/SepReformer_Base_WSJ0/engine.py

# 2. Correctly Re-Patch module.py
module_path = 'models/SepReformer_Base_WSJ0/modules/module.py'

# Using torch.nn explicitly since module.py might not have 'import torch.nn as nn'
# FIX 1: using positional arguments for cross_attn to satisfy ptflops
# FIX 2: Return (B*S, D, T) to match decoder expectation
NEW_CLASS_CODE = textwrap.dedent("""
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
#  AttentionSpeakerSplit  (drop-in replacement for SpkSplitStage)
# ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ‚îÄ
class AttentionSpeakerSplit(torch.nn.Module):
    def __init__(self, in_channels: int, num_spks: int,
                 num_heads: int = 8, dropout: float = 0.05):
        super().__init__()
        self.in_channels = in_channels
        self.num_spks    = num_spks
        self.speaker_queries = torch.nn.Parameter(torch.empty(num_spks, in_channels))
        torch.nn.init.trunc_normal_(self.speaker_queries, std=0.02)
        self.cross_attn = torch.nn.MultiheadAttention(
            embed_dim=in_channels, num_heads=num_heads, dropout=dropout, batch_first=False
        )
        self.attn_norm = torch.nn.LayerNorm(in_channels)
        self.ff_gate = torch.nn.Sequential(
            torch.nn.LayerNorm(in_channels),
            torch.nn.Linear(in_channels, in_channels * 2, bias=True),
            torch.nn.GLU(dim=-1),
        )
        self.ff_norm = torch.nn.LayerNorm(in_channels)
        self.res_proj = torch.nn.Linear(in_channels, in_channels, bias=False)
        self.out_norm = torch.nn.GroupNorm(1, in_channels)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, D, T = x.shape
        kv = x.permute(2, 0, 1)
        q  = self.speaker_queries.unsqueeze(1).expand(-1, B, -1)

        # PTFLOPS FIX: Pass q, k, v as positional arguments.
        # ptflops hook expects to unpack (q, k, v) from input tuple.
        attn_out, _ = self.cross_attn(q, kv, kv, need_weights=False)

        attn_out = self.attn_norm(attn_out + q)
        x_exp   = x.unsqueeze(1).expand(-1, self.num_spks, -1, -1)
        spk_ctx = attn_out.permute(1, 0, 2).unsqueeze(-1)
        res = self.res_proj(x.permute(0, 2, 1))
        res = res.permute(0, 2, 1).unsqueeze(1)
        out = x_exp * torch.sigmoid(spk_ctx) + res
        out_flat = out.reshape(B * self.num_spks, D, T).permute(0, 2, 1)
        ff_out   = self.ff_gate(out_flat)
        out_flat = self.ff_norm(out_flat + ff_out)
        out      = out_flat.permute(0, 2, 1).reshape(B, self.num_spks, D, T)

        # Reshape to (B*S, D, T) for GroupNorm and final output
        # Decoder expects flattened batch of speakers
        out = self.out_norm(out.reshape(B * self.num_spks, D, T))
        return out
""")

if os.path.exists(module_path):
    with open(module_path, 'r') as f:
        src = f.read()

    if "AttentionSpeakerSplit" not in src:
        # A. Rename the original class definition to avoid conflict
        src = src.replace("class SpkSplitStage", "class _OldSpkSplitStage")

        # B. Replace all instantiation calls SpkSplitStage(...) -> AttentionSpeakerSplit(...)
        src = src.replace("SpkSplitStage(", "AttentionSpeakerSplit(")

        # C. Append the new class to the END of the file
        src += NEW_CLASS_CODE

        with open(module_path, 'w') as f:
            f.write(src)
        print("‚úÖ module.py patched successfully.")
    else:
        print("‚ÑπÔ∏è module.py already patched.")

    # 3. Patch model.py and engine.py (Update usages)
    other_files = [
        'models/SepReformer_Base_WSJ0/model.py',
        'models/SepReformer_Base_WSJ0/engine.py'
    ]

    for fpath in other_files:
        if os.path.exists(fpath):
            with open(fpath, 'r') as f:
                src = f.read()

            if "SpkSplitStage(" in src:
                src = src.replace("SpkSplitStage(", "AttentionSpeakerSplit(")
                with open(fpath, 'w') as f:
                    f.write(src)
                print(f"‚úÖ {os.path.basename(fpath)} updated.")
            else:
                print(f"‚ÑπÔ∏è No usages found in {os.path.basename(fpath)}.")

else:
    print(f"‚ùå Error: {module_path} not found.")



/content/SepReformer
üîÑ Reverting files to clean state...
Updated 1 path from the index
Updated 0 paths from the index
Updated 1 path from the index
‚úÖ module.py patched successfully.
‚ÑπÔ∏è No usages found in model.py.
‚ÑπÔ∏è No usages found in engine.py.
üöÄ Starting Training...
2026-02-18 12:22:41.638806: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-18 12:22:41.657090: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771417361.679636    2620 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771417361.687003    26

In [10]:
import os

# Path to the engine file
engine_path = 'models/SepReformer_Base_WSJ0/engine.py'
print(f"üîß Patching {engine_path} to fix AttributeError and typos...")

if os.path.exists(engine_path):
    with open(engine_path, 'r') as f:
        lines = f.readlines()

    new_lines = []
    for line in lines:
        # Fix 1: Remove 'self.wandb_run' causing AttributeError
        if 'util_engine.save_checkpoint_per_best' in line and 'self.wandb_run' in line:
            # Remove the argument. Assuming format: ..., self.checkpoint_path, self.wandb_run)
            line = line.replace(', self.wandb_run', '')
            print("  ‚úÖ Removed undefined 'self.wandb_run' argument.")

        # Fix 2: 'add_scalars' should be 'add_scalar' for single value (Learning Rate)
        # This prevents a future error during logging
        if 'writer_src.add_scalars("Learning Rate"' in line:
            line = line.replace('add_scalars', 'add_scalar')
            print("  ‚úÖ Fixed 'add_scalars' -> 'add_scalar' for Learning Rate.")

        new_lines.append(line)

    with open(engine_path, 'w') as f:
        f.writelines(new_lines)
    print("‚úÖ engine.py patched.")
else:
    print(f"‚ùå Error: {engine_path} not found.")

print("üöÄ Restarting Training...")
!python run.py --model SepReformer_Base_WSJ0 --engine-mode train

üîß Patching models/SepReformer_Base_WSJ0/engine.py to fix AttributeError and typos...
  ‚úÖ Removed undefined 'self.wandb_run' argument.
  ‚úÖ Fixed 'add_scalars' -> 'add_scalar' for Learning Rate.
‚úÖ engine.py patched.
üöÄ Restarting Training...
2026-02-18 12:24:56.174332: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-02-18 12:24:56.192745: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1771417496.215039    3935 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1771417496.222430    3935 cuda_blas.cc:1407] Unable to reg

In [11]:
!ls -l models/SepReformer_Base_WSJ0/log/scratch_weights/

total 172560
-rw-r--r-- 1 root root 176696025 Feb 18 12:26 epoch.0001.pth


In [12]:
import os
import random

# ◊ë◊ó◊ô◊®◊™ ◊ß◊ï◊ë◊• ◊ê◊ß◊®◊ê◊ô
mix_dir = '/content/drive/MyDrive/◊™◊ï◊ê◊®/speech/mix_clean'
files = [f for f in os.listdir(mix_dir) if f.endswith('.wav')]
test_file = random.choice(files)
test_file_path = os.path.join(mix_dir, test_file)

print(f"Testing on: {test_file_path}")

# ◊î◊§◊ß◊ï◊ì◊î ◊î◊û◊™◊ï◊ß◊†◊™ (◊©◊ô◊ù ◊ú◊ë ◊ú◊û◊ß◊§◊ô◊ù ◊ë-out-wav-dir)
!python run.py \
  --model SepReformer_Base_WSJ0 \
  --engine-mode infer_sample \
  --sample-file "{test_file_path}" \
  --out-wav-dir "results_debug"

Testing on: /content/drive/MyDrive/◊™◊ï◊ê◊®/speech/mix_clean/mix_00046.wav
Traceback (most recent call last):
  File "<frozen importlib._bootstrap>", line 1331, in _find_and_load_unlocked
  File "<frozen importlib._bootstrap>", line 935, in _load_unlocked
  File "<frozen importlib._bootstrap_external>", line 999, in exec_module
  File "<frozen importlib._bootstrap>", line 488, in _call_with_frames_removed
  File "/usr/local/lib/python3.12/dist-packages/torch/distributed/rpc/server_process_global_profiler.py", line 7, in <module>
    from torch.autograd.profiler_legacy import profile
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/__init__.py", line 32, in <module>
    from .gradcheck import gradcheck, gradgradcheck
  File "/usr/local/lib/python3.12/dist-packages/torch/autograd/gradcheck.py", line 11, in <module>
    import torch.testing
  File "/usr/local/lib/python3.12/dist-packages/torch/testing/__init__.py", line 4, in <module>
    from ._comparison import assert_allc

In [13]:
import os
import glob
import random
import shutil
import IPython.display as ipd
import librosa
import librosa.display
import matplotlib.pyplot as plt
import numpy as np

# 1. Define Paths
checkpoint_path = "/content/SepReformer/models/SepReformer_Base_WSJ0/log/scratch_weights/"
# Find the latest checkpoint automatically
checkpoints = sorted(glob.glob(os.path.join(checkpoint_path, "epoch.*.pth")))
latest_ckpt = checkpoints[-1]
print(f"‚≠ê Using Checkpoint: {os.path.basename(latest_ckpt)}")

# 2. Pick a random test file
mix_dir = '/content/drive/MyDrive/◊™◊ï◊ê◊®/speech/mix_clean'
files = [f for f in os.listdir(mix_dir) if f.endswith('.wav') and '_in' not in f and '_out' not in f]
test_file = random.choice(files)
test_file_path = os.path.join(mix_dir, test_file)

print(f"üéµ Separating: {test_file}")

# 3. Run Inference
!python run.py \
  --model SepReformer_Base_WSJ0 \
  --engine-mode infer_sample \
  --sample-file "{test_file_path}" \
  --out-wav-dir "final_results_90ep" > /dev/null 2>&1

# 4. Locate and Move Files (Handling the path bug we saw earlier)
# The model likely saved them in the source folder again due to the bug
base_name = os.path.splitext(test_file)[0]
path_mix = os.path.join(mix_dir, f"{base_name}_in.wav")
path_s1 = os.path.join(mix_dir, f"{base_name}_out_0.wav")
path_s2 = os.path.join(mix_dir, f"{base_name}_out_1.wav")

if os.path.exists(path_s1):
    print("‚úÖ Separation successful!\n")

    # Audio Player
    print("üéß Original Mixture:")
    ipd.display(ipd.Audio(path_mix))

    print("üó£Ô∏è Speaker 1 (Est):")
    ipd.display(ipd.Audio(path_s1))

    print("üó£Ô∏è Speaker 2 (Est):")
    ipd.display(ipd.Audio(path_s2))

    # Visualization
    fig, ax = plt.subplots(1, 3, figsize=(18, 5))

    def plot_spec(ax, p, title):
        y, sr = librosa.load(p, sr=8000)
        D = librosa.amplitude_to_db(np.abs(librosa.stft(y)), ref=np.max)
        img = librosa.display.specshow(D, y_axis='linear', x_axis='time', sr=sr, ax=ax)
        ax.set_title(title)
        return img

    img = plot_spec(ax[0], path_mix, "Mixture")
    plot_spec(ax[1], path_s1, "Speaker 1")
    plot_spec(ax[2], path_s2, "Speaker 2")
    fig.colorbar(img, ax=ax, format="%+2.0f dB")
    plt.show()

else:
    print("‚ùå Output files not found. Check if the model saved to a different location.")

KeyboardInterrupt: 

### **save model offline**

In [None]:
import os
import shutil
import glob

# 1. Define Paths
# Source directories
model_root = "/content/SepReformer/models/SepReformer_Base_WSJ0"
weights_dir = os.path.join(model_root, "log", "scratch_weights")

# Destination directory in Drive (Persistent storage)
dest_dir = "/content/drive/MyDrive/◊™◊ï◊ê◊®/speech/SepReformer_Saved_Model"
os.makedirs(dest_dir, exist_ok=True)

# 2. Identify Files to Save
# A. Latest Checkpoint
checkpoints = sorted(glob.glob(os.path.join(weights_dir, "epoch.*.pth")))
if not checkpoints:
    print("‚ùå No checkpoints found!")
else:
    latest_ckpt = checkpoints[-1]
    ckpt_name = os.path.basename(latest_ckpt)

    # B. Config and Patched Code
    files_to_save = {
        latest_ckpt: os.path.join(dest_dir, ckpt_name),
        os.path.join(model_root, "configs.yaml"): os.path.join(dest_dir, "configs.yaml"),
        os.path.join(model_root, "dataset.py"): os.path.join(dest_dir, "dataset.py"), # Save patched file
        os.path.join(model_root, "engine.py"): os.path.join(dest_dir, "engine.py")   # Save patched file
    }

    print(f"üíæ Saving files to: {dest_dir} ...")

    # 3. Copy Files
    for src, dst in files_to_save.items():
        if os.path.exists(src):
            shutil.copy2(src, dst)
            print(f"‚úÖ Saved: {os.path.basename(src)}")
        else:
            print(f"‚ö†Ô∏è Warning: Source file not found: {src}")


### **load model**

In [None]:
import os
import shutil
import glob
from google.colab import drive

# --- 1. Mount Drive ---
drive.mount('/content/drive')

# --- 2. Setup Repository (Clone & Install) ---
if not os.path.exists('/content/SepReformer'):
    print("üöÄ Cloning SepReformer...")
    !git clone https://github.com/dmlguq456/SepReformer.git
    %cd SepReformer
    print("üì¶ Installing dependencies...")
    !pip install -r requirements.txt
    !pip install loguru torchinfo ptflops thop mir_eval
else:
    print("‚úÖ Repository already exists.")
    %cd /content/SepReformer

# --- 3. Restore Saved Model ---
# Paths
saved_models_dir = '/content/drive/MyDrive/◊™◊ï◊ê◊®/speech/SepReformer_Saved_Model'
model_root = 'models/SepReformer_Base_WSJ0' # Relative to /content/SepReformer

print(f"üîÑ Restoring model files from: {saved_models_dir}")

# Create directories
os.makedirs(os.path.join(model_root, 'log', 'scratch_weights'), exist_ok=True)

# Files to restore
files_to_copy = ['configs.yaml', 'dataset.py', 'engine.py']

for fname in files_to_copy:
    src = os.path.join(saved_models_dir, fname)
    dst = os.path.join(model_root, fname)
    if os.path.exists(src):
        shutil.copy2(src, dst)
        print(f"  - Restored {fname}")
    else:
        print(f"  ‚ö†Ô∏è Missing {fname} in saved dir")

# Restore Checkpoint
saved_ckpts = sorted(glob.glob(os.path.join(saved_models_dir, "epoch.*.pth")))
if saved_ckpts:
    latest_ckpt = saved_ckpts[-1]
    ckpt_name = os.path.basename(latest_ckpt)
    dst_ckpt = os.path.join(model_root, 'log', 'scratch_weights', ckpt_name)
    shutil.copy2(latest_ckpt, dst_ckpt)
    print(f"  - Restored Checkpoint: {ckpt_name}")
else:
    print("  ‚ùå No checkpoint found to restore!")

print("\nüéâ Ready! The environment is set up and the model is loaded.")

### **Attention model**