# Importing

In [1]:
import os
import torch
import torchaudio
import json
import time 
import requests
import random
import functools
import ipykernel
import argparse
import torch.nn as nn
import numpy as np
import h5py
from tqdm import tqdm
import copy, math
from einops.layers.torch import Rearrange
from einops import rearrange
from torchaudio import transforms
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, ConcatDataset
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer
from math import cos, pi, floor, sin
from torch.utils.tensorboard import SummaryWriter
import pandas as pd
import shutil
from sacred import Experiment
from sacred.observers import FileStorageObserver
from sacred.utils import apply_backspaces_and_linefeeds
import gc
import scipy.io as sio
import scipy.sparse as sp
import itertools
import matplotlib.pyplot as plt
import re
from scipy.io import loadmat
from torch import einsum
from typing import Optional
from torch.nn.parameter import Parameter
from torch import Tensor
import torch.nn.init as init
from pprint import pprint
from types import SimpleNamespace 
from typing import List, Tuple
from scipy.signal import butter, lfilter, sosfilt
from pesq import pesq
from pystoi import stoi  

# Preparation

## Configs

In [3]:
class Config():
    root = ''
    file_name = ''
    root_feature = ''
    
    sample_rate = 16000                     
    learning_rate = 1e-4
    batch_size = 8
    
    SUBJECTS = []
    seed=42
    
    #Mossformer2
    #FLASHT
    emb_size = 128                          
    encoder_kernel_size = 20               
    encoder_out_nchannels = 256           
    intra_numlayers = 4                   
    intra_nhead = 8                        
    intra_dffn = 512                       
    intra_dropout = 0.1                    
    intra_use_positional = True            
    intra_norm_before = True              

    #MASK
    encoder_out_nchannels = 256             
    masknet_numspks = 1                   
    masknet_chunksize = 100                 
    masknet_numlayers = 6                   
    masknet_norm = "ln"                     
    masknet_useextralinearlayer = False     
    masknet_extraskipconnection = True        
              
    #Encoder
    num_channels = 64
    sequence_length = 128
    num_subjects = 32
    num_features = 128
    num_latents = 256
    num_blocks = 1

    chan_num = 128   
    band_num = 5   
    inc = 32      
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

## Utils

In [10]:
class AvgMeter:
    def __init__(self, name):
        self.name = name
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0

    def update(self, value, n=1):
        self.sum += value * n
        self.count += n

    @property
    def avg(self):
        return self.sum / self.count

def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]

## Loss

In [14]:
def loss_fn(net, X, mrstftloss, **kwargs):
    """
    Loss function in CleanUNet

    Parameters:
    net: network
    X: training data pair (clean audio, noisy_audio)
    ell_p: \ell_p norm (1 or 2) of the AE loss
    ell_p_lambda: factor of the AE loss
    stft_lambda: factor of the STFT loss
    mrstftloss: multi-resolution STFT loss function

    Returns:
    loss: value of objective function
    output_dic: values of each component of loss
    """

    assert type(X) == tuple and len(X) == 4
    
    noisy_audio, eeg, clean_audio,eeg_feature = X
    loss = 0.0
    denoised_audio = net(noisy_audio, eeg,eeg_feature)
    
    sc_loss = mrstftloss(denoised_audio.squeeze(1), clean_audio.squeeze(1))

    loss += sc_loss

    return loss,denoised_audio

In [15]:
def calc_SISDR(estimate_source, source):
    """Calculate Scale-Invariant Source-to-Distortion Ratio (SI-SDR)
    Args:
        source: torch tensor, [batch size, sequence length]
        estimate_source: torch tensor, [batch size, sequence length]
    Returns:
        SISDR, [batch size]
    """
    source = source.squeeze()
    estimate_source = estimate_source.squeeze()
    assert source.size() == estimate_source.size()
    
    # Step 1. Zero-mean norm (optional for SI-SDR, but keeping consistent with SI-SNR)
    source = source - torch.mean(source, axis=-1, keepdim=True)
    estimate_source = estimate_source - torch.mean(estimate_source, axis=-1, keepdim=True)
    
    # Step 2. SI-SDR calculation
    EPS = 1e-9  # Small value to avoid division by zero
    
    # Calculate scaling factor alpha = <s',s> / <s,s>
    source_energy = torch.sum(source * source, axis=-1, keepdim=True) + EPS
    scaling = torch.sum(source * estimate_source, axis=-1, keepdim=True) / source_energy
    
    # s_target = alpha * s
    s_target = scaling * source
    # e_noise = s' - s_target
    noise = estimate_source - s_target
    
    # SI-SDR = 10 * log_10(||s_target||^2 / ||e_noise||^2)
    target_energy = torch.sum(s_target ** 2, axis=-1)
    noise_energy = torch.sum(noise ** 2, axis=-1) + EPS
    ratio = target_energy / noise_energy
    
    sisdr = 10 * torch.log10(ratio + EPS)
    return sisdr

In [16]:
class si_sidrloss(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, pred, target):
        """Calculate SI-SDR loss
        Args:
            pred: torch tensor, [batch size, sequence length]
            target: torch tensor, [batch size, sequence length]
        Returns:
            loss: torch tensor, scalar
        """
        sisdr = calc_SISDR(pred, target)
        # Return negative mean SI-SDR as loss (higher SI-SDR is better, so negate for loss)
        return -torch.mean(sisdr)

In [17]:
def cal_sisdri(mix_wave, target_wave, estmate_wave):
    sisdr1 = calc_SISDR(mix_wave, target_wave)
    sisdr2 = calc_SISDR(estmate_wave, target_wave)
    sisdri = sisdr2 - sisdr1
    
    return sisdri

# Model

## EEGEncoder

In [None]:
class down_sample(nn.Module):
    def __init__(self, c, k, s, p):
        super().__init__()
        self.op = nn.Sequential(
            nn.Conv2d(c, c, (1, k), stride=(1, s), padding=(0, p), bias=False),
            nn.BatchNorm2d(c),
            nn.ELU(False)
        )
        nn.init.xavier_uniform_(self.op[0].weight)
    def forward(self, x): return self.op(x)

class Residual_Block(nn.Module):
    def __init__(self, inc, outc, g=1):
        super().__init__()
        self.exp   = nn.Conv2d(inc, outc, 1, bias=False) if inc != outc else None
        self.conv1 = nn.Conv2d(inc,  outc, (1,3), padding=(0,1), groups=g, bias=False)
        self.conv2 = nn.Conv2d(outc, outc, (1,3), padding=(0,1), groups=g, bias=False)
        self.bn1, self.bn2 = nn.BatchNorm2d(outc), nn.BatchNorm2d(outc)
        self.elu = nn.ELU(False)
        for m in (self.conv1, self.conv2):
            nn.init.xavier_uniform_(m.weight)
    def forward(self, x):
        y = self.bn1(self.conv1(x))
        y = self.bn2(self.conv2(y))
        return self.elu(y + (self.exp(x) if self.exp is not None else x))

class input_layer(nn.Module):
    def __init__(self, outc, g=1):
        super().__init__()
        self.op = nn.Sequential(
            nn.Conv2d(1, outc, (1,3), padding=(0,1), groups=g, bias=False),
            nn.BatchNorm2d(outc)
        )
        nn.init.xavier_uniform_(self.op[0].weight)
    def forward(self, x): return self.op(x)

def embedding_network(outc, nlayer, g=1):
    layers = [input_layer(outc, g)]
    for i in range(nlayer):
        layers.append(Residual_Block(int(2**i * outc),
                                     int(2**(i+1) * outc), g))
    return nn.Sequential(*layers)

# ============================================================
#                     Temporal Branch
# ============================================================
class Multi_Scale_Temporal_Block(nn.Module):
    def __init__(self, outc, nlayer=1):
        super().__init__()
        self.embed = embedding_network(outc, nlayer, 1)
        ch = outc * (2**nlayer) + 1
        self.ds1 = down_sample(ch,  4,  2, 1)
        self.ds2 = down_sample(ch,  8,  4, 2)
        self.ds3 = down_sample(ch, 16,  8, 4)
        self.ds4 = down_sample(ch, 16, 16, 0)
        self.ds5 = down_sample(ch, 16, 16, 0)
    def forward(self,x):
        cat = torch.cat((self.embed(x), x),1)
        return torch.concat((self.ds1(cat), self.ds2(cat),
                             self.ds3(cat), self.ds4(cat), self.ds5(cat)), 3)

class Temporal_Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.blocks = nn.ModuleList([Multi_Scale_Temporal_Block(2) for _ in range(5)])
        self.up = nn.Identity()
    def forward(self,x):
        outs = [blk(x[:, i:i+1]) for i, blk in enumerate(self.blocks)]
        return self.up(torch.cat(outs,1))               # [B,25,128,256]

