In [4]:
import pandas as pd
import re
from typing import List, Dict
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

class STIDataPreprocessor:
    """
    Preprocessor for Sexually Transmitted Infections (STI) medical data.
    
    This class handles cleaning, normalization, and preparation of text data
    related to STIs for use in a medical chatbot.
    """
    
    def __init__(self):
        """Initialize the preprocessor with necessary NLTK downloads."""
        nltk.download('punkt', quiet=True)
        nltk.download('stopwords', quiet=True)
        self.stop_words = set(stopwords.words('english'))
        
        # STI-specific terms to keep even if they're in stopwords
        self.sti_terms = {'hiv', 'aids', 'std', 'sti', 'hpv', 'hsv'}
        self.stop_words = self.stop_words - self.sti_terms

    def clean_text(self, text: str) -> str:
        """
        Clean and normalize the input text.
        
        Args:
            text (str): Raw input text
        
        Returns:
            str: Cleaned and normalized text
        """
        if pd.isna(text):
            return ""
        
        # Convert to lowercase
        text = text.lower()
        
        # Remove special characters but keep medical symbols
        text = re.sub(r'[^a-zA-Z0-9\s+\-/%]', '', text)
        
        # Remove extra whitespace
        text = ' '.join(text.split())
        
        return text

    def remove_stopwords(self, text: str) -> str:
        """
        Remove stopwords from the text, keeping STI-specific terms.
        
        Args:
            text (str): Input text
        
        Returns:
            str: Text with stopwords removed
        """
        words = word_tokenize(text)
        filtered_words = [word for word in words if word.lower() not in self.stop_words]
        return ' '.join(filtered_words)

    def preprocess_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
        """
        Preprocess the entire dataframe.
        
        Args:
            df (pd.DataFrame): Input dataframe with 'abstract' and 'results' columns
        
        Returns:
            pd.DataFrame: Preprocessed dataframe
        """
        # change columns names depending which csv file you are using
        df['clean_abstract'] = df['abstract'].apply(self.clean_text).apply(self.remove_stopwords)
        df['clean_full_text'] = df['full_text'].apply(self.clean_text).apply(self.remove_stopwords)
        
        # Combine cleaned abstract and results or full_texts depending on which csv you are using
        df['combined_text'] = df['clean_abstract'] + ' ' + df['clean_full_text']
        
        return df

    def prepare_for_model(self, text: str, max_length: int = 512) -> str:
        """
        Prepare text for model input, truncating if necessary.
        
        Args:
            text (str): Input text
            max_length (int): Maximum number of words
        
        Returns:
            str: Prepared text
        """
        words = text.split()
        if len(words) > max_length:
            return ' '.join(words[:max_length])
        return text



In [6]:
# Example usage
def main():
    # Load your DataFrame
    df = pd.read_csv('pmc_dataset.csv')  # choose a csv file
    
    # Initialize the preprocessor
    preprocessor = STIDataPreprocessor()
    
    # Preprocess the data
    processed_df = preprocessor.preprocess_dataframe(df)
    
    # Prepare a sample for model input
    sample_text = processed_df['combined_text'].iloc[0]
    model_ready_text = preprocessor.prepare_for_model(sample_text)
    
    print("Sample preprocessed and model-ready text:")
    print(model_ready_text[:200] + "...")  # Print first 200 characters

if __name__ == "__main__":
    main()

Sample preprocessed and model-ready text:
case report explores challenges associated management chronic medical diseases patients psychiatric disorders patient 36-year-old female patient diagnosed aids multiple secondary infections major depr...


In [7]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForCausalLM

class STIChatbot:
    def __init__(self, summary_model="FalconsAI/medical_summarization", 
                 qa_model="microsoft/BioGPT"):
        self.preprocessor = STIDataPreprocessor()
        self.summary_tokenizer = AutoTokenizer.from_pretrained(summary_model)
        self.summary_model = AutoModelForSeq2SeqLM.from_pretrained(summary_model)
        self.qa_tokenizer = AutoTokenizer.from_pretrained(qa_model)
        self.qa_model = AutoModelForCausalLM.from_pretrained(qa_model)

    def preprocess_data(self, df):
        return self.preprocessor.preprocess_dataframe(df)

    def generate_summary(self, text):
        inputs = self.summary_tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
        summary_ids = self.summary_model.generate(
            inputs["input_ids"], 
            max_new_tokens=150,
            min_length=40,
            length_penalty=2.0,
            num_beams=4,
            early_stopping=True
        )
        return self.summary_tokenizer.decode(summary_ids[0], skip_special_tokens=True)

    def answer_question(self, context, question):
        prompt = f"Context: {context}\nQuestion: {question}\nAnswer:"
        inputs = self.qa_tokenizer(prompt, return_tensors="pt", truncation=True, max_length=512)
        
        # Calculate available tokens for the answer
        input_length = inputs["input_ids"].shape[1]
        max_new_tokens = min(100, 1024 - input_length)  # Assuming 1024 is the model's maximum context length
        
        output = self.qa_model.generate(
            inputs["input_ids"],
            max_new_tokens=max_new_tokens,
            num_return_sequences=1,
            do_sample=True,
            temperature=0.7
        )
        return self.qa_tokenizer.decode(output[0], skip_special_tokens=True)



In [8]:
# Usage
chatbot = STIChatbot()
df = pd.read_csv('pmc_dataset.csv')
processed_df = chatbot.preprocess_data(df)

# Example usage
context = processed_df['combined_text'].iloc[0]
summary = chatbot.generate_summary(context)
question = "What are the main symptoms of chlamydia?"
answer = chatbot.answer_question(summary, question)
print(f"Q: {question}\nA: {answer}")

Q: What are the main symptoms of chlamydia?
A: Context: background management chronic medical diseases patients psychiatric disorders patient 36-year-old female patient diagnosed aids multiple secondary infections major depressive disorder anxiety polysubstance abuse simultaneous occurrence physical mental health conditions presents unique obstacles managing chronic medical diseases goal report discuss patients medical history psychosocial factors interventions outcomes provide insights future patient care report also illuminate relationship psychiatric disorders diminished health maintenance human immunodeficiency virus hiv predominantly sexually transmitted infection destroys cd4 + lymphocyte cells acquired immunodeficiency syndrome aids final disease stage original viral infection diagnosis and / or presence aids-defining illness complex treatment plan patient included management psychiatric medical diagnoses upon admission patient given diphenhydramine Question: What are the main s