In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

In [None]:
!pip install -q  torch peft bitsandbytes transformers trl accelerate sentencepiece numpy matplotlib seaborn

In [None]:
# imports
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, PeftConfig
from IPython.display import clear_output

In [None]:
from huggingface_hub import login

In [None]:
BASE_MODEL_NAME = "meta-llama/Llama-3.1-8B"
PROJECT_NAME = 'messages'
RUN_NAME = 'a100'
MODEL_NAME = f"Tom10117/{PROJECT_NAME}-{RUN_NAME}"
MAX_LENGTH = 200
ME = "Tiến Dũng Nguyễn"

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    quantization_config=quant_config,
    device_map="auto",
)

base_model.config.use_cache = False
base_model.config.pretraining_tp = 1

model = PeftModel.from_pretrained(
    base_model,
    MODEL_NAME,
    tokenizer=tokenizer,
    max_seq_length=MAX_LENGTH,
)

In [None]:
SUPPRESS_TOKENS = [26308, 243, 162,   155,   149, 160, 47, 18610]
BAD_WORDS = [[26308], [243], [162], [155], [149], [160], [47], [18610], [229,159,171]]


def generate_next(text, min_tokens, max_tokens):
  inputs, outputs = [], []
  attempt = ""
  final_tokens = []
  try:
    inputs = tokenizer(text, return_tensors="pt").to('cuda')
    outputs = model.generate(**inputs,
                                    max_new_tokens=max_tokens,
                                    min_new_tokens=min_tokens,
                                    return_dict_in_generate=True,
                                    output_scores=False,
                                    no_repeat_ngram_size=6,
                                    suppress_tokens = SUPPRESS_TOKENS,
                                    bad_words_ids = BAD_WORDS)
    sequence = outputs['sequences'][0]
    attempt = tokenizer.decode(sequence, skip_special_tokens=True)
    final_tokens = sequence[:-10]
  finally:
    del inputs
    del outputs
    torch.cuda.empty_cache()
  return attempt, final_tokens

In [None]:
class Message:
  def __init__(self, sender=None, message=None, text=None):
    self.is_complete = True
    if message is not None:
      self.sender = ME if sender is None else sender
      self.message = message
    else:
      if ':' not in text and ';' in text:
        text = text.replace(';',':')
      if ':' not in text:
        self.sender = text
        self.message = ''
        self.is_complete = False
      else:
        beginning, ending = text.split(':')
        self.sender = beginning.replace('###', '').strip()
        self.message = ending.strip()

  def __repr__(self):
    if self.is_complete:
      return f'### {self.sender}: {self.message}'
    else:
      return f'### {self.sender}'


class Conversation:

  NLI_MAX_MESSAGES = 20

  def __init__(self, who):
    self.who = who
    self.messages = []
    self.nli_message_count = 0
    self.current_sender = ME

  def prefix(self):
    result = f"<<SYS>>Write a realistic text message chat. Avoid repetition.<</SYS>>\n"
    result += f"[INST]Write a chat between {ME} and {self.who}[/INST]\n"
    return result

  def next_sender(self):
    self.current_sender = self.who if self.current_sender == ME else ME

  def add(self, message_contents):
    self.add_message(Message(message=message_contents, sender=self.current_sender))

  def add_message(self, message):
    self.messages.append(message)

  def add_prompt(self):
    self.add('')

  def nli(self):
    result = self.prefix()
    nlis = [message.__repr__() for message in self.messages[-Conversation.NLI_MAX_MESSAGES:]]
    self.nli_message_count = len(nlis)
    result += ' '.join(nlis)
    return result

  def __repr__(self):
    result = ""
    for message in self.messages:
      result += message.__repr__() + '\n'
    return result

  def process(self, language):
    language = language.replace('?:',':').replace('::',':')
    incoming = language.replace(' ###','###').split('### ')[1:]
    self.messages = self.messages[:-1] # remove the last message
    new_messages = incoming[self.nli_message_count-1:]
    for index, new_message in enumerate(incoming[self.nli_message_count-1:]):
      message = Message(text=new_message)
      if message.sender != self.current_sender and index != 0:
        return True
      else:
        self.add_message(message)
    return False

In [None]:
def generate_response(model, tokenizer, input_text, max_tokens=100, min_tokens=5):
    inputs = tokenizer(input_text, return_tensors="pt").to('cuda')
    output = model.generate(
        **inputs,
        max_new_tokens=max_tokens,
        min_new_tokens=min_tokens,
        return_dict_in_generate=True,
        output_scores=False,
        no_repeat_ngram_size=6,
        suppress_tokens=SUPPRESS_TOKENS,
        bad_words_ids=BAD_WORDS
    )

    sequences = output['sequences'][0]

    generated_text = tokenizer.decode(sequences, skip_special_tokens=True)

    return generated_text

input_text = "what is 1 + 1 equal to"
response = generate_response(model, tokenizer, input_text)
print(response)

In [None]:
# Compare generation from baseline and fine-tuned model
def compare_models(baseline_model, fine_tuned_model, tokenizer, text):
    inputs = tokenizer(text, return_tensors="pt").to('cuda')

    # Baseline model output
    baseline_output = baseline_model.generate(**inputs, max_new_tokens=100)
    baseline_response = tokenizer.decode(baseline_output[0], skip_special_tokens=True)

    # Fine-tuned model output
    fine_tuned_output = fine_tuned_model.generate(**inputs, max_new_tokens=100)
    fine_tuned_response = tokenizer.decode(fine_tuned_output[0], skip_special_tokens=True)

    print("Baseline Model Response:\n", baseline_response)
    print("Fine-Tuned Model Response:\n", fine_tuned_response)

text = "what is 1 + 1 equal to"
compare_models(base_model, model, tokenizer, text)

In [None]:
print('Who is the conversation with?')
who = input()

conversation = Conversation(who)
while True:
  print(f'{conversation.current_sender}: ')
  reply = input()
  if reply == 'stop':
    break
  elif reply != '':
    conversation.add(reply)
  else:
    conversation.add_prompt()
    ready = False
    while not ready:
      language, final_tokens = generate_next(conversation.nli(), 3, 8)
      ready = conversation.process(language)
      clear_output(wait=True)
      print(conversation)
      # print(final_tokens)
  conversation.next_sender()