# Ekodi â€” Fine-tune MMS-TTS (VITS) for Bambara Voice

This notebook trains a custom Bambara TTS model on Google Colab.

**What it does:**
1. Clones the Ekodi repo from GitHub (code + voice samples)
2. Downloads 10K+ Bambara speech samples from HuggingFace
3. Includes your custom voice recordings from the repo
4. Fine-tunes Facebook MMS-TTS-BAM (VITS architecture)
5. Pushes the trained model to HuggingFace Hub
6. Commits training results back to GitHub

**Requirements:** Google Colab with GPU (free T4 works)

**Repo:** https://github.com/adiarra14/ekodi

---

In [None]:
#@title 0. Setup â€” Check GPU & Install Dependencies
!nvidia-smi
import torch
print(f"\nPyTorch {torch.__version__}")
print(f"CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected! Go to Runtime > Change runtime type > GPU")

In [None]:
#@title 1. Clone repo & install packages
import os

# Clone the Ekodi repo from GitHub
REPO_URL = "https://github.com/adiarra14/ekodi.git"
REPO_DIR = "/content/ekodi"

if not os.path.exists(REPO_DIR):
    !git clone {REPO_URL} {REPO_DIR}
    print(f"Cloned {REPO_URL}")
else:
    !cd {REPO_DIR} && git pull
    print(f"Updated {REPO_DIR}")

# Change to project directory
os.chdir(f"{REPO_DIR}/tts")
print(f"Working directory: {os.getcwd()}")

# Install dependencies
!pip install -q transformers datasets soundfile librosa numpy scipy \
    huggingface_hub accelerate torchaudio pydub pyyaml

# Convert custom voice samples (.m4a -> .wav) if not already done
!apt-get install -qq ffmpeg > /dev/null 2>&1
print("Packages + ffmpeg installed!")

In [None]:
#@title 2. Login to HuggingFace
from huggingface_hub import login
from google.colab import userdata

# Option A: Use Colab Secrets (recommended â€” no token in code!)
# Go to: Colab left panel > ðŸ”‘ Secrets > Add "HF_TOKEN" with your token
try:
    HF_TOKEN = userdata.get("HF_TOKEN")
    print("Token loaded from Colab Secrets")
except Exception:
    # Option B: Paste your token here (will prompt if empty)
    HF_TOKEN = ""  #@param {type:"string"}

if HF_TOKEN:
    login(token=HF_TOKEN)
else:
    login()  # Interactive prompt

print("HuggingFace login OK!")

In [None]:
#@title 3. Configuration
import os
import yaml
from pathlib import Path

# Load config from the repo
cfg = yaml.safe_load(Path("config/ekodi-port.yml").read_text())

# === Training settings ===
BASE_MODEL = cfg["model"]["base_model"]     # "facebook/mms-tts-bam"
HUB_REPO = "adiarra14/ekodi-bambara-tts"    # Where to push trained model
MAX_PUBLIC_SAMPLES = 10000                    # How many public dataset samples
USE_CUSTOM_VOICE = True                       # Include your voice recordings
EPOCHS = 5                                    # Training epochs
BATCH_SIZE = 4                                # Batch size (4 works on T4 16GB)
LEARNING_RATE = 1e-4                          # Learning rate
MAX_AUDIO_SEC = 10.0                          # Max audio duration in seconds
TARGET_SR = cfg["data"].get("sample_rate", 16000)

# Directories
RAW_DIR = Path("data/raw")
CUSTOM_DIR = Path("data/custom_voice")
PROC_DIR = Path("data/processed")
CKPT_DIR = Path("checkpoints")
for d in [RAW_DIR, CUSTOM_DIR, PROC_DIR, CKPT_DIR]:
    d.mkdir(parents=True, exist_ok=True)

# Check if custom voice samples exist in the repo
voice_dir = Path("assets/voice")
m4a_files = list(voice_dir.glob("*.m4a")) if voice_dir.exists() else []
print(f"Base model:    {BASE_MODEL}")
print(f"Hub repo:      {HUB_REPO}")
print(f"Public data:   up to {MAX_PUBLIC_SAMPLES} samples")
print(f"Custom voice:  {len(m4a_files)} .m4a files in assets/voice/")
print(f"Config loaded from: config/ekodi-port.yml")

