_____
# ***End to End Models***

#### **Conditioners**

In [3]:
# Base
from torch import nn


class Conditioning(nn.Module):
    def __init__(
        self, cond_embedding_dim: int, channels: int, channels_per_group: int = 16
    ):
        super().__init__()

        self.channels = channels
        self.cond_embedding_dim = cond_embedding_dim
        self.channels_per_group = channels_per_group

        self.gn = nn.GroupNorm(self.channels // self.channels_per_group, self.channels)

    def forward(self, x, w):
        raise NotImplementedError


class PassThroughConditioning(Conditioning):
    def __init__(
        self, cond_embedding_dim: int, channels: int, channels_per_group: int = 16
    ):
        super().__init__(cond_embedding_dim, channels, channels_per_group)

    def forward(self, x, w):
        return self.gn(x)

In [4]:
# Film
import math
import torch
from torch import nn
from torch.nn.modules import activation as activation_


class FiLM(Conditioning):
    def __init__(
        self,
        cond_embedding_dim: int,
        channels: int,
        additive: bool = True,
        multiplicative: bool = False,
        depth: int = 1,
        activation: str = "ELU",
        channels_per_group: int = 16,
    ):
        super().__init__(
            channels=channels,
            channels_per_group=channels_per_group,
            cond_embedding_dim=cond_embedding_dim,
        )

        self.additive = additive
        self.multiplicative = multiplicative
        self.depth = depth
        self.activation = activation

        Activation = activation_.__dict__[activation]

        if self.multiplicative:

            if depth == 1:
                self.gamma = nn.Linear(self.cond_embedding_dim, self.channels)
            else:
                layers = [nn.Linear(self.cond_embedding_dim, self.channels)]
                for _ in range(depth - 1):
                    layers += [Activation(), nn.Linear(self.channels, self.channels)]
                self.gamma = nn.Sequential(*layers)
        else:
            self.gamma = None

        if self.additive:
            if depth == 1:
                self.beta = nn.Linear(self.cond_embedding_dim, self.channels)
            else:
                layers = [nn.Linear(self.cond_embedding_dim, self.channels)]
                for _ in range(depth - 1):
                    layers += [Activation(), nn.Linear(self.channels, self.channels)]
                self.beta = nn.Sequential(*layers)
        else:
            self.beta = None

    def forward(self, x, w):

        x = self.gn(x)

        if self.multiplicative:
            gamma = self.gamma(w)

            if len(x.shape) == 4:
                gamma = gamma[:, :, None, None]
            elif len(x.shape) == 3:
                gamma = gamma[:, :, None]
            elif len(x.shape) == 2:
                pass
            else:
                raise ValueError(f"Invalid shape for input tensor: {x.shape}")

            x = gamma * x

        if self.additive:
            beta = self.beta(w)
            if len(x.shape) == 4:
                beta = beta[:, :, None, None]
            elif len(x.shape) == 3:
                beta = beta[:, :, None]
            elif len(x.shape) == 2:
                pass
            else:
                raise ValueError(f"Invalid shape for input tensor: {x.shape}")

            x = x + beta

        return x
        
class CosineSimiliarity(Conditioning):
    def __init__(self, cond_embedding_dim: int, channels: int, channels_per_group: int = 16):
        super().__init__(cond_embedding_dim, channels, channels_per_group)
        
        self.csim = nn.CosineSimilarity(dim=1)
        self.proj = nn.Linear(self.cond_embedding_dim, self.channels * self.channels)
        
    def forward(self, x, w):
        
        
        x = self.gn(x)

        gamma = self.gamma(w)

        if len(x.shape) == 4:
            gamma = gamma[:, :, None, None]
        elif len(x.shape) == 3:
            gamma = gamma[:, :, None]
        elif len(x.shape) == 2:
            pass
        else:
            raise ValueError(f"Invalid shape for input tensor: {x.shape}")
        
        c = self.csim(gamma, x)
        
        x = c[:, None, ...] * x

        
        


class GeneralizedBilinear(nn.Bilinear):
    def __init__(self, in1_features: int, in2_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None:
        super().__init__(in1_features, in2_features, out_features, bias, device, dtype)
        
    def forward(self, x1, x2):
        
        out = torch.einsum(
            "bc...,acd,bd->ba...", x1, self.weight, x2
        )
        
        if self.bias is not None:
            ndim = out.ndim
            bias = torch.reshape(self.bias, (1, -1) + (1,) * (ndim - 2))
            
            out = out + bias
            
        return out
           

class BilinearFiLM(Conditioning):
    def __init__(
        self,
        cond_embedding_dim: int,
        channels: int,
        additive: bool = True,
        multiplicative: bool = False,
        depth: int = 2,
        activation: str = "ELU",
        channels_per_group: int = 16,
    ):
        super().__init__(
            channels=channels,
            channels_per_group=channels_per_group,
            cond_embedding_dim=cond_embedding_dim,
        )

        self.additive = additive
        self.multiplicative = multiplicative
        self.depth = depth
        assert depth == 2, "Only depth 2 is supported for BilinearFiLM"
        self.activation = activation

        Activation = activation_.__dict__[activation]

        if self.multiplicative:
            self.gamma_proj = nn.Sequential(
                nn.Linear(self.cond_embedding_dim, self.channels),
                Activation(),
            )
            self.gamma_bilinear = GeneralizedBilinear(self.channels, self.channels, self.channels)
        else:
            self.gamma = None

        if self.additive:
            self.beta_proj = nn.Sequential(
                nn.Linear(self.cond_embedding_dim, self.channels),
                Activation(),
            )
            self.beta_bilinear = GeneralizedBilinear(self.channels, self.channels, self.channels)
        else:
            self.beta = None

    def forward(self, x, w):

        x = self.gn(x)

        if self.multiplicative:
            gamma = self.gamma_proj(w)
            gamma = self.gamma_bilinear(x, gamma)
            x = gamma * x

        if self.additive:
            beta = self.beta_proj(w)
            beta = self.beta_bilinear(x, beta)
            x = x + beta

        return x

#### **Types**

In [16]:
# InputType  OutputType LossOutputType MetricOutputType 
# ModelType  OptimizerType SchedulerType MetricType LossType 
# OptimizationBundle
# LossHandler MetricHandler AugmentationHandle InferenceHandler 

from types import SimpleNamespace
from typing import Any, Dict, Optional, TypedDict

import torch
from torch import nn, optim
import torchmetrics as tm


class OperationMode:
    TRAIN = "train"
    VAL = "val"
    TEST = "test"
    PREDICT = "predict"


RawInputType = Dict


def nested_dict_to_nested_namespace(d: dict) -> SimpleNamespace:
    d_ = d.copy()

    for k, v in d.items():
        if isinstance(v, dict):
            v = nested_dict_to_nested_namespace(v)

        d_[k] = v

    return SimpleNamespace(**d_)


RawInputType = TypedDict(
    "RawInputType",
    {
        "mixture": torch.Tensor,
        "sources": Dict[str, torch.Tensor],
        "estimates": Optional[Dict[str, torch.Tensor]],
        "metadata": Dict[str, Any],
    },
    total=False,
)


def input_dict(
    mixture: torch.Tensor = None,
    sources: Dict[str, torch.Tensor] = None,
    query: torch.Tensor = None,
    metadata: Dict[str, Any] = None,
    modality: str = "audio",
) -> RawInputType:

    out = {
        "estimates": {
            k: {
                modality: torch.empty(
                    0,
                )
            }
            for k, v in sources.items()
        }
    }

    if mixture is not None:
        out["mixture"] = {modality: torch.from_numpy(mixture).to(torch.float32)}

    if sources is not None:
        out["sources"] = {k: {modality: torch.from_numpy(v).to(torch.float32)} for k, v in sources.items()}

    if query is not None:
        out["query"] = {modality: torch.from_numpy(query).to(torch.float32)}

    if metadata is not None:
        out["metadata"] = metadata

    return out


class SimpleishNamespace(SimpleNamespace):
    def __init__(self, **kwargs: Any) -> None:
        kwargs_ = kwargs.copy()

        for k, v in kwargs.items():
            if isinstance(v, dict):
                v = SimpleishNamespace(**v)

            kwargs_[k] = v

        super().__init__(**kwargs_)

    def copy(self) -> "SimpleishNamespace":
        return SimpleishNamespace(**{k: v for k, v in self.__dict__.items()})

    def add_subnamespace(self, name: str, **kwargs: Any) -> None:
        if hasattr(self, name):
            raise ValueError(f"Namespace already has attribute {name}")

        setattr(self, name, SimpleishNamespace(**kwargs))

    def keys(self):
        return self.__dict__.keys()

    def __getitem__(self, key: str) -> Any:
        return self.__dict__[key]

    def __setitem__(self, key: str, value: Any) -> None:
        self.__dict__[key] = value

    def items(self):
        return self.__dict__.items()


class BatchedInputOutput(SimpleishNamespace):
    mixture: torch.Tensor
    sources: Dict[str, torch.Tensor]
    estimates: Optional[Dict[str, torch.Tensor]]
    metadata: Dict[str, Any]

    def __init__(self, **kwargs: Any) -> None:
        super().__init__(**kwargs)

    @classmethod
    def from_dict(cls, d: dict) -> "BatchedInputOutput":
        return cls(**d)

    def to_dict(self) -> dict:
        return self.__dict__


class TensorCollection(SimpleishNamespace):
    def __init__(self, **kwargs: torch.Tensor) -> None:
        super().__init__(**kwargs)

    def apply(self, func: Any, *args: Any, **kwargs: Any) -> "TensorCollection":
        return TensorCollection(
            **{k: func(v, *args, **kwargs) for k, v in self.__dict__.items()}
        )

    def as_stacked_tensor(self, dim: int = 0) -> torch.Tensor:
        return torch.stack(list(self.__dict__.values()), dim=dim)

    def as_concatenated_tensor(self, dim: int = 0) -> torch.Tensor:
        return torch.cat(list(self.__dict__.values()), dim=dim)

    def __getitem__(self, key: str) -> torch.Tensor:
        return self.__dict__[key]


InputType = BatchedInputOutput
OutputType = BatchedInputOutput
LossOutputType = Any
MetricOutputType = Any

ModelType = nn.Module
OptimizerType = optim.Optimizer
SchedulerType = optim.lr_scheduler._LRScheduler
MetricType = tm.Metric
LossType = nn.Module

OptimizationBundle = Any

LossHandler = Any
MetricHandler = Any
AugmentationHandler = Any
InferenceHandler = Any

#### **End to end base**

In [20]:
import warnings
from typing import Any, Dict, Optional, Tuple, Union
import pytorch_lightning as pl

#from audiocraft.models import encodec
import torch
from torch import nn
from torch.nn import functional as F

# from ...types import (
#     BatchedInputOutput,
#     InputType,
#     OperationMode,
#     OutputType,
#     SimpleishNamespace,
#     TensorCollection
# )

import torchaudio as ta


class BaseEndToEndModule(pl.LightningModule):

    def __init__(
        self,
    ) -> None:
        super().__init__()


if __name__ == "__main__":
    model = BaseEndToEndModule()
    print(model)
    print(model.__class__.__name__)
    print(model.__module__)

BaseEndToEndModule()
BaseEndToEndModule
__main__


#### **Querier**

In [13]:
#PaSST
import torch
import torchaudio as ta
from hear21passt.base import get_basic_model
from torch import nn

class Passt(nn.Module):

    PASST_EMB_DIM: int = 768
    PASST_FS: int = 32000

    def __init__(
        self,
        original_fs: int=44100,
        passt_fs: int=PASST_FS,
    ):
        super().__init__()

        self.passt = get_basic_model(mode="embed_only", arch="openmic").eval()
        self.resample = ta.transforms.Resample(
            orig_freq=original_fs, new_freq=passt_fs
        ).eval()

        for p in self.passt.parameters():
            p.requires_grad = False

    def forward(self, x):
        """
        Forward pass of the PasstWrapper model.

        Args:
            qspec (torch.Tensor): Query spectrogram.
            qaudio (torch.Tensor): Query audio.

        Returns:
            torch.Tensor: Embedding output.
        """
        with torch.no_grad():
            x = torch.mean(x, dim=1)
            x = self.resample(x)

            specs = self.passt.mel(x)[..., :998]
            specs = specs[:, None, ...]
            _, z = self.passt.net(specs)

        return z


class PasstWrapper(nn.Module):

    PASST_EMB_DIM: int = 768
    PASST_FS: int = 32000

    def __init__(
        self,
        cond_emb_dim: int = 384,
        original_cond_emb_dim=PASST_EMB_DIM,
        original_fs: int=44100,
        passt_fs: int=PASST_FS,
    ):
        super().__init__()
        self.cond_emb_dim = cond_emb_dim

        self.passt = get_basic_model(mode="embed_only", arch="openmic").eval()
        self.proj = nn.Linear(original_cond_emb_dim, cond_emb_dim) if cond_emb_dim is not None else nn.Identity()
        self.resample = ta.transforms.Resample(
            orig_freq=original_fs, new_freq=passt_fs
        ).eval()

        for p in self.passt.parameters():
            p.requires_grad = False

    def forward(self, qspec, qaudio):
        """
        Forward pass of the PasstWrapper model.

        Args:
            qspec (torch.Tensor): Query spectrogram.
            qaudio (torch.Tensor): Query audio.

        Returns:
            torch.Tensor: Embedding output.
        """
        with torch.no_grad():
            x = torch.mean(qaudio, dim=1)
            x = self.resample(x)

            specs = self.passt.mel(x)[..., :998]
            specs = specs[:, None, ...]
            _, z = self.passt.net(specs)

        z = self.proj(z)

        return z

#### **Bandit**

In [5]:
# Utils
import os
from abc import abstractmethod
from typing import Any, Callable

import numpy as np
import torch
from librosa import hz_to_midi, midi_to_hz
from torch import Tensor
from torchaudio import functional as taF
# from spafe.fbanks import bark_fbanks
# from spafe.utils.converters import erb2hz, hz2bark, hz2erb
from torchaudio.functional.functional import _create_triangular_filterbank


def band_widths_from_specs(band_specs):
    return [e - i for i, e in band_specs]


def check_nonzero_bandwidth(band_specs):
    # pprint(band_specs)
    for fstart, fend in band_specs:
        if fend - fstart <= 0:
            raise ValueError("Bands cannot be zero-width")


def check_no_overlap(band_specs):
    fend_prev = -1
    for fstart_curr, fend_curr in band_specs:
        if fstart_curr <= fend_prev:
            raise ValueError("Bands cannot overlap")


def check_no_gap(band_specs):
    fstart, _ = band_specs[0]
    assert fstart == 0

    fend_prev = -1
    for fstart_curr, fend_curr in band_specs:
        if fstart_curr - fend_prev > 1:
            raise ValueError("Bands cannot leave gap")
        fend_prev = fend_curr


class BandsplitSpecification:
    def __init__(self, nfft: int, fs: int) -> None:
        self.fs = fs
        self.nfft = nfft
        self.nyquist = fs / 2
        self.max_index = nfft // 2 + 1

        self.split500 = self.hertz_to_index(500)
        self.split1k = self.hertz_to_index(1000)
        self.split2k = self.hertz_to_index(2000)
        self.split4k = self.hertz_to_index(4000)
        self.split8k = self.hertz_to_index(8000)
        self.split16k = self.hertz_to_index(16000)
        self.split20k = self.hertz_to_index(20000)

        self.above20k = [(self.split20k, self.max_index)]
        self.above16k = [(self.split16k, self.split20k)] + self.above20k

    def index_to_hertz(self, index: int):
        return index * self.fs / self.nfft

    def hertz_to_index(self, hz: float, round: bool = True):
        index = hz * self.nfft / self.fs

        if round:
            index = int(np.round(index))

        return index

    def get_band_specs_with_bandwidth(
            self,
            start_index,
            end_index,
            bandwidth_hz
            ):
        band_specs = []
        lower = start_index

        while lower < end_index:
            upper = int(np.floor(lower + self.hertz_to_index(bandwidth_hz)))
            upper = min(upper, end_index)

            band_specs.append((lower, upper))
            lower = upper

        return band_specs

    @abstractmethod
    def get_band_specs(self):
        raise NotImplementedError


class VocalBandsplitSpecification(BandsplitSpecification):
    def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
        super().__init__(nfft=nfft, fs=fs)

        self.version = version

    def get_band_specs(self):
        return getattr(self, f"version{self.version}")()

    @property
    def version1(self):
        return self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.max_index, bandwidth_hz=1000
        )

    def version2(self):
        below16k = self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.split16k, bandwidth_hz=1000
        )
        below20k = self.get_band_specs_with_bandwidth(
                start_index=self.split16k,
                end_index=self.split20k,
                bandwidth_hz=2000
        )

        return below16k + below20k + self.above20k

    def version3(self):
        below8k = self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.split8k, bandwidth_hz=1000
        )
        below16k = self.get_band_specs_with_bandwidth(
                start_index=self.split8k,
                end_index=self.split16k,
                bandwidth_hz=2000
        )

        return below8k + below16k + self.above16k

    def version4(self):
        below1k = self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.split1k, bandwidth_hz=100
        )
        below8k = self.get_band_specs_with_bandwidth(
                start_index=self.split1k,
                end_index=self.split8k,
                bandwidth_hz=1000
        )
        below16k = self.get_band_specs_with_bandwidth(
                start_index=self.split8k,
                end_index=self.split16k,
                bandwidth_hz=2000
        )

        return below1k + below8k + below16k + self.above16k

    def version5(self):
        below1k = self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.split1k, bandwidth_hz=100
        )
        below16k = self.get_band_specs_with_bandwidth(
                start_index=self.split1k,
                end_index=self.split16k,
                bandwidth_hz=1000
        )
        below20k = self.get_band_specs_with_bandwidth(
                start_index=self.split16k,
                end_index=self.split20k,
                bandwidth_hz=2000
        )
        return below1k + below16k + below20k + self.above20k

    def version6(self):
        below1k = self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.split1k, bandwidth_hz=100
        )
        below4k = self.get_band_specs_with_bandwidth(
                start_index=self.split1k,
                end_index=self.split4k,
                bandwidth_hz=500
        )
        below8k = self.get_band_specs_with_bandwidth(
                start_index=self.split4k,
                end_index=self.split8k,
                bandwidth_hz=1000
        )
        below16k = self.get_band_specs_with_bandwidth(
                start_index=self.split8k,
                end_index=self.split16k,
                bandwidth_hz=2000
        )
        return below1k + below4k + below8k + below16k + self.above16k

    def version7(self):
        below1k = self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.split1k, bandwidth_hz=100
        )
        below4k = self.get_band_specs_with_bandwidth(
                start_index=self.split1k,
                end_index=self.split4k,
                bandwidth_hz=250
        )
        below8k = self.get_band_specs_with_bandwidth(
                start_index=self.split4k,
                end_index=self.split8k,
                bandwidth_hz=500
        )
        below16k = self.get_band_specs_with_bandwidth(
                start_index=self.split8k,
                end_index=self.split16k,
                bandwidth_hz=1000
        )
        below20k = self.get_band_specs_with_bandwidth(
                start_index=self.split16k,
                end_index=self.split20k,
                bandwidth_hz=2000
        )
        return below1k + below4k + below8k + below16k + below20k + self.above20k


class OtherBandsplitSpecification(VocalBandsplitSpecification):
    def __init__(self, nfft: int, fs: int) -> None:
        super().__init__(nfft=nfft, fs=fs, version="7")


class BassBandsplitSpecification(BandsplitSpecification):
    def __init__(self, nfft: int, fs: int, version: str = "7") -> None:
        super().__init__(nfft=nfft, fs=fs)

    def get_band_specs(self):
        below500 = self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.split500, bandwidth_hz=50
        )
        below1k = self.get_band_specs_with_bandwidth(
                start_index=self.split500,
                end_index=self.split1k,
                bandwidth_hz=100
        )
        below4k = self.get_band_specs_with_bandwidth(
                start_index=self.split1k,
                end_index=self.split4k,
                bandwidth_hz=500
        )
        below8k = self.get_band_specs_with_bandwidth(
                start_index=self.split4k,
                end_index=self.split8k,
                bandwidth_hz=1000
        )
        below16k = self.get_band_specs_with_bandwidth(
                start_index=self.split8k,
                end_index=self.split16k,
                bandwidth_hz=2000
        )
        above16k = [(self.split16k, self.max_index)]

        return below500 + below1k + below4k + below8k + below16k + above16k


class DrumBandsplitSpecification(BandsplitSpecification):
    def __init__(self, nfft: int, fs: int) -> None:
        super().__init__(nfft=nfft, fs=fs)

    def get_band_specs(self):
        below1k = self.get_band_specs_with_bandwidth(
                start_index=0, end_index=self.split1k, bandwidth_hz=50
        )
        below2k = self.get_band_specs_with_bandwidth(
                start_index=self.split1k,
                end_index=self.split2k,
                bandwidth_hz=100
        )
        below4k = self.get_band_specs_with_bandwidth(
                start_index=self.split2k,
                end_index=self.split4k,
                bandwidth_hz=250
        )
        below8k = self.get_band_specs_with_bandwidth(
                start_index=self.split4k,
                end_index=self.split8k,
                bandwidth_hz=500
        )
        below16k = self.get_band_specs_with_bandwidth(
                start_index=self.split8k,
                end_index=self.split16k,
                bandwidth_hz=1000
        )
        above16k = [(self.split16k, self.max_index)]

        return below1k + below2k + below4k + below8k + below16k + above16k




class PerceptualBandsplitSpecification(BandsplitSpecification):
    def __init__(
            self,
            nfft: int,
            fs: int,
            fbank_fn: Callable[[int, int, float, float, int], torch.Tensor],
            n_bands: int,
            f_min: float = 0.0,
            f_max: float = None
    ) -> None:
        super().__init__(nfft=nfft, fs=fs)
        self.n_bands = n_bands
        if f_max is None:
            f_max = fs / 2

        self.filterbank = fbank_fn(
                n_bands, fs, f_min, f_max, self.max_index
        )

        weight_per_bin = torch.sum(
            self.filterbank,
            dim=0,
            keepdim=True
            )  # (1, n_freqs)
        normalized_mel_fb = self.filterbank / weight_per_bin  # (n_mels, n_freqs)

        freq_weights = []
        band_specs = []
        for i in range(self.n_bands):
            active_bins = torch.nonzero(self.filterbank[i, :]).squeeze().tolist()
            if isinstance(active_bins, int):
                active_bins = (active_bins, active_bins)
            if len(active_bins) == 0:
                continue
            start_index = active_bins[0]
            end_index = active_bins[-1] + 1
            band_specs.append((start_index, end_index))
            freq_weights.append(normalized_mel_fb[i, start_index:end_index])

        self.freq_weights = freq_weights
        self.band_specs = band_specs

    def get_band_specs(self):
        return self.band_specs

    def get_freq_weights(self):
        return self.freq_weights

    def save_to_file(self, dir_path: str) -> None:

        os.makedirs(dir_path, exist_ok=True)

        import pickle

        with open(os.path.join(dir_path, "mel_bandsplit_spec.pkl"), "wb") as f:
            pickle.dump(
                    {
                            "band_specs": self.band_specs,
                            "freq_weights": self.freq_weights,
                            "filterbank": self.filterbank,
                    },
                    f,
            )

def mel_filterbank(n_bands, fs, f_min, f_max, n_freqs):
    fb = taF.melscale_fbanks(
                n_mels=n_bands,
                sample_rate=fs,
                f_min=f_min,
                f_max=f_max,
                n_freqs=n_freqs,
        ).T

    fb[0, 0] = 1.0

    return fb


