<a href="https://colab.research.google.com/github/SalmanShahzadAli/HealthMate---Software-Engineering-Semester-Project/blob/main/medical_chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from pydantic import BaseModel, Field, validator
from typing import List, Optional, Dict, Any
from datetime import datetime, date
from enum import Enum
import json


class SeverityLevel(str, Enum):
    LOW = "low"
    MODERATE = "moderate"
    HIGH = "high"
    CRITICAL = "critical"


class Gender(str, Enum):
    MALE = "male"
    FEMALE = "female"
    OTHER = "other"


class BloodType(str, Enum):
    A_POSITIVE = "A+"
    A_NEGATIVE = "A-"
    B_POSITIVE = "B+"
    B_NEGATIVE = "B-"
    AB_POSITIVE = "AB+"
    AB_NEGATIVE = "AB-"
    O_POSITIVE = "O+"
    O_NEGATIVE = "O-"


class MedicationFrequency(str, Enum):
    ONCE_DAILY = "once_daily"
    TWICE_DAILY = "twice_daily"
    THREE_TIMES_DAILY = "three_times_daily"
    FOUR_TIMES_DAILY = "four_times_daily"
    AS_NEEDED = "as_needed"
    WEEKLY = "weekly"


class Patient(BaseModel):
    patient_id: str = Field(..., description="Unique patient identifier")
    name: str = Field(..., min_length=1, max_length=100)
    age: int = Field(..., ge=0, le=150)
    gender: Gender
    blood_type: Optional[BloodType] = None
    weight_kg: Optional[float] = Field(None, ge=0, le=500)
    height_cm: Optional[float] = Field(None, ge=0, le=300)
    phone: Optional[str] = Field(None, pattern=r"^\+?[\d\s\-\(\)]+$")
    emergency_contact: Optional[str] = None
    allergies: List[str] = Field(default_factory=list)
    chronic_conditions: List[str] = Field(default_factory=list)

    @validator('age')
    def validate_age(cls, v):
        if v < 0 or v > 150:
            raise ValueError('Age must be between 0 and 150')
        return v


class Symptom(BaseModel):
    symptom_id: str = Field(..., description="Unique symptom identifier")
    name: str = Field(..., min_length=1, max_length=100)
    description: Optional[str] = Field(None, max_length=500)
    severity: SeverityLevel
    onset_date: datetime
    duration_hours: Optional[int] = Field(None, ge=0)
    location: Optional[str] = Field(None, max_length=100)
    triggers: List[str] = Field(default_factory=list)
    associated_symptoms: List[str] = Field(default_factory=list)


class Medication(BaseModel):
    medication_id: str = Field(..., description="Unique medication identifier")
    name: str = Field(..., min_length=1, max_length=100)
    dosage: str = Field(..., min_length=1, max_length=50)
    frequency: MedicationFrequency
    start_date: date
    end_date: Optional[date] = None
    prescribing_doctor: Optional[str] = Field(None, max_length=100)
    instructions: Optional[str] = Field(None, max_length=500)
    side_effects: List[str] = Field(default_factory=list)

    @validator('end_date')
    def validate_end_date(cls, v, values):
        if v and 'start_date' in values and v < values['start_date']:
            raise ValueError('End date cannot be before start date')
        return v


class VitalSigns(BaseModel):
    reading_id: str = Field(..., description="Unique reading identifier")
    timestamp: datetime = Field(default_factory=datetime.now)
    systolic_bp: Optional[int] = Field(None, ge=50, le=300)
    diastolic_bp: Optional[int] = Field(None, ge=30, le=200)
    heart_rate: Optional[int] = Field(None, ge=30, le=250)
    temperature_celsius: Optional[float] = Field(None, ge=30.0, le=50.0)
    respiratory_rate: Optional[int] = Field(None, ge=8, le=60)
    oxygen_saturation: Optional[int] = Field(None, ge=70, le=100)
    blood_glucose: Optional[float] = Field(None, ge=20.0, le=500.0)
    notes: Optional[str] = Field(None, max_length=200)


class HealthAssessment(BaseModel):
    assessment_id: str = Field(..., description="Unique assessment identifier")
    patient_id: str
    timestamp: datetime = Field(default_factory=datetime.now)
    symptoms: List[Symptom]
    vital_signs: Optional[VitalSigns] = None
    risk_level: SeverityLevel
    recommendations: List[str] = Field(default_factory=list)
    requires_immediate_attention: bool = False
    follow_up_needed: bool = False
    follow_up_date: Optional[date] = None


