In [None]:
# Standard Libraries
import os
import re
import subprocess
import traceback
import pickle
# import spacy
from collections import Counter

# Data Handling and Processing
import pandas as pd
import numpy as np
import nltk
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns  

# Scikit-Learn: Preprocessing and Model Selection
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler, OneHotEncoder, LabelEncoder
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import MultiLabelBinarizer


# Scikit-Learn: Models
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from xgboost import XGBClassifier

# Scikit-Learn: Evaluation Metrics
from sklearn.metrics import (
    accuracy_score,
    confusion_matrix,
    classification_report,
    roc_auc_score
)

# Scikit-Learn: Class Weights
from sklearn.utils.class_weight import compute_class_weight

# Scikit-Learn: Feature Engineering
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.metrics.pairwise import cosine_similarity

# Transformers
from transformers import AutoModelForTokenClassification, AutoTokenizer
from transformers import pipeline  
from sentence_transformers import SentenceTransformer

# Torch
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

# Imbalanced Data Handling
from imblearn.combine import SMOTEENN

# TensorFlow & Keras
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import (
    Input, Dense, Dropout, Concatenate, BatchNormalization
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import (
    EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
)
from tensorflow.keras.regularizers import l2
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, LSTM
from tensorflow.keras.preprocessing.text import Tokenizer


import logging
from tqdm import tqdm


# Joblib (for saving/loading models)
import joblib

In [None]:
nltk.download('punkt')

In [5]:
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:

def check_emergency_symptoms(symptoms):
    """Check for emergency symptoms that require immediate medical attention"""
    emergency_keywords = [
        'chest pain', 'heart attack', 'stroke', 'unconscious', 'breathing difficulty',
        'severe bleeding', 'head injury', 'suicide', 'poisoning', 'overdose',
        'seizure', 'severe burn', 'gunshot', 'drowning', 'choking',
        'anaphylaxis', 'allergic shock', 'coughing blood', 'severe trauma',
        'loss of vision', 'paralysis', 'severe abdominal pain'
    ]
    
    for keyword in emergency_keywords:
        if keyword in symptoms.lower():
            return True
    return False

def get_emergency_message():
    return """
    EMERGENCY MEDICAL ATTENTION NEEDED
    ---------------------------------
    Based on the symptoms you've described, you should seek immediate medical attention:
    
    1. Call emergency services (911 in the US) or your local emergency number
    2. Go to the nearest emergency room
    3. Do not wait for symptoms to improve on their own
    
    This is not a situation for an AI chatbot. Please seek professional medical help immediately.
    """

In [None]:
def generate_medical_response(user_input, model, label_encoder, embeddings_model):
    """Generate medical response based on user symptoms"""
    try:
        # Check for emergency symptoms first
        if check_emergency_symptoms(user_input):
            return get_emergency_message()
        
        # Create embedding for user input
        user_embedding = embeddings_model.encode(
            [user_input],
            batch_size=1,
            show_progress_bar=False,
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )
        
        # Find similar cases using cosine similarity
        similarities = cosine_similarity(user_embedding, embeddings)
        top_indices = similarities[0].argsort()[-3:][::-1]  # Get top 3 similar cases
        
        # Generate response
        response = "Based on your symptoms, here's my analysis:\n\n"
        
        # Add similar cases analysis
        response += "Similar Cases and Their Diagnoses:\n"
        for idx, similarity in zip(top_indices, similarities[0][top_indices]):
            if similarity > 0.2:  # Threshold for relevance
                response += f"\nSimilarity Score: {similarity:.2f}\n"
                response += f"Diagnosis: {healthCareMagic_df.iloc[idx]['prognosis']}\n"
                response += f"Doctor's Response: {healthCareMagic_df.iloc[idx]['output']}\n"
                response += "-" * 80 + "\n"
        
        # Add disclaimer
        response += "\nIMPORTANT DISCLAIMER:\n"
        response += "- This is an AI-generated response based on similar medical cases.\n"
        response += "- This should not be considered as professional medical advice.\n"
        response += "- Please consult with a qualified healthcare provider for proper diagnosis and treatment.\n"
        
        return response
    
    except Exception as e:
        logger.error(f"Error generating response: {str(e)}")
        return "I apologize, but I encountered an error while processing your symptoms. Please try again or consult a healthcare provider."


In [None]:
def interactive_medical_consultant(model, label_encoder, embeddings_model, embeddings):
    """Interactive medical consultation interface"""
    print("\nMedical Symptom Consultant")
    print("-------------------------")
    print("Please describe your symptoms in detail.")
    print("Type 'quit' to exit.")
    
    while True:
        try:
            user_input = input("\nDescribe your symptoms: ").strip()
            
            if user_input.lower() in ['quit', 'exit', 'bye']:
                print("\nThank you for using the Medical Symptom Consultant. Take care!")
                break
            
            if len(user_input) < 10:
                print("Please provide more details about your symptoms for better analysis.")
                continue
            
            response = generate_medical_response(
                user_input, 
                model, 
                label_encoder, 
                embeddings_model
            )
            print("\n", response)
            
        except Exception as e:
            logger.error(f"Consultation error: {str(e)}")
            print("\nAn error occurred while processing your request.")
            print("Please try again or seek medical attention if you're concerned.")


In [None]:

# def check_emergency_symptoms(symptoms):
#     """Check for emergency symptoms that require immediate medical attention"""
#     emergency_keywords = [
#         'chest pain', 'heart attack', 'stroke', 'unconscious', 'breathing difficulty',
#         'severe bleeding', 'head injury', 'suicide', 'poisoning', 'overdose',
#         'seizure', 'severe burn', 'gunshot', 'drowning', 'choking',
#         'anaphylaxis', 'allergic shock', 'coughing blood', 'severe trauma',
#         'loss of vision', 'paralysis', 'severe abdominal pain'
#     ]
    
#     for keyword in emergency_keywords:
#         if keyword in symptoms.lower():
#             return True
#     return False

# def get_emergency_message():
#     return """
#     EMERGENCY MEDICAL ATTENTION NEEDED
#     ---------------------------------
#     Based on the symptoms you've described, you should seek immediate medical attention:
    
#     1. Call emergency services (911 in the US) or your local emergency number
#     2. Go to the nearest emergency room
#     3. Do not wait for symptoms to improve on their own
    
#     This is not a situation for an AI chatbot. Please seek professional medical help immediately.
#     """

In [6]:
# First, let's read the dataset
healthCareMagic_df = pd.read_csv('Pivot_Resource/HealthCareMagic-100k.csv')


In [None]:
healthCareMagic_df.head()

In [None]:
healthCareMagic_df.shape

In [9]:
medical_conditions = {
"abortion": "Abortion",
"abscess": "Dental Abscess",
"acne": "Acne Vulgaris",
"acid reflux":"acid reflux",
    "allergy": "Allergic Reaction",
"allergic rhinitis": "Allergic rhinitis",
    "allergic rash": "Allergic Rash",
    "allergic reaction": "Allergic Reaction",
    "anemia": "Iron Deficiency Anemia",
"angina": "Angina Pectoris (Chest Pain Due to Reduced Blood Flow to the Heart)",
    "anxiety": "Generalized Anxiety Disorder",
    "anxiety-related disorders": "Generalized Anxiety Disorder",
"arterial insufficiency": "Arterial Insufficiency (Reduced Blood Flow Due to Blocked or Narrowed Arteries)",
    "arthritis": "Rheumatoid Arthritis",
    "arrhythmia": "Arrhythmia",
    "asthma": "Chronic Asthma",
    "atrial fibrillation": "Atrial Fibrillation",
    "autism": "Autism Spectrum Disorder",
    "Autonomic Neuropathy":"Autonomic Neuropathy",
    "autonomic neuropathy": "Autonomic Neuropathy",
"baby":"fertility",
    "birth control": "Birth Control",
    "bipolar": "Bipolar Disorder",
    "bronchitis": "Bronchitis",
    "bursitis": "Bursitis",
    "breast cancer": "Breast Cancer",
"body aches": "Myalgia (Body Aches)",
"bruise": "Contusion",
"candida": "Candida Infection",
"cardiac":"heart related",
    "complex partial seizure": "Complex Partial Seizure (CPS)",

"conception": "Paternity",

"conceive": "Infertility (Difficulty in Conceiving)",

"conceiving": "Infertility (Difficulty in Conceiving)",
    "carpal tunnel syndrome": "Carpal Tunnel Syndrome",
    "cancer": "Cancer",
    "cataracts": "Cataracts",
    "celiac disease": "Celiac Disease",
    "chickenpox": "chickenpox",
    "cholesterol":"Hyperlipidemia",
    "chronic back pain": "Chronic Back Pain",
    "chronic bronchitis": "Chronic Bronchitis",
    "chronic fatigue": "Chronic Fatigue Syndrome",
    "chronic kidney disease": "chronic kidney disease",
    "chronic migraine": "Chronic Migraines",
    "chronic obstructive pulmonary disease": "Chronic Obstructive Pulmonary Disease",
    "chronic psoriasis": "Chronic Psoriasis",
    "chronic sinusitis": "Chronic Sinusitis",
        "chronic cough": "Chronic Cough",
    "cyst":"cyst",
    "cystic fibrosis": "Cystic Fibrosis",
    "conception": "Conception",
     "cold":"Common Cold",
"de quarries tenosynovitis":"De Quarries tenosynovitis",
    "diarrhea": "Diarrhea",
    "dehydration": "Dehydration",
    "dementia": "Alzheimer's Disease",
    "depression": "Major Depressive Disorder",
    "diabetes": "Diabetes",
    "diabetes (type 1 and type 2)": "Type 2 Diabetes",
    "dyslexia": "Dyslexia",
    "eczema": "Atopic Dermatitis",
    "epidermal cyst": "Sebaceous Cyst",
    "epilepsy": "Epileptic Seizures",
"erectile dysfunction": "Erectile Dysfunction",

"endometriosis": "Endometriosis",
  "fracture": "Fracture",
    "fever":"fever",
    "fibroid": "Uterine Fibroids",
    "fibromyalgia": "Fibromyalgia",
    "folate deficiency": "Folate Deficiency Anemia",
"gas": "gas or bloating",
    "gastritis": "Gastritis",
    "gallbladder": "gallbladder",
    "gallstones": "gallstones",
    "gastroesophageal reflux disease": "Gastroesophageal Reflux Disease",
    "gestational diabetes": "Gestational Diabetes Mellitus",
    "glaucoma": "Primary Open-Angle Glaucoma",
    "gout": "Gouty Arthritis",
"gingivitis": "Gingivitis",
"heart flutters": "Arrhythmia",
"herniated disc": "Herniated Disc",

    "heart disease": "Coronary Artery Disease",
    "hepatitis": "Hepatitis",
    "hernia": "Hernia",
    
    "hiv": "HIV/AIDS",
    "hodgkin's disease": "Hodgkin's Disease",
    "hyperglycemia": "Hyperglycemia",
    "hyperhidrosis": "Hyperhidrosis",
    "hyperthyroidism": "Graves' Disease",
    "hypertension": "Hypertension",
    "hypertension (high blood pressure)": "Hypertension",
    "hyperpigmentation":"Hyperpigmentation",
    "hypoglycemia": "Hypoglycemia",
    "hypothyroidism": "Hypothyroidism",
"hives": "Hives (Raised, Itchy Skin Rash Caused by Allergic Reactions, Also Known as Urticaria)",
"irregular heart beats": "Arrhythmia",
    "irregular menstruation": "Menstrual Irregularities",
    "ivf":"In Vitro Fertilization",
        "IVF":"In Vitro Fertilization",
"infertility":"Infertility",
    "infection": "Infection",
    "infected boil": "Infected Boil",
    "influenza": "Seasonal Influenza",
    "insomnia": "Chronic Insomnia",
    "irritable bowel syndrome": "Irritable Bowel Syndrome",
        "kidney stones": "Kidney Stones",

    "kidney disease": "Chronic Kidney Disease",
    "kidney infection": "Pyelonephritis",
    "lumbar disc prolapse": "Lumbar Disc Prolapse",

    "lung disease": "Lung Disease",
    "liver disease": "Non-alcoholic Fatty Liver Disease",
    "lupus": "Systemic Lupus Erythematosus",
    "lymphoma": "Lymphoma",
    "male pattern baldness": "Male Pattern Baldness",

"mono": "Mononucleosis (Infectious Mononucleosis, Often Caused by Epstein-Barr Virus)",
"multi infarct dementia": "Multi Infarct Dementia (Cognitive Decline Due to Multiple Small Strokes)",
"multi infarct dementiafurther": "Multi Infarct Dementia (Cognitive Decline Due to Multiple Small Strokes)",

"male baldness": "Male Pattern Baldness",
    "measles": "Measles",
    "meningitis": "Meningitis",
    "menopause": "Menopausal Syndrome",
    "micronutrient deficiency": "Micronutrient Deficiency",
"microvascular ischemia": "Microvascular Ischemia (Reduced Blood Flow in the Small Blood Vessels of the Heart, Not Visible on Angiography)",
"MRSA": "Methicillin-Resistant Staphylococcus Aureus",
"mitral valve prolapse": "Mitral Valve Prolapse",


    "migraine": "Chronic Migraine",
    "migraine headaches": "Chronic Migraines",
"muscular pain": "Musculoskeletal Pain (Pain Due to Muscle Strain or Tension)",
"muscular strain": "Muscular Strain (Injury or Damage to Muscle Tissue Due to Overstretching or Excessive Stress)",
    "multiple sclerosis": "Multiple Sclerosis",
    "mumps": "Mumps",
"mesenteric adenopathy":"mesenteric adenopathy",
    "nephritis": "Glomerulonephritis",
    "neurofibromatosis": "Neurofibromatosis Type 1",
    "obesity": "Obesity-related Metabolic Syndrome",
    "osteomyelitis": "Osteomyelitis",
    "osteoporosis": "Postmenopausal Osteoporosis",
    "ovarian cysts": "Ovarian Cysts",
    "parkinson's": "Parkinson's Disease",
    "peanut allergy": "Peanut Allergy",
    "peptic ulcers": "Peptic Ulcer Disease",
    "periventricular leukoplakia": "Periventricular Leukoplakia",
"pleural effusion": "Pleural Effusion",
    "pneumonia": "Pneumonia",
    "pneumothorax": "Spontaneous Pneumothorax",
    "postpartum depression": "Postpartum Depression",
    "post-traumatic stress disorder": "Post-traumatic Stress Disorder",
"premature ejaculation": "Premature Ejaculation",
"pilonidal cyst": "Pilonidal Cyst",
    "pregnancy": "Pregnancy",
"pressure sore": "Pressure Sore",

    "preeclampsia": "Preeclampsia",
    "psoriasis": "Chronic Psoriasis",
    "pulmonary embolism": "Pulmonary Embolism",
"pulmonary contusions": "Pulmonary Contusions (Bruising of the Lung Tissue, Often Due to Trauma)",
    "pcod": "Polycystic Ovary Disease (PCOD)",
"pcos": "Polycystic Ovary Syndrome (PCOS)",
"pleurisy": "Pleurisy (Inflammation of the Pleura, the Membranes Surrounding the Lungs)",
    "rabies": "Rabies",
"rabis":"Rabies",
"root canal": "Root Canal",
    "rotator cuff injury": "Rotator Cuff Injury",
"retained products of conception": "Retained Products of Conception (Tissue Left in the Uterus After Abortion or Miscarriage)",
"renal stone":"Renal Stone",
    "renal calculus": "Kidney Stones",

"scar rupture": "Uterine Scar Rupture",
    "seizure": "Seizure",

"swollen forehead": "Sinusitis",

"sun stroke": "Heat Stroke",


"sternum injury": "Sternal Injury",

"skin flora": "Skin Flora",


"scabies mite": "Scabies Mite (Sarcoptes scabiei)",

    "scoliosis": "Idiopathic Scoliosis",
    "sciatica": "Sciatica",

    "sepsis": "Sepsis",
    "shingles": "Shingles",
    "sinusitis": "Chronic Sinusitis",
    "sjögren's syndrome": "Sjögren's Syndrome",
    "skin infections": "Bacterial Skin Infections",
    "sleep disorders": "Insomnia",
    "sickle cell anemia": "Sickle Cell Disease",
    "stroke": "Ischemic Stroke",
    "systolic hypertension": "Essential Hypertension",
    "sprain": "Sprain",

    "std":"Sexually Transmitted Disease",
    
"stent": "Coronary Artery Stent (Used to Treat Blockages in Heart Arteries)",
"surgery":"surgery related complications",
"swelling": "Swelling",
"TIA": "Transient Ischemic Attack (Mini-Stroke, Temporary Blockage of Blood Flow to the Brain)",

  
"tooth":"Dental Issues",
    "tetanus": "Tetanus",
"teeth":"Dental Issues",
    "tachycardia": "Tachycardia",
    "testicular cancer": "Testicular Cancer",
    "thrombophilia": "Thrombophilia",
    "thyroid": "Hypothyroidism",
    "tonsillitis": "Tonsillitis",
    "tuberculosis": "Pulmonary Tuberculosis",
    "tinnitus": "Tinnitus",
    "transient ischemic attack": "Transient Ischemic Attack (TIA)",
"urticaria": "Urticaria (Hives or Skin Rash Caused by an Allergic Reaction, Leading to Itchy, Raised Skin)",
    "ulcer": "Ulcer",
    "urinary tract infection": "Urinary Tract Infection (UTI)",
    "vitiligo": "Vitiligo",
    "vitamin d deficiency": "Osteomalacia",
    "wart": "Viral Warts",
    "warts": "Viral Warts",
    "vertigo": "Vertigo",
    "vaginitis": "Vaginal Infections",
"viral illness": "Viral Infection",


    "viral hepatitis": "Viral Hepatitis",
    "viral gastroenteritis": "Viral Gastroenteritis",
        "wheezing": "Wheezing",
    "yeast infection": "Yeast Infection",
    "zika virus": "Zika Virus",






    "heavy period":"menstrual irregularities",
"heavy periods":"menstrual irregularities",
"neuropathy":"neuropathy",
"autonomic neuropathy2":"Autonomic Neuropathy",
"pinched nerve":"pinched nerve",

"low self-esteem":"Psychoactive Misalignment",
"low self-confidence":"Psychoactive Misalignment",

"supraspinatus tendinitis":"Rotator Cuff Injury",

"biceps tendinitis":"Rotator Cuff Injury", 
"pressure headache":"Pressure Headache",
"h pilori desease":"H. Pylori Infection",
"h. pylori":"H. Pylori Infection",

"hematoma":"Hematoma",
"cervical radiculopathy":"Cervical Radiculopathy",
"ms":"Multiple Sclerosis",
"excessive fatigue":"Excessive Fatigue",
"irregularities in heart rhythm":"Arrhythmia",
"allergic component":"Allergic Reaction",

"orthopedic":"Orthopedic Related Issues",
"haematospermia":"Haematospermia",

"fainted":"Fainting",

"blood in my urine":"Bloody Urine",
"iron deficiency":"Iron Deficiency",



"blood pressure":"Blood Pressure",

"fatty liver":"Fatty Liver",

"low hemoglobin":"Low Hemoglobin",


"indigestion":"Indigestion",
"acidity":"Acidity",

"liver":"Liver Related Issues",

}

## consider best prognosis and secondary prognosis - Ludwe

In [10]:

def extract_conditions_with_count(text):
    """Extract medical conditions and count their occurrences in text"""
    if not isinstance(text, str):
        return Counter()
    
    # Preprocess the text
    text = text.lower()
    
    # Create counter for conditions
    condition_counts = Counter()
    
    # Check for each condition in the medical_conditions dictionary
    for condition in medical_conditions.keys():
        # Search for the condition as a whole word
        matches = re.findall(r'\b' + re.escape(condition) + r'\b', text)
        if matches:
            condition_counts[condition] += len(matches)
    
    return condition_counts

In [11]:
def get_highest_prognosis(row):
    """Find the highest count condition from both input and output text"""
    # Combine input and output text
    combined_text = f"{row['input']} {row['output']}"
    
    # Get condition counts
    condition_counts = extract_conditions_with_count(combined_text)
    
    # If no conditions found
    if not condition_counts:
        return "Unknown"
    
    # Find condition with highest count
    highest_condition = condition_counts.most_common(1)[0][0]
    
    # Return corresponding prognosis
    return medical_conditions.get(highest_condition, "Unknown")


In [None]:
# Apply the function to create new column
print("Processing medical conditions...")
healthCareMagic_df['prognosis'] = healthCareMagic_df.apply(get_highest_prognosis, axis=1)

# Display summary statistics
print("\nPrognosis Distribution:")
prognosis_counts = healthCareMagic_df['prognosis'].value_counts()
print(prognosis_counts.head(20))

# Calculate percentage of specified vs unspecified
total_cases = len(healthCareMagic_df)
specified_cases = len(healthCareMagic_df[healthCareMagic_df['prognosis'] != "Unspecified"])
print(f"\nTotal cases: {total_cases}")
print(f"Cases with specific prognosis: {specified_cases}")
print(f"Percentage specified: {(specified_cases/total_cases)*100:.2f}%")

# Optional: Create visualization
plt.figure(figsize=(15, 6))
prognosis_counts.head(20).plot(kind='bar')
plt.title('Top 20 Prognoses')
plt.xlabel('Prognosis')
plt.ylabel('Count')
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
# Display only relevant columns
healthCareMagic_df[['input', 'output', 'prognosis']].head()

In [None]:
# Create a new DataFrame with just the relevant columns
summary_df = healthCareMagic_df[['input', 'output', 'prognosis']].copy()

# Display settings for better visibility
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)
pd.set_option('display.expand_frame_repr', False)

