In [1]:
import os
import sys

project_root = os.getcwd()
while "src" not in os.listdir(project_root):
    project_root = os.path.dirname(project_root)
sys.path.append(project_root)

In [None]:
import os
import torch
import torchaudio
import numpy as np
from typing import List, Dict, Any, Tuple, Callable
import matplotlib.pyplot as plt
import random
import librosa

# Global configuration for plotting
plt.rcParams["figure.figsize"] = [20, 6]

In [None]:
data_root = os.path.join(project_root, "data", "processed", "train")

if not os.path.exists(data_root):
    raise FileNotFoundError(
        "No se encontró la carpeta data, por favor ejecute el script download_data.sh antes de ejecutar este script."
    )

In [None]:
from src.utils.audio import AudioSample
from src.utils.data.dataset import AudioDataset
from src.utils.data.transforms import TransformChain, FreqSplit, STFTTransform

samples: List[AudioSample] = []
for file in os.listdir(data_root):
    npz_file = os.path.join(data_root, file)
    data = np.load(npz_file, allow_pickle=True)
    for key in data.files:
        samples.append(AudioSample(data[key].item()))
    break

In [None]:
tc = TransformChain()
convert_to_tensor: Callable[[AudioSample], Dict[str, torch.Tensor]] = (
    lambda x: x.to_tensor()
)
tc.add_transform(convert_to_tensor)
tc.add_transform(STFTTransform())

dataset = AudioDataset(samples, tc)

In [None]:
from src.utils.training.loss import sdr_loss, si_sdr_loss, snr_loss
from src.utils.training.trainer import Trainer
from src.models import SingleChannelUNet

model = SingleChannelUNet()