# **DEPENDENCIES**

In [None]:
!pip3 uninstall torch torchvision torchaudio -y

In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
%%capture
!pip install exllamav2==0.0.8
!pip install huggingface_hub
!pip install git+https://github.com/m-bain/whisperx.git
!pip install hyperdb-python
!pip install sentence-transformers
!pip install edge-tts
!pip install gradio
!pip install diffusers
!pip install pytz
!pip install pyngrok

# **IMPORTS**

In [None]:
from exllamav2 import (ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer)
from exllamav2.generator import (ExLlamaV2StreamingGenerator, ExLlamaV2Sampler)
import whisperx
import torch
from hyperdb import HyperDB
from sentence_transformers import SentenceTransformer
import gradio as gr
from diffusers import StableDiffusionPipeline
import soundfile as sf
import pytz

from huggingface_hub import snapshot_download, hf_hub_download
import os
import sys
import subprocess
import json
import re
from datetime import datetime
import ast

In [None]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# **SETUP**

In [None]:
project_path = "/kaggle/working/AIW"
input_data_path = "/kaggle/input/ai-data"  # Modify
!mkdir AIW
!mkdir AIW/LLM
!cp {input_data_path}/conversation.jsonl {project_path}/conversation.jsonl
!cp {input_data_path}/lore.txt {project_path}/lore.txt

Download LL GPT Model

In [None]:
llm_model_id = "TheBloke/dolphin-2.1-mistral-7B-GPTQ"
llm_local_dir = project_path + "/LLM/dolphin-2.1-mistral-7B-GPTQ"

if os.path.exists(llm_local_dir) == False:
  snapshot_download(repo_id=llm_model_id, local_dir=llm_local_dir, local_dir_use_symlinks=False)

# **CONFIGURATION**

Configure and initialize LL GPT Model

In [None]:
llm_config = ExLlamaV2Config()
llm_config.model_dir = llm_local_dir
llm_config.prepare()

In [None]:
ExLlamatokenizer = ExLlamaV2Tokenizer(llm_config)
llm_model = ExLlamaV2(llm_config)
llm_model.load([16, 24])

llm_cache = ExLlamaV2Cache(llm_model)

In [None]:
llm_generator = ExLlamaV2StreamingGenerator(llm_model, llm_cache, ExLlamatokenizer)
llm_generator.set_stop_conditions(['"}'])

In [None]:
llm_settings = ExLlamaV2Sampler.Settings()
llm_settings.temperature = 0.85
llm_settings.top_k = 50
llm_settings.top_p = 0.8
llm_settings.token_repetition_penalty = 1.15
llm_settings.disallow_tokens(ExLlamatokenizer, [ExLlamatokenizer.eos_token_id])

llm_max_new_tokens = 250

Configure and initialize Stable Diffusion Model

In [None]:
# !wget https://civitai.com/api/download/models/???? --content-disposition

For LoRA: (Not really tested)

In [None]:
# sd_model_id = "runwayml/stable-diffusion-v1-5"
# sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id, torch_dtype=torch.float16,
#     safety_checker = None, requires_safety_checker = False)
# sd_pipe.scheduler = DPMSolverMultistepScheduler.from_config(sd_pipe.scheduler.config)

# # load lora weight
# model_path = "something-10.safetensors"
# state_dict = safetensors.torch.load_file(model_path)

# LORA_PREFIX_UNET = 'lora_unet'
# LORA_PREFIX_TEXT_ENCODER = 'lora_te'

# alpha = 0.75

# visited = []

# # directly update weight in diffusers model
# for key in state_dict:
    
#     # it is suggested to print out the key, it usually will be something like below
#     # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
    
#     # as we have set the alpha beforehand, so just skip
#     if '.alpha' in key or key in visited:
#         continue
        
#     if 'text' in key:
#         layer_infos = key.split('.')[0].split(LORA_PREFIX_TEXT_ENCODER+'_')[-1].split('_')
#         curr_layer = sd_pipe.text_encoder
#     else:
#         layer_infos = key.split('.')[0].split(LORA_PREFIX_UNET+'_')[-1].split('_')
#         curr_layer = sd_pipe.unet

#     # find the target layer
#     temp_name = layer_infos.pop(0)
#     while len(layer_infos) > -1:
#         try:
#             curr_layer = curr_layer.__getattr__(temp_name)
#             if len(layer_infos) > 0:
#                 temp_name = layer_infos.pop(0)
#             elif len(layer_infos) == 0:
#                 break
#         except Exception:
#             if len(temp_name) > 0:
#                 temp_name += '_'+layer_infos.pop(0)
#             else:
#                 temp_name = layer_infos.pop(0)
    
#     # org_forward(x) + lora_up(lora_down(x)) * multiplier
#     pair_keys = []
#     if 'lora_down' in key:
#         pair_keys.append(key.replace('lora_down', 'lora_up'))
#         pair_keys.append(key)
#     else:
#         pair_keys.append(key)
#         pair_keys.append(key.replace('lora_up', 'lora_down'))
    
