In [3]:
import torch


In [4]:
import torch
print(torch.__version__)
print(torch.version.cuda)

2.2.2
11.8


In [5]:
# !pip install mamba-ssm

https://www.kaggle.com/code/wolfy73/run-mamba-on-p100/notebook

In [6]:
# !git clone https://github.com/Dao-AILab/causal-conv1d.git
# %cd causal-conv1d
# !git checkout v1.0.2

In [7]:
%%writefile /home/main/projects/CourseWork/CourseWork2024/causal-conv1d/setup.py
# Copyright (c) 2023, Tri Dao.
import sys
import warnings
import os
import re
import shutil
import ast
from pathlib import Path
from packaging.version import parse, Version
import platform

from setuptools import setup, find_packages
import subprocess

import urllib.request
import urllib.error
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

import torch
from torch.utils.cpp_extension import (
    BuildExtension,
    CppExtension,
    CUDAExtension,
    CUDA_HOME,
)


with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()


# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

PACKAGE_NAME = "causal_conv1d"

BASE_WHEEL_URL = "https://github.com/Dao-AILab/causal-conv1d/releases/download/{tag_name}/{wheel_name}"

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("CAUSAL_CONV1D_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("CAUSAL_CONV1D_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("CAUSAL_CONV1D_FORCE_CXX11_ABI", "FALSE") == "TRUE"


def get_platform():
    """
    Returns the platform name as used in wheel filenames.
    """
    if sys.platform.startswith("linux"):
        return "linux_x86_64"
    elif sys.platform == "darwin":
        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
        return f"macosx_{mac_version}_x86_64"
    elif sys.platform == "win32":
        return "win_amd64"
    else:
        raise ValueError("Unsupported platform: {}".format(sys.platform))


def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output(
        [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
    )
    output = raw_output.split()
    release_idx = output.index("release") + 1
    bare_metal_version = parse(output[release_idx].split(",")[0])

    return raw_output, bare_metal_version


def check_if_cuda_home_none(global_option: str) -> None:
    if CUDA_HOME is not None:
        return
    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
    # in that case.
    warnings.warn(
        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
        "only images whose names contain 'devel' will provide nvcc."
    )


def append_nvcc_threads(nvcc_extra_args):
    return nvcc_extra_args + ["--threads", "4"]


cmdclass = {}
ext_modules = []

if not SKIP_CUDA_BUILD:
    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    check_if_cuda_home_none("causal_conv1d")
    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
    if CUDA_HOME is not None:
        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
        if bare_metal_version < Version("11.6"):
            raise RuntimeError(
                "causal_conv1d is only supported on CUDA 11.6 and above.  "
                "Note: make sure nvcc has a supported version by running nvcc -V."
            )
    
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_60,code=sm_60")
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_70,code=sm_70")
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_80,code=sm_80")
    if bare_metal_version >= Version("11.8"):
        cc_flag.append("-gencode")
        cc_flag.append("arch=compute_90,code=sm_90")

    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True

    ext_modules.append(
        CUDAExtension(
            name="causal_conv1d_cuda",
            sources=[
                "csrc/causal_conv1d.cpp",
                "csrc/causal_conv1d_fwd.cu",
                "csrc/causal_conv1d_bwd.cu",
                "csrc/causal_conv1d_update.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3"],
                "nvcc": append_nvcc_threads(
                    [
                        "-O3",
                        "-U__CUDA_NO_HALF_OPERATORS__",
                        "-U__CUDA_NO_HALF_CONVERSIONS__",
                        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
                        "--expt-relaxed-constexpr",
                        "--expt-extended-lambda",
                        "--use_fast_math",
                        "--ptxas-options=-v",
                        "-lineinfo",
                    ]
                    + cc_flag
                ),
            },
            include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"],
        )
    )


def get_package_version():
    with open(Path(this_dir) / "causal_conv1d" / "__init__.py", "r") as f:
        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
    public_version = ast.literal_eval(version_match.group(1))
    local_version = os.environ.get("CAUSAL_CONV1D_LOCAL_VERSION")
    if local_version:
        return f"{public_version}+{local_version}"
    else:
        return str(public_version)


def get_wheel_url():
    # Determine the version numbers that will be used to determine the correct wheel
    # We're using the CUDA version used to build torch, not the one currently installed
    # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
    torch_cuda_version = parse(torch.version.cuda)
    torch_version_raw = parse(torch.__version__)
    # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
    # to save CI time. Minor versions should be compatible.
    torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
    python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
    platform_name = get_platform()
    causal_conv1d_version = get_package_version()
    # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
    cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
    torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()

    # Determine wheel URL based on CUDA version, torch version, python version and OS
    wheel_filename = f"{PACKAGE_NAME}-{causal_conv1d_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
    wheel_url = BASE_WHEEL_URL.format(
        tag_name=f"v{causal_conv1d_version}", wheel_name=wheel_filename
    )
    return wheel_url, wheel_filename


class CachedWheelsCommand(_bdist_wheel):
    """
    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
    find an existing wheel (which is currently the case for all installs). We use
    the environment parameters to detect whether there is already a pre-built version of a compatible
    wheel available and short-circuits the standard full build pipeline.
    """

    def run(self):
        if FORCE_BUILD:
            return super().run()

        wheel_url, wheel_filename = get_wheel_url()
        print("Guessing wheel URL: ", wheel_url)
        try:
            urllib.request.urlretrieve(wheel_url, wheel_filename)

            # Make the archive
            # Lifted from the root wheel processing command
            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
            if not os.path.exists(self.dist_dir):
                os.makedirs(self.dist_dir)

            impl_tag, abi_tag, plat_tag = self.get_tag()
            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"

            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
            print("Raw wheel path", wheel_path)
            shutil.move(wheel_filename, wheel_path)
        except urllib.error.HTTPError:
            print("Precompiled wheel not found. Building from source...")
            # If the wheel could not be downloaded, build from source
            super().run()


setup(
    name=PACKAGE_NAME,
    version=get_package_version(),
    packages=find_packages(
        exclude=(
            "build",
            "csrc",
            "include",
            "tests",
            "dist",
            "docs",
            "benchmarks",
            "causal_conv1d.egg-info",
        )
    ),
    author="Tri Dao",
    author_email="tri@tridao.me",
    description="Causal depthwise conv1d in CUDA, with a PyTorch interface",
    long_description=long_description,
    long_description_content_type="text/markdown",
    url="https://github.com/Dao-AILab/causal-conv1d",
    classifiers=[
        "Programming Language :: Python :: 3",
        "License :: OSI Approved :: BSD License",
        "Operating System :: Unix",
    ],
    ext_modules=ext_modules,
    cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
    if ext_modules
    else {
        "bdist_wheel": CachedWheelsCommand,
    },
    python_requires=">=3.7",
    install_requires=[
        "torch",
        "packaging",
        "buildtools",
        "ninja",
    ],
)

Overwriting /home/main/projects/CourseWork/CourseWork2024/causal-conv1d/setup.py


In [8]:
!CAUSAL_CONV1D_FORCE_BUILD=TRUE pip install .

[31mERROR: Directory '.' is not installable. Neither 'setup.py' nor 'pyproject.toml' found.[0m[31m
[0m

In [9]:
import torch
from causal_conv1d import causal_conv1d_fn

batch, dim, seq, width = 10, 5, 17, 4
x = torch.zeros((batch, dim, seq)).to('cuda')
weight = torch.zeros((dim, width)).to('cuda')
bias = torch.zeros((dim, )).to('cuda')

causal_conv1d_fn(x, weight, bias, None)

tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.

In [10]:
# %cd ..
# !git clone https://github.com/state-spaces/mamba.git
# %cd mamba

In [11]:
%%writefile /home/main/projects/CourseWork/CourseWork2024/mamba/setup.py
# Copyright (c) 2023, Albert Gu, Tri Dao.
import sys
import warnings
import os
import re
import ast
from pathlib import Path
from packaging.version import parse, Version
import platform
import shutil

from setuptools import setup, find_packages
import subprocess

import urllib.request
import urllib.error
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel

import torch
from torch.utils.cpp_extension import (
    BuildExtension,
    CppExtension,
    CUDAExtension,
    CUDA_HOME,
)


with open("README.md", "r", encoding="utf-8") as fh:
    long_description = fh.read()


# ninja build does not work unless include_dirs are abs path
this_dir = os.path.dirname(os.path.abspath(__file__))

PACKAGE_NAME = "mamba_ssm"

BASE_WHEEL_URL = "https://github.com/state-spaces/mamba/releases/download/{tag_name}/{wheel_name}"

# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels
# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation
FORCE_BUILD = os.getenv("MAMBA_FORCE_BUILD", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("MAMBA_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI
FORCE_CXX11_ABI = os.getenv("MAMBA_FORCE_CXX11_ABI", "FALSE") == "TRUE"


def get_platform():
    """
    Returns the platform name as used in wheel filenames.
    """
    if sys.platform.startswith("linux"):
        return "linux_x86_64"
    elif sys.platform == "darwin":
        mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
        return f"macosx_{mac_version}_x86_64"
    elif sys.platform == "win32":
        return "win_amd64"
    else:
        raise ValueError("Unsupported platform: {}".format(sys.platform))


def get_cuda_bare_metal_version(cuda_dir):
    raw_output = subprocess.check_output(
        [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
    )
    output = raw_output.split()
    release_idx = output.index("release") + 1
    bare_metal_version = parse(output[release_idx].split(",")[0])

    return raw_output, bare_metal_version


def check_if_cuda_home_none(global_option: str) -> None:
    if CUDA_HOME is not None:
        return
    # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary
    # in that case.
    warnings.warn(
        f"{global_option} was requested, but nvcc was not found.  Are you sure your environment has nvcc available?  "
        "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, "
        "only images whose names contain 'devel' will provide nvcc."
    )


def append_nvcc_threads(nvcc_extra_args):
    return nvcc_extra_args + ["--threads", "4"]


cmdclass = {}
ext_modules = []

if not SKIP_CUDA_BUILD:
    print("\n\ntorch.__version__  = {}\n\n".format(torch.__version__))
    TORCH_MAJOR = int(torch.__version__.split(".")[0])
    TORCH_MINOR = int(torch.__version__.split(".")[1])

    check_if_cuda_home_none(PACKAGE_NAME)
    # Check, if CUDA11 is installed for compute capability 8.0
    cc_flag = []
    if CUDA_HOME is not None:
        _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME)
        if bare_metal_version < Version("11.6"):
            raise RuntimeError(
                f"{PACKAGE_NAME} is only supported on CUDA 11.6 and above.  "
                "Note: make sure nvcc has a supported version by running nvcc -V."
            )
    
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_60,code=sm_60")
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_70,code=sm_70")
    cc_flag.append("-gencode")
    cc_flag.append("arch=compute_80,code=sm_80")
    if bare_metal_version >= Version("11.8"):
        cc_flag.append("-gencode")
        cc_flag.append("arch=compute_90,code=sm_90")

    # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
    # torch._C._GLIBCXX_USE_CXX11_ABI
    # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
    if FORCE_CXX11_ABI:
        torch._C._GLIBCXX_USE_CXX11_ABI = True

    ext_modules.append(
        CUDAExtension(
            name="selective_scan_cuda",
            sources=[
                "csrc/selective_scan/selective_scan.cpp",
                "csrc/selective_scan/selective_scan_fwd_fp32.cu",
                "csrc/selective_scan/selective_scan_fwd_fp16.cu",
                "csrc/selective_scan/selective_scan_fwd_bf16.cu",
                "csrc/selective_scan/selective_scan_bwd_fp32_real.cu",
                "csrc/selective_scan/selective_scan_bwd_fp32_complex.cu",
                "csrc/selective_scan/selective_scan_bwd_fp16_real.cu",
                "csrc/selective_scan/selective_scan_bwd_fp16_complex.cu",
                "csrc/selective_scan/selective_scan_bwd_bf16_real.cu",
                "csrc/selective_scan/selective_scan_bwd_bf16_complex.cu",
            ],
            extra_compile_args={
                "cxx": ["-O3", "-std=c++17"],
                "nvcc": append_nvcc_threads(
                    [
                        "-O3",
                        "-std=c++17",
                        "-U__CUDA_NO_HALF_OPERATORS__",
                        "-U__CUDA_NO_HALF_CONVERSIONS__",
                        "-U__CUDA_NO_BFLOAT16_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
                        "-U__CUDA_NO_BFLOAT162_OPERATORS__",
                        "-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
                        "--expt-relaxed-constexpr",
                        "--expt-extended-lambda",
                        "--use_fast_math",
                        "--ptxas-options=-v",
                        "-lineinfo",
                    ]
                    + cc_flag
                ),
            },
            include_dirs=[Path(this_dir) / "csrc" / "selective_scan"],
        )
    )


def get_package_version():
    with open(Path(this_dir) / PACKAGE_NAME / "__init__.py", "r") as f:
        version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE)
    public_version = ast.literal_eval(version_match.group(1))
    local_version = os.environ.get("MAMBA_LOCAL_VERSION")
    if local_version:
        return f"{public_version}+{local_version}"
    else:
        return str(public_version)


def get_wheel_url():
    # Determine the version numbers that will be used to determine the correct wheel
    # We're using the CUDA version used to build torch, not the one currently installed
    # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
    torch_cuda_version = parse(torch.version.cuda)
    torch_version_raw = parse(torch.__version__)
    # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.2
    # to save CI time. Minor versions should be compatible.
    torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.2")
    python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
    platform_name = get_platform()
    mamba_ssm_version = get_package_version()
    # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
    cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
    torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
    cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()

    # Determine wheel URL based on CUDA version, torch version, python version and OS
    wheel_filename = f"{PACKAGE_NAME}-{mamba_ssm_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
    wheel_url = BASE_WHEEL_URL.format(
        tag_name=f"v{mamba_ssm_version}", wheel_name=wheel_filename
    )
    return wheel_url, wheel_filename


class CachedWheelsCommand(_bdist_wheel):
    """
    The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
    find an existing wheel (which is currently the case for all installs). We use
    the environment parameters to detect whether there is already a pre-built version of a compatible
    wheel available and short-circuits the standard full build pipeline.
    """

    def run(self):
        if FORCE_BUILD:
            return super().run()

        wheel_url, wheel_filename = get_wheel_url()
        print("Guessing wheel URL: ", wheel_url)
        try:
            urllib.request.urlretrieve(wheel_url, wheel_filename)

            # Make the archive
            # Lifted from the root wheel processing command
            # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
            if not os.path.exists(self.dist_dir):
                os.makedirs(self.dist_dir)

            impl_tag, abi_tag, plat_tag = self.get_tag()
            archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"

            wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
            print("Raw wheel path", wheel_path)
            shutil.move(wheel_filename, wheel_path)
        except urllib.error.HTTPError:
            print("Precompiled wheel not found. Building from source...")
            # If the wheel could not be downloaded, build from source
            super().run()


setup(
    name=PACKAGE_NAME,
    version=get_package_version(),
    packages=find_packages(
        exclude=(
            "build",
            "csrc",
            "include",
            "tests",
            "dist",
            "docs",
            "benchmarks",
            "mamba_ssm.egg-info",
        )
    ),
    author="Tri Dao, Albert Gu",
    author_email="tri@tridao.me, agu@cs.cmu.edu",
    description="Mamba state-space model",
    long_description=long_description,
    long_description_content_type="text/markdown",
    url="https://github.com/state-spaces/mamba",
    classifiers=[
        "Programming Language :: Python :: 3",
        "License :: OSI Approved :: BSD License",
        "Operating System :: Unix",
    ],
    ext_modules=ext_modules,
    cmdclass={"bdist_wheel": CachedWheelsCommand, "build_ext": BuildExtension}
    if ext_modules
    else {
        "bdist_wheel": CachedWheelsCommand,
    },
    python_requires=">=3.7",
    install_requires=[
        "torch",
        "packaging",
        "ninja",
        "einops",
        "triton",
        "transformers",
        "causal_conv1d",
    ],
)

Overwriting /home/main/projects/CourseWork/CourseWork2024/mamba/setup.py


In [12]:
# !MAMBA_FORCE_BUILD=TRUE pip install .


In [13]:
import torch

from mamba_ssm import Mamba

batch, length, dim = 2, 64, 16
x = torch.randn(batch, length, dim).to("cuda")
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=dim, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")
y = model(x)
assert y.shape == x.shape
y

tensor([[[-1.1146e-02, -3.9352e-02,  2.3822e-02,  ...,  2.1485e-02,
           1.7347e-02,  6.7614e-02],
         [-2.3396e-02,  1.8208e-02,  2.2303e-02,  ..., -3.5535e-02,
          -3.1351e-02,  7.6012e-03],
         [ 8.4175e-02,  1.2027e-03, -4.2173e-02,  ..., -2.7933e-02,
          -1.4844e-02, -6.7537e-02],
         ...,
         [-1.2864e-02, -7.5387e-02,  3.3725e-02,  ..., -6.1465e-02,
           8.2823e-02,  4.4937e-02],
         [ 5.7107e-04,  3.3003e-03,  2.4038e-02,  ...,  3.3715e-02,
           2.3220e-02, -2.2669e-03],
         [ 6.1681e-02,  6.9858e-03, -3.1010e-02,  ..., -3.3422e-02,
          -4.3597e-03, -7.9107e-02]],

        [[ 1.8062e-02,  2.1797e-02, -1.3272e-02,  ...,  8.2128e-03,
           1.2229e-02, -7.1284e-02],
         [ 3.2764e-03,  2.0396e-02, -2.3066e-02,  ..., -3.4217e-03,
          -1.4118e-02, -3.7305e-02],
         [ 3.3982e-03,  3.2881e-02, -2.4321e-02,  ..., -3.0020e-02,
          -3.5167e-02, -4.2316e-02],
         ...,
         [-2.3551e-02,  9

In [14]:
# Copyright (c) 2023, Tri Dao, Albert Gu.

import torch
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd

from einops import rearrange, repeat

try:
    from causal_conv1d import causal_conv1d_fn
    import causal_conv1d_cuda
except ImportError:
    causal_conv1d_fn = None
    causal_conv1d_cuda = None

import selective_scan_cuda


class SelectiveScanFn(torch.autograd.Function):

    @staticmethod
    def forward(ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                return_last_state=False):
        if u.stride(-1) != 1:
            u = u.contiguous()
        if delta.stride(-1) != 1:
            delta = delta.contiguous()
        if D is not None:
            D = D.contiguous()
        if B.stride(-1) != 1:
            B = B.contiguous()
        if C.stride(-1) != 1:
            C = C.contiguous()
        if z is not None and z.stride(-1) != 1:
            z = z.contiguous()
        if B.dim() == 3:
            B = rearrange(B, "b dstate l -> b 1 dstate l")
            ctx.squeeze_B = True
        if C.dim() == 3:
            C = rearrange(C, "b dstate l -> b 1 dstate l")
            ctx.squeeze_C = True
        out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
        ctx.delta_softplus = delta_softplus
        ctx.has_z = z is not None
        last_state = x[:, :, -1, 1::2]  # (batch, dim, dstate)
        if not ctx.has_z:
            ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
            return out if not return_last_state else (out, last_state)
        else:
            ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
            out_z = rest[0]
            return out_z if not return_last_state else (out_z, last_state)

    @staticmethod
    def backward(ctx, dout, *args):
        if not ctx.has_z:
            u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
            z = None
            out = None
        else:
            u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
        if dout.stride(-1) != 1:
            dout = dout.contiguous()
        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
        # backward of selective_scan_cuda with the backward of chunk).
        # Here we just pass in None and dz will be allocated in the C++ code.
        du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
            u, delta, A, B, C, D, z, delta_bias, dout, x, out, None, ctx.delta_softplus,
            False  # option to recompute out_z, not used here
        )
        dz = rest[0] if ctx.has_z else None
        dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
        dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
        return (du, ddelta, dA, dB, dC,
                dD if D is not None else None,
                dz,
                ddelta_bias if delta_bias is not None else None,
                None,
                None)


def selective_scan_fn(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                     return_last_state=False):
    """if return_last_state is True, returns (out, last_state)
    last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
    not considered in the backward pass.
    """
    return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)


def selective_scan_ref(u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False,
                      return_last_state=False):
    """
    u: r(B D L)
    delta: r(B D L)
    A: c(D N) or r(D N)
    B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
    D: r(D)
    z: r(B D L)
    delta_bias: r(D), fp32

    out: r(B D L)
    last_state (optional): r(B D dstate) or c(B D dstate)
    """
    dtype_in = u.dtype
    u = u.float()
    delta = delta.float()
    if delta_bias is not None:
        delta = delta + delta_bias[..., None].float()
    if delta_softplus:
        delta = F.softplus(delta)
    batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
    is_variable_B = B.dim() >= 3
    is_variable_C = C.dim() >= 3
    if A.is_complex():
        if is_variable_B:
            B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
        if is_variable_C:
            C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
    else:
        B = B.float()
        C = C.float()
    x = A.new_zeros((batch, dim, dstate))
    ys = []
    deltaA = torch.exp(torch.einsum('bdl,dn->bdln', delta, A))
    if not is_variable_B:
        deltaB_u = torch.einsum('bdl,dn,bdl->bdln', delta, B, u)
    else:
        if B.dim() == 3:
            deltaB_u = torch.einsum('bdl,bnl,bdl->bdln', delta, B, u)
        else:
            B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
            deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u)
    if is_variable_C and C.dim() == 4:
        C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
    last_state = None
    for i in range(u.shape[2]):
        x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
        if not is_variable_C:
            y = torch.einsum('bdn,dn->bd', x, C)
        else:
            if C.dim() == 3:
                y = torch.einsum('bdn,bn->bd', x, C[:, :, i])
            else:
                y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i])
        if i == u.shape[2] - 1:
            last_state = x
        if y.is_complex():
            y = y.real * 2
        ys.append(y)
    y = torch.stack(ys, dim=2) # (batch dim L)
    out = y if D is None else y + u * rearrange(D, "d -> d 1")
    if z is not None:
        out = out * F.silu(z)
    out = out.to(dtype=dtype_in)
    return out if not return_last_state else (out, last_state)


class MambaInnerFn(torch.autograd.Function):

    @staticmethod
    @custom_fwd
    def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                out_proj_weight, out_proj_bias,
                A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
                C_proj_bias=None, delta_softplus=True, checkpoint_lvl=1):
        """
             xz: (batch, dim, seqlen)
        """
        assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
        assert checkpoint_lvl in [0, 1]
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        if torch.is_autocast_enabled():
            x_proj_weight = x_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            delta_proj_weight = delta_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            out_proj_weight = out_proj_weight.to(dtype=torch.get_autocast_gpu_dtype())
            out_proj_bias = (out_proj_bias.to(dtype=torch.get_autocast_gpu_dtype())
                             if out_proj_bias is not None else None)
        if xz.stride(-1) != 1:
            xz = xz.contiguous()
        conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
        x, z = xz.chunk(2, dim=1)
        conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
            x, conv1d_weight, conv1d_bias, None, None, None, True
        )
        # We're being very careful here about the layout, to avoid extra transposes.
        # We want delta to have d as the slowest moving dimension
        # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
        x_dbl = F.linear(rearrange(conv1d_out, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
        delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L)
        ctx.is_variable_B = B is None
        ctx.is_variable_C = C is None
        ctx.B_proj_bias_is_None = B_proj_bias is None
        ctx.C_proj_bias_is_None = C_proj_bias is None
        if B is None:  # variable B
            B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl dstate)
            if B_proj_bias is not None:
                B = B + B_proj_bias.to(dtype=B.dtype)
            if not A.is_complex():
                # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
                B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if B.stride(-1) != 1:
                B = B.contiguous()
        if C is None:  # variable C
            C = x_dbl[:, -d_state:]  # (bl dstate)
            if C_proj_bias is not None:
                C = C + C_proj_bias.to(dtype=C.dtype)
            if not A.is_complex():
                # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
                C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
            else:
                C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
        else:
            if C.stride(-1) != 1:
                C = C.contiguous()
        if D is not None:
            D = D.contiguous()
        out, scan_intermediates, out_z = selective_scan_cuda.fwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
        )
        ctx.delta_softplus = delta_softplus
        ctx.out_proj_bias_is_None = out_proj_bias is None
        ctx.checkpoint_lvl = checkpoint_lvl
        if checkpoint_lvl >= 1:  # Will recompute conv1d_out and delta in the backward pass
            conv1d_out, delta = None, None
        ctx.save_for_backward(xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight,
                              delta_proj_weight, out_proj_weight, conv1d_out, delta,
                              A, B, C, D, delta_bias, scan_intermediates, out)
        return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)

    @staticmethod
    @custom_bwd
    def backward(ctx, dout):
        # dout: (batch, seqlen, dim)
        assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
        (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
         conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
        L = xz.shape[-1]
        delta_rank = delta_proj_weight.shape[1]
        d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
        x, z = xz.chunk(2, dim=1)
        if dout.stride(-1) != 1:
            dout = dout.contiguous()
        if ctx.checkpoint_lvl == 1:
            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
                x, conv1d_weight, conv1d_bias, None, None, None, True
            )
            delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                              "d (b l) -> b d l", l = L)
        # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
        # backward of selective_scan_cuda with the backward of chunk).
        dxz = torch.empty_like(xz)  # (batch, dim, seqlen)
        dx, dz = dxz.chunk(2, dim=1)
        dout = rearrange(dout, "b l e -> e (b l)")
        dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
        dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
            conv1d_out, delta, A, B, C, D, z, delta_bias, dout_y, scan_intermediates, out, dz,
            ctx.delta_softplus,
            True  # option to recompute out_z
        )
        dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
        dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
        dD = dD if D is not None else None
        dx_dbl = torch.empty_like(x_dbl)
        dB_proj_bias = None
        if ctx.is_variable_B:
            if not A.is_complex():
                dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
            dx_dbl[:, delta_rank:delta_rank + d_state] = dB  # (bl d)
            dB = None
        dC_proj_bias = None
        if ctx.is_variable_C:
            if not A.is_complex():
                dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
            else:
                dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
            dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
            dx_dbl[:, -d_state:] = dC  # (bl d)
            dC = None
        ddelta = rearrange(ddelta, "b d l -> d (b l)")
        ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
        dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
        dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
        dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
        dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
        dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
        # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
        # backward of conv1d with the backward of chunk).
        dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
            x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
        )
        dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
        dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
        return (dxz, dconv1d_weight, dconv1d_bias, dx_proj_weight, ddelta_proj_weight,
                dout_proj_weight, dout_proj_bias,
                dA, dB, dC, dD,
                ddelta_bias if delta_bias is not None else None,
                dB_proj_bias, dC_proj_bias, None)


