# Digital Guru
In this notebook we are gonna fine tune an opensource LLM on our own data for it to learn the user's behaviour and tone.

Steps:
* **Data Curation:** In this section we will parse the whatapp and Instagram data from a csv file. We will clean it for use and upload it to Huggingface.
  Note: For whatsapp csv I used imazing to export the csv files but for instagram I used meta's official data exporter. The instagram data is exported as HTML which was then converted to CSV with a custom jupyter notebook.
* **Fine-Tuning:** In this section we will use QLoRA for fine tuning the *Llama-3.2-3B-Instruct* model.
* **Inference:** Lastly we will run our model. I use gradio to provide a GUI for quick and easy usage.

# Data Curation

In [None]:
# installs
#!pip install ipywidgets datasets cryptography torch transformers sentencepiece matplotlib wordcloud

# imports
import csv
import datetime
import random
from collections import Counter
import datasets
from cryptography.fernet import Fernet  # to encrypt our texts
import tqdm
import torch
from transformers import AutoTokenizer
import matplotlib
import matplotlib.pyplot as plt
from wordcloud import WordCloud
import re
%matplotlib inline

In [None]:
# constants
MAX_LENGTH = 200  # The length of each chunk
DATA_NAME = "biggestFudge/messagesv3-whatsapp"  # for uploading
ME = "Guru" # your name here!
base_model_name = "meta-llama/Llama-3.2-3B-Instruct"  # for the tokenizer

In [None]:
# Login to HF
# Initialize and constants
import os
from dotenv import load_dotenv
load_dotenv(override=True)
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')

from huggingface_hub import login
hf_token = os.environ['HF_TOKEN']
login(hf_token, add_to_git_credential=True)

In [None]:
# Parse the csv files containg data 
# csv must contain 'Chat Session', 'Type', 'Text', 'Message Date'
# You can have many more sources of data, I used my personal and Business WhatsApp and Instagram 
# 'Chat Session' : Name of the chat for e.g. is the chat is with a person named ABC then this colum must contain ABC
# 'Type' : Should be ['Incoming', 'Outgoing'] representing is the text was received or was sent by user
# 'Text' : The actual text conversation
# 'Message Date' : The date of the message
txts = []
filenames = ['WhatsApp_1.csv', 'WhatsApp_2.csv', 'insta_1.csv']
for filename in filenames:
    with open(filename, newline='') as csvfile:
        reader = csv.DictReader(csvfile, delimiter=',', quotechar='"')
        txts.extend(reader)
        
print(f'Read in a total of {len(txts):,} messages')

## Data Cleanup

In [None]:
# Removing the dummy row from whatsapp
INTRO = 'Messages to this chat and calls are now secured with end-to-end encryption'
txts = [txt for txt in txts if txt['Text'] != INTRO]
print(f'Now a total of {len(txts):,} messages') 

In [None]:
# Quick check on how does the text looks like
txts[20]

In [None]:
# Some cleanup for insta based texts
# Here I tried removing some artifacts in text coming from instagram based text.
# My name "Guru Deep" has been put in cleanup, you can choose anything specific to your usecase

# I took help from chatGPT to device this function 
def clean_text(text: str) -> str: 
    # Remove http/https links (including common scheme typo https;//)
    link_pattern = re.compile(r"https?[;:][/\\]{2}\S+", flags=re.IGNORECASE)
    text = link_pattern.sub("", text)

    # Remove the word before the timestamp, including the two-word "Guru Deep"
    # Matches examples like:
    # "Guru Deep (Sep 29, 2024 3;09 am)"
    # "Alex: (Jan 1, 2025 12:05 pm)"
    month = r"(?:Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Sept|Oct|Nov|Dec)[a-z]*"
    datetime_core = rf"{month}\s+\d{{1,2}},\s+\d{{4}}\s+\d{{1,2}}[;:]\d{{2}}\s*(?:am|pm)"
    name_before_dt = re.compile(
        rf"""
        (?:\bGuru\ Deep|\b[^\s()]+)      # "Guru Deep" OR a single word just before the (
        \s*[:\-–—]?\s*                   # optional punctuation (e.g., colon) between name and (
        \(\s*{datetime_core}\s*\)        # the parenthesized datetime
        """,
        flags=re.IGNORECASE | re.VERBOSE,
    )
    text = name_before_dt.sub("", text)

    # Cleanup: collapse spaces and fix spaces before punctuation
    text = re.sub(r"\s{2,}", " ", text)
    text = re.sub(r"\s+([,.;:!?])", r"\1", text)
    return text.strip()
    

