# DEPENDENCIES

In [None]:
%%capture
!pip install hyperdb-python
!pip install sentence-transformers
!pip install pyngrok
!pip install exllamav2==0.0.8
!pip install edge-tts
!pip install rvc

# IMPORTS

In [None]:
from exllamav2 import (ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer)
from exllamav2.generator import (ExLlamaV2StreamingGenerator, ExLlamaV2Sampler)
from sentence_transformers import SentenceTransformer
from hyperdb import HyperDB
from huggingface_hub import snapshot_download, hf_hub_download

import edge_tts
from pathlib import Path
from dotenv import load_dotenv
from scipy.io import wavfile
from rvc.modules.vc.modules import VC

import uvicorn
import asyncio
from fastapi import FastAPI, Request
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from typing import Any, Optional
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import padding
from datetime import datetime
import base64
import json
import calendar
import time
import pytz
import threading
import os
import re
import shutil
import subprocess
import sys
import logging
import nest_asyncio
nest_asyncio.apply()

# CONSTANTS

In [None]:
# NGROK
NGROK_AUTHTOKEN = ""
NGROK_STATIC_DOMAIN = ""

# LLM
LLM_MODEL_ID = "TheBloke/dolphin-2.1-mistral-7B-GPTQ"
LLM_FOLDER = "llm"
EMBEDDING_MODEL = "all-MiniLM-L6-v2"
MAX_NEW_TOKENS = 256
LAST_CONTEXT = 11
LONG_TERM_CONTEXT = 10
APPLY_CHAT_TEMPLATE = True

# RVC
RVC_FOLDER = "rvc"
RVCMD_FOLDER = "rvcmd"
RVC_ASSETS_DOWNLOADER_URL = "https://github.com/RVC-Project/RVC-Models-Downloader/releases/download/v0.2.4/rvcmd_linux_386.tar.gz"

RVC_MIKU_MODEL_URL = "https://huggingface.co/juuxn/RVCModels/resolve/main/miku222333333.zip"
RVC_MIKU = "miku"
RVC_MIKU_MODEL_NAME = "miku222333333.pth"

RVC_HOSHINO_MODEL_URL = '"https://huggingface.co/juuxn/RVCModels/resolve/main/Ai_Hoshino_(From_Oshi_no_Ko)_(RVC_v2)_300_Epoch.zip"'
RVC_HOSHINO = "hoshino"
RVC_HOSHINO_MODEL_NAME = "AiHoshino.pth"

RVC_GLADOS_MODEL_URL = "https://huggingface.co/juuxn/RVCModels/resolve/main/glados2333333.zip"
RVC_GLADOS = "glados"
RVC_GLADOS_MODEL_NAME = "glados2333333.pth"


# Authentication
DELIMITER_DATA_DATE = "<DATA_DATE>"
AUTH_KEY_DECRYPTION_KEY = "B9CF1133E770E069695ZX8E6F4F0B9B5"
AUTH_TOKEN_ENCRYPTION_KEY = "A5DE2143E770E069695ZX8E6FDBB7BBE"
ENCRYPTION_ALGORITHM = "AES"
AUTH_TOKEN_EXPIRY_DAYS = 7
CHARACTER_ENCODING = "utf-8"

# Assets
ASSISTANTS_FOLDER = "assistants"
DEVICES_FOLDER = "devices"
CONVERSATION_FILE = "conversation.jsonl"
VOICE_CONFIGURATION_FILE = "voice_configuration.jsonl"
HYPER_DB_FILE = "conversation.pickle.gz"
TEMP_FILE = "temp.jsonl"
PROMPT_FILE = "prompt.txt"
EDGE_TTS_FILE = "edge_tts.wav"
RVC_STS_FILE = "rvc_tts.wav"

# Application
APP_ID = "ara"
DEVICE_ID_KEY = "DeviceID"
AUTHORIZATION_KEY = "Authorization"

# SETUP

In [None]:
# Download LLM
if not os.path.exists(LLM_FOLDER):
    os.makedirs(LLM_FOLDER, exist_ok=True)
    snapshot_download(repo_id=LLM_MODEL_ID, local_dir=LLM_FOLDER, local_dir_use_symlinks=False)

In [None]:
# Setup LLM
llm_config = ExLlamaV2Config()
llm_config.model_dir = LLM_FOLDER
llm_config.prepare()