# ============================================================
#                     GCN 
# ============================================================
class Electrodes:
    """128 Electrode coordinates + learnable edge importance"""
    def __init__(self):
        # ---------- 128×3  ----------
        self.positions_3d = np.array([
            # A1–A32 --------------------------------------------------
            [ 0.000000,  0.095000,  0.000000], [-0.012383,  0.093504,  0.011344],
            [ 0.002068,  0.092008, -0.023564], [ 0.017557,  0.090512,  0.022899],
            [-0.032677,  0.089016, -0.005780], [ 0.031177,  0.087520, -0.019832],
            [-0.010465,  0.086024,  0.038928], [-0.019985,  0.084528, -0.038480],
            [ 0.043359,  0.083031,  0.015835], [-0.045066,  0.081535,  0.018602],
            [ 0.021690,  0.080039, -0.046349], [ 0.015994,  0.078543,  0.050992],
            [-0.048085,  0.077047, -0.027866], [ 0.056250,  0.075551, -0.012366],
            [-0.034223,  0.074055,  0.048679], [-0.007880,  0.072559, -0.060812],
            [ 0.048231,  0.071063,  0.036332], [ 0.002143,  0.069567,  0.065043],
            [-0.059341,  0.068071, -0.001459], [ 0.062635,  0.066575,  0.004402],
            [-0.016635,  0.065079, -0.070415], [-0.028852,  0.063583,  0.065405],
            [ 0.055184,  0.062087, -0.040042], [-0.070070,  0.060591,  0.019264],
            [ 0.032360,  0.059095,  0.062903], [ 0.010572,  0.057599, -0.074345],
            [-0.050861,  0.056103,  0.055513], [ 0.074536,  0.054607, -0.018668],
            [-0.030532,  0.053111, -0.071740], [-0.002514,  0.051615,  0.078259],
            [ 0.068120,  0.050119,  0.041998], [-0.078236,  0.048623,  0.003167],
            # B1–B32 --------------------------------------------------
            [ 0.095000,  0.047127,  0.000000], [ 0.088017,  0.044381, -0.029607],
            [ 0.082092,  0.041635,  0.035895], [ 0.060069,  0.038888, -0.059242],
            [ 0.044573,  0.036142,  0.066201], [ 0.015313,  0.033396, -0.079206],
            [-0.015313,  0.033396,  0.079206], [-0.044573,  0.036142, -0.066201],
            [-0.060069,  0.038888,  0.059242], [-0.082092,  0.041635, -0.035895],
            [-0.088017,  0.044381,  0.029607], [-0.095000,  0.047127,  0.000000],
            [-0.078236,  0.048623, -0.003167], [-0.068120,  0.050119, -0.041998],
            [ 0.002514,  0.051615, -0.078259], [ 0.030532,  0.053111,  0.071740],
            [-0.074536,  0.054607,  0.018668], [ 0.050861,  0.056103, -0.055513],
            [-0.010572,  0.057599,  0.074345], [-0.032360,  0.059095, -0.062903],
            [ 0.070070,  0.060591, -0.019264], [-0.055184,  0.062087,  0.040042],
            [ 0.028852,  0.063583, -0.065405], [ 0.016635,  0.065079,  0.070415],
            [-0.062635,  0.066575, -0.004402], [ 0.059341,  0.068071,  0.001459],
            [-0.002143,  0.069567, -0.065043], [-0.048231,  0.071063, -0.036332],
            [ 0.007880,  0.072559,  0.060812], [ 0.034223,  0.074055, -0.048679],
            [-0.056250,  0.075551,  0.012366], [ 0.048085,  0.077047,  0.027866],
            # C1–C32 --------------------------------------------------
            [ 0.000000, -0.095000,  0.000000], [ 0.012383, -0.093504, -0.011344],
            [-0.002068, -0.092008,  0.023564], [-0.017557, -0.090512, -0.022899],
            [ 0.032677, -0.089016,  0.005780], [-0.031177, -0.087520,  0.019832],
            [ 0.010465, -0.086024, -0.038928], [ 0.019985, -0.084528,  0.038480],
            [-0.043359, -0.083031, -0.015835], [ 0.045066, -0.081535, -0.018602],
            [-0.021690, -0.080039,  0.046349], [-0.015994, -0.078543, -0.050992],
            [ 0.048085, -0.077047,  0.027866], [-0.056250, -0.075551,  0.012366],
            [ 0.034223, -0.074055, -0.048679], [ 0.007880, -0.072559,  0.060812],
            [-0.048231, -0.071063, -0.036332], [-0.002143, -0.069567,  0.065043],
            [ 0.059341, -0.068071, -0.001459], [-0.062635, -0.066575, -0.004402],
            [ 0.016635, -0.065079,  0.070415], [ 0.028852, -0.063583, -0.065405],
            [-0.055184, -0.062087,  0.040042], [ 0.070070, -0.060591, -0.019264],
            [-0.032360, -0.059095, -0.062903], [-0.010572, -0.057599,  0.074345],
            [ 0.050861, -0.056103, -0.055513], [-0.074536, -0.054607,  0.018668],
            [ 0.030532, -0.053111,  0.071740], [ 0.002514, -0.051615, -0.078259],
            [-0.068120, -0.050119, -0.041998], [ 0.078236, -0.048623, -0.003167],
            # D1–D32 --------------------------------------------------
            [-0.095000, -0.047127,  0.000000], [-0.088017, -0.044381,  0.029607],
            [-0.082092, -0.041635, -0.035895], [-0.060069, -0.038888,  0.059242],
            [-0.044573, -0.036142, -0.066201], [-0.015313, -0.033396,  0.079206],
            [ 0.015313, -0.033396, -0.079206], [ 0.044573, -0.036142,  0.066201],
            [ 0.060069, -0.038888, -0.059242], [ 0.082092, -0.041635,  0.035895],
            [ 0.088017, -0.044381, -0.029607], [ 0.095000, -0.047127,  0.000000],
            [ 0.078236, -0.048623,  0.003167], [ 0.068120, -0.050119,  0.041998],
            [-0.002514, -0.051615,  0.078259], [-0.030532, -0.053111, -0.071740],
            [ 0.074536, -0.054607, -0.018668], [-0.050861, -0.056103,  0.055513],
            [ 0.010572, -0.057599, -0.074345], [ 0.032360, -0.059095,  0.062903],
            [-0.070070, -0.060591,  0.019264], [ 0.055184, -0.062087, -0.040042],
            [-0.028852, -0.063583,  0.065405], [-0.016635, -0.065079, -0.070415],
            [ 0.062635, -0.066575,  0.004402], [-0.059341, -0.068071, -0.001459],
            [ 0.002143, -0.069567, -0.065043], [ 0.048231, -0.071063,  0.036332],
            [-0.007880, -0.072559, -0.060812], [-0.034223, -0.074055,  0.048679],
            [ 0.056250, -0.075551, -0.012366], [-0.048085, -0.077047, -0.027866]
        ])
        self.positions_3d = np.int_(self.positions_3d * 1000)

        # -------- Channel name --------
        self.channel_names   = np.array([f'{sec}{i+1}' for sec in "ABCD" for i in range(32)])
        self.channel_to_index = {n: i for i, n in enumerate(self.channel_names)}
        # -------- Learnable edge --------
        self.edge_importance = nn.Parameter(
            0.1 * torch.eye(Config.chan_num, device=Config.device) +
            0.01 * torch.randn(Config.chan_num, Config.chan_num, device=Config.device),
            requires_grad=True
        )

    # -------- Basic distance weighted adjacency --------
    def get_adjacency_matrix(self, calibration_constant=6.0, active_threshold=0.1):
        dists = np.linalg.norm(self.positions_3d[:, None] - self.positions_3d, axis=-1)
        with np.errstate(divide='ignore'):
            weights = np.where(dists != 0, calibration_constant / dists, 0.0)
        weights[weights < active_threshold] = 0.0
        np.fill_diagonal(weights, 0.0)
        if weights.max() > weights.min():
            weights = (weights - weights.min()) / (weights.max() - weights.min())
        np.fill_diagonal(weights, 1.0)
        return weights.astype(np.float32)

# ---------- Residual GCN ----------
class resGCN(nn.Module):
    def __init__(self, dim, band):
        super().__init__()
        inc  = dim * band
        outc = Config.chan_num
        self.g1 = nn.Conv2d(inc,  outc, (1,3), padding=(0,1), bias=False)
        self.g2 = nn.Conv2d(outc, outc, (1,1), bias=False)
        self.bn1, self.bn2 = nn.BatchNorm2d(outc), nn.BatchNorm2d(outc)
        self.elu  = nn.ELU(False)
    def forward(self,x,x_p,L):
        return self.elu(self.bn2(self.g2(self.elu(self.bn1(self.g1(x))))))

# ---------- HGCN ----------
class HGCN_Dual(nn.Module):

    def __init__(self, dim, tau_mm: float = 25.0):

        super().__init__()
        self.elec = Electrodes()
        self.tau  = tau_mm
        # Two independent residuals GCN
        self.gcn_short = resGCN(dim, Config.band_num)
        self.gcn_long  = resGCN(dim, Config.band_num)

    @staticmethod
    def _normalize(A: torch.Tensor):
        A = A + torch.eye(A.shape[-1], device=A.device)
        deg = A.sum(-1, keepdim=True)
        D_inv_sqrt = deg.pow(-0.5)
        D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0
        return D_inv_sqrt * A * D_inv_sqrt.transpose(-1, -2)


    def _split_adj(self, base_adj: torch.Tensor):

        pos = torch.tensor(self.elec.positions_3d,
                           dtype=torch.float32, device=base_adj.device)
        dist = torch.cdist(pos, pos)          # [128,128] (mm)
        mask_short = (dist < self.tau).float()
        mask_long  = (dist >= self.tau).float()
        A_short = self._normalize(base_adj * mask_short)
        A_long  = self._normalize(base_adj * mask_long)
        return A_short, A_long


    def forward(self, x: torch.Tensor):
        base_adj = torch.tensor(self.elec.get_adjacency_matrix(),
                                dtype=torch.float32, device=x.device)
        # Add a learning edge
        base_adj = base_adj + self.elec.edge_importance
        base_adj = (base_adj + base_adj.t()) / 2.0

        A_short, A_long = self._split_adj(base_adj)

        out_s = self.gcn_short(x, x, A_short)
        out_l = self.gcn_long (x, x, A_long)
        return out_s + out_l   