# Class Message will be used to create object for every message 
class Message:

    AVOID_WORDS = ["reagiert", "You sent an attachment", ".gif", "Missed voice call", "Gefällt"]
    def __init__(self, chat_session, message_type, text, when):
        self.name = chat_session
        self.sender = self.name if message_type == 'Incoming' else ME    
        self.receiver = ME if message_type == 'Incoming' else self.name
        self.text = text
        if when:
            try:
                self.when = datetime.datetime.strptime(when, "%Y-%m-%d %H:%M:%S")
            except ValueError:
                # handle unparseable date
                self.when = None
        else:
            self.when = None
            
        self.massage_text()

    def massage_text(self):
        
        # Some cleanup for instagram
        self.text = clean_text(self.text)
        
        # Replace special characters used in our format for training
        self.text = self.text.replace('\n','  ').replace(':',';').replace('#',';')

        # Remove all emojis
        emoji_pattern = re.compile(
        "[" 
        "\U0001F600-\U0001F64F"  # Emoticons
        "\U0001F300-\U0001F5FF"  # Symbols & pictographs
        "\U0001F680-\U0001F6FF"  # Transport & map symbols
        "\U0001F1E0-\U0001F1FF"  # Flags
        "\U00002500-\U00002BEF"  # Misc symbols
        "\U00002702-\U000027B0"
        "\U00002702-\U000027B0"
        "\U000024C2-\U0001F251"
        "\U0001f926-\U0001f937"
        "\U00010000-\U0010ffff"
        "\u2640-\u2642" 
        "\u2600-\u2B55"
        "\u200d"
        "\u23cf"
        "\u23e9"
        "\u231a"
        "\ufe0f"  # Dingbats
        "\u3030"
        "]+", 
        flags=re.UNICODE)
        
        self.text = emoji_pattern.sub('', self.text)

        
        # Indicate if the message is an image
        if self.text == '': self.text = '***'
    
    def should_exclude(self):
        if self.when is None:
            return True
        
        if any(word in self.text for word in self.AVOID_WORDS):
            return True
        return any(ch in self.name for ch in '+&,') or all(ch.isdigit() for ch in self.name)

# Create lists of messages
messages = [Message(t['Chat Session'], t['Type'], t['Text'], t['Message Date']) for t in txts]
messages = [m for m in messages if not m.should_exclude()]
print(f'A total of {len(messages):,} messages')

In [None]:
# Quick check 
print(messages[1].sender)
print(messages[1].when)
print(messages[1].text)

In [None]:
# Organize into dict with key = chat name, value = list of messages
chats = {}
for message in messages:
    if message.name not in chats:
        chats[message.name] = []
    chats[message.name].append(message)

# Sort the chats by time
for message_list in chats.values():
    message_list.sort(key = lambda m: m.when)

print(f'{sum([len(v) for v in chats.values()]):,} messages with {len(chats)} people')

In [None]:
# Explicitly removing some chats
# I did not want to use group chats as they did not have many messages from me thus I removed them
# I needs to customize this list

REMOVE_GROUP_CHATS = []
print("All Conversations before explicit removal: ")
print(chats.keys())

for rKey in REMOVE_GROUP_CHATS:
    if rKey in chats.keys():
        chats.pop(rKey)

