<a href="https://colab.research.google.com/github/abubakarafzal/beatrice-trainer/blob/copy/beatrice_trainer_colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Beatrice Trainer - Google Colab

This notebook allows you to train Beatrice voice conversion models on Google Colab with T4 GPU support.

## Instructions:
1. Make sure to select a GPU runtime: Runtime → Change runtime type → GPU (T4)
2. Your training data is already in the repository at `datasets/model_1/`
3. Your configuration is already set in `datasets/model_1_config_lowmem.json`
4. Run all cells to start training with your existing setup


## 1. Setup Environment


In [1]:
# Install dependencies
!pip install -q torch torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q tqdm numpy tensorboard soundfile pyworld ipynbname

# Verify GPU
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA version: {torch.version.cuda}")


[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/261.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.0/261.0 kB[0m [31m7.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pyworld (pyproject.toml) ... [?25l[?25hdone
CUDA available: False


In [2]:
# Clone the repository
import os
from pathlib import Path

if not os.path.exists("beatrice-trainer"):
    !git clone -q https://github.com/abubakarafzal/beatrice-trainer.git
    !cd beatrice-trainer && git lfs pull

os.chdir("beatrice-trainer")
print(f"Current directory: {os.getcwd()}")


Current directory: /content/beatrice-trainer


## 2. Verify Training Data

Your training data is already in the repository at `datasets/model_1/`. The structure should be:
```
datasets/model_1/
  - my/
    - b1.wav
    - ...
```

**Note**: The data is already configured and ready to use.


In [None]:
from pathlib import Path

# Use existing data directory from repository
data_dir = Path("datasets/model_1")

# Verify data exists
if data_dir.exists():
    print(f"✓ Found training data directory: {data_dir}")
    print("\nTraining data structure:")
    for item in data_dir.rglob("*.wav"):
        print(f"  {item.relative_to(data_dir)}")
    for item in data_dir.rglob("*.flac"):
        print(f"  {item.relative_to(data_dir)}")
    print(f"\nTotal audio files found: {len(list(data_dir.rglob('*.wav'))) + len(list(data_dir.rglob('*.flac')))}")
else:
    raise FileNotFoundError(f"Training data directory not found: {data_dir}")


## 3. Load Training Configuration

Your training configuration is already set in `datasets/model_1_config_lowmem.json`. This configuration is optimized for T4 GPU (16GB VRAM).


In [None]:
import json
from pathlib import Path

# Load your existing config file
config_path = Path("datasets/model_1_config_lowmem.json")

if not config_path.exists():
    raise FileNotFoundError(f"Config file not found: {config_path}")

print(f"✓ Loading configuration from: {config_path}")

# Load your existing config
with open(config_path, "r") as f:
    config = json.load(f)

# Add data_dir and out_dir to your config
config["data_dir"] = str(data_dir)  # Use your existing data directory
config["out_dir"] = "/content/outputs"  # Output directory in Colab

# Save the updated config for training
training_config_path = Path("training_config.json")
with open(training_config_path, "w") as f:
    json.dump(config, f, indent=2)

print("\nTraining configuration:")
print(f"  Data directory: {config['data_dir']}")
print(f"  Output directory: {config['out_dir']}")
print(f"  Batch size: {config['batch_size']}")
print(f"  Hidden channels: {config['hidden_channels']}")
print(f"  Total steps: {config['n_steps']}")
print(f"  Wav length: {config['wav_length']}")
print(f"  Segment length: {config['segment_length']}")
print(f"  Use AMP: {config['use_amp']}")
print(f"\n✓ Configuration saved to: {training_config_path}")


In [None]:
# Import and run training
import sys
import os
sys.path.insert(0, '/content/beatrice-trainer')
os.chdir('/content/beatrice-trainer')

# Modify the notebook detection to work in Colab
import beatrice_trainer.__main__ as trainer_module
from contextlib import nullcontext
from tqdm.auto import tqdm
import math
import torch
import torch.nn as nn
import torchaudio
from torch.nn import functional as F
import gc
import gzip
import shutil
from collections import defaultdict

# Override the prepare_training_configs_for_experiment function for Colab
def prepare_training_configs_for_colab():
    from copy import deepcopy
    from pathlib import Path
    import json

    # Load config
    config_path = Path("training_config.json")
    with open(config_path, "r") as f:
        h = json.load(f)

    data_dir = Path(h.pop("data_dir"))
    out_dir = Path(h.pop("out_dir"))

    # Fill in defaults for any missing keys
    default_hparams = trainer_module.dict_default_hparams
    for key in default_hparams.keys():
        if key not in h:
            h[key] = default_hparams[key]

    return h, data_dir, out_dir, False, False

# Replace the function
trainer_module.prepare_training_configs_for_experiment = prepare_training_configs_for_colab

# Prepare training (this sets up everything)
(
    device,
    in_wav_dataset_dir,
    h,
    out_dir,
    speakers,
    test_filelist,
    training_loader,
    speaker_f0s,
    test_pitch_shifts,
    phone_extractor,
    pitch_estimator,
    net_g,
    net_d,
    optim_g,
    optim_d,
    grad_scaler,
    grad_balancer,
    resample_to_in_sample_rate,
    initial_iteration,
    scheduler_g,
    scheduler_d,
    dict_scalars,
    quality_tester,
    writer,
) = trainer_module.prepare_training()

# Import necessary functions and classes
from beatrice_trainer.__main__ import (
    ConvNeXtStack,
    MultiPeriodDiscriminator,
    PhoneExtractor,
    PitchEstimator,
    ConverterNetwork,
    get_resampler,
    compute_grad_norm,
    get_compressed_optimizer_state_dict,
    PARAPHERNALIA_VERSION,
    repo_root,
)

# Training loop (adapted from __main__ block)
if writer is not None:
    if h.compile_convnext:
        raw_convnextstack_forward = ConvNeXtStack.forward
        compiled_convnextstack_forward = torch.compile(
            ConvNeXtStack.forward, mode="reduce-overhead"
        )
    if h.compile_d4c:
        d4c = torch.compile(d4c, mode="reduce-overhead")
    if h.compile_discriminator:
        MultiPeriodDiscriminator.forward_and_compute_loss = torch.compile(
            MultiPeriodDiscriminator.forward_and_compute_loss, mode="reduce-overhead"
        )

    # Training loop
    with (
        torch.profiler.profile(
            schedule=torch.profiler.schedule(wait=1500, warmup=10, active=5, repeat=1),
            on_trace_ready=torch.profiler.tensorboard_trace_handler(out_dir),
            record_shapes=True,
            with_stack=True,
            profile_memory=True,
            with_flops=True,
        )
        if h.profile
        else nullcontext()
    ) as profiler:
        data_iter = iter(training_loader)
        for iteration in tqdm(range(initial_iteration, h.n_steps), desc="Training"):
            # === 1. データ前処理 ===
            try:
                batch = next(data_iter)
            except (NameError, StopIteration):
                data_iter = iter(training_loader)
                batch = next(data_iter)
            (
                clean_wavs,
                noisy_wavs_16k,
                slice_starts,
                speaker_ids,
                formant_shift_semitone,
            ) = map(lambda x: x.to(device, non_blocking=True), batch)

            # === 2. 学習 ===
            with torch.amp.autocast("cuda", enabled=h.use_amp):
                # === 2.1 Generator の順伝播 ===
                if h.compile_convnext:
                    ConvNeXtStack.forward = compiled_convnextstack_forward
                (
                    y,
                    y_hat,
                    y_hat_for_backward,
                    loss_loudness,
                    loss_mel,
                    loss_ap,
                    generator_stats,
                ) = net_g.forward_and_compute_loss(
                    noisy_wavs_16k[:, None, :],
                    speaker_ids,
                    formant_shift_semitone,
                    slice_start_indices=slice_starts,
                    slice_segment_length=h.segment_length,
                    y_all=clean_wavs[:, None, :],
                    enable_loss_ap=h.grad_weight_ap != 0.0,
                )
                if h.compile_convnext:
                    ConvNeXtStack.forward = raw_convnextstack_forward
                assert y_hat.isfinite().all()
                assert loss_loudness.isfinite().all()
                assert loss_mel.isfinite().all()
                assert loss_ap.isfinite().all()

                # === 2.2 Discriminator の順伝播 ===
                loss_discriminator, loss_adv, loss_fm, discriminator_stats = (
                    net_d.forward_and_compute_loss(y, y_hat)
                )
                assert loss_discriminator.isfinite().all()
                assert loss_adv.isfinite().all()
                assert loss_fm.isfinite().all()

            # === 2.3 Discriminator の逆伝播 ===
            for param in net_d.parameters():
                assert param.grad is None
            grad_scaler.scale(loss_discriminator).backward(
                retain_graph=True, inputs=list(net_d.parameters())
            )
            loss_discriminator = loss_discriminator.item()
            grad_scaler.unscale_(optim_d)
            if iteration % 5 == 0:
                grad_norm_d, d_grad_norm_stats = compute_grad_norm(net_d, True)
            else:
                grad_norm_d = math.nan
                d_grad_norm_stats = {}

            # === 2.4 Generator の逆伝播 ===
            for param in net_g.parameters():
                assert param.grad is None
            gradient_balancer_stats = grad_balancer.backward(
                {
                    "loss_loudness": loss_loudness,
                    "loss_mel": loss_mel,
                    "loss_adv": loss_adv,
                    "loss_fm": loss_fm,
                }
                | ({"loss_ap": loss_ap} if h.grad_weight_ap else {}),
                y_hat_for_backward,
                grad_scaler,
                skip_update_ema=iteration > 10 and iteration % 5 != 0,
            )
            loss_loudness = loss_loudness.item()
            loss_mel = loss_mel.item()
            loss_adv = loss_adv.item()
            loss_fm = loss_fm.item()
            if h.grad_weight_ap:
                loss_ap = loss_ap.item()
            grad_scaler.unscale_(optim_g)
            if iteration % 5 == 0:
                grad_norm_g, g_grad_norm_stats = compute_grad_norm(net_g, True)
            else:
                grad_norm_g = math.nan
                g_grad_norm_stats = {}

            # === 2.5 パラメータの更新 ===
            grad_scaler.step(optim_g)
            optim_g.zero_grad(set_to_none=True)
            grad_scaler.step(optim_d)
            optim_d.zero_grad(set_to_none=True)
            grad_scaler.update()

            # === 3. ログ ===
            dict_scalars["loss_g/loss_loudness"].append(loss_loudness)
            dict_scalars["loss_g/loss_mel"].append(loss_mel)
            if h.grad_weight_ap:
                dict_scalars["loss_g/loss_ap"].append(loss_ap)
            dict_scalars["loss_g/loss_fm"].append(loss_fm)
            dict_scalars["loss_g/loss_adv"].append(loss_adv)
            dict_scalars["other/grad_scale"].append(grad_scaler.get_scale())
            dict_scalars["loss_d/loss_discriminator"].append(loss_discriminator)
            if math.isfinite(grad_norm_d):
                dict_scalars["other/gradient_norm_d"].append(grad_norm_d)
                for name, value in d_grad_norm_stats.items():
                    dict_scalars[f"~gradient_norm_d/{name}"].append(value)
            if math.isfinite(grad_norm_g):
                dict_scalars["other/gradient_norm_g"].append(grad_norm_g)
                for name, value in g_grad_norm_stats.items():
                    dict_scalars[f"~gradient_norm_g/{name}"].append(value)
            dict_scalars["other/lr_g"].append(scheduler_g.get_last_lr()[0])
            dict_scalars["other/lr_d"].append(scheduler_d.get_last_lr()[0])
            for k, v in generator_stats.items():
                dict_scalars[f"~loss_generator/{k}"].append(v)
            for k, v in discriminator_stats.items():
                dict_scalars[f"~loss_discriminator/{k}"].append(v)
            for k, v in gradient_balancer_stats.items():
                dict_scalars[f"~gradient_balancer/{k}"].append(v)

            if (iteration + 1) % 1000 == 0 or iteration == 0:
                for name, scalars in dict_scalars.items():
                    if scalars:
                        writer.add_scalar(
                            name, sum(scalars) / len(scalars), iteration + 1
                        )
                        scalars.clear()

            # === 4. 検証 ===
            if (iteration + 1) % h.evaluation_interval == 0 or iteration + 1 in {
                1,
                h.n_steps,
            }:
                torch.backends.cudnn.benchmark = False
                net_g.eval()
                torch.cuda.empty_cache()

                dict_qualities_all = defaultdict(list)
                n_added_wavs = 0
                with torch.inference_mode():
                    for i, ((file, target_ids), pitch_shift_semitones) in enumerate(
                        zip(test_filelist, test_pitch_shifts)
                    ):
                        source_wav, sr = torchaudio.load(file, backend="soundfile")
                        source_wav = source_wav.to(device)
                        if sr != h.in_sample_rate:
                            source_wav = get_resampler(sr, h.in_sample_rate, device)(
                                source_wav
                            )
                        source_wav = source_wav.to(device)
                        original_source_wav_length = source_wav.size(1)
                        if source_wav.size(1) % h.in_sample_rate == 0:
                            padded_source_wav = source_wav
                        else:
                            padded_source_wav = F.pad(
                                source_wav,
                                (
                                    0,
                                    h.in_sample_rate
                                    - source_wav.size(1) % h.in_sample_rate,
                                ),
                            )
                        converted = net_g(
                            padded_source_wav[[0] * len(target_ids), None],
                            torch.tensor(target_ids, device=device),
                            torch.tensor(
                                [0.0] * len(target_ids), device=device
                            ),
                            torch.tensor(
                                [float(p) for p in pitch_shift_semitones], device=device
                            ),
                        ).squeeze_(1)[:, : original_source_wav_length // 160 * 240]
                        if i < 12:
                            if iteration == 0:
                                writer.add_audio(
                                    f"source/y_{i:02d}",
                                    source_wav,
                                    iteration + 1,
                                    h.in_sample_rate,
                                )
                            for d in range(
                                min(
                                    len(target_ids),
                                    1 + (12 - i - 1) // len(test_filelist),
                                )
                            ):
                                idx_in_batch = n_added_wavs % len(target_ids)
                                writer.add_audio(
                                    f"converted/y_hat_{i:02d}_{target_ids[idx_in_batch]:03d}_{pitch_shift_semitones[idx_in_batch]:+02d}",
                                    converted[idx_in_batch],
                                    iteration + 1,
                                    h.out_sample_rate,
                                )
                                n_added_wavs += 1
                        converted = resample_to_in_sample_rate(converted)
                        quality = quality_tester.test(converted, source_wav)
                        for metric_name, values in quality.items():
                            dict_qualities_all[metric_name].extend(values)
                dict_qualities = {
                    metric_name: sum(values) / len(values)
                    for metric_name, values in dict_qualities_all.items()
                    if len(values)
                }
                for metric_name, value in dict_qualities.items():
                    writer.add_scalar(f"validation/{metric_name}", value, iteration + 1)

                net_g.train()
                torch.backends.cudnn.benchmark = True
                gc.collect()
                torch.cuda.empty_cache()

            # === 5. 保存 ===
            if (iteration + 1) % h.save_interval == 0 or iteration + 1 in {
                1,
                h.n_steps,
            }:
                # チェックポイント
                name = f"{in_wav_dataset_dir.name}_{iteration + 1:08d}"
                checkpoint_file_save = out_dir / f"checkpoint_{name}.pt.gz"
                if checkpoint_file_save.exists():
                    checkpoint_file_save = checkpoint_file_save.with_name(
                        f"{checkpoint_file_save.name}_{hash(None):x}"
                    )
                with gzip.open(checkpoint_file_save, "wb") as f:
                    torch.save(
                        {
                            "iteration": iteration + 1,
                            "net_g": net_g.state_dict(),
                            "phone_extractor": phone_extractor.state_dict(),
                            "pitch_estimator": pitch_estimator.state_dict(),
                            "net_d": {
                                k: v.half() for k, v in net_d.state_dict().items()
                            },
                            "optim_g": get_compressed_optimizer_state_dict(optim_g),
                            "optim_d": get_compressed_optimizer_state_dict(optim_d),
                            "grad_balancer": grad_balancer.state_dict(),
                            "grad_scaler": grad_scaler.state_dict(),
                            "h": dict(h),
                        },
                        f,
                    )
                shutil.copy(checkpoint_file_save, out_dir / "checkpoint_latest.pt.gz")

                # 推論用
                paraphernalia_dir = out_dir / f"paraphernalia_{name}"
                if paraphernalia_dir.exists():
                    paraphernalia_dir = paraphernalia_dir.with_name(
                        f"{paraphernalia_dir.name}_{hash(None):x}"
                    )
                paraphernalia_dir.mkdir()
                phone_extractor_fp16 = PhoneExtractor()
                phone_extractor_fp16.load_state_dict(phone_extractor.state_dict())
                phone_extractor_fp16.remove_weight_norm()
                phone_extractor_fp16.merge_weights()
                phone_extractor_fp16.half()
                phone_extractor_fp16.dump(paraphernalia_dir / "phone_extractor.bin")
                del phone_extractor_fp16
                pitch_estimator_fp16 = PitchEstimator()
                pitch_estimator_fp16.load_state_dict(pitch_estimator.state_dict())
                pitch_estimator_fp16.merge_weights()
                pitch_estimator_fp16.half()
                pitch_estimator_fp16.dump(paraphernalia_dir / "pitch_estimator.bin")
                del pitch_estimator_fp16
                net_g_fp16 = ConverterNetwork(
                    nn.Module(),
                    nn.Module(),
                    len(speakers),
                    h.pitch_bins,
                    h.hidden_channels,
                    h.vq_topk,
                    h.training_time_vq,
                    h.phone_noise_ratio,
                    h.floor_noise_level,
                )
                net_g_fp16.load_state_dict(net_g.state_dict())
                net_g_fp16.merge_weights()
                net_g_fp16.half()
                net_g_fp16.dump(paraphernalia_dir / "waveform_generator.bin")
                net_g_fp16.dump_speaker_embeddings(
                    paraphernalia_dir / "speaker_embeddings.bin"
                )
                net_g_fp16.dump_embedding_setter(
                    paraphernalia_dir / "embedding_setter.bin"
                )
                del net_g_fp16
                shutil.copy(
                    repo_root() / "assets/images/noimage.png", paraphernalia_dir
                )
                with open(
                    paraphernalia_dir / f"beatrice_paraphernalia_{name}.toml",
                    "w",
                    encoding="utf-8",
                ) as f:
                    f.write(
                        f'''[model]
version = "{PARAPHERNALIA_VERSION}"
name = "{name}"
description = """
No description for this model.
このモデルの説明はありません。
"""
'''
                    )
                    for speaker_id, (speaker, speaker_f0) in enumerate(
                        zip(speakers, speaker_f0s)
                    ):
                        average_pitch = 69.0 + 12.0 * math.log2(speaker_f0 / 440.0)
                        average_pitch = round(average_pitch * 8.0) / 8.0
                        f.write(
                            f'''
[voice.{speaker_id}]
name = "{speaker}"
description = """
No description for this voice.
この声の説明はありません。
"""
average_pitch = {average_pitch}

[voice.{speaker_id}.portrait]
path = "noimage.png"
description = """
"""
'''
                        )

            # === 6. スケジューラ更新 ===
            scheduler_g.step()
            scheduler_d.step()
            if h.profile:
                profiler.step()

    print("Training finished.")


## 5. Download Results

After training completes, download your trained model from the output directory.


In [None]:
from pathlib import Path
import zipfile
from google.colab import files

output_dir = Path("/content/outputs")

# Find the latest paraphernalia directory
paraphernalia_dirs = list(output_dir.glob("paraphernalia_*"))
if paraphernalia_dirs:
    latest_dir = max(paraphernalia_dirs, key=lambda p: p.stat().st_mtime)
    print(f"Found model: {latest_dir.name}")

    # Create zip file
    zip_path = Path(f"/content/{latest_dir.name}.zip")
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        for file in latest_dir.rglob("*"):
            if file.is_file():
                zipf.write(file, file.relative_to(latest_dir.parent))

    print(f"\nDownloading {zip_path.name}...")
    files.download(str(zip_path))
    print("\nDownload complete!")
else:
    print("No paraphernalia directory found. Training may not have completed yet.")
    print("Available files in output directory:")
    for item in output_dir.iterdir():
        print(f"  {item.name}")


## 6. Resume Training (Optional)

If training was interrupted, you can resume from the latest checkpoint.


In [None]:
# Modify the resume function for Colab
def prepare_training_configs_for_colab_resume():
    from copy import deepcopy
    from pathlib import Path
    import json

    # Load config
    config_path = Path("training_config.json")
    with open(config_path, "r") as f:
        h = json.load(f)

    data_dir = Path(h.pop("data_dir"))
    out_dir = Path(h.pop("out_dir"))

    # Fill in defaults for any missing keys
    default_hparams = trainer_module.dict_default_hparams
    for key in default_hparams.keys():
        if key not in h:
            h[key] = default_hparams[key]

    return h, data_dir, out_dir, True, False  # Set resume=True

# Replace the function
trainer_module.prepare_training_configs_for_experiment = prepare_training_configs_for_colab_resume

# Run training with resume
trainer_module.prepare_training()