ExLlamatokenizer = ExLlamaV2Tokenizer(llm_config)
llm_model = ExLlamaV2(llm_config)
llm_model.load([16, 24])

llm_cache = ExLlamaV2Cache(llm_model)

llm_generator = ExLlamaV2StreamingGenerator(llm_model, llm_cache, ExLlamatokenizer)
llm_generator.set_stop_conditions(['}'])

llm_settings = ExLlamaV2Sampler.Settings()
llm_settings.temperature = 0.85
llm_settings.top_k = 50
llm_settings.top_p = 0.8
llm_settings.token_repetition_penalty = 1.15
llm_settings.disallow_tokens(ExLlamatokenizer, [ExLlamatokenizer.eos_token_id])

llm_max_new_tokens = MAX_NEW_TOKENS

In [None]:
# Download RVC Models
# For more: https://docs.google.com/spreadsheets/d/1owfUtQuLW9ReiIwg6U9UkkDmPOTkuNHf0OKQtWu1iaI
os.makedirs(RVC_FOLDER, exist_ok=True)

# Download Model-1:
if not os.path.exists(f"{RVC_FOLDER}/{RVC_MIKU}"):
    !wget {RVC_MIKU_MODEL_URL} -O {RVC_FOLDER}/{RVC_MIKU}.zip
    !unzip {RVC_FOLDER}/{RVC_MIKU}.zip -d {RVC_FOLDER}/{RVC_MIKU}
    
# Download Model-2:
if not os.path.exists(f"{RVC_FOLDER}/{RVC_HOSHINO}"):
    !wget {RVC_HOSHINO_MODEL_URL} -O {RVC_FOLDER}/{RVC_HOSHINO}.zip
    !unzip {RVC_FOLDER}/{RVC_HOSHINO}.zip -d {RVC_FOLDER}/{RVC_HOSHINO}
    
# Download Model-3:
if not os.path.exists(f"{RVC_FOLDER}/{RVC_GLADOS}"):
    !wget {RVC_GLADOS_MODEL_URL} -O {RVC_FOLDER}/{RVC_GLADOS}.zip
    !unzip {RVC_FOLDER}/{RVC_GLADOS}.zip -d {RVC_FOLDER}/{RVC_GLADOS}

In [None]:
# Download RVC Assets downloader
if not os.path.exists(f"{RVC_FOLDER}/{RVCMD_FOLDER}"):
    !wget {RVC_ASSETS_DOWNLOADER_URL} -O {RVC_FOLDER}/rvcmd.tar.gz
    os.makedirs(f'{RVC_FOLDER}/{RVCMD_FOLDER}', exist_ok=True)
    !tar -xvf {RVC_FOLDER}/rvcmd.tar.gz -C {RVC_FOLDER}/{RVCMD_FOLDER}

In [None]:
cwd = os.getcwd()
os.chdir(RVC_FOLDER)

In [None]:
# Download RVC Assets
if not os.path.exists("assets"):
    !{RVCMD_FOLDER}/rvcmd assets/all

In [None]:
os.chdir(cwd)

In [None]:
%%writefile {RVC_FOLDER}/rvc.env
weight_root='/kaggle/working/rvc/assets'
weight_uvr5_root='/kaggle/working/rvc/assets/uvr5_weights'
index_root='/kaggle/working/rvc/assets'
rmvpe_root='/kaggle/working/rvc/assets/rmvpe'
hubert_path='/kaggle/working/rvc/assets/hubert/hubert_base.pt'
pretrained='/kaggle/working/rvc/assets/pretrained_v2'

In [None]:
# Load env file
load_dotenv(f"{RVC_FOLDER}/rvc.env")

In [None]:
# Load Model-1
miku_vc = VC()
miku_vc.get_vc(f"{RVC_FOLDER}/{RVC_MIKU}/{RVC_MIKU_MODEL_NAME}")
    
# Load Model-2:
hoshino_vc = VC()
hoshino_vc.get_vc(f"{RVC_FOLDER}/{RVC_HOSHINO}/{RVC_HOSHINO_MODEL_NAME}")

# Load Model-3:
glados_vc = VC()
glados_vc.get_vc(f"{RVC_FOLDER}/{RVC_GLADOS}/{RVC_GLADOS_MODEL_NAME}")

In [None]:
rvc_map = dict()