In [None]:
# Checking if the chats were removed 
print("All Conversations after explicit removal: ")
print(sorted(chats.keys()))

In [None]:
# I only use chats which have minimum of 20 conversations
AT_LEAST = 20
chats = {name: messages for name, messages in chats.items() if len(messages)>=AT_LEAST}
print(f'{sum([len(v) for v in chats.values()]):,} messages with {len(chats)} people')

## Visualizing the Data

In [None]:
# After all filteration lets analyze the output in graphs
messages_cut = []
for name, messages in chats.items():
    messages_cut.extend(messages)

In [None]:
# Prepare data
dates = [message.when for message in messages_cut]

# Plot
fig, ax = plt.subplots(1, 1)
plt.title("How many texts I've sent over time")
ax.set_xlabel('Year')
ax.set_ylabel('How many texts');
ax.get_yaxis().set_major_formatter(matplotlib.ticker.FuncFormatter(lambda y, p: format(int(y), ',')))
_ = ax.hist(dates, bins=20, color='purple', rwidth=0.5)

In [None]:
# Prepare data
counter = Counter(message.name for message in messages_cut)
results = counter.most_common(40)
names, counts = zip(*results)

# Plot
fig, ax = plt.subplots(1, 1, figsize = (10, 5))
ax.set_ylabel('How many texts');
ax.get_yaxis().set_major_formatter(matplotlib.ticker.FuncFormatter(lambda y, p: format(int(y), ',')))
plt.xticks(range(len(names)), names, rotation='vertical')
_ = ax.bar(names, counts, color ='teal', width = 0.5)

In [None]:
# Prepare data
text = ' '.join([message.text for message in messages_cut])

# Plot
wordcloud = WordCloud(max_font_size=60, max_words=100).generate(text)
plt.figure(figsize = (10, 10))
plt.imshow(wordcloud, interpolation="bilinear")
plt.axis("off")
plt.show()

## Uploading the data to Huggingface

In [None]:
# Load the model
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline

tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=quant_config,
    device_map="auto",
)

In [None]:
# Document with chat template for llama 3
# Please use your own prompt

system_prompt = """You are Guru in this conversation. Respond only as Guru would in a realistic text message exchange.
Write naturally, using the tone, style, and pacing of everyday messaging.
Avoid repetition.
Never make up facts or invent details.
Answer in brief only."""

class Document:
    def __init__(self, name, messages):
        self.name = name
        self.messages = messages

    def to_messages(self):
        sys = {"role": "system", "content": system_prompt}
        convo = []
        for m in self.messages:
            role = "assistant" if m.sender == ME else "user"
            convo.append({"role": role, "content": f"### {m.sender}: {m.text}"})
        return [sys] + convo

    def token_len(self, tokenizer):
        prompt = tokenizer.apply_chat_template(self.to_messages(), tokenize=True, add_generation_prompt=False, return_tensors=None)
        return len(prompt)

In [None]:
documents = []
for name, message_list in tqdm.tqdm(chats.items()):
    pointer = 0
    while pointer < len(message_list):
        size = 1
        while pointer + size < len(message_list):
            next_doc = Document(name, message_list[pointer:pointer+size+1])
            if next_doc.token_len(tokenizer)>=MAX_LENGTH:
                break
            size += 1
        document = Document(name, message_list[pointer:pointer+size])
        documents.append(document)
        pointer += size
print(f"{len(documents):,} documents")

In [None]:
data = [doc.to_messages() for doc in documents]
lengths = [doc.token_len(tokenizer) for doc in documents]

# Count messages: skip the initial system message
counts = sum(len(msgs) - 1 for msgs in data)  # -1 for system

print(f'There are {counts:,} messages; average {counts/len(documents):.2} messages in each of {len(documents):,} documents')