class MelBandsplitSpecification(PerceptualBandsplitSpecification):
    def __init__(
            self,
            nfft: int,
            fs: int,
            n_bands: int,
            f_min: float = 0.0,
            f_max: float = None
    ) -> None:
        super().__init__(fbank_fn=mel_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)

def musical_filterbank(n_bands, fs, f_min, f_max, n_freqs,
                       scale="constant"):

    nfft = 2 * (n_freqs - 1)
    df = fs / nfft
    # init freqs
    f_max = f_max or fs / 2
    f_min = f_min or 0
    f_min = fs / nfft

    n_octaves = np.log2(f_max / f_min)
    n_octaves_per_band = n_octaves / n_bands
    bandwidth_mult = np.power(2.0, n_octaves_per_band)

    low_midi = max(0, hz_to_midi(f_min))
    high_midi = hz_to_midi(f_max)
    midi_points = np.linspace(low_midi, high_midi, n_bands)
    hz_pts = midi_to_hz(midi_points)

    low_pts = hz_pts / bandwidth_mult
    high_pts = hz_pts * bandwidth_mult

    low_bins = np.floor(low_pts / df).astype(int)
    high_bins = np.ceil(high_pts / df).astype(int)

    fb = np.zeros((n_bands, n_freqs))

    for i in range(n_bands):
        fb[i, low_bins[i]:high_bins[i]+1] = 1.0

    fb[0, :low_bins[0]] = 1.0
    fb[-1, high_bins[-1]+1:] = 1.0

    return torch.as_tensor(fb)

class MusicalBandsplitSpecification(PerceptualBandsplitSpecification):
    def __init__(
            self,
            nfft: int,
            fs: int,
            n_bands: int,
            f_min: float = 0.0,
            f_max: float = None
    ) -> None:
        super().__init__(fbank_fn=musical_filterbank, nfft=nfft, fs=fs, n_bands=n_bands, f_min=f_min, f_max=f_max)


if __name__ == "__main__":
    import pandas as pd

    band_defs = []

    for bands in [VocalBandsplitSpecification]:  
        band_name = bands.__name__.replace("BandsplitSpecification", "")

        mbs = bands(nfft=2048, fs=44100).get_band_specs()

        for i, (f_min, f_max) in enumerate(mbs):
            band_defs.append({
                "band": band_name,
                "band_index": i,
                "f_min": f_min,
                "f_max": f_max
            })

    df = pd.DataFrame(band_defs)
    df.to_csv("vox7bands.csv", index=False)

In [9]:
# Time Frequency Model
import warnings

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules import rnn

import torch.backends.cuda


class TimeFrequencyModellingModule(nn.Module):
    def __init__(self) -> None:
        super().__init__()


class ResidualRNN(nn.Module):
    def __init__(
            self,
            emb_dim: int,
            rnn_dim: int,
            bidirectional: bool = True,
            rnn_type: str = "LSTM",
            use_batch_trick: bool = True,
            use_layer_norm: bool = True,
    ) -> None:
        super().__init__()

        self.use_layer_norm = use_layer_norm
        if use_layer_norm:
            self.norm = nn.LayerNorm(emb_dim)
        else:
            self.norm = nn.GroupNorm(num_groups=emb_dim, num_channels=emb_dim)

        self.rnn = rnn.__dict__[rnn_type](
            input_size=emb_dim,
            hidden_size=rnn_dim,
            num_layers=1,
            batch_first=True,
            bidirectional=bidirectional,
        )

        self.fc = nn.Linear(
            in_features=rnn_dim * (2 if bidirectional else 1),
            out_features=emb_dim
        )

        self.use_batch_trick = use_batch_trick
        if not self.use_batch_trick:
            warnings.warn("NOT USING BATCH TRICK IS EXTREMELY SLOW!!")

    def forward(self, z):
        # z = (batch, n_uncrossed, n_across, emb_dim)

        z0 = torch.clone(z)

        if self.use_layer_norm:
            z = self.norm(z)  # (batch, n_uncrossed, n_across, emb_dim)
        else:
            z = torch.permute(
                z, (0, 3, 1, 2)
            )  # (batch, emb_dim, n_uncrossed, n_across)

            z = self.norm(z)  # (batch, emb_dim, n_uncrossed, n_across)

            z = torch.permute(
                z, (0, 2, 3, 1)
            )  # (batch, n_uncrossed, n_across, emb_dim)

        batch, n_uncrossed, n_across, emb_dim = z.shape

        if self.use_batch_trick:
            z = torch.reshape(z, (batch * n_uncrossed, n_across, emb_dim))
            z = self.rnn(z.contiguous())[0]  # (batch * n_uncrossed, n_across, dir_rnn_dim)

            z = torch.reshape(z, (batch, n_uncrossed, n_across, -1))
            # (batch, n_uncrossed, n_across, dir_rnn_dim)
        else:
            # Note: this is EXTREMELY SLOW
            zlist = []
            for i in range(n_uncrossed):
                zi = self.rnn(z[:, i, :, :])[0]  # (batch, n_across, emb_dim)
                zlist.append(zi)

            z = torch.stack(
                zlist,
                dim=1
            )  # (batch, n_uncrossed, n_across, dir_rnn_dim)

        z = self.fc(z)  # (batch, n_uncrossed, n_across, emb_dim)

        z = z + z0

        return z


class SeqBandModellingModule(TimeFrequencyModellingModule):
    def __init__(
            self,
            n_modules: int = 12,
            emb_dim: int = 128,
            rnn_dim: int = 256,
            bidirectional: bool = True,
            rnn_type: str = "LSTM",
            parallel_mode=False,
    ) -> None:
        super().__init__()
        self.seqband = nn.ModuleList([])

        if parallel_mode:
            for _ in range(n_modules):
                self.seqband.append(
                    nn.ModuleList(
                        [ResidualRNN(
                            emb_dim=emb_dim,
                            rnn_dim=rnn_dim,
                            bidirectional=bidirectional,
                            rnn_type=rnn_type,
                        ),
                            ResidualRNN(
                                emb_dim=emb_dim,
                                rnn_dim=rnn_dim,
                                bidirectional=bidirectional,
                                rnn_type=rnn_type,
                            )]
                    )
                )
        else:

            for _ in range(2 * n_modules):
                self.seqband.append(
                    ResidualRNN(
                        emb_dim=emb_dim,
                        rnn_dim=rnn_dim,
                        bidirectional=bidirectional,
                        rnn_type=rnn_type,
                    )
                )

        self.parallel_mode = parallel_mode

    def forward(self, z):
        # z = (batch, n_bands, n_time, emb_dim)

        if self.parallel_mode:
            for sbm_pair in self.seqband:
                # z: (batch, n_bands, n_time, emb_dim)
                sbm_t, sbm_f = sbm_pair[0], sbm_pair[1]
                zt = sbm_t(z)  # (batch, n_bands, n_time, emb_dim)
                zf = sbm_f(z.transpose(1, 2))  # (batch, n_time, n_bands, emb_dim)
                z = zt + zf.transpose(1, 2)
        else:
            for sbm in self.seqband:
                z = sbm(z)
                z = z.transpose(1, 2)

                # (batch, n_bands, n_time, emb_dim)
                #   --> (batch, n_time, n_bands, emb_dim)
                # OR
                # (batch, n_time, n_bands, emb_dim)
                #   --> (batch, n_bands, n_time, emb_dim)

        q = z
        return q  # (batch, n_bands, n_time, emb_dim)


# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Example of moving model and data to the correct device
model = SeqBandModellingModule().to(device)

# # Example data tensor
# input_data = torch.randn(32, 64, 128, 128).to(device)  # Adjust dimensions as needed

# # Forward pass
# output = model(input_data)

# print(f"Output shape: {output.shape}")

'''Commented this example input as memory is not sufficient on my disc'''


'Commented this example input as memory is not sufficient on my disc'

In [10]:
# Mask estimation
import warnings
from typing import Dict, List, Optional, Tuple, Type

import torch
from torch import nn
from torch.nn.modules import activation

# from core.models.e2e.bandit.utils import (
#     band_widths_from_specs,
#     check_no_gap,
#     check_no_overlap,
#     check_nonzero_bandwidth,
# )


class BaseNormMLP(nn.Module):
    def __init__(
            self,
            emb_dim: int,
            mlp_dim: int,
            bandwidth: int,
            in_channel: Optional[int],
            hidden_activation: str = "Tanh",
            hidden_activation_kwargs=None,
            complex_mask: bool = True, ):

        super().__init__()
        if hidden_activation_kwargs is None:
            hidden_activation_kwargs = {}
        self.hidden_activation_kwargs = hidden_activation_kwargs
        self.norm = nn.LayerNorm(emb_dim)
        self.hidden = torch.jit.script(nn.Sequential(
                nn.Linear(in_features=emb_dim, out_features=mlp_dim),
                activation.__dict__[hidden_activation](
                        **self.hidden_activation_kwargs
                ),
        ))

        self.bandwidth = bandwidth
        self.in_channel = in_channel

        self.complex_mask = complex_mask
        self.reim = 2 if complex_mask else 1
        self.glu_mult = 2


class NormMLP(BaseNormMLP):
    def __init__(
            self,
            emb_dim: int,
            mlp_dim: int,
            bandwidth: int,
            in_channel: Optional[int],
            hidden_activation: str = "Tanh",
            hidden_activation_kwargs=None,
            complex_mask: bool = True,
    ) -> None:
        super().__init__(
                emb_dim=emb_dim,
                mlp_dim=mlp_dim,
                bandwidth=bandwidth,
                in_channel=in_channel,
                hidden_activation=hidden_activation,
                hidden_activation_kwargs=hidden_activation_kwargs,
                complex_mask=complex_mask,
        )

        self.output = torch.jit.script(
                nn.Sequential(
                        nn.Linear(
                                in_features=mlp_dim,
                                out_features=bandwidth * in_channel * self.reim * 2,
                        ),
                        nn.GLU(dim=-1),
                )
        )

    def reshape_output(self, mb):
        # print(mb.shape)
        batch, n_time, _ = mb.shape
        if self.complex_mask:
            mb = mb.reshape(
                    batch,
                    n_time,
                    self.in_channel,
                    self.bandwidth,
                    self.reim
            ).contiguous()
            # print(mb.shape)
            mb = torch.view_as_complex(
                    mb
            )  # (batch, n_time, in_channel, bandwidth)
        else:
            mb = mb.reshape(batch, n_time, self.in_channel, self.bandwidth)

        mb = torch.permute(
                mb,
                (0, 2, 3, 1)
        )  # (batch, in_channel, bandwidth, n_time)

        return mb

    def forward(self, qb):
        # qb = (batch, n_time, emb_dim)

        # if torch.any(torch.isnan(qb)):
        #     raise ValueError("qb0")


        qb = self.norm(qb)  # (batch, n_time, emb_dim)

        # if torch.any(torch.isnan(qb)):
        #     raise ValueError("qb1")

        qb = self.hidden(qb)  # (batch, n_time, mlp_dim)
        # if torch.any(torch.isnan(qb)):
        #     raise ValueError("qb2")
        mb = self.output(qb)  # (batch, n_time, bandwidth * in_channel * reim)
        # if torch.any(torch.isnan(qb)):
        #     raise ValueError("mb")
        mb = self.reshape_output(mb)  # (batch, in_channel, bandwidth, n_time)

        return mb


# class MultAddNormMLP(NormMLP):
#     def __init__(self, emb_dim: int, mlp_dim: int, bandwidth: int, in_channel: int | None, hidden_activation: str = "Tanh", hidden_activation_kwargs=None, complex_mask: bool = True) -> None:
#         super().__init__(emb_dim, mlp_dim, bandwidth, in_channel, hidden_activation, hidden_activation_kwargs, complex_mask)

#         self.output2 = torch.jit.script(
#                 nn.Sequential(
#                         nn.Linear(
#                                 in_features=mlp_dim,
#                                 out_features=bandwidth * in_channel * self.reim * 2,
#                         ),
#                         nn.GLU(dim=-1),
#                 )
#         )

#     def forward(self, qb):

#         qb = self.norm(qb)  # (batch, n_time, emb_dim)
#         qb = self.hidden(qb)  # (batch, n_time, mlp_dim)
#         mmb = self.output(qb)  # (batch, n_time, bandwidth * in_channel * reim)
#         mmb = self.reshape_output(mmb)  # (batch, in_channel, bandwidth, n_time)
#         amb = self.output2(qb)  # (batch, n_time, bandwidth * in_channel * reim)
#         amb = self.reshape_output(amb)  # (batch, in_channel, bandwidth, n_time)

#         return mmb, amb


class MaskEstimationModuleSuperBase(nn.Module):
    pass


class MaskEstimationModuleBase(MaskEstimationModuleSuperBase):
    def __init__(
            self,
            band_specs: List[Tuple[float, float]],
            emb_dim: int,
            mlp_dim: int,
            in_channel: Optional[int],
            hidden_activation: str = "Tanh",
            hidden_activation_kwargs: Dict = None,
            complex_mask: bool = True,
            norm_mlp_cls: Type[nn.Module] = NormMLP,
            norm_mlp_kwargs: Dict = None,
    ) -> None:
        super().__init__()

        self.band_widths = band_widths_from_specs(band_specs)
        self.n_bands = len(band_specs)

        if hidden_activation_kwargs is None:
            hidden_activation_kwargs = {}

        if norm_mlp_kwargs is None:
            norm_mlp_kwargs = {}

        self.norm_mlp = nn.ModuleList(
                [
                        (
                                norm_mlp_cls(
                                        bandwidth=self.band_widths[b],
                                        emb_dim=emb_dim,
                                        mlp_dim=mlp_dim,
                                        in_channel=in_channel,
                                        hidden_activation=hidden_activation,
                                        hidden_activation_kwargs=hidden_activation_kwargs,
                                        complex_mask=complex_mask,
                                        **norm_mlp_kwargs,
                                )
                        )
                        for b in range(self.n_bands)
                ]
        )

    def compute_masks(self, q):
        batch, n_bands, n_time, emb_dim = q.shape

        masks = []

        for b, nmlp in enumerate(self.norm_mlp):
            # print(f"maskestim/{b:02d}")
            qb = q[:, b, :, :]
            mb = nmlp(qb)
            masks.append(mb)

        return masks



class OverlappingMaskEstimationModule(MaskEstimationModuleBase):
    def __init__(
            self,
            in_channel: int,
            band_specs: List[Tuple[float, float]],
            freq_weights: List[torch.Tensor],
            n_freq: int,
            emb_dim: int,
            mlp_dim: int,
            cond_dim: int = 0,
            hidden_activation: str = "Tanh",
            hidden_activation_kwargs: Dict = None,
            complex_mask: bool = True,
            norm_mlp_cls: Type[nn.Module] = NormMLP,
            norm_mlp_kwargs: Dict = None,
            use_freq_weights: bool = True,
    ) -> None:
        check_nonzero_bandwidth(band_specs)
        check_no_gap(band_specs)

        # if cond_dim > 0:
        #     raise NotImplementedError

        super().__init__(
                band_specs=band_specs,
                emb_dim=emb_dim + cond_dim,
                mlp_dim=mlp_dim,
                in_channel=in_channel,
                hidden_activation=hidden_activation,
                hidden_activation_kwargs=hidden_activation_kwargs,
                complex_mask=complex_mask,
                norm_mlp_cls=norm_mlp_cls,
                norm_mlp_kwargs=norm_mlp_kwargs,
        )

        self.n_freq = n_freq
        self.band_specs = band_specs
        self.in_channel = in_channel

        if freq_weights is not None:
            for i, fw in enumerate(freq_weights):
                self.register_buffer(f"freq_weights/{i}", fw)

                self.use_freq_weights = use_freq_weights
        else:
            self.use_freq_weights = False

        self.cond_dim = cond_dim

    def forward(self, q, cond=None):
        # q = (batch, n_bands, n_time, emb_dim)

        batch, n_bands, n_time, emb_dim = q.shape

        if cond is not None:
            print(cond)
            if cond.ndim == 2:
                cond = cond[:, None, None, :].expand(-1, n_bands, n_time, -1)
            elif cond.ndim == 3:
                assert cond.shape[1] == n_time
            else:
                raise ValueError(f"Invalid cond shape: {cond.shape}")

            q = torch.cat([q, cond], dim=-1)
        elif self.cond_dim > 0:
            cond = torch.ones(
                    (batch, n_bands, n_time, self.cond_dim),
                    device=q.device,
                    dtype=q.dtype,
            )
            q = torch.cat([q, cond], dim=-1)
        else:
            pass

        mask_list = self.compute_masks(
                q
        )  # [n_bands  * (batch, in_channel, bandwidth, n_time)]

        masks = torch.zeros(
                (batch, self.in_channel, self.n_freq, n_time),
                device=q.device,
                dtype=mask_list[0].dtype,
        )

        for im, mask in enumerate(mask_list):
            fstart, fend = self.band_specs[im]
            if self.use_freq_weights:
                fw = self.get_buffer(f"freq_weights/{im}")[:, None]
                mask = mask * fw
            masks[:, :, fstart:fend, :] += mask

        return masks


class MaskEstimationModule(OverlappingMaskEstimationModule):
    def __init__(
            self,
            band_specs: List[Tuple[float, float]],
            emb_dim: int,
            mlp_dim: int,
            in_channel: Optional[int],
            hidden_activation: str = "Tanh",
            hidden_activation_kwargs: Dict = None,
            complex_mask: bool = True,
            **kwargs,
    ) -> None:
        check_nonzero_bandwidth(band_specs)
        check_no_gap(band_specs)
        check_no_overlap(band_specs)
        super().__init__(
                in_channel=in_channel,
                band_specs=band_specs,
                freq_weights=None,
                n_freq=None,
                emb_dim=emb_dim,
                mlp_dim=mlp_dim,
                hidden_activation=hidden_activation,
                hidden_activation_kwargs=hidden_activation_kwargs,
                complex_mask=complex_mask,
        )

    def forward(self, q, cond=None):
        # q = (batch, n_bands, n_time, emb_dim)

        masks = self.compute_masks(
                q
        )  # [n_bands  * (batch, in_channel, bandwidth, n_time)]

        # TODO: currently this requires band specs to have no gap and no overlap
        masks = torch.concat(
                masks,
                dim=2
        )  # (batch, in_channel, n_freq, n_time)

        return masks

In [11]:
# Band split
from typing import List, Tuple

import torch
from torch import nn

class NormFC(nn.Module):
    def __init__(
            self,
            emb_dim: int,
            bandwidth: int,
            in_channel: int,
            normalize_channel_independently: bool = False,
            treat_channel_as_feature: bool = True,
    ) -> None:
        super().__init__()

        self.treat_channel_as_feature = treat_channel_as_feature

        if normalize_channel_independently:
            raise NotImplementedError

        reim = 2

        self.norm = nn.LayerNorm(in_channel * bandwidth * reim)

        fc_in = bandwidth * reim

        if treat_channel_as_feature:
            fc_in *= in_channel
        else:
            assert emb_dim % in_channel == 0
            emb_dim = emb_dim // in_channel

        self.fc = nn.Linear(fc_in, emb_dim)

    def forward(self, xb):
        # xb = (batch, n_time, in_chan, reim * band_width)

        batch, n_time, in_chan, ribw = xb.shape
        xb = self.norm(xb.reshape(batch, n_time, in_chan * ribw))
        # (batch, n_time, in_chan * reim * band_width)

        if not self.treat_channel_as_feature:
            xb = xb.reshape(batch, n_time, in_chan, ribw)
            # (batch, n_time, in_chan, reim * band_width)

        zb = self.fc(xb)
        # (batch, n_time, emb_dim)
        # OR
        # (batch, n_time, in_chan, emb_dim_per_chan)

        if not self.treat_channel_as_feature:
            batch, n_time, in_chan, emb_dim_per_chan = zb.shape
            # (batch, n_time, in_chan, emb_dim_per_chan)
            zb = zb.reshape((batch, n_time, in_chan * emb_dim_per_chan))

        return zb  # (batch, n_time, emb_dim)


class BandSplitModule(nn.Module):
    def __init__(
            self,
            band_specs: List[Tuple[float, float]],
            emb_dim: int,
            in_channel: int,
            require_no_overlap: bool = False,
            require_no_gap: bool = True,
            normalize_channel_independently: bool = False,
            treat_channel_as_feature: bool = True,
    ) -> None:
        super().__init__()

        check_nonzero_bandwidth(band_specs)

        if require_no_gap:
            check_no_gap(band_specs)

        if require_no_overlap:
            check_no_overlap(band_specs)

        self.band_specs = band_specs
        # list of [fstart, fend) in index.
        # Note that fend is exclusive.
        self.band_widths = band_widths_from_specs(band_specs)
        self.n_bands = len(band_specs)
        self.emb_dim = emb_dim

        self.norm_fc_modules = nn.ModuleList(
                [  # type: ignore
                        (
                                NormFC(
                                        emb_dim=emb_dim,
                                        bandwidth=bw,
                                        in_channel=in_channel,
                                        normalize_channel_independently=normalize_channel_independently,
                                        treat_channel_as_feature=treat_channel_as_feature,
                                )
                        )
                        for bw in self.band_widths
                ]
        )

    def forward(self, x: torch.Tensor):
        # x = complex spectrogram (batch, in_chan, n_freq, n_time)

        batch, in_chan, _, n_time = x.shape

        z = torch.zeros(
            size=(batch, self.n_bands, n_time, self.emb_dim),
            device=x.device
        )

        xr = torch.view_as_real(x)  # batch, in_chan, n_freq, n_time, 2
        xr = torch.permute(
            xr,
            (0, 3, 1, 4, 2)
            )  # batch, n_time, in_chan, 2, n_freq
        batch, n_time, in_chan, reim, band_width = xr.shape
        for i, nfm in enumerate(self.norm_fc_modules):
            # print(f"bandsplit/band{i:02d}")
            fstart, fend = self.band_specs[i]
            xb = xr[..., fstart:fend]
            # (batch, n_time, in_chan, reim, band_width)
            xb = torch.reshape(xb, (batch, n_time, in_chan, -1))
            # (batch, n_time, in_chan, reim * band_width)
            # z.append(nfm(xb))  # (batch, n_time, emb_dim)
            z[:, i, :, :] = nfm(xb.contiguous())

        # z = torch.stack(z, dim=1)

        return z

In [21]:
# Bandit
from typing import Dict, List, Optional, Tuple
from torch import Tensor, nn
import torch
import torchaudio as ta



