In [1]:
!pip install moshi

In [2]:
import sys
# Clone the repository
!rm -rf Pleias-RAG-Library/
!git clone https://github.com/Pleias/Pleias-RAG-Library

# Install the cloned package in development mode
%cd Pleias-RAG-Library
!pip install -e .

%cd ..
sys.path.append('/content/Pleias-RAG-Library')

In [3]:
!pip install fastapi uvicorn nest_asyncio pydantic

In [4]:
!pip install pyngrok

## Pleias

In [5]:
import pleias_rag_interface
print(pleias_rag_interface.__version__)

In [6]:
# Simple manual test
from pleias_rag_interface import RAGWithCitations

# We initialize our mdoel and interface here

rag = RAGWithCitations(
  model_path_or_name="PleIAs/Pleias-RAG-350M",
  backend = "transformers"
)

## STT

In [7]:
from dataclasses import dataclass
import time
import sentencepiece
import sphn
import textwrap
import torch

from moshi.models import loaders, MimiModel, LMModel, LMGen

In [8]:
import nest_asyncio
import os
import threading
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from pyngrok import ngrok
import uvicorn
from pydantic import BaseModel, Field
import numpy as np
import torch
import logging


logger = logging.getLogger('stt')
logger.setLevel(logging.DEBUG)

app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=['*'],
    allow_credentials=True,
    allow_methods=['*'],
    allow_headers=['*'],
)

In [9]:
!ngrok authtoken YOUR_NGROK_AUTH_TOKEN_HERE

In [10]:
TRANSCRIPT_BUFFER = ""

@dataclass
class InferenceState:
    mimi: MimiModel
    text_tokenizer: sentencepiece.SentencePieceProcessor
    lm_gen: LMGen

    def __init__(
        self,
        mimi: MimiModel,
        text_tokenizer: sentencepiece.SentencePieceProcessor,
        lm: LMModel,
        batch_size: int,
        device: str | torch.device,
    ):
        self.mimi = mimi
        self.text_tokenizer = text_tokenizer
        self.lm_gen = LMGen(lm, temp=0, temp_text=0, use_sampling=False)
        self.device = device
        self.frame_size = int(self.mimi.sample_rate / self.mimi.frame_rate)
        self.batch_size = batch_size
        self.mimi.streaming_forever(batch_size)
        self.lm_gen.streaming_forever(batch_size)

    def run(self, in_pcms: torch.Tensor):
        ntokens = 0
        first_frame = True
        chunks = [
            c
            for c in in_pcms.split(self.frame_size, dim=2)
            if c.shape[-1] == self.frame_size
        ]
        start_time = time.time()
        all_text = []
        for chunk in chunks:
            codes = self.mimi.encode(chunk)
            if first_frame:
                # Ensure that the first slice of codes is properly seen by the transformer
                # as otherwise the first slice is replaced by the initial tokens.
                tokens = self.lm_gen.step(codes)
                first_frame = False
            tokens = self.lm_gen.step(codes)
            if tokens is None:
                continue
            assert tokens.shape[1] == 1
            one_text = tokens[0, 0].cpu()
            if one_text.item() not in [0, 3]:
                text = self.text_tokenizer.id_to_piece(one_text.item())
                text = text.replace("▁", " ")
                all_text.append(text)
            ntokens += 1
        dt = time.time() - start_time
        print(
            f"processed {ntokens} steps in {dt:.0f}s, {1000 * dt / ntokens:.2f}ms/step"
        )
        return "".join(all_text)

class AudioChunk(BaseModel):
    type: str  # "audio_chunk"
    pcm: list[float]

In [11]:
from typing import List, Dict, Any, Optional
import asyncio
import json

if torch.cuda.is_available():
  stt_stream = torch.cuda.Stream()
  llm_stream = torch.cuda.Stream()

class ModelInfo(BaseModel):
    """The model info the client expects from /v1/models."""
    id: str = "llm-model"
    object: str = "model"
    created: int = 1677652288
    owned_by: str = "local"

class ModelList(BaseModel):
    """The list structure for the /v1/models endpoint."""
    object: str = "list"
    data: List[ModelInfo]

class ChatCompletionRequest(BaseModel):
    """The request body the client sends."""
    model: str
    messages: List[Dict[str, Any]]
    stream: bool = False
    extra_body: Optional[Dict[str, Any]] = None

def messages_to_prompt(messages: list[dict]) -> str:
  parts = []
  if messages[0]['role'] == 'system':
    parts.append(f"[System]\n{messages[0]['content']}\n")
  if messages[-1]['role'] == 'user':
    parts.append(f"[User]\n{messages[-1]['content']}\n")
  return "\n".join(parts).strip()

async def llm_forward(query, sources):
  with torch.cuda.stream(llm_stream):
    full_text_response = await asyncio.to_thread(
        rag.generate,
        query=query,
        sources=sources
    )
    return full_text_response

