# Chat Inference API on Google Colab

This notebook demonstrates how to set up and use the chat inference API on Google Colab. It includes both the server and client code.

## Setup

First, let's install the required dependencies and set up the environment.

In [None]:
!pip install flask pyyaml torch transformers

# Clone your repository (replace with your actual repo URL)
!git clone https://github.com/your-username/your-repo.git
!cd your-repo

# Download your model checkpoint (replace with actual download command)
!wget https://your-model-checkpoint-url.com/chat.pt -O chat.pt

## Server Code

Now, let's define the server code for our chat inference API.

In [None]:
import os
import yaml
import torch
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, set_seed
from src.factory import create_model_and_transforms
from data import AudioTextDataProcessor

app = Flask(__name__)

# Global variables
model = None
text_tokenizer = None
DataProcessor = None

def load_config(config_file):
    with open(config_file, 'r') as f:
        return yaml.safe_load(f)

def prepare_tokenizer(model_config):
    tokenizer_path = model_config['tokenizer_path']
    cache_dir = model_config['cache_dir']
    text_tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path,
        local_files_only=False,
        trust_remote_code=True,
        cache_dir=cache_dir,
    )
    text_tokenizer.add_special_tokens(
        {"additional_special_tokens": ["<audio>", "<|endofchunk|>"]}
    )
    if text_tokenizer.pad_token is None:
        text_tokenizer.add_special_tokens({"pad_token": "<PAD>"})
    if text_tokenizer.sep_token is None:
        text_tokenizer.add_special_tokens({"sep_token": "<SEP>"})
    return text_tokenizer

def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    model, _ = create_model_and_transforms(
        **model_config,
        clap_config=clap_config,
        use_local_files=False,
        gradient_checkpointing=False,
        freeze_lm_embeddings=False,
    )
    model.eval()
    model = model.to(device_id)

    checkpoint = torch.load(checkpoint_path, map_location="cpu")
    model_state_dict = checkpoint["model_state_dict"]
    model_state_dict = {k.replace("module.", ""): v for k, v in model_state_dict.items()}
    model.load_state_dict(model_state_dict, False)

    return model

def inference(model, tokenizer, item, processed_item, inference_kwargs, device_id=0):
    filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
    audio_clips = audio_clips.to(device_id, dtype=None, non_blocking=True)
    audio_embed_mask = audio_embed_mask.to(device_id, dtype=None, non_blocking=True)
    input_ids = input_ids.to(device_id, dtype=None, non_blocking=True).squeeze()

    eos_token_id = tokenizer.eos_token_id
    
    outputs = model.generate(
        audio_x=audio_clips.unsqueeze(0),
        audio_x_mask=audio_embed_mask.unsqueeze(0),
        lang_x=input_ids.unsqueeze(0),
        eos_token_id=eos_token_id,
        max_new_tokens=128,
        **inference_kwargs,
    )

    outputs_decoded = [
        tokenizer.decode(output).split(tokenizer.sep_token)[-1].replace(tokenizer.eos_token, '').replace(tokenizer.pad_token, '').replace('<|endofchunk|>', '') for output in outputs
    ]

    return outputs_decoded[0]

@app.route('/chat', methods=['POST'])
def chat():
    data = request.json
    audio_file = data.get('audio_file')
    dialogue = data.get('dialogue', [])
    
    if not audio_file or not dialogue:
        return jsonify({"error": "Missing audio_file or dialogue"}), 400

    item = {
        'name': audio_file,
        'prefix': "The task is dialog.",
        'dialogue': dialogue
    }

    processed_item = DataProcessor.process(item)
    
    inference_kwargs = {
        "do_sample": True,
        "top_k": 50,
        "top_p": 0.95,
        "num_return_sequences": 1
    }

    response = inference(model, text_tokenizer, item, processed_item, inference_kwargs)
    
    return jsonify({"response": response})

# Initialize the model and other components
config = load_config('configs/chat.yaml')
clap_config = config['clap_config']
model_config = config['model_config']

set_seed(0)
text_tokenizer = prepare_tokenizer(model_config)
model = prepare_model(
    model_config=model_config, 
    clap_config=clap_config, 
    checkpoint_path="chat.pt"
)

DataProcessor = AudioTextDataProcessor(
    data_root='model_ckpts/datasets',
    clap_config=clap_config,
    tokenizer=text_tokenizer,
    max_tokens=512,
)

print("Server initialized and ready to accept requests.")

## Start the Server

Now, let's start the Flask server in the background.

In [None]:
from threading import Thread

def run_app():
    app.run(port=5000)

thread = Thread(target=run_app)
thread.start()
print("Server is running in the background.")

## Client Code

Now that our server is running, let's create a client to interact with it.

In [None]:
import requests

def chat_with_audio(audio_file, dialogue):
    url = "http://localhost:5000/chat"
    
    payload = {
        "audio_file": audio_file,
        "dialogue": dialogue
    }
    
    headers = {
        "Content-Type": "application/json"
    }
    
    try:
        response = requests.post(url, json=payload, headers=headers)
        response.raise_for_status()
        return response.json()["response"]
    except requests.exceptions.RequestException as e:
        print(f"An error occurred: {e}")
        return None

# Example usage
audio_file = "audioset/eval_segments/22khz/Y0bRUkLsttto.wav"
dialogue = [
    {"user": "What genre does this music belong to?"}
]

response = chat_with_audio(audio_file, dialogue)
print("API Response:", response)

# Add a follow-up question
dialogue.append({"user": "Can you describe the vocals in this track?"})
response = chat_with_audio(audio_file, dialogue)
print("API Response:", response)

## Interactive Chat Session

Here's an interactive cell where you can have a multi-turn conversation about an audio file.

In [None]:
audio_file = "audioset/eval_segments/22khz/YXyktNsq4SZU.wav"
dialogue = []

print("Starting an interactive chat session. Type 'quit' to end the conversation.")
while True:
    user_input = input("Your question: ")
    if user_input.lower() == 'quit':
        break
    
    dialogue.append({"user": user_input})
    response = chat_with_audio(audio_file, dialogue)
    
    if response:
        print(f"API Response: {response}")
    else:
        print("Failed to get a response from the API.")

print("Chat session ended.")

## Cleanup

When you're done, run this cell to stop the Flask server.

In [None]:
import os
import signal

os.kill(os.getpid(), signal.SIGINT)
print("Server stopped.")