class BaseBandit(BaseEndToEndModule):

    def __init__(
        self,
        in_channel: int,
        band_type: str = "musical",
        n_bands: int = 64,
        require_no_overlap: bool = False,
        require_no_gap: bool = True,
        normalize_channel_independently: bool = False,
        treat_channel_as_feature: bool = True,
        n_sqm_modules: int = 12,
        emb_dim: int = 128,
        rnn_dim: int = 256,
        bidirectional: bool = True,
        rnn_type: str = "LSTM",
        n_fft: int = 2048,
        win_length: Optional[int] = 2048,
        hop_length: int = 512,
        window_fn: str = "hann_window",
        wkwargs: Optional[Dict] = None,
        power: Optional[int] = None,
        center: bool = True,
        normalized: bool = True,
        pad_mode: str = "constant",
        onesided: bool = True,
        fs: int = 44100,
    ):
        super().__init__()

        self.instantitate_spectral(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window_fn=window_fn,
            wkwargs=wkwargs,
            power=power,
            normalized=normalized,
            center=center,
            pad_mode=pad_mode,
            onesided=onesided,
        )

        self.instantiate_bandsplit(
            in_channel=in_channel,
            band_type=band_type,
            n_bands=n_bands,
            require_no_overlap=require_no_overlap,
            require_no_gap=require_no_gap,
            normalize_channel_independently=normalize_channel_independently,
            treat_channel_as_feature=treat_channel_as_feature,
            emb_dim=emb_dim,
            n_fft=n_fft,
            fs=fs,
        )

        self.instantiate_tf_modelling(
            n_sqm_modules=n_sqm_modules,
            emb_dim=emb_dim,
            rnn_dim=rnn_dim,
            bidirectional=bidirectional,
            rnn_type=rnn_type,
        )

    def instantitate_spectral(
        self,
        n_fft: int = 2048,
        win_length: Optional[int] = 2048,
        hop_length: int = 512,
        window_fn: str = "hann_window",
        wkwargs: Optional[Dict] = None,
        power: Optional[int] = None,
        normalized: bool = True,
        center: bool = True,
        pad_mode: str = "constant",
        onesided: bool = True,
    ):

        assert power is None

        window_fn = torch.__dict__[window_fn]

        self.stft = ta.transforms.Spectrogram(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            pad_mode=pad_mode,
            pad=0,
            window_fn=window_fn,
            wkwargs=wkwargs,
            power=power,
            normalized=normalized,
            center=center,
            onesided=onesided,
        )

        self.istft = ta.transforms.InverseSpectrogram(
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            pad_mode=pad_mode,
            pad=0,
            window_fn=window_fn,
            wkwargs=wkwargs,
            normalized=normalized,
            center=center,
            onesided=onesided,
        )

    def instantiate_bandsplit(
        self,
        in_channel: int,
        band_type: str = "musical",
        n_bands: int = 64,
        require_no_overlap: bool = False,
        require_no_gap: bool = True,
        normalize_channel_independently: bool = False,
        treat_channel_as_feature: bool = True,
        emb_dim: int = 128,
        n_fft: int = 2048,
        fs: int = 44100,
    ):

        assert band_type == "musical"

        self.band_specs = MusicalBandsplitSpecification(
            nfft=n_fft, fs=fs, n_bands=n_bands
        )

        self.band_split = BandSplitModule(
            in_channel=in_channel,
            band_specs=self.band_specs.get_band_specs(),
            require_no_overlap=require_no_overlap,
            require_no_gap=require_no_gap,
            normalize_channel_independently=normalize_channel_independently,
            treat_channel_as_feature=treat_channel_as_feature,
            emb_dim=emb_dim,
        )

    def instantiate_tf_modelling(
        self,
        n_sqm_modules: int = 12,
        emb_dim: int = 128,
        rnn_dim: int = 256,
        bidirectional: bool = True,
        rnn_type: str = "LSTM",
    ):
        self.tf_model = SeqBandModellingModule(
            n_modules=n_sqm_modules,
            emb_dim=emb_dim,
            rnn_dim=rnn_dim,
            bidirectional=bidirectional,
            rnn_type=rnn_type,
        )

    def mask(self, x, m):
        return x * m

    def forward(self, batch: InputType, mode: OperationMode = OperationMode.TRAIN):

        with torch.no_grad():
            x = self.stft(batch.mixture.audio)
            batch.mixture.spectrogram = x

            if "sources" in batch.keys():
                for stem in batch.sources.keys():
                    s = batch.sources[stem].audio
                    s = self.stft(s)
                    batch.sources[stem].spectrogram = s

        batch = self.separate(batch)

        return batch

    def encode(self, batch):
        x = batch.mixture.spectrogram
        length = batch.mixture.audio.shape[-1]

        z = self.band_split(x)  # (batch, emb_dim, n_band, n_time)
        q = self.tf_model(z)  # (batch, emb_dim, n_band, n_time)

        return x, q, length

    def separate(self, batch):
        raise NotImplementedError


class Bandit(BaseBandit):
    def __init__(
        self,
        in_channel: int,
        stems: List[str],
        band_type: str = "musical",
        n_bands: int = 64,
        require_no_overlap: bool = False,
        require_no_gap: bool = True,
        normalize_channel_independently: bool = False,
        treat_channel_as_feature: bool = True,
        n_sqm_modules: int = 12,
        emb_dim: int = 128,
        rnn_dim: int = 256,
        bidirectional: bool = True,
        rnn_type: str = "LSTM",
        mlp_dim: int = 512,
        hidden_activation: str = "Tanh",
        hidden_activation_kwargs: Dict | None = None,
        complex_mask: bool = True,
        use_freq_weights: bool = True,
        n_fft: int = 2048,
        win_length: int | None = 2048,
        hop_length: int = 512,
        window_fn: str = "hann_window",
        wkwargs: Dict | None = None,
        power: int | None = None,
        center: bool = True,
        normalized: bool = True,
        pad_mode: str = "constant",
        onesided: bool = True,
        fs: int = 44100,
    ):
        super().__init__(
            in_channel=in_channel,
            band_type=band_type,
            n_bands=n_bands,
            require_no_overlap=require_no_overlap,
            require_no_gap=require_no_gap,
            normalize_channel_independently=normalize_channel_independently,
            treat_channel_as_feature=treat_channel_as_feature,
            n_sqm_modules=n_sqm_modules,
            emb_dim=emb_dim,
            rnn_dim=rnn_dim,
            bidirectional=bidirectional,
            rnn_type=rnn_type,
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window_fn=window_fn,
            wkwargs=wkwargs,
            power=power,
            center=center,
            normalized=normalized,
            pad_mode=pad_mode,
            onesided=onesided,
            fs=fs,
        )

        self.instantiate_mask_estim(
            in_channel=in_channel,
            stems=stems,
            emb_dim=emb_dim,
            mlp_dim=mlp_dim,
            hidden_activation=hidden_activation,
            hidden_activation_kwargs=hidden_activation_kwargs,
            complex_mask=complex_mask,
            n_freq=n_fft // 2 + 1,
            use_freq_weights=use_freq_weights,
        )

    def instantiate_mask_estim(
        self,
        in_channel: int,
        stems: List[str],
        emb_dim: int,
        mlp_dim: int,
        hidden_activation: str,
        hidden_activation_kwargs: Optional[Dict] = None,
        complex_mask: bool = True,
        n_freq: Optional[int] = None,
        use_freq_weights: bool = True,
    ):
        if hidden_activation_kwargs is None:
            hidden_activation_kwargs = {}

        assert n_freq is not None

        self.mask_estim = nn.ModuleDict(
            {
                stem: OverlappingMaskEstimationModule(
                    band_specs=self.band_specs.get_band_specs(),
                    freq_weights=self.band_specs.get_freq_weights(),
                    n_freq=n_freq,
                    emb_dim=emb_dim,
                    mlp_dim=mlp_dim,
                    in_channel=in_channel,
                    hidden_activation=hidden_activation,
                    hidden_activation_kwargs=hidden_activation_kwargs,
                    complex_mask=complex_mask,
                    use_freq_weights=use_freq_weights,
                )
                for stem in stems
            }
        )

    def separate(self, batch):

        x, q, length = self.encode(batch)

        for stem, mem in self.mask_estim.items():
            m = mem(q)
            s = self.mask(x, m)
            s = torch.reshape(s, x.shape)
            batch.estimates[stem] = SimpleishNamespace(
                audio=self.istft(s, length), spectrogram=s
            )

        return batch


class BaseConditionedBandit(BaseBandit):
    query_encoder: nn.Module

    def __init__(
        self,
        in_channel: int,
        band_type: str = "musical",
        n_bands: int = 64,
        require_no_overlap: bool = False,
        require_no_gap: bool = True,
        normalize_channel_independently: bool = False,
        treat_channel_as_feature: bool = True,
        n_sqm_modules: int = 12,
        emb_dim: int = 128,
        rnn_dim: int = 256,
        bidirectional: bool = True,
        rnn_type: str = "LSTM",
        mlp_dim: int = 512,
        hidden_activation: str = "Tanh",
        hidden_activation_kwargs: Dict | None = None,
        complex_mask: bool = True,
        use_freq_weights: bool = True,
        n_fft: int = 2048,
        win_length: int | None = 2048,
        hop_length: int = 512,
        window_fn: str = "hann_window",
        wkwargs: Dict | None = None,
        power: int | None = None,
        center: bool = True,
        normalized: bool = True,
        pad_mode: str = "constant",
        onesided: bool = True,
        fs: int = 44100,
    ):
        super().__init__(
            in_channel=in_channel,
            band_type=band_type,
            n_bands=n_bands,
            require_no_overlap=require_no_overlap,
            require_no_gap=require_no_gap,
            normalize_channel_independently=normalize_channel_independently,
            treat_channel_as_feature=treat_channel_as_feature,
            n_sqm_modules=n_sqm_modules,
            emb_dim=emb_dim,
            rnn_dim=rnn_dim,
            bidirectional=bidirectional,
            rnn_type=rnn_type,
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window_fn=window_fn,
            wkwargs=wkwargs,
            power=power,
            center=center,
            normalized=normalized,
            pad_mode=pad_mode,
            onesided=onesided,
            fs=fs,
        )

        self.instantiate_mask_estim(
            in_channel=in_channel,
            emb_dim=emb_dim,
            mlp_dim=mlp_dim,
            hidden_activation=hidden_activation,
            hidden_activation_kwargs=hidden_activation_kwargs,
            complex_mask=complex_mask,
            n_freq=n_fft // 2 + 1,
            use_freq_weights=use_freq_weights,
        )

    def instantiate_mask_estim(
        self,
        in_channel: int,
        emb_dim: int,
        mlp_dim: int,
        hidden_activation: str,
        hidden_activation_kwargs: Optional[Dict] = None,
        complex_mask: bool = True,
        n_freq: Optional[int] = None,
        use_freq_weights: bool = True,
    ):
        if hidden_activation_kwargs is None:
            hidden_activation_kwargs = {}

        assert n_freq is not None

        self.mask_estim = OverlappingMaskEstimationModule(
            band_specs=self.band_specs.get_band_specs(),
            freq_weights=self.band_specs.get_freq_weights(),
            n_freq=n_freq,
            emb_dim=emb_dim,
            mlp_dim=mlp_dim,
            in_channel=in_channel,
            hidden_activation=hidden_activation,
            hidden_activation_kwargs=hidden_activation_kwargs,
            complex_mask=complex_mask,
            use_freq_weights=use_freq_weights,
        )

    def separate(self, batch):

        x, q, length = self.encode(batch)

        q = self.adapt_query(q, batch)

        m = self.mask_estim(q)
        s = self.mask(x, m)
        s = torch.reshape(s, x.shape)
        batch.estimates["target"] = SimpleishNamespace(
            audio=self.istft(s, length), spectrogram=s
        )

        return batch

    def adapt_query(self, q, batch):
        raise NotImplementedError


class PasstFiLMConditionedBandit(BaseConditionedBandit):

    def __init__(
        self,
        in_channel: int,
        band_type: str = "musical",
        n_bands: int = 64,
        additive_film: bool = True,
        multiplicative_film: bool = True,
        film_depth: int = 2,
        require_no_overlap: bool = False,
        require_no_gap: bool = True,
        normalize_channel_independently: bool = False,
        treat_channel_as_feature: bool = True,
        n_sqm_modules: int = 12,
        emb_dim: int = 128,
        rnn_dim: int = 256,
        bidirectional: bool = True,
        rnn_type: str = "LSTM",
        mlp_dim: int = 512,
        hidden_activation: str = "Tanh",
        hidden_activation_kwargs: Dict | None = None,
        complex_mask: bool = True,
        use_freq_weights: bool = True,
        n_fft: int = 2048,
        win_length: int | None = 2048,
        hop_length: int = 512,
        window_fn: str = "hann_window",
        wkwargs: Dict | None = None,
        power: int | None = None,
        center: bool = True,
        normalized: bool = True,
        pad_mode: str = "constant",
        onesided: bool = True,
        fs: int = 44100,
        pretrain_encoder = None,
        freeze_encoder = False
    ):
        super().__init__(
            in_channel=in_channel,
            band_type=band_type,
            n_bands=n_bands,
            require_no_overlap=require_no_overlap,
            require_no_gap=require_no_gap,
            normalize_channel_independently=normalize_channel_independently,
            treat_channel_as_feature=treat_channel_as_feature,
            n_sqm_modules=n_sqm_modules,
            emb_dim=emb_dim,
            rnn_dim=rnn_dim,
            bidirectional=bidirectional,
            rnn_type=rnn_type,
            mlp_dim=mlp_dim,
            hidden_activation=hidden_activation,
            hidden_activation_kwargs=hidden_activation_kwargs,
            complex_mask=complex_mask,
            use_freq_weights=use_freq_weights,
            n_fft=n_fft,
            win_length=win_length,
            hop_length=hop_length,
            window_fn=window_fn,
            wkwargs=wkwargs,
            power=power,
            center=center,
            normalized=normalized,
            pad_mode=pad_mode,
            onesided=onesided,
            fs=fs,
        )

        self.query_encoder = Passt(
            original_fs=fs,
            passt_fs=32000,
        )
        
        self.film = FiLM(
            self.query_encoder.PASST_EMB_DIM,
            emb_dim,
            additive=additive_film,
            multiplicative=multiplicative_film,
            depth=film_depth,
        )
        
        if pretrain_encoder is not None:
            self.load_pretrained_encoder(pretrain_encoder)
            
            for p in self.band_split.parameters():
                p.requires_grad = not freeze_encoder
                
            for p in self.tf_model.parameters():
                p.requires_grad = not freeze_encoder
            
        
        
    def load_pretrained_encoder(self, path):
        
        state_dict = torch.load(path, map_location="cpu")["state_dict"]
        
        state_dict_ = {k.replace("model.", "") if k.startswith("model.") else k: v for k, v in state_dict.items()}
        
        state_dict = {}

        for k, v in state_dict_.items():
            if "mask_estim" in k:
                continue
            
            if "tf_seqband" in k:
                k = k.replace("tf_seqband", "tf_model.seqband")
            
            state_dict[k] = v
            
        
        res = self.load_state_dict(state_dict, strict=False)
        
        for k in res.unexpected_keys:
            if "mask_estim" in k:
                continue
            print(f"Unexpected key: {k}")
        
        for k in res.missing_keys:
            print(f"Missing key: {k}")
            for kw in ["band_split", "tf_model"]:
                if kw in k:
                    raise ValueError(f"Missing key: {k}")
                
            for kw in ["mask_estim", "query_encoder"]:
                if kw in k:
                    continue
            


    def adapt_query(self, q, batch):
        
        w = self.query_encoder(batch.query.audio)
        q = torch.permute(q, (0, 3, 1, 2)) # (batch, n_band, n_time, emb_dim) -> (batch, emb_dim, n_band, n_time)
        q = self.film(q, w)
        q = torch.permute(q, (0, 2, 3, 1)) # -> (batch, n_band, n_time, emb_dim)
        
        return q


    def optimized_forward(self, batch: InputType, mode: OperationMode = OperationMode.TRAIN):

        with torch.no_grad():
            x = self.stft(batch.mixture.audio)
            batch.mixture.spectrogram = x

            if "sources" in batch.keys():
                for stem in batch.sources.keys():
                    s = batch.sources[stem].audio
                    s = self.stft(s)
                    batch.sources[stem].spectrogram = s

        batch = self.optimized_separate(batch)

        return batch


    def optimized_separate(self, batch):

        x, q, length = self.encode(batch)

        for stem, query in batch.query.items():

            batch_ = SimpleishNamespace(**batch.__dict__)
            batch_.query = query

            q = self.adapt_query(q, batch_)

            m = self.mask_estim(q)
            s = self.mask(x, m)
            s = torch.reshape(s, x.shape)
            batch.estimates[stem] = SimpleishNamespace(
                audio=self.istft(s, length), spectrogram=s
            )

        return batch

#### **ebase.py**

In [23]:
# Inference Probably
import math
import os.path
from collections import defaultdict
from itertools import chain, combinations
from pprint import pprint
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Type, TypedDict

import pytorch_lightning as pl
import torch
import torchaudio as ta
import torchmetrics as tm
from torch import nn, optim
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import LRScheduler
from tqdm import tqdm

from torch.nn import functional as F

# from core.types import BatchedInputOutput, OperationMode, RawInputType, SimpleishNamespace
# from core.types import (
#     InputType,
#     OutputType,
#     LossOutputType,
#     MetricOutputType,
#     ModelType,
#     OptimizerType,
#     SchedulerType,
#     MetricType,
#     LossType,
#     OptimizationBundle,
#     LossHandler,
#     MetricHandler,
#     AugmentationHandler,
#     InferenceHandler,
# )