def mamba_inner_fn(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    return MambaInnerFn.apply(xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
                              out_proj_weight, out_proj_bias,
                              A, B, C, D, delta_bias, B_proj_bias, C_proj_bias, delta_softplus)


def mamba_inner_ref(
    xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
    out_proj_weight, out_proj_bias,
    A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
    C_proj_bias=None, delta_softplus=True
):
    assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
    L = xz.shape[-1]
    delta_rank = delta_proj_weight.shape[1]
    d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
    x, z = xz.chunk(2, dim=1)
    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
    # We're being very careful here about the layout, to avoid extra transposes.
    # We want delta to have d as the slowest moving dimension
    # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
    x_dbl = F.linear(rearrange(x, 'b d l -> (b l) d'), x_proj_weight)  # (bl d)
    delta = delta_proj_weight @ x_dbl[:, :delta_rank].t()
    delta = rearrange(delta, "d (b l) -> b d l", l=L)
    if B is None:  # variable B
        B = x_dbl[:, delta_rank:delta_rank + d_state]  # (bl d)
        if B_proj_bias is not None:
            B = B + B_proj_bias.to(dtype=B.dtype)
        if not A.is_complex():
            B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
        else:
            B = rearrange(B, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
    if C is None:  # variable B
        C = x_dbl[:, -d_state:]  # (bl d)
        if C_proj_bias is not None:
            C = C + C_proj_bias.to(dtype=C.dtype)
        if not A.is_complex():
            C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
        else:
            C = rearrange(C, "(b l) (dstate two) -> b dstate (l two)", l=L, two=2).contiguous()
    y = selective_scan_fn(x, delta, A, B, C, D, z=z, delta_bias=delta_bias, delta_softplus=True)
    return F.linear(rearrange(y, "b d l -> b l d"), out_proj_weight, out_proj_bias)


In [15]:
# import torch
# import torch.nn as nn

# class Mamba(nn.Module):
#     def __init__(
#         self,
#         d_model,
#         d_state=16,
#         d_conv=4,
#         expand=2,
#         dt_rank="auto",
#         dt_min=0.001,
#         dt_max=0.1,
#         dt_init="random",
#         dt_scale=1.0,
#         dt_init_floor=1e-4,
#         conv_bias=True,
#         bias=False,
#         use_fast_path=True,  # Fused kernel options
#         layer_idx=None,      # Новый параметр
#         device=None,
#         dtype=None,
#     ):
#         super().__init__()
#         self.d_model = d_model
#         self.d_state = d_state
#         self.d_conv = d_conv
#         self.expand = expand
#         self.d_inner = int(self.expand * self.d_model)
#         self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
#         self.use_fast_path = use_fast_path
#         self.layer_idx = layer_idx  # Сохраняем layer_idx
        
#         factory_kwargs = {"device": device, "dtype": dtype}

#         # Проекции
#         self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)

#         # Свёртка
#         self.conv1d = nn.Conv1d(
#             in_channels=self.d_inner,
#             out_channels=self.d_inner,
#             bias=conv_bias,
#             kernel_size=d_conv,
#             groups=self.d_inner,
#             padding=d_conv - 1,
#             **factory_kwargs,
#         )

#         # Активация
#         self.activation = "silu"
#         self.act = nn.SiLU()

#         # Проекции для состояний
#         self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs)
#         self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

#         # Инициализация проекций для dt
#         dt_init_std = self.dt_rank**-0.5 * dt_scale
#         if dt_init == "constant":
#             nn.init.constant_(self.dt_proj.weight, dt_init_std)
#         elif dt_init == "random":
#             nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
#         else:
#             raise NotImplementedError

#         # Инициализация bias для dt
#         dt = torch.exp(
#             torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
#             + math.log(dt_min)
#         ).clamp(min=dt_init_floor)
#         inv_dt = dt + torch.log(-torch.expm1(-dt))
#         with torch.no_grad():
#             self.dt_proj.bias.copy_(inv_dt)
#         self.dt_proj.bias._no_reinit = True

#         # Параметры S4D
#         A = torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device).repeat(self.d_inner, 1)
#         self.A_log = nn.Parameter(torch.log(A))
#         self.A_log._no_weight_decay = True

#         # Параметр D
#         self.D = nn.Parameter(torch.ones(self.d_inner, device=device))
#         self.D._no_weight_decay = True

#         # Финальная проекция
#         self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

#     def forward(self, hidden_states, inference_params=None):
#         """
#         Основной forward для Mamba.

#         hidden_states: (B, L, D)
#         Returns: same shape as hidden_states
#         """
#         batch_size, seqlen, _ = hidden_states.size()

#         conv_state, ssm_state = None, None
#         if inference_params is not None:
#             conv_state, ssm_state = self._get_states_from_cache(inference_params, batch_size)
#             if inference_params.seqlen_offset > 0:
#                 # В режиме inference возвращаем результат по одному шагу
#                 out, _, _ = self.step(hidden_states, conv_state, ssm_state)
#                 return out

#         # Проекции входа
#         xz = self.in_proj(hidden_states)  # (B, L, 2 * d_inner)
#         x, z = torch.chunk(xz, 2, dim=-1)  # Разделяем на x и z

#         # Свёртка по временной оси
#         x = x.transpose(1, 2)  # (B, d_inner, L)
#         x = self.conv1d(x)  # Свёртка
#         x = self.act(x)  # Активация
#         x = x.transpose(1, 2)  # Возвращаем обратно (B, L, d_inner)

#         # Проекция x в dt, B и C
#         x_db = self.x_proj(x)  # (B, L, dt_rank + 2 * d_state)
#         dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
#         dt = self.dt_proj.weight @ dt.transpose(1, 2)
#         dt = dt.transpose(1, 2)

#         # Работа с состоянием A и его обновление
#         A = -torch.exp(self.A_log.float())
#         B = B.contiguous()
#         C = C.contiguous()

#         # Операция обновления состояния и выхода
#         print(B.shape, C.shape)
#         # y = torch.einsum('bld,bl->bd', B, C) + self.D * x
#         # y = y * self.act(z)
#         y = rearrange(y, "b d l -> b l d")

#         # Финальная проекция
#         out = self.out_proj(y)
#         return out

#     def step(self, hidden_states, conv_state, ssm_state):
#         """
#         Метод для пошагового обновления в режиме inference.
#         """
#         dtype = hidden_states.dtype
#         assert hidden_states.shape[1] == 1, "Только один шаг в режиме декодирования"
#         xz = self.in_proj(hidden_states.squeeze(1))
#         x, z = torch.chunk(xz, 2, dim=-1)

#         # Обновление состояния свёртки
#         conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))
#         conv_state[:, :, -1] = x
#         x = torch.sum(conv_state * self.conv1d.weight.view(-1, 1), dim=-1) + self.conv1d.bias
#         x = self.act(x).to(dtype=dtype)

#         # Обновление состояния SSM
#         x_db = self.x_proj(x)
#         dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
#         dt = F.linear(dt, self.dt_proj.weight)
#         A = -torch.exp(self.A_log.float())

#         dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
#         dA = torch.exp(dt.unsqueeze(-1) * A)
#         dB = dt.unsqueeze(-1) * B
#         ssm_state.copy_(ssm_state * dA + x.unsqueeze(-1) * dB)
#         print(B.shape, C.shape)
#         # y = torch.einsum('bld,bd->bl', B, C) + self.D * x
#         # y = y * self.act(z)
#         y = rearrange(y, "b l d -> b d l")
#         y = y + self.D.to(dtype) * x

#         out = self.out_proj(y)
#         return out.unsqueeze(1), conv_state, ssm_state

#     def _get_states_from_cache(self, inference_params, batch_size):
#         """
#         Получаем состояния для декодирования из кеша.
#         """
#         if self.layer_idx is not None and self.layer_idx in inference_params.key_value_memory_dict:
#             return inference_params.key_value_memory_dict[self.layer_idx]
#         else:
#             conv_state = torch.zeros(batch_size, self.d_inner, self.d_conv, device=self.conv1d.weight.device)
#             ssm_state = torch.zeros(batch_size, self.d_inner, self.d_state, device=self.dt_proj.weight.device)
#             if self.layer_idx is not None:
#                 inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
#             return conv_state, ssm_state
    
# # class TFMambaBlock(nn.Module):
# #     """
# #     Custom Temporal-Frequency блок для последовательной обработки данных.
# #     """
# #     def __init__(self, cfg):
# #         super(TFMambaBlock, self).__init__()
# #         self.hid_feature = cfg['model_cfg']['hid_feature']
        
# #         # Создаём несколько временных и частотных блоков
# #         self.time_block = Mamba(d_state=self.hid_feature, d_conv=cfg['model_cfg']['d_conv'], expand=cfg['model_cfg']['expand'], layer_idx=0)
# #         self.freq_block = Mamba(d_state=self.hid_feature, d_conv=cfg['model_cfg']['d_conv'], expand=cfg['model_cfg']['expand'], layer_idx=1)
        
# #         # ConvTranspose для восстановления размерности после блока
# #         self.tlinear = nn.ConvTranspose1d(self.hid_feature * cfg['model_cfg']['expand'], self.hid_feature, kernel_size=1)
# #         self.flinear = nn.ConvTranspose1d(self.hid_feature * cfg['model_cfg']['expand'], self.hid_feature, kernel_size=1)

# #     def forward(self, x, inference_params=None):
# #         b, c, t, f = x.size()

# #         # Обработка по временной оси
# #         x = x.permute(0, 3, 2, 1).contiguous().view(b * f, t, c)
# #         x = self.time_block(x)
# #         x = self.tlinear(x.permute(0, 2, 1)).permute(0, 2, 1)

# #         # Обработка по частотной оси
# #         x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b * t, f, c)
# #         x = self.freq_block(x)
# #         x = self.flinear(x.permute(0, 2, 1)).permute(0, 2, 1)

# #         # Восстанавливаем оригинальную форму тензора
# #         x = x.view(b, t, f, c).permute(0, 3, 1, 2)

# #         return x
    
# import torch
# import torch.nn as nn
# from typing import Optional

# class Block(nn.Module):
#     def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False):
#         """
#         Блок, который включает остаточные соединения, нормализацию (LayerNorm или RMSNorm) 
#         и вычисления в формате FP32 для остаточных соединений.

#         Args:
#         - dim: Размер входного тензора.
#         - mixer_cls: Класс или функция, которая будет применяться как основной миксер (например, CustomBlock).
#         - norm_cls: Нормализационный слой (по умолчанию LayerNorm).
#         - fused_add_norm: Оптимизация для ускоренного сложения и нормализации.
#         - residual_in_fp32: Если True, вычисления остатка производятся в формате FP32.
#         """
#         super().__init__()
#         self.residual_in_fp32 = residual_in_fp32
#         self.fused_add_norm = fused_add_norm
#         self.mixer = mixer_cls(dim)  # Основной миксер (например, CustomBlock)
#         self.norm = norm_cls(dim)  # Нормализация (LayerNorm или RMSNorm)

#         if self.fused_add_norm:
#             # Если используется fused_add_norm, проверяем, что используется LayerNorm или RMSNorm
#             assert isinstance(self.norm, nn.LayerNorm), "Только LayerNorm поддерживается для fused_add_norm"

#     def forward(self, hidden_states: torch.Tensor, residual: Optional[torch.Tensor] = None, inference_params=None):
#         """
#         Пропуск входных данных через слой с остаточными соединениями и нормализацией.

#         Args:
#         - hidden_states: Входные данные в слой.
#         - residual: Остаточные соединения. Если None, используется hidden_states как residual.
#         - inference_params: Параметры, которые могут быть переданы в миксер для работы в режиме inference.
        
#         Returns:
#         - hidden_states: Выход из слоя.
#         - residual: Обновленные остаточные соединения.
#         """
#         if not self.fused_add_norm:
#             # Обычная операция: складываем hidden_states и residual, нормализуем
#             residual = (hidden_states + residual) if residual is not None else hidden_states
#             hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))

#             # Если требуется, преобразуем residual в FP32
#             if self.residual_in_fp32:
#                 residual = residual.to(torch.float32)
#         else:
#             # Оптимизированное сложение и нормализация (fused add-norm)
#             hidden_states = self.norm(hidden_states + (residual if residual is not None else hidden_states))

#         # Пропускаем данные через миксер
#         hidden_states = self.mixer(hidden_states, inference_params=inference_params)
        
#         return hidden_states, residual

#     def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
#         """
#         Выделение кэша для режима inference. Может быть полезно для трансформеров или других блоков,
#         работающих с последовательностями.
#         """
#         return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)

In [16]:
# Copyright (c) 2023, Tri Dao, Albert Gu.

import math
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor

from einops import rearrange, repeat

from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn

try:
    from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
except ImportError:
    causal_conv1d_fn, causal_conv1d_update = None, None

try:
    from mamba_ssm.ops.triton.selective_state_update import selective_state_update
except ImportError:
    selective_state_update = None

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
except ImportError:
    RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None


class Mamba(nn.Module):
    def __init__(
        self,
        d_model,
        d_state=16,
        d_conv=4,
        expand=2,
        dt_rank="auto",
        dt_min=0.001,
        dt_max=0.1,
        dt_init="random",
        dt_scale=1.0,
        dt_init_floor=1e-4,
        conv_bias=True,
        bias=False,
        use_fast_path=True,  # Fused kernel options
        layer_idx=None,
        device=None,
        dtype=None,
    ):
        factory_kwargs = {"device": device, "dtype": dtype}
        super().__init__()
        self.d_model = d_model
        self.d_state = d_state
        self.d_conv = d_conv
        self.expand = expand
        self.d_inner = int(self.expand * self.d_model)
        self.dt_rank = math.ceil(self.d_model / 16) if dt_rank == "auto" else dt_rank
        self.use_fast_path = use_fast_path
        self.layer_idx = layer_idx

        self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)

        self.conv1d = nn.Conv1d(
            in_channels=self.d_inner,
            out_channels=self.d_inner,
            bias=conv_bias,
            kernel_size=d_conv,
            groups=self.d_inner,
            padding=d_conv - 1,
            **factory_kwargs,
        )

        self.activation = "silu"
        self.act = nn.SiLU()

        self.x_proj = nn.Linear(
            self.d_inner, self.dt_rank + self.d_state * 2, bias=False, **factory_kwargs
        )
        self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True, **factory_kwargs)

        # Initialize special dt projection to preserve variance at initialization
        dt_init_std = self.dt_rank**-0.5 * dt_scale
        if dt_init == "constant":
            nn.init.constant_(self.dt_proj.weight, dt_init_std)
        elif dt_init == "random":
            nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
        else:
            raise NotImplementedError

        # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
        dt = torch.exp(
            torch.rand(self.d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
            + math.log(dt_min)
        ).clamp(min=dt_init_floor)
        # Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
        inv_dt = dt + torch.log(-torch.expm1(-dt))
        with torch.no_grad():
            self.dt_proj.bias.copy_(inv_dt)
        # Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
        self.dt_proj.bias._no_reinit = True

        # S4D real initialization
        A = repeat(
            torch.arange(1, self.d_state + 1, dtype=torch.float32, device=device),
            "n -> d n",
            d=self.d_inner,
        ).contiguous()
        A_log = torch.log(A)  # Keep A_log in fp32
        self.A_log = nn.Parameter(A_log)
        self.A_log._no_weight_decay = True

        # D "skip" parameter
        self.D = nn.Parameter(torch.ones(self.d_inner, device=device))  # Keep in fp32
        self.D._no_weight_decay = True

        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)

    def forward(self, hidden_states, inference_params=None):
        """
        hidden_states: (B, L, D)
        Returns: same shape as hidden_states
        """
        batch, seqlen, dim = hidden_states.shape

        conv_state, ssm_state = None, None
        if inference_params is not None:
            conv_state, ssm_state = self._get_states_from_cache(inference_params, batch)
            if inference_params.seqlen_offset > 0:
                # The states are updated inplace
                out, _, _ = self.step(hidden_states, conv_state, ssm_state)
                return out

        # We do matmul and transpose BLH -> HBL at the same time
        xz = rearrange(
            self.in_proj.weight @ rearrange(hidden_states, "b l d -> d (b l)"),
            "d (b l) -> b d l",
            l=seqlen,
        )
        if self.in_proj.bias is not None:
            xz = xz + rearrange(self.in_proj.bias.to(dtype=xz.dtype), "d -> d 1")

        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
        # In the backward pass we write dx and dz next to each other to avoid torch.cat
        if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:  # Doesn't support outputting the states
            out = mamba_inner_fn(
                xz,
                self.conv1d.weight,
                self.conv1d.bias,
                self.x_proj.weight,
                self.dt_proj.weight,
                self.out_proj.weight,
                self.out_proj.bias,
                A,
                None,  # input-dependent B
                None,  # input-dependent C
                self.D.float(),
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
            )
        else:
            x, z = xz.chunk(2, dim=1)
            # Compute short convolution
            if conv_state is not None:
                # If we just take x[:, :, -self.d_conv :], it will error if seqlen < self.d_conv
                # Instead F.pad will pad with zeros if seqlen < self.d_conv, and truncate otherwise.
                conv_state.copy_(F.pad(x, (self.d_conv - x.shape[-1], 0)))  # Update state (B D W)
            if causal_conv1d_fn is None:
                x = self.act(self.conv1d(x)[..., :seqlen])
            else:
                assert self.activation in ["silu", "swish"]
                x = causal_conv1d_fn(
                    x=x,
                    weight=rearrange(self.conv1d.weight, "d 1 w -> d w"),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                )

            # We're careful here about the layout, to avoid extra transposes.
            # We want dt to have d as the slowest moving dimension
            # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
            x_dbl = self.x_proj(rearrange(x, "b d l -> (b l) d"))  # (bl d)
            dt, B, C = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=-1)
            dt = self.dt_proj.weight @ dt.t()
            dt = rearrange(dt, "d (b l) -> b d l", l=seqlen)
            B = rearrange(B, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            C = rearrange(C, "(b l) dstate -> b dstate l", l=seqlen).contiguous()
            assert self.activation in ["silu", "swish"]
            y = selective_scan_fn(
                x,
                dt,
                A,
                B,
                C,
                self.D.float(),
                z=z,
                delta_bias=self.dt_proj.bias.float(),
                delta_softplus=True,
                return_last_state=ssm_state is not None,
            )
            if ssm_state is not None:
                y, last_state = y
                ssm_state.copy_(last_state)
            y = rearrange(y, "b d l -> b l d")
            out = self.out_proj(y)
        return out

    def step(self, hidden_states, conv_state, ssm_state):
        dtype = hidden_states.dtype
        assert hidden_states.shape[1] == 1, "Only support decoding with 1 token at a time for now"
        xz = self.in_proj(hidden_states.squeeze(1))  # (B 2D)
        x, z = xz.chunk(2, dim=-1)  # (B D)

        # Conv step
        if causal_conv1d_update is None:
            conv_state.copy_(torch.roll(conv_state, shifts=-1, dims=-1))  # Update state (B D W)
            conv_state[:, :, -1] = x
            x = torch.sum(conv_state * rearrange(self.conv1d.weight, "d 1 w -> d w"), dim=-1)  # (B D)
            if self.conv1d.bias is not None:
                x = x + self.conv1d.bias
            x = self.act(x).to(dtype=dtype)
        else:
            x = causal_conv1d_update(
                x,
                conv_state,
                rearrange(self.conv1d.weight, "d 1 w -> d w"),
                self.conv1d.bias,
                self.activation,
            )

        x_db = self.x_proj(x)  # (B dt_rank+2*d_state)
        dt, B, C = torch.split(x_db, [self.dt_rank, self.d_state, self.d_state], dim=-1)
        # Don't add dt_bias here
        dt = F.linear(dt, self.dt_proj.weight)  # (B d_inner)
        A = -torch.exp(self.A_log.float())  # (d_inner, d_state)

        # SSM step
        if selective_state_update is None:
            # Discretize A and B
            dt = F.softplus(dt + self.dt_proj.bias.to(dtype=dt.dtype))
            dA = torch.exp(torch.einsum("bd,dn->bdn", dt, A))
            dB = torch.einsum("bd,bn->bdn", dt, B)
            ssm_state.copy_(ssm_state * dA + rearrange(x, "b d -> b d 1") * dB)
            y = torch.einsum("bdn,bn->bd", ssm_state.to(dtype), C)
            y = y + self.D.to(dtype) * x
            y = y * self.act(z)  # (B D)
        else:
            y = selective_state_update(
                ssm_state, x, dt, A, B, C, self.D, z=z, dt_bias=self.dt_proj.bias, dt_softplus=True
            )

        out = self.out_proj(y)
        return out.unsqueeze(1), conv_state, ssm_state

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        device = self.out_proj.weight.device
        conv_dtype = self.conv1d.weight.dtype if dtype is None else dtype
        conv_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_conv, device=device, dtype=conv_dtype
        )
        ssm_dtype = self.dt_proj.weight.dtype if dtype is None else dtype
        # ssm_dtype = torch.float32
        ssm_state = torch.zeros(
            batch_size, self.d_model * self.expand, self.d_state, device=device, dtype=ssm_dtype
        )
        return conv_state, ssm_state

    def _get_states_from_cache(self, inference_params, batch_size, initialize_states=False):
        assert self.layer_idx is not None
        if self.layer_idx not in inference_params.key_value_memory_dict:
            batch_shape = (batch_size,)
            conv_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_conv,
                device=self.conv1d.weight.device,
                dtype=self.conv1d.weight.dtype,
            )
            ssm_state = torch.zeros(
                batch_size,
                self.d_model * self.expand,
                self.d_state,
                device=self.dt_proj.weight.device,
                dtype=self.dt_proj.weight.dtype,
                # dtype=torch.float32,
            )
            inference_params.key_value_memory_dict[self.layer_idx] = (conv_state, ssm_state)
        else:
            conv_state, ssm_state = inference_params.key_value_memory_dict[self.layer_idx]
            # TODO: What if batch size changes between generation, and we reuse the same states?
            if initialize_states:
                conv_state.zero_()
                ssm_state.zero_()
        return conv_state, ssm_state


