查看模型的内容，参数量

In [None]:
from utils import demix, get_model_from_config
import json

model_type = 'mel_band_roformer'
config_path = 'configs/KimberleyJensen/config_musdb18_vocals_mel_band_roformer_kj_with_augue.yaml'
model, config = get_model_from_config(model_type, config_path)

# 计算模型的总参数量
total_params = sum(param.numel() for param in model.state_dict().values())

# 打印总参数量，单位是个数
print(f"Total number of parameters: {total_params}, {total_params/1e6:.2f}M")

# 如果需要显示以MB为单位的参数量
params_in_mb = total_params * 4 / 1e6  # 假设每个参数占 4 字节（float32）
print(f"Total number of parameters (in MB): {params_in_mb:.2f} MB")
print(json.dumps(list(model.state_dict().keys())[:40], indent=4))

In [None]:
# 获取模型的参数字典
state_dict = model.state_dict()

# 计算并显示每个参数的大小和参数量
parameter_info = []
for name, param in state_dict.items():
    num_params = param.numel()  # 获取参数的数量（numel 返回参数的元素数量）
    parameter_info.append({
        'name': name,
        'shape': list(param.shape),
        'num_params': num_params
    })

# 打印前40个参数的名称、形状和数量
print(json.dumps(parameter_info[:40], indent=4))