class MedicalHistory(BaseModel):
    history_id: str = Field(..., description="Unique history identifier")
    patient_id: str
    condition: str = Field(..., min_length=1, max_length=100)
    diagnosis_date: date
    treatment: Optional[str] = Field(None, max_length=500)
    outcome: Optional[str] = Field(None, max_length=200)
    doctor: Optional[str] = Field(None, max_length=100)


class MedicationReminder(BaseModel):
    reminder_id: str = Field(..., description="Unique reminder identifier")
    patient_id: str
    medication_id: str
    scheduled_time: datetime
    taken: bool = False
    taken_time: Optional[datetime] = None
    notes: Optional[str] = Field(None, max_length=200)


class MedicalAssistantAgent:
    def __init__(self):
        self.patients: Dict[str, Patient] = {}
        self.assessments: Dict[str, HealthAssessment] = {}
        self.medications: Dict[str, Medication] = {}
        self.medical_history: Dict[str, List[MedicalHistory]] = {}
        self.reminders: Dict[str, List[MedicationReminder]] = {}

    def register_patient(self, patient_data: Dict[str, Any]) -> Patient:
        """Register a new patient in the system"""
        patient = Patient(**patient_data)
        self.patients[patient.patient_id] = patient
        self.medical_history[patient.patient_id] = []
        self.reminders[patient.patient_id] = []
        return patient

    def add_symptom_assessment(self, patient_id: str, symptoms_data: List[Dict[str, Any]],
                             vital_signs_data: Optional[Dict[str, Any]] = None) -> HealthAssessment:
        """Create a health assessment based on symptoms"""
        if patient_id not in self.patients:
            raise ValueError(f"Patient {patient_id} not found")

        symptoms = [Symptom(**symptom_data) for symptom_data in symptoms_data]
        vital_signs = VitalSigns(**vital_signs_data) if vital_signs_data else None

        # Determine risk level based on symptoms
        risk_level = self._calculate_risk_level(symptoms, vital_signs)

        # Generate recommendations
        recommendations = self._generate_recommendations(symptoms, vital_signs, risk_level)

        assessment = HealthAssessment(
            assessment_id=f"assess_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            patient_id=patient_id,
            symptoms=symptoms,
            vital_signs=vital_signs,
            risk_level=risk_level,
            recommendations=recommendations,
            requires_immediate_attention=risk_level in [SeverityLevel.HIGH, SeverityLevel.CRITICAL],
            follow_up_needed=risk_level != SeverityLevel.LOW
        )

        self.assessments[assessment.assessment_id] = assessment
        return assessment

    def add_medication(self, patient_id: str, medication_data: Dict[str, Any]) -> Medication:
        """Add a medication to patient's regimen"""
        if patient_id not in self.patients:
            raise ValueError(f"Patient {patient_id} not found")

        medication = Medication(**medication_data)
        self.medications[medication.medication_id] = medication

        # Create medication reminders
        self._create_medication_reminders(patient_id, medication)

        return medication

    def record_vital_signs(self, patient_id: str, vital_signs_data: Dict[str, Any]) -> VitalSigns:
        """Record vital signs for a patient"""
        if patient_id not in self.patients:
            raise ValueError(f"Patient {patient_id} not found")

        vital_signs = VitalSigns(**vital_signs_data)

        # Check for abnormal values and create alert if needed
        self._check_vital_signs_alerts(patient_id, vital_signs)

        return vital_signs

    def get_patient_summary(self, patient_id: str) -> Dict[str, Any]:
        """Get comprehensive patient summary"""
        if patient_id not in self.patients:
            raise ValueError(f"Patient {patient_id} not found")

        patient = self.patients[patient_id]
        recent_assessments = [
            assessment for assessment in self.assessments.values()
            if assessment.patient_id == patient_id
        ]

        active_medications = [
            med for med in self.medications.values()
            if med.end_date is None or med.end_date >= date.today()
        ]

        return {
            "patient_info": patient.dict(),
            "recent_assessments": len(recent_assessments),
            "active_medications": len(active_medications),
            "medical_history": len(self.medical_history.get(patient_id, [])),
            "upcoming_reminders": len([
                r for r in self.reminders.get(patient_id, [])
                if not r.taken and r.scheduled_time > datetime.now()
            ])
        }

    def _calculate_risk_level(self, symptoms: List[Symptom], vital_signs: Optional[VitalSigns]) -> SeverityLevel:
        """Calculate overall risk level based on symptoms and vital signs"""
        max_severity = SeverityLevel.LOW

        # Check symptom severity
        for symptom in symptoms:
            if symptom.severity == SeverityLevel.CRITICAL:
                return SeverityLevel.CRITICAL
            elif symptom.severity == SeverityLevel.HIGH:
                max_severity = SeverityLevel.HIGH
            elif symptom.severity == SeverityLevel.MODERATE and max_severity == SeverityLevel.LOW:
                max_severity = SeverityLevel.MODERATE

        # Check vital signs for abnormalities
        if vital_signs:
            if (vital_signs.systolic_bp and vital_signs.systolic_bp > 180) or \
               (vital_signs.heart_rate and vital_signs.heart_rate > 120) or \
               (vital_signs.temperature_celsius and vital_signs.temperature_celsius > 39.0):
                max_severity = SeverityLevel.HIGH

        return max_severity

    def _generate_recommendations(self, symptoms: List[Symptom], vital_signs: Optional[VitalSigns],
                                risk_level: SeverityLevel) -> List[str]:
        """Generate recommendations based on assessment"""
        recommendations = []

        if risk_level == SeverityLevel.CRITICAL:
            recommendations.append("Seek immediate emergency medical attention")
            recommendations.append("Call emergency services if symptoms worsen")
        elif risk_level == SeverityLevel.HIGH:
            recommendations.append("Contact your healthcare provider immediately")
            recommendations.append("Monitor symptoms closely")
        elif risk_level == SeverityLevel.MODERATE:
            recommendations.append("Schedule appointment with healthcare provider")
            recommendations.append("Continue monitoring symptoms")
        else:
            recommendations.append("Monitor symptoms and rest")
            recommendations.append("Stay hydrated")

        # Add specific recommendations based on symptoms
        symptom_names = [s.name.lower() for s in symptoms]
        if "fever" in symptom_names:
            recommendations.append("Take temperature regularly and use fever reducers as needed")
        if "pain" in symptom_names:
            recommendations.append("Apply appropriate pain management techniques")

        return recommendations

    def _create_medication_reminders(self, patient_id: str, medication: Medication):
        """Create medication reminders based on frequency"""
        # This is a simplified implementation - in practice, you'd create recurring reminders
        base_time = datetime.now().replace(hour=8, minute=0, second=0, microsecond=0)

        if medication.frequency == MedicationFrequency.ONCE_DAILY:
            times = [base_time]
        elif medication.frequency == MedicationFrequency.TWICE_DAILY:
            times = [base_time, base_time.replace(hour=20)]
        elif medication.frequency == MedicationFrequency.THREE_TIMES_DAILY:
            times = [base_time, base_time.replace(hour=14), base_time.replace(hour=20)]
        else:
            times = [base_time]  # Default

        for time in times:
            reminder = MedicationReminder(
                reminder_id=f"rem_{medication.medication_id}_{time.strftime('%H%M')}",
                patient_id=patient_id,
                medication_id=medication.medication_id,
                scheduled_time=time
            )
            self.reminders[patient_id].append(reminder)

    def _check_vital_signs_alerts(self, patient_id: str, vital_signs: VitalSigns):
        """Check for abnormal vital signs and create alerts"""
        alerts = []

        if vital_signs.systolic_bp and vital_signs.systolic_bp > 140:
            alerts.append("High blood pressure detected")
        if vital_signs.heart_rate and vital_signs.heart_rate > 100:
            alerts.append("Elevated heart rate detected")
        if vital_signs.temperature_celsius and vital_signs.temperature_celsius > 38.0:
            alerts.append("Fever detected")

        if alerts:
            print(f"ALERTS for patient {patient_id}: {', '.join(alerts)}")