#     # update weight
#     if len(state_dict[pair_keys[0]].shape) == 4:
#         weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
#         weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
#         curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
#     else:
#         weight_up = state_dict[pair_keys[0]].to(torch.float32)
#         weight_down = state_dict[pair_keys[1]].to(torch.float32)
#         curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
        
#      # update visited list
#     for item in pair_keys:
#         visited.append(item)

# sd_pipe = sd_pipe.to("cuda")

In [None]:
# sd_model_id = "runwayml/stable-diffusion-v1-5"  # General
sd_model_id = "Linaqruf/anything-v3.0"  # Anime styled
sd_pipe = StableDiffusionPipeline.from_pretrained(sd_model_id, torch_dtype=torch.float16,
    safety_checker = None,
    requires_safety_checker = False)
sd_pipe = sd_pipe.to("cuda")

Configure and initialize WhisperX Model

In [None]:
%%capture
whisper_output = r"temp.wav"
whisper_batch_size = 16

whisper_model = whisperx.load_model("medium.en", device="cuda", language="en", compute_type="float32")

Configure and initialize HyperDB

In [None]:
%%capture
documents = []

conversation_path = project_path + "/conversation.jsonl"
db_path = project_path + "/conversation.pickle.gz"

with open(conversation_path, "r", encoding="utf-8") as f:
    for line in f:
         documents.append(line)

model = SentenceTransformer('all-MiniLM-L6-v2')
db = HyperDB(documents, key="doesnt.really.matter.here", embedding_function=model.encode)

db.save(db_path)
db.load(db_path)

# **AI SETTINGS**

In [None]:
lore_path = project_path + '/lore.txt'
with open(lore_path, 'r', encoding='UTF-8') as file:
  lore = file.readlines()

In [None]:
user = "user name"  # Modify
ai = "ai name"  # Modify
lore = ''.join(lore).replace('\n', ' ')
timezone = pytz.timezone('Asia/Shanghai')  # Modify (For list of timezones run: pytz.all_timezones)

# **MODULAR FUNCTIONS**

In [None]:
def process_text(input_text, role=user):
  date_time = datetime.now(timezone)
  date = date_time.strftime("%m/%d/%Y")
  time = date_time.strftime("%H:%M:%S")

  nop = False
  if input_text.startswith("[NOP]"):
    nop = True
    input_text = input_text[len("[NOP]"):]

  new_document = {"role": role, "date": date, "time": time, "content": input_text}

  print("\n\nYou: " + str(new_document))

  with open(conversation_path, 'a', encoding='UTF-8') as c:
    c.write("\n" + json.dumps(new_document, ensure_ascii=False))
  db.add_document(str(new_document))

  if nop:
    print("\n\n[NOP]")
    return None, None, None

  results = db.query(new_document["content"], top_k=4)
  related_content = [ast.literal_eval(doc)["content"] for doc, _ in results]
  if new_document["content"] in related_content:
    related_content.remove(new_document["content"])
  print("\n\nRelated content:")
  for r in results:
    print(r)

  lines = []
  with open(conversation_path, 'r', encoding='UTF-8') as file:
    lines = file.readlines()

  lsize = 11
  lfrom = lsize if len(lines) > lsize else 0
  print("\n\nLast lines:")
  for ll in lines[-lfrom:-1]:
    print(ll)

  last_lines = ''.join(lines[-lfrom:-1])

  long_term_memory = []
  if len(lines) >= lsize:
    for content in related_content:
      if content not in last_lines:
        for i, line in enumerate(lines):
          if content in line:
            long_term_memory.append(lines[i])
            long_term_memory.append(lines[i + 1])
            break
  print("\n\nFinal LTM:")
  for ltm in long_term_memory:
    print(ltm)

  prompt = lore + "\n\n" + ''.join(long_term_memory) + last_lines + str(new_document) + f'\n{{"role": "{ai}", "date": "{date}", "time": "{time}", "content": "'
  prompt = str(prompt)
  print('\nPrompt:\n', prompt)

  input_ids = ExLlamatokenizer.encode(prompt)
  sys.stdout.flush()

  llm_generator.begin_stream(input_ids, llm_settings)

  generated_tokens = 0

  print("\n"+ai+": ", end = "")
  generated_text = ""
  while True:
      chunk, eos, _ = llm_generator.stream()
      generated_tokens += 1
      generated_text += chunk
      print (chunk, end = "")
      sys.stdout.flush()
      if eos or (len(chunk)>0 and chunk[-1] == '}') or generated_tokens == llm_max_new_tokens: break
  print()
  print()

  if generated_text[-1] == '}':
    generated_text = generated_text[:-1]
  new_document = {"role": ai, "date": date, "time": time, "content": generated_text}

  response = '"'+generated_text+'"'
  !edge-tts --pitch=+40Hz --text {re.sub("[\(\[].*?[\)\]]", "", response)} --write-media tts_out.mp3

  image_prompt = extract_image_prompt(generated_text)
  image = None
  if image_prompt != None:
    image = generate_image(image_prompt)

  with open(conversation_path, "a", encoding='UTF-8') as c:
      c.write("\n" + json.dumps(new_document, ensure_ascii=False))
  db.add_document(str(new_document))      
  db.save(db_path)

  return generated_text, "tts_out.mp3", image