class EndToEndLightningSystem(pl.LightningModule):
    def __init__(
        self,
        model: ModelType,
        loss_handler: LossHandler,
        metrics: MetricHandler,
        augmentation_handler: AugmentationHandler,
        inference_handler: InferenceHandler,
        optimization_bundle: OptimizationBundle,
        fast_run: bool = False,
        commitment_weight: float = 1.0,
        batch_size: Optional[int] = None,
        effective_batch_size: Optional[int] = None,
    ) -> None:
        super().__init__()

        self.model = model

        self.loss = loss_handler

        self.metrics = metrics
        self.optimization = optimization_bundle
        self.augmentation = augmentation_handler
        self.inference = inference_handler

        self.fast_run = fast_run

        self.model.fast_run = fast_run

        self.commitment_weight = commitment_weight

        self.batch_size = batch_size
        self.effective_batch_size = effective_batch_size if effective_batch_size is not None else batch_size
        self.accum_ratio = self.effective_batch_size // self.batch_size if self.effective_batch_size is not None else 1

        self.output_dir = None
        self.split_size = None

    def configure_optimizers(self) -> Any:
        optimizer = self.optimization.optimizer.cls(
            self.model.parameters(),
            **self.optimization.optimizer.kwargs
        )

        ret = {
            "optimizer": optimizer,
        }

        if self.optimization.scheduler is not None:
            scheduler = self.optimization.scheduler.cls(
                optimizer,
                **self.optimization.scheduler.kwargs
            )
            ret["lr_scheduler"] = scheduler

        return ret

    def compute_loss(
        self,
        batch: BatchedInputOutput,
        mode=OperationMode.TRAIN
    ) -> LossOutputType:
        loss_dict = self.loss(batch)
        return loss_dict

    # TODO: move to a metric handler
    def update_metrics(
        self,
        batch: BatchedInputOutput,
        mode: OperationMode = OperationMode.TRAIN,
    ) -> None:
        metrics: MetricType = self.metrics.get_mode(mode)

        for stem, metric in metrics.items():
            if stem not in batch.estimates.keys():
                continue
            metric.update(batch)

    # TODO: move to a metric handler
    def compute_metrics(self, mode: OperationMode) -> MetricOutputType:
        metrics: MetricType = self.metrics.get_mode(mode)

        metric_dict = {}

        for stem, metric in metrics.items():
            md = metric.compute()
            metric_dict.update({f"{stem}/{k}": v for k, v in md.items()})

        self.log_dict(metric_dict, prog_bar=True, logger=False)

        return metric_dict

    # TODO: move to a metric handler
    def reset_metrics(self, mode: OperationMode) -> None:
        metrics: MetricType = self.metrics.get_mode(mode)

        for _, metric in metrics.items():
            metric.reset()

    def forward(self, batch: RawInputType) -> Tuple[InputType, OutputType]:
        batch = self.model(batch)
        return batch

    def common_step(
        self, batch: RawInputType, mode: OperationMode, batch_idx: int = -1
    ) -> Tuple[OutputType, LossOutputType]:
        batch = BatchedInputOutput.from_dict(batch)
        batch = self.forward(batch)

        loss_dict = self.compute_loss(batch, mode=mode)

        if not self.fast_run:
            with torch.no_grad():
                self.update_metrics(batch, mode=mode)

        return loss_dict

    def training_step(self, batch: RawInputType, batch_idx: int) -> LossOutputType:
        # augmented_batch = self.augmentation(batch, mode=OperationMode.TRAIN)

        self.model.train()

        loss_dict = self.common_step(batch, mode=OperationMode.TRAIN, batch_idx=batch_idx)

        self.log_dict_with_prefix(loss_dict, prefix=OperationMode.TRAIN, prog_bar=True)

        return loss_dict

    def on_train_batch_end(
        self, outputs: OutputType, batch: RawInputType, batch_idx: int
    ) -> None:

        if self.fast_run:
            return

        if (batch_idx + 1) % self.accum_ratio == 0:
            metric_dict = self.compute_metrics(mode=OperationMode.TRAIN)
            self.log_dict_with_prefix(metric_dict, prefix=OperationMode.TRAIN)
            self.reset_metrics(mode=OperationMode.TRAIN)

    @torch.inference_mode()
    def validation_step(
        self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0
    ) -> Dict[str, Any]:

        self.model.eval()

        with torch.inference_mode():
            loss_dict = self.common_step(batch, mode=OperationMode.VAL)

        self.log_dict_with_prefix(loss_dict, prefix=OperationMode.VAL)

        return loss_dict

    def on_validation_epoch_start(self) -> None:
        self.reset_metrics(mode=OperationMode.VAL)

    def on_validation_epoch_end(self) -> None:
        if self.fast_run:
            return

        metric_dict = self.compute_metrics(mode=OperationMode.VAL)
        self.log_dict_with_prefix(
            metric_dict, OperationMode.VAL, prog_bar=True, add_dataloader_idx=False
        )
        self.reset_metrics(mode=OperationMode.VAL)

    
    def save_to_audio(self, batch: BatchedInputOutput, batch_idx: int) -> None:
        
        batch_size = batch["mixture"]["audio"].shape[0]
        
        assert batch_size == 1, "Batch size must be 1 for inference"
        
        metadata = batch.metadata
        
        song_id = metadata["mix"][0]
        stem = metadata["stem"][0]
        
        log_dir = os.path.join(self.logger.log_dir, "audio")
        
        os.makedirs(os.path.join(log_dir, song_id), exist_ok=True)
        
        audio = batch.estimates[stem]["audio"]
        
        audio = audio.squeeze(0).cpu().numpy()
        
        audio_path = os.path.join(log_dir, song_id, f"{stem}.wav")
        
        ta.save(audio_path, torch.tensor(audio), self.inference.fs)

    def save_vdbo_to_audio(self, batch: BatchedInputOutput, batch_idx: int) -> None:
        
        batch_size = batch["mixture"]["audio"].shape[0]
        
        assert batch_size == 1, "Batch size must be 1 for inference"
        
        metadata = batch.metadata
        
        song_id = metadata["song_id"][0]
        
        log_dir = os.path.join(self.logger.log_dir, "audio")
        
        os.makedirs(os.path.join(log_dir, song_id), exist_ok=True)
        
        for stem, audio in batch.estimates.items():
            audio = audio["audio"]
            audio = audio.squeeze(0).cpu().numpy()
            
            audio_path = os.path.join(log_dir, song_id, f"{stem}.wav")
            
            ta.save(audio_path, torch.tensor(audio), self.inference.fs)

    @torch.inference_mode()
    def chunked_inference(
        self, batch: RawInputType, batch_idx: int = -1, dataloader_idx: int = 0
    ) -> BatchedInputOutput:
        batch = BatchedInputOutput.from_dict(batch)
        
        audio = batch["mixture"]["audio"]
        
        b, c, n_samples = audio.shape
        
        assert b == 1

        fs = self.inference.fs

        chunk_size = int(self.inference.chunk_size_seconds * fs)
        hop_size = int(self.inference.hop_size_seconds * fs)
        
        batch_size = self.inference.batch_size
        
        overlap = chunk_size - hop_size
        
        scaler = chunk_size / (2 * hop_size)

        n_chunks = int(math.ceil(
            (n_samples + 4 * overlap - chunk_size) / hop_size
        )) + 1
        
        pad = (n_chunks - 1) * hop_size + chunk_size - n_samples

        # print(audio.shape)
        audio = F.pad(
            audio,
            pad=(2 * overlap, 2 * overlap + pad),
            mode="reflect"
        )
        padded_length = audio.shape[-1]
        audio = audio.reshape(c, 1, -1, 1)
        
        chunked_audio = F.unfold(
            audio,
            kernel_size=(chunk_size, 1), 
            stride=(hop_size, 1)
        ) # (c, chunk_size, n_chunk)

        # print(chunked_audio.shape)

        chunked_audio = chunked_audio.permute(2, 0, 1).reshape(-1, c, chunk_size)
        
        n_chunks = chunked_audio.shape[0]
        
        n_batch = math.ceil(n_chunks / batch_size)

        outputs = []
        
        for i in tqdm(range(n_batch)):
            start = i * batch_size
            end = min((i + 1) * batch_size, n_chunks)
            
            chunked_batch = SimpleishNamespace(
                mixture={
                    "audio": chunked_audio[start:end]
                },
                query=batch["query"],
                estimates=batch["estimates"]
            )
            
            output = self.forward(chunked_batch)
            outputs.append(output.estimates["target"]["audio"])

        output = torch.cat(outputs, dim=0) # (n_chunks, c, chunk_size)
        window = torch.hann_window(chunk_size, device=self.device).reshape(1, 1, chunk_size)
        output = output * window / scaler

        output = torch.permute(output, (1, 2, 0))

        output = F.fold(
            output,
            output_size=(padded_length, 1),
            kernel_size=(chunk_size, 1),
            stride=(hop_size, 1)
        ) # (c, 1, t, 1)

        output = output[None, :, 0, 2*overlap: n_samples + 2*overlap, 0]

        stem = batch.metadata["stem"][0]

        batch["estimates"][stem] = {
            "audio": output
        }

        return batch

    def chunked_vdbo_inference(
        self, batch: RawInputType, batch_idx: int = -1, dataloader_idx: int = 0
    ) -> BatchedInputOutput:
        batch = BatchedInputOutput.from_dict(batch)
        
        audio = batch["mixture"]["audio"]
        
        b, c, n_samples = audio.shape
        
        assert b == 1

        fs = self.inference.fs

        chunk_size = int(self.inference.chunk_size_seconds * fs)
        hop_size = int(self.inference.hop_size_seconds * fs)
        
        batch_size = self.inference.batch_size
        
        overlap = chunk_size - hop_size
        
        scaler = chunk_size / (2 * hop_size)

        n_chunks = int(math.ceil(
            (n_samples + 4 * overlap - chunk_size) / hop_size
        )) + 1
        
        pad = (n_chunks - 1) * hop_size + chunk_size - n_samples

        # print(audio.shape)
        audio = F.pad(
            audio,
            pad=(2 * overlap, 2 * overlap + pad),
            mode="reflect"
        )
        padded_length = audio.shape[-1]
        audio = audio.reshape(c, 1, -1, 1)
        
        chunked_audio = F.unfold(
            audio,
            kernel_size=(chunk_size, 1), 
            stride=(hop_size, 1)
        ) # (c, chunk_size, n_chunk)

        # print(chunked_audio.shape)

        chunked_audio = chunked_audio.permute(2, 0, 1).reshape(-1, c, chunk_size)
        
        n_chunks = chunked_audio.shape[0]
        
        n_batch = math.ceil(n_chunks / batch_size)

        outputs = defaultdict(list)
        
        for i in tqdm(range(n_batch)):
            start = i * batch_size
            end = min((i + 1) * batch_size, n_chunks)
            
            chunked_batch = SimpleishNamespace(
                mixture={
                    "audio": chunked_audio[start:end]
                },
                estimates=batch["estimates"]
            )
            
            output = self.forward(chunked_batch)
            
            for stem, estimate in output.estimates.items():
                outputs[stem].append(estimate["audio"])

        for stem, outputs_ in outputs.items():                

            output = torch.cat(outputs_, dim=0) # (n_chunks, c, chunk_size)
            window = torch.hann_window(chunk_size, device=self.device).reshape(1, 1, chunk_size)
            output = output * window / scaler

            output = torch.permute(output, (1, 2, 0))

            output = F.fold(
                output,
                output_size=(padded_length, 1),
                kernel_size=(chunk_size, 1),
                stride=(hop_size, 1)
            ) # (c, 1, t, 1)

            output = output[None, :, 0, 2*overlap: n_samples + 2*overlap, 0]

            batch["estimates"][stem] = {
                "audio": output
            }

        return batch


    def on_test_epoch_start(self) -> None:
        self.reset_metrics(mode=OperationMode.TEST)

    def test_step(
        self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0
    ) -> Any:

        self.model.eval()
        
        if "query" in batch.keys():
            batch = self.chunked_inference(batch, batch_idx, dataloader_idx)
        else:
            batch = self.chunked_vdbo_inference(batch, batch_idx, dataloader_idx)
        
        self.reset_metrics(mode=OperationMode.TEST)
        self.update_metrics(batch, mode=OperationMode.TEST)
        metrics = self.compute_metrics(mode=OperationMode.TEST)
        # metrics["song_id"] = batch.metadata["mix"][0]
        self.log_dict_with_prefix(metrics, OperationMode.TEST, 
                                  on_step=True, on_epoch=False, prog_bar=True)
        self.reset_metrics(mode=OperationMode.TEST)

        # pprint(metrics)

        return batch

    def on_test_epoch_end(self) -> None:
        self.reset_metrics(mode=OperationMode.TEST)

    def set_output_path(self, output_dir: str) -> None:
        self.output_dir = output_dir

    def predict_step(
        self, batch: RawInputType, batch_idx: int, dataloader_idx: int = 0
    ) -> Any:

        self.model.eval()
        
        if "query" in batch.keys():    
            batch = self.chunked_inference(batch, batch_idx, dataloader_idx)
            
            self.save_to_audio(batch, batch_idx)
        else:
            batch = self.chunked_vdbo_inference(batch, batch_idx, dataloader_idx)
            self.save_vdbo_to_audio(batch, batch_idx)

    def load_state_dict(
        self, state_dict: Mapping[str, Any], strict: bool = False
    ) -> Any:
        return super().load_state_dict(state_dict, strict=False)

    def log_dict_with_prefix(
        self,
        dict_: Dict[str, torch.Tensor],
        prefix: str,
        batch_size: Optional[int] = None,
        **kwargs: Any,
    ) -> None:

        
        self.log_dict(
            {f"{prefix}/{k}": v for k, v in dict_.items()},
            batch_size=batch_size,
            logger=True,
            sync_dist=True,
            **kwargs,
            # on_step=True,
            # on_epoch=False,
        )

        self.logger.save()

_____
# ***Data***

#### **Base Dataset handlers**

In [24]:
# Train, Val, Test, Predict Data loaders
import inspect
import os
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import numpy as np
import torch
import torchaudio as ta
from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils import data

from pytorch_lightning import LightningDataModule
from torch.utils.data import Dataset, DataLoader, IterableDataset


def from_datasets(
    train_dataset: Optional[Union[Dataset, Sequence[Dataset], Mapping[str, Dataset]]] = None,
    val_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
    test_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
    predict_dataset: Optional[Union[Dataset, Sequence[Dataset]]] = None,
    batch_size: int = 1,
    num_workers: int = 0,
    **datamodule_kwargs: Any,
) -> "LightningDataModule":

    def dataloader(ds: Dataset, shuffle: bool = False) -> DataLoader:
        shuffle &= not isinstance(ds, IterableDataset)
        return DataLoader(
            ds,
            batch_size=batch_size,
            shuffle=shuffle,
            num_workers=num_workers,
            pin_memory=True,
            prefetch_factor=4,
            persistent_workers=True,
        )

    def train_dataloader() -> TRAIN_DATALOADERS:
        assert train_dataset

        if isinstance(train_dataset, Mapping):
            return {key: dataloader(ds, shuffle=True) for key, ds in train_dataset.items()}
        if isinstance(train_dataset, Sequence):
            return [dataloader(ds, shuffle=True) for ds in train_dataset]
        return dataloader(train_dataset, shuffle=True)

    def val_dataloader() -> EVAL_DATALOADERS:
        assert val_dataset

        if isinstance(val_dataset, Sequence):
            return [dataloader(ds) for ds in val_dataset]
        return dataloader(val_dataset)

    def test_dataloader() -> EVAL_DATALOADERS:
        assert test_dataset

        if isinstance(test_dataset, Sequence):
            return [dataloader(ds) for ds in test_dataset]
        return dataloader(test_dataset)

    def predict_dataloader() -> EVAL_DATALOADERS:
        assert predict_dataset

        if isinstance(predict_dataset, Sequence):
            return [dataloader(ds) for ds in predict_dataset]
        return dataloader(predict_dataset)

    candidate_kwargs = {"batch_size": batch_size, "num_workers": num_workers}
    accepted_params = inspect.signature(LightningDataModule.__init__).parameters
    accepts_kwargs = any(param.kind == param.VAR_KEYWORD for param in accepted_params.values())
    if accepts_kwargs:
        special_kwargs = candidate_kwargs
    else:
        accepted_param_names = set(accepted_params)
        accepted_param_names.discard("self")
        special_kwargs = {k: v for k, v in candidate_kwargs.items() if k in accepted_param_names}

    datamodule = LightningDataModule(**datamodule_kwargs, **special_kwargs)
    if train_dataset is not None:
        datamodule.train_dataloader = train_dataloader  # type: ignore[method-assign]
    if val_dataset is not None:
        datamodule.val_dataloader = val_dataloader  # type: ignore[method-assign]
    if test_dataset is not None:
        datamodule.test_dataloader = test_dataloader  # type: ignore[method-assign]
    if predict_dataset is not None:
        datamodule.predict_dataloader = predict_dataloader  # type: ignore[method-assign]

    return datamodule


class BaseSourceSeparationDataset(data.Dataset, ABC):
    def __init__(
        self,
        split: str,
        stems: List[str],
        files: List[str],
        data_path: str,
        fs: int,
        npy_memmap: bool,
        recompute_mixture: bool,
    ):
        if "mixture" not in stems:
            stems = ["mixture"] + stems

        self.split = split
        self.stems = stems
        self.stems_no_mixture = [s for s in stems if s != "mixture"]
        self.files = files
        self.data_path = data_path
        self.fs = fs
        self.npy_memmap = npy_memmap
        self.recompute_mixture = recompute_mixture

    @abstractmethod
    def get_stem(self, *, stem: str, identifier: Dict[str, Any]) -> torch.Tensor:
        raise NotImplementedError

    def _get_audio(self, stems, identifier: Dict[str, Any]):
        audio = {}
        for stem in stems:
            audio[stem] = self.get_stem(stem=stem, identifier=identifier)

        return audio

    def get_audio(self, identifier: Dict[str, Any]):
        if self.recompute_mixture:
            audio = self._get_audio(self.stems_no_mixture, identifier=identifier)
            audio["mixture"] = self.compute_mixture(audio)
            return audio
        else:
            return self._get_audio(self.stems, identifier=identifier)

    @abstractmethod
    def get_identifier(self, index: int) -> Dict[str, Any]:
        pass

    def compute_mixture(self, audio) -> torch.Tensor:
        return sum(audio[stem] for stem in audio if stem != "mixture")

### ***MoisesDB***

#### **Taxonomy**

In [25]:
# Taxonomy
taxonomy = {
    "vocals": [
        "lead male singer",
        "lead female singer",
        "human choir",
        "background vocals",
        "other (vocoder, beatboxing etc)",
    ],
    "bass": [
        "bass guitar",
        "bass synthesizer (moog etc)",
        "contrabass/double bass (bass of instrings)",
        "tuba (bass of brass)",
        "bassoon (bass of woodwind)",
    ],
    "drums": [
        "snare drum",
        "toms",
        "kick drum",
        "cymbals",
        "overheads",
        "full acoustic drumkit",
        "drum machine",
    ],
    "other": [
        "fx/processed sound, scratches, gun shots, explosions etc",
        "click track",
    ],
    "guitar": [
        "clean electric guitar",
        "distorted electric guitar",
        "lap steel guitar or slide guitar",
        "acoustic guitar",
    ],
    "other plucked": ["banjo, mandolin, ukulele, harp etc"],
    "percussion": [
        "a-tonal percussion (claps, shakers, congas, cowbell etc)",
        "pitched percussion (mallets, glockenspiel, ...)",
    ],
    "piano": [
        "grand piano",
        "electric piano (rhodes, wurlitzer, piano sound alike)",
    ],
    "other keys": [
        "organ, electric organ",
        "synth pad",
        "synth lead",
        "other sounds (hapischord, melotron etc)",
    ],
    "bowed strings": [
        "violin (solo)",
        "viola (solo)",
        "cello (solo)",
        "violin section",
        "viola section",
        "cello section",
        "string section",
        "other strings",
    ],
    "wind": [
        "brass (trumpet, trombone, french horn, brass etc)",
        "flutes (piccolo, bamboo flute, panpipes, flutes etc)",
        "reeds (saxophone, clarinets, oboe, english horn, bagpipe)",
        "other wind",
    ],
}


def clean_track_inst(inst):

    if "fx" in inst:
        inst = "fx"

    if "contrabass_double_bass" in inst:
        inst = "double_bass"

    if "banjo" in inst:
        return "other_plucked"

    if "(" in inst:
        inst = inst.split("(")[0]

    for s in [",", "-"]:
        if s in inst:
            inst = inst.replace(s, "")

    for s in ["/"]:
        if s in inst:
            inst = inst.replace(s, "_")

    if inst[-1] == "_":
        inst = inst[:-1]

    return inst


taxonomy = {k: [clean_track_inst(i.replace(" ", "_")) for i in v] for k, v in taxonomy.items()}

#### **Main Data Input**

In [30]:
# Dataset.py
import math
import os
import random
import warnings
from abc import ABC
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
from omegaconf import OmegaConf
import pandas as pd
import torch
from torch_audiomentations.utils.object_dict import ObjectDict
import torchaudio as ta
from torch.utils import data
from tqdm import tqdm

# from core.data.base import BaseSourceSeparationDataset
# from core.types import input_dict

# from . import clean_track_inst

from torch import Tensor, nn

DBFS_HOP_SIZE = int(0.125 * 44100)
DBFS_CHUNK_SIZE = int(1 * 44100)

INST_BY_OCCURRENCE = [
    "bass_guitar",
    "kick_drum",
    "snare_drum",
    "lead_male_singer",
    "distorted_electric_guitar",
    "clean_electric_guitar",
    "toms",
    "acoustic_guitar",
    "background_vocals",
    "hi_hat",
    "overheads",
    "atonal_percussion",
    "grand_piano",
    "cymbals",
    "lead_female_singer",
    "synth_lead",
    "bass_synthesizer",
    "synth_pad",
    "organ_electric_organ",
    "fx",
    "drum_machine",
    "string_section",
    "electric_piano",
    "full_acoustic_drumkit",
    "other_sounds",
    "pitched_percussion",
    "brass",
    "reeds",
    "contrabass_double_bass",
    "other_plucked",
    "other_strings",
    "other_wind",
    "cello",
    "other",
    "flutes",
    "viola_section",
    "viola",
    "cello_section",
]

FINE_LEVEL_INSTRUMENTS = {
    "lead_male_singer",
    "lead_female_singer",
    "human_choir",
    "background_vocals",
    "other_vocals",
    "bass_guitar",
    "bass_synthesizer",
    "contrabass_double_bass",
    "tuba",
    "bassoon",
    "snare_drum",
    "toms",
    "kick_drum",
    "cymbals",
    "overheads",
    "full_acoustic_drumkit",
    "drum_machine",
    "hihat",
    "fx",
    "click_track",
    "clean_electric_guitar",
    "distorted_electric_guitar",
    "lap_steel_guitar_or_slide_guitar",
    "acoustic_guitar",
    "other_plucked",
    "atonal_percussion",
    "pitched_percussion",
    "grand_piano",
    "electric_piano",
    "organ_electric_organ",
    "synth_pad",
    "synth_lead",
    "other_sounds",
    "violin",
    "viola",
    "cello",
    "violin_section",
    "viola_section",
    "cello_section",
    "string_section",
    "other_strings",
    "brass",
    "flutes",
    "reeds",
    "other_wind",
}

COARSE_LEVEL_INSTRUMENTS = {
    "vocals",
    "bass",
    "drums",
    "guitar",
    "other_plucked",
    "percussion",
    "piano",
    "other_keys",
    "bowed_strings",
    "wind",
    "other",
}

COARSE_TO_FINE = {
    "vocals": [
        "lead_male_singer",
        "lead_female_singer",
        "human_choir",
        "background_vocals",
        "other_vocals",
    ],
    "bass": [
        "bass_guitar",
        "bass_synthesizer",
        "contrabass_double_bass",
        "tuba",
        "bassoon",
    ],
    "drums": [
        "snare_drum",
        "toms",
        "kick_drum",
        "cymbals",
        "overheads",
        "full_acoustic_drumkit",
        "drum_machine",
        "hihat",
    ],
    "other": ["fx", "click_track"],
    "guitar": [
        "clean_electric_guitar",
        "distorted_electric_guitar",
        "lap_steel_guitar_or_slide_guitar",
        "acoustic_guitar",
    ],
    "other_plucked": ["other_plucked"],
    "percussion": ["atonal_percussion", "pitched_percussion"],
    "piano": ["grand_piano", "electric_piano"],
    "other_keys": ["organ_electric_organ", "synth_pad", "synth_lead", "other_sounds"],
    "bowed_strings": [
        "violin",
        "viola",
        "cello",
        "violin_section",
        "viola_section",
        "cello_section",
        "string_section",
        "other_strings",
    ],
    "wind": ["brass", "flutes", "reeds", "other_wind"],
}

COARSE_TO_FINE = {k: set(v) for k, v in COARSE_TO_FINE.items()}
FINE_TO_COARSE = {k: kk for kk, v in COARSE_TO_FINE.items() for k in v}

ALL_LEVEL_INSTRUMENTS = COARSE_LEVEL_INSTRUMENTS.union(FINE_LEVEL_INSTRUMENTS)


class MoisesDBBaseDataset(BaseSourceSeparationDataset, ABC):
    def __init__(
        self,
        split: str,
        data_path: str = "/home/kwatchar3/Documents/data/moisesdb",
        fs: int = 44100,
        return_stems: Union[bool, List[str]] = False,
        npy_memmap=True,
        recompute_mixture=False,
        train_folds=None,
        val_folds=None,
        test_folds=None,
        query_file="query",
    ) -> None:
        if test_folds is None:
            test_folds = [5]

        if val_folds is None:
            val_folds = [4]

        if train_folds is None:
            train_folds = [1, 2, 3]

        split_path = os.path.join(data_path, "splits.csv")
        splits = pd.read_csv(split_path)

        metadata_path = os.path.join(data_path, "stems.csv")
        metadata = pd.read_csv(metadata_path)

        if split == "train":
            folds = train_folds
        elif split == "val":
            folds = val_folds
        elif split == "test":
            folds = test_folds
        else:
            raise NameError

        files = splits[splits["split"].isin(folds)]["song_id"].tolist()
        metadata = metadata[metadata["song_id"].isin(files)]

        super().__init__(
            split=split,
            stems=["mixture"],
            files=files,
            data_path=data_path,
            fs=fs,
            npy_memmap=npy_memmap,
            recompute_mixture=recompute_mixture,
        )

        self.folds = folds

        self.metadata = metadata.rename(
            columns={k: k.replace(" ", "_") for k in metadata.columns}
        )

        self.song_to_stem = (
            metadata.set_index("song_id")
            .apply(lambda row: row[row == 1].index.tolist(), axis=1)
            .to_dict()
        )
        self.stem_to_song = (
            metadata.set_index("song_id")
            .transpose()
            .apply(lambda row: row[row == 1].index.tolist(), axis=1)
            .to_dict()
        )

        self.true_length = len(self.files)
        self.n_channels = 2

        self.audio_path = os.path.join(data_path, "npy2")

        self.return_stems = return_stems

        self.query_file = query_file

    def get_full_stem(self, *, stem: str, identifier) -> torch.Tensor:
        song_id = identifier["song_id"]
        path = os.path.join(self.data_path, "npy2", song_id)
        # noinspection PyUnresolvedReferences

        assert self.npy_memmap

        if os.path.exists(os.path.join(path, f"{stem}.npy")):
            audio = np.load(os.path.join(path, f"{stem}.npy"), mmap_mode="r")
        else:
            audio = None

        return audio

    def get_query_stem(self, *, stem: str, identifier) -> torch.Tensor:
        song_id = identifier["song_id"]
        path = os.path.join(self.data_path, "npyq", song_id)
        # noinspection PyUnresolvedReferences

        if self.npy_memmap:
            # print(self.npy_memmap)
            audio = np.load(
                os.path.join(path, f"{stem}.{self.query_file}.npy"), mmap_mode="r"
            )
        else:
            raise NotImplementedError

        return audio

    def get_stem(self, *, stem: str, identifier) -> torch.Tensor:
        audio = self.get_full_stem(stem=stem, identifier=identifier)
        return audio

    def get_identifier(self, index):
        return dict(song_id=self.files[index % self.true_length])

    def __getitem__(self, index: int):
        identifier = self.get_identifier(index)
        audio = self.get_audio(identifier)

        mixture = audio["mixture"].copy()

        if isinstance(self.return_stems, list):
            sources = {
                stem: audio.get(stem, np.zeros_like(mixture))
                for stem in self.return_stems
            }
        elif isinstance(self.return_stems, bool):
            if self.return_stems:
                sources = {
                    stem: audio[stem].copy()
                    for stem in self.song_to_stem[identifier["song_id"]]
                }
            else:
                sources = None
        else:
            raise ValueError

        return input_dict(
            mixture=mixture,
            sources=sources,
            metadata=identifier,
            modality="audio",
        )


class MoisesDBFullTrackDataset(MoisesDBBaseDataset):
    def __init__(
        self,
        data_root: str,
        split: str,
        return_stems: Union[bool, List[str]] = False,
        npy_memmap=True,
        recompute_mixture=False,
        query_file="query",
    ) -> None:
        super().__init__(
            split=split,
            data_path=data_root,
            return_stems=return_stems,
            npy_memmap=npy_memmap,
            recompute_mixture=recompute_mixture,
            query_file=query_file,
        )

    def __len__(self) -> int:
        return self.true_length


class MoisesDBVDBOFullTrackDataset(MoisesDBFullTrackDataset):
    def __init__(
        self, data_root: str, split: str, npy_memmap=True, recompute_mixture=False
    ) -> None:
        super().__init__(
            data_root=data_root,
            split=split,
            return_stems=["vocals", "bass", "drums", "vdbo_others"],
            npy_memmap=npy_memmap,
            recompute_mixture=recompute_mixture,
            query_file=None,
        )


import torch_audiomentations as audiomentations
from torch_audiomentations.utils.dsp import convert_decibels_to_amplitude_ratio