class Block(nn.Module):
    def __init__(
        self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
    ):
        """
        Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"

        This Block has a slightly different structure compared to a regular
        prenorm Transformer block.
        The standard block is: LN -> MHA/MLP -> Add.
        [Ref: https://arxiv.org/abs/2002.04745]
        Here we have: Add -> LN -> Mixer, returning both
        the hidden_states (output of the mixer) and the residual.
        This is purely for performance reasons, as we can fuse add and LayerNorm.
        The residual needs to be provided (except for the very first block).
        """
        super().__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)
        if self.fused_add_norm:
            assert RMSNorm is not None, "RMSNorm import fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

    def forward(
        self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
    ):
        r"""Pass the input through the encoder layer.

        Args:
            hidden_states: the sequence to the encoder layer (required).
            residual: hidden_states = Mixer(LN(residual))
        """
        if not self.fused_add_norm:
            residual = (hidden_states + residual) if residual is not None else hidden_states
            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
            if self.residual_in_fp32:
                residual = residual.to(torch.float32)
        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            hidden_states, residual = fused_add_norm_fn(
                hidden_states,
                self.norm.weight,
                self.norm.bias,
                residual=residual,
                prenorm=True,
                residual_in_fp32=self.residual_in_fp32,
                eps=self.norm.eps,
            )
        hidden_states = self.mixer(hidden_states, inference_params=inference_params)
        return hidden_states, residual

    def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
        return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)