rvc_map[RVC_MIKU] = miku_vc
rvc_map[RVC_HOSHINO] = hoshino_vc
rvc_map[RVC_GLADOS] = glados_vc

In [None]:
voices = await edge_tts.list_voices()
edge_en_voices = list(filter(lambda x: x.startswith("en"), [v["ShortName"] for v in voices]))

In [None]:
# HyperDB
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
hyper_db_map = dict()

# Fast API
app = FastAPI()
router_prefix = f"/{APP_ID}"

# DATA TRANSFER OBJECTS

In [None]:
class MessageRequest(BaseModel):
    id: int
    quotedId: int = None
    content: str
    timestamp: int
    fromRole: str = Field(alias="from")
    chatId: int
    assistantId: int

class ServerResponse(BaseModel):
    isSuccess: bool
    statusCode: int
    errorMessage: str = None
    errorMessages: dict = None
    payload: Any = None
    
class AssistantRequest(BaseModel):
    id: int
    name: str
    prompt: str
    edgeVoice: str
    edgePitch: int
    rvcVoice: str

# UTILITY

In [None]:
# Custom exceptions
class InvalidToken(Exception):
    def __init__(self, message, errors=None):
        self.message = message
        self.errors = errors
    def __str__(self):
        return f"Message: [{self.message}] Trace: [{self.errors}]"

class TokenExpiredException(Exception):
    def __init__(self, message):
        self.message = message
    def __str__(self):
        return f"Message: [{self.message}]"

# Authentication
def encrypt(input: str, key: str) -> str:
    cipher = Cipher(algorithms.AES(key.encode(CHARACTER_ENCODING)), modes.ECB(), backend=default_backend())
    encryptor = cipher.encryptor()
    padded_input = pad(input.encode(CHARACTER_ENCODING))
    encrypted_bytes = encryptor.update(padded_input) + encryptor.finalize()
    return base64.urlsafe_b64encode(encrypted_bytes).decode(CHARACTER_ENCODING)

def decrypt(encrypted_input: str, key: str) -> str:
    cipher = Cipher(algorithms.AES(key.encode(CHARACTER_ENCODING)), modes.ECB(), backend=default_backend())
    decryptor = cipher.decryptor()
    decrypted_bytes = decryptor.update(base64.urlsafe_b64decode(encrypted_input)) + decryptor.finalize()
    return unpad(decrypted_bytes).decode(CHARACTER_ENCODING)

def pad(data: bytes) -> bytes:
    padder = padding.PKCS7(128).padder()
    return padder.update(data) + padder.finalize()

def unpad(data: bytes) -> bytes:
    unpadder = padding.PKCS7(128).unpadder()
    return unpadder.update(data) + unpadder.finalize()

def generate_token_with_expiration(input: str, expiry_time_days: int, key: str) -> str:
    current_date = calendar.timegm(time.gmtime())
    expiration_time = current_date + expiry_time_days * 86400
    token = f"{input}{DELIMITER_DATA_DATE}{expiration_time}"
    return encrypt(token, key)

def decrypt_token_with_expiration(token: str, key: str) -> str:
    try:
        decrypted_token = decrypt(token, key)
    except Exception as e:
        print(f"[decrypt_token_with_expiration] Invalid token: {e}")
        raise InvalidToken("Invalid token", e)

    token_parts = decrypted_token.split(DELIMITER_DATA_DATE)
    if len(token_parts) == 2:
        data = token_parts[0]
        expiration_time = int(token_parts[1])
        expiration_date = datetime.utcfromtimestamp(expiration_time)
        current_date = datetime.utcnow()
        if expiration_date < current_date:
            logger.warning("[decrypt_token_with_expiration] Token has expired")
            raise TokenExpiredException("Token has expired")
        return data
    else:
        print("[decrypt_token_with_expiration] Invalid token format")
        raise InvalidToken("Invalid token format", e)
        
def validate_request(headers: dict):
    device_id = headers.get(DEVICE_ID_KEY)
    token = headers.get(AUTHORIZATION_KEY)
    
    try:
        token = decrypt_token_with_expiration(token, AUTH_TOKEN_ENCRYPTION_KEY)
    except InvalidToken as e:
        return ServerResponse(isSuccess=False, statusCode=422, errorMessage=e.message)
    except TokenExpiredException as e:
        return ServerResponse(isSuccess=False, statusCode=410, errorMessage=e.message)
        
    if device_id != token:
        print("[validate_request] Invalid token, DeviceID mismatch")
        return ServerResponse(isSuccess=False, statusCode=401, errorMessage="Invalid token, DeviceID mismatch")
    
    return device_id