# save the new dataframe with the prognosis column

# Display the DataFrame
print("\nSummary DataFrame:")
summary_df.head(40)

## Try biobert on input and output columns to validate the values in the prognosis column - Ludwe

In [None]:
# show rows where prognosis is unknown
healthCareMagic_df[healthCareMagic_df['prognosis'] == 'Unknown'].head(40)

In [None]:
# Save the summary DataFrame to CSV
print("Saving summary DataFrame to CSV...")
summary_df.to_csv('healthCareMagic_df_with_prognosis.csv', index=False)

In [17]:
# read in healthCareMagic_df_with_prognosis.csv' 
healthCareMagic_df = pd.read_csv('healthCareMagic_df_with_prognosis.csv')


In [None]:
healthCareMagic_df.head()

In [19]:

# Model architecture
class PrognosisClassifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, dropout_rate=0.3):
        super(PrognosisClassifier, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_dim // 2, output_dim)
        )
    
    def forward(self, x):
        return self.network(x)

In [20]:
def create_and_save_embeddings_in_chunks(df, chunk_size=1000, save_path='medical_embeddings.npz'):
    """Create embeddings in smaller chunks"""
    if os.path.exists(save_path):
        logger.info("Loading existing embeddings...")
        loaded = np.load(save_path)
        return loaded['embeddings']
    
    logger.info("Creating new embeddings in chunks...")
    model = SentenceTransformer('all-MiniLM-L6-v2')
    
    all_embeddings = []
    for i in range(0, len(df), chunk_size):
        logger.info(f"Processing chunk {i//chunk_size + 1}/{len(df)//chunk_size + 1}")
        chunk = df.iloc[i:i+chunk_size]
        combined_texts = [f"{row['input']} [SEP] {row['output']}" 
                         for _, row in chunk.iterrows()]
        
        chunk_embeddings = model.encode(
            combined_texts,
            batch_size=64,
            show_progress_bar=True,
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )
        all_embeddings.append(chunk_embeddings)
    
    embeddings = np.vstack(all_embeddings)
    np.savez_compressed(save_path, embeddings=embeddings)
    return embeddings

