In [None]:
from math import sqrt
from pathlib import Path
import toml

import torch
from torch.utils.data import DataLoader

from audiozen.logger import init_logging_logger
from audiozen.utils import instantiate

from accelerate import Accelerator


def run(config_path: str, mode: list, resume: bool = False, ckpt_path: str = None):
    # Load config
    config_path = Path(config_path).expanduser().absolute()
    config = toml.load(config_path.as_posix())

    config["meta"]["exp_id"] = config_path.stem
    config["meta"]["config_path"] = config_path.as_posix()
    
    # Adding this here because its a pain to remove accelerator from audiozen
    accelerator = Accelerator()

    if "test" in mode:
        if ckpt_path is None:
            raise ValueError("Checkpoint path is required for 'test'. Please provide 'ckpt_path'.")
        config["meta"]["ckpt_path"] = ckpt_path

    init_logging_logger(config)
    torch.manual_seed(config["meta"]["seed"])

    # Model
    model = instantiate(config["model"]["path"], args=config["model"]["args"])

    # Optimizer
    optimizer = instantiate(
        config["optimizer"]["path"],
        args={"params": model.parameters()} | config["optimizer"]["args"]
    )

    # Loss function
    loss_function = instantiate(config["loss_function"]["path"], args=config["loss_function"]["args"])

    # Dataloaders
    train_dataloader, validate_dataloaders, test_dataloaders = None, [], []

    if "train" in mode:
        train_dataset = instantiate(config["train_dataset"]["path"], args=config["train_dataset"]["args"])
        train_dataloader = DataLoader(
            dataset=train_dataset, collate_fn=None, shuffle=True, **config["train_dataset"]["dataloader"]
        )

    if "train" in mode or "validate" in mode:
        val_configs = config["validate_dataset"]
        if not isinstance(val_configs, list):
            val_configs = [val_configs]

        for val_config in val_configs:
            validate_dataset = instantiate(val_config["path"], args=val_config["args"])
            val_loader = DataLoader(dataset=validate_dataset, **val_config["dataloader"])
            validate_dataloaders.append(val_loader)

    if "test" in mode:
        test_configs = config["test_dataset"]
        if not isinstance(test_configs, list):
            test_configs = [test_configs]

        for test_config in test_configs:
            test_dataset = instantiate(test_config["path"], args=test_config["args"])
            test_loader = DataLoader(dataset=test_dataset, **test_config["dataloader"])
            test_dataloaders.append(test_loader)

    # Trainer
    trainer = instantiate(config["trainer"]["path"], initialize=False)(
        accelerator=accelerator,  # Removed accelerate
        config=config,
        resume=resume,
        model=model,
        optimizer=optimizer,
        loss_function=loss_function,
    )

    for flag in mode:
        if flag == "train":
            trainer.train(train_dataloader, validate_dataloaders)
        elif flag == "validate":
            trainer.validate(validate_dataloaders)
        elif flag == "test":
            trainer.test(test_dataloaders, config["meta"]["ckpt_path"])
        elif flag == "predict":
            raise NotImplementedError("Predict is not implemented yet.")
        elif flag == "finetune":
            raise NotImplementedError("Finetune is not implemented yet.")
        else:
            raise ValueError(f"Unknown mode: {flag}")


In [None]:
# Define the config path and mode
config_path = "./config/baseline_m.toml"  # Adjust path to your config
# mode = ["train", "validate"]  # or ["test"], etc.
mode = ["train"]
resume = False
ckpt_path = "best"  # Required only for test mode

# Run
run(config_path=config_path, mode=mode, resume=resume, ckpt_path=ckpt_path)