class SGCN_Dual(nn.Module):
    def __init__(self, dim, tau_mm: float = 25.0):
        super().__init__()
        self.hgcn = HGCN_Dual(dim, tau_mm)
        self.elec = self.hgcn.elec
    def forward(self,x):
        return self.hgcn(x)

# ============================================================
#                     EEGEncoder
# ============================================================
class EEGEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # --- Temporal ---
        self.t_block  = Temporal_Block()
        self.tgcn     = SGCN_Dual(dim=5)                       
        self.chanattn = nn.Conv2d(Config.chan_num, Config.inc, 1)  # 128→32
        self.proj1    = nn.Linear(256 * Config.chan_num, 256)

        # --- Spectral ---
        self.dgcn = SGCN_Dual(dim=1)                         
        self.pgcn = SGCN_Dual(dim=1)                           
        self.mlp_low  = nn.Sequential(nn.Linear(128,128), nn.ELU(False),
                                      nn.Linear(128,128))
        self.mlp_high = nn.Sequential(nn.Linear(128,128), nn.ELU(False),
                                      nn.Linear(128,128))
        self.spec_up  = nn.ConvTranspose1d(128,128,4,4)

        # --- Fuse ---
        self.final_conv = nn.Conv2d(288,128,1)
        self.dpout      = nn.Dropout(0.5)

    def forward(self, filtered, prefeat):
        B = filtered.size(0)

        # ----- Temporal -----
        x = filtered.unsqueeze(1).expand(-1,5,-1,-1)           # [B,5,128,256]
        t = self.t_block(x)                                    # [B,25,128,256]
        t = self.tgcn(t)                                       # [B,128,128,256]
        t = self.chanattn(t)                                   # [B,32,128,256]
        t = t.permute(0,3,1,2).contiguous()                    # [B,256,32,128]
        t = t.reshape(B, Config.inc, -1)                       # [B,32,32768]
        t = self.proj1(t)                                      # [B,32,256]

        # ----- Spectral -----
        low  = self.dgcn(prefeat[:,:, :5].permute(0,2,1).unsqueeze(3)).squeeze(-1)
        high = self.pgcn(prefeat[:,:,5: ].permute(0,2,1).unsqueeze(3)).squeeze(-1)
        low  = self.spec_up(self.mlp_low (low ))[:, :, :256]
        high = self.spec_up(self.mlp_high(high))[:, :, :256]
        spec = torch.cat((low, high), 1)                       # [B,256,256]

        # ----- Fuse & Output -----
        fused = torch.cat((t, spec), 1).unsqueeze(-1)          # [B,288,256,1]
        out   = self.final_conv(fused).squeeze(-1)             # [B,128,256]
        return self.dpout(out)


## AudioEncoder

In [24]:
class AudioEncoder(nn.Module):
    def __init__(self, L, N):
        super(AudioEncoder, self).__init__()
        self.L, self.N = L, N
        self.conv1d_U = nn.Conv1d(1, N, kernel_size=L, stride=L // 2, bias=False)

    def forward(self, mixture):
        """
            mixture: [M, T], M is batch size, T is #samples
        Returns:
            mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
        """
        #print("----1----",mixture.shape)
        mixture_w = F.relu(self.conv1d_U(mixture.unsqueeze(1)))  # [M, N, K]().unsqueeze(0)
        return mixture_w

## utils_mossformer

### RotaryEmbedding

In [None]:
def exists(x): return x is not None

def rotate_half(x: Tensor) -> Tensor:
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)

@torch.cuda.amp.autocast(enabled=False)
def _apply_rotary(
    angles: Tensor,
    t: Tensor,
    *,
    start_index: int = 0,
    scale: Tensor | float = 1.0,
    seq_dim: int = -2,
) -> Tensor:
    if t.ndim == 3:                        
        angles = angles[-t.shape[seq_dim]:]

    rot_d = angles.shape[-1]
    end_idx = start_index + rot_d
    assert rot_d <= t.shape[-1], (
        f"rotation dim {rot_d} > tensor dim {t.shape[-1]}"
    )

    left, mid, right = t[..., :start_index], t[..., start_index:end_idx], t[..., end_idx:]
    mid = (mid * angles.cos() * scale) + (rotate_half(mid) * angles.sin() * scale)
    return torch.cat((left, mid, right), dim=-1).type_as(t)

# ---------------- RotaryEmbedding ---------------- #
class RotaryEmbedding(Module):
    def __init__(
        self,
        dim: int,
        *,
        freqs_for: Literal["lang", "pixel", "constant"] = "lang",
        theta: float = 10000.0,
        max_freq: int = 10,
        num_freqs: int = 1,
        learned_freq: bool = False,
        use_xpos: bool = False,
        xpos_scale_base: int = 512,
        interpolate_factor: float = 1.0,
        theta_rescale_factor: float = 1.0,
        seq_before_head_dim: bool = False,
        cache_if_possible: bool = True,
    ):
        super().__init__()

        # —— frequency —— #
        theta *= theta_rescale_factor ** (dim / max(dim - 2, 1))
        if freqs_for == "lang":
            freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
        elif freqs_for == "pixel":
            freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
        else:  # 'constant'
            freqs = torch.ones(num_freqs)

        self.freqs = torch.nn.Parameter(freqs, requires_grad=learned_freq)
        self.freqs_for = freqs_for
        self.learned_freq = learned_freq

        self.seq_before_head_dim = seq_before_head_dim
        self.default_seq_dim = -3 if seq_before_head_dim else -2

        self.interpolate_factor = max(1.0, interpolate_factor)
        self.use_xpos = use_xpos
        self.xpos_scale_base = xpos_scale_base
        self.cache_if_possible = cache_if_possible
        self._angle_cache: Dict[Tuple[int, torch.device, torch.dtype], Tensor] = {}
        self._scale_cache: Dict[Tuple[int, torch.device, torch.dtype], Tensor] = {}

        if use_xpos:
            scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
            self.register_buffer("xpos_scale", scale)
        else:
            self.register_buffer("xpos_scale", torch.tensor([]))  

    def _get_angles(
        self, seq_len: int, *, device: torch.device, dtype: torch.dtype
    ) -> Tensor:
        if self.interpolate_factor != 1.0 and self.freqs_for == "lang":
            pos = torch.arange(seq_len, device=device, dtype=dtype)
            pos = pos / self.interpolate_factor
            angles = einsum("i , j -> i j", pos, self.freqs.to(device=device, dtype=dtype))
        else:
            angles = einsum(
                "i , j -> i j",
                torch.arange(seq_len, device=device, dtype=dtype),
                self.freqs.to(device=device, dtype=dtype),
            )
        return torch.cat((angles, angles), dim=-1)  

    def _maybe_get_cached(
        self, seq_len: int, *, device: torch.device, dtype: torch.dtype
    ) -> Tuple[Tensor, Tensor]:
        key = (seq_len, device, dtype)
        if not self.cache_if_possible or key not in self._angle_cache:
            angles = self._get_angles(seq_len, device=device, dtype=dtype)
            self._angle_cache[key] = angles  
            if self.use_xpos:
                s = (
                    (torch.arange(seq_len, device=device, dtype=dtype) + 0.5)
                    / self.xpos_scale_base
                ) ** self.xpos_scale
                self._scale_cache[key] = torch.cat((s, s), dim=-1)
        return self._angle_cache[key], (
            self._scale_cache[key] if self.use_xpos else 1.0
        )

    def forward(
        self, seq_len: int, *, device: torch.device | None = None, dtype: torch.dtype | None = None
    ) -> Tuple[Tensor, Tensor]:
        device = device or self.freqs.device
        dtype = dtype or self.freqs.dtype
        angles, _ = self._maybe_get_cached(seq_len, device=device, dtype=dtype)
        return angles.sin(), angles.cos()

    def rotate_queries_or_keys(
        self,
        t: Tensor,
        *,
        seq_dim: int | None = None,
        offset: int = 0,
        start_index: int = 0,
    ) -> Tensor:
        seq_dim = self.default_seq_dim if seq_dim is None else seq_dim
        seq_len = t.shape[seq_dim]
        angles, scale = self._maybe_get_cached(seq_len + offset, device=t.device, dtype=t.dtype)
        angles = angles[offset : offset + seq_len]
        if isinstance(scale, Tensor):
            scale = scale[offset : offset + seq_len]

        while angles.ndim < t.ndim:
            angles = angles.unsqueeze(0)
        if isinstance(scale, Tensor) and scale.ndim < t.ndim:
            scale = scale.unsqueeze(0)

        return _apply_rotary(angles, t, start_index=start_index, scale=scale, seq_dim=seq_dim)

    def rotate_queries_with_cached_keys(
        self,
        q: Tensor,
        k: Tensor,
        *,
        seq_dim: int | None = None,
        start_index: int = 0,
    ) -> Tuple[Tensor, Tensor]:
        seq_dim = self.default_seq_dim if seq_dim is None else seq_dim
        offset = k.shape[seq_dim] - q.shape[seq_dim]
        q = self.rotate_queries_or_keys(q, seq_dim=seq_dim, offset=offset, start_index=start_index)
        k = self.rotate_queries_or_keys(k, seq_dim=seq_dim, start_index=start_index)
        return q, k


