In [None]:
!pip install moshi

In [None]:
!pip install fastapi uvicorn nest_asyncio pydantic

In [None]:
!pip install pyngrok

In [None]:
from dataclasses import dataclass
import time
import sentencepiece
import sphn
import textwrap
import torch

from moshi.models import loaders, MimiModel, LMModel, LMGen

In [None]:
import nest_asyncio
import os
import threading
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pyngrok import ngrok
import uvicorn
from pydantic import BaseModel
import numpy as np
import torch

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

In [None]:
!ngrok authtoken YOUR_NGROK_AUTH_TOKEN_HERE

In [None]:
TRANSCRIPT_BUFFER = ""

@dataclass
class InferenceState:
    mimi: MimiModel
    text_tokenizer: sentencepiece.SentencePieceProcessor
    lm_gen: LMGen

    def __init__(
        self,
        mimi: MimiModel,
        text_tokenizer: sentencepiece.SentencePieceProcessor,
        lm: LMModel,
        batch_size: int,
        device: str | torch.device,
    ):
        self.mimi = mimi
        self.text_tokenizer = text_tokenizer
        self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)
        self.device = device
        self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
        self.batch_size = batch_size
        self.mimi.streaming_forever(batch_size)
        self.lm_gen.streaming_forever(batch_size)

    def run(self, in_pcms: torch.Tensor):
        ntokens = 0
        first_frame = True
        chunks = [
            c
            for c in in_pcms.split(self.frame_size, dim=2)
            if c.shape[-1] == self.frame_size
        ]
        start_time = time.time()
        all_text = []
        for chunk in chunks:
            codes = self.mimi.encode(chunk)
            if first_frame:
                # Ensure that the first slice of codes is properly seen by the transformer
                # as otherwise the first slice is replaced by the initial tokens.
                tokens = self.lm_gen.step(codes)
                first_frame = False
            tokens = self.lm_gen.step(codes)
            if tokens is None:
                continue
            assert tokens.shape[1] == 1
            one_text = tokens[0, 0].cpu()
            if one_text.item() not in [0, 3]:
                text = self.text_tokenizer.id_to_piece(one_text.item())
                text = text.replace("‚ñÅ", " ")
                all_text.append(text)
            ntokens += 1
        dt = time.time() - start_time
        print(
            f"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step"
        )
        return "".join(all_text)

class AudioChunk(BaseModel):
    type: str  # "audio_chunk"
    pcm: list[float]

In [None]:
device = "cuda"
# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en
checkpoint_info = loaders.CheckpointInfo.from_hf_repo("kyutai/stt-1b-en_fr")
mimi = checkpoint_info.get_mimi(device=device)
text_tokenizer = checkpoint_info.get_text_tokenizer()
lm = checkpoint_info.get_moshi(device=device)
inference_state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)

@app.get("/")
def root():
    return {"message": "Colab FastAPI is working!"}

@app.post("/stt")
async def receive_audio(request: Request):
    global TRANSCRIPT_BUFFER
    data = await request.json()
    # Convert back to numpy array
    audio_np = np.array(data["pcm"], dtype=np.float32)

    audio_tensor = torch.from_numpy(audio_np)[None, None, :]
    text_chunk = inference_state.run(audio_tensor)
    TRANSCRIPT_BUFFER += " " + text_chunk

    return {"text": text_chunk}

In [None]:
import io
import sys
import time
public_url = ngrok.connect('7680')
print(public_url)

nest_asyncio.apply()

log_capture_buffer = io.StringIO()

log_file = "/content/server.log"
def run():
    try:
      uvicorn.run(app, host="0.0.0.0", port=7680, log_level="debug")
    finally:
      sys.stderr = sys.__stderr__


server_thread = threading.Thread(target=run, daemon=True)
server_thread.start()

In [None]:
def print_server_logs():
    """Prints and clears the captured logs."""
    logs = log_capture_buffer.getvalue()
    if logs:
        print("\n--- SERVER LOGS/ERRORS ---")
        print(logs, end='')
        # Clear the buffer after printing to avoid duplicates
        log_capture_buffer.truncate(0)
        log_capture_buffer.seek(0)
        print("--------------------------\n")

In [None]:
while True:
  time.sleep(3)
  print_server_logs()