In [None]:
class TextEmbeddingDataset(Dataset):
    def __init__(self, embeddings, labels):
        self.embeddings = torch.FloatTensor(embeddings)
        self.labels = torch.LongTensor(labels)
    
    def __len__(self):
        return len(self.embeddings)
    
    def __getitem__(self, idx):
        return self.embeddings[idx], self.labels[idx]

In [22]:
def prepare_data(embeddings, df, test_size=0.2, val_size=0.2):
    """Prepare data for training"""
    # Encode labels
    label_encoder = LabelEncoder()
    labels = label_encoder.fit_transform(df['prognosis'])
    
    # Split data
    X_temp, X_test, y_temp, y_test = train_test_split(
        embeddings, labels, test_size=test_size, random_state=42, stratify=labels
    )
    
    X_train, X_val, y_train, y_val = train_test_split(
        X_temp, y_temp, test_size=val_size, random_state=42, stratify=y_temp
    )
    
    # Create DataLoaders
    train_dataset = TextEmbeddingDataset(X_train, y_train)
    val_dataset = TextEmbeddingDataset(X_val, y_val)
    test_dataset = TextEmbeddingDataset(X_test, y_test)
    
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32)
    test_loader = DataLoader(test_dataset, batch_size=32)
    
    return train_loader, val_loader, test_loader, label_encoder