In [None]:
def process_self_text():
  date_time = datetime.now(timezone)
  date = date_time.strftime("%m/%d/%Y")
  time = date_time.strftime("%H:%M:%S")

  lines = []
  with open(conversation_path, 'r', encoding='UTF-8') as file:
    lines = file.readlines()
    
  last_document = ast.literal_eval(lines[-1])
  print("\n\nLast: " + str(last_document))
    
  results = db.query(last_document["content"], top_k=4)
  related_content = [ ast.literal_eval(doc)["content"] for doc, _ in results]
  if last_document["content"] in related_content:
    related_content.remove(last_document["content"])
  print("\n\nRelated content:")
  for r in results:
    print(r)

  lsize = 11
  lfrom = lsize if len(lines) > lsize else 0
  print("\n\nLast lines:")
  for ll in lines[-lfrom:-1]:
    print(ll)

  last_lines = ''.join(lines[-lfrom:-1])

  long_term_memory = []
  if len(lines) >= lsize:
    for content in related_content:
      if content not in last_lines:
        for i, line in enumerate(lines):
          if content in line:
            long_term_memory.append(lines[i])
            long_term_memory.append(lines[i + 1])
            break
  print("\n\nFinal LTM:")
  for ltm in long_term_memory:
    print(ltm)

  prompt = lore + "\n\n" + ''.join(long_term_memory) + last_lines + str(last_document) + f'\n{{"role": "{ai}", "date": "{date}", "time": "{time}", "content": "'
  prompt = str(prompt)
  print('\nPrompt:\n', prompt)

  input_ids = ExLlamatokenizer.encode(prompt)
  sys.stdout.flush()

  llm_generator.begin_stream(input_ids, llm_settings)

  generated_tokens = 0

  print("\n"+ai+": ", end = "")
  generated_text = ""
  while True:
      chunk, eos, _ = llm_generator.stream()
      generated_tokens += 1
      generated_text += chunk
      print (chunk, end = "")
      sys.stdout.flush()
      if eos or (len(chunk)>0 and chunk[-1] == '}') or generated_tokens == llm_max_new_tokens: break
  print()
  print()

  if generated_text[-1] == '}':
    generated_text = generated_text[:-1]
  new_document = {"role": ai, "date": date, "time": time, "content": generated_text}

  response = '"'+generated_text+'"'
  !edge-tts --pitch=+40Hz --text {re.sub("[\(\[].*?[\)\]]", "", response)} --write-media tts_out.mp3

  image_prompt = extract_image_prompt(generated_text)
  image = None
  if image_prompt != None:
    image = generate_image(image_prompt)

  with open(conversation_path, "a", encoding='UTF-8') as c:
      c.write("\n" + json.dumps(new_document, ensure_ascii=False))
  db.add_document(str(new_document))      
  db.save(db_path)

  return generated_text, "tts_out.mp3", image

In [None]:
def extract_image_prompt(text):
    pattern = r'\[(.*?)\]'
    match = re.search(pattern, text)
    return match.group(1) if match else None

In [None]:
def generate_image(input_prompt):
  input_prompt = f"(({input_prompt})),(best quality),4K"
  image = sd_pipe(prompt=input_prompt, negative_prompt="((((ugly)))), (bad eyes), (((duplicate))), ((morbid)), ((mutilated)), [out of frame], extra fingers, mutated hands, ((poorly drawn hands)), ((poorly drawn face)), (((mutation))), (((deformed))), blurry, ((bad anatomy)), (((bad proportions))), ((extra limbs)), cloned face, (((disfigured))), gross proportions, (malformed limbs), ((missing arms)), ((missing legs)), (((extra arms))), (((extra legs))), (fused fingers), (too many fingers), (((long neck))), easynegative").images[0]

  file = "img_out.png"
  image.save(file)
  return file

In [None]:
def process_audio(audio_data, input_text=None):
    sample_rate, waveform = audio_data
    audio_file = "tts_in.mp3"
    sf.write("tts_in.mp3", waveform, sample_rate)

    audio = whisperx.load_audio(audio_file)
    result = whisper_model.transcribe(audio, batch_size=whisper_batch_size)
    trans = result["segments"][0]["text"]

    if input_text != None and input_text.startswith("[NOP]"):
        trans = "[NOP]"+trans
    
    response, tts, img = process_text(trans)
    return f'{trans}<SEP>{str(response)}', tts, img

# **API**

In [None]:
def router_function(api, input_text, input_audio):
    if api == "RESPOND_TO_TEXT":
        return process_text(input_text)
    elif api == "SELF_START":
        return process_self_text()
    elif api == "RESPOND_TO_AUDIO":
        return process_audio(input_audio, input_text)
    else:
        return "Invalid API", None, None

In [None]:
router_interface = gr.Interface(fn=router_function, inputs=["text", "text", "audio"], outputs=["text", "audio", "image"])
router_interface.launch(share=False, debug=False)

In [None]:
!ngrok config add-authtoken auth_token

In [None]:
!ngrok http 7860 2>&1

# **TEST**

In [None]:
#process_self_text()

In [None]:
#process_text("Hello")