In [None]:
import torch
import torchaudio
from mamba_ssm import Mamba
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
import pandas as pd
import numpy as np
import os
import IPython.display as ipd
import librosa

In [None]:
len(os.listdir('data/youtubemix/piano'))

In [None]:
fname = 'data/youtubemix/piano/0.wav'   # Raindrop
ipd.Audio(fname)

In [None]:
data, _ = librosa.core.load('data/youtubemix/piano/0.wav', sr=96000, res_type='kaiser_fast')

In [None]:
data[100000:100010]

In [None]:
data.shape

In [None]:
# DataGenerator class to load .wav files with librosa and return a PyTorch tensor loaded in cuda in batches
class DataGenerator(torch.utils.data.Dataset):
    def __init__(self, files, sr=96000):
        self.files = files
        self.sr = sr

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        data, _ = librosa.core.load(self.files[idx], sr=self.sr, res_type='kaiser_fast')
        return torch.tensor(data, dtype=torch.float32).to("cuda")

In [None]:
# DataLoader to load torch tensors in batches
def get_dataloader(files, batch_size=8):
    dataset = DataGenerator(files)
    return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Load all .wav files in the directory
files = ['data/youtubemix/piano/' + f for f in os.listdir('data/youtubemix/piano') if f.endswith('.wav')]
train_files, test_files = train_test_split(files, test_size=0.2)

train_loader = get_dataloader(train_files)
test_loader = get_dataloader(test_files)

In [None]:
# Define the model
model = Mamba(
    # This module uses roughly 3 * expand * d_model^2 parameters
    d_model=1, # Model dimension d_model
    d_state=16,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

In [None]:
# Define the optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
NUM_EPOCHS = 10

In [None]:
# Train the model
for epoch in range(NUM_EPOCHS):
    model.train()
    total_loss = 0
    for data in train_loader:
        data = data.reshape(-1, data.shape[1], 1)
        optimizer.zero_grad()
        y = model(data)
        loss = torch.mean((y - data) ** 2)
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}: Loss = {total_loss}")

In [None]:
# Predict the output for the test data
with torch.no_grad():
    model.eval()
    y_pred = []
    y_true = []
    for data in test_loader:
        data = data.reshape(-1, data.shape[1], 1)
        y = model(data)
        y_pred.append(y.detach().cpu().numpy())
        y_true.append(data.detach().cpu().numpy())
    y_pred = np.concatenate(y_pred)
    y_true = np.concatenate(y_true)

In [None]:
# Save model
torch.save(model.state_dict(), 'models/audio_mamba_model.pth')

In [None]:
# Load model
model_loaded = Mamba(
    d_model=1,
    d_state=16,
    d_conv=4,
    expand=2
).to("cuda")
model_loaded.load_state_dict(torch.load('models/audio_mamba_model.pth'))

In [None]:
# Predict the output for the test data
with torch.no_grad():
    model_loaded.eval()
    y_pred = []
    y_true = []
    for data in test_loader:
        data = data.reshape(-1, data.shape[1], 1)
        y = model_loaded(data)
        y_pred.append(y.detach().cpu().numpy())
        y_true.append(data.detach().cpu().numpy())
    y_pred = np.concatenate(y_pred)
    y_true = np.concatenate(y_true)

In [None]:
y_pred[0].shape

In [None]:
torchaudio.save('data/youtubemix/output.wav', torch.tensor(y_pred[0], dtype=torch.float32), 96000)

In [None]:
ipd.Audio('data/youtubemix/output.wav')