In [None]:
fig, ax = plt.subplots(1, 1)
ax.set_xlabel('Number of tokens in a document')
ax.set_ylabel('Count of documents')
ax.get_yaxis().set_major_formatter(matplotlib.ticker.FuncFormatter(lambda y, p: format(int(y), ',')))
l2 = [min(MAX_LENGTH+100,l) for l in lengths]
_ = ax.hist(l2, bins=range(0,MAX_LENGTH+50,10), color='darkorange', rwidth=0.5)

In [None]:
random.seed(42) # for reproducibility
random.shuffle(data)

In [None]:
# quick schema sanity check (raises if anything is off)
for i, msgs in enumerate(data):
    if not isinstance(msgs, list):
        raise TypeError(f"Document {i} is not a list")
    for j, m in enumerate(msgs):
        if not isinstance(m, dict) or "role" not in m or "content" not in m:
            raise TypeError(f"Doc {i}, message {j} is not a {{'role','content'}} dict: {m}")


In [None]:
import json
#serialize each doc to a compact JSON string
#(compact separators reduce size; ensure_ascii=False keeps unicode readable pre-encryption)
serialized = [
    json.dumps(doc, ensure_ascii=False, separators=(",", ":"))
    for doc in data
]

In [None]:
# I chose to encrypt the data
# Even though the data is on my private repo on HF I still chose
# One layer of added safety

key = Fernet.generate_key()
print(key)  # NOTE: keep this safe; you'll need it to decrypt later
f = Fernet(key)


In [None]:
encrypted = [f.encrypt(s.encode("utf-8")).decode("utf-8") for s in serialized]

In [None]:
# Train/test split (same logic as before)
split = int(0.95 * len(encrypted))
train, test = encrypted[:split], encrypted[split:]

In [None]:
# Build dataset
from datasets import Dataset, DatasetDict
train_dataset = Dataset.from_dict({"text": train})
test_dataset  = Dataset.from_dict({"text": test})
dataset = DatasetDict({"train": train_dataset, "test": test_dataset})

In [None]:
# Push to HF
dataset.push_to_hub(DATA_NAME, private=True)

# Fine Tuning

In [None]:
import torch
from datasets import load_dataset, Dataset, DatasetDict
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline
)
from peft import LoraConfig
from trl import SFTTrainer, SFTConfig
from cryptography.fernet import Fernet
from getpass import getpass
from huggingface_hub import notebook_login
import os
import wandb
import json

In [None]:
# SETUP
DATA_NAME = 'biggestFudge/messagesv3-whatsapp'
PROJECT_NAME = 'messages'
RUN_NAME = 'v4'
MAX_SEQ_LENGTH = 200
BASE_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
REFINED_MODEL_NAME = f"biggestFudge/{PROJECT_NAME}-{RUN_NAME}"

# HYPER-PARAMETERS
LORA_ALPHA = 64
LORA_R = 32
LORA_DROPOUT = 0.1
BATCH_SIZE = 1
GRADIENT_ACCUMULATION_STEPS = 4
LEARNING_RATE = 2e-4
LR_SCHEDULER_TYPE = 'cosine'
WEIGHT_DECAY = 0.001
TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

# OTHER TRAINING CONFIG
# Choose your config carefully depending on the scale of your dataset
STEPS = 10
SAVE_STEPS = 20
EVAL_STEPS = 60
LOG_TO_WANDB = True


In [None]:
# Setting up weights and Biases for live plotting
wandb_api_key = os.environ['WANDB_API_KEY']
os.environ["WANDB_API_KEY"] = wandb_api_key
wandb.login()

# Configure Weights & Biases to record against our project
os.environ["WANDB_PROJECT"] = PROJECT_NAME
os.environ["WANDB_LOG_MODEL"] = "checkpoint" if LOG_TO_WANDB else "end"
os.environ["WANDB_WATCH"] = "gradients"

In [None]:
if LOG_TO_WANDB:
  wandb.init(project=PROJECT_NAME, name=RUN_NAME)