# Example usage
if __name__ == "__main__":
    # Initialize the medical assistant
    agent = MedicalAssistantAgent()

    # Register a patient
    patient_data = {
        "patient_id": "P001",
        "name": "John Doe",
        "age": 35,
        "gender": Gender.MALE,
        "blood_type": BloodType.A_POSITIVE,
        "weight_kg": 75.0,
        "height_cm": 180.0,
        "allergies": ["penicillin", "nuts"],
        "chronic_conditions": ["hypertension"]
    }

    patient = agent.register_patient(patient_data)
    print(f"Registered patient: {patient.name}")

    # Add symptoms
    symptoms_data = [
        {
            "symptom_id": "S001",
            "name": "headache",
            "description": "Throbbing pain in temples",
            "severity": SeverityLevel.MODERATE,
            "onset_date": datetime.now(),
            "duration_hours": 6,
            "location": "temples"
        },
        {
            "symptom_id": "S002",
            "name": "fever",
            "description": "Body temperature feels high",
            "severity": SeverityLevel.HIGH,
            "onset_date": datetime.now(),
            "duration_hours": 4
        }
    ]

    # Add vital signs
    vital_signs_data = {
        "reading_id": "VS001",
        "systolic_bp": 145,
        "diastolic_bp": 90,
        "heart_rate": 85,
        "temperature_celsius": 38.5,
        "respiratory_rate": 18,
        "oxygen_saturation": 98
    }

    # Create assessment
    assessment = agent.add_symptom_assessment("P001", symptoms_data, vital_signs_data)
    print(f"Assessment created with risk level: {assessment.risk_level}")
    print(f"Recommendations: {assessment.recommendations}")

    # Add medication
    medication_data = {
        "medication_id": "M001",
        "name": "Ibuprofen",
        "dosage": "400mg",
        "frequency": MedicationFrequency.THREE_TIMES_DAILY,
        "start_date": date.today(),
        "prescribing_doctor": "Dr. Smith",
        "instructions": "Take with food"
    }

    medication = agent.add_medication("P001", medication_data)
    print(f"Added medication: {medication.name}")

    # Get patient summary
    summary = agent.get_patient_summary("P001")
    print(f"Patient summary: {json.dumps(summary, indent=2, default=str)}")


        # Manually input patient details
    patient_data = {
        "patient_id": input("Enter patient ID: "),
        "name": input("Enter patient name: "),
        "age": int(input("Enter patient age: ")),
        "gender": input("Enter gender (male/female/other): "),
        "blood_type": input("Enter blood type (e.g., A+, O-): "),
        "weight_kg": float(input("Enter weight in kg: ")),
        "height_cm": float(input("Enter height in cm: ")),
        "allergies": input("Enter allergies (comma-separated): ").split(","),
        "chronic_conditions": input("Enter chronic conditions (comma-separated): ").split(",")
    }

    # Convert Enums properly
    patient_data["gender"] = Gender(patient_data["gender"].strip().lower())
    patient_data["blood_type"] = BloodType(patient_data["blood_type"].strip().upper())

    patient = agent.register_patient(patient_data)
    print(f"âœ… Registered patient: {patient.name}")

    # Input symptoms
    symptoms_data = []
    num_symptoms = int(input("How many symptoms do you want to enter? "))
    for i in range(num_symptoms):
        print(f"\nEnter data for Symptom #{i+1}")
        symptoms_data.append({
            "symptom_id": f"S{i+1:03}",
            "name": input("Symptom name: "),
            "description": input("Description: "),
            "severity": SeverityLevel(input("Severity (low/moderate/high/critical): ").lower()),
            "onset_date": datetime.now(),
            "duration_hours": int(input("Duration in hours: ")),
            "location": input("Location: ")
        })

    # Ask if you want to enter vitals
    add_vitals = input("Do you want to enter vital signs? (yes/no): ").lower() == "yes"
    vital_signs_data = None
    if add_vitals:
        vital_signs_data = {
            "reading_id": "VS001",
            "systolic_bp": int(input("Systolic BP: ")),
            "diastolic_bp": int(input("Diastolic BP: ")),
            "heart_rate": int(input("Heart rate: ")),
            "temperature_celsius": float(input("Temperature (Â°C): ")),
            "respiratory_rate": int(input("Respiratory rate: ")),
            "oxygen_saturation": int(input("Oxygen saturation: "))
        }

    # Create assessment
    assessment = agent.add_symptom_assessment(patient.patient_id, symptoms_data, vital_signs_data)
    print(f"ðŸ©º Assessment risk level: {assessment.risk_level}")
    print("Recommendations:")
    for rec in assessment.recommendations:
        print(f" - {rec}")

