In [None]:
import os
import platform
import subprocess
import torch

print(f"Python: {platform.python_version()}")
print(f"Platform: {platform.platform()}")
print(f"Torch: {torch.__version__}")

if not torch.cuda.is_available():
    raise RuntimeError("CUDA GPU not detected. In Colab, enable GPU: Runtime -> Change runtime type -> GPU")

print(f"CUDA device: {torch.cuda.get_device_name(0)}")
subprocess.run(["nvidia-smi"], check=False)


In [None]:
from pathlib import Path

REPO_URL = "https://github.com/AsyncThunky/AudioBuff.git"
BRANCH = "main"
DRIVE_ROOT = "/content/drive/MyDrive/AudioBuff"
PROFILE = "smoke"  # smoke | poc | full
XCODEC_GIT_URL = ""  # Optional: e.g. https://github.com/your-fork/xcodec.git

RUN_PREP = True
RUN_TRAIN = True
RUN_INFER = True

RAW_AUDIO_DIR = "/content/drive/MyDrive/AudioBuff/raw_pristine_audio"
INFER_INPUT_WAV = ""  # Optional explicit input wav path

RESUME_FROM_LATEST = True
DAC_MODEL_TYPE = "44khz"
SEED = 1337


In [None]:
from google.colab import drive
from pathlib import Path

drive.mount('/content/drive')

DRIVE_ROOT_PATH = Path(DRIVE_ROOT)
PERSIST_LATENTS = DRIVE_ROOT_PATH / "latents"
PERSIST_CHECKPOINTS = DRIVE_ROOT_PATH / "checkpoints"
PERSIST_ARTIFACTS = DRIVE_ROOT_PATH / "artifacts"

for directory in (DRIVE_ROOT_PATH, PERSIST_LATENTS, PERSIST_CHECKPOINTS, PERSIST_ARTIFACTS):
    directory.mkdir(parents=True, exist_ok=True)

print(f"Drive root: {DRIVE_ROOT_PATH}")
print(f"Latents: {PERSIST_LATENTS}")
print(f"Checkpoints: {PERSIST_CHECKPOINTS}")
print(f"Artifacts: {PERSIST_ARTIFACTS}")


In [None]:
import os
import shutil
import subprocess
from pathlib import Path

REPO_DIR = Path("/content/AudioBuff")


def run(cmd: list[str], cwd: Path | None = None) -> None:
    print("+", " ".join(str(x) for x in cmd))
    subprocess.run(cmd, check=True, cwd=str(cwd) if cwd else None)

if REPO_DIR.exists():
    shutil.rmtree(REPO_DIR)

run(["git", "clone", "--depth", "1", "--branch", BRANCH, REPO_URL, str(REPO_DIR)])
run(["python", "-m", "pip", "install", "--upgrade", "pip"], cwd=REPO_DIR)
run(["python", "-m", "pip", "install", "-r", "requirements.txt"], cwd=REPO_DIR)

if XCODEC_GIT_URL.strip():
    run(["python", "-m", "pip", "install", f"git+{XCODEC_GIT_URL.strip()}"], cwd=REPO_DIR)
else:
    print("XCODEC_GIT_URL is empty. Training/inference will use the built-in random fallback conditioning.")


In [None]:
from pathlib import Path


def must_contain(path: Path, tokens: list[str]) -> None:
    text = path.read_text(encoding="utf-8")
    missing = [token for token in tokens if token not in text]
    if missing:
        raise AssertionError(f"{path} missing expected tokens: {missing}")

must_contain(REPO_DIR / "train.py", ["--profile", "--grad_accum_steps", "--amp", "--no_amp"])
must_contain(REPO_DIR / "data" / "prepare_latents.py", ["--segment_seconds", "--max_files", "--dac_model_type"])
must_contain(REPO_DIR / "inference" / "generate.py", ["--chunk_tokens", "--hop_tokens", "--steps"])

print("Sanity checks passed: expected CLI options are present.")


In [None]:
import json

PROFILES = {
    "smoke": {
        "prep": {"segment_seconds": 2.0, "max_files": 2},
        "train": {
            "epochs": 1,
            "batch_size": 2,
            "num_workers": 2,
            "grad_accum_steps": 1,
            "amp": False,
            "log_interval": 5,
        },
        "infer": {"cfg": 1.5, "steps": 8, "chunk_tokens": 64, "hop_tokens": 32},
    },
    "poc": {
        "prep": {"segment_seconds": 5.0, "max_files": 80},
        "train": {
            "epochs": 8,
            "batch_size": 4,
            "num_workers": 2,
            "grad_accum_steps": 4,
            "amp": True,
            "log_interval": 20,
        },
        "infer": {"cfg": 1.8, "steps": 24, "chunk_tokens": 192, "hop_tokens": 96},
    },
    "full": {
        "prep": {"segment_seconds": 5.0, "max_files": -1},
        "train": {
            "epochs": 100,
            "batch_size": 8,
            "num_workers": 4,
            "grad_accum_steps": 8,
            "amp": True,
            "log_interval": 50,
        },
        "infer": {"cfg": 2.0, "steps": 32, "chunk_tokens": 256, "hop_tokens": 128},
    },
}