In [44]:
import torch
import torch.nn as nn
from functools import partial

# def create_block(d_model, cfg, layer_idx=0):
#     """
#     Создание Mamba блока.
#     """
#     d_state = cfg['model_cfg']['d_state']
#     d_conv = cfg['model_cfg']['d_conv']
#     expand = cfg['model_cfg']['expand']

#     mixer_cls = partial(Mamba, layer_idx=layer_idx, d_state=d_state, d_conv=d_conv, expand=expand)
#     block = Block(
#             d_model,
#             mixer_cls=mixer_cls
#         )
#     return block

def create_block(d_model, cfg, layer_idx=0):
    d_state = cfg['model_cfg'].get('d_state', 16)
    d_conv = cfg['model_cfg'].get('d_conv', 4)
    expand = cfg['model_cfg'].get('expand', 4)

    # Передаем d_model в Mamba
    mixer_cls = lambda dim: Mamba(d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand, layer_idx=layer_idx)

    # Создаём Block
    block = Block(
        dim=d_model,
        mixer_cls=mixer_cls,
        norm_cls=nn.LayerNorm,  # Нормализация по умолчанию
        fused_add_norm=False,   # Можно включить ускорение
        residual_in_fp32=False  # Можно включить FP32 для остаточных связей
    )

    return block


