In [13]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import re
import os
import tkinter as tk
from tkinter import scrolledtext

In [14]:
from datasets import load_dataset
dataset = load_dataset("M-A-D/Mixed-Arabic-Dataset-Main")


In [15]:

# Suppress huggingface_hub symlink warning
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("aubmindlab/aragpt2-base")
model = AutoModelForCausalLM.from_pretrained("aubmindlab/aragpt2-base")

In [32]:
# Normalize Arabic text
def normalize_arabic(text):
    text = re.sub(r'[ًٌٍَُِّْ]', '', text)  # Remove diacritics
    text = re.sub(r'ـ', '', text)  # Remove tatweel
    text = re.sub(r'[إأآ]', 'ا', text)  # Normalize alif variants
    return text

# Autocomplete function (runs on every key release)
def autocomplete(event=None):
    prefix = input_field.get().strip()
    if not prefix:
        clear_suggestions()
        return
    
    # Normalize and tokenize input
    prefix = normalize_arabic(prefix)
    inputs = tokenizer(prefix, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
        top_k_probs, top_k_indices = torch.topk(probs, k=3)
    
    # Decode top 3 suggestions
    top_words = []
    for idx in top_k_indices[0]:
        token = tokenizer.decode([idx], skip_special_tokens=True).strip()
        if token:
            top_words.append(token)
    top_words += [""] * (3 - len(top_words))

    # Update suggestion buttons
    for i, (btn, word) in enumerate(zip(suggestion_buttons, top_words)):
        if word:
            btn.config(text=word, state="normal", command=lambda w=word: append_word(w))
        else:
            btn.config(text="", command=lambda: None, state="disabled")

def append_word(word):
    current_text = input_field.get().strip()
    new_text = f"{current_text} {word}".strip()
    input_field.delete(0, tk.END)
    input_field.insert(0, new_text)
    
    

def send_message():
    message = input_field.get().strip()
    if message:
        message_box.config(state="normal")
        message_box.insert("1.0", f" {message}\n\n")  # Insert at the top
        message_box.config(state="disabled")
        input_field.delete(0, tk.END)
        
        


def clear_suggestions():
    for btn in suggestion_buttons:
        btn.config(text="", command=lambda: None, state="disabled")

# GUI setup
root = tk.Tk()
root.title("Arabic Chat Autocomplete")
root.geometry("500x600")
root.configure(bg="#e5ddd5")

# Chat display area
message_box = tk.Text(root, height=20, width=60, wrap="word", font=("Arial", 12), bg="#f7f7f7")
message_box.pack(pady=10, padx=10)
message_box.config(state="disabled")

# Suggestion buttons
suggestion_frame = tk.Frame(root, bg="#e5ddd5")
suggestion_frame.pack(pady=5)
suggestion_buttons = []
for i in range(3):
    btn = tk.Button(suggestion_frame, text="", font=("Arial", 12), width=12, state="disabled", bg="#ffffff")
    btn.pack(side=tk.RIGHT, padx=5)
    suggestion_buttons.append(btn)

# Input frame 
input_frame = tk.Frame(root, bg="#ffffff", pady=10)
input_frame.pack(fill=tk.X, padx=10, pady=10)

input_field = tk.Entry(input_frame, width=30, font=("Arial", 14), justify="right")
input_field.pack(side=tk.LEFT, padx=5, fill=tk.X, expand=True)
input_field.bind("<KeyRelease>", autocomplete)

send_button = tk.Button(input_frame, text="Append", font=("Arial", 12), command=send_message, bg="#25d366", fg="white")
send_button.pack(side=tk.RIGHT, padx=5)
clear_suggestions()



root.option_add("*Font", "Arial 12")
root.mainloop()