### fsmn

In [27]:
class UniDeepFsmn(nn.Module):

    def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
        super(UniDeepFsmn, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        if lorder is None:
            return

        self.lorder = lorder
        self.hidden_size = hidden_size

        self.linear = nn.Linear(input_dim, hidden_size)

        self.project = nn.Linear(hidden_size, output_dim, bias=False)

        self.conv1 = nn.Conv2d(output_dim, output_dim, [lorder+lorder-1, 1], [1, 1], groups=output_dim, bias=False)

    def forward(self, input):

        f1 = F.relu(self.linear(input))

        p1 = self.project(f1)

        x = torch.unsqueeze(p1, 1)

        x_per = x.permute(0, 3, 2, 1)

        y = F.pad(x_per, [0, 0, self.lorder - 1, self.lorder - 1])

        out = x_per + self.conv1(y)

        out1 = out.permute(0, 3, 2, 1)

        return input + out1.squeeze()

class DilatedDenseNet(nn.Module):
    def __init__(self, depth=4, lorder=20, in_channels=64):
        super(DilatedDenseNet, self).__init__()
        self.depth = depth
        self.in_channels = in_channels
        self.pad = nn.ConstantPad2d((1, 1, 1, 0), value=0.)
        self.twidth = lorder*2-1
        self.kernel_size = (self.twidth, 1)
        for i in range(self.depth):
            dil = 2 ** i
            pad_length = lorder + (dil - 1) * (lorder - 1) - 1
            setattr(self, 'pad{}'.format(i + 1), nn.ConstantPad2d((0, 0, pad_length, pad_length), value=0.))
            setattr(self, 'conv{}'.format(i + 1),
                    nn.Conv2d(self.in_channels*(i+1), self.in_channels, kernel_size=self.kernel_size,
                              dilation=(dil, 1), groups=self.in_channels, bias=False))
            setattr(self, 'norm{}'.format(i + 1), nn.InstanceNorm2d(in_channels, affine=True))
            setattr(self, 'prelu{}'.format(i + 1), nn.PReLU(self.in_channels))

    def forward(self, x):
        skip = x
        for i in range(self.depth):
            out = getattr(self, 'pad{}'.format(i + 1))(skip)
            out = getattr(self, 'conv{}'.format(i + 1))(out)
            out = getattr(self, 'norm{}'.format(i + 1))(out)
            out = getattr(self, 'prelu{}'.format(i + 1))(out)            
            skip = torch.cat([out, skip], dim=1)
        return out

class UniDeepFsmn_dilated(nn.Module):

    def __init__(self, input_dim, output_dim, lorder=None, hidden_size=None):
        super(UniDeepFsmn_dilated, self).__init__()

        self.input_dim = input_dim
        self.output_dim = output_dim

        if lorder is None:
            return

        self.lorder = lorder
        self.hidden_size = hidden_size

        self.linear = nn.Linear(input_dim, hidden_size)

        self.project = nn.Linear(hidden_size, output_dim, bias=False)

        self.conv = DilatedDenseNet(depth=2, lorder=lorder, in_channels=output_dim)

    def forward(self, input):

        f1 = F.relu(self.linear(input))

        p1 = self.project(f1)

        x = torch.unsqueeze(p1, 1)

        x_per = x.permute(0, 3, 2, 1)

        out = self.conv(x_per)

        out1 = out.permute(0, 3, 2, 1)

        return input + out1.squeeze()

### normalization

In [28]:
class LayerNorm(nn.Module):
    """
    This code came from sb.nnet.normalization
    # from sb.nnet.normalization import LayerNorm


    Applies layer normalization to the input tensor.

    Arguments
    ---------
    input_shape : tuple
        The expected shape of the input.
    eps : float
        This value is added to std deviation estimation to improve the numerical
        stability.
    elementwise_affine : bool
        If True, this module has learnable per-element affine parameters
        initialized to ones (for weights) and zeros (for biases).

    Example
    -------
    >>> input = torch.randn(100, 101, 128)
    >>> norm = LayerNorm(input_shape=input.shape)
    >>> output = norm(input)
    >>> output.shape
    torch.Size([100, 101, 128])
    """

    def __init__(
        self,
        input_size=None,
        input_shape=None,
        eps=1e-05,
        elementwise_affine=True,
    ):
        super().__init__()
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if input_shape is not None:
            input_size = input_shape[2:]

        self.norm = torch.nn.LayerNorm(
            input_size,
            eps=self.eps,
            elementwise_affine=self.elementwise_affine,
        )

    def forward(self, x):
        """Returns the normalized input tensor.

        Arguments
        ---------
        x : torch.Tensor (batch, time, channels)
            input to normalize. 3d or 4d tensors are expected.
        """
        return self.norm(x)

class CLayerNorm(nn.LayerNorm):
    """Channel-wise layer normalization."""

    def __init__(self,*args,**kwargs):
        super(CLayerNorm, self).__init__(*args,**kwargs)

    def forward(self, sample):
        """Forward function.
            sample: [batch_size, channels, length]
        """
        if sample.dim() != 3:
            raise RuntimeError('{} only accept 3-D tensor as input'.format(
                self.__name__))
        # [N, C, T] -> [N, T, C]
        sample = torch.transpose(sample, 1, 2)
        # LayerNorm
        sample = super().forward(sample)
        # [N, T, C] -> [N, C, T]
        sample = torch.transpose(sample, 1, 2)
        return sample

class ScaleNorm(nn.Module):
    def __init__(self, dim, eps = 1e-5):
        super().__init__()
        self.scale = dim ** -0.5
        self.eps = eps
        self.g = nn.Parameter(torch.ones(1))

    def forward(self, x):
        norm = torch.norm(x, dim = -1, keepdim = True) * self.scale
        return x / norm.clamp(min = self.eps) * self.g

### conv_module

In [29]:
class Transpose(nn.Module):
    """ Wrapper class of torch.transpose() for Sequential module. """
    def __init__(self, shape: tuple):
        super(Transpose, self).__init__()
        self.shape = shape

    def forward(self, x: Tensor) -> Tensor:
        return x.transpose(*self.shape)

class DepthwiseConv1d(nn.Module):
    """
    When groups == in_channels and out_channels == K * in_channels, where K is a positive integer,
    this operation is termed in literature as depthwise convolution.
    
        in_channels (int): Number of channels in the input
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int or tuple): Size of the convolving kernel
        stride (int, optional): Stride of the convolution. Default: 1
        padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
        bias (bool, optional): If True, adds a learnable bias to the output. Default: True
    Inputs: inputs
        - **inputs** (batch, in_channels, time): Tensor containing input vector
    Returns: outputs
        - **outputs** (batch, out_channels, time): Tensor produces by depthwise 1-D convolution.
    """
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int = 1,
            padding: int = 0,
            bias: bool = False,
    ) -> None:
        super(DepthwiseConv1d, self).__init__()
        assert out_channels % in_channels == 0, "out_channels should be constant multiple of in_channels"
        self.conv = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            groups=in_channels,
            stride=stride,
            padding=padding,
            bias=bias,
        )

    def forward(self, inputs: Tensor) -> Tensor:
        return self.conv(inputs)