/tmp/ipython-input-1191614608.py:54: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  @validator('age')
/tmp/ipython-input-1191614608.py:84: PydanticDeprecatedSince20: Pydantic V1 style `@validator` validators are deprecated. You should migrate to Pydantic V2 style `@field_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.11/migration/
  @validator('end_date')
/tmp/ipython-input-1191614608.py:224: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.

Registered patient: John Doe
Assessment created with risk level: SeverityLevel.HIGH
Recommendations: ['Contact your healthcare provider immediately', 'Monitor symptoms closely', 'Take temperature regularly and use fever reducers as needed']
Added medication: Ibuprofen
Patient summary: {
  "patient_info": {
    "patient_id": "P001",
    "name": "John Doe",
    "age": 35,
    "gender": "male",
    "blood_type": "A+",
    "weight_kg": 75.0,
    "height_cm": 180.0,
    "phone": null,
    "emergency_contact": null,
    "allergies": [
      "penicillin",
      "nuts"
    ],
    "chronic_conditions": [
      "hypertension"
    ]
  },
  "recent_assessments": 1,
  "active_medications": 1,
  "medical_history": 0,
  "upcoming_reminders": 0
}
Enter patient ID: 1123
Enter patient name: Murtaza
Enter patient age: 21
Enter gender (male/female/other): male
Enter blood type (e.g., A+, O-): B+
Enter weight in kg: 72
Enter height in cm: 175
Enter allergies (comma-separated): flu, chicken pox, cough
Enter