In [None]:
# First load the dataset from Hugging Face
from getpass import getpass
encrypted_data = load_dataset(DATA_NAME)

# Next, decrypt
# It will ask for the key you used during encription 
key = getpass("Enter encryption key").encode()
f = Fernet(key)

decrypted = {"train": [], "test": []}

for split_name in ["train", "test"]:
    split = encrypted_data[split_name]
    for row in split:
        cipher_str = row["text"]                   # base64 string
        plain_bytes = f.decrypt(cipher_str.encode("utf-8"))
        json_str = plain_bytes.decode("utf-8")
        obj = json.loads(json_str)                 # -> list[{"role","content"}, ...]
        decrypted[split_name].append(obj)

# Finally, recreate the dataset
train_dataset = Dataset.from_dict({'text':decrypted['train']})
test_dataset = Dataset.from_dict({'text':decrypted['test']})
data = DatasetDict({'train':train_dataset, 'test':test_dataset})

# Quick check
print(data)
print(data['train'][0])
print(data['test'][0])

In [None]:
# Model and tokenizer names
base_model_name = BASE_MODEL_NAME
refined_model = REFINED_MODEL_NAME

# Tokenizer
llama_tokenizer = AutoTokenizer.from_pretrained(base_model_name, trust_remote_code=True)
llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_tokenizer.padding_side = "right"

# Quantization Config
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True
)

# Model
base_model = AutoModelForCausalLM.from_pretrained(
    base_model_name,
    quantization_config=quant_config,
    device_map="auto"
)

base_model.config.use_cache = False
base_model.config.pretraining_tp = 1

In [None]:
def render_to_prompt(example):
    msgs = example["text"]
    prompt = llama_tokenizer.apply_chat_template(
        msgs,
        tokenize=False,
        add_generation_prompt=False,  # training on full dialogues
    )
    return {"text": prompt}

data = DatasetDict({
    "train": data["train"].map(render_to_prompt),
    "test": data["test"].map(render_to_prompt),
})

In [None]:
print(data['train'][0]["text"])

## Let the fun begin!

In [None]:
# LoRA Config
peft_parameters = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    r=LORA_R,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=TARGET_MODULES,
)

# Depending on version of trl you may or may not have default datacollator
# I used Google collab which had 0.21 and had default data collator

# # Data Collator
# #from trl import DataCollatorForCompletionOnlyLM
# assistant_prefix = "<|start_header_id|>assistant<|end_header_id|>\n\n"
# data_collator = DataCollatorForCompletionOnlyLM(
#     tokenizer=llama_tokenizer,
#     response_template=assistant_prefix,
# )

# # (Optional) quick sanity check: make sure assistant header exists in samples
# assert assistant_prefix in data["train"][0]["text"], "Assistant header not found in dataset text"

# --- Training / SFT config ---
train_params = SFTConfig(
    output_dir=REFINED_MODEL_NAME,
    num_train_epochs=1,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=1,
    eval_strategy="steps",                 # <- was evaluation_strategy
    eval_steps=EVAL_STEPS,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    optim="paged_adamw_32bit",
    save_steps=SAVE_STEPS,
    save_total_limit=10,
    logging_steps=STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    fp16=False,
    bf16=True,
    max_grad_norm=0.3,
    max_steps=-1,
    warmup_ratio=0.03,
    group_by_length=True,
    lr_scheduler_type=LR_SCHEDULER_TYPE,
    report_to="wandb",
    run_name=RUN_NAME,

    # SFT-specific bits:
    max_seq_length=MAX_SEQ_LENGTH,
    dataset_text_field="text",             # <- moved here
    packing=False,

    # Hub
    push_to_hub=True,
    hub_model_id=REFINED_MODEL_NAME,
    hub_strategy="checkpoint",
    hub_private_repo=True,
)

