<a href="https://colab.research.google.com/github/Sug-ar-N-Spice/Dr.Chats/blob/Patricia/Patricia_Dr_chat_pre.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ! pip install sacremoses
# ! pip install transformers
# ! pip install datasets
# ! pip install torch
#!pip install -q gradio

In [4]:
import pandas as pd
import re
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM
#import gradio as gr

# from google.colab import drive
# drive.mount('/content/drive')

from typing import List, Dict
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from datasets import load_dataset




In [3]:

class STIChatbot:
    def __init__(self, summary_model="FalconsAI/medical_summarization",
                 qa_model="microsoft/BioGPT"):
        self.preprocessor = STIDataPreprocessor()
        self.summary_tokenizer = AutoTokenizer.from_pretrained(summary_model)
        self.summary_model = AutoModelForSeq2SeqLM.from_pretrained(summary_model)
        self.qa_tokenizer = AutoTokenizer.from_pretrained(qa_model)
        self.qa_model = AutoModelForCausalLM.from_pretrained(qa_model)

    def preprocess_data(self, df):
        return self.preprocessor.preprocess_dataframe(df)

    def generate_summary(self, text):
        inputs = self.summary_tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
        summary_ids = self.summary_model.generate(
            inputs["input_ids"],
            max_new_tokens=150,
            min_length=40,
            length_penalty=2.0,
            num_beams=4,
            early_stopping=True
        )
        return self.summary_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    def answer_question(self, context, question):
        prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
        inputs = self.qa_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)

        # Calculate available tokens for the answer
        input_length = inputs["input_ids"].shape[1]
        max_new_tokens = min(100, 1024 - input_length)  # Assuming 1024 is the model's maximum context length

        output = self.qa_model.generate(
            inputs["input_ids"],
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            do_sample=True,
            temperature=0.7
        )
        return self.qa_tokenizer.decode(output[0], skip_special_tokens=True)



In [12]:

##STOP WORDS IN NLP DONT MEAN ANYTHING LIKE WE THEY THEY JUST COMPLETE THE SENTENCE