class SmartGain(audiomentations.Gain):
    def __init__(
        self, p=0.5, min_gain_in_db=-6, max_gain_in_db=6, dbfs_threshold=-45.0
    ):
        super().__init__(
            p=p, min_gain_in_db=min_gain_in_db, max_gain_in_db=max_gain_in_db
        )

        self.dbfs_threshold = dbfs_threshold

    def randomize_parameters(
        self,
        samples: Tensor = None,
        sample_rate: Optional[int] = None,
        targets: Optional[Tensor] = None,
        target_rate: Optional[int] = None,
    ):

        dbfs = 10 * torch.log10(torch.mean(torch.square(samples)) + 1e-6)

        if dbfs > self.dbfs_threshold:
            low = self.min_gain_in_db
        else:
            low = max(0.0, self.min_gain_in_db)

        distribution = torch.distributions.Uniform(
            low=torch.tensor(low, dtype=torch.float32, device=samples.device),
            high=torch.tensor(
                self.max_gain_in_db, dtype=torch.float32, device=samples.device
            ),
            validate_args=True,
        )
        selected_batch_size = samples.size(0)
        self.transform_parameters["gain_factors"] = (
            convert_decibels_to_amplitude_ratio(
                distribution.sample(sample_shape=(selected_batch_size,))
            )
            .unsqueeze(1)
            .unsqueeze(1)
        )


class Audiomentations(audiomentations.Compose):
    def __init__(self, augment="gssp", fs: int = 44100):

        if isinstance(augment, str):
            if augment == "gssp":
                augment = OmegaConf.create(
                    [
                        dict(
                            cls="Shift",
                            kwargs=dict(p=1.0, min_shift=-0.5, max_shift=0.5),
                        ),
                        dict(
                            cls="Gain",
                            kwargs=dict(p=1.0, min_gain_in_db=-6, max_gain_in_db=6),
                        ),
                        dict(cls="ShuffleChannels", kwargs=dict(p=0.5)),
                        dict(cls="PolarityInversion", kwargs=dict(p=0.5)),
                    ]
                )
            else:
                raise ValueError

        transforms = []

        for transform in augment:

            if transform.cls == "Gain":
                transforms.append(SmartGain(**transform.kwargs))
            else:
                transforms.append(
                    getattr(audiomentations, transform.cls)(**transform.kwargs)
                )

        super().__init__(transforms=transforms, shuffle=True)

        self.fs = fs

    def forward(
        self,
        samples: torch.Tensor = None,
    ) -> ObjectDict:
        return super().forward(samples, sample_rate=self.fs)


class MoisesDBVDBORandomChunkDataset(MoisesDBVDBOFullTrackDataset):
    def __init__(
        self,
        data_root: str,
        split: str,
        chunk_size_seconds: float = 4.0,
        fs: int = 44100,
        target_length: int = 8192,
        augment=None,
        npy_memmap=True,
        recompute_mixture=True,
        db_threshold=-24.0,
        db_step=-12.0,
    ) -> None:
        super().__init__(
            data_root=data_root,
            split=split,
            npy_memmap=npy_memmap,
            recompute_mixture=recompute_mixture,
        )

        self.chunk_size_seconds = chunk_size_seconds
        self.chunk_size_samples = int(chunk_size_seconds * fs)
        self.fs = fs

        self.target_length = target_length

        self.db_threshold = db_threshold
        self.db_step = db_step

        if augment is not None:
            assert self.recompute_mixture
            self.augment = Audiomentations(augment, fs)
        else:
            self.augment = None

    def __len__(self) -> int:
        return self.target_length

    def _chunk_audio(self, audio, start, end):
        audio = {k: v[..., start:end] for k, v in audio.items()}

        return audio

    def _get_start_end(self, audio, identifier):
        n_samples = audio.shape[-1]
        start = np.random.randint(0, n_samples - self.chunk_size_samples)
        end = start + self.chunk_size_samples

        return start, end

    def _get_audio(self, stems, identifier: Dict[str, Any]):
        audio = {}

        for stem in stems:
            audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)

        for stem in stems:
            if audio[stem] is None:
                audio[stem] = np.zeros(
                    audio[
                        (
                            "mixture"
                            if "mixture" in stems
                            else [s for s in stems if audio[s] is not None][0]
                        )
                    ].shape,
                    dtype=np.float32,
                )

        start, end = self._get_start_end(audio[stems[0]], identifier)
        audio = self._chunk_audio(audio, start, end)

        if self.augment is not None:
            audio = {
                k: self.augment(torch.from_numpy(v[None, :, :]))[0, :, :].numpy()
                for k, v in audio.items()
            }

        return audio

    def get_audio(self, identifier: Dict[str, Any]):
        if self.recompute_mixture:
            audio = self._get_audio(
                ["vocals", "bass", "drums", "vdbo_others"], identifier=identifier
            )
            audio["mixture"] = self.compute_mixture(audio)
            return audio
        else:
            return self._get_audio(
                ["mixture", "vocals", "bass", "drums", "vdbo_others"],
                identifier=identifier,
            )

    def __getitem__(self, index: int):

        identifier = self.get_identifier(index)
        audio = self.get_audio(identifier=identifier)

        mixture = audio["mixture"].copy()

        sources = {
            stem: audio.get(stem, np.zeros_like(mixture)) for stem in self.return_stems
        }

        return input_dict(
            mixture=mixture,
            sources=sources,
            metadata=identifier,
            modality="audio",
        )


class MoisesDBVDBODeterministicChunkDataset(MoisesDBVDBORandomChunkDataset):
    def __init__(
        self,
        data_root: str,
        split: str,
        chunk_size_seconds: float = 4.0,
        hop_size_seconds: float = 8.0,
        fs: int = 44100,
        npy_memmap=True,
        recompute_mixture=False,
    ) -> None:
        super().__init__(
            data_root=data_root,
            split=split,
            chunk_size_seconds=chunk_size_seconds,
            npy_memmap=npy_memmap,
            recompute_mixture=recompute_mixture,
        )

        self.hop_size_seconds = hop_size_seconds
        self.hop_size_samples = int(hop_size_seconds * fs)

        self.index_to_identifiers = self._generate_index()
        self.length = len(self.index_to_identifiers)

    def __len__(self) -> int:
        return self.length

    def _generate_index(self):

        identifiers = []

        for song_id in self.files:
            audio = self.get_full_stem(stem="mixture", identifier=dict(song_id=song_id))
            n_samples = audio.shape[-1]
            n_chunks = math.floor(
                (n_samples - self.chunk_size_samples) / self.hop_size_samples
            )

            for i in range(n_chunks):
                chunk_start = i * self.hop_size_samples
                identifiers.append(dict(song_id=song_id, chunk_start=chunk_start))

        return identifiers

    def get_identifier(self, index):
        return self.index_to_identifiers[index]

    def _get_start_end(self, audio, identifier):

        start = identifier["chunk_start"]
        end = start + self.chunk_size_samples

        return start, end


def round_samples(seconds, fs, hop_size, downsample):
    n_frames = math.ceil(seconds * fs / hop_size) + 1
    n_frames_down = math.ceil(n_frames / downsample)
    n_frames = n_frames_down * downsample
    n_samples = (n_frames - 1) * hop_size

    return int(n_samples)


class MoisesDBRandomChunkRandomQueryDataset(MoisesDBFullTrackDataset):
    def __init__(
        self,
        data_root: str,
        split: str,
        target_length: int,
        chunk_size_seconds: float = 4.0,
        query_size_seconds: float = 1.0,
        round_query: bool = False,
        min_query_dbfs: float = -40.0,
        min_target_dbfs: float = -36.0,
        min_target_dbfs_step: float = -12.0,
        max_dbfs_tries: int = 10,
        top_k_instrument: int = 10,
        mixture_stem: str = "mixture",
        use_own_query: bool = True,
        npy_memmap=True,
        allowed_stems=None,
        query_file="query",
        augment=None,
    ) -> None:

        super().__init__(
            data_root=data_root,
            split=split,
            npy_memmap=npy_memmap,
            recompute_mixture=augment is not None,
            query_file=query_file,
        )

        self.mixture_stem = mixture_stem

        self.chunk_size_seconds = chunk_size_seconds
        self.chunk_size_samples = round_samples(
            self.chunk_size_seconds, self.fs, 512, 2**6
        )

        self.query_size_seconds = query_size_seconds

        if round_query:
            self.query_size_samples = round_samples(
                self.query_size_seconds, self.fs, 512, 2**6
            )
        else:
            self.query_size_samples = int(self.query_size_seconds * self.fs)

        self.target_length = target_length

        self.min_query_dbfs = min_query_dbfs

        if min_target_dbfs is None:
            min_target_dbfs = -np.inf
            min_target_dbfs_step = None
            max_dbfs_tries = 1

        self.min_target_dbfs = min_target_dbfs
        self.min_target_dbfs_step = min_target_dbfs_step
        self.max_dbfs_tries = max_dbfs_tries

        self.top_k_instrument = top_k_instrument

        if allowed_stems is None:
            allowed_stems = INST_BY_OCCURRENCE[: self.top_k_instrument]
        else:
            self.top_k_instrument = None

        self.allowed_stems = allowed_stems

        self.song_to_all_stems = {
            k: list(set(v) & set(ALL_LEVEL_INSTRUMENTS))
            for k, v in self.song_to_stem.items()
        }

        self.song_to_stem = {
            k: list(set(v) & set(self.allowed_stems))
            for k, v in self.song_to_stem.items()
        }
        self.stem_to_song = {
            k: list(set(v) & set(self.files)) for k, v in self.stem_to_song.items()
        }

        self.queriable_songs = [k for k, v in self.song_to_stem.items() if len(v) > 0]

        self.use_own_query = use_own_query

        if self.use_own_query:
            self.files = [k for k in self.files if len(self.song_to_stem[k]) > 0]
            self.true_length = len(self.files)

        if augment is not None:
            assert self.recompute_mixture
            self.augment = Audiomentations(augment, self.fs)
        else:
            self.augment = None

    def __len__(self) -> int:
        return self.target_length

    def _chunk_audio(self, audio, start, end):
        audio = {k: v[..., start:end] for k, v in audio.items()}

        return audio

    def _get_start_end(self, audio):
        n_samples = audio.shape[-1]
        start = np.random.randint(0, n_samples - self.chunk_size_samples)
        end = start + self.chunk_size_samples

        return start, end

    def _target_dbfs(self, audio):
        return 10.0 * np.log10(np.mean(np.square(np.abs(audio))) + 1e-6)

    def _chunk_and_check_dbfs_threshold(self, audio_, target_stem, threshold):

        target_dict = {target_stem: audio_[target_stem]}

        for _ in range(self.max_dbfs_tries):
            start, end = self._get_start_end(audio_[target_stem])
            taudio = self._chunk_audio(target_dict, start, end)

            dbfs = self._target_dbfs(taudio[target_stem])
            if dbfs > threshold:
                return self._chunk_audio(audio_, start, end)

        return None

    def _chunk_and_check_dbfs(self, audio_, target_stem):
        out = self._chunk_and_check_dbfs_threshold(
            audio_, target_stem, self.min_target_dbfs
        )

        if out is not None:
            return out

        out = self._chunk_and_check_dbfs_threshold(
            audio_, target_stem, self.min_target_dbfs + self.min_target_dbfs_step
        )

        if out is not None:
            return out

        start, end = self._get_start_end(audio_[target_stem])
        audio = self._chunk_audio(audio_, start, end)

        return audio

    def _augment(self, audio, target_stem):
        stack_audio = np.stack([v for v in audio.values()], axis=0)
        aug_audio = self.augment(torch.from_numpy(stack_audio)).numpy()
        mixture = np.sum(aug_audio, axis=0)

        out = {
            "mixture": mixture,
        }

        if target_stem is not None:
            target_idx = list(audio.keys()).index(target_stem)
            out[target_stem] = aug_audio[target_idx]

        return out

    def _choose_stems_for_augment(self, identifier, target_stem):
        stems_for_song = set(self.song_to_all_stems[identifier["song_id"]])

        stems_ = []
        coarse_level_accounted = set()

        is_none_target = target_stem is None
        is_coarse_target = target_stem in COARSE_LEVEL_INSTRUMENTS

        if is_coarse_target or is_none_target:
            coarse_target = target_stem
        else:
            coarse_target = FINE_TO_COARSE[target_stem]

        fine_level_stems = stems_for_song & FINE_LEVEL_INSTRUMENTS
        coarse_level_stems = stems_for_song & COARSE_LEVEL_INSTRUMENTS

        for s in fine_level_stems:
            coarse_level = FINE_TO_COARSE[s]

            if is_coarse_target and coarse_level == coarse_target:
                continue
            else:
                stems_.append(s)

            coarse_level_accounted.add(coarse_level)

        stems_ += list(coarse_level_stems - coarse_level_accounted)

        if target_stem is not None:
            assert target_stem in stems_, f"stems: {stems_}, target stem: {target_stem}"

            if len(stems_for_song) > 1:
                assert (
                    len(stems_) > 1
                ), f"stems: {stems_}, stems in song: {stems_for_song},\n target stem: {target_stem}"

        assert "mixture" not in stems_

        return stems_

    def _get_audio(
        self, stems, identifier: Dict[str, Any], check_dbfs=True, no_target=False
    ):

        target_stem = stems[0] if not no_target else None

        if self.augment is not None:
            stems_ = self._choose_stems_for_augment(identifier, target_stem)
        else:
            stems_ = stems

        audio = {}
        for stem in stems_:
            audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)

        audio_ = {k: v.copy() for k, v in audio.items()}

        if check_dbfs:
            assert target_stem is not None
            audio = self._chunk_and_check_dbfs(audio_, target_stem)
        else:
            first_key = list(audio_.keys())[0]
            start, end = self._get_start_end(audio_[first_key])
            audio = self._chunk_audio(audio_, start, end)

        if self.augment is not None:
            assert "mixture" not in audio
            audio = self._augment(audio, target_stem)
            assert "mixture" in audio

        return audio

    def __getitem__(self, index: int):

        mix_identifier = self.get_identifier(index)
        mix_stems = self.song_to_stem[mix_identifier["song_id"]]

        if self.use_own_query:
            query_id = mix_identifier["song_id"]
            query_identifier = dict(song_id=query_id)
            possible_stem = mix_stems

            assert len(possible_stem) > 0

            zero_target = False
        else:
            query_id = random.choice(self.queriable_songs)
            query_identifier = dict(song_id=query_id)
            query_stems = self.song_to_stem[query_id]
            possible_stem = list(set(mix_stems) & set(query_stems))

            if len(possible_stem) == 0:
                possible_stem = query_stems
                zero_target = True
                # print(f"Mix {mix_identifier['song_id']} and query {query_id} have no common stems.")
                # return self.__getitem__(index + 1)
            else:
                zero_target = False

        assert (
            len(possible_stem) > 0
        ), f"{mix_identifier['song_id']} and {query_id} have no common stems. zero target is {zero_target}"
        stem = random.choice(possible_stem)

        if zero_target:
            audio = self._get_audio(
                [self.mixture_stem],
                identifier=mix_identifier,
                check_dbfs=False,
                no_target=True,
            )
            mixture = audio[self.mixture_stem].copy()
            sources = {"target": np.zeros_like(mixture)}
        else:
            audio = self._get_audio(
                [stem, self.mixture_stem], identifier=mix_identifier, check_dbfs=True
            )
            mixture = audio[self.mixture_stem].copy()
            sources = {"target": audio[stem].copy()}

        query = self.get_query_stem(stem=stem, identifier=query_identifier)
        query = query.copy()

        assert mixture.shape[-1] == self.chunk_size_samples
        assert query.shape[-1] == self.query_size_samples
        assert sources["target"].shape[-1] == self.chunk_size_samples

        return input_dict(
            mixture=mixture,
            sources=sources,
            query=query,
            metadata={
                "mix": mix_identifier,
                "query": query_identifier,
                "stem": stem,
            },
            modality="audio",
        )


class MoisesDBRandomChunkBalancedRandomQueryDataset(
    MoisesDBRandomChunkRandomQueryDataset
):
    def __init__(
        self,
        data_root: str,
        split: str,
        target_length: int,
        chunk_size_seconds: float = 4,
        query_size_seconds: float = 1,
        round_query: bool = False,
        min_query_dbfs: float = -40.0,
        min_target_dbfs: float = -36.0,
        min_target_dbfs_step: float = -12.0,
        max_dbfs_tries: int = 10,
        top_k_instrument: int = 10,
        mixture_stem: str = "mixture",
        use_own_query: bool = True,
        npy_memmap=True,
        allowed_stems=None,
        query_file="query",
        augment=None,
    ) -> None:
        super().__init__(
            data_root,
            split,
            target_length,
            chunk_size_seconds,
            query_size_seconds,
            round_query,
            min_query_dbfs,
            min_target_dbfs,
            min_target_dbfs_step,
            max_dbfs_tries,
            top_k_instrument,
            mixture_stem,
            use_own_query,
            npy_memmap,
            allowed_stems,
            query_file,
            augment,
        )
        
        self.stem_to_n_songs = {k: len(v) for k, v in self.stem_to_song.items()}
        self.trainable_stems = [k for k, v in self.stem_to_n_songs.items() if v > 1]
        self.n_allowed_stems = len(self.allowed_stems)
        
        
        
    def __getitem__(self, index: int):
        
        stem = self.allowed_stems[index % self.n_allowed_stems]
        song_ids_with_stem = self.stem_to_song[stem]
        
        song_id = song_ids_with_stem[index % self.stem_to_n_songs[stem]]
        
        mix_identifier = dict(song_id=song_id)
        
        audio = self._get_audio([stem, self.mixture_stem], identifier=mix_identifier, check_dbfs=True)
        mixture = audio[self.mixture_stem].copy()
        
        if self.use_own_query:
            query_id = song_id
            query_identifier = dict(song_id=query_id)
        else:
            query_id = random.choice(song_ids_with_stem)
            query_identifier = dict(song_id=query_id)
            
        query = self.get_query_stem(stem=stem, identifier=query_identifier)
        query = query.copy()
        
        sources = {"target": audio[stem].copy()}
        
        return input_dict(
            mixture=mixture,
            sources=sources,
            query=query,
            metadata={
                "mix": mix_identifier,
                "query": query_identifier,
                "stem": stem,
            },
            modality="audio",
        )
        
        


class MoisesDBDeterministicChunkDeterministicQueryDataset(
    MoisesDBRandomChunkRandomQueryDataset
):
    def __init__(
        self,
        data_root: str,
        split: str,
        chunk_size_seconds: float = 4.0,
        hop_size_seconds: float = 8.0,
        query_size_seconds: float = 1.0,
        min_query_dbfs: float = -40.0,
        top_k_instrument: int = 10,
        n_queries_per_chunk: int = 1,
        mixture_stem: str = "mixture",
        use_own_query: bool = True,
        npy_memmap=True,
        allowed_stems: List[str] = None,
        query_file="query",
    ) -> None:

        super().__init__(
            data_root=data_root,
            split=split,
            target_length=None,
            chunk_size_seconds=chunk_size_seconds,
            query_size_seconds=query_size_seconds,
            min_query_dbfs=min_query_dbfs,
            top_k_instrument=top_k_instrument,
            mixture_stem=mixture_stem,
            use_own_query=use_own_query,
            npy_memmap=npy_memmap,
            allowed_stems=allowed_stems,
            query_file=query_file,
        )

        if hop_size_seconds is None:
            hop_size_seconds = chunk_size_seconds

        self.chunk_hop_size_seconds = hop_size_seconds

        self.chunk_hop_size_samples = int(hop_size_seconds * self.fs)

        self.n_queries_per_chunk = n_queries_per_chunk

        self._overwrite = False

        self.query_tuples = self.find_query_tuples_or_generate()
        self.n_chunks = len(self.query_tuples)

    def __len__(self) -> int:
        return self.n_chunks

    def _get_audio(self, stems, identifier: Dict[str, Any]):
        audio = {}

        for stem in stems:
            audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)

        start = identifier["chunk_start"]
        end = start + self.chunk_size_samples
        audio = self._chunk_audio(audio, start, end)

        return audio

    def find_query_tuples_or_generate(self):
        query_path = os.path.join(self.data_path, "queries")
        val_folds = "-".join(map(str, self.folds))

        path_so_far = os.path.join(query_path, val_folds)

        if not os.path.exists(path_so_far):
            return self.generate_index()

        chunk_specs = f"chunk{self.chunk_size_samples}-hop{self.chunk_hop_size_samples}"
        path_so_far = os.path.join(path_so_far, chunk_specs)

        if not os.path.exists(path_so_far):
            return self.generate_index()

        query_specs = f"query{self.query_size_samples}-n{self.n_queries_per_chunk}"
        path_so_far = os.path.join(path_so_far, query_specs)

        if not os.path.exists(path_so_far):
            return self.generate_index()

        if self.top_k_instrument is not None:
            path_so_far = os.path.join(
                path_so_far, f"queries-top{self.top_k_instrument}.csv"
            )
        else:
            if len(self.allowed_stems) > 5:
                allowed_stems = (
                    str(len(self.allowed_stems))
                    + "stems:"
                    + ":".join([k[0] for k in self.allowed_stems if k != "mixture"])
                )
            else:
                allowed_stems = ":".join(self.allowed_stems)

            path_so_far = os.path.join(path_so_far, f"queries-{allowed_stems}.csv")

        if not os.path.exists(path_so_far):
            return self.generate_index()

        print(f"Loading query tuples from {path_so_far}")

        return pd.read_csv(path_so_far)

    def _get_index_path(self):
        query_root = os.path.join(self.data_path, "queries")
        val_folds = "-".join(map(str, self.folds))
        chunk_specs = f"chunk{self.chunk_size_samples}-hop{self.chunk_hop_size_samples}"
        query_specs = f"query{self.query_size_samples}-n{self.n_queries_per_chunk}"
        query_dir = os.path.join(query_root, val_folds, chunk_specs, query_specs)

        if self.top_k_instrument is not None:
            query_path = os.path.join(
                query_dir, f"queries-top{self.top_k_instrument}.csv"
            )
        else:
            if len(self.allowed_stems) > 5:
                allowed_stems = (
                    str(len(self.allowed_stems))
                    + "stems:"
                    + ":".join([k[0] for k in self.allowed_stems if k != "mixture"])
                )
            else:
                allowed_stems = ":".join(self.allowed_stems)
            query_path = os.path.join(query_dir, f"queries-{allowed_stems}.csv")

        if not self._overwrite:
            assert not os.path.exists(
                query_path
            ), f"Query path {query_path} already exists."

        os.makedirs(query_dir, exist_ok=True)

        return query_path

    def generate_index(self):

        query_path = self._get_index_path()

        durations = pd.read_csv(os.path.join(self.data_path, "durations.csv"))
        durations = (
            durations[["song_id", "duration"]]
            .set_index("song_id")["duration"]
            .to_dict()
        )

        tuples = []

        stems_without_queries = defaultdict(list)

        for i, song_id in tqdm(enumerate(self.files), total=len(self.files)):
            song_duration = durations[song_id]
            mix_stems = self.song_to_stem[song_id]

            n_mix_chunks = math.floor(
                (song_duration - self.chunk_size_seconds) / self.chunk_hop_size_seconds
            )

            for stem in mix_stems:
                possible_queries = self.stem_to_song[stem]
                if song_id in possible_queries:
                    possible_queries.remove(song_id)

                if len(possible_queries) == 0:
                    stems_without_queries[song_id].append(stem)
                    continue

                for k in tqdm(range(n_mix_chunks), desc=f"song{i + 1}/{stem}"):
                    mix_chunk_start = int(k * self.chunk_hop_size_samples)

                    for j in range(self.n_queries_per_chunk):
                        query = random.choice(possible_queries)

                        tuples.append(
                            dict(
                                mix=song_id,
                                query=query,
                                stem=stem,
                                mix_chunk_start=mix_chunk_start,
                            )
                        )

        if len(stems_without_queries) > 0:
            print("Stems without queries:")
            for song_id, stems in stems_without_queries.items():
                print(f"{song_id}: {stems}")

        tuples = pd.DataFrame(tuples)

        print(
            f"Generating query tuples for {self.split} set with {len(tuples)} tuples."
        )
        print(f"Saving query tuples to {query_path}")

        tuples.to_csv(query_path, index=False)

        return tuples

    def index_to_identifiers(self, index: int) -> Tuple[str, str, str, int]:

        row = self.query_tuples.iloc[index]
        mix_id = row["mix"]

        if self.use_own_query:
            query_id = mix_id
        else:
            query_id = row["query"]

        stem = row["stem"]
        mix_chunk_start = row["mix_chunk_start"]

        return mix_id, query_id, stem, mix_chunk_start

    def __getitem__(self, index: int):

        mix_id, query_id, stem, mix_chunk_start = self.index_to_identifiers(index)

        mix_identifier = dict(song_id=mix_id, chunk_start=mix_chunk_start)
        query_identifier = dict(song_id=query_id)

        audio = self._get_audio([stem, self.mixture_stem], identifier=mix_identifier)
        query = self.get_query_stem(stem=stem, identifier=query_identifier)

        mixture = audio[self.mixture_stem].copy()
        sources = {"target": audio[stem].copy()}
        query = query.copy()

        assert mixture.shape[-1] == self.chunk_size_samples
        # print(query.shape[-1], self.query_size_samples)
        assert query.shape[-1] == self.query_size_samples
        assert sources["target"].shape[-1] == self.chunk_size_samples

        return input_dict(
            mixture=mixture,
            sources=sources,
            query=query,
            metadata={
                "mix": mix_identifier,
                "query": query_identifier,
                "stem": stem,
            },
            modality="audio",
        )