class MambaBlock(nn.Module):
    """
    MambaBlock состоит из forward и backward блоков для двунаправленной обработки.
    """
    def __init__(self, in_channels, cfg):
        super(MambaBlock, self).__init__()
        n_layer = 1
        self.forward_blocks = nn.ModuleList([create_block(in_channels, cfg, i) for i in range(n_layer)])
        self.backward_blocks = nn.ModuleList([create_block(in_channels, cfg, i) for i in range(n_layer)])

    def forward(self, x):
        x_forward, x_backward = x.clone(), torch.flip(x, [1])
        resi_forward, resi_backward = None, None

        # Forward
        for layer in self.forward_blocks:
            x_forward, resi_forward = layer(x_forward, resi_forward)
        y_forward = (x_forward + resi_forward) if resi_forward is not None else x_forward

        # Backward
        for layer in self.backward_blocks:
            x_backward, resi_backward = layer(x_backward, resi_backward)
        y_backward = torch.flip((x_backward + resi_backward), [1]) if resi_backward is not None else torch.flip(x_backward, [1])

        return torch.cat([y_forward, y_backward], -1)

class TFMambaBlock(nn.Module):
    """
    TFMambaBlock для моделирования временных и частотных зависимостей.
    """
    def __init__(self, cfg):
        super(TFMambaBlock, self).__init__()
        self.hid_feature = cfg['model_cfg']['hid_feature']
        
        # Initialize Mamba blocks для временной и частотной обработки
        self.time_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg)
        self.freq_mamba = MambaBlock(in_channels=self.hid_feature, cfg=cfg)
        
        # ConvTranspose1d для увеличения размерностей по временной и частотной осям
        self.tlinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1)
        self.flinear = nn.ConvTranspose1d(self.hid_feature * 2, self.hid_feature, 1, stride=1)
    
    def forward(self, x):
        """
        Forward pass через TFMambaBlock.
        """
        b, c, t, f = x.size()

        x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
        x = self.tlinear(self.time_mamba(x).permute(0, 2, 1)).permute(0, 2, 1) + x
        x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
        x = self.flinear(self.freq_mamba(x).permute(0, 2, 1)).permute(0, 2, 1) + x
        x = x.view(b, t, f, c).permute(0, 3, 1, 2)
        return x