def convert_timestamp_to_date_time(timestamp: int) -> str:
    epoch_seconds = timestamp / 1000    
    utc_dt = datetime.utcfromtimestamp(epoch_seconds)
    kolkata_tz = pytz.timezone('Asia/Kolkata')
    kolkata_dt = pytz.utc.localize(utc_dt).astimezone(kolkata_tz)
    date_time = kolkata_dt.strftime("%d-%m-%Y %H:%M:%S")
    return date_time

# File structure
def get_assistant_folder(device_id: str, id: int) -> str:
    return os.path.join(DEVICES_FOLDER, str(device_id), ASSISTANTS_FOLDER, str(id))

def create_assistant_folder(device_id: str, assistant_request: AssistantRequest) -> None:
    id = assistant_request.id
    prompt = assistant_request.prompt
    name = assistant_request.name
    sign = "" if assistant_request.edgePitch < 0 else "+"
    edge_rvc_config = {
        "edgeVoice": assistant_request.edgeVoice,
        "edgePitch":f"{sign}{assistant_request.edgePitch}Hz",
        "rvcVoice":assistant_request.rvcVoice
    }
    
    folder_path = get_assistant_folder(device_id, id)
    os.makedirs(folder_path, exist_ok=True)
    
    conversation_file = os.path.join(folder_path, CONVERSATION_FILE)
    with open(conversation_file, 'w', encoding=CHARACTER_ENCODING) as f:
        pass
    
    voice_config_file = os.path.join(folder_path, VOICE_CONFIGURATION_FILE)
    with open(voice_config_file, 'w', encoding=CHARACTER_ENCODING) as f:
        f.write(json.dumps(edge_rvc_config, ensure_ascii=False))
    
    assistant_name = os.path.join(folder_path, name)
    with open(assistant_name, 'w', encoding=CHARACTER_ENCODING) as f:
        pass
    
    prompt_file = os.path.join(folder_path, PROMPT_FILE)
    with open(prompt_file, 'w', encoding=CHARACTER_ENCODING) as f:
        f.write(prompt)
        
def delete_assistant_folder(device_id: str, id: int) -> None:
    folder_path = get_assistant_folder(device_id, id)
    shutil.rmtree(folder_path)
    

# LLM text generaion
def get_hyper_db(device_id, assistant_id):
    key = f"{device_id}_{assistant_id}"
    if key not in hyper_db_map:
        conversation_file = os.path.join(get_assistant_folder(device_id, assistant_id), CONVERSATION_FILE) 
        hyper_db_file = os.path.join(get_assistant_folder(device_id, assistant_id), HYPER_DB_FILE)
        documents = []
        with open(conversation_file, 'r', encoding=CHARACTER_ENCODING) as f:
            for line in f:
                if line == "\n" or line == '':
                        continue
                documents.append(line)
        hyper_db = HyperDB(documents, key="content", embedding_function=embedding_model.encode)
        hyper_db.save(hyper_db_file)
        #hyper_db.load(hyper_db_file)
        hyper_db_map[key] = hyper_db
    return hyper_db_map[key]
    
def get_instruction_prompt(device_id, assistant_id):
    instruction_prompt_file = os.path.join(get_assistant_folder(device_id, assistant_id), PROMPT_FILE)
    instruction_prompt = ""
    with open(instruction_prompt_file, 'r', encoding=CHARACTER_ENCODING) as file:
        instruction_prompt = file.read()
    return instruction_prompt
    
def apply_chat_template(input_data):
    """This is for ChatML format"""
    if input_data == '' or input_data == '\n':
        return input_data
    output = ""
    json_data = json.loads(input_data)
    output += f"<|im_start|>{json_data['role']}\n"
    output += f"{input_data}<|im_end|>"
    return output
        
def apply_chat_template_generate(input_data):
    """This is for ChatML format"""
    if input_data == '' or input_data == '\n':
        return input_data
    output = ""
    output += f"<|im_start|>assistant\n"
    output += f"{input_data}"
    return output