## THIS IS CLASS THAT Cleans the data related to STIs
class STIDataPreprocessor:
    """
    Preprocessor for Sexually Transmitted Infections (STI) medical data.

    This class handles cleaning, normalization, and preparation of text data
    related to STIs for use in a medical chatbot.
    """

    def __init__(self):
        """Initialize the preprocessor with necessary NLTK downloads."""
        nltk.download('punkt', quiet=True)
        nltk.download('stopwords', quiet=True)
        self.stop_words = set(stopwords.words('english'))

        # STI-specific terms to keep even if they're in stopwords
        self.sti_terms = {'hiv', 'aids', 'std', 'sti', 'hpv', 'hsv'}
        self.stop_words = self.stop_words - self.sti_terms

    def clean_text(self, text: str) -> str: #str) -> str: this part is more developed skill level can do without in python but in Ctt C sharp need it
    #"""" is a multi-line comment


        """
        Clean and normalize the input text.

        Args:
            text (str): Raw input text

        Returns:
            str: Cleaned and normalized text
        """

        if pd.isna(text):
            return ""

        # Convert to lowercase
        text = text.lower()

        # Remove special characters but keep medical symbols
        text = re.sub(r'[^a-zA-Z0-9\s+\-/%]', '', text)

        # Remove extra whitespace
        text = ' '.join(text.split())

        return text

    def remove_stopwords(self, text: str) -> str:
        """
        Remove stopwords from the text, keeping STI-specific terms.

        Args:
            text (str): Input text

        Returns:
            str: Text with stopwords removed
        """
        words = word_tokenize(text) ### this is resulting in a list of words was converting sentence / paragraph into a list of words

        filtered_words = [word for word in words if word.lower() not in self.stop_words]
        return ' '.join(filtered_words)

    def preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess the entire dataframe.

        Args:
            df (pd.DataFrame): Input dataframe with 'train' and 'results' columns

        Returns:
            pd.DataFrame: Preprocessed dataframe
        """
        # change columns names depending which csv file you are using cleaning text and removing stopwords
        df['clean_abstract'] = df['question'].apply(self.clean_text).apply(self.remove_stopwords)
        df['clean_full_text'] = df['context'].apply(self.clean_text).apply(self.remove_stopwords)

        # Combine cleaned abstract and results or full_texts depending on which csv you are using
        df['combined_text'] = df['clean_abstract'] + ' ' + df['clean_full_text']

        return df

    def prepare_for_model(self, text: str, max_length: int = 512) -> str: ##looks at paragraph, cuts the paragraph if more than 512 This takes the sentence splits to words and has a max length
        """
        Prepare text for model input, truncating if necessary.

        Args:
            text (str): Input text
            max_length (int): Maximum number of words

        Returns:
            str: Prepared text
        """
        words = text.split()
        if len(words) > max_length:
            return ' '.join(words[:max_length])
        return text



In [13]:
def preprocess_dataframe(self, df): ###USUALLY YOU SEE SELF IN A CLASS This allows you to code attributes in a class
    """
    Preprocess the entire dataframe.

    Args:
        df (pd.DataFrame or DatasetDict): Input dataframe with 'abstract' and 'results' columns

    Returns:
        pd.DataFrame: Preprocessed dataframe
    """

    # Check if df is a DatasetDict object-- When we try to find dataset Dict- that means we havent converted it to pandas and cleaning whole dataset


    if isinstance(df.DatasetDict):
        # Process each split separately and combine into a DataFrame
        all_data = []
        for split in df:
                    # Assuming all splits have the same columns
                    # Process the columns in each split
            df_split = df[split].to_pandas()  # Convert to DataFrame

            df_split['clean_abstract'] = df_split['question'].apply(self.clean_text).apply(self.remove_stopwords)
            df_split['clean_full_text'] = df_split['context'].apply(self.clean_text).apply(self.remove_stopwords)
            df_split['combined_text'] = df_split['clean_abstract'] + ' ' + df_split['clean_full_text']

            all_data.extend(df_split.to_dict('records')) # Add processed data to the list

        processed_df_1= pd.DataFrame(all_data) # Create a new DataFrame from the combined data

    else:  # If it's a regular DataFrame, process as before
        processed_df_1= df.copy()  # Create a copy to avoid modifying the original DataFrame
        processed_df_1['clean_abstract'] = processed_df_1['question'].apply(self.clean_text).apply(self.remove_stopwords)
        processed_df_1['clean_full_text'] = processed_df_1['context'].apply(self.clean_text).apply(self.remove_stopwords)
        processed_df_1['combined_text'] = processed_df_1['clean_abstract'] + ' ' + processed_df_1['clean_full_text']

    return processed_df_1  # Return the processed data

In [14]:
# Example usage
def main():
    # Load your DataFrame
    df_1 = pd.read_csv('million_sample.csv')  # choose a csv file

    # Initialize the preprocessor
    preprocessor = STIDataPreprocessor()

    # Preprocess the data
    processed_df = preprocessor.preprocess_dataframe(df_1)

    # Prepare a sample for model input
    sample_text = processed_df['combined_text'].iloc[0]
    model_ready_text = preprocessor.prepare_for_model(sample_text)

    print("Sample preprocessed and model-ready text:")
    print(model_ready_text[:200] + "...")  # Print first 200 characters

if __name__ == "__main__":
    main()

Sample preprocessed and model-ready text:
resurgent sodium current mouse cerebellar purkinje neurons fgf14 modulates resurgent sodium current mouse cerebellar purkinje neurons...


In [None]:
!pip install -q gradio
!pip install sacremoses
import gradio as gr
from google.colab import drive
drive.mount('/content/drive')


In [None]:
# Create a function that takes in a message.
# def run(msg):
#     return f'Returning this message: {msg}'


# Create an instance of the Gradio Interface application function with the following parameters.
app = gr.Interface(fn=run, inputs="text", outputs="text")

# Launch the app
app.launch()

In [15]:
  # processed_df_2 = chatbot.preprocess_data(df_2
  #processed_df_2 = chatbot.preprocess_data(df_2)



# df_2 = load_dataset("ruslanmv/ai-medical-dataset")
chatbot = STIChatbot()
processed_df_1 = chatbot.preprocess_data(first_10_rows_df)




# Example usage
# context_1 = processed_df_1['combined_text'].iloc[0]
context_2 = processed_df_1['combined_text']
summary = chatbot.generate_summary("context")
question = "What are the main symptoms of gonorrhea?"
answer = chatbot.answer_question(summary, question)
print(f"Q: {question}\nA: {answer}")

Q: What are the main symptoms of gonorrhea?
A: Context: Question: What are the main symptoms of gonorrhea? Answer: Gonorrhea is a sexually transmitted infection and a cause of urethritis, cervicitis, and pelvic inflammatory disease.


In [None]:
import gradio as gr

# Initialize the chatbot and preprocess the data
chatbot = STIChatbot()
processed_df_1 = chatbot.preprocess_data(first_10_rows_df)

# Function to handle the summary and answering the question
def generate_response(question):
    if not question.strip():
        return "Please enter a valid question."

    try:
        context_2 = processed_df_1['combined_text'].iloc[0]
        summary = chatbot.generate_summary(context_2)
        answer = chatbot.answer_question(summary, question)
        return answer
    except Exception as e:
        return f"An error occurred while processing your request: {str(e)}"

# Create a Gradio interface
iface = gr.Interface(
    fn=generate_response,            # The function to call
    inputs=gr.Textbox(label="Your Question", placeholder="Ask about symptoms..."),  # Input type
    outputs=gr.Textbox(label="Answer"),        # Output type
    title="Symptoms Q&A",  # Title of the interface
    description="Ask about the main symptoms you are concerned with."
)

# Launch the Gradio app
iface.launch(share=True)  # Set share=True if you want a public link


In [None]:
# import gradio as gr

# # Initialize the chatbot and preprocess the data
# chatbot = STIChatbot()
# processed_df_1 = chatbot.preprocess_data(first_10_rows_df)

# # Function to handle the summary and answering the question
# def generate_response(question):
#     try:
#         context_2 = processed_df_1['combined_text'].iloc[0]
#         summary = chatbot.generate_summary(context_2)
#         answer = chatbot.answer_question(summary, question)
#         return answer
#     except Exception as e:
#         return f"An error occurred: {str(e)}"

# # Create a Gradio interface
# iface = gr.Interface(
#     fn=generate_response,            # The function to call
#     inputs=gr.Textbox(label="Your Question"),  # Input type
#     outputs=gr.Textbox(label="Answer"),        # Output type
#     title="Symptoms Q&A",  # Title of the interface
#     description="Ask about the main symptoms you are concerned with."
# )

# # Launch the Gradio app
# iface.launch()