class ConvModule(nn.Module):
    """
    Conformer convolution module starts with a pointwise convolution and a gated linear unit (GLU).
    This is followed by a single 1-D depthwise convolution layer. Batchnorm is  deployed just after the convolution
    to aid training deep models.
        in_channels (int): Number of channels in the input
        kernel_size (int or tuple, optional): Size of the convolving kernel Default: 31
        dropout_p (float, optional): probability of dropout
    Inputs: inputs
        inputs (batch, time, dim): Tensor contains input sequences
    Outputs: outputs
        outputs (batch, time, dim): Tensor produces by conformer convolution module.
    """
    def __init__(
            self,
            in_channels: int,
            kernel_size: int = 17,
            expansion_factor: int = 2,
            dropout_p: float = 0.1,
    ) -> None:
        super(ConvModule, self).__init__()
        assert (kernel_size - 1) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding"
        assert expansion_factor == 2, "Currently, Only Supports expansion_factor 2"

        self.sequential = nn.Sequential(
            Transpose(shape=(1, 2)),
            DepthwiseConv1d(in_channels, in_channels, kernel_size, stride=1, padding=(kernel_size - 1) // 2),
        )

    def forward(self, inputs: Tensor) -> Tensor:
        return inputs + self.sequential(inputs).transpose(1, 2)

### Transformer

In [30]:
def exists(val):
    return val is not None

def padding_to_multiple_of(n, mult):
    remainder = n % mult
    if remainder == 0:
        return 0
    return mult - remainder

def default(val, d):
    return val if exists(val) else d

class FFConvM(nn.Module):
    def __init__(
        self,
        dim_in,
        dim_out,
        norm_klass = nn.LayerNorm,
        dropout = 0.1
    ):
        super().__init__()
        self.mdl = nn.Sequential(
            norm_klass(dim_in),
            nn.Linear(dim_in, dim_out),
            nn.SiLU(),
            ConvModule(dim_out),
            nn.Dropout(dropout)
        )
    def forward(
        self,
        x,
    ):
        output = self.mdl(x)
        return output

class Gated_FSMN_dilated(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        lorder,
        hidden_size
    ):
        super().__init__()
        self.to_u = FFConvM(
            dim_in = in_channels,
            dim_out = hidden_size,
            norm_klass = nn.LayerNorm,
            dropout = 0.1,
            )
        self.to_v = FFConvM(
            dim_in = in_channels,
            dim_out = hidden_size,
            norm_klass = nn.LayerNorm,
            dropout = 0.1,
            )
        self.fsmn = UniDeepFsmn_dilated(in_channels, out_channels, lorder, hidden_size)

    def forward(
        self,
        x,
    ):
        input = x
        x_u = self.to_u(x)
        x_v = self.to_v(x) 
        x_u = self.fsmn(x_u)
        x = x_v * x_u + input               
        return x

class Gated_FSMN_Block_Dilated(nn.Module):
    """1-D convolutional block."""

    def __init__(self,
                 dim,
                 inner_channels = 256,
                 group_size = 256, #384, #128, #256,
                 #query_key_dim = 128, #256, #128,
                 #expansion_factor = 4.,
                 #causal = False,
                 #dropout = 0.1,
                 norm_type = 'scalenorm',
                 #shift_tokens = True,
                 #rotary_pos_emb = None,
                 ):
        super(Gated_FSMN_Block_Dilated, self).__init__()
        if norm_type == 'scalenorm':
            norm_klass = ScaleNorm
        elif norm_type == 'layernorm':
            norm_klass = nn.LayerNorm

        self.group_size = group_size

        # rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
        self.conv1 = nn.Sequential(
            nn.Conv1d(dim, inner_channels, kernel_size=1),
            nn.PReLU(),
        )
        self.norm1 = CLayerNorm(inner_channels)
        #block dilated without gating
        #self.gated_fsmn = UniDeepFsmn_dilated(inner_channels, inner_channels, 20, inner_channels)
        #block dilated with gating
        self.gated_fsmn = Gated_FSMN_dilated(inner_channels, inner_channels, lorder=20, hidden_size=inner_channels)
        self.norm2 = CLayerNorm(inner_channels)
        self.conv2 = nn.Conv1d(inner_channels, dim, kernel_size=1)

    def forward(self, input):
        conv1 = self.conv1(input.transpose(2,1))
        norm1 = self.norm1(conv1)
        seq_out = self.gated_fsmn(norm1.transpose(2,1))
        norm2 = self.norm2(seq_out.transpose(2,1))
        conv2 = self.conv2(norm2)
        return conv2.transpose(2,1) + input

class OffsetScale(nn.Module):
    def __init__(self, dim, heads = 1):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(heads, dim))
        self.beta = nn.Parameter(torch.zeros(heads, dim))
        nn.init.normal_(self.gamma, std = 0.02)

    def forward(self, x):
        out = einsum('... d, h d -> ... h d', x, self.gamma) + self.beta
        return out.unbind(dim = -2)
        
class FLASH_ShareA_FFConvM(nn.Module):
    def __init__(
        self,
        *,
        dim,
        group_size = 256,
        query_key_dim = 128,
        expansion_factor = 1.,
        causal = False,
        dropout = 0.1,
        rotary_pos_emb = None,
        norm_klass = nn.LayerNorm,
        shift_tokens = True
    ):
        super().__init__()
        hidden_dim = int(dim * expansion_factor)        
        self.group_size = group_size
        self.causal = causal
        self.shift_tokens = shift_tokens

        # positional embeddings
        self.rotary_pos_emb = rotary_pos_emb
        # norm
        self.dropout = nn.Dropout(dropout)
        #self.move = MultiHeadEMA(embed_dim=dim, ndim=4, bidirectional=False, truncation=None)
        # projections
        
        self.to_hidden = FFConvM(
            dim_in = dim,
            dim_out = hidden_dim,
            norm_klass = norm_klass,
            dropout = dropout,
            )
        self.to_qk = FFConvM(
            dim_in = dim,
            dim_out = query_key_dim,
            norm_klass = norm_klass,
            dropout = dropout,
            )

        self.qk_offset_scale = OffsetScale(query_key_dim, heads = 4)

        self.to_out = FFConvM(
            dim_in = dim*2,
            dim_out = dim,
            norm_klass = norm_klass,
            dropout = dropout,
            )
        
        self.gateActivate=nn.Sigmoid() #exp3

    def forward(
        self,
        x,
        *,
        mask = None
    ):

        """
        b - batch
        n - sequence length (within groups)
        g - group dimension
        d - feature dimension (keys)
        e - feature dimension (values)
        i - sequence dimension (source)
        j - sequence dimension (target)
        """

        #b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size

        # prenorm
        #x = self.fsmn(x)
        normed_x = x #self.norm(x)

        # do token shift - a great, costless trick from an independent AI researcher in Shenzhen
        residual = x

        if self.shift_tokens:
            x_shift, x_pass = normed_x.chunk(2, dim = -1)
            x_shift = F.pad(x_shift, (0, 0, 1, -1), value = 0.)
            normed_x = torch.cat((x_shift, x_pass), dim = -1)

        # initial projections

        v, u = self.to_hidden(normed_x).chunk(2, dim = -1)
        qk = self.to_qk(normed_x)
        #print('normed_x: {}'.format(normed_x.shape)) 

        # offset and scale
        quad_q, lin_q, quad_k, lin_k = self.qk_offset_scale(qk)
        #print('q {}, k {}, v {}'.format(quad_q.shape, quad_k.shape, v.shape))
        att_v, att_u = self.cal_attention(x, quad_q, lin_q, quad_k, lin_k, v, u)

        #exp5: self.gateActivate=nn.SiLU()
        out = (att_u*v ) * self.gateActivate(att_v*u)
        
        x = x + self.to_out(out)
        #x = x + self.conv_module(x)
        return x

    def cal_attention(self, x, quad_q, lin_q, quad_k, lin_k, v, u, mask = None):
        b, n, device, g = x.shape[0], x.shape[-2], x.device, self.group_size

        if exists(mask):
            lin_mask = rearrange(mask, '... -> ... 1')
            lin_k = lin_k.masked_fill(~lin_mask, 0.)

        # rotate queries and keys

        if exists(self.rotary_pos_emb):
            quad_q, lin_q, quad_k, lin_k = map(self.rotary_pos_emb.rotate_queries_or_keys, (quad_q, lin_q, quad_k, lin_k))

        # padding for groups

        padding = padding_to_multiple_of(n, g)

        if padding > 0:
            quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: F.pad(t, (0, 0, 0, padding), value = 0.), (quad_q, quad_k, lin_q, lin_k, v, u))

            mask = default(mask, torch.ones((b, n), device = device, dtype = torch.bool))
            mask = F.pad(mask, (0, padding), value = False)

        # group along sequence

        quad_q, quad_k, lin_q, lin_k, v, u = map(lambda t: rearrange(t, 'b (g n) d -> b g n d', n = self.group_size), (quad_q, quad_k, lin_q, lin_k, v, u))

        if exists(mask):
            mask = rearrange(mask, 'b (g j) -> b g 1 j', j = g)

        # calculate quadratic attention output

        sim = einsum('... i d, ... j d -> ... i j', quad_q, quad_k) / g

        ###eddy REMOVE this part can solve infinite loss prob!!!!!!!!!!!!!
        #sim = sim + self.rel_pos_bias(sim)

        attn = F.relu(sim) ** 2
        #attn = F.relu(sim)
        attn = self.dropout(attn)

        if exists(mask):
            attn = attn.masked_fill(~mask, 0.)

        if self.causal:
            causal_mask = torch.ones((g, g), dtype = torch.bool, device = device).triu(1)
            attn = attn.masked_fill(causal_mask, 0.)

        quad_out_v = einsum('... i j, ... j d -> ... i d', attn, v)
        quad_out_u = einsum('... i j, ... j d -> ... i d', attn, u)

        # calculate linear attention output

        if self.causal:
            lin_kv = einsum('b g n d, b g n e -> b g d e', lin_k, v) / g
            # exclusive cumulative sum along group dimension
            lin_kv = lin_kv.cumsum(dim = 1)
            lin_kv = F.pad(lin_kv, (0, 0, 0, 0, 1, -1), value = 0.)
            lin_out_v = einsum('b g d e, b g n d -> b g n e', lin_kv, lin_q)

            lin_ku = einsum('b g n d, b g n e -> b g d e', lin_k, u) / g
            # exclusive cumulative sum along group dimension
            lin_ku = lin_ku.cumsum(dim = 1)
            lin_ku = F.pad(lin_ku, (0, 0, 0, 0, 1, -1), value = 0.)
            lin_out_u = einsum('b g d e, b g n d -> b g n e', lin_ku, lin_q)
        else:
            lin_kv = einsum('b g n d, b g n e -> b d e', lin_k, v) / n
            lin_out_v = einsum('b g n d, b d e -> b g n e', lin_q, lin_kv)

            lin_ku = einsum('b g n d, b g n e -> b d e', lin_k, u) / n
            lin_out_u = einsum('b g n d, b d e -> b g n e', lin_q, lin_ku)

        # fold back groups into full sequence, and excise out padding
        '''
        quad_attn_out_v, lin_attn_out_v = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v, lin_out_v))
        quad_attn_out_u, lin_attn_out_u = map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_u, lin_out_u))
        return quad_attn_out_v+lin_attn_out_v, quad_attn_out_u+lin_attn_out_u
        '''
        return map(lambda t: rearrange(t, 'b g n d -> b (g n) d')[:, :n], (quad_out_v+lin_out_v, quad_out_u+lin_out_u))