if PROFILE not in PROFILES:
    raise ValueError(f"Unsupported PROFILE '{PROFILE}'. Use one of: {list(PROFILES)}")

ACTIVE = PROFILES[PROFILE]
print(json.dumps(ACTIVE, indent=2))


In [None]:
import subprocess
from pathlib import Path


def run(cmd: list[str], cwd: Path | None = None) -> None:
    print("+", " ".join(str(x) for x in cmd))
    subprocess.run(cmd, check=True, cwd=str(cwd) if cwd else None)

if RUN_PREP:
    prep = ACTIVE["prep"]
    cmd = [
        "python", "-m", "data.prepare_latents",
        "--source_dir", str(Path(RAW_AUDIO_DIR)),
        "--out_dir", str(PERSIST_LATENTS),
        "--dac_model_type", DAC_MODEL_TYPE,
        "--segment_seconds", str(prep["segment_seconds"]),
        "--seed", str(SEED),
    ]
    if prep["max_files"] > 0:
        cmd.extend(["--max_files", str(prep["max_files"])])
    run(cmd, cwd=REPO_DIR)
else:
    print("Skipping latent extraction (RUN_PREP=False)")


In [None]:
import subprocess
from pathlib import Path


def run(cmd: list[str], cwd: Path | None = None) -> None:
    print("+", " ".join(str(x) for x in cmd))
    subprocess.run(cmd, check=True, cwd=str(cwd) if cwd else None)

if RUN_TRAIN:
    train_cfg = ACTIVE["train"]
    cmd = [
        "torchrun", "--standalone", "--nnodes=1", "--nproc_per_node=1", "train.py",
        "--profile", PROFILE,
        "--data_dir", str(PERSIST_LATENTS),
        "--checkpoint_dir", str(PERSIST_CHECKPOINTS),
        "--epochs", str(train_cfg["epochs"]),
        "--batch_size", str(train_cfg["batch_size"]),
        "--num_workers", str(train_cfg["num_workers"]),
        "--grad_accum_steps", str(train_cfg["grad_accum_steps"]),
        "--log_interval", str(train_cfg["log_interval"]),
        "--seed", str(SEED),
    ]

    if train_cfg["amp"]:
        cmd.append("--amp")
    else:
        cmd.append("--no_amp")

    if RESUME_FROM_LATEST:
        checkpoints = sorted(PERSIST_CHECKPOINTS.glob("checkpoint_epoch_*.pt"))
        if checkpoints:
            cmd.extend(["--resume_path", str(checkpoints[-1])])
            print(f"Resuming from: {checkpoints[-1]}")

    run(cmd, cwd=REPO_DIR)
else:
    print("Skipping training (RUN_TRAIN=False)")


In [None]:
import subprocess
from pathlib import Path


def run(cmd: list[str], cwd: Path | None = None) -> None:
    print("+", " ".join(str(x) for x in cmd))
    subprocess.run(cmd, check=True, cwd=str(cwd) if cwd else None)

if RUN_INFER:
    infer_cfg = ACTIVE["infer"]

    if INFER_INPUT_WAV.strip():
        input_wav = Path(INFER_INPUT_WAV)
    else:
        candidates = sorted(Path(RAW_AUDIO_DIR).glob("*.wav"))
        if not candidates:
            raise FileNotFoundError(f"No .wav files found in {RAW_AUDIO_DIR} and INFER_INPUT_WAV is empty.")
        input_wav = candidates[0]

    checkpoint = PERSIST_CHECKPOINTS / "checkpoint_best.pt"
    if not checkpoint.exists():
        candidates = sorted(PERSIST_CHECKPOINTS.glob("checkpoint_epoch_*.pt"))
        if not candidates:
            raise FileNotFoundError("No checkpoint found for inference.")
        checkpoint = candidates[-1]

    output_wav = PERSIST_ARTIFACTS / f"repaired_{input_wav.stem}_{PROFILE}.wav"

    cmd = [
        "python", "-m", "inference.generate",
        "--input", str(input_wav),
        "--output", str(output_wav),
        "--checkpoint", str(checkpoint),
        "--cfg", str(infer_cfg["cfg"]),
        "--steps", str(infer_cfg["steps"]),
        "--chunk_tokens", str(infer_cfg["chunk_tokens"]),
        "--hop_tokens", str(infer_cfg["hop_tokens"]),
        "--dac_model_type", DAC_MODEL_TYPE,
    ]
    run(cmd, cwd=REPO_DIR)
    print(f"Saved: {output_wav}")
else:
    print("Skipping inference (RUN_INFER=False)")


In [None]:
from pathlib import Path

latent_files = sorted(PERSIST_LATENTS.glob("*.pt"))
checkpoint_files = sorted(PERSIST_CHECKPOINTS.glob("*.pt"))
artifact_wavs = sorted(PERSIST_ARTIFACTS.glob("*.wav"))

print(f"Latent files: {len(latent_files)}")
print(f"Checkpoint files: {len(checkpoint_files)}")
print(f"Artifact wav files: {len(artifact_wavs)}")

if checkpoint_files:
    print("Latest checkpoints:")
    for path in checkpoint_files[-5:]:
        print(" -", path)

if artifact_wavs:
    print("Latest artifact wavs:")
    for path in artifact_wavs[-5:]:
        print(" -", path)