In [None]:
#@title 4a. Download public Bambara speech data
from datasets import load_dataset, Audio
import soundfile as sf
import numpy as np
import io

print("Downloading OumarDicko/Bambara_AudioSynthetique_42K_V3...")
ds = load_dataset("OumarDicko/Bambara_AudioSynthetique_42K_V3", split="train", streaming=True)

public_records = []
skipped = 0

for i, example in enumerate(ds):
    if i >= MAX_PUBLIC_SAMPLES:
        break

    text = example.get("sentence") or example.get("text") or example.get("transcription", "")
    if not text or not text.strip():
        skipped += 1
        continue

    audio = example.get("audio")
    if audio is None:
        skipped += 1
        continue

    wav_path = RAW_DIR / f"{i:06d}.wav"

    try:
        if isinstance(audio, dict) and "array" in audio:
            arr = np.array(audio["array"], dtype=np.float32)
            sr = audio.get("sampling_rate", 16000)
            sf.write(str(wav_path), arr, sr)
        elif isinstance(audio, dict) and "bytes" in audio and audio["bytes"]:
            arr, sr = sf.read(io.BytesIO(audio["bytes"]), dtype="float32")
            if arr.ndim > 1:
                arr = arr.mean(axis=1)
            sf.write(str(wav_path), arr, sr)
        else:
            skipped += 1
            continue

        public_records.append({"file_path": str(wav_path), "text": text.strip(), "source": "public"})

        if (i + 1) % 2000 == 0:
            print(f"  ... {i+1} processed, {len(public_records)} saved")
    except Exception:
        skipped += 1
        continue

print(f"\nPublic data: {len(public_records)} samples ({skipped} skipped)")

In [None]:
#@title 4b. Convert custom voice samples from repo
#
# Voice samples are in the repo at assets/voice/*.m4a
# Transcriptions are at assets/voice/transcriptions.csv
# This cell converts m4a -> wav and loads the transcriptions.

import csv
import subprocess

custom_records = []

if USE_CUSTOM_VOICE and m4a_files:
    print(f"Converting {len(m4a_files)} voice samples from assets/voice/ ...")
    CUSTOM_DIR.mkdir(parents=True, exist_ok=True)

    # Convert m4a -> wav
    converted = 0
    wav_map = {}  # m4a_filename -> wav_path
    for m4a in sorted(m4a_files):
        wav_path = CUSTOM_DIR / (m4a.stem + ".wav")
        try:
            subprocess.run([
                "ffmpeg", "-y", "-i", str(m4a),
                "-ar", str(TARGET_SR), "-ac", "1", "-sample_fmt", "s16",
                "-f", "wav", str(wav_path)
            ], capture_output=True, timeout=30, check=True)
            wav_map[m4a.name] = str(wav_path)
            converted += 1
        except Exception as e:
            print(f"  Failed: {m4a.name} ({e})")
    print(f"  Converted {converted}/{len(m4a_files)} files to WAV")

    # Load transcriptions from assets/voice/ (committed to git)
    trans_path = Path("assets/voice/transcriptions.csv")
    if trans_path.exists():
        with open(trans_path, encoding="utf-8") as f:
            reader = csv.DictReader(f)
            for row in reader:
                text = row.get("text", "").strip()
                fname = row.get("file_name", "").strip()
                if not text:
                    continue
                # Match m4a filename to converted wav
                if fname in wav_map:
                    custom_records.append({
                        "file_path": wav_map[fname],
                        "text": text,
                        "source": "custom",
                    })
        print(f"  Loaded {len(custom_records)} transcribed samples from assets/voice/transcriptions.csv")
    else:
        print("  No transcriptions.csv found at assets/voice/transcriptions.csv")
        print("  TIP: Add transcriptions and push to git")
else:
    print("No custom voice files found in assets/voice/")

print(f"\nTotal: {len(public_records)} public + {len(custom_records)} custom = {len(public_records) + len(custom_records)} samples")

In [None]:
#@title 5. Preprocess all data
import unicodedata
import re
import librosa
import random

def normalize_bambara(text):
    """Bambara text normalization."""
    text = unicodedata.normalize("NFC", text)
    text = text.replace("\u2019", "'").replace("\u2018", "'")
    text = text.replace("\u201c", '"').replace("\u201d", '"')
    text = re.sub(r"\s+", " ", text).strip()
    return text

