<a href="https://colab.research.google.com/github/Swayam21345/EPIChat/blob/main/Epichat.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision torchaudio
!pip install transformers

In [None]:
!pip show transformers

In [None]:
import pandas as pd
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import accuracy_score
import joblib
import pickle
import warnings
from collections import defaultdict
import matplotlib.pyplot as plt
import seaborn as sns
from wordcloud import WordCloud

warnings.filterwarnings('ignore')

class DiseasePredictor:
    def __init__(self, csv_path):
        self.csv_path = csv_path
        self.model = None
        self.vectorizer = None
        self.vocabulary_ = None
        self.symptom_frequency = None
        self.disease_symptoms = None

    def load_and_preprocess_data(self):
        """Load and preprocess the disease-symptom data"""
        try:
            df = pd.read_csv(self.csv_path)

            if 'Disease' not in df.columns or 'Symptoms' not in df.columns:
                raise ValueError("CSV must contain 'Disease' and 'Symptoms' columns")

            df['Symptoms'] = df['Symptoms'].str.lower().str.replace(', ', ',')

            self.symptom_frequency = defaultdict(int)
            self.disease_symptoms = {}

            for _, row in df.iterrows():
                symptoms = row['Symptoms'].split(',')
                for symptom in symptoms:
                    self.symptom_frequency[symptom] += 1
                self.disease_symptoms[row['Disease']] = symptoms

            return df

        except Exception as e:
            print(f"Error loading data: {e}")
            return None

    def train_model(self, df):
        """Train a Naive Bayes classifier"""
        self.vectorizer = CountVectorizer(tokenizer=lambda x: x.split(','))
        X = self.vectorizer.fit_transform(df['Symptoms'])
        y = df['Disease']

        self.vocabulary_ = self.vectorizer.vocabulary_

        X_train, X_test, y_train, y_test = train_test_split(
            X, y, test_size=0.2, random_state=42)

        self.model = MultinomialNB()
        self.model.fit(X_train, y_train)

        y_pred = self.model.predict(X_test)
        accuracy = accuracy_score(y_test, y_pred)
        print(f"\nModel trained with accuracy: {accuracy:.2%}\n")

        self.save_model()

    def save_model(self):
        """Save all necessary components to disk"""
        if self.model and self.vocabulary_:
            joblib.dump(self.model, 'disease_model.joblib')
            with open('symptom_vocabulary.pkl', 'wb') as f:
                pickle.dump({
                    'vocabulary': self.vocabulary_,
                    'symptom_frequency': self.symptom_frequency,
                    'disease_symptoms': self.disease_symptoms
                }, f)
            print("Model and data saved to disk")

    def load_saved_model(self):
        """Load a previously trained model"""
        try:
            self.model = joblib.load('disease_model.joblib')
            with open('symptom_vocabulary.pkl', 'rb') as f:
                data = pickle.load(f)
                self.vocabulary_ = data['vocabulary']
                self.symptom_frequency = data['symptom_frequency']
                self.disease_symptoms = data['disease_symptoms']

            self.vectorizer = CountVectorizer(
                tokenizer=lambda x: x.split(','),
                vocabulary=self.vocabulary_
            )
            print("\nPre-trained model loaded successfully!")
            return True
        except Exception as e:
            print(f"\nError loading model: {e}")
            return False

    def predict_disease(self, symptoms):
        """Predict disease based on symptoms with confidence score"""
        try:
            if not self.vectorizer:
                raise ValueError("Model not loaded properly")

            symptoms_processed = symptoms.lower().replace(', ', ',')
            symptoms_vec = self.vectorizer.transform([symptoms_processed])

            probabilities = self.model.predict_proba(symptoms_vec)[0]
            top3_indices = probabilities.argsort()[-3:][::-1]
            top3_diseases = self.model.classes_[top3_indices]
            top3_probs = probabilities[top3_indices]

            results = []
            for disease, prob in zip(top3_diseases, top3_probs):
                disease_symptoms = self.disease_symptoms.get(disease, [])
                results.append({
                    'disease': disease,
                    'probability': f"{prob:.1%}",
                    'common_symptoms': disease_symptoms
                })

            return results

        except Exception as e:
            print(f"\nPrediction error: {e}")
            return None

    def find_related_symptoms(self, input_symptoms):
        """Find symptoms that commonly occur with the input symptoms"""
        input_symptoms = [s.strip().lower() for s in input_symptoms.split(',')]
        related = defaultdict(int)

        for disease, symptoms in self.disease_symptoms.items():
            if any(s in symptoms for s in input_symptoms):
                for symptom in symptoms:
                    if symptom not in input_symptoms:
                        related[symptom] += 1

        return sorted(related.items(), key=lambda x: x[1], reverse=True)[:5]

    def plot_symptom_frequency(self, top_n=20):
        """Plot bar chart of most common symptoms"""
        plt.figure(figsize=(10,6))
        sorted_symptoms = sorted(self.symptom_frequency.items(),
                               key=lambda x: x[1], reverse=True)[:top_n]
        sns.barplot(x=[v[1] for v in sorted_symptoms],
                   y=[v[0] for v in sorted_symptoms])
        plt.title(f'Top {top_n} Most Common Symptoms')
        plt.xlabel('Frequency')
        plt.ylabel('Symptom')
        plt.tight_layout()
        plt.show()

    def plot_disease_wordcloud(self):
        """Generate word cloud of disease names"""
        text = ' '.join(self.disease_symptoms.keys())
        wordcloud = WordCloud(width=800, height=400,
                            background_color='white').generate(text)
        plt.figure(figsize=(10,5))
        plt.imshow(wordcloud, interpolation='bilinear')
        plt.axis('off')
        plt.title('Disease Word Cloud')
        plt.show()

    def plot_symptom_heatmap(self, top_diseases=10, top_symptoms=15):
        """Plot heatmap of disease-symptom relationships"""
        disease_counts = {k: len(v) for k, v in self.disease_symptoms.items()}
        top_diseases = sorted(disease_counts.items(),
                            key=lambda x: x[1], reverse=True)[:top_diseases]

        top_symptoms = sorted(self.symptom_frequency.items(),
                            key=lambda x: x[1], reverse=True)[:top_symptoms]
        top_symptoms = [s[0] for s in top_symptoms]

        matrix = []
        for disease, _ in top_diseases:
            row = []
            for symptom in top_symptoms:
                row.append(1 if symptom in self.disease_symptoms[disease] else 0)
            matrix.append(row)

        plt.figure(figsize=(12,8))
        sns.heatmap(matrix,
                   xticklabels=top_symptoms,
                   yticklabels=[d[0] for d in top_diseases],
                   cmap='Blues', cbar=False)
        plt.title('Disease-Symptom Relationships')
        plt.xlabel('Symptoms')
        plt.ylabel('Diseases')
        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plt.show()