class MoisesDBFullTrackTestQueryDataset(MoisesDBFullTrackDataset):
    def __init__(
        self,
        data_root: str,
        split: str = "test",
        top_k_instrument: int = 10,
        mixture_stem: str = "mixture",
        use_own_query: bool = True,
        npy_memmap=True,
        allowed_stems: List[str] = None,
        query_file="query-10s",
    ) -> None:
        super().__init__(
            data_root=data_root,
            split=split,
            npy_memmap=npy_memmap,
            recompute_mixture=False,
            query_file=query_file,
        )

        self.use_own_query = use_own_query

        self.allowed_stems = allowed_stems

        test_indices = pd.read_csv(os.path.join(data_root, "test_indices.csv"))

        test_indices = test_indices[test_indices.stem.isin(self.allowed_stems)]

        self.test_indices = test_indices

        self.length = len(self.test_indices)

    def __len__(self) -> int:
        return self.length

    def index_to_identifiers(self, index: int) -> Tuple[str, str, str]:

        row = self.test_indices.iloc[index]
        mix_id = row["song_id"]
        if self.use_own_query:
            query_id = mix_id
        else:
            query_id = row["query_id"]
        stem = row["stem"]

        return mix_id, query_id, stem

    def _get_audio(self, stems, identifier: Dict[str, Any]):
        audio = {}

        for stem in stems:
            audio[stem] = self.get_full_stem(stem=stem, identifier=identifier)

        return audio

    def __getitem__(self, index: int):

        mix_id, query_id, stem = self.index_to_identifiers(index)

        mix_identifier = dict(song_id=mix_id)

        query_identifier = dict(song_id=query_id)

        audio = self._get_audio([stem, "mixture"], identifier=mix_identifier)
        query = self.get_query_stem(stem=stem, identifier=query_identifier)

        mixture = audio["mixture"].copy()
        sources = {stem: audio[stem].copy()}
        query = query.copy()

        return input_dict(
            mixture=mixture,
            sources=sources,
            query=query,
            metadata={
                "mix": mix_identifier["song_id"],
                "query": query_identifier["song_id"],
                "stem": stem,
            },
            modality="audio",
        )


if __name__ == "__main__":

    print("Beginning")

    config = "config\data\setup-c\moisesdb-everything-query-d-aug.yml"

    config = OmegaConf.load(config)

    print("Loaded config")

    dataset = MoisesDBRandomChunkRandomQueryDataset(
        data_root=config.data_root, split="train", **config.train_kwargs
    )

    print("Loaded dataset")

    for item in tqdm(dataset, total=len(dataset)):
        target_audio = item["sources"]["target"]["audio"]
        mixture = item["mixture"]["audio"]

        if target_audio is None:
            raise ValueError
        else:
            tdb = 10.0 * torch.log10(torch.mean(torch.square(target_audio)) + 1e-6)
            mdb = 10.0 * torch.log10(torch.mean(torch.square(mixture)) + 1e-6)
            print(f"Target db: {tdb}, Mixture db: {mdb}")

  config = "config\data\setup-c\moisesdb-everything-query-d-aug.yml"
  config = "config\data\setup-c\moisesdb-everything-query-d-aug.yml"


Beginning
Loaded config


InterpolationResolutionError: KeyError raised while resolving interpolation: "Environment variable 'DATA_ROOT' not found"
    full_key: data_root
    object_type=dict

In [31]:
# Moises data modules
import os.path
from typing import Mapping, Optional

import pytorch_lightning as pl

# from core.data.base import from_datasets
# from core.data.moisesdb.dataset import MoisesDBRandomChunkBalancedRandomQueryDataset, MoisesDBRandomChunkRandomQueryDataset, \
#     MoisesDBDeterministicChunkDeterministicQueryDataset, \
#     MoisesDBFullTrackDataset, MoisesDBVDBODeterministicChunkDataset, \
#     MoisesDBVDBOFullTrackDataset, MoisesDBVDBORandomChunkDataset, \
#     MoisesDBFullTrackTestQueryDataset
    
def MoisesDataModule(
    data_root: str,
    batch_size: int,
    num_workers: int = 8,
    train_kwargs: Optional[Mapping] = None,
    val_kwargs: Optional[Mapping] = None,
    test_kwargs: Optional[Mapping] = None,
    datamodule_kwargs: Optional[Mapping] = None,
) -> pl.LightningDataModule:
    if train_kwargs is None:
        train_kwargs = {}

    if val_kwargs is None:
        val_kwargs = {}

    if test_kwargs is None:
        test_kwargs = {}

    if datamodule_kwargs is None:
        datamodule_kwargs = {}

    train_dataset = MoisesDBRandomChunkRandomQueryDataset(
        data_root=data_root, split="train", **train_kwargs
    )

    val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
        data_root=data_root, split="val", **val_kwargs
    )

    test_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
        data_root=data_root, split="test", **test_kwargs
    )

    datamodule = from_datasets(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        **datamodule_kwargs
    )

    datamodule.predict_dataloader = (  # type: ignore[method-assign]
        datamodule.test_dataloader
    )

    return datamodule

def MoisesBalancedTrainDataModule(
    data_root: str,
    batch_size: int,
    num_workers: int = 8,
    train_kwargs: Optional[Mapping] = None,
    val_kwargs: Optional[Mapping] = None,
    test_kwargs: Optional[Mapping] = None,
    datamodule_kwargs: Optional[Mapping] = None,
) -> pl.LightningDataModule:
    if train_kwargs is None:
        train_kwargs = {}

    if val_kwargs is None:
        val_kwargs = {}

    if test_kwargs is None:
        test_kwargs = {}

    if datamodule_kwargs is None:
        datamodule_kwargs = {}

    train_dataset = MoisesDBRandomChunkBalancedRandomQueryDataset(
        data_root=data_root, split="train", **train_kwargs
    )

    val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
        data_root=data_root, split="val", **val_kwargs
    )

    test_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
        data_root=data_root, split="test", **test_kwargs
    )

    datamodule = from_datasets(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        **datamodule_kwargs
    )

    datamodule.predict_dataloader = (  # type: ignore[method-assign]
        datamodule.test_dataloader
    )

    return datamodule
    

def MoisesValidationDataModule(
    data_root: str,
    batch_size: int,
    num_workers: int = 8,
    val_kwargs: Optional[Mapping] = None,
    datamodule_kwargs: Optional[Mapping] = None,
    **kwargs
) -> pl.LightningDataModule:
    if val_kwargs is None:
        val_kwargs = {}

    if datamodule_kwargs is None:
        datamodule_kwargs = {}
        
    allowed_stems = val_kwargs.get("allowed_stems", None)
    
    assert allowed_stems is not None, "allowed_stems must be provided"
    
    val_datasets = []
    
    for allowed_stem in allowed_stems:
        kwargs = val_kwargs.copy()
        kwargs["allowed_stems"] = [allowed_stem]
        val_dataset = MoisesDBDeterministicChunkDeterministicQueryDataset(
            data_root=data_root, split="val", 
            **kwargs
        )
        
        val_datasets.append(val_dataset)

    datamodule = from_datasets(
        val_dataset=val_datasets,
        batch_size=batch_size,
        num_workers=num_workers,
        **datamodule_kwargs
    )

    datamodule.predict_dataloader = (  # type: ignore[method-assign]
        datamodule.val_dataloader
    )

    return datamodule

def MoisesTestDataModule(
    data_root: str,
    batch_size: int = 1,
    num_workers: int = 8,
    test_kwargs: Optional[Mapping] = None,
    datamodule_kwargs: Optional[Mapping] = None,
    **kwargs
) -> pl.LightningDataModule:
    if test_kwargs is None:
        test_kwargs = {}

    if datamodule_kwargs is None:
        datamodule_kwargs = {}
        
    allowed_stems = test_kwargs.get("allowed_stems", None)
    
    assert allowed_stems is not None, "allowed_stems must be provided"

    test_dataset = MoisesDBFullTrackTestQueryDataset(
        data_root=data_root, split="test",
        **test_kwargs
    )

    datamodule = from_datasets(
        test_dataset=test_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        **datamodule_kwargs
    )

    datamodule.predict_dataloader = (  # type: ignore[method-assign]
        datamodule.test_dataloader
    )

    return datamodule


def MoisesVDBODataModule(
    data_root: str,
    batch_size: int,
    num_workers: int = 8,
    train_kwargs: Optional[Mapping] = None,
    val_kwargs: Optional[Mapping] = None,
    test_kwargs: Optional[Mapping] = None,
    datamodule_kwargs: Optional[Mapping] = None,
):
    
    
    if train_kwargs is None:
        train_kwargs = {}

    if val_kwargs is None:
        val_kwargs = {}

    if test_kwargs is None:
        test_kwargs = {}

    if datamodule_kwargs is None:
        datamodule_kwargs = {}
        
    train_dataset = MoisesDBVDBORandomChunkDataset(
        data_root=data_root, split="train", **train_kwargs
    )
    
    val_dataset = MoisesDBVDBODeterministicChunkDataset(
        data_root=data_root, split="val", **val_kwargs
    )
    
    test_dataset = MoisesDBVDBOFullTrackDataset(
        data_root=data_root, split="test", **test_kwargs
    )
    
    predict_dataset = test_dataset
    
    datamodule = from_datasets(
        train_dataset=train_dataset,
        val_dataset=val_dataset,
        test_dataset=test_dataset,
        predict_dataset=predict_dataset,
        batch_size=batch_size,
        num_workers=num_workers,
        **datamodule_kwargs
    )
    
    return datamodule
    
    

#### **npyify.py**

In [33]:
from collections import defaultdict
import glob
import json
import math
import os
import shutil
from itertools import chain
from pprint import pprint
from types import SimpleNamespace
import numpy as np
import pandas as pd

from omegaconf import OmegaConf

from tqdm.contrib.concurrent import process_map

from tqdm import tqdm as tdqm, tqdm
import torchaudio as ta

import librosa

taxonomy = {
    "vocals": [
        "lead male singer",
        "lead female singer",
        "human choir",
        "background vocals",
        "other (vocoder, beatboxing etc)",
    ],
    "bass": [
        "bass guitar",
        "bass synthesizer (moog etc)",
        "contrabass/double bass (bass of instrings)",
        "tuba (bass of brass)",
        "bassoon (bass of woodwind)",
    ],
    "drums": [
        "snare drum",
        "toms",
        "kick drum",
        "cymbals",
        "overheads",
        "full acoustic drumkit",
        "drum machine",
        "hi-hat"
    ],
    "other": [
        "fx/processed sound, scratches, gun shots, explosions etc",
        "click track",
    ],
    "guitar": [
        "clean electric guitar",
        "distorted electric guitar",
        "lap steel guitar or slide guitar",
        "acoustic guitar",
    ],
    "other plucked": ["banjo, mandolin, ukulele, harp etc"],
    "percussion": [
        "a-tonal percussion (claps, shakers, congas, cowbell etc)",
        "pitched percussion (mallets, glockenspiel, ...)",
    ],
    "piano": [
        "grand piano",
        "electric piano (rhodes, wurlitzer, piano sound alike)",
    ],
    "other keys": [
        "organ, electric organ",
        "synth pad",
        "synth lead",
        "other sounds (hapischord, melotron etc)",
    ],
    "bowed strings": [
        "violin (solo)",
        "viola (solo)",
        "cello (solo)",
        "violin section",
        "viola section",
        "cello section",
        "string section",
        "other strings",
    ],
    "wind": [
        "brass (trumpet, trombone, french horn, brass etc)",
        "flutes (piccolo, bamboo flute, panpipes, flutes etc)",
        "reeds (saxophone, clarinets, oboe, english horn, bagpipe)",
        "other wind",
    ],
}

def clean_npy_other_vox(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npyq"):
    npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)
    
    
    npys = [npy for npy in npys if "other" in npy]
    npys = [npy for npy in npys if "vdbo_" not in npy]
    npys = [npy for npy in npys if "other_" not in npy]

    stems = set([
        os.path.basename(npy).split(".")[0] for npy in npys
    ])
    
    assert len(stems) == 1
    
    for npy in tqdm(npys):
        shutil.move(npy, npy.replace("other", "other_vocals"))
    
    


def clean_track_inst(inst):
    
    if "vocoder" in inst:
        inst = "other_vocals"

    if "fx" in inst:
        inst = "fx"

    if "contrabass_double_bass" in inst:
        inst = "double_bass"

    if "banjo" in inst:
        return "other_plucked"

    if "(" in inst:
        inst = inst.split("(")[0]

    for s in [",", "-"]:
        if s in inst:
            inst = inst.replace(s, "")

    for s in ["/"]:
        if s in inst:
            inst = inst.replace(s, "_")

    if inst[-1] == "_":
        inst = inst[:-1]

    return inst


taxonomy = {
    k.replace(" ", "_"): [clean_track_inst(i.replace(" ", "_")) for i in v] for k, v in taxonomy.items()
}

fine_to_coarse = {}

for k, v in taxonomy.items():
    for vv in v:
        fine_to_coarse[vv] = k

# pprint(fine_to_coarse)

def save_taxonomy():
    with open("taxonomy.json", "w") as f:
        json.dump(taxonomy, f, indent=4)

    taxonomy_coarse = list(taxonomy.keys())
    
    with open("taxonomy_coarse.json", "w") as f:
        json.dump(taxonomy_coarse, f, indent=4)
        
    taxonomy_fine = list(chain(*taxonomy.values()))
    
    count_ = defaultdict(int)
    for t in taxonomy_fine:
        count_[t] += 1
        
    with open("taxonomy_fine.json", "w") as f:
        json.dump(taxonomy_fine, f, indent=4)
    


possible_coarse = list(taxonomy.keys())
possible_fine = list(set(chain(*taxonomy.values())))


def trim_and_mix(audios, length_=None):
    length = min([a.shape[-1] for a in audios])
    
    if length_ is not None:
        length = min(length, length_)
    
    audios = [a[..., :length] for a in audios]
    return np.sum(np.stack(audios, axis=0), axis=0), length


def retrim_npys(saved_npy, new_length):
    print("retrimming")
    for npy in saved_npy:
        audio = np.load(npy)
        audio = audio[..., :new_length]
        np.save(npy, audio)


def convert_one(inout):
    input_path = inout.input_path
    output_root = inout.output_root

    song_id = os.path.basename(input_path)
    output_root = os.path.join(output_root, song_id)
    os.makedirs(output_root, exist_ok=True)

    metadata = OmegaConf.load(os.path.join(input_path, "data.json"))
    stems = metadata.stems

    min_length = None
    saved_npy = []

    all_tracks = []
    other_tracks = []
    
    outfile = None
    
    added_tracks = set()
    duplicated_tracks = set()
    track_to_stem = defaultdict(list)
    added_stems = set()
    duplicated_stems = set()
    
    stem_name_to_stems = defaultdict(list)
    
    for stem in stems:
        stem_name = stem.stemName
        stem_name_to_stems[stem_name].append(stem)
    
        
    for stem_name in tqdm(stem_name_to_stems):
        stem_tracks = []
        for stem in stem_name_to_stems[stem_name]:
            stem_name = stem.stemName
            
            if stem_name in added_stems:
                print(f"Duplicate stem {stem_name} in {song_id}")
                duplicated_stems.add(stem_name)
            
            added_stems.add(stem_name)
            
            for track in stem.tracks:
                track_inst = track.trackType
                track_inst = clean_track_inst(track_inst)
                
                if track_inst in added_tracks:
                    if stem_name in track_to_stem[track_inst]:
                        continue
                    print(f"Duplicate track {track_inst} in {song_id}")
                    print(f"Stems: {track_to_stem[track_inst]}")
                    duplicated_tracks.add(track_inst)
                    raise ValueError
                else:
                    added_tracks.add(track_inst)
                    
                track_to_stem[track_inst].append(stem_name)
                track_id = track.id
                
                audio, fs = ta.load(os.path.join(input_path, stem_name, f"{track_id}.wav"))

                if fs != 44100:
                    print(f"fs is {fs} for {track_id}")
                    with open(os.path.join(output_root, "fs.txt"), "w") as f:
                        f.write(f"{song_id}\t{track_id}\t{fs}\n")

                if min_length is None:
                    min_length = audio.shape[-1]
                else:
                    if audio.shape[-1] < min_length:
                        min_length = audio.shape[-1]

                        if len(saved_npy) > 0:
                            retrim_npys(saved_npy, min_length)

                audio = audio[..., :min_length]
                audio = audio.numpy()
                audio = audio.astype(np.float32)

                if audio.shape[0] == 1:
                    print("mono")
                if audio.shape[0] > 2:
                    print("multi channel")

                assert outfile is None
                outfile = os.path.join(output_root, f"{track_inst}.npy")
                np.save(outfile, audio)
                saved_npy.append(outfile)
                outfile = None
                stem_tracks.append(audio)
                audio = None
                
        stem_track, min_length = trim_and_mix(stem_tracks)

        assert outfile is None
        outfile = os.path.join(output_root, f"{stem_name}.npy")
        np.save(outfile, stem_track)
        saved_npy.append(outfile)
        outfile = None
        
        all_tracks.append(stem_track)
        
        if stem_name not in ["vocals", "drums", "bass"]:
            # print(f"Putting {stem_name} in other")
            other_tracks.append(stem_track)
            
        
    assert outfile is None
    all_track, min_length_ = trim_and_mix(all_tracks, min_length)
    outfile = os.path.join(output_root, f"mixture.npy")
    np.save(outfile, all_track)
    
    if min_length_ != min_length:
        retrim_npys(saved_npy, min_length_)
        min_length = min_length_
    
    saved_npy.append(outfile)
    outfile = None
    
    other_track, min_length_ = trim_and_mix(other_tracks, min_length)
    np.save(os.path.join(output_root, f"vdbo_others.npy"), other_track)
    
    if min_length_ != min_length:
        retrim_npys(saved_npy, min_length_)
        min_length = min_length_


def convert_to_npy(
    data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/canonical",
    output_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2",
):
    if output_root is None:
        output_root = os.path.join(os.path.dirname(data_root), "npy")

    files = os.listdir(data_root)
    files = [
        os.path.join(data_root, f)
        for f in files
        if os.path.isdir(os.path.join(data_root, f))
    ]

    inout = [SimpleNamespace(input_path=f, output_root=output_root) for f in files]

    process_map(convert_one, inout)

    # for io in tdqm(inout):
    #     convert_one(io)


def make_others_one(input_path, dry_run=False):

    other_stems = [k for k in taxonomy.keys() if k not in ["vocals", "bass", "drums"]]
    npys = glob.glob(os.path.join(input_path, "**/*.npy"), recursive=True)

    npys = [npy for npy in npys if ".dbfs" not in npy]
    npys = [npy for npy in npys if ".query" not in npy]
    npys = [npy for npy in npys if "mixture" not in npy]
    npys = [npy for npy in npys if os.path.basename(npy).split(".")[0] in other_stems]

    print(f"Using stems: {[os.path.basename(npy).split('.')[0] for npy in npys]}")

    if len(npys) == 0:
        audio = np.zeros_like(np.load(os.path.join(input_path, "mixture.npy")))
    else:
        audio = [np.load(npy) for npy in npys]

        audio = np.sum(np.stack(audio, axis=0), axis=0)
    assert audio.shape[0] == 2

    output = os.path.join(input_path, "vdbo_others.npy")

    if dry_run:
        return

    np.save(output, audio)


def check_vdbo_one(f):
    s = np.sum(
        np.stack(
            [
                np.load(os.path.join(f, s + ".npy"))
                for s in ["vocals", "drums", "bass", "vdbo_others"]
                if os.path.exists(os.path.join(f, s + ".npy"))
            ],
            axis=0,
        ),
        axis=0,
    )
    m = np.load(os.path.join(f, "mixture.npy"))
    snr = 10 * np.log10(np.mean(np.square(m)) / np.mean(np.square(s - m)))
    print(snr)
    
    return snr

