In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import torch
import numpy as np

from pathlib import Path
from transformers import MimiModel, AutoFeatureExtractor

In [None]:
from overfit_trial.data_models import (
    MimiChannelArchive,
)
from overfit_trial.model import MachOverfitModel
from overfit_trial.inference import SlidingDuplexModelInference

from matplotlib import pyplot as plt
from IPython.display import Audio

In [None]:
LATENT_FREQ = 12.5
num_quantizers = 32

In [None]:
%%capture

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
mimi = MimiModel.from_pretrained("kyutai/mimi", num_quantizers=num_quantizers)
feature_extractor = AutoFeatureExtractor.from_pretrained("kyutai/mimi")

In [None]:
from contextlib import contextmanager


@contextmanager
def offload_after_context():
    gpu_tensors = []

    class Tracker:
        def to(self, device):
            # Intercept `.to('cuda')` calls
            tensor = self.tensor.to(device)
            if device.startswith("cuda"):
                gpu_tensors.append(tensor)
            return tensor

        def __init__(self, tensor):
            self.tensor = tensor

    try:
        yield Tracker  # gives you a class that wraps tensors
    finally:
        for t in gpu_tensors:
            t.cpu()
        gpu_tensors.clear()

In [None]:
channel_id_pair = ["V00_S0696_I00000375_P0844A", "V00_S0696_I00000375_P0847"]
num_quantizers = 32
SAMPLE_RATE_HZ = 24000

In [None]:
repo_root = Path(os.getcwd()).parent

channel1 = MimiChannelArchive(
    channel_id=channel_id_pair[0],
    npz_path=Path(repo_root / "asset" / "single_pair_dataset" / f"{channel_id_pair[0]}.{num_quantizers}q.npz"),
)

In [None]:
user = channel1.load_codes()
user_codes = user.to_torch()

In [None]:
with offload_after_context() as Tracker:
    model_output = mimi.decode(user_codes.unsqueeze(0))
    user_audio = model_output.audio_values.detach().cpu().squeeze().numpy()

In [None]:
plt.plot(user_audio[: 24000 * 2])

In [None]:
start_frame = 0
warmp_up_seconds = 2
warmup_frames = int(LATENT_FREQ * warmp_up_seconds)
decode_seconds = 10

In [None]:
duplex_checkpoint = "/home/henry/dev/overfit-duplex/checkpoints/overfit_trial/checkpoint_update_500.pt"

In [None]:
duplex = MachOverfitModel(
    num_quantizers=32, mimi_audio_embed_dir="/home/henry/dev/overfit-duplex/asset/mimi_audio_embeddings"
)
engine = SlidingDuplexModelInference(duplex, num_quantizers=32, window_size=512, device="cuda")
engine.load_checkpoint(duplex_checkpoint)

In [None]:
result = engine.generate(
    user_codes=user_codes.to(device),
    start_frame=start_frame,
    warmup_frames=warmup_frames,
    num_steps=int(LATENT_FREQ * decode_seconds),
)

In [None]:
with offload_after_context() as Tracker:
    model_output = mimi.decode(result.cpu())
    audio_arr = model_output.audio_values.detach().cpu().squeeze().numpy()
full_assistant_audio = np.concatenate([np.zeros(24000 * warmp_up_seconds), audio_arr])

In [None]:
full_user = user_audio[int(LATENT_FREQ * start_frame) : full_assistant_audio.shape[0]]
assert full_user.shape == full_assistant_audio.shape
full_convo = full_user + full_assistant_audio
Audio(full_convo, rate=24000)