def main():
    print("\n=== Advanced Disease Prediction System ===")
    print("1. Train new model")
    print("2. Load existing model")
    print("3. Exit")

    choice = input("\nEnter your choice (1-3): ")

    csv_path = 'disease_symptoms.csv'
    predictor = DiseasePredictor(csv_path)

    if choice == '1':
        df = predictor.load_and_preprocess_data()
        if df is not None:
            predictor.train_model(df)
    elif choice == '2':
        if not predictor.load_saved_model():
            print("\nNo saved model found. Training new model...")
            df = predictor.load_and_preprocess_data()
            if df is not None:
                predictor.train_model(df)
    elif choice == '3':
        return
    else:
        print("Invalid choice")
        return

    while True:
        print("\n=== Disease Prediction Menu ===")
        print("1. Predict disease from symptoms")
        print("2. Find related symptoms")
        print("3. View symptom frequency chart")
        print("4. View disease word cloud")
        print("5. View disease-symptom heatmap")
        print("6. Exit")

        sub_choice = input("\nEnter your choice (1-6): ")

        if sub_choice == '1':
            symptoms = input("\nEnter your symptoms (comma separated): ").strip()
            if symptoms.lower() == 'exit':
                break

            results = predictor.predict_disease(symptoms)
            if results:
                print("\nTop 3 possible diseases:")
                for i, result in enumerate(results, 1):
                    print(f"\n{i}. {result['disease']} ({result['probability']} confidence)")
                    print("   Common symptoms:", ', '.join(result['common_symptoms']))
            else:
                print("Could not make prediction. Please check your symptoms.")

        elif sub_choice == '2':
            symptoms = input("\nEnter symptoms to find related ones (comma separated): ").strip()
            related = predictor.find_related_symptoms(symptoms)
            if related:
                print("\nSymptoms that often occur with your input:")
                for symptom, count in related:
                    print(f"- {symptom} (appears with {count} diseases)")
            else:
                print("No related symptoms found.")

        elif sub_choice == '3':
            predictor.plot_symptom_frequency()

        elif sub_choice == '4':
            predictor.plot_disease_wordcloud()

        elif sub_choice == '5':
            predictor.plot_symptom_heatmap()

        elif sub_choice == '6':
            break

        else:
            print("Invalid choice")

if __name__ == "__main__":
    main()


=== Advanced Disease Prediction System ===
1. Train new model
2. Load existing model
3. Exit

Model trained with accuracy: 0.00%

Model and data saved to disk

=== Disease Prediction Menu ===
1. Predict disease from symptoms
2. Find related symptoms
3. View symptom frequency chart
4. View disease word cloud
5. View disease-symptom heatmap
6. Exit

Top 3 possible diseases:

1. Migraine (20.6% confidence)
   Common symptoms: headache, vomiting, sensitivity to light, sensitivity to sound

2. Meningitis (5.0% confidence)
   Common symptoms: fever, headache, neck pain, nausea, vomiting

3. Hypertension (High Blood Pressure) (2.6% confidence)
   Common symptoms: high blood pressure, headache, dizziness

=== Disease Prediction Menu ===
1. Predict disease from symptoms
2. Find related symptoms
3. View symptom frequency chart
4. View disease word cloud
5. View disease-symptom heatmap
6. Exit


In [None]:
!pip install torch torchvision torchaudio
!pip install transformers==4.30.0
!pip install scikit-learn pandas tqdm

In [None]:
!pip install --upgrade torch

In [None]:
!pip install --upgrade transformers