def apply_chat_template_instruction(input_data):
    """This is for ChatML format"""
    if input_data == '' or input_data == '\n':
        return input_data
    output = ""
    output += f"<|im_start|>system\n"
    output += f"{input_data}<|im_end|>"
    return output

def generate_text(prompt):
    input_ids = ExLlamatokenizer.encode(prompt)
    sys.stdout.flush()

    llm_generator.begin_stream(input_ids, llm_settings)

    generated_tokens = 0

    print("\nOutput: ", end = "")
    generated_text = ""
    while True:
        chunk, eos, _ = llm_generator.stream()
        generated_tokens += 1
        generated_text += chunk
        print (chunk, end = "")
        sys.stdout.flush()
        if eos or (len(chunk)>0 and chunk[-1] == '}') or generated_tokens == MAX_NEW_TOKENS: break
    print()
    return generated_text

def remove_brackets(text):
    square_brackets_pattern = r'\[.*?\]'
    angular_brackets_pattern = r'<.*?>'

    text_without_square_brackets = re.sub(square_brackets_pattern, '', text)
    cleaned_text = re.sub(angular_brackets_pattern, '', text_without_square_brackets)
    return cleaned_text

# Edge TTS Generation
async def generate_tts(device_id, assistant_id, message_id, text, voice="en-US-AriaNeural", pitch="+0Hz", rate="+0%"):
    tts_file = os.path.join(get_assistant_folder(device_id, assistant_id), f"{message_id}_{EDGE_TTS_FILE}")
    communicate = edge_tts.Communicate(text, voice=voice, rate=rate, pitch=pitch)
    await communicate.save(tts_file)
    return tts_file
    
# RVC STS Generation
def generate_sts(device_id, assistant_id, message_id, voice=RVC_MIKU):
    tts_file = os.path.join(get_assistant_folder(device_id, assistant_id), f"{message_id}_{EDGE_TTS_FILE}")
    sts_file = os.path.join(get_assistant_folder(device_id, assistant_id), f"{message_id}_{RVC_STS_FILE}")
    print(tts_file, sts_file)
    tgt_sr, audio_opt, times, _ = rvc_map[voice].vc_single(1, Path(tts_file))
    wavfile.write(sts_file, tgt_sr, audio_opt)
    return sts_file

# API ENDPOINTS

In [None]:
@app.get(f"{router_prefix}/token", response_model=ServerResponse)
async def get_auth_token(request: Request):
    auth_key = request.headers.get('auth_key')
    if not auth_key:
        print(f"[ERROR] [get_auth_token] Authentication key is missing")
        return ServerResponse(isSuccess=False, statusCode=400, errorMessage="Missing authentication key")

    try:
        device_id = decrypt(auth_key, AUTH_KEY_DECRYPTION_KEY).split(DELIMITER_DATA_DATE)[0]
        print(f"[INFO] [get_auth_token] DeviceID: {device_id}")
    except Exception as e:
        print(f"[ERROR] [get_auth_token] Invalid key format: {e}")
        return ServerResponse(isSuccess=False, statusCode=422, errorMessage="Invalid key format")

    try:
        token = generate_token_with_expiration(device_id, AUTH_TOKEN_EXPIRY_DAYS, AUTH_TOKEN_ENCRYPTION_KEY)
        print(f"[INFO] [get_auth_token] Token: {token}")
        return ServerResponse(isSuccess=True, statusCode=200, payload=token)
    except Exception as e:
        print(f"[ERROR] [get_auth_token] Error generating token: {e}")
        return ServerResponse(isSuccess=False, statusCode=500, errorMessage="[get_auth_token] Error generating token")

@app.get(f"{router_prefix}/status", response_model=ServerResponse)
async def get_status(request: Request):
    return ServerResponse(isSuccess=True, statusCode=200, payload="[get_status] ARA Server is up and running")
    
    
@app.post(f"{router_prefix}/assistant/create", response_model=ServerResponse)
async def create_assistant(request: Request, assistant_request: AssistantRequest):
    validated_data = validate_request(request.headers)
    if isinstance(validated_data, ServerResponse):
        return validated_data
    device_id = validated_data
        
    try:
        create_assistant_folder(device_id, assistant_request)
        return ServerResponse(isSuccess=True, statusCode=200, payload="Assistant created successfully")
    except Exception as e:
        print(f"[ERROR] [create_assistant] Error creating assistant: {e}")
        return ServerResponse(isSuccess=False, statusCode=500, errorMessage="Error creating assistant")
    
    