class FLASHTransformer_DualA_FSMN(nn.Module):
    def __init__(
        self,
        *,
        dim,
        depth,
        group_size = 256, #384, #128, #256,
        query_key_dim = 128, #256, #128,
        expansion_factor = 4.,
        causal = False,
        attn_dropout = 0.1,
        norm_type = 'scalenorm',
        shift_tokens = True
    ):
        super().__init__()
        assert norm_type in ('scalenorm', 'layernorm'), 'norm_type must be one of scalenorm or layernorm'

        if norm_type == 'scalenorm':
            norm_klass = ScaleNorm
        elif norm_type == 'layernorm':
            norm_klass = nn.LayerNorm

        self.group_size = group_size

        rotary_pos_emb = RotaryEmbedding(dim = min(32, query_key_dim))
        # max rotary embedding dimensions of 32, partial Rotary embeddings, from Wang et al - GPT-J
        #self.fsmn = nn.ModuleList([Gated_FSMN(dim, dim, lorder=20, hidden_size=dim) for _ in range(depth)])
        #self.fsmn = nn.ModuleList([Gated_FSMN_Block(dim) for _ in range(depth)])
        self.fsmn = nn.ModuleList([Gated_FSMN_Block_Dilated(dim) for _ in range(depth)])
        self.layers = nn.ModuleList([FLASH_ShareA_FFConvM(dim = dim, group_size = group_size, query_key_dim = query_key_dim, expansion_factor = expansion_factor, causal = causal, dropout = attn_dropout, rotary_pos_emb = rotary_pos_emb, norm_klass = norm_klass, shift_tokens = shift_tokens) for _ in range(depth)])

    def _build_repeats(self, in_channels, out_channels, lorder, hidden_size, repeats=1):
        repeats = [
            UniDeepFsmn(in_channels, out_channels, lorder, hidden_size)
            for i in range(repeats)
        ]
        return nn.Sequential(*repeats)

    def forward(
        self,
        x,
        *,
        mask = None
    ):
        ii = 0
        for flash in self.layers:
            #x_residual = x
            x = flash(x, mask = mask)
            x = self.fsmn[ii](x)
            #x = x + x_residual
            ii = ii + 1
        return x

class TransformerEncoder_FLASH_DualA_FSMN(nn.Module):
    """This class implements the transformer encoder.

    Arguments
    ---------
    num_layers : int
        Number of transformer layers to include.
    nhead : int
        Number of attention heads.
    d_ffn : int
        Hidden size of self-attention Feed Forward layer.
    d_model : int
        The dimension of the input embedding.
    kdim : int
        Dimension for key (Optional).
    vdim : int
        Dimension for value (Optional).
    dropout : float
        Dropout for the encoder (Optional).
    input_module: torch class
        The module to process the source input feature to expected
        feature dimension (Optional).

    Example
    -------
    >>> import torch
    >>> x = torch.rand((8, 60, 512))
    >>> net = TransformerEncoder(1, 8, 512, d_model=512)
    >>> output, _ = net(x)
    >>> output.shape
    torch.Size([8, 60, 512])
    """
    def __init__(
        self,
        num_layers,
        nhead,
        d_ffn,
        input_shape=None,
        d_model=None,
        kdim=None,
        vdim=None,
        dropout=0.0,
        activation=nn.ReLU,
        normalize_before=False,
        causal=False,
        attention_type="regularMHA",
    ):

        super().__init__()

        self.flashT = FLASHTransformer_DualA_FSMN(dim=d_model, depth=num_layers)
        self.norm = LayerNorm(d_model, eps=1e-6)

    def forward(
        self,
        src,
        src_mask: Optional[torch.Tensor] = None,
        src_key_padding_mask: Optional[torch.Tensor] = None,
        pos_embs: Optional[torch.Tensor] = None,
    ):
        """
        Arguments
        ----------
        src : tensor
            The sequence to the encoder layer (required).
        src_mask : tensor
            The mask for the src sequence (optional).
        src_key_padding_mask : tensor
            The mask for the src keys per batch (optional).
        """
        output = self.flashT(src)
        #summary(self.flashT, [(src.size())])
        output = self.norm(output)
        #summary(self.norm, [(output.size())])

        return output

### one_path_flash_fsmn

In [31]:
EPS = 1e-8

class ScaledSinuEmbedding(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = nn.Parameter(torch.ones(1,))
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n, device = device).type_as(self.inv_freq)
        sinu = einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
        return emb * self.scale

class Linear(torch.nn.Module):
    """Computes a linear transformation y = wx + b.

    Arguments
    ---------
    n_neurons : int
        It is the number of output neurons (i.e, the dimensionality of the
        output).
    input_shape: tuple
        It is the shape of the input tensor.
    input_size: int
        Size of the input tensor.
    bias : bool
        If True, the additive bias b is adopted.
    combine_dims : bool
        If True and the input is 4D, combine 3rd and 4th dimensions of input.

    Example
    -------
    >>> inputs = torch.rand(10, 50, 40)
    >>> lin_t = Linear(input_shape=(10, 50, 40), n_neurons=100)
    >>> output = lin_t(inputs)
    >>> output.shape
    torch.Size([10, 50, 100])
    """

    def __init__(
        self,
        n_neurons,
        input_shape=None,
        input_size=None,
        bias=True,
        combine_dims=False,
    ):
        super().__init__()
        self.combine_dims = combine_dims

        if input_shape is None and input_size is None:
            raise ValueError("Expected one of input_shape or input_size")

        if input_size is None:
            input_size = input_shape[-1]
            if len(input_shape) == 4 and self.combine_dims:
                input_size = input_shape[2] * input_shape[3]

        # Weights are initialized following pytorch approach
        self.w = nn.Linear(input_size, n_neurons, bias=bias)

    def forward(self, x):
        """Returns the linear transformation of input tensor.

        Arguments
        ---------
        x : torch.Tensor
            Input to transform linearly.
        """
        if x.ndim == 4 and self.combine_dims:
            x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])

        wx = self.w(x)

        return wx

class GlobalLayerNorm(nn.Module):
    """Calculate Global Layer Normalization.

    Arguments
    ---------
       dim : (int or list or torch.Size)
           Input shape from an expected input of size.
       eps : float
           A value added to the denominator for numerical stability.
       elementwise_affine : bool
          A boolean value that when set to True,
          this module has learnable per-element affine parameters
          initialized to ones (for weights) and zeros (for biases).

    Example
    -------
    >>> x = torch.randn(5, 10, 20)
    >>> GLN = GlobalLayerNorm(10, 3)
    >>> x_norm = GLN(x)
    """

    def __init__(self, dim, shape, eps=1e-8, elementwise_affine=True):
        super(GlobalLayerNorm, self).__init__()
        self.dim = dim
        self.eps = eps
        self.elementwise_affine = elementwise_affine

        if self.elementwise_affine:
            if shape == 3:
                self.weight = nn.Parameter(torch.ones(self.dim, 1))
                self.bias = nn.Parameter(torch.zeros(self.dim, 1))
            if shape == 4:
                self.weight = nn.Parameter(torch.ones(self.dim, 1, 1))
                self.bias = nn.Parameter(torch.zeros(self.dim, 1, 1))
        else:
            self.register_parameter("weight", None)
            self.register_parameter("bias", None)

    def forward(self, x):
        """Returns the normalized tensor.

        Arguments
        ---------
        x : torch.Tensor
            Tensor of size [N, C, K, S] or [N, C, L].
        """
        # x = N x C x K x S or N x C x L
        # N x 1 x 1
        # cln: mean,var N x 1 x K x S
        # gln: mean,var N x 1 x 1
        if x.dim() == 3:
            mean = torch.mean(x, (1, 2), keepdim=True)
            var = torch.mean((x - mean) ** 2, (1, 2), keepdim=True)
            if self.elementwise_affine:
                x = (
                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
                    + self.bias
                )
            else:
                x = (x - mean) / torch.sqrt(var + self.eps)

        if x.dim() == 4:
            mean = torch.mean(x, (1, 2, 3), keepdim=True)
            var = torch.mean((x - mean) ** 2, (1, 2, 3), keepdim=True)
            if self.elementwise_affine:
                x = (
                    self.weight * (x - mean) / torch.sqrt(var + self.eps)
                    + self.bias
                )
            else:
                x = (x - mean) / torch.sqrt(var + self.eps)
        return x


class CumulativeLayerNorm(nn.LayerNorm):
    """Calculate Cumulative Layer Normalization.

       Arguments
       ---------
       dim : int
        Dimension that you want to normalize.
       elementwise_affine : True
        Learnable per-element affine parameters.

    Example
    -------
    >>> x = torch.randn(5, 10, 20)
    >>> CLN = CumulativeLayerNorm(10)
    >>> x_norm = CLN(x)
    """

    def __init__(self, dim, elementwise_affine=True):
        super(CumulativeLayerNorm, self).__init__(
            dim, elementwise_affine=elementwise_affine, eps=1e-8
        )

    def forward(self, x):
        """Returns the normalized tensor.

        Arguments
        ---------
        x : torch.Tensor
            Tensor size [N, C, K, S] or [N, C, L]
        """
        # x: N x C x K x S or N x C x L
        # N x K x S x C
        if x.dim() == 4:
            x = x.permute(0, 2, 3, 1).contiguous()
            # N x K x S x C == only channel norm
            x = super().forward(x)
            # N x C x K x S
            x = x.permute(0, 3, 1, 2).contiguous()
        if x.dim() == 3:
            x = torch.transpose(x, 1, 2)
            # N x L x C == only channel norm
            x = super().forward(x)
            # N x C x L
            x = torch.transpose(x, 1, 2)
        return x


