In [None]:
import sys
from pathlib import Path

notebook_path = Path().resolve()
project_root = notebook_path.parent
sys.path.append(str(project_root))


In [None]:
import os
import torch
from torch.utils.data import DataLoader
from torch import nn, optim
import wandb

from Music_MMLS.data.dataset import Music_Dataset
from Music_MMLS.models.unet import UNet
from Music_MMLS.training.train import Trainer


In [2]:
clean_path = '../content/sample_data/Data/all_records/'
noise_path = '../content/sample_data/Data/noise/'

clean_files = [os.path.join(clean_path, f) for f in os.listdir(clean_path) if f.endswith('.wav')]
noise_files = [os.path.join(noise_path, f) for f in os.listdir(noise_path) if f.endswith('.wav')]

In [4]:
batch_size = 4
num_epochs = 10
learning_rate = 1e-3
model_name = "UNet"
dataset_name = "Dataset"

dataset = Music_Dataset(size=500, clean_files=clean_files, noise_files=noise_files)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(n_channels=1)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

config = {
    "batch_size": batch_size,
    "num_epochs": num_epochs,
    "learning_rate": learning_rate,
    "model": model_name,
    "dataset": dataset_name
}

trainer = Trainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    use_wandb=True,
    config=config
)

for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    train_loss, train_metrics = trainer.train_epoch(dataloader)
    print(f"Train Loss: {train_loss:.4f}")
    avg_metrics = {key: sum(d[key] for d in train_metrics)/len(train_metrics) for key in train_metrics[0]}
    for k, v in avg_metrics.items():
        print(f"{k.upper()}: {v:.4f}")

os.makedirs("../checkpoints", exist_ok=True)
torch.save(model.state_dict(), "../checkpoints/unet_final.pth")

Epoch 1/10


Train:   2%|▏         | 3/125 [00:02<01:49,  1.11it/s]


KeyboardInterrupt: 