In [63]:
import torch
import torch.nn as nn
from einops import rearrange

# Helper functions
def get_padding(kernel_size, dilation=1):
    """
    Calculate the padding size for a convolutional layer.
    """
    return int((kernel_size * dilation - dilation) / 2)

def get_padding_2d(kernel_size, dilation=(1, 1)):
    """
    Calculate the padding size for a 2D convolutional layer.
    """
    return (int((kernel_size[0] * dilation[0] - dilation[0]) / 2), 
            int((kernel_size[1] * dilation[1] - dilation[1]) / 2))

# DenseBlock definition
class DenseBlock(nn.Module):
    """
    DenseBlock module consisting of multiple convolutional layers with dilation.
    """
    def __init__(self, cfg, kernel_size=(3, 3), depth=4):
        super(DenseBlock, self).__init__()
        self.cfg = cfg
        self.depth = depth
        self.dense_block = nn.ModuleList()
        self.hid_feature = cfg['model_cfg']['hid_feature']

        for i in range(depth):
            dil = 2 ** i
            dense_conv = nn.Sequential(
                nn.Conv2d(self.hid_feature * (i + 1), self.hid_feature, kernel_size, 
                          dilation=(dil, 1), padding=get_padding_2d(kernel_size, (dil, 1))),
                nn.InstanceNorm2d(self.hid_feature, affine=True),
                nn.PReLU(self.hid_feature)
            )
            self.dense_block.append(dense_conv)

    def forward(self, x):
        """
        Forward pass for the DenseBlock module.
        """
        skip = x
        for i in range(self.depth):
            x = self.dense_block[i](skip)
            skip = torch.cat([x, skip], dim=1)
        return x