def select_norm(norm, dim, shape):
    """Just a wrapper to select the normalization type.
    """

    if norm == "gln":
        return GlobalLayerNorm(dim, shape, elementwise_affine=True)
    if norm == "cln":
        return CumulativeLayerNorm(dim, elementwise_affine=True)
    if norm == "ln":
        return nn.GroupNorm(1, dim, eps=1e-8)
    else:
        return nn.BatchNorm1d(dim)


class SBFLASHBlock_DualA(nn.Module):
    """A wrapper for the SpeechBrain implementation of the transformer encoder.

    Arguments
    ---------
    num_layers : int
        Number of layers.
    d_model : int
        Dimensionality of the representation.
    nhead : int
        Number of attention heads.
    d_ffn : int
        Dimensionality of positional feed forward.
    input_shape : tuple
        Shape of input.
    kdim : int
        Dimension of the key (Optional).
    vdim : int
        Dimension of the value (Optional).
    dropout : float
        Dropout rate.
    activation : str
        Activation function.
    use_positional_encoding : bool
        If true we use a positional encoding.
    norm_before: bool
        Use normalization before transformations.

    Example
    ---------
    >>> x = torch.randn(10, 100, 64)
    >>> block = SBTransformerBlock(1, 64, 8)
    >>> x = block(x)
    >>> x.shape
    torch.Size([10, 100, 64])
    """

    def __init__(
        self,
        num_layers,
        d_model,
        nhead,
        d_ffn=2048,
        input_shape=None,
        kdim=None,
        vdim=None,
        dropout=0.1,
        activation="relu",
        use_positional_encoding=False,
        norm_before=False,
        attention_type="regularMHA",
    ):

        super(SBFLASHBlock_DualA, self).__init__()
        self.use_positional_encoding = use_positional_encoding

        if activation == "relu":
            activation = nn.ReLU
        elif activation == "gelu":
            activation = nn.GELU
        else:
            raise ValueError("unknown activation")


        self.mdl = TransformerEncoder_FLASH_DualA_FSMN(
            num_layers=num_layers,
            nhead=nhead,
            d_ffn=d_ffn,
            input_shape=input_shape,
            d_model=d_model,
            kdim=kdim,
            vdim=vdim,
            dropout=dropout,
            activation=activation,
            normalize_before=norm_before,
            attention_type=attention_type,
        )

    def forward(self, x):
        """Returns the transformed output.

        Arguments
        ---------
        x : torch.Tensor
            Tensor shape [B, L, N],
            where, B = Batchsize,
                   L = time points
                   N = number of filters

        """
        output = self.mdl(x)

        return output


def _get_activation_fn(activation):
    """Just a wrapper to get the activation functions.
    """

    if activation == "relu":
        return F.relu
    elif activation == "gelu":
        return F.gelu


class Dual_Computation_Block(nn.Module):
    """Computation block for dual-path processing.

    Arguments
    ---------
    intra_mdl : torch.nn.module
        Model to process within the chunks.
     inter_mdl : torch.nn.module
        Model to process across the chunks.
     out_channels : int
        Dimensionality of inter/intra model.
     norm : str
        Normalization type.
     skip_around_intra : bool
        Skip connection around the intra layer.
     linear_layer_after_inter_intra : bool
        Linear layer or not after inter or intra.

    Example
    ---------
        >>> intra_block = SBTransformerBlock(1, 64, 8)
        >>> inter_block = SBTransformerBlock(1, 64, 8)
        >>> dual_comp_block = Dual_Computation_Block(intra_block, inter_block, 64)
        >>> x = torch.randn(10, 64, 100, 10)
        >>> x = dual_comp_block(x)
        >>> x.shape
        torch.Size([10, 64, 100, 10])
    """

    def __init__(
        self,
        intra_mdl,
        out_channels,
        norm="ln",
        skip_around_intra=True,
        linear_layer_after_inter_intra=True,
    ):
        super(Dual_Computation_Block, self).__init__()

        self.intra_mdl = intra_mdl
        self.skip_around_intra = skip_around_intra
        self.linear_layer_after_inter_intra = linear_layer_after_inter_intra

        # Norm
        self.norm = norm
        if norm is not None:
            self.intra_norm = select_norm(norm, out_channels, 3)

        # Linear
        if linear_layer_after_inter_intra:
            self.intra_linear = Linear(
                    out_channels, input_size=out_channels
            )

    def forward(self, x):
        """Returns the output tensor.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor of dimension [B, N, K, S].


        Return
        ---------
        out: torch.Tensor
            Output tensor of dimension [B, N, K, S].
            where, B = Batchsize,
               N = number of filters
               K = time points in each chunk
               S = the number of chunks
        """
        B, N, S = x.shape
        # intra RNN
        # [B, S, N]
        intra = x.permute(0, 2, 1).contiguous() #.view(B, S, N)

        intra = self.intra_mdl(intra)

        # [B, S, N]
        if self.linear_layer_after_inter_intra:
            intra = self.intra_linear(intra)

        # [B, N, S]
        intra = intra.permute(0, 2, 1).contiguous()
        if self.norm is not None:
            intra = self.intra_norm(intra)

        # [B, N, S]
        if self.skip_around_intra:
            intra = intra + x

        # inter RNN
        # [B, S, N]
        '''
        inter = intra.permute(0, 2, 1).contiguous() #.view(B, S, N)
        # [BK, S, H]
        inter = self.inter_mdl(inter)

        # [BK, S, N]
        if self.linear_layer_after_inter_intra:
            inter = self.inter_linear(inter)

        # [B, N, S]
        inter = inter.permute(0, 2, 1).contiguous()
        if self.norm is not None:
            inter = self.inter_norm(inter)
        # [B, N, K, S]
        out = inter + intra
        '''
        out = intra
        return out


class Dual_Path_Model(nn.Module):
    """The dual path model which is the basis for dualpathrnn, sepformer, dptnet.

    Arguments
    ---------
    in_channels : int
        Number of channels at the output of the encoder.
    out_channels : int
        Number of channels that would be inputted to the intra and inter blocks.
    intra_model : torch.nn.module
        Model to process within the chunks.
    inter_model : torch.nn.module
        model to process across the chunks,
    num_layers : int
        Number of layers of Dual Computation Block.
    norm : str
        Normalization type.
    K : int
        Chunk length.
    num_spks : int
        Number of sources (speakers).
    skip_around_intra : bool
        Skip connection around intra.
    linear_layer_after_inter_intra : bool
        Linear layer after inter and intra.
    use_global_pos_enc : bool
        Global positional encodings.
    max_length : int
        Maximum sequence length.

    Example
    ---------
    >>> intra_block = SBTransformerBlock(1, 64, 8)
    >>> inter_block = SBTransformerBlock(1, 64, 8)
    >>> dual_path_model = Dual_Path_Model(64, 64, intra_block, inter_block, num_spks=2)
    >>> x = torch.randn(10, 64, 2000)
    >>> x = dual_path_model(x)
    >>> x.shape
    torch.Size([2, 10, 64, 2000])
    """

    def __init__(
        self,
        in_channels,
        out_channels,
        intra_model,
        #inter_model,
        num_layers=1,
        norm="ln",
        K=200,
        num_spks=2,
        skip_around_intra=True,
        linear_layer_after_inter_intra=True,
        use_global_pos_enc=True,
        max_length=20000,
    ):
        super(Dual_Path_Model, self).__init__()
        self.K = K
        self.num_spks = num_spks
        self.num_layers = num_layers
        # self.norm = select_norm(norm, in_channels, 3)
        # self.conv1d_encoder = nn.Conv1d(in_channels, out_channels, 1, bias=False)
        self.use_global_pos_enc = use_global_pos_enc

        if self.use_global_pos_enc:
            self.pos_enc = ScaledSinuEmbedding(out_channels)

        self.dual_mdl = nn.ModuleList([])
        for i in range(num_layers):
            self.dual_mdl.append(
                copy.deepcopy(
                    Dual_Computation_Block(
                        intra_model,
                        #inter_model,
                        out_channels,
                        norm,
                        skip_around_intra=skip_around_intra,
                        linear_layer_after_inter_intra=linear_layer_after_inter_intra,
                    )
                )
            )

        self.conv1d_out = nn.Conv1d(
            out_channels, out_channels * num_spks, kernel_size=1
        )
        self.conv1_decoder = nn.Conv1d(out_channels, in_channels, 1, bias=False)
        self.prelu = nn.PReLU()
        self.activation = nn.ReLU()
        # gated output layer
        self.output = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Tanh()
        )
        self.output_gate = nn.Sequential(
            nn.Conv1d(out_channels, out_channels, 1), nn.Sigmoid()
        )

    def forward(self, x):
        """Returns the output tensor.

        Arguments
        ---------
        x : torch.Tensor
            Input tensor of dimension [B, N, L].

        Returns
        -------
        out : torch.Tensor
            Output tensor of dimension [spks, B, N, L]
            where, spks = Number of speakers
               B = Batchsize,
               N = number of filters
               L = the number of time points
        """

        # before each line we indicate the shape after executing the line

        # # [B, N, L]
        # x = self.norm(x)

        # # [B, N, L]
        # x = self.conv1d_encoder(x)

        if self.use_global_pos_enc:
            base = x
            x = x.transpose(1, -1)
            emb = self.pos_enc(x)
            emb = emb.transpose(0, -1) 
            x = base + emb
        
        # [B, N, S]
        for i in range(self.num_layers):
            x = self.dual_mdl[i](x)
        x = self.prelu(x)

        # [B, N*spks, K, S]
        x = self.conv1d_out(x)
        B, _, S = x.shape

        # [B*spks, N, K, S]
        x = x.view(B * self.num_spks, -1, S)

        # [B*spks, N, L]
        x = self.output(x) * self.output_gate(x)

        # [B*spks, N, L]
        x = self.conv1_decoder(x)

        # [B, spks, N, L]
        _, N, L = x.shape
        x = x.view(B, self.num_spks, N, L)
        x = self.activation(x)

        # [spks, B, N, L]
        x = x.transpose(0, 1)

        return x

    def _padding(self, input, K):
        """Padding the audio times.

        Arguments
        ---------
        K : int
            Chunks of length.
        P : int
            Hop size.
        input : torch.Tensor
            Tensor of size [B, N, L].
            where, B = Batchsize,
                   N = number of filters
                   L = time points
        """
        B, N, L = input.shape
        P = K // 2
        gap = K - (P + L % K) % K
        if gap > 0:
            pad = torch.Tensor(torch.zeros(B, N, gap)).type(input.type())
            input = torch.cat([input, pad], dim=2)

        _pad = torch.Tensor(torch.zeros(B, N, P)).type(input.type())
        input = torch.cat([_pad, input, _pad], dim=2)

        return input, gap

    def _Segmentation(self, input, K):
        """The segmentation stage splits

        Arguments
        ---------
        K : int
            Length of the chunks.
        input : torch.Tensor
            Tensor with dim [B, N, L].

        Return
        -------
        output : torch.tensor
            Tensor with dim [B, N, K, S].
            where, B = Batchsize,
               N = number of filters
               K = time points in each chunk
               S = the number of chunks
               L = the number of time points
        """
        B, N, L = input.shape
        P = K // 2
        input, gap = self._padding(input, K)
        # [B, N, K, S]
        input1 = input[:, :, :-P].contiguous().view(B, N, -1, K)
        input2 = input[:, :, P:].contiguous().view(B, N, -1, K)
        input = (
            torch.cat([input1, input2], dim=3).view(B, N, -1, K).transpose(2, 3)
        )

        return input.contiguous(), gap

    def _over_add(self, input, gap):
        """Merge the sequence with the overlap-and-add method.

        Arguments
        ---------
        input : torch.tensor
            Tensor with dim [B, N, K, S].
        gap : int
            Padding length.

        Return
        -------
        output : torch.tensor
            Tensor with dim [B, N, L].
            where, B = Batchsize,
               N = number of filters
               K = time points in each chunk
               S = the number of chunks
               L = the number of time points

        """
        B, N, K, S = input.shape
        P = K // 2
        # [B, N, S, K]
        input = input.transpose(2, 3).contiguous().view(B, N, -1, K * 2)

        input1 = input[:, :, :, :K].contiguous().view(B, N, -1)[:, :, P:]
        input2 = input[:, :, :, K:].contiguous().view(B, N, -1)[:, :, :-P]
        input = input1 + input2
        # [B, N, L]
        if gap > 0:
            input = input[:, :, :-gap]

        return input