In [23]:
def train_model(model, train_loader, val_loader, num_epochs=10):
    """Train the model"""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")
    
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters())
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2)
    
    best_val_loss = float('inf')
    train_losses = []
    val_losses = []
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        total_train_loss = 0
        
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            loss.backward()
            optimizer.step()
            
            total_train_loss += loss.item()
        
        avg_train_loss = total_train_loss / len(train_loader)
        train_losses.append(avg_train_loss)
        
        # Validation phase
        model.eval()
        val_loss = 0
        
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                val_loss += criterion(outputs, labels).item()
        
        avg_val_loss = val_loss / len(val_loader)
        val_losses.append(avg_val_loss)
        
        # Learning rate scheduling
        scheduler.step(avg_val_loss)
        
        logger.info(f"Epoch {epoch+1}: Train Loss = {avg_train_loss:.4f}, Val Loss = {avg_val_loss:.4f}")
        
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            torch.save(model.state_dict(), 'best_model.pth')
    
    # Plot training curves
    plt.figure(figsize=(10, 6))
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig('training_curves.png')
    plt.close()
    
    return model

In [24]:
def evaluate_model(model, test_loader, label_encoder, device):
    """Evaluate the model and generate detailed metrics"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Convert numeric labels back to original classes
    pred_classes = label_encoder.inverse_transform(all_preds)
    true_classes = label_encoder.inverse_transform(all_labels)
    
    # Generate classification report
    report = classification_report(true_classes, pred_classes)
    
    # Create confusion matrix
    cm = confusion_matrix(true_classes, pred_classes)
    plt.figure(figsize=(15, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
    plt.title('Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.xticks(rotation=45)
    plt.yticks(rotation=45)
    plt.tight_layout()
    plt.savefig('confusion_matrix.png')
    plt.close()
    
    return report

In [None]:
if __name__ == "__main__":
    # Load dataset
    logger.info("Loading dataset...")
    healthCareMagic_df = pd.read_csv('healthCareMagic_df_with_prognosis.csv')
    
    # Display prognosis distribution
    logger.info("\nPrognosis Distribution:")
    print(healthCareMagic_df['prognosis'].value_counts().head(20))
    
    # Create or load embeddings
    embeddings = create_and_save_embeddings_in_chunks(healthCareMagic_df)
    logger.info(f"Embeddings shape: {embeddings.shape}")
    
    # Prepare data for training
    train_loader, val_loader, test_loader, label_encoder = prepare_data(embeddings, healthCareMagic_df)
    
    # Initialize and train model
    input_dim = embeddings.shape[1]
    hidden_dim = 256
    output_dim = len(label_encoder.classes_)
    
    model = PrognosisClassifier(input_dim, hidden_dim, output_dim)
    model = train_model(model, train_loader, val_loader)
    
    # Evaluate model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    evaluation_report = evaluate_model(model, test_loader, label_encoder, device)
    logger.info("\nModel Evaluation Report:")
    print(evaluation_report)
    
    # Save model and components
    torch.save({
        'model_state_dict': model.state_dict(),
        'input_dim': input_dim,
        'hidden_dim': hidden_dim,
        'output_dim': output_dim,
        'label_encoder_classes': label_encoder.classes_
    }, 'medical_classifier_model.pth')
    
    logger.info("Training completed and model saved!")

## Hyperparameter tuning - Ludwe

## Make sure to create smart features - Ludwe

## Keras Tuner and Torch Tuner for tuning the parameters of the model. LOOK INTO TUNING THE MODEL.