[
    {
        "name": "layers.0.0.layers.0.0.rotary_embed.freqs",
        "shape": [
            32
        ],
        "num_params": 32
    },
    {
        "name": "layers.0.0.layers.0.0.norm.gamma",
        "shape": [
            384
        ],
        "num_params": 384
    },
    {
        "name": "layers.0.0.layers.0.0.to_qkv.weight",
        "shape": [
            1536,
            384
        ],
        "num_params": 589824
    },
    {
        "name": "layers.0.0.layers.0.0.to_gates.weight",
        "shape": [
            8,
            384
        ],
        "num_params": 3072
    },
    {
        "name": "layers.0.0.layers.0.0.to_gates.bias",
        "shape": [
            8
        ],
        "num_params": 8
    },
    {
        "name": "layers.0.0.layers.0.0.to_out.0.weight",
        "shape": [
            384,
            512
        ],
        "num_params": 196608
    },
    {
        "name": "layers.0.0.layers.0.1.net.0.gamma",
        "shape": [
            384
       

In [None]:
from collections import defaultdict
# 获取模型的参数字典
state_dict = model.state_dict()

# 用 defaultdict 按层（前缀）分组
layer_params = defaultdict(int)  # 默认每组的参数量为 0

# 将每个参数按其层（前缀）进行分组，并计算每组的参数量
for name, param in state_dict.items():
    # 获取层的前缀部分（例如 layers.0.0, layers.0.1 等）
    prefix_parts = name.split('.')[:3]  # 提取前三部分作为层的前缀
    layer_prefix = '.'.join(prefix_parts)  # 拼接成前缀，如 layers.0.0
    num_params = param.numel()  # 获取该参数的元素数量
    
    # 累加每一层的参数量
    layer_params[layer_prefix] += num_params

# 将每个层的参数信息整理成字典形式
# layer_info = [{'layer': layer, 'num_params': num_params} for layer, num_params in layer_params.items()]
layer_info = [{'layer': layer, 'num_params (MB)': num_params / 1e6} for layer, num_params in layer_params.items()]

# 打印前40个层的名称和参数量
print(json.dumps(layer_info[:40], indent=4))

[
    {
        "layer": "layers.0.0",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.0.1",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.1.0",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.1.1",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.2.0",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.2.1",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.3.0",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.3.1",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.4.0",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.4.1",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.5.0",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "layers.5.1",
        "num_params (MB)": 1.972264
    },
    {
        "layer": "band_split.to_

In [1]:
from rotary_embedding_torch import RotaryEmbedding
rope = RotaryEmbedding(64)

theta: 10000.0


查看Melformer mel band得到的内容是什么

In [2]:
print(model.freq_per_bands_with_complex)

(28, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 28, 28, 28, 36, 36, 36, 40, 40, 44, 52, 52, 52, 60, 64, 68, 76, 80, 80, 88, 96, 104, 112, 116, 124, 132, 144, 156, 164, 176, 188, 200, 216, 228, 244, 264, 284, 304, 320, 344, 372, 396, 420, 452, 488, 520)


In [8]:
from librosa import filters
import torch
from einops import rearrange, pack, unpack, reduce, repeat
from einops.layers.torch import Rearrange
# create mel filter bank
# with librosa.filters.mel as in section 2 of paper
sample_rate = 44100
stft_n_fft = 2048
num_bands = 60
stft_hop_length = 441
stft_win_length = 2048
stft_normalized = False
stereo = True
audio_channels = 2 if stereo else 1

stft_kwargs = dict(
    n_fft=stft_n_fft,
    hop_length=stft_hop_length,
    win_length=stft_win_length,
    normalized=stft_normalized
)
# 首先这里计算得到的频率数量是 1025，因为 n_fft 是 2048，所以频率数量是 n_fft // 2 + 1
freqs = torch.stft(torch.randn(1, 4096), **stft_kwargs, window=torch.ones(stft_n_fft), return_complex=True).shape[1]

# 得到的 mel_filter_bank_numpy 的 shape 是 (num_bands, freqs)
mel_filter_bank_numpy = filters.mel(sr=sample_rate, n_fft=stft_n_fft, n_mels=num_bands)

mel_filter_bank = torch.from_numpy(mel_filter_bank_numpy)

# for some reason, it doesn't include the first freq? just force a value for now
# 将 mel_filter_bank 的第一行第一列的值设为 1，因为 mel_filter_bank 的第一行第一列的值是 0，这里强制设为 1，确保所有频率都被所有频带覆盖
mel_filter_bank[0][0] = 1.

# In some systems/envs we get 0.0 instead of ~1.9e-18 in the last position,
# so let's force a positive value
# 将 mel_filter_bank 的最后一行最后一列的值设为 1，确保所有频率都被所有频带覆盖
mel_filter_bank[-1, -1] = 1.

# binary as in paper (then estimated masks are averaged for overlapping regions)
# 得到的 freqs_per_band 的 shape 是 (num_bands, freqs)，其中True表示该频率在该频带中，False表示不在
freqs_per_band = mel_filter_bank > 0
print(f"freqs_per_band: {freqs_per_band.shape}, {freqs_per_band}")
assert freqs_per_band.any(dim=0).all(), 'all frequencies need to be covered by all bands for now'
# 将频率索引的序列（0, 1, 2, ..., 1024）复制 num_bands 次
repeated_freq_indices = repeat(torch.arange(freqs), 'f -> b f', b=num_bands)
print(f"repeated_freq_indices: {repeated_freq_indices.shape}, {repeated_freq_indices}")
# 然后根据 mel_filter_bank 的值进行筛选，得到每个频带对应的频率索引，输出为True的索引，False的索引被过滤掉
# 布尔索引会将所有标记为 True 的频率索引提取到一个一维张量中，
freq_indices = repeated_freq_indices[freqs_per_band]
print(f"1:freq_indices: {freq_indices.shape}, {freq_indices}")

if stereo: # 将索引的序列复制一次，然后每个索引乘以2，再加上0和1，最后将所有索引展平
    freq_indices = repeat(freq_indices, 'f -> f s', s=2)
    print(f"2:freq_indices: {freq_indices.shape}, {freq_indices}")
    freq_indices = freq_indices * 2 + torch.arange(2)
    print(f"3:freq_indices: {freq_indices.shape}, {freq_indices}")
    freq_indices = rearrange(freq_indices, 'f s -> (f s)')
    print(f"4:freq_indices: {freq_indices.shape}, {freq_indices}")

# 求True的数量，即每个频带中的频率数量
num_freqs_per_band = reduce(freqs_per_band, 'b f -> b', 'sum')
print(f"num_freqs_per_band: {num_freqs_per_band.shape}, {num_freqs_per_band}")
# 求每个频率对应的频带数量
num_bands_per_freq = reduce(freqs_per_band, 'b f -> f', 'sum')
print(f"num_bands_per_freq: {num_bands_per_freq.shape}, {num_bands_per_freq}")

# band split and mask estimator

freqs_per_bands_with_complex = tuple(2 * f * audio_channels for f in num_freqs_per_band.tolist())

print(f"""freqs: {freqs} 
mel_filter_bank_numpy: {mel_filter_bank_numpy.shape} ,content: {mel_filter_bank_numpy} 
freqs_per_band: {freqs_per_band.shape} ,content: {freqs_per_band} 
freq_indices: {freq_indices.shape} ,content: {freq_indices} 
num_freqs_per_band: {num_freqs_per_band.shape} ,content: {num_freqs_per_band} 
freqs_per_bands_with_complex: {freqs_per_bands_with_complex}
""")


freqs_per_band: torch.Size([60, 1025]), tensor([[ True,  True,  True,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        ...,
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ..., False, False, False],
        [False, False, False,  ...,  True,  True,  True]])
repeated_freq_indices: torch.Size([60, 1025]), tensor([[   0,    1,    2,  ..., 1022, 1023, 1024],
        [   0,    1,    2,  ..., 1022, 1023, 1024],
        [   0,    1,    2,  ..., 1022, 1023, 1024],
        ...,
        [   0,    1,    2,  ..., 1022, 1023, 1024],
        [   0,    1,    2,  ..., 1022, 1023, 1024],
        [   0,    1,    2,  ..., 1022, 1023, 1024]])
1:freq_indices: torch.Size([1979]), tensor([   0,    1,    2,  ..., 1022, 1023, 1024])
2:freq_indices: torch.Size([1979, 2]), tensor([[   0,    0],
        [   1,    1],
        [   2,    2],
        ...,
        [1022, 1022]

In [None]:
# 最终实现Mel频带分解，主要是先将音频得到的STFT结果进行频率索引的筛选，然后再进行频带的合并
# to stft
from einops import rearrange, pack, unpack, reduce, repeat
from functools import partial
def exists(val):
    return val is not None


def default(v, d):
    return v if exists(v) else d


def pack_one(t, pattern):
    return pack([t], pattern)


def unpack_one(t, ps, pattern):
    return unpack(t, ps, pattern)[0]


raw_audio = torch.randn(20, 2, 44100) # [batch_size, num_channels, raw_audio_length]
batch, channels, raw_audio_length = raw_audio.shape
raw_audio, batch_audio_channel_packed_shape = pack_one(raw_audio, '* t') # [batch_size*num_channels, raw_audio_length]
print(f"raw_audio: {raw_audio.shape}, batch_audio_channel_packed_shape: {batch_audio_channel_packed_shape}")

stft_window_fn = partial(torch.hann_window, stft_win_length)
device = raw_audio.device

stft_kwargs = dict(
    n_fft=stft_n_fft,
    hop_length=stft_hop_length,
    win_length=stft_win_length,
    normalized=stft_normalized
)
# num_freqs = stft_n_fft // 2 + 1 num_frames = 1 + L // hop_length for center=True, or 1 + (L - n_fft) // hop_length for center=False
# center = True, pad_mode = 'reflect',这时，会在两边pad n_fft // 2，所以num_frames = 1 + L // hop_length
stft_window = stft_window_fn(device=device)
stft_repr = torch.stft(raw_audio, **stft_kwargs, window=stft_window, return_complex=True) # [batch_size*num_channels, num_freqs, num_frames]
print(f"stft_repr: {stft_repr.shape}")
stft_repr = torch.view_as_real(stft_repr)# [batch_size*num_channels, num_freqs, num_frames, 2]
print(f"stft_repr: {stft_repr.shape}")

stft_repr = unpack_one(stft_repr, batch_audio_channel_packed_shape, '* f t c') # [batch_size, num_channels, num_freqs, num_frames, 2]
print(f"stft_repr: {stft_repr.shape}")
stft_repr = rearrange(stft_repr,
                        'b s f t c -> b (f s) t c')  # merge stereo / mono into the frequency, with frequency leading dimension, for band splitting
print(f"stft_repr: {stft_repr.shape}") # [batch_size, num_freqs*num_channels, num_frames, 2]

# index out all frequencies for all frequency ranges across bands ascending in one go

batch_arange = torch.arange(batch, device=device)[..., None] # [batch_size, 1]
print(f"batch_arange: {batch_arange.shape}")
# account for stereo

x = stft_repr[batch_arange, freq_indices] # 根据freq_indices 的索引，将每个音频对应的STFT结果进行频率索引的筛选，按照mel freq的索引进行扩充拉平
# [batch_size, 1+sum(mel_freqs), num_frames, 2]
print(f"x: {x.shape}, {x}")

raw_audio: torch.Size([40, 44100]), batch_audio_channel_packed_shape: [torch.Size([20, 2])]
num_freqs=1025, num_frames=96, num_frames_pad = 101
stft_repr: torch.Size([40, 1025, 101])
stft_repr: torch.Size([40, 1025, 101, 2])
stft_repr: torch.Size([20, 2, 1025, 101, 2])
stft_repr: torch.Size([20, 2050, 101, 2])
batch_arange: torch.Size([20, 1])
x: torch.Size([20, 3958, 101, 2]), tensor([[[[ 1.9414e+01,  0.0000e+00],
          [ 1.6111e+01,  0.0000e+00],
          [-1.2920e+01,  0.0000e+00],
          ...,
          [ 1.7361e+01,  0.0000e+00],
          [ 2.1856e+00,  0.0000e+00],
          [-2.2472e+01,  0.0000e+00]],

         [[ 7.0733e+01,  0.0000e+00],
          [ 4.0397e+01,  0.0000e+00],
          [-2.1741e+01,  0.0000e+00],
          ...,
          [ 5.3809e+01,  0.0000e+00],
          [ 2.7783e+01,  0.0000e+00],
          [ 1.3805e+01,  0.0000e+00]],

         [[-5.3832e+00,  0.0000e+00],
          [-1.6903e+01,  2.5503e-01],
          [ 3.4481e+00, -3.4564e+01],
          ...,


In [1]:
from utils import demix, get_model_from_config
import json

model_type = 'mel_band_roformer'
config_path = 'configs/KimberleyJensen/config_musdb18_vocals_mel_band_roformer_kj_with_augue.yaml'
model, config = get_model_from_config(model_type, config_path)

import torch
x = torch.randn(2, 2, 4096, dtype=torch.float32)
y = torch.randn(2, 2, 4096, dtype=torch.float32)
loss = model(x, y)

  from .autonotebook import tqdm as notebook_tqdm


GPU Compute Capability equal or above 8.0, using flash attention if input tensor is on cuda
28 384
24 384
24 384
24 384
24 384
24 384
24 384
24 384
24 384
24 384
24 384
24 384
24 384
24 384
24 384
28 384
28 384
28 384
36 384
36 384
36 384
40 384
40 384
44 384
52 384
52 384
52 384
60 384
64 384
68 384
76 384
80 384
80 384
88 384
96 384
104 384
112 384
116 384
124 384
132 384
144 384
156 384
164 384
176 384
188 384
200 384
216 384
228 384
244 384
264 384
284 384
304 384
320 384
344 384
372 384
396 384
420 384
452 384
488 384
520 384




torch.Size([2, 10, 60, 384])
torch.Size([2, 10, 384]) Sequential(
  (0): Sequential(
    (0): Linear(in_features=384, out_features=1536, bias=True)
    (1): Tanh()
    (2): Linear(in_features=1536, out_features=1536, bias=True)
    (3): Tanh()
    (4): Linear(in_features=1536, out_features=56, bias=True)
  )
  (1): GLU(dim=-1)
)
torch.Size([2, 10, 384]) Sequential(
  (0): Sequential(
    (0): Linear(in_features=384, out_features=1536, bias=True)
    (1): Tanh()
    (2): Linear(in_features=1536, out_features=1536, bias=True)
    (3): Tanh()
    (4): Linear(in_features=1536, out_features=48, bias=True)
  )
  (1): GLU(dim=-1)
)
torch.Size([2, 10, 384]) Sequential(
  (0): Sequential(
    (0): Linear(in_features=384, out_features=1536, bias=True)
    (1): Tanh()
    (2): Linear(in_features=1536, out_features=1536, bias=True)
    (3): Tanh()
    (4): Linear(in_features=1536, out_features=48, bias=True)
  )
  (1): GLU(dim=-1)
)
torch.Size([2, 10, 384]) Sequential(
  (0): Sequential(
    (0): 

# 检查 Mel band llama

In [1]:
from utils import demix, get_model_from_config
import json

model_type = 'mel_band_llama'
config_path = 'configs/KimberleyJensen/config_musdb18_vocals_mel_band_roformer_kj_with_augue.yaml'
model, config = get_model_from_config(model_type, config_path)

# 计算模型的总参数量
total_params = sum(param.numel() for param in model.state_dict().values())

# 打印总参数量，单位是个数
print(f"Total number of parameters: {total_params}, {total_params/1e6:.2f}M")

# 如果需要显示以MB为单位的参数量
params_in_mb = total_params * 4 / 1e6  # 假设每个参数占 4 字节（float32）
print(f"Total number of parameters (in MB): {params_in_mb:.2f} MB")
print(json.dumps(list(model.state_dict().keys())[:40], indent=4))

  from .autonotebook import tqdm as notebook_tqdm


eager
Total number of parameters: 232856772, 232.86M
Total number of parameters (in MB): 931.43 MB
[
    "layers.0.0.self_attn.q_proj.weight",
    "layers.0.0.self_attn.k_proj.weight",
    "layers.0.0.self_attn.v_proj.weight",
    "layers.0.0.self_attn.o_proj.weight",
    "layers.0.0.mlp.gate_proj.weight",
    "layers.0.0.mlp.up_proj.weight",
    "layers.0.0.mlp.down_proj.weight",
    "layers.0.0.input_layernorm.weight",
    "layers.0.0.post_attention_layernorm.weight",
    "layers.0.1.self_attn.q_proj.weight",
    "layers.0.1.self_attn.k_proj.weight",
    "layers.0.1.self_attn.v_proj.weight",
    "layers.0.1.self_attn.o_proj.weight",
    "layers.0.1.mlp.gate_proj.weight",
    "layers.0.1.mlp.up_proj.weight",
    "layers.0.1.mlp.down_proj.weight",
    "layers.0.1.input_layernorm.weight",
    "layers.0.1.post_attention_layernorm.weight",
    "layers.1.0.self_attn.q_proj.weight",
    "layers.1.0.self_attn.k_proj.weight",
    "layers.1.0.self_attn.v_proj.weight",
    "layers.1.0.self_attn

In [2]:
import torch
x = torch.randn(2, 2, 4096, dtype=torch.float32)
y = torch.randn(2, 2, 4096, dtype=torch.float32)
loss = model(x, y)

In [None]:
for i, transformer_block in enumerate(model.layers):
    
    if len(transformer_block) == 3:
        linear_transformer, time_transformer, freq_transformer = transformer_block

        x, ft_ps = pack([x], 'b * d')

        x = linear_transformer(x)
        x, = unpack(x, ft_ps, 'b * d')
    else:
        time_transformer, freq_transformer = transformer_block
        print(f"i: {i}, time_transformer: \n {time_transformer}, freq_transformer:  \n {freq_transformer}")
        break

In [1]:
from transformers.utils import is_flash_attn_2_available
print(is_flash_attn_2_available())

  from .autonotebook import tqdm as notebook_tqdm


True


# 检查两个的rope embedding 是否能够相同

In [9]:
import torch

seq_length = 352800
t = torch.ones(1, 1, seq_length, 128, dtype=torch.float32)

from rotary_embedding_torch import RotaryEmbedding
mel_band_roformer_rope_embedding = RotaryEmbedding(48)

device, dtype, seq_len = t.device, t.dtype, t.shape[-2]

seq = mel_band_roformer_rope_embedding.get_seq_pos(seq_len, device = device, dtype = dtype, offset = 0)

freqs = mel_band_roformer_rope_embedding.forward(seq, seq_len = seq_len, offset = 0)
        
from transformers import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding
config = LlamaConfig(
    hidden_size=384,
    num_attention_heads=8,
    intermediate_size=384 * 4,  # Llama 默认前馈网络扩展为 4 倍
    rms_norm_eps=1e-6,
    hidden_dropout_prob=0.1,
    attention_dropout_prob=0.1,
    rope_theta=10000.0  # 默认值，可调整
)
mel_band_llama_rotary_emb = LlamaRotaryEmbedding(config=config)

def get_pos_ids(seq_length, device):
    return torch.arange(seq_length, device=device).unsqueeze(0)

pos_ids = get_pos_ids(seq_length, 'cpu')
mel_band_llama_embedding = mel_band_llama_rotary_emb(t, pos_ids)[0].squeeze(0)

print(mel_band_llama_embedding.shape, freqs.shape)
assert mel_band_llama_embedding.shape == freqs.shape
print(mel_band_llama_embedding, freqs)

torch.Size([352800, 48]) torch.Size([352800, 48])
tensor([[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
        [ 0.5403,  0.7768,  0.8942,  ...,  1.0000,  1.0000,  1.0000],
        [-0.4161,  0.2067,  0.5992,  ...,  1.0000,  1.0000,  1.0000],
        ...,
        [-0.7561,  0.6777,  0.0868,  ...,  0.0377,  0.8199,  0.0527],
        [-0.9592,  0.0727, -0.3726,  ...,  0.0380,  0.8198,  0.0526],
        [-0.2804, -0.5767, -0.7517,  ...,  0.0383,  0.8196,  0.0524]]) tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [1.0000e+00, 1.0000e+00, 6.8129e-01,  ..., 2.1544e-04, 1.4678e-04,
         1.4678e-04],
        [2.0000e+00, 2.0000e+00, 1.3626e+00,  ..., 4.3089e-04, 2.9356e-04,
         2.9356e-04],
        ...,
        [3.5280e+05, 3.5280e+05, 2.4036e+05,  ..., 7.6008e+01, 5.1784e+01,
         5.1784e+01],
        [3.5280e+05, 3.5280e+05, 2.4036e+05,  ..., 7.6008e+01, 5.1784e+01,
         5.1784e+01],
        [3.5280e+0

# 查看split 和 unbind

In [10]:
import torch
from torch import nn
from einops import rearrange, repeat

dim_inputs = (16,128,256)
mlp_list = nn.ModuleList()
for dim_in in dim_inputs:
    mlp_list.append(nn.Linear(dim_in, dim_in))
f = 16+128+256
f = f // 2
b = 2
t = 4096
c = 2
x = torch.randn(b, f, t, c, 128, dtype=torch.float32)
# [b f t c]
x = rearrange(x, 'b f t c d -> b t (f c) d')
print(f"x: {x.shape}")
x = x.split(dim_inputs, dim=-2)
print(f"x: {x[0].shape}, {x[1].shape}, {x[2].shape}")
x = torch.cat(x, dim=-2)
print(f"x: {x.shape}")
x = x.unbind(dim=-2)
for band_feature, mlp in zip(x, mlp_list):
    print(band_feature.shape)

x: torch.Size([2, 4096, 400, 128])
x: torch.Size([2, 4096, 16, 128]), torch.Size([2, 4096, 128, 128]), torch.Size([2, 4096, 256, 128])
x: torch.Size([2, 4096, 400, 128])
torch.Size([2, 4096, 128])
torch.Size([2, 4096, 128])
torch.Size([2, 4096, 128])


# 查看重复indices的操作

In [8]:
import torch
from einops import repeat

# need to average the estimated mask for the overlapped frequencies
masks = [[[0.1, 0.2, 0.3, 0.4, 0.5, 0.6]]]
masks = torch.tensor(masks, dtype=torch.float32)
freq_indices = torch.tensor([1,2,3,2,3,4], dtype=torch.long)
batch = 2
num_stems = 1
stft_repr = torch.randn(batch, num_stems, 2, dtype=torch.float32)
scatter_indices = repeat(freq_indices, 'f -> b n f t', b=batch, n=num_stems, t=stft_repr.shape[-1])
print(f"scatter_indices: {scatter_indices.shape}, {scatter_indices}")
stft_repr_expanded_stems = repeat(stft_repr, 'b 1 ... -> b n ...', n=num_stems)
masks_summed = torch.zeros_like(stft_repr_expanded_stems).scatter_add_(2, scatter_indices, masks)


scatter_indices: torch.Size([2, 1, 6, 2]), tensor([[[[1, 1],
          [2, 2],
          [3, 3],
          [2, 2],
          [3, 3],
          [4, 4]]],


        [[[1, 1],
          [2, 2],
          [3, 3],
          [2, 2],
          [3, 3],
          [4, 4]]]])


RuntimeError: Index tensor must have the same number of dimensions as self tensor