@app.delete(f"{router_prefix}/assistant/delete/{{assistantId}}", response_model=ServerResponse)
async def delete_assistant(request: Request, assistantId):
    validated_data = validate_request(request.headers)
    if isinstance(validated_data, ServerResponse):
        return validated_data
    device_id = validated_data
    
    try:
        delete_assistant_folder(device_id, assistantId)
        return ServerResponse(isSuccess=True, statusCode=200, payload="Assistant deleted successfully")
    except Exception as e:
        print(f"[ERROR] [delete_assistant] Error deleting assistant: {e}")
        return ServerResponse(isSuccess=False, statusCode=500, errorMessage="Error deleting assistant")
    
    
@app.delete(f"{router_prefix}/message/delete/{{assistantId}}/{{messageId}}", response_model=ServerResponse)
async def delete_message(request: Request, assistantId, messageId):
    validated_data = validate_request(request.headers)
    if isinstance(validated_data, ServerResponse):
        return validated_data
    device_id = validated_data
    
    try:
        deleted = False
        conversation_file = os.path.join(get_assistant_folder(device_id, assistantId), CONVERSATION_FILE)        
        temp_file = os.path.join(get_assistant_folder(device_id, assistantId), TEMP_FILE)     
        with open(conversation_file, 'r', encoding=CHARACTER_ENCODING) as file:
            with open(temp_file, 'w', encoding=CHARACTER_ENCODING) as output:
                for line in file:
                    if line == "\n" or line == '':
                        continue
                    try:
                        message = json.loads(line)
                        if int(message.get("id")) != int(messageId):
                            output.write(line)
                        else:
                            deleted = True
                            key = f"{device_id}_{assistantId}"
                            hyper_db_map.pop(key, -1)
                    except json.JSONDecodeError:
                        print(f"[WARN] Error decoding line: {line}")
                        output.write(line)
        
        os.replace(temp_file, conversation_file)
        
        if deleted:
            return ServerResponse(isSuccess=True, statusCode=200, payload="Message deleted successfully")
        return ServerResponse(isSuccess=False, statusCode=404, payload="Message not found")
    except Exception as e:
        print(f"[ERROR] [delete_message] Error deleting message: {e}")
        return ServerResponse(isSuccess=False, statusCode=500, errorMessage="Error deleting message")
    
@app.put(f"{router_prefix}/message/edit", response_model=ServerResponse)
async def edit_message(request: Request, message_request: MessageRequest):
    validated_data = validate_request(request.headers)
    if isinstance(validated_data, ServerResponse):
        return validated_data
    device_id = validated_data
    
    try:
        edited = False
        conversation_file = os.path.join(get_assistant_folder(device_id, message_request.assistantId), CONVERSATION_FILE)        
        temp_file = os.path.join(get_assistant_folder(device_id, message_request.assistantId), TEMP_FILE)     
        with open(conversation_file, 'r', encoding=CHARACTER_ENCODING) as file:
            with open(temp_file, 'w', encoding=CHARACTER_ENCODING) as output:
                for line in file:
                    if line == "\n" or line == '':
                        continue
                    try:
                        message = json.loads(line)
                        if int(message.get("id")) != int(message_request.id):
                            output.write(line)
                        else:
                            edited = True
                            message["content"] = message_request.content
                            output.write(json.dumps(message, ensure_ascii=False) + "\n")
                            key = f"{device_id}_{message_request.assistantId}"
                            hyper_db_map.pop(key, -1)
                    except json.JSONDecodeError:
                        print(f"[WARN] Error decoding line: {line}")
                        output.write(line)
        
        os.replace(temp_file, conversation_file)
        
        if edited:
            return ServerResponse(isSuccess=True, statusCode=200, payload="Message edited successfully")
        return ServerResponse(isSuccess=False, statusCode=404, payload="Message not found")
    except Exception as e:
        print(f"[ERROR] [edit_message] Error editing message: {e}")
        return ServerResponse(isSuccess=False, statusCode=500, errorMessage="Error editing message")
    
    