all_records = public_records + custom_records
processed = []

for rec in all_records:
    try:
        audio, sr = librosa.load(rec["file_path"], sr=TARGET_SR)
        dur = len(audio) / TARGET_SR
        if dur < 0.5 or dur > MAX_AUDIO_SEC:
            continue

        text = normalize_bambara(rec["text"])
        if len(text) < 2:
            continue

        out_path = PROC_DIR / Path(rec["file_path"]).name
        sf.write(str(out_path), audio, TARGET_SR)

        processed.append({
            "file_path": str(out_path),
            "text": text,
            "source": rec.get("source", "public"),
        })
    except Exception:
        continue

# Shuffle and split
random.shuffle(processed)
split_idx = int(len(processed) * 0.95)
train_records = processed[:split_idx]
val_records = processed[split_idx:]

n_custom = sum(1 for r in train_records if r["source"] == "custom")
print(f"Processed: {len(processed)} total")
print(f"  Train: {len(train_records)} ({n_custom} custom)")
print(f"  Val:   {len(val_records)}")

In [None]:
#@title 6. Load model
from transformers import VitsModel, AutoTokenizer

print(f"Loading {BASE_MODEL}...")
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
model = VitsModel.from_pretrained(BASE_MODEL).to("cuda")

total_params = sum(p.numel() for p in model.parameters())
print(f"Model: {total_params/1e6:.1f}M params, SR={model.config.sampling_rate}Hz")

In [None]:
#@title 7. Test BEFORE fine-tuning
import IPython.display as ipd

test_texts = [
    "I ni ce",
    "Aw ni ce, ne togo ye Ekodi",
    "Bamanankan ye kan nafama ye",
    "Lakoli ye yoro \u0272uman ye denmis\u025bnw ye",
]

print("=== BEFORE fine-tuning ===")
for text in test_texts:
    tokens = tokenizer(text, return_tensors="pt").to("cuda")
    with torch.no_grad():
        output = model(**tokens)
    wav = output.waveform[0].cpu().numpy()
    print(f'\n"{text}"  ({len(wav)/model.config.sampling_rate:.1f}s)')
    ipd.display(ipd.Audio(wav, rate=model.config.sampling_rate))

In [None]:
#@title 8. Setup training
import torchaudio.transforms as T
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import time

# Mel spectrogram transform
mel_fn = T.MelSpectrogram(
    sample_rate=model.config.sampling_rate,
    n_fft=1024, hop_length=256, n_mels=80
).to("cuda")


class BambaraDataset(Dataset):
    def __init__(self, records, tokenizer, sr=16000, max_len=None):
        self.records = records
        self.tokenizer = tokenizer
        self.sr = sr
        self.max_len = max_len or int(MAX_AUDIO_SEC * sr)

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        audio, _ = sf.read(rec["file_path"], dtype="float32")
        if audio.ndim > 1:
            audio = audio.mean(axis=1)
        if len(audio) > self.max_len:
            audio = audio[:self.max_len]
        tokens = self.tokenizer(rec["text"], return_tensors="pt", padding=False)
        return {
            "input_ids": tokens["input_ids"].squeeze(0),
            "attention_mask": tokens["attention_mask"].squeeze(0),
            "waveform": torch.tensor(audio, dtype=torch.float32),
        }


def collate(batch):
    max_t = max(b["input_ids"].shape[0] for b in batch)
    max_a = max(b["waveform"].shape[0] for b in batch)
    ids = torch.zeros(len(batch), max_t, dtype=torch.long)
    mask = torch.zeros(len(batch), max_t, dtype=torch.long)
    wavs = torch.zeros(len(batch), max_a)
    for i, b in enumerate(batch):
        tl = b["input_ids"].shape[0]
        al = b["waveform"].shape[0]
        ids[i, :tl] = b["input_ids"]
        mask[i, :tl] = b["attention_mask"]
        wavs[i, :al] = b["waveform"]
    return {"input_ids": ids, "attention_mask": mask, "waveforms": wavs}


train_ds = BambaraDataset(train_records, tokenizer, model.config.sampling_rate)
train_loader = DataLoader(
    train_ds, batch_size=BATCH_SIZE, shuffle=True,
    collate_fn=collate, num_workers=2, pin_memory=True
)

