In [None]:
import tkinter as tk
from tkinter import ttk
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import ast

class PharMEApp:
    def __init__(self, root):
        self.root = root
        self.root.title("PharME - Personalized Medication Generator")
        self.root.geometry("600x400")
        self.root.configure(bg="white")

        # Load model and tokenizer
        self.model, self.tokenizer = self.load_model_and_tokenizer()

        # Header Frame
        header_frame = tk.Frame(root, bg="#0033cc", height=50)
        header_frame.pack(side="top", fill="x")

        header_label = tk.Label(
            header_frame, text="PharME", font=("Arial", 20, "bold"), bg="#0033cc", fg="white"
        )
        header_label.pack(pady=10)

        # Diagnosis Input Frame
        diagnosis_frame = tk.Frame(root, bg="white")
        diagnosis_frame.pack(pady=20, padx=20, fill="x")

        diagnosis_label = tk.Label(
            diagnosis_frame, text="Enter Diagnosis:", font=("Arial", 14), bg="white", fg="#0033cc"
        )
        diagnosis_label.grid(row=0, column=0, sticky="w")

        self.diagnosis_entry = tk.Entry(diagnosis_frame, font=("Arial", 12), width=30)
        self.diagnosis_entry.grid(row=0, column=1, padx=10, pady=5)

        # Generate Button
        generate_button = tk.Button(
            root, text="Generate Medication", font=("Arial", 14), bg="#0033cc", fg="white",
            command=self.generate_drugs
        )
        generate_button.pack(pady=10)

        # Output Frame
        output_frame = tk.Frame(root, bg="white")
        output_frame.pack(pady=10, padx=20, fill="both", expand=True)

        output_label = tk.Label(
            output_frame, text="Generated Medication:", font=("Arial", 14), bg="white", fg="#0033cc"
        )
        output_label.pack(anchor="w")

        self.output_text = tk.Text(output_frame, font=("Arial", 12), height=10, bg="#f0f8ff", wrap="word")
        self.output_text.pack(fill="both", expand=True, padx=10, pady=5)

    def load_model_and_tokenizer(self):
        """Load the Medical-Llama3-8B model and tokenizer with 4-bit quantization."""
        model_name = "ruslanmv/Medical-Llama3-8B"

        # Configure 4-bit quantization
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16,
        )

        try:
            model = AutoModelForCausalLM.from_pretrained(
                model_name,
                quantization_config=bnb_config,
                trust_remote_code=True,
                use_cache=False,
                device_map="auto"
            )
            tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
            tokenizer.pad_token = tokenizer.eos_token
            return model, tokenizer
        except Exception as e:
            print(f"Error loading model: {e}")
            return None, None

    def generate_drugs(self):
        """Generate medication based on diagnosis using the model."""
        diagnosis = self.diagnosis_entry.get().strip()
        self.output_text.delete("1.0", tk.END)

        if not diagnosis:
            self.output_text.insert(tk.END, "Please enter a diagnosis to generate medication.")
            return

        if self.model is None or self.tokenizer is None:
            self.output_text.insert(tk.END, "Model not loaded. Unable to generate medication.")
            return

        try:
            prompt = f"Based on the diagnosis '{diagnosis}', provide a list of drug names:"

            # Select device
            device = "cuda" if torch.cuda.is_available() else "cpu"

            # Tokenize input
            inputs = self.tokenizer(
                prompt,
                return_tensors="pt",
                padding=True,
                truncation=True,
                max_length=512
            ).to(device)

            # Generate output
            outputs = self.model.generate(
                inputs.input_ids,
                max_length=100,
                num_beams=5,
                early_stopping=True
            )

            # Decode output
            generated_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

            # Extract drug names
            drug_list = self.extract_drug_names(generated_text)

            # Display output
            if drug_list:
                self.output_text.insert(tk.END, f"Medications for diagnosis '{diagnosis}':\n\n")
                for idx, drug in enumerate(drug_list, start=1):
                    self.output_text.insert(tk.END, f"{idx}. {drug}\n")
            else:
                self.output_text.insert(tk.END, "No specific drug names could be generated.")
        except Exception as e:
            self.output_text.insert(tk.END, f"Error during prediction: {e}")

    def extract_drug_names(self, text):
        """
        Extract drug names from the generated text.
        If the model returns a Python-like list of drug names (e.g., ['Paracetamol', 'Ibuprofen']),
        we can parse it directly. Otherwise, we can fall back to line-by-line extraction.
        """
        # First, try to find a Python-like list in the text
        # For example, the text might contain: "['Paracetamol', 'Ibuprofen', 'Acetaminophen']"
        start = text.find("[")
        end = text.find("]", start)
        if start != -1 and end != -1:
            list_str = text[start:end+1]  # This should give us the substring like "['Paracetamol', 'Ibuprofen', ...]"
            try:
                drug_names = ast.literal_eval(list_str)
                if isinstance(drug_names, list):
                    return [drug.strip() for drug in drug_names if drug.strip()]
            except (SyntaxError, ValueError):
                pass  # Fall through to the default line-based extraction

        # If no bracketed list was found or failed to parse, do the default extraction
        lines = text.split("\n")
        drug_names = []
        for line in lines:
            line = line.strip()
            # If the line has alphabetical characters and isn't empty, consider it a drug entry
            if line and any(char.isalpha() for char in line):
                # Remove common bullet characters
                line = line.lstrip("-• ")
                drug_names.append(line)
        return drug_names


if __name__ == "__main__":
    root = tk.Tk()
    app = PharMEApp(root)
    root.mainloop()