# DenseEncoder definition
class DenseEncoder(nn.Module):
    """
    DenseEncoder module consisting of initial convolution, dense block, and a final convolution.
    """
    def __init__(self, cfg):
        super(DenseEncoder, self).__init__()
        self.cfg = cfg
        self.input_channel = cfg['model_cfg']['input_channel']
        self.hid_feature = cfg['model_cfg']['hid_feature']

        self.dense_conv_1 = nn.Sequential(
            nn.Conv2d(self.input_channel, self.hid_feature, (1, 1)),
            nn.InstanceNorm2d(self.hid_feature, affine=True),
            nn.PReLU(self.hid_feature)
        )

        self.dense_block = DenseBlock(cfg, depth=4)

        self.dense_conv_2 = nn.Sequential(
            nn.Conv2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
            nn.InstanceNorm2d(self.hid_feature, affine=True),
            nn.PReLU(self.hid_feature)
        )

    def forward(self, x):
        """
        Forward pass for the DenseEncoder module.
        """
        x = self.dense_conv_1(x)
        x = self.dense_block(x)
        x = self.dense_conv_2(x)
        return x

# MagDecoder definition
class MagDecoder(nn.Module):
    """
    MagDecoder module for decoding magnitude information.
    """
    def __init__(self, cfg):
        super(MagDecoder, self).__init__()
        self.dense_block = DenseBlock(cfg, depth=4)
        self.hid_feature = cfg['model_cfg']['hid_feature']
        self.output_channel = cfg['model_cfg']['output_channel']
        self.n_fft = cfg['stft_cfg']['n_fft']
        self.beta = cfg['model_cfg']['beta']

        self.mask_conv = nn.Sequential(
            nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
            nn.Conv2d(self.hid_feature, self.output_channel, (1, 1)),
            nn.InstanceNorm2d(self.output_channel, affine=True),
            nn.PReLU(self.output_channel),
            nn.Conv2d(self.output_channel, self.output_channel, (1, 1))
        )
        self.lsigmoid = LearnableSigmoid2D(self.n_fft // 2 + 1, beta=self.beta)

    def forward(self, x):
        """
        Forward pass for the MagDecoder module.
        """
        x = self.dense_block(x)
        x = self.mask_conv(x)
        x = rearrange(x, 'b c t f -> b f t c').squeeze(-1)
        x = self.lsigmoid(x)
        x = rearrange(x, 'b f t -> b t f').unsqueeze(1)
        return x

# PhaseDecoder definition
class PhaseDecoder(nn.Module):
    """
    PhaseDecoder module for decoding phase information.
    """
    def __init__(self, cfg):
        super(PhaseDecoder, self).__init__()
        self.dense_block = DenseBlock(cfg, depth=4)
        self.hid_feature = cfg['model_cfg']['hid_feature']
        self.output_channel = cfg['model_cfg']['output_channel']

        self.phase_conv = nn.Sequential(
            nn.ConvTranspose2d(self.hid_feature, self.hid_feature, (1, 3), stride=(1, 2)),
            nn.InstanceNorm2d(self.hid_feature, affine=True),
            nn.PReLU(self.hid_feature)
        )

        self.phase_conv_r = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))
        self.phase_conv_i = nn.Conv2d(self.hid_feature, self.output_channel, (1, 1))

    def forward(self, x):
        """
        Forward pass for the PhaseDecoder module.
        """
        x = self.dense_block(x)
        x = self.phase_conv(x)
        x_r = self.phase_conv_r(x)
        x_i = self.phase_conv_i(x)
        x = torch.atan2(x_i, x_r)
        return x

# SEMamba model definition
class SEMamba(nn.Module):
    """
    SEMamba model for speech enhancement using Mamba blocks.
    This model uses a dense encoder, multiple Mamba blocks, and separate magnitude
    and phase decoders to process noisy magnitude and phase inputs.
    """
    def __init__(self, cfg):
        super(SEMamba, self).__init__()
        self.cfg = cfg
        self.num_tscblocks = cfg['model_cfg']['num_tfmamba'] if cfg['model_cfg']['num_tfmamba'] is not None else 4  # default tfmamba: 4

        # Initialize dense encoder
        self.dense_encoder = DenseEncoder(cfg)

        # Initialize Mamba blocks
        self.TSMamba = nn.ModuleList([TFMambaBlock(cfg) for _ in range(self.num_tscblocks)])

        # Initialize decoders
        self.mask_decoder = MagDecoder(cfg)
        self.phase_decoder = PhaseDecoder(cfg)

    def forward(self, noisy_mag, noisy_pha):
        """
        Forward pass for the SEMamba model.
        """
        # Reshape inputs
        noisy_mag = rearrange(noisy_mag.squeeze(1), 'b f t -> b t f').unsqueeze(1)  # [B, 1, T, F]
        noisy_pha = rearrange(noisy_pha.squeeze(1), 'b f t -> b t f').unsqueeze(1)  # [B, 1, T, F]

        # Concatenate magnitude and phase inputs
        x = torch.cat((noisy_mag, noisy_pha), dim=1)  # [B, 2, T, F]

        # Encode input
        x = self.dense_encoder(x)

        # Apply Mamba blocks
        for block in self.TSMamba:
            x = block(x)

        # Decode magnitude and phase
        denoised_mag = rearrange(self.mask_decoder(x) * noisy_mag, 'b c t f -> b f t c').squeeze(-1)
        denoised_pha = rearrange(self.phase_decoder(x), 'b c t f -> b f t c').squeeze(-1)

        # Combine denoised magnitude and phase into a complex representation
        denoised_com = torch.stack(
            (denoised_mag * torch.cos(denoised_pha), denoised_mag * torch.sin(denoised_pha)),
            dim=-1
        )

        return denoised_mag, denoised_pha, denoised_com

# LearnableSigmoid2D definition (add this class here if it is used in the code)
class LearnableSigmoid2D(nn.Module):
    """
    Learnable Sigmoid for 2D tensors.
    """
    def __init__(self, dim, beta=1.0):
        super(LearnableSigmoid2D, self).__init__()
        self.beta = nn.Parameter(torch.tensor(beta))

    def forward(self, x):
        return torch.sigmoid(self.beta * x)


In [125]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
from torch.utils.data import DataLoader, Dataset
from torchaudio.transforms import Spectrogram, InverseSpectrogram
from tqdm.auto import tqdm