# Freeze everything except embeddings + duration predictor + projection
for p in model.parameters():
    p.requires_grad = False

for name, p in model.named_parameters():
    if any(k in name for k in ["embed_tokens", "duration_predictor", "proj"]):
        p.requires_grad = True

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Trainable: {trainable:,} / {total_params:,} ({100*trainable/total_params:.1f}%)")
print(f"Dataset: {len(train_ds)} train samples")
print(f"Batches/epoch: {len(train_loader)}")
print(f"Epochs: {EPOCHS}, LR: {LEARNING_RATE}")

In [None]:
#@title 9. TRAIN! ðŸš€
optimizer = torch.optim.AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=LEARNING_RATE, weight_decay=0.01
)

ACCUM_STEPS = 4  # Effective batch = BATCH_SIZE * ACCUM_STEPS
best_loss = float("inf")

print(f"Starting training: {EPOCHS} epochs, batch={BATCH_SIZE}, accum={ACCUM_STEPS}")
print(f"Effective batch size: {BATCH_SIZE * ACCUM_STEPS}")
print("=" * 60)

for epoch in range(EPOCHS):
    model.train()
    epoch_loss = 0.0
    n_batches = 0
    t0 = time.time()
    optimizer.zero_grad()

    for batch_idx, batch in enumerate(train_loader):
        ids = batch["input_ids"].to("cuda")
        mask = batch["attention_mask"].to("cuda")
        target = batch["waveforms"].to("cuda")

        try:
            # Forward: generate audio
            output = model(input_ids=ids, attention_mask=mask)
            pred = output.waveform

            # Mel-spectrogram loss
            min_len = min(pred.shape[-1], target.shape[-1])
            if min_len < 1024:
                continue

            pred_mel = torch.log(mel_fn(pred[..., :min_len]).clamp(min=1e-5))
            tgt_mel = torch.log(mel_fn(target[..., :min_len]).clamp(min=1e-5))

            l1 = F.l1_loss(pred_mel, tgt_mel)
            sc = torch.norm(tgt_mel - pred_mel) / (torch.norm(tgt_mel) + 1e-7)
            loss = (l1 + sc) / ACCUM_STEPS

            loss.backward()
            epoch_loss += loss.item() * ACCUM_STEPS
            n_batches += 1

            if (batch_idx + 1) % ACCUM_STEPS == 0:
                torch.nn.utils.clip_grad_norm_(
                    [p for p in model.parameters() if p.requires_grad], 1.0
                )
                optimizer.step()
                optimizer.zero_grad()

            if (batch_idx + 1) % 100 == 0:
                print(f"  [{batch_idx+1}/{len(train_loader)}] loss={loss.item()*ACCUM_STEPS:.4f}")

        except torch.cuda.OutOfMemoryError:
            torch.cuda.empty_cache()
            optimizer.zero_grad()
            continue
        except Exception as e:
            if batch_idx < 3:
                print(f"  Error batch {batch_idx}: {e}")
            continue

    avg = epoch_loss / max(n_batches, 1)
    dt = time.time() - t0
    print(f"\n{'='*60}")
    print(f"Epoch {epoch+1}/{EPOCHS} | loss={avg:.4f} | batches={n_batches} | time={dt:.0f}s")
    print(f"{'='*60}")

    if avg < best_loss and n_batches > 0:
        best_loss = avg
        model.save_pretrained(str(CKPT_DIR / "best"))
        tokenizer.save_pretrained(str(CKPT_DIR / "best"))
        print(f"  >>> New best model saved! (loss={best_loss:.4f})")

# Final save
model.save_pretrained(str(CKPT_DIR / "final"))
tokenizer.save_pretrained(str(CKPT_DIR / "final"))
print(f"\nTraining complete! Best loss: {best_loss:.4f}")
print(f"Checkpoints: {CKPT_DIR}")

In [None]:
#@title 10. Test AFTER fine-tuning
print("=== AFTER fine-tuning ===")
model.eval()

for text in test_texts:
    tokens = tokenizer(text, return_tensors="pt").to("cuda")
    with torch.no_grad():
        output = model(**tokens)
    wav = output.waveform[0].cpu().numpy()
    print(f'\n"{text}"  ({len(wav)/model.config.sampling_rate:.1f}s)')
    ipd.display(ipd.Audio(wav, rate=model.config.sampling_rate))