@app.post(f"{router_prefix}/message/upload", response_model=ServerResponse)
async def upload_message(request: Request, message_request: MessageRequest):
    validated_data = validate_request(request.headers)
    if isinstance(validated_data, ServerResponse):
        return validated_data
    device_id = validated_data

    try:
        hyper_db = get_hyper_db(device_id, message_request.assistantId)

        message_data = {
            "id": message_request.id,
            "role": message_request.fromRole.lower(),
            "time": convert_timestamp_to_date_time(message_request.timestamp),
            "content": message_request.content,
        }
        if message_request.quotedId is not None:
            message_data["replyto"] = message_request.quotedId
        
        print(f"[INFO] [upload_message] YOU: {str(message_data)}")
        conversation_file = os.path.join(get_assistant_folder(device_id, message_request.assistantId), CONVERSATION_FILE)
        with open(conversation_file, 'a', encoding=CHARACTER_ENCODING) as f:
            f.write("\n" + json.dumps(message_data, ensure_ascii=False))
        hyper_db.add_document(json.dumps(message_data, ensure_ascii=False))
            
        return ServerResponse(isSuccess=True, statusCode=201, payload="Message uploaded successfully")
    except Exception as e:
        print(f"[ERROR] [upload_message] Error uploading message: {e}")
        return ServerResponse(isSuccess=False, statusCode=500, errorMessage="Error uploading message")
    
    
@app.post(f"{router_prefix}/assistant/response", response_model=ServerResponse)
async def assistant_response(request: Request, message_request: MessageRequest):
    validated_data = validate_request(request.headers)
    if isinstance(validated_data, ServerResponse):
        return validated_data
    device_id = validated_data
    
    try:
        hyper_db = get_hyper_db(device_id, message_request.assistantId)
        conversation_file = os.path.join(get_assistant_folder(device_id, message_request.assistantId), CONVERSATION_FILE)
        lines = []
        with open(conversation_file, 'r', encoding=CHARACTER_ENCODING) as f:
            lines = f.read().splitlines()
            
        last_context = lines[-LAST_CONTEXT:] if len(lines) > 0 else []
        
        results = hyper_db.query(json.loads(last_context[-1])["content"], top_k=LONG_TERM_CONTEXT) if len(lines) > 0 else []
        related_context = [doc.rstrip('\n') for doc, _ in results]
        
        related_context = [line for line in related_context if line not in last_context]
                
        timestamp = time.time() * 1000
        date_time = convert_timestamp_to_date_time(timestamp)
        message_data = {
            "id": message_request.id,
            "role": message_request.fromRole.lower(),
            "time": date_time,
            "content": ""
        }
        new_message = [json.dumps(message_data, ensure_ascii=False)[:-2]]  # Leave out the "}
        
        instruction_prompt = get_instruction_prompt(device_id, message_request.assistantId)
        combined_inputs = related_context + last_context + new_message
        
        print(f"[INFO] [assistant_response] Related context: {related_context}")
        print(f"[INFO] [assistant_response] Last context: {last_context}")
        print(f"[INFO] [assistant_response] New context: {new_message}")
        
        if APPLY_CHAT_TEMPLATE:
            combined_inputs = [apply_chat_template(input_data) for input_data in (related_context + last_context)] \
                            + [apply_chat_template_generate(new_message[0])]
            instruction_prompt = apply_chat_template_instruction(instruction_prompt)
            
        prompt = instruction_prompt + '\n' + '\n'.join(combined_inputs)
        print(f"[INFO] [assistant_response] Final prompt:\n{prompt}")
        
        message_generated = generate_text(prompt)
        print(f"[INFO] [assistant_response] Generated: {message_generated}")
        
        try:
            voice_config_file = os.path.join(get_assistant_folder(device_id, message_request.assistantId), VOICE_CONFIGURATION_FILE)
            voice_config = "{}"
            with open(voice_config_file, 'r', encoding=CHARACTER_ENCODING) as f:
                voice_config = f.readline()
            edge_rvc_config = json.loads(voice_config)
            
            cleaned_generation = remove_brackets(message_generated).split('"')[0].replace("\n", " ").replace("/n", " ")
            print(f"[INFO] [assistant_response] Cleaned: {cleaned_generation}")
            audio_generated = await generate_tts(device_id, message_request.assistantId, message_request.id,
                 text=cleaned_generation, voice=edge_rvc_config["edgeVoice"], pitch=edge_rvc_config["edgePitch"])

            print(f"[INFO] [assistant_response] Edge TTS Completed")
            if edge_rvc_config.get("rvcVoice") != None and edge_rvc_config.get("rvcVoice") != "":
                audio_generated = generate_sts(device_id, message_request.assistantId,
                                               message_request.id, edge_rvc_config["rvcVoice"])
            print(f"[INFO] [assistant_response] RVC STS Completed")
            
        except Exception as e:
            print(f"[ERROR] [assistant_response] TTS Error : {e}")
        
        message_generated = new_message[0] + message_generated + "}"
        message_generated_json = {}
        try:
            message_generated_json = json.loads(message_generated)
        except Exception as e:
            print(f"[ERROR] [assistant_response] JSON Error : {e}")
            return ServerResponse(isSuccess=False, statusCode=500, payload=message_generated,
                                  errorMessage="Generated response is not in JSON")
        
        with open(conversation_file, 'a', encoding=CHARACTER_ENCODING) as f:
            f.write("\n" + message_generated)
        hyper_db.add_document(message_generated)

        message_response = {
            "id": message_request.id,
            "from": message_request.fromRole,
            "content": message_generated_json["content"],
            "quotedId": message_generated_json.get("replyto"),
            "timestamp": round(timestamp)
        }
        return ServerResponse(isSuccess=True, statusCode=200, payload=message_response)
    except Exception as e:
        print(f"[ERROR] [assistant_response] Error generating assistant response: {e}")
        return ServerResponse(isSuccess=False, statusCode=500, errorMessage="Error generating assistant response")
    

