In [1]:
import sys
from pathlib import Path

%load_ext autoreload
%autoreload 2

notebook_path = Path().resolve()
project_root = notebook_path.parent
sys.path.append(str(project_root))
print('Project root added:', project_root)

Project root added: /Users/aleksandr/MMLS/Music_MMLS


In [2]:
import os
import torch

from lightning.pytorch.loggers import WandbLogger

from Music_MMLS.data.ldatamodule import MusicDataModule
from Music_MMLS.models.lmodule import MusicModelModule
import lightning as L




In [3]:
from hydra import initialize, compose
from omegaconf import OmegaConf

with initialize(config_path='../configs', job_name='demo', version_base='1.3'):
    cfg = compose(config_name='config')

print('Hydra Config:\n')
print(OmegaConf.to_yaml(cfg))

Hydra Config:

project:
  name: Music_MMLS
  seed: 42
  device: cuda
  wandb_entity: sasha_kovylyaev-hse
  experiment_name: default
dataset:
  size: 500
  data_dir: ../content/sample_data/Data
  clean_dir: ../content/sample_data/Data/all_records
  noise_dir: ../content/sample_data/Data/noise
  test_size: 0.2
model:
  model: UNet
  n_channels: 1
training:
  epochs: 10
  learning_rate: 0.001
  batch_size: 4
  precision: 32
  optimizer: Adam
  criterion: MSE
  scheduler: ''



In [4]:
clean_files = [os.path.join(cfg.dataset.clean_dir, f)
               for f in os.listdir(cfg.dataset.clean_dir) if f.endswith('.wav')]
noise_files = [os.path.join(cfg.dataset.noise_dir, f) 
               for f in os.listdir(cfg.dataset.noise_dir) if f.endswith('.wav')]

print(f'Found {len(clean_files)} clean files and {len(noise_files)} noise files.')

Found 275 clean files and 5000 noise files.


In [None]:
batch_size = cfg.training.batch_size
num_epochs = cfg.training.epochs
lr = cfg.training.learning_rate

music_data_module = MusicDataModule(cfg.dataset, cfg.training.batch_size)
music_model_module = MusicModelModule(cfg.model, cfg.training)

device = torch.device(cfg.project.device if torch.cuda.is_available() else 'cpu')

wandb_config = OmegaConf.to_container(
    cfg, resolve=True, throw_on_missing=True
)

wandb_logger = WandbLogger(project=cfg.project.name, name=cfg.project.experiment_name, log_model='all', dir='../checkpoints')
wandb_logger.log_hyperparams(wandb_config)

trainer = L.Trainer(
    max_epochs=cfg.training.epochs,
    logger=wandb_logger,
    devices=-1,
    precision=cfg.training.precision
)

print(f'Training on {device} with batch size {batch_size} for {num_epochs} epochs.')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33msasha_kovylyaev[0m ([33msasha_kovylyaev-hse[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Training on cpu with batch size 4 for 10 epochs.


In [None]:

trainer.fit(
    music_model_module,
    datamodule=music_data_module,
)

print("Training complete.")


  | Name          | Type             | Params | Mode 
-----------------------------------------------------------
0 | model         | UNet             | 1.9 M  | train
1 | criterion     | MSELoss          | 0      | train
2 | train_metrics | MetricCollection | 0      | train
3 | test_metrics  | MetricCollection | 0      | train
-----------------------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.769     Total estimated model params size (MB)
103       Modules in train mode
0         Modules in eval mode


ℹ️ Папка ../content/sample_data/Data/all_records не пуста, возможно, датасет уже там
ℹ️ Папка ../content/sample_data/Data/noise не пуста, возможно, датасет уже там


/Users/aleksandr/MMLS/Music_MMLS/.venv/lib/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:420: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