# --- Trainer ---
fine_tuning = SFTTrainer(
    model=base_model,
    train_dataset=data["train"],
    eval_dataset=data["test"],
    peft_config=peft_parameters,
    processing_class=llama_tokenizer,      # <- was tokenizer=
    #data_collator=data_collator,
    args=train_params,
)

# Training
fine_tuning.train()

# Save Model
fine_tuning.save_model(refined_model) 
llama_tokenizer.save_pretrained(refined_model)

# --- Push to Hub (adapter + tokenizer) ---
fine_tuning.push_to_hub()                  # uses args.* hub settings
llama_tokenizer.push_to_hub(REFINED_MODEL_NAME)

In [None]:
if LOG_TO_WANDB:
  wandb.finish()

# Inference

In [None]:
# imports
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, PeftModel, PeftConfig
from IPython.display import clear_output

In [None]:
BASE_MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"
PROJECT_NAME = 'messages'
RUN_NAME = 'v4'
MODEL_NAME = f"biggestFudge/{PROJECT_NAME}-{RUN_NAME}"
MAX_LENGTH = 200
ME = "Guru" # your name here
REVISION = None #"d1a8b673cc3a46bde46c71690ef2d89b54bc1b47"#None

In [None]:
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    quantization_config=quant_config,
    device_map="auto",
)

base_model.config.use_cache = False
base_model.config.pretraining_tp = 1

if REVISION:
    model = PeftModel.from_pretrained(
        base_model,
        MODEL_NAME,
        tokenizer=tokenizer,
        max_seq_length=MAX_LENGTH,
    )
else:
    model = PeftModel.from_pretrained(
        base_model,
        MODEL_NAME,
        tokenizer=tokenizer,
        max_seq_length=MAX_LENGTH,
        revision= REVISION   
    )

model.eval()
device = next(model.parameters()).device

![Gradio Based MOM](../images/digital_guru_stock_discussion.png)

In [None]:
import gradio as gr
import torch

# Name for the model's personality
BOT_NAME = "Guru"

# --- System prompt (finish the string!) ---
# SYSTEM_PROMPT = (
#     "You are Guru in this conversation. Respond only as Guru would in a realistic text message exchange.\n"
#     "Write naturally, using the tone, style, and pacing of everyday messaging.\n"
#     "Avoid repetition at all cost and always answer in brief, no more than few lines.\n"
#     "If you are unsure or do not know something, say so clearly.\n"
#     "Never make up facts or invent details.\n"
#     "Do not break character."
# )

SYSTEM_PROMPT = ("You are Guru in this conversation. You are snarky. Answer in 2-3 sentences only.\n"
"Write naturally, using the tone and style.\n"
"Avoid repetition.\n"
"Never make up facts or invent details.\n"
"Answer in brief only.")


@torch.inference_mode()
def chat_fn(message, history):
    # history: list of [user_text, assistant_text] pairs
    msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
    for pair in history or []:
        if isinstance(pair, (list, tuple)) and len(pair) == 2:
            u, a = pair
            if u:
                msgs.append({"role": "user", "content": u})
            if a:
                msgs.append({"role": "assistant", "content": a})

    # current user message
    msgs.append({"role": "user", "content": message})

    # turn into a string prompt using the model's chat template
    prompt = tokenizer.apply_chat_template(
        msgs, tokenize=False, add_generation_prompt=True
    )

    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    input_len = inputs.input_ids.shape[-1]

    outputs = model.generate(
        **inputs,
        max_new_tokens=50,
        temperature=0.5,
        top_p=0.9,
        do_sample=True,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
    )

    reply = tokenizer.decode(outputs[0, input_len:], skip_special_tokens=True).strip()
    return reply

demo = gr.ChatInterface(
    fn=chat_fn,
    type="tuples",  # <- classic (message, history) mode
    title=f"Chat with {BOT_NAME}",
    description="Type anything and get a reply from Guru."
)

demo.launch()