def check_vdbo(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2"):
    files = os.listdir(data_root)

    files = [
        os.path.join(data_root, f)
        for f in files
        if os.path.isdir(os.path.join(data_root, f))
    ]

    snrs = process_map(check_vdbo_one, files)

    np.save("/storage/home/hcoda1/1/kwatchar3/data/vdbo.npy", np.array(snrs))


def make_others(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2"):

    files = os.listdir(data_root)

    files = [
        os.path.join(data_root, f)
        for f in files
        if os.path.isdir(os.path.join(data_root, f))
    ]

    process_map(make_others_one, files)

    # for f in tqdm(files):
    #     make_others_one(f, dry_run=False)


def extract_metadata_one(input_path):
    song_id = os.path.basename(input_path)
    metadata = OmegaConf.load(os.path.join(input_path, "data.json"))

    song = metadata.song
    artist = metadata.artist
    genre = metadata.genre

    stems = metadata.stems
    data_out = []

    for stem in stems:
        stem_name = stem.stemName
        stem_id = stem.id
        for track in stem.tracks:
            track_inst = track.trackType
            track_id = track.id

            data_out.append(
                {
                    "song_id": song_id,
                    "song": song,
                    "artist": artist,
                    "genre": genre,
                    "stem_name": stem_name,
                    "stem_id": stem_id,
                    "track_inst": track_inst,
                    "track_id": track_id,
                    "has_bleed": track.has_bleed,
                }
            )

    return data_out


def consolidate_metadata(
    data_root="/home/kwatchar3/Documents/data/moisesdb/canonical",
):

    files = os.listdir(data_root)
    files = [
        os.path.join(data_root, f)
        for f in files
        if os.path.isdir(os.path.join(data_root, f))
    ]

    data = process_map(extract_metadata_one, files)

    df = pd.DataFrame.from_records(list(chain(*data)))

    df.to_csv(os.path.join(os.path.dirname(data_root), "metadata.csv"), index=False)


def clean_canonical(data_root="/home/kwatchar3/Documents/data/moisesdb/canonical"):

    npy = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)

    for n in tqdm(npy):
        os.remove(n)


def remove_dbfs(data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy"):
    npy = glob.glob(os.path.join(data_root, "**/*.dbfs.npy"), recursive=True)

    for n in tqdm(npy):
        os.remove(n)


def make_split(
    metadata_path="/home/kwatchar3/Documents/data/moisesdb/metadata.csv",
    n_splits=5,
    seed=42,
):

    df = pd.read_csv(metadata_path)
    # print(df.columns)
    df = df[["song_id", "genre"]].drop_duplicates()

    genres = df["genre"].value_counts()
    genres_map = {g: g if c > n_splits else "other" for g, c in genres.items()}

    df["genre"] = df["genre"].map(genres_map)

    n_samples = len(df)
    n_per_split = n_samples // n_splits

    np.random.seed(seed)

    from sklearn.model_selection import train_test_split

    splits = []

    df_ = df.copy()

    for i in range(n_splits - 1):
        df_, test = train_test_split(
            df_,
            test_size=n_per_split,
            random_state=seed,
            stratify=df_["genre"],
            shuffle=True,
        )

        dfs = test[["song_id"]].copy().sort_values(by="song_id")
        dfs["split"] = i + 1
        splits.append(dfs)

    test = df_
    dfs = test[["song_id"]].copy().sort_values(by="song_id")
    dfs["split"] = n_splits
    splits.append(dfs)

    splits = pd.concat(splits)

    splits.to_csv(
        os.path.join(os.path.dirname(metadata_path), "splits.csv"), index=False
    )


def consolidate_stems(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):

    metadata = pd.read_csv(os.path.join(os.path.dirname(data_root), "metadata.csv"))

    dfg = metadata.groupby("song_id")[["stem_name", "track_inst"]]

    pprint(dfg)

    df = []

    def make_stem_dict(song_id, track_inst, stem_names):

        d = {"song_id": song_id}

        for inst in possible_fine:
            d[inst] = int(inst in track_inst)

        for inst in possible_coarse:
            d[inst] = int(inst in stem_names)

        return d

    for song_id, dfgg in dfg:

        track_inst = dfgg["track_inst"].tolist()
        track_inst = list(set(track_inst))
        track_inst = [clean_track_inst(inst) for inst in track_inst]

        stem_names = dfgg["stem_name"].tolist()
        stem_names = list(set([clean_track_inst(inst) for inst in stem_names]))

        d = make_stem_dict(song_id, track_inst, stem_names)
        df.append(d)

    print(df)

    df = pd.DataFrame.from_records(df)

    df.to_csv(os.path.join(os.path.dirname(data_root), "stems.csv"), index=False)


def get_dbfs(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
    npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)

    dbfs = []

    for npy in tqdm(npys):
        audio = np.load(npy)
        song_id = os.path.basename(os.path.dirname(npy))
        track_id = os.path.basename(npy).split(".")[0]

        dbfs.append(
            {
                "song_id": song_id,
                "track_id": track_id,
                "dbfs": 10 * np.log10(np.mean(np.square(audio))),
            }
        )

    dbfs = pd.DataFrame.from_records(dbfs)

    dbfs.to_csv(os.path.join(os.path.dirname(data_root), "dbfs.csv"), index=False)

    return dbfs


def get_dbfs_by_chunk_one(inout):

    audio = np.load(inout.audio_path, mmap_mode="r")
    chunk_size = inout.chunk_size
    fs = inout.fs
    hop_size = inout.hop_size

    n_chan, n_samples = audio.shape
    chunk_size_samples = int(chunk_size * fs)
    hop_size_samples = int(hop_size * fs)

    x2win = np.lib.stride_tricks.sliding_window_view(
        np.square(audio), chunk_size_samples, axis=1
    )[:, ::hop_size_samples, :]

    x2win_mean = np.mean(x2win, axis=(0, 2))
    x2win_mean[x2win_mean == 0] = 1e-8
    dbfs = 10 * np.log10(x2win_mean)

    # song_id = os.path.basename(os.path.dirname(inout.audio_path))
    track_id = os.path.basename(inout.audio_path).split(".")[0]

    np.save(
        os.path.join(os.path.dirname(inout.audio_path), f"{track_id}.dbfs.npy"), dbfs
    )


def clean_data_root(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
    npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)

    for npy in tqdm(npys):
        if ".dbfs" in npy or ".query" in npy:
            # print("removing", npy)
            os.remove(npy)


#
def get_dbfs_by_chunk(
    data_root="/home/kwatchar3/Documents/data/moisesdb/npy",
    query_root="/home/kwatchar3/Documents/data/moisesdb/npyq",
):
    npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)

    inout = [
        SimpleNamespace(
            audio_path=npy,
            chunk_size=1,
            hop_size=0.125,
            fs=44100,
            output_path=npy.replace(data_root, query_root).replace(
                ".npy", ".query.npy"
            ),
        )
        for npy in npys
    ]

    process_map(get_dbfs_by_chunk_one, inout, chunksize=2)


def round_samples(seconds, fs, hop_size, downsample):
    n_frames = math.ceil(seconds * fs / hop_size) + 1
    n_frames_down = math.ceil(n_frames / downsample)
    n_frames = n_frames_down * downsample
    n_samples = (n_frames - 1) * hop_size

    return int(n_samples)


def get_query_one(inout):

    audio = np.load(inout.audio_path, mmap_mode="r")
    chunk_size = inout.chunk_size
    fs = inout.fs
    output_path = inout.output_path
    round = inout.round
    hop_size = inout.hop_size

    if round:
        chunk_size_samples = round_samples(chunk_size, fs, 512, 2**6)
    else:
        chunk_size_samples = int(chunk_size * fs)

    audio_mono = np.mean(audio, axis=0)

    onset = librosa.onset.onset_detect(
        y=audio_mono, sr=fs, units="frames", hop_length=hop_size
    )

    onset_strength = librosa.onset.onset_strength(
        y=audio_mono, sr=fs, hop_length=hop_size
    )

    n_frames_per_chunk = chunk_size_samples // hop_size

    onset_strength_slide = np.lib.stride_tricks.sliding_window_view(
        onset_strength, n_frames_per_chunk, axis=0
    )

    onset_strength = np.mean(onset_strength_slide, axis=1)

    max_onset_frame = np.argmax(onset_strength)

    max_onset_samples = librosa.frames_to_samples(max_onset_frame)

    track_id = os.path.basename(inout.audio_path).split(".")[0]

    segment = audio[:, max_onset_samples : max_onset_samples + chunk_size_samples]

    os.makedirs(os.path.dirname(output_path), exist_ok=True)

    np.save(output_path, segment)


def get_query_from_onset(
    data_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npy2",  # "/home/kwatchar3/Documents/data/moisesdb/npy",
    query_root="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/npyq",  # "/home/kwatchar3/Documents/data/moisesdb/npyq",
    query_file="query-10s",
    pmap=True,
):
    npys = glob.glob(os.path.join(data_root, "**/*.npy"), recursive=True)

    npys = [npy for npy in npys if "dbfs" not in npy]

    inout = [
        SimpleNamespace(
            audio_path=npy,
            chunk_size=10,
            hop_size=512,
            round=False,
            fs=44100,
            output_path=npy.replace(data_root, query_root).replace(
                ".npy", f".{query_file}.npy"
            ),
        )
        for npy in npys
    ]

    if pmap:
        process_map(get_query_one, inout, chunksize=2, max_workers=24)
    else:
        for io in tqdm(inout):
            get_query_one(io)


def get_durations(data_root="/home/kwatchar3/Documents/data/moisesdb/npy"):
    npys = glob.glob(os.path.join(data_root, "**/mixture.npy"), recursive=True)

    durations = []

    for npy in tqdm(npys):
        audio = np.load(npy, mmap_mode="r")
        song_id = os.path.basename(os.path.dirname(npy))
        track_id = os.path.basename(npy).split(".")[0]

        durations.append(
            {
                "song_id": song_id,
                "track_id": track_id,
                "duration": audio.shape[-1] / 44100,
            }
        )

    durations = pd.DataFrame.from_records(durations)

    durations.to_csv(
        os.path.join(os.path.dirname(data_root), "durations.csv"), index=False
    )

    return durations


def clean_query_root(
    data_root="/home/kwatchar3/Documents/data/moisesdb/npy",
    query_root="/home/kwatchar3/Documents/data/moisesdb/npyq",
):
    npys = glob.glob(os.path.join(data_root, "**/*.query.npy"), recursive=True)

    for npy in tqdm(npys):
        dst = npy.replace(data_root, query_root)
        dstdir = os.path.dirname(dst)
        os.makedirs(dstdir, exist_ok=True)
        shutil.move(npy, dst)


def make_test_indices(
    metadata_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/metadata.csv",
    stem_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/stems.csv",
    splits_path="/storage/home/hcoda1/1/kwatchar3/data/data/moisesdb/splits.csv",
    test_split=5,
):
    
    coarse_stems = set(taxonomy.keys())
    fine_stems = set(chain(*taxonomy.values()))

    metadata = pd.read_csv(metadata_path)
    splits = pd.read_csv(splits_path)
    stems = pd.read_csv(stem_path)

    file_in_test = splits[splits["split"] == test_split]["song_id"].tolist()
    
    stems_test = stems[stems["song_id"].isin(file_in_test)]
    metadata_test = metadata[metadata["song_id"].isin(file_in_test)]
    splits_test = splits[splits["split"] == test_split]
    
    stems_test = stems_test.set_index("song_id")
    metadata_test = metadata_test.drop_duplicates("song_id").set_index("song_id")
    splits_test = splits_test.set_index("song_id")
    
    stem_to_song_id = defaultdict(list)
    song_id_to_stem = defaultdict(list)
    
    for song_id in file_in_test:
        
        stems_ = stems_test.loc[song_id]
        stem_names = stems_.T
        stem_names = stem_names[stem_names == 1].index.tolist()
        
        for stem in stem_names:
            stem_to_song_id[stem].append(song_id)
            
        song_id_to_stem[song_id] = stem_names
        
        
    indices = []
    no_query = []
    
    for song_id in file_in_test:
        
        genre = metadata_test.loc[song_id, "genre"]
        # print(genre)
        artist = metadata_test.loc[song_id, "artist"]
        # print(artist)
        
        stems_ = song_id_to_stem[song_id]
        
        for stem in stems_:
            possible_query = stem_to_song_id[stem]
            possible_query = [p for p in possible_query if p != song_id]
            
            if len(possible_query) == 0:
                print(f"No possible query for {song_id} with {stem}")
                
                no_query.append(
                    {
                        "song_id": song_id,
                        "stem": stem
                    }
                )
                continue
            
            query_df = metadata_test.loc[possible_query, ["genre", "artist"]]
            
            assert len(query_df) > 0
            
            query_df_ = query_df.copy()
            
            same_genre = True
            different_artist = True
            query_df = query_df[(query_df["genre"] == genre) & (query_df["artist"] != artist)]
            
            if len(query_df) == 0:
                
                same_genre = False
                different_artist = True
                
                query_df = query_df_.copy()
                query_df = query_df[(query_df["artist"] != artist)]
            
            if len(query_df) == 0:
                
                same_genre = True
                different_artist = False
                
                query_df = query_df_.copy()
                query_df = query_df[(query_df["genre"] == genre)]
            
            if len(query_df) == 0:
                
                same_genre = False
                different_artist = False
                
                query_df = query_df_.copy()
            
            query_id = query_df.sample(1).index[0]
            
            indices.append(
                {
                    "song_id": song_id,
                    "query_id": query_id,
                    "stem": stem,
                    "same_genre": same_genre,
                    "different_artist": different_artist
                }   
            )
            
    indices = pd.DataFrame.from_records(indices)
    no_query = pd.DataFrame.from_records(no_query)
    
    indices.to_csv(
        os.path.join(os.path.dirname(metadata_path), "test_indices.csv"), index=False
    )
    
    no_query.to_csv(
        os.path.join(os.path.dirname(metadata_path), "no_query.csv"), index=False
    )
    
    print("Total number of queries:", len(indices))
    print("Total number of no queries:", len(no_query))
    
    query_type = indices.groupby(["same_genre", "different_artist"]).size()
    
    print(query_type)


if __name__ == "__main__":
    import fire

    fire.Fire()

ERROR: Cannot find key: --f=c:\Users\Dell\AppData\Roaming\jupyter\runtime\kernel-v3f0d43beaf9ae7787c0ec59baf0b0092a7079c985.json
Usage: ipykernel_launcher.py <group|command|value>
  available groups:      In | Out | exit | quit | os | Callable | np | torch |
                         taF | pd | band_defs | mbs | df | nn | math |
                         input_data | Dict | List | Optional | Tuple | Type |
                         activation | ta | optim | tm | Union | pl | Iterator |
                         Mapping | lr_scheduler | inspect | Sequence | data |
                         taxonomy | random | INST_BY_OCCURRENCE |
                         FINE_LEVEL_INSTRUMENTS | COARSE_LEVEL_INSTRUMENTS |
                         COARSE_TO_FINE | FINE_TO_COARSE |
                         ALL_LEVEL_INSTRUMENTS | audiomentations | glob |
                         json | shutil | librosa | fine_to_coarse | v |
                         possible_coarse | possible_fine | fire
  available commands: 

FireExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


#### **Audio/Dataset & Passt EDA**

_____
# ***Loss Functions***

In [35]:
# Base
from typing import Dict, List, Optional, Union
from torch import nn
import torch
from torch.nn.modules.loss import _Loss

# from core.types import BatchedInputOutput
from torch.nn import functional as F

class BaseLossHandler(nn.Module):
    def __init__(
        self, loss: nn.Module, modality: Union[str, List[str]], name: Optional[str] = None
    ) -> None:
        super().__init__()

        self.loss = loss

        if isinstance(modality, str):
            modality = [modality]

        self.modality = modality

        if name is None:
            name = "loss"

        if name == "__auto__":
            name = self.loss.__class__.__name__

        self.name = name

    def _audio_preprocess(self, y_pred, y_true):

        n_sample_true = y_true.shape[-1]
        n_sample_pred = y_pred.shape[-1]

        if n_sample_pred > n_sample_true:
            y_pred = y_pred[..., :n_sample_true]
        elif n_sample_pred < n_sample_true:
            y_true = y_true[..., :n_sample_pred]

        return y_pred, y_true

    def forward(self, batch: BatchedInputOutput):
        y_true = batch.sources
        y_pred = batch.estimates

        loss_contribs = {}

        stem_contribs = {
            stem: 0.0 for stem in y_pred.keys()
        }

        for stem in y_pred.keys():
            for modality in self.modality:

                if modality not in y_pred[stem].keys():
                    continue

                if y_pred[stem][modality].shape[-1] == 0:
                    continue

                y_true_ = y_true[stem][modality]
                y_pred_ = y_pred[stem][modality]

                if modality == "audio":
                    y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)
                elif modality == "spectrogram":
                    y_pred_ = torch.view_as_real(y_pred_)
                    y_true_ = torch.view_as_real(y_true_)

                loss_contribs[f"{self.name}/{stem}/{modality}"] = self.loss(
                    y_pred_, y_true_
                )

                stem_contribs[stem] += loss_contribs[f"{self.name}/{stem}/{modality}"]

        total_loss = sum(stem_contribs.values())
        loss_contribs[self.name] = total_loss

        with torch.no_grad():
            for stem in stem_contribs.keys():
                loss_contribs[f"{self.name}/{stem}"] = stem_contribs[stem]

        return loss_contribs


class AdversarialLossHandler(BaseLossHandler):
    def __init__(self, loss: nn.Module, modality: str, name: Optional[str] = "adv_loss"):

        super().__init__(loss, modality, name)

    def discriminator_forward(self, batch: BatchedInputOutput):

        y_true = batch.sources
        y_pred = batch.estimates

        # g_loss_contribs = {}
        d_loss_contribs = {}

        for stem in y_pred.keys():

            if self.modality not in y_pred[stem].keys():
                continue

            if y_pred[stem][self.modality].shape[-1] == 0:
                continue

            y_true_ = y_true[stem][self.modality]
            y_pred_ = y_pred[stem][self.modality]

            if self.modality == "audio":
                y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)

            # g_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.generator_loss(
            #     y_pred_, y_true_
            # )

            d_loss_contribs[f"{self.name}:d/{stem}"] = self.loss.discriminator_loss(
                y_pred_, y_true_
            )

        # g_total_loss = sum(g_loss_contribs.values())
        d_total_loss = sum(d_loss_contribs.values())

        # g_loss_contribs["loss"] = g_total_loss
        d_loss_contribs["disc_loss"] = d_total_loss

        return d_loss_contribs

    def generator_forward(self, batch: BatchedInputOutput):

        y_true = batch.sources
        y_pred = batch.estimates

        g_loss_contribs = {}
        # d_loss_contribs = {}

        for stem in y_pred.keys():

            if self.modality not in y_pred[stem].keys():
                continue

            if y_pred[stem][self.modality].shape[-1] == 0:
                continue

            y_true_ = y_true[stem][self.modality]
            y_pred_ = y_pred[stem][self.modality]

            if self.modality == "audio":
                y_pred_, y_true_ = self._audio_preprocess(y_pred_, y_true_)

            g_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.generator_loss(
                y_pred_, y_true_
            )

            # d_loss_contribs[f"{self.name}:g/{stem}"] = self.loss.discriminator_loss(
            #     y_pred_, y_true_
            # )

        g_total_loss = sum(g_loss_contribs.values())
        # d_total_loss = sum(d_loss_contribs.values())

        g_loss_contribs["gen_loss"] = g_total_loss
        # d_loss_contribs["loss"] = d_total_loss

        return g_loss_contribs

    def forward(self, batch: BatchedInputOutput):
        return {
            "generator": self.generator_forward(batch),
            "discriminator": self.discriminator_forward(batch)
        }

In [36]:
# L1SNR Loss
import torch
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F

class WeightedL1Loss(_Loss):
    def __init__(self, weights=None):
        super().__init__()

    def forward(self, y_pred, y_true):
        ndim = y_pred.ndim
        dims = list(range(1, ndim))
        loss = F.l1_loss(y_pred, y_true, reduction='none')
        loss = torch.mean(loss, dim=dims)
        weights = torch.mean(torch.abs(y_true), dim=dims)

        loss = torch.sum(loss * weights) / torch.sum(weights)

        return loss


class L1MatchLoss(_Loss):
    def __init__(self):
        super().__init__()

    def forward(self, y_pred, y_true):
        batch_size = y_pred.shape[0]

        y_pred = y_pred.reshape(batch_size, -1)
        y_true = y_true.reshape(batch_size, -1)

        l1_true = torch.mean(torch.abs(y_true), dim=-1)
        l1_pred = torch.mean(torch.abs(y_pred), dim=-1)
        loss = torch.mean(torch.abs(l1_pred - l1_true))

        return loss

class DecibelMatchLoss(_Loss):
    def __init__(self, eps=1e-3):
        super().__init__()

        self.eps = eps

    def forward(self, y_pred, y_true):
        batch_size = y_pred.shape[0]

        y_pred = y_pred.reshape(batch_size, -1)
        y_true = y_true.reshape(batch_size, -1)

        db_true = 10.0 * torch.log10(self.eps + torch.mean(torch.square(torch.abs(y_true)), dim=-1))
        db_pred = 10.0 * torch.log10(self.eps + torch.mean(torch.square(torch.abs(y_pred)), dim=-1))
        loss = torch.mean(torch.abs(db_pred - db_true))

        return loss

class L1SNRLoss(_Loss):
    def __init__(self, eps=1e-3):
        super().__init__()
        self.eps = torch.tensor(eps)

    def forward(self, y_pred, y_true):
        batch_size = y_pred.shape[0]

        y_pred = y_pred.reshape(batch_size, -1)
        y_true = y_true.reshape(batch_size, -1)

        l1_error = torch.mean(torch.abs(y_pred - y_true), dim=-1)
        l1_true = torch.mean(torch.abs(y_true), dim=-1)

        snr = 20.0 * torch.log10((l1_true + self.eps) / (l1_error + self.eps))

        return -torch.mean(snr)
    
class L1SNRLossIgnoreSilence(_Loss):
    def __init__(self, eps=1e-3, dbthresh=-20, dbthresh_step=20):
        super().__init__()
        self.eps = torch.tensor(eps)
        self.dbthresh = dbthresh
        self.dbthresh_step = dbthresh_step

    def forward(self, y_pred, y_true):
        batch_size = y_pred.shape[0]

        y_pred = y_pred.reshape(batch_size, -1)
        y_true = y_true.reshape(batch_size, -1)

        l1_error = torch.mean(torch.abs(y_pred - y_true), dim=-1)
        l1_true = torch.mean(torch.abs(y_true), dim=-1)

        snr = 20.0 * torch.log10((l1_true + self.eps) / (l1_error + self.eps))
        
        db = 10.0 * torch.log10(torch.mean(torch.square(y_true), dim=-1) + 1e-6)
        
        if torch.sum(db > self.dbthresh) == 0:
            if torch.sum(db > self.dbthresh - self.dbthresh_step) == 0:
                return -torch.mean(snr)
            else:
                return -torch.mean(snr[db > self.dbthresh  - self.dbthresh_step])

        return -torch.mean(snr[db > self.dbthresh])

class L1SNRDecibelMatchLoss(_Loss):
    def __init__(self, db_weight=0.1, l1snr_eps=1e-3, dbeps=1e-3):
        super().__init__()
        self.l1snr = L1SNRLoss(l1snr_eps)
        self.decibel_match = DecibelMatchLoss(dbeps)
        self.db_weight = db_weight

    def forward(self, y_pred, y_true):

        return self.l1snr(y_pred, y_true) + self.decibel_match(y_pred, y_true)

_____
# ***Metrics***

In [37]:
# Base
from typing import Dict, Optional
from torch import nn
import torchmetrics as tm

# from core.types import BatchedInputOutput, OperationMode


class BaseMetricHandler(nn.Module):
    def __init__(
        self, stem: str, metric: tm.Metric, modality: str, name: Optional[str] = None
    ) -> None:
        super().__init__()

        self.metric = metric
        self.modality = modality
        self.stem = stem

        if name is None or name == "__auto__":
            name = self.metric.__class__.__name__

        self.name = name

    def update(self, batch: BatchedInputOutput):
        y_true = batch.sources[self.stem]
        y_pred = batch.estimates[self.stem]

        self.metric.update(y_pred[self.modality].cuda(), y_true[self.modality].cuda())

    def compute(self) -> Dict[str, float]:

        metric = self.metric.compute()

        if isinstance(metric, dict):
            return {f"{self.name}/{k}": v for k, v in metric.items()}

        return {self.name: self.metric.compute()}

    def reset(self):
        self.metric.reset()


