In [None]:
import argparse
import os
import random
import time

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from gradio import processing_utils
import json
from transformers import LlamaTokenizer, WhisperFeatureExtractor
from transformers import GenerationConfig
from blsp.src.modeling_blsp import BlspModel
from blsp.src.speech_text_paired_dataset import get_waveform

def parse_args():
    parser = argparse.ArgumentParser(description="Speech-Language Demo")
    parser.add_argument(
        "--blsp_model", type=str, default='checkpoints/stage2',
        help="Path to the blsp model"
    )
    ### args for generation
    parser.add_argument(
        "--max_new_tokens", type=int, default=128,
        help="max new tokens for generation"
    )
    parser.add_argument(
        "--min_new_tokens", type=int, default=1,
        help="min new tokens for generation"
    )
    parser.add_argument(
        "--temperature", type=float, default=0.1,
        help="temperature for generation"
    )
    parser.add_argument(
        "--top_p", type=float, default=0.75,
        help="top_p for generation"
    )
    args = parser.parse_args()
    return args

generation_config = GenerationConfig(
    max_new_tokens=128,
    min_new_tokens=1,
    do_sample=True,
    temperature=0.9,
    top_p=0.75,
    num_beams=1,
    num_return_sequences=1,
)

class ChatHistory(object):
    def __init__(self, tokenizer, extractor):
        super().__init__()
        self.tokenizer = tokenizer
        self.extractor = extractor
        self.history = []
        self.audio_file = []
        self.audio_to_history = True

        ### add bos token
        self.add_bos()

    def add_bos(self):
        input_ids = tokenizer("", return_tensors="pt").input_ids.cuda()
        self.history.append(
            (input_ids,)
        )

    def add_text_history(self, text):
        input_ids = self.tokenizer(text, return_tensors="pt").input_ids[:,1:].cuda()
        self.history.append(
            (input_ids,)
        )

    def add_audio(self, audio_file):
        self.audio_to_history = False
        self.audio_file.append(audio_file)

    def add_speech_history(self, speech):
        if self.audio_to_history:
            return
        self.audio_to_history = True
        speech = get_waveform(speech, output_sample_rate=self.extractor.sampling_rate)
        speech_inputs = self.extractor(
            speech,
            sampling_rate=self.extractor.sampling_rate,
            return_attention_mask=True,
            return_tensors="pt"
        )
        speech_values = speech_inputs.input_features.cuda()
        speech_attention_mask = speech_inputs.attention_mask.cuda()
        self.history.append(
            (speech_values, speech_attention_mask)
        )

print('Initializing Chat')
args = parse_args()

tokenizer = LlamaTokenizer.from_pretrained(args.blsp_model)
extractor = WhisperFeatureExtractor.from_pretrained(args.blsp_model)
model = BlspModel.from_pretrained(args.blsp_model)

generation_config.update(
    **{
        "max_new_tokens": args.max_new_tokens,
        "min_new_tokens": args.min_new_tokens,
        "temperature": args.temperature,
        "top_p": args.top_p,
        "pad_token_id": tokenizer.pad_token_id,
        "bos_token_id": tokenizer.bos_token_id,
        "eos_token_id": tokenizer.eos_token_id
    }
)

def add_text(user_message):
    history.add_text_history("###[Human]:")
    history.add_text_history(user_message)
    history.add_text_history("\n\n\n###[Assistant]:")
    
def add_file(audio_file):
    history.add_text_history("###[Human]:")
    history.add_audio(audio_file)
    history.add_speech_history(history.audio_file[-1])
    history.add_text_history("\n\n\n###[Assistant]:")

model = model.cuda()
model.eval()