class SpeechDataset(Dataset):
    def __init__(self, noised_dir, clean_dir, sample_rate=16000, segment_length=16000):
        """
        Инициализация датасета.
        Args:
        - noised_dir (str): Путь к шумным аудиофайлам.
        - clean_dir (str): Путь к чистым аудиофайлам.
        - sample_rate (int): Частота дискретизации (по умолчанию 16000 Гц).
        - segment_length (int): Длина сегмента для обучения в сэмплах.
        """
        self.noised_files = sorted([os.path.join(noised_dir, f) for f in os.listdir(noised_dir) if f.endswith('.wav')])
        self.clean_files = sorted([os.path.join(clean_dir, f) for f in os.listdir(clean_dir) if f.endswith('.wav')])
        self.sample_rate = sample_rate
        self.segment_length = segment_length
        assert len(self.noised_files) == len(self.clean_files), "Количество файлов в папках должно совпадать"
    
    def __len__(self):
        return len(self.noised_files)
    
    def __getitem__(self, idx):
        # Загрузка аудиофайлов
        noised_waveform, _ = torchaudio.load(self.noised_files[idx], normalize=True)
        clean_waveform, _ = torchaudio.load(self.clean_files[idx], normalize=True)

        # Убедимся, что длина аудиофайлов одинакова и не превышает segment_length
        noised_waveform = noised_waveform[:, :self.segment_length]
        clean_waveform = clean_waveform[:, :self.segment_length]

        # Если длина аудио меньше, то заполним нулями
        if noised_waveform.size(1) < self.segment_length:
            pad = self.segment_length - noised_waveform.size(1)
            noised_waveform = torch.nn.functional.pad(noised_waveform, (0, pad))
            clean_waveform = torch.nn.functional.pad(clean_waveform, (0, pad))

        return noised_waveform, clean_waveform

def train_model(model, train_loader, num_epochs, learning_rate, device):
    """
    Функция для обучения модели.
    Args:
    - model: Модель для обучения.
    - train_loader: DataLoader для тренировочных данных.
    - num_epochs (int): Количество эпох для обучения.
    - learning_rate (float): Скорость обучения.
    - device: Устройство, на котором выполняется обучение (cuda/mps/cpu).
    """
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    loss_fn = nn.MSELoss()

    # STFT и инверсный STFT
    stft_transform = Spectrogram(n_fft=400, hop_length=160, win_length=400, power=None).to(device)
    istft_transform = InverseSpectrogram(n_fft=400, hop_length=160, win_length=400).to(device)

    model.train()

    for epoch in tqdm(range(num_epochs)):
        running_loss = 0.0
        progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

        for noised_waveform, clean_waveform in progress_bar:
            # Перемещаем данные на устройство
            noised_waveform = noised_waveform.to(device)
            clean_waveform = clean_waveform.to(device)

            # Преобразование аудио в спектр (магнитуда и фаза)
            noised_spectrogram = stft_transform(noised_waveform)
            clean_spectrogram = stft_transform(clean_waveform)
            noised_mag, noised_phase = noised_spectrogram.abs(), noised_spectrogram.angle()
            clean_mag, clean_phase = clean_spectrogram.abs(), clean_spectrogram.angle()

            # Forward pass через модель
            optimizer.zero_grad()
            denoised_mag, denoised_phase, _ = model(noised_mag, noised_phase)

            # Вычисление потерь
            loss = loss_fn(denoised_mag.unsqueeze(1), clean_mag) + loss_fn(denoised_phase.unsqueeze(1), clean_phase)

            # Backpropagation и обновление весов
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            progress_bar.set_postfix({"Loss": running_loss / len(train_loader)})
        # print(model)
        #print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss / len(train_loader)}")

    print("Обучение завершено")

def main(noised_dir='/kaggle/input/speech-dataset/noised_speech', clean_dir='/kaggle/input/speech-dataset/clean_speech', num_epochs=10):
    # Параметры обучения
    sample_rate = 16000
    segment_length = sample_rate * 1  # 1 секунда
    batch_size = 8
    # num_epochs = 10
    learning_rate = 0.001
    # Определение устройства (GPU, если доступно, иначе CPU)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Используемое устройство: {device}")

    # Проверка доступности данных
    if not os.path.exists(noised_dir) or not os.path.exists(clean_dir):
        raise FileNotFoundError(f"Директории с данными не найдены. Убедитесь, что данные загружены в {noised_dir} и {clean_dir}.")

    # Инициализация датасета и DataLoader
    dataset = SpeechDataset(noised_dir, clean_dir, sample_rate, segment_length)
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Инициализация модели
    cfg = {
        'model_cfg': {
            'hid_feature': 64,  # Например, скрытые признаки
            'input_channel': 2,
            'output_channel': 1,
            'beta': 1.0,
            'num_tfmamba': 4,  # Число блоков Mamba
            'd_state': 16,
            'd_conv': 4,
            'expand': 4
        },
        'stft_cfg': {
            'n_fft': 400,
            'hop_size': 160,
            'win_size': 400,
            'sampling_rate': sample_rate,
        }
    }
    
    model = SEMamba(cfg).to(device)
    # print(model)
    # Обучение модели
    train_model(model, train_loader, num_epochs, learning_rate, device)
    # print(model)
    return model


In [126]:
model = None

In [127]:
model

In [121]:
SEMamba_model_1 = main(noised_dir='data/noised_speech', clean_dir='data/clean_speech', num_epochs=100)


Используемое устройство: cuda


  0%|          | 0/10 [00:00<?, ?it/s]

Epoch 1/10:   0%|          | 0/336 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
SEMamba_model_2 = main(noised_dir='data/back_noised_speech', clean_dir='data/clean_speech', num_epochs=100)


Используемое устройство: cuda


  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 2/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 3/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 4/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 5/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 6/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 7/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 8/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 9/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 10/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 11/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 12/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 13/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 14/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 15/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 16/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 17/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 18/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 19/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 20/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 21/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 22/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 23/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 24/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 25/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 26/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 27/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 28/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 29/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 30/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 31/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 32/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 33/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 34/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 35/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 36/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 37/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 38/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 39/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 40/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 41/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 42/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 43/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 44/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 45/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 46/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 47/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 48/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 49/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 50/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 51/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 52/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 53/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 54/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 55/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 56/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 57/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 58/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 59/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 60/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 61/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 62/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 63/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 64/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 65/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 66/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 67/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 68/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 69/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 70/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 71/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 72/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 73/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 74/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 75/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 76/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 77/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 78/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 79/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 80/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 81/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 82/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 83/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 84/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 85/100:   0%|          | 0/336 [00:00<?, ?it/s]

Epoch 86/100:   0%|          | 0/336 [00:00<?, ?it/s]

In [110]:
import torch
import librosa
import numpy as np
import soundfile as sf
import IPython

def denoise_audio(audio_path, model, output_path='denoised_output.wav'):
    """
    Применяет модель SEMamba для денойзинга аудиофайла.

    :param audio_path: путь до файла с зашумленным аудио
    :param model: обученная модель SEMamba
    :param output_path: путь для сохранения денойзингового аудиофайла
    :return: None
    """
    # Загрузка аудиофайла
    audio, sr = librosa.load(audio_path, sr=None)

    display(IPython.display.Audio(audio, rate=sr))
    # Применение STFT для получения магнитуды и фазы
    n_fft = 1024
    hop_length = 512
    win_length = 1024
    stft_result = librosa.stft(audio, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
    magnitude, phase = np.abs(stft_result), np.angle(stft_result)

    # Преобразование в тензоры
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    magnitude_tensor = torch.tensor(magnitude).unsqueeze(0).to(device)
    phase_tensor = torch.tensor(phase).unsqueeze(0).to(device)

    # Перевод модели в режим инференса
    model.eval()

    # Применение модели
    with torch.no_grad():
        denoised_mag, denoised_phase, _ = model(magnitude_tensor, phase_tensor)

    # Восстановление комплексного сигнала и применение ISTFT
    complex_denoised_signal = denoised_mag.squeeze().cpu().numpy() * np.exp(1j * denoised_phase.squeeze().cpu().numpy())
    denoised_audio = librosa.istft(complex_denoised_signal, hop_length=hop_length, win_length=win_length)

    display(IPython.display.Audio(denoised_audio, rate=sr))
    # Сохранение денойзингового аудиофайла
    sf.write(output_path, denoised_audio, sr)
    print(f"Денойзинговое аудио сохранено в: {output_path}")

# Пример использования:
# denoise_audio('path_to_noisy_audio.wav', model, 'path_to_save_denoised_audio.wav')

In [111]:
torch.save(SEMamba_model_2, './models/SEMamba_model_2_v1.pth').

In [112]:
SEMamba_model

SEMamba(
  (dense_encoder): DenseEncoder(
    (dense_conv_1): Sequential(
      (0): Conv2d(2, 64, kernel_size=(1, 1), stride=(1, 1))
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): PReLU(num_parameters=64)
    )
    (dense_block): DenseBlock(
      (dense_block): ModuleList(
        (0): Sequential(
          (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (2): PReLU(num_parameters=64)
        )
        (1): Sequential(
          (0): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(2, 1), dilation=(2, 1))
          (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
          (2): PReLU(num_parameters=64)
        )
        (2): Sequential(
          (0): Conv2d(192, 64, kernel_size=(3, 3), stride=(1, 1), padding=(4, 1), dilation=(4, 1))
          (1)

In [123]:
denoise_audio('./data/noised_speech/female_1462-170138-0000.wav', SEMamba_model_1, 'denoised_audio.wav')

Денойзинговое аудио сохранено в: denoised_audio.wav


In [None]:
denoise_audio('./data/back_noised_speech/female_1462-170138-0000.wav', SEMamba_model_2, 'denoised_audio.wav')