## Fusion

In [32]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model = 256, dropout = 0.1, max_len = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Tensor, shape [seq_len, batch_size, embedding_dim]
        """
        x = x + self.pe[:x.size(0)].to(x.device)
        return self.dropout(x)

In [33]:
class Fusion(nn.Module):    
    def __init__(self, N, B):
        super(Fusion, self).__init__()               
        
        self.layer_norm = nn.GroupNorm(1, N, eps=1e-8)
        # [M, N, K] -> [M, N, K] 
        self.bottleneck_conv1x1 = nn.Conv1d(N, N, 1, bias=False)
        # [M, N, K] -> [M, N, K]         
        self.po_encoding = PositionalEncoding(d_model=256)
        encoder_layers = TransformerEncoderLayer(d_model=256, nhead=4, dim_feedforward=64*4)
        self.eeg_net = TransformerEncoder(encoder_layers, num_layers=5)
        self.fusion = nn.Conv1d(N+2*B,N, 1, bias=False)
        
    def forward(self, x, eeg, reference, speech):
        mixture_w = x

        #print(x.shape)
        M, N, D = x.size()
        
        x = self.layer_norm(x) # [M, N, K]  
    
        x = self.bottleneck_conv1x1(x) # [M, N, K]

        eeg = self.po_encoding(eeg.transpose(0,1))  #[B,C,T] -> [C,B,T]

        eeg = self.eeg_net(eeg)

        eeg = eeg.transpose(0,1).transpose(1,2)

        eeg = F.interpolate(eeg, (D), mode='linear')

        x = torch.cat((x, eeg),1)
        
        fused_output  = self.fusion(x)
        
        return fused_output

## Seprator

In [35]:
class Separator(nn.Module):
    def __init__(self):
        super(Separator, self).__init__()
        
        # mossformer 2
        intra_model = SBFLASHBlock_DualA(
            num_layers = Config.intra_numlayers,
            d_model = Config.encoder_out_nchannels,
            nhead = Config.intra_nhead,
            d_ffn = Config.intra_dffn,
            dropout = Config.intra_dropout,
            use_positional_encoding = Config.intra_use_positional,                 
            norm_before = Config.intra_norm_before                              
        )

        self.masknet = Dual_Path_Model(
            in_channels = Config.encoder_out_nchannels,
            out_channels = Config.encoder_out_nchannels,
            intra_model = intra_model,
            num_layers = Config.masknet_numlayers,
            norm = Config.masknet_norm,
            K = Config.masknet_chunksize,
            num_spks = Config.masknet_numspks,
            skip_around_intra = Config.masknet_extraskipconnection,
            linear_layer_after_inter_intra = Config.masknet_useextralinearlayer
        )

        # reference
       

    def forward(self, x):
        """
        Keep this API same with TasNet
        
            mixture_w: [M, N, K], M is batch size
        returns:
            est_mask: [M, C, N, K]
        """

        x = self.masknet(x)

        x = x.squeeze(0)

        return x



def overlap_and_add(signal, frame_step):
    """Reconstructs a signal from a framed representation.

    Adds potentially overlapping frames of a signal with shape
    `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
    The resulting tensor has shape `[..., output_size]` where

        output_size = (frames - 1) * frame_step + frame_length

        signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
        frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.

    Returns:
        A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
        output_size = (frames - 1) * frame_step + frame_length

    Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
    """
    outer_dimensions = signal.size()[:-2]
    frames, frame_length = signal.size()[-2:]

    subframe_length = math.gcd(frame_length, frame_step)  # gcd=Greatest Common Divisor
    subframe_step = frame_step // subframe_length
    subframes_per_frame = frame_length // subframe_length
    output_size = frame_step * (frames - 1) + frame_length
    output_subframes = output_size // subframe_length

    subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)

    frame = torch.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
    frame = frame.clone().detach().cuda()  # signal may in GPU or CPU
    frame = frame.contiguous().view(-1)

    result = signal.new_zeros(*outer_dimensions, output_subframes, subframe_length)
    result.index_add_(-2, frame, subframe_signal)
    result = result.view(*outer_dimensions, -1)
    return result


## Decoder

In [37]:
class Decoder(nn.Module):
    def __init__(self,N,L):
        super(Decoder, self).__init__()
        self.conv = nn.ConvTranspose1d(in_channels=N, out_channels=10, kernel_size=L, stride=L//2)
        self.conv1d = nn.Conv1d(10,1,1,bias=False)

    def forward(self, mixture_w, est_mask):
        x = mixture_w * est_mask      #torch.Size([batchsize, 256, 1599])

        x = self.conv(x) #[batchsize, 256, 1599]

        output = x.contiguous()  # B*C, 1, L

        output = self.conv1d(output)

        return output  # remove extra dimension

## Mossformer

In [40]:
EPS = 1e-8

class Mossformer(nn.Module):
    def __init__(self):
        super(Mossformer, self).__init__()
        
        self.N, self.L = Config.encoder_out_nchannels, Config.encoder_kernel_size
        self.B, self.K= Config.emb_size,Config.masknet_chunksize
        self.device = Config.device
                                        
        self.eegencoder = EEGEncoder()
        
        self.audioencoder = AudioEncoder(self.L, self.N)
        
        self.fusion = Fusion(self.N, self.B)
        
        self.separator = Separator()
        
        self.decoder = Decoder(self.N, self.L)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_normal_(p)

    def forward(self, mixture, eeg, feature,reference=None):
        """
            mixture: [M, T], M is batch size, T is #samples
        Returns:
            est_source: [M, C, T]
        """
        mixture = mixture.to(self.device)
        
        eeg = eeg.to(self.device)
        
        feature = feature.to(self.device)
        
        eeg = self.eegencoder(eeg,feature)
        
        mixture_w = self.audioencoder(mixture)
        
        fused_output = self.fusion(mixture_w, eeg, reference, mixture)
        
        est_mask = self.separator(fused_output)
        
        est_source = self.decoder(mixture_w, est_mask)
              
        
        T_origin = mixture.size(-1)
        T_conv = est_source.size(-1)
        est_source = F.pad(est_source, (0, T_origin - T_conv))
        return est_source

# Pre-Training(Single)

In [None]:
train_model(Train_dataloader, Valid_dataloader, start_epoch = 0, additional_epochs = 60)