## Function `generate_exaone_response` extended to `generate_response`

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def load_models():
  """Loads and returns all necessary language models and tokenizers."""

  exaone_model = AutoModelForCausalLM.from_pretrained(
      "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct",
      torch_dtype=torch.bfloat16,
      trust_remote_code=True,
      device_map="auto"
  )
  exaone_tokenizer = AutoTokenizer.from_pretrained("LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct")

  llama_model = AutoModelForCausalLM.from_pretrained( 
      "lmsys/vicuna-33b-v1.3",
      torch_dtype=torch.bfloat16,
      trust_remote_code=True,
      device_map="auto"
  )
  llama_tokenizer = AutoTokenizer.from_pretrained("lmsys/vicuna-33b-v1.3")

  gemma_model = AutoModelForCausalLM.from_pretrained(
      "google/gemini-pro",
      torch_dtype=torch.bfloat16,
      trust_remote_code=True,
      device_map="auto"
  )
  gemma_tokenizer = AutoTokenizer.from_pretrained("google/gemini-pro")

  return {
      "exaone3.0": (exaone_model, exaone_tokenizer),
      "llama3.1": (llama_model, llama_tokenizer),
      "gemma2.0": (gemma_model, gemma_tokenizer),
  }

# Load all models at the beginning
models = load_models() 

def generate_exaone_response(prompt):
  """Generates a response from EXAONE 3.0 model."""
  model, tokenizer = models["exaone3.0"]

  messages = [
      {"role": "system", 
       "content": "You are EXAONE model from LG AI Research, a helpful assistant."},
      {"role": "user", "content": prompt}
  ]
  input_ids = tokenizer.apply_chat_template(
      messages,
      tokenize=True,
      add_generation_prompt=True,
      return_tensors="pt"
  )

  output = model.generate(
      input_ids.to("cuda"),
      eos_token_id=tokenizer.eos_token_id,
      max_new_tokens=128
  )

  print(tokenizer.decode(output[0]))


def generate_llama_response(prompt):
  """Generates a response from Llama 3.1 model."""
  model, tokenizer = models["llama3.1"]

  messages = [
      {"role": "system",
       "content": "You are Vicuna model from lmsys, a helpful assistant."},
      {"role": "user", "content": prompt}
  ]
  input_ids = tokenizer.apply_chat_template(
      messages,
      tokenize=True,
      add_generation_prompt=True,
      return_tensors="pt"
  )

  output = model.generate(
      input_ids.to("cuda"),
      eos_token_id=tokenizer.eos_token_id,
      max_new_tokens=128
  )

  print(tokenizer.decode(output[0]))


def generate_gemma_response(prompt):
  """Generates a response from Gemma 2.0 model."""
  model, tokenizer = models["gemma2.0"]

  messages = [
      {"role": "system",
       "content": "You are GEMMA model from Google, a helpful assistant."},
      {"role": "user", "content": prompt}
  ]
  input_ids = tokenizer.apply_chat_template(
      messages,
      tokenize=True,
      add_generation_prompt=True,
      return_tensors="pt"
  )

  output = model.generate(
      input_ids.to("cuda"),
      eos_token_id=tokenizer.eos_token_id,
      max_new_tokens=128
  )

  print(tokenizer.decode(output[0]))


def generate_response(prompt, model_name="exaone3.0"):
  """
  Generates a response from a specified language model.

  Args:
    prompt: The input prompt.
    model_name: The name of the language model.
  """

  if model_name == "exaone3.0":
    generate_exaone_response(prompt)
  elif model_name == "llama3.1":
    generate_llama_response(prompt)
  elif model_name == "gemma2.0":
    generate_gemma_response(prompt)
  else:
    print(f"Model '{model_name}' is not currently supported.")


# Example usage
prompt = "Explain who you are"
model_name = "exaone3.0" 
generate_response(prompt, model_name)

In [None]:
def generate_response(prompt, model_name="exaone3.0"):
  """
  Generates a response from a specified language model.

  Args:
    prompt: The input prompt.
    model_name: The name of the language model.
  """

  if model_name == "exaone3.0":
    generate_exaone_response(prompt)
  elif model_name == "llama3.1":
    generate_llama_response(prompt)
  elif model_name == "gemma2.0":
    generate_gemma_response(prompt)
  else:
    print(f"Model '{model_name}' is not currently supported.")