async def rag_to_openai_stream(query: Any):

    messages = query.messages
    print("query: ", query)
    sources = query.extra_body or {}.get("sources", [])
    prompt = messages_to_prompt(messages)
    print("prompt: ", prompt)
    print("sources: ", sources)
    # full_text_response = await asyncio.to_thread(
    #     rag.generate,
    #     query=prompt,
    #     sources=sources
    # )
    full_text_response = await llm_forward(prompt, sources)

    chunk_size = 10
    print("raw text response: ", full_text_response)
    full_text_response = full_text_response["processed"]["clean_answer"]
    for i in range(0, len(full_text_response), chunk_size):
        chunk_content = full_text_response[i:i + chunk_size]
        data = {
            "id": "chatcmpl-12345",
            "object": "chat.completion.chunk",
            "choices": [{
                "delta": {"content": chunk_content},
                "index": 0,
                "finish_reason": None,
            }],
        }
        yield f"data: {json.dumps(data)}\n\n"
        await asyncio.sleep(0.01)
    yield "data: [DONE]\n\n"

In [12]:
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse

device = "cuda"
# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en
checkpoint_info = loaders.CheckpointInfo.from_hf_repo("kyutai/stt-1b-en_fr")
mimi = checkpoint_info.get_mimi(device=device)
text_tokenizer = checkpoint_info.get_text_tokenizer()
lm = checkpoint_info.get_moshi(device=device)
inference_state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)

def stt_forward(in_pcms):
  with torch.cuda.stream(stt_stream):
        text_chunk = inference_state.run(in_pcms)
        return text_chunk

@app.get("/")
def root():
    return {"message": "Colab FastAPI is working!"}

@app.post("/stt")
async def receive_audio(request: Request):
    global TRANSCRIPT_BUFFER
    data = await request.json()
    # Convert back to numpy array
    audio_np = np.array(data["pcm"], dtype=np.float32)

    audio_tensor = torch.from_numpy(audio_np)[None, None, :].to(device)

    text_chunk = await asyncio.to_thread(
        stt_forward,
        in_pcms=audio_tensor
    )
    #text_chunk = inference_state.run(audio_tensor)
    TRANSCRIPT_BUFFER += " " + text_chunk

    return {"text": text_chunk}

@app.get("/v1/models", response_model=ModelList)
async def list_models():
    """Endpoint for LLMService initialization."""
    return ModelList(data=[ModelInfo(id="colab-llama-3-mock")])


@app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
    """Endpoint for LLMService.stream_response."""

    if not request.stream:
        return JSONResponse(content={"error": "Non-streaming is not implemented."})
    return StreamingResponse(rag_to_openai_stream(request), media_type="text/event-stream")

In [13]:
import io
import sys
import time
public_url = ngrok.connect('7680')
print(public_url)

nest_asyncio.apply()

log_capture_buffer = io.StringIO()



# 3. Create a StreamHandler that writes to the StringIO object
capture_handler = logging.StreamHandler(log_capture_buffer)
capture_handler.setLevel(logging.DEBUG) # Set the handler level

# 4. (Optional) Define a simple formatter
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
capture_handler.setFormatter(formatter)

# 5. Attach the handler to your application logger
logger.addHandler(capture_handler)

#log_file = "/content/server.log"
def run():
    try:
      uvicorn.run(app, host="0.0.0.0", port=7680, log_level="debug")
    finally:
      sys.stderr = sys.__stderr__


server_thread = threading.Thread(target=run, daemon=True)
server_thread.start()

In [14]:
def print_server_logs():
    """Prints and clears the captured logs."""
    logs = log_capture_buffer.getvalue()
    if logs:
        print("\n--- SERVER LOGS/ERRORS ---")
        print(logs, end='')
        # Clear the buffer after printing to avoid duplicates
        log_capture_buffer.truncate(0)
        log_capture_buffer.seek(0)
        print("--------------------------\n")
    else:
        print("No new server logs captured.")

In [None]:
while True:
  time.sleep(3)
  print_server_logs()

In [None]:
TRANSCRIPT_BUFFER

In [None]:
import io
public_url = ngrok.connect('7680')
print(public_url)

nest_asyncio.apply()

log_capture_buffer = io.StringIO()

log_file = "/content/server.log"
def run():
    try:
      uvicorn.run(app, host="0.0.0.0", port=7680, log_level="debug")
    finally:
      sys.stderr = sys.__stderr__

threading.Thread(target=run, daemon=True).start()

In [17]:
!kill -9 $(lsof -t -i:8000) 2>/dev/null || true
!pkill ngrok || true

In [None]:
!tail -f /content/server.log

In [None]:
device = "cuda"
# Use the en+fr low latency model, an alternative is kyutai/stt-2.6b-en
checkpoint_info = loaders.CheckpointInfo.from_hf_repo("kyutai/stt-1b-en_fr")
mimi = checkpoint_info.get_mimi(device=device)
text_tokenizer = checkpoint_info.get_text_tokenizer()
lm = checkpoint_info.get_moshi(device=device)
in_pcms, _ = sphn.read("sample_fr_hibiki_crepes.mp3", sample_rate=mimi.sample_rate)
in_pcms = torch.from_numpy(in_pcms).to(device=device)

stt_config = checkpoint_info.stt_config
pad_left = int(stt_config.get("audio_silence_prefix_seconds", 0.0) * 24000)
pad_right = int((stt_config.get("audio_delay_seconds", 0.0) + 1.0) * 24000)
in_pcms = torch.nn.functional.pad(in_pcms, (pad_left, pad_right), mode="constant")
in_pcms = in_pcms[None, 0:1].expand(1, -1, -1)

state = InferenceState(mimi, text_tokenizer, lm, batch_size=1, device=device)
text = state.run(in_pcms)
print(textwrap.fill(text, width=100))

In [None]:
from IPython.display import Audio

Audio("sample_fr_hibiki_crepes.mp3")