@app.get(f"{router_prefix}/assistant/voice-models", response_model=ServerResponse)
async def assistant_voice_models(request: Request):
    validated_data = validate_request(request.headers)
    if isinstance(validated_data, ServerResponse):
        return validated_data
    device_id = validated_data
    
    payload = {
        "edgeVoiceModels": edge_en_voices,
        "rvcVoiceModels": list(rvc_map.keys())
    }
    return ServerResponse(isSuccess=True, statusCode=200, payload=payload)
    
    
@app.get(f"{router_prefix}/assistant/response_audio/{{assistantId}}/{{messageId}}")
async def assistant_response_audio(request: Request, assistantId, messageId):
    validated_data = validate_request(request.headers)
    if isinstance(validated_data, ServerResponse):
        return validated_data
    device_id = validated_data
    
    tts_file = os.path.join(get_assistant_folder(device_id, assistantId), f"{messageId}_{EDGE_TTS_FILE}")
    sts_file = os.path.join(get_assistant_folder(device_id, assistantId), f"{messageId}_{RVC_STS_FILE}")
    
    audio_file_path = sts_file if os.path.isfile(sts_file) else tts_file
    return FileResponse(audio_file_path, media_type="audio/wav", filename=f"{messageId}.wav")

# THREADS

In [None]:
stop_event = threading.Event()

def fastAPIThread(stop_event):
    config = uvicorn.Config(app, host="0.0.0.0", port=5003)
    server = uvicorn.Server(config)

    loop = asyncio.new_event_loop()
    asyncio.set_event_loop(loop)

    server_task = loop.create_task(server.serve())

    try:
        while not stop_event.is_set():
            loop.run_until_complete(asyncio.sleep(1))
    except KeyboardInterrupt:
        pass
    finally:
        server.should_exit = True
        loop.run_until_complete(server_task)
        loop.close()

In [None]:
ngrok_process = None

def run_ngrok():
    global ngrok_process
    command = ["ngrok", "http", f"--domain={NGROK_STATIC_DOMAIN}", "5003"]
    ngrok_process = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True)
    for line in ngrok_process.stdout:
        print(line.strip())
    
def stop_ngrok():
    global ngrok_process
    if ngrok_process:
        ngrok_process.terminate()

In [None]:
!ngrok config add-authtoken {NGROK_AUTHTOKEN}

# EXECUTE

In [None]:
server_thread = threading.Thread(target=fastAPIThread, args=(stop_event,), daemon=True)
server_thread.start()

In [None]:
ngrok_thread = threading.Thread(target=run_ngrok, daemon=True)
ngrok_thread.start()

In [None]:
# Infinite loop to prevent idle timeout
print('--LOGS--')
try:
    while True:
        # Sleep for 5 minutes
        time.sleep(300)
except KeyboardInterrupt:
    print("Interrupted! Stopping the loop.")

# CLEANUP

In [None]:
stop_event.set()
server_thread.join()

In [None]:
stop_ngrok()
ngrok_thread.join()