In [None]:
#@title 11. Push to HuggingFace Hub ðŸš€
from huggingface_hub import HfApi

# Load best checkpoint
best_path = CKPT_DIR / "best"
if not best_path.exists():
    best_path = CKPT_DIR / "final"

print(f"Pushing {best_path} to {HUB_REPO}...")

best_model = VitsModel.from_pretrained(str(best_path))
best_tokenizer = AutoTokenizer.from_pretrained(str(best_path))

best_model.push_to_hub(HUB_REPO, private=False)
best_tokenizer.push_to_hub(HUB_REPO)

# Create model card
api = HfApi()
card = f"""---
language: bm
license: cc-by-nc-4.0
tags:
  - tts
  - bambara
  - vits
  - mms
  - ekodi
pipeline_tag: text-to-speech
---

# Ekodi Bambara TTS

Fine-tuned MMS-TTS model for Bambara (Bamanankan) text-to-speech.

## Usage

```python
from transformers import VitsModel, AutoTokenizer
import torch

model = VitsModel.from_pretrained("{HUB_REPO}")
tokenizer = AutoTokenizer.from_pretrained("{HUB_REPO}")

text = "I ni ce"  # Hello in Bambara
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
    output = model(**inputs)
waveform = output.waveform[0].numpy()
# waveform is 16kHz float32 audio
```

## Base model
- facebook/mms-tts-bam (Meta MMS-TTS, VITS architecture)

## Training data
- OumarDicko/Bambara_AudioSynthetique_42K_V3
- Custom Bambara voice recordings

## Project
- GitHub: https://github.com/adiarra14/ekodi
"""

api.upload_file(
    path_or_fileobj=card.encode(),
    path_in_repo="README.md",
    repo_id=HUB_REPO,
    repo_type="model",
)

print(f"\n{'='*60}")
print(f"Model pushed to: https://huggingface.co/{HUB_REPO}")
print(f"{'='*60}")
print(f"\nUse it with:")
print(f'  model = VitsModel.from_pretrained("{HUB_REPO}")')
print(f'  tokenizer = AutoTokenizer.from_pretrained("{HUB_REPO}")')

In [None]:
#@title 12. Push training results back to GitHub
import subprocess

# Configure git
GITHUB_USER = "adiarra14"
GITHUB_REPO = "ekodi"
# For private repos, use a personal access token:
# GITHUB_TOKEN = "ghp_..."  # Uncomment and paste your GitHub token
# !git remote set-url origin https://{GITHUB_TOKEN}@github.com/{GITHUB_USER}/{GITHUB_REPO}.git

os.chdir(f"/content/ekodi")

# Git config
!git config user.email "adiarra@gmail.com"
!git config user.name "adiarra14"

# Add training results (not the large data/checkpoint files)
!echo "tts/data/" >> .gitignore
!echo "tts/checkpoints/" >> .gitignore

# Copy best model to a lightweight location
best_info = f"Training complete. Best loss: {best_loss:.4f}"
with open("tts/TRAINING_LOG.md", "w") as f:
    f.write(f"# Ekodi Training Log\n\n")
    f.write(f"- **Date**: {time.strftime('%Y-%m-%d %H:%M')}\n")
    f.write(f"- **Base model**: {BASE_MODEL}\n")
    f.write(f"- **Epochs**: {EPOCHS}\n")
    f.write(f"- **Best loss**: {best_loss:.4f}\n")
    f.write(f"- **Train samples**: {len(train_records)} ({sum(1 for r in train_records if r.get('source')=='custom')} custom)\n")
    f.write(f"- **HuggingFace model**: https://huggingface.co/{HUB_REPO}\n")
    f.write(f"\n## Usage\n\n```python\nfrom transformers import VitsModel, AutoTokenizer\n")
    f.write(f'model = VitsModel.from_pretrained("{HUB_REPO}")\n')
    f.write(f'tokenizer = AutoTokenizer.from_pretrained("{HUB_REPO}")\n```\n')

!git add -A
!git status
!git commit -m "Add training results â€” best loss {best_loss:.4f}"

print("\nTo push to GitHub, run:")
print(f"  !git push origin main")
print("\n(You may need to authenticate â€” use a GitHub personal access token)")