class MultiModeMetricHandler(nn.Module):
    def __init__(
        self,
        train_metrics: Dict[str, BaseMetricHandler],
        val_metrics: Dict[str, BaseMetricHandler],
        test_metrics: Dict[str, BaseMetricHandler],
    ):
        super().__init__()

        self.train_metrics = nn.ModuleDict(train_metrics)
        self.val_metrics = nn.ModuleDict(val_metrics)
        self.test_metrics = nn.ModuleDict(test_metrics)

    def get_mode(self, mode: OperationMode) -> BaseMetricHandler:
        if mode == OperationMode.TRAIN:
            return self.train_metrics
        elif mode == OperationMode.VAL:
            return self.val_metrics
        elif mode == OperationMode.TEST:
            return self.test_metrics
        else:
            raise ValueError(f"Unknown mode: {mode}")

In [38]:
# SNR
from typing import Any, Tuple
import torch
import torchmetrics as tm
from torchmetrics.audio.snr import SignalNoiseRatio
from torchmetrics.functional.audio.snr import signal_noise_ratio, scale_invariant_signal_noise_ratio
from torchmetrics.utilities.checks import _check_same_shape


def safe_signal_noise_ratio(
    preds: torch.Tensor, target: torch.Tensor, zero_mean: bool = False
) -> torch.Tensor:

    return torch.nan_to_num(
        signal_noise_ratio(preds, target, zero_mean=zero_mean), nan=torch.nan, posinf=100.0, neginf=-100.0
    )


def safe_scale_invariant_signal_noise_ratio(
    preds: torch.Tensor, target: torch.Tensor,
    zero_mean: bool = False
) -> torch.Tensor:
    """`Scale-invariant signal-to-distortion ratio`_ (SI-SDR).

    The SI-SDR value is in general considered an overall measure of how good a source sound.

    Args:
        preds: float tensor with shape ``(...,time)``
        target: float tensor with shape ``(...,time)``
        zero_mean: If to zero mean target and preds or not

    Returns:
        Float tensor with shape ``(...,)`` of SDR values per sample

    Raises:
        RuntimeError:
            If ``preds`` and ``target`` does not have the same shape
    """
    return torch.nan_to_num(
        scale_invariant_signal_noise_ratio(preds, target), nan=torch.nan, posinf=100.0,
        neginf=-100.0
    )

def decibels(x: torch.Tensor, threshold: float = 1e-6) -> torch.Tensor:
    mean_squared = torch.mean(torch.square(x), dim=-1)
    n_samples = x.shape[0]
    return torch.sum(10 * torch.log10(mean_squared + threshold)), n_samples

class Decibels(tm.Metric):
    def __init__(self, threshold: float = 1e-6, **kwargs: Any) -> None:
        super().__init__(**kwargs)
        self.threshold = threshold

        self.add_state("running_mean", default=torch.tensor(0.0), dist_reduce_fx="sum")
        self.add_state("running_count", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, y):

        db, count = decibels(y, self.threshold)
        self.running_mean += db.cpu()
        self.running_count += count

    def compute(self) -> torch.Tensor:
        return self.running_mean / self.running_count

    # def reset(self) -> None:
    #     self.running_mean = torch.tensor(0.0)
    #     self.running_count = torch.tensor(0)

class PredictedDecibels(Decibels):
    def update(self, ypred, ytrue) -> None:
        return super().update(ypred)

class TargetDecibels(Decibels):
    def update(self, ypred, ytrue) -> None:
        return super().update(ytrue)


class SafeSignalNoiseRatio(SignalNoiseRatio):
    def __init__(
        self,
        zero_mean: bool = False,
        threshold: float = 1e-6,
        fs: int = 44100,
        **kwargs: Any
    ) -> None:
        super().__init__(zero_mean, **kwargs)

        self.threshold = threshold
        self.fs = fs

        self.sample_mismatch_thresh_seconds = 0.1

        self.add_state("snr_list", default=[], dist_reduce_fx="cat")

    def _fix_shape(self, preds: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        n_samples_preds = preds.shape[-1]
        n_samples_target = target.shape[-1]

        if n_samples_preds != n_samples_target:
            if (
                    abs(n_samples_preds - n_samples_target) / self.fs
                    > self.sample_mismatch_thresh_seconds
            ):
                raise ValueError(
                    "The difference between the number of samples of the predictions and the target is too large (100 ms)"
                )

            min_samples = min(n_samples_preds, n_samples_target)
            preds = preds[..., :min_samples]
            target = target[..., :min_samples]

        return preds, target

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        """Update state with predictions and targets."""

        preds, target = self._fix_shape(preds, target)

        snr_batch = safe_signal_noise_ratio(
            preds=preds, target=target, zero_mean=self.zero_mean
        )

        self.snr_list.append(snr_batch)

    def compute(self) -> torch.Tensor:
        """Compute metric."""

        if len(self.snr_list) == 0:
            return torch.tensor(float("nan"))

        return torch.nanmedian(torch.cat(self.snr_list))


class SafeScaleInvariantSignalNoiseRatio(SafeSignalNoiseRatio):
    def __init__(
        self,
        zero_mean: bool = False,
        threshold: float = 1e-6,
        fs: int = 44100,
        **kwargs: Any
    ) -> None:
        super().__init__(zero_mean, threshold, fs, **kwargs)

    def update(self, preds: torch.Tensor, target: torch.Tensor) -> None:
        """Update state with predictions and targets."""

        preds, target = self._fix_shape(preds, target)

        snr_batch = safe_scale_invariant_signal_noise_ratio(
            preds=preds, target=target, zero_mean=self.zero_mean
        )

        self.snr_list.append(snr_batch)

    def compute(self) -> torch.Tensor:
        """Compute metric."""

        if len(self.snr_list) == 0:
            return torch.tensor(float("nan"))

        return torch.nanmedian(torch.cat(self.snr_list))

_____
# ***Training***

In [40]:
import json
import os.path
from pprint import pprint
import random
import string
from types import SimpleNamespace
import pandas as pd

import torch
from tqdm import tqdm

# from core.data.moisesdb.datamodule import (
#     MoisesTestDataModule,
#     MoisesValidationDataModule,
#     MoisesDataModule,
#     MoisesBalancedTrainDataModule,
#     MoisesVDBODataModule,
# )
# from core.losses.base import AdversarialLossHandler, BaseLossHandler
# from core.losses.l1snr import (
#     L1SNRDecibelMatchLoss,
#     L1SNRLoss,
#     WeightedL1Loss,
#     L1SNRLossIgnoreSilence,
# )
# from core.metrics.base import BaseMetricHandler, MultiModeMetricHandler
# from core.metrics.snr import (
#     SafeScaleInvariantSignalNoiseRatio,
#     SafeSignalNoiseRatio,
#     PredictedDecibels,
#     TargetDecibels,
# )
# from core.models.ebase import EndToEndLightningSystem
# from core.models.e2e.resunet.resunet import (
#     ResUNetPasstConditionedSeparator,
#     ResUNetResQueryConditionedSeparator,
#     # StupidNet
# )

# from core.models.e2e.bandit.bandit import Bandit, PasstFiLMConditionedBandit

from omegaconf import OmegaConf
# from core.types import LossHandler, OptimizationBundle

from torch import nn, optim
from torch.optim import lr_scheduler

import torchmetrics as tm

import pytorch_lightning as pl
import pytorch_lightning.callbacks
import pytorch_lightning.loggers
from pytorch_lightning.profilers import AdvancedProfiler

import torch.backends.cudnn

torch.set_float32_matmul_precision("high")
torch.backends.cudnn.benchmark = True


def _allowed_classes_to_dict(allowed_classes):
    return {cls.__name__: cls for cls in allowed_classes}


ALLOWED_MODELS = [
    # ResUNetPasstConditionedSeparator,
    # ResUNetResQueryConditionedSeparator,
    # StupidNet
    Bandit,
    PasstFiLMConditionedBandit,
]

ALLOWED_MODELS_DICT = _allowed_classes_to_dict(ALLOWED_MODELS)

ALLOWED_DATAMODULES = [
    MoisesDataModule,
    MoisesBalancedTrainDataModule,
    MoisesVDBODataModule,
    MoisesValidationDataModule,
    MoisesTestDataModule,
]

ALLOWED_DATAMODULE_DICT = _allowed_classes_to_dict(ALLOWED_DATAMODULES)

ALLOWED_LOSSES = [
    L1SNRLoss,
    WeightedL1Loss,
    L1SNRDecibelMatchLoss,
    L1SNRLossIgnoreSilence,
]

ALLOWED_LOSS_DICT = _allowed_classes_to_dict(ALLOWED_LOSSES)


def _build_model(config: OmegaConf) -> nn.Module:

    model_config = config.model

    model_name = model_config.cls
    kwargs = model_config.get("kwargs", {})

    if model_name in ALLOWED_MODELS_DICT:
        model = ALLOWED_MODELS_DICT[model_name](**kwargs)
    else:
        raise ValueError(f"Unknown model name: {model_name}")

    return model


def _build_inner_loss(config: OmegaConf) -> nn.Module:
    loss_config = config.loss

    loss_name = loss_config.cls
    kwargs = loss_config.get("kwargs", {})

    if loss_name in ALLOWED_LOSS_DICT:
        loss = ALLOWED_LOSS_DICT[loss_name](**kwargs)
    elif loss_name in nn.modules.loss.__dict__:
        loss = nn.modules.loss.__dict__[loss_name](**kwargs)
    else:
        raise ValueError(f"Unknown loss name: {loss_name}")

    return loss


def _build_loss(config: OmegaConf) -> BaseLossHandler:
    loss_handler = BaseLossHandler(
        loss=_build_inner_loss(config),
        modality=config.loss.modality,
        name=config.loss.get("name", None),
    )

    return loss_handler


def _dummy_metrics(config: OmegaConf) -> MultiModeMetricHandler:
    metrics = MultiModeMetricHandler(
        train_metrics={
            stem: BaseMetricHandler(
                stem=stem,
                metric=tm.MetricCollection(
                    SafeSignalNoiseRatio(),
                    SafeScaleInvariantSignalNoiseRatio(),
                    PredictedDecibels(),
                    TargetDecibels(),
                ),
                modality="audio",
                name="snr",
            )
            for stem in config.stems
        },
        val_metrics={
            stem: BaseMetricHandler(
                stem=stem,
                metric=tm.MetricCollection(
                    SafeSignalNoiseRatio(),
                    SafeScaleInvariantSignalNoiseRatio(),
                    PredictedDecibels(),
                    TargetDecibels(),
                ),
                modality="audio",
                name="snr",
            )
            for stem in config.stems
        },
        test_metrics={
            stem: BaseMetricHandler(
                stem=stem,
                metric=tm.MetricCollection(
                    SafeSignalNoiseRatio(),
                    SafeScaleInvariantSignalNoiseRatio(),
                    PredictedDecibels(),
                    TargetDecibels(),
                ),
                modality="audio",
                name="snr",
            )
            for stem in config.stems
        },
    )

    return metrics


def _build_optimization_bundle(config: OmegaConf) -> OptimizationBundle:
    optim_config = config.optim

    print(optim_config)

    optimizer_name = optim_config.optimizer.cls
    kwargs = optim_config.optimizer.get("kwargs", {})

    optimizer = getattr(optim, optimizer_name)

    optim_bundle = SimpleNamespace(
        optimizer=SimpleNamespace(cls=optimizer, kwargs=kwargs), scheduler=None
    )

    scheduler_config = optim_config.get("scheduler", None)

    if scheduler_config is not None:
        scheduler_name = scheduler_config.cls
        scheduler_kwargs = scheduler_config.get("kwargs", {})

        if scheduler_name in lr_scheduler.__dict__:
            scheduler = lr_scheduler.__dict__[scheduler_name]
        else:
            raise ValueError(f"Unknown scheduler name: {scheduler_name}")

        optim_bundle.scheduler = SimpleNamespace(cls=scheduler, kwargs=scheduler_kwargs)

    return optim_bundle


def _dummy_augmentation():
    return nn.Identity()


def _load_config(config_path: str) -> OmegaConf:
    config = OmegaConf.load(config_path)

    config_dict = {}

    for k, v in config.items():
        if isinstance(v, str) and v.endswith(".yml"):
            config_dict[k] = OmegaConf.load(v)
        else:
            config_dict[k] = v

    config = OmegaConf.merge(config_dict)

    return config


def _build_datamodule(config: OmegaConf) -> pl.LightningDataModule:

    DataModule = ALLOWED_DATAMODULE_DICT[config.data.cls]

    datamodule = DataModule(
        data_root=config.data.data_root,
        batch_size=config.data.batch_size,
        num_workers=config.data.num_workers,
        train_kwargs=config.data.get("train_kwargs", None),
        val_kwargs=config.data.get("val_kwargs", None),
        test_kwargs=config.data.get("test_kwargs", None),
        datamodule_kwargs=config.data.get("datamodule_kwargs", None),
    )

    return datamodule


def train(
    config_path: str,
    profile: bool = False,
    ckpt_path: str = None,
    validate_only: bool = False,
    inference_only: bool = False,
    output_dir: str = None,
    test_datamodule: bool = False,
    precision=32,
):
    config = _load_config(config_path)

    pl.seed_everything(config.seed, workers=True)

    if inference_only:
        config["data"]["batch_size"] = 1

    datamodule = _build_datamodule(config)

    if test_datamodule:
        for batch in tqdm(datamodule.train_dataloader()):
            pass

        for batch in tqdm(datamodule.val_dataloader()):
            pass

        for batch in tqdm(datamodule.test_dataloader()):
            pass

        return

    model = _build_model(config)
    loss_handler = _build_loss(config)

    system = EndToEndLightningSystem(
        model=model,
        loss_handler=loss_handler,
        metrics=_dummy_metrics(config),
        augmentation_handler=_dummy_augmentation(),
        inference_handler=config.get("inference", None),
        optimization_bundle=_build_optimization_bundle(config),
        fast_run=config.fast_run,
        batch_size=config.data.batch_size,
        effective_batch_size=config.data.get("effective_batch_size", None),
        commitment_weight=config.get("commitment_weight", 1.0),
    )

    rand_str = "".join(
        random.choice(string.ascii_uppercase + string.digits) for _ in range(6)
    )

    logger = pytorch_lightning.loggers.TensorBoardLogger(
        save_dir=os.path.join(
            config.trainer.logger.save_dir, os.environ.get("SLURM_JOB_ID", rand_str)
        ),
    )

    callbacks = [
        pytorch_lightning.callbacks.ModelCheckpoint(
            monitor=config.trainer.callbacks.checkpoint.monitor,
            mode=config.trainer.callbacks.checkpoint.mode,
            save_top_k=config.trainer.callbacks.checkpoint.save_top_k,
            save_last=config.trainer.callbacks.checkpoint.save_last,
        ),
        pytorch_lightning.callbacks.ModelCheckpoint(
            monitor=None,
        ),  # also save the last 3 epochs
        pytorch_lightning.callbacks.RichModelSummary(max_depth=3),
    ]

    if profile:
        profiler = AdvancedProfiler(filename="profile.txt", dirpath=".")

    if config.trainer.accumulate_grad_batches is None:
        config.trainer.accumulate_grad_batches = 1
        if config.data.effective_batch_size is not None:
            config.trainer.accumulate_grad_batches = int(
                config.data.effective_batch_size / config.data.batch_size
            )

    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        max_epochs=1 if profile else config.trainer.max_epochs,
        callbacks=callbacks,
        logger=logger,
        profiler=profiler if profile else None,
        limit_train_batches=int(8) if profile else float(1.0),
        limit_val_batches=int(8) if profile else float(1.0),
        accumulate_grad_batches=config.trainer.accumulate_grad_batches,
        precision=precision,
        gradient_clip_val=config.trainer.get("gradient_clip_val", None),
        gradient_clip_algorithm=config.trainer.get("gradient_clip_algorithm", "norm"),
    )

    if validate_only:
        trainer.validate(system, datamodule, ckpt_path=ckpt_path)
    elif inference_only:
        if output_dir is None:
            output_dir = os.path.join(
                os.path.dirname(os.path.dirname(ckpt_path)), "inference"
            )
            system.set_output_path(output_dir)
        trainer.predict(system, datamodule, ckpt_path=ckpt_path)
    else:
        trainer.logger.log_hyperparams(OmegaConf.to_object(config))
        trainer.logger.save()
        trainer.fit(system, datamodule, ckpt_path=ckpt_path)


def query_validate(config_path: str, ckpt_path: str):
    config = _load_config(config_path)

    datamodule = _build_datamodule(config)

    model = _build_model(config)
    loss_handler = _build_loss(config)

    system = EndToEndLightningSystem(
        model=model,
        loss_handler=loss_handler,
        metrics=_dummy_metrics(config),
        augmentation_handler=_dummy_augmentation(),
        inference_handler=None,
        optimization_bundle=_build_optimization_bundle(config),
        fast_run=config.fast_run,
        batch_size=config.data.batch_size,
        effective_batch_size=config.data.get("effective_batch_size", None),
        commitment_weight=config.get("commitment_weight", 1.0),
    )

    logger = pytorch_lightning.loggers.CSVLogger(
        save_dir=os.path.join(config.trainer.logger.save_dir, "validate"),
    )

    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        logger=logger,
    )

    allowed_stems = config.data.val_kwargs.get("allowed_stems", None)

    data = []

    os.makedirs(trainer.logger.log_dir, exist_ok=True)

    with open(trainer.logger.log_dir + "/config.txt", "w") as f:
        f.write(ckpt_path)

    dl = datamodule.val_dataloader()

    for stem, val_dl in zip(allowed_stems, dl):
        metrics = trainer.validate(system, val_dl, ckpt_path=ckpt_path)[0]
        print(stem)
        pprint(metrics)

        for metric, value in metrics.items():
            data.append({"metric": metric, "value": value, "stem": stem})

    df = pd.DataFrame(data)

    df.to_csv(
        os.path.join(trainer.logger.log_dir, "validation_metrics.csv"), index=False
    )


def query_test(config_path: str, ckpt_path: str):
    config = _load_config(config_path)

    pprint(config)
    pprint(config.data.inference_kwargs)

    datamodule = _build_datamodule(config)

    model = _build_model(config)
    loss_handler = _build_loss(config)

    system = EndToEndLightningSystem(
        model=model,
        loss_handler=loss_handler,
        metrics=_dummy_metrics(config),
        augmentation_handler=_dummy_augmentation(),
        inference_handler=config.data.inference_kwargs,
        optimization_bundle=_build_optimization_bundle(config),
        fast_run=config.fast_run,
        batch_size=config.data.batch_size,
        effective_batch_size=config.data.get("effective_batch_size", None),
        commitment_weight=config.get("commitment_weight", 1.0),
    )

    rand_str = "".join(
        random.choice(string.ascii_uppercase + string.digits) for _ in range(6)
    )

    use_own_query = config.data.test_kwargs.get("use_own_query", False)

    prefix = "test-o" if use_own_query else "test"

    logger = pytorch_lightning.loggers.CSVLogger(
        save_dir=os.path.join(
            config.trainer.logger.save_dir,
            prefix,
            os.environ.get("SLURM_JOB_ID", rand_str),
        ),
    )

    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        logger=logger,
    )

    os.makedirs(trainer.logger.log_dir, exist_ok=True)

    with open(trainer.logger.log_dir + "/config.txt", "w") as f:
        f.write(ckpt_path)

    trainer.logger.log_hyperparams(OmegaConf.to_object(config))
    trainer.logger.save()

    dl = datamodule.test_dataloader()

    trainer.test(system, dl, ckpt_path=ckpt_path)

def query_inference(config_path: str, ckpt_path: str):
    config = _load_config(config_path)

    pprint(config)
    pprint(config.data.inference_kwargs)

    datamodule = _build_datamodule(config)

    model = _build_model(config)
    loss_handler = _build_loss(config)

    system = EndToEndLightningSystem(
        model=model,
        loss_handler=loss_handler,
        metrics=_dummy_metrics(config),
        augmentation_handler=_dummy_augmentation(),
        inference_handler=config.data.inference_kwargs,
        optimization_bundle=_build_optimization_bundle(config),
        fast_run=config.fast_run,
        batch_size=config.data.batch_size,
        effective_batch_size=config.data.get("effective_batch_size", None),
        commitment_weight=config.get("commitment_weight", 1.0),
    )

    rand_str = "".join(
        random.choice(string.ascii_uppercase + string.digits) for _ in range(6)
    )

    use_own_query = config.data.test_kwargs.get("use_own_query", False)

    prefix = "inference-o" if use_own_query else "inference-d"

    logger = pytorch_lightning.loggers.CSVLogger(
        save_dir=os.path.join(
            config.trainer.logger.save_dir,
            prefix,
            os.environ.get("SLURM_JOB_ID", rand_str),
        ),
    )

    trainer = pl.Trainer(
        accelerator="gpu" if torch.cuda.is_available() else "cpu",
        logger=logger,
    )

    os.makedirs(trainer.logger.log_dir, exist_ok=True)

    with open(trainer.logger.log_dir + "/config.txt", "w") as f:
        f.write(ckpt_path)

    trainer.logger.log_hyperparams(OmegaConf.to_object(config))
    trainer.logger.save()

    dl = datamodule.test_dataloader()

    trainer.predict(system, dl, ckpt_path=ckpt_path)


def clean_validation_metrics(path):
    df = pd.read_csv(path).T

    data = []

    stems = [
        "drums",
        "lead_male_singer",
        "lead_female_singer",
        # "human_choir",
        "background_vocals",
        # "other_vocals",
        "bass_guitar",
        "bass_synthesizer",
        # "contrabass_double_bass",
        # "tuba",
        # "bassoon",
        "fx",
        "clean_electric_guitar",
        "distorted_electric_guitar",
        # "lap_steel_guitar_or_slide_guitar",
        "acoustic_guitar",
        "other_plucked",
        "pitched_percussion",
        "grand_piano",
        "electric_piano",
        "organ_electric_organ",
        "synth_pad",
        "synth_lead",
        # "violin",
        # "viola",
        # "cello",
        # "violin_section",
        # "viola_section",
        # "cello_section",
        "string_section",
        "other_strings",
        "brass",
        # "flutes",
        "reeds",
        "other_wind",
    ]

    for metric, value in df.iterrows():

        mm = metric.split("/")
        idx = mm[-1]
        m = "/".join(mm[:-1])

        print(metric, idx)

        try:
            idx = int(idx.split("_")[-1])
        except ValueError as e:
            assert "invalid literal for int() with base 10" in str(e)
            continue

        data.append({m: value, "stem": stems[idx]})

    df = pd.DataFrame(data)

    new_path = path.replace(".csv", "_clean.csv")

    df.to_csv(new_path, index=False)


if __name__ == "__main__":
    import fire

    fire.Fire()

ERROR: Cannot find key: --f=c:\Users\Dell\AppData\Roaming\jupyter\runtime\kernel-v3f0d43beaf9ae7787c0ec59baf0b0092a7079c985.json
Usage: ipykernel_launcher.py <group|command|value>
  available groups:      In | Out | exit | quit | os | Callable | np | torch |
                         taF | pd | band_defs | mbs | df | nn | math |
                         input_data | Dict | List | Optional | Tuple | Type |
                         activation | ta | optim | tm | Union | pl | Iterator |
                         Mapping | lr_scheduler | inspect | Sequence | data |
                         taxonomy | random | INST_BY_OCCURRENCE |
                         FINE_LEVEL_INSTRUMENTS | COARSE_LEVEL_INSTRUMENTS |
                         COARSE_TO_FINE | FINE_TO_COARSE |
                         ALL_LEVEL_INSTRUMENTS | audiomentations | glob |
                         json | shutil | librosa | fine_to_coarse | v |
                         possible_coarse | possible_fine | fire | string |
           

FireExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
