<a href="https://colab.research.google.com/github/aakarshhh/AI_ML/blob/main/medical_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from experta import *
import os
from common import Severity

class Condition(truth):
    name = Field(str, mandatory=True)
    disez = Field(list)

class Disease(truth):
    name = Field(str, mandatory=True)
    symptom = Field(list)
    severity = Field(list)

class Task(truth):
    pass

class Error(truth):
    pass

class Result(truth):
    name = Field(str, mandatory=True)

class DiseaseWatch(truth):
    disezs = Field(list, mandatory=True)
    completed = Field(bool, mandatory=True)

class DiseaseStub(truth):
    name = Field(str)

class Num(truth):
    name = Field(str, mandatory=True)
    required = Field(int, mandatory=True)
    obtained = Field(int, mandatory=True)
    symptom = Field(list, mandatory=True)

class MaxNum(Count):
    pass

class Ratio(truth):
    ratio=Field(float, mandatory=True)

class Transaction(truth):
    symptom = Field(str)
    severity = Field(Severity)
    disez = Field(str)

class Query(truth):
    symptom = Field(str, mandatory=True)
    severity = Field(Severity, mandatory=True)

def Db_Read():
    with open(os.path.join(os.getcwd(), "disezs.txt")) as disezs_t:
        d_l = [a.strip()  for a in disezs_t.read().split("\n") if a != '']
    d_s_l = {}
    for disez in d_l:
        with open(os.path.join(os.getcwd(), "Disease symptoms", "{}.txt".format(disez))) as disez_temp_file:
            disez_temp_data = disez_temp_file.read()
            d_s_l[disez] = {}
            temp_symp = []
            temp_sev = []
            for idx, data in enumerate(disez_temp_data.split("\n")):
                if idx % 2 == 0:
                    temp_symp.append(data)
                else:
                    temp_sev.append(tuple([Severity(int(x.strip())) for x in data.split(",")]))
            d_s_l[disez]['symp'] = temp_symp
            d_s_l[disez]['sev'] = temp_sev
    sdd = {}
    sl = []
    for disez in d_l:
        for symptom in d_s_l[disez]['symp']:
            if symptom not in sdd:
                sdd[symptom] = []
            sdd[symptom].append(disez)
            if symptom not in sl:
                sl.append(symptom)
    return (d_s_l, sdd, d_l, sl)


def fetch_d(disez):
    with open(os.path.join(os.getcwd(), "Disease descriptions", "{}.txt".format(disez))) as fd:
        discription = fd.read()
    return discription


def fetch_treat(disez):
    with open(os.path.join(os.getcwd(), "Disease treatments", "{}.txt".format(disez))) as fd:
        treatment = fd.read()
    return treatment


def pld(disez):
    id_disez = disez
    disez_details = fetch_d(id_disez)
    treatments = fetch_treat(id_disez)
    print("\n\nA likely disez that you have is: {}".format(id_disez))
    print("A short description of the disez is given below :\n")
    print(disez_details)
    print("The common medications and procedures suggested by other real doctors are:\n")
    print(treatments)

def deff_disez(disez):
    id_disez = disez
    disez_details = fetch_d(id_disez)
    treatments = fetch_treat(id_disez)
    print("\n\nThe most probable disez that you have is {}".format(disez))
    print("\nA short description of the disez is given below :\n")
    print(disez_details)
    print("\nThe common medications and procedures suggested by other real doctors are:\n")
    print(treatments)

class Diagnose(KnowledgeEngine):
    @Deftruths()
    def _initial_action(self, dis_symp_dict, symp_dis_dict, symp_list):
        for symptom in symp_dis_dict:
            yield Condition(name=symptom, disez=symp_dis_dict[symptom])
        for disez in dis_symp_dict:
            yield Disease(name=disez, symptom=dis_symp_dict[disez]['symp'], severity=dis_symp_dict[disez]['sev'])
        self.symp_list = symp_list
        self.diagnosis = []
        self.incomplete = False
        self.all_matches = []

    @Rule(salience=1000)
    def startup(self):
        self.patients = input("Patient's Name: ").strip().upper()
        print("Hello! I am a Custom Diagnosis Expert System (CUSTODES)\n"
                "I am here to help you diagnose your disez.\n"
                "Please start typing in your symptoms and its severity."
                "Press enter on a blank line to get diagnosis")
        self.declare(Task('type-symptom'))

    @Rule(AS.f1 << Task('type-symptom'))
    def type_Condition(self, f1):
        ans = ' '
        sev = ' '
        to_check_duplicates = {}
        while ans != '' and sev != '':
            ans = input('Symptom>').strip().lower()
            auto_fill = list(filter(lambda x: x.startswith(ans), self.symp_list))
            if len(auto_fill) == 1 or ans.replace(" ", "_").strip() in self.symp_list:
                if not ans.replace(" ", "_").strip() in self.symp_list:
                    ans = auto_fill[0]
                    print(ans.replace("_", " "))
                sev = input('Severity (0-5)>')
                try:
                    sev = Severity(int(sev))
                except ValueError:
                    print("Please enter a number between 0 - 5")
                    sev = ' '
                    continue
                to_check_duplicates[ans] = sev
            elif len(auto_fill) == 0:
                suggest = list(filter(lambda x: ans.replace(" ", "_") in x, self.symp_list))
                if len(suggest) == 0:
                    print("Could not find any matching symptoms.\n")
                else:
                    print("Did you mean:")
                    print([x.replace("_", " ") for x in suggest])
                ans = ' '
            elif ans != '':
                print("Did you mean:")
                print(auto_fill)
                ans = ' '
                sev = ' '
        for symp in to_check_duplicates:
            self.declare(Query(symptom=symp, severity=to_check_duplicates[symp]))
        self.retract(f1)

    @Rule(NOT(Task()),
          AS.f1 << Query(symptom=MATCH.symp, severity=MATCH.sev),
          Condition(name=MATCH.symp, disez=MATCH.dis))
    def process_input_query(self, f1, symp, sev, dis):
        for disez in dis:
            self.declare(Transaction(symptom=symp, severity=sev, disez=disez))
        self.retract(f1)

    @Rule(Transaction(symptom=MATCH.symp, severity=MATCH.sev, disez=MATCH.dis),
          Disease(name=MATCH.dis, symptom=MATCH.sy, severity=MATCH.se))
    def create_Num(self, dis, sy):
        self.declare(Num(name=dis, required=len(sy), obtained=0, symptom=[]))
        self.declare(DiseaseStub(name=dis))

    @Rule(AS.f1 << Transaction(symptom=MATCH.symp, severity=MATCH.sev, disez=MATCH.dis),
          Disease(name=MATCH.dis, symptom=MATCH.syflist, severity=MATCH.sevflist),
          AS.f2 << Num(name=MATCH.dis, required=MATCH.req, obtained=MATCH.obt, symptom=MATCH.clist))
    def count_activations(self, f1, f2, symp, sev, syflist, sevflist, req, obt, clist):
        try:
            idx = syflist.index(symp)
            sev0, sev1 = sevflist[idx]
            if sev.value <= sev1.value and sev.value >= sev0.value:
                in_list = list(clist)
                in_list.append(tuple([symp, sev]))
                self.modify(f2, obtained=obt+1, symptom=in_list)
        except ValueError:
            self.declare(Error("Something went wrong."))
        self.retract(f1)

    @Rule(NOT(Transaction()), NOT(DiseaseWatch()),
          EXISTS(Num()))
    def add_disez_watcher(self):
        self.declare(DiseaseWatch(disezs=[], completed=False))

    @Rule(AS.f1 << DiseaseWatch(disezs=MATCH.dis_list, completed=MATCH.com),
          AS.f2 << DiseaseStub(name=MATCH.dis))
    def add_disezs_to_watcher(self, f1, f2, dis, dis_list):
        in_list = list(dis_list)
        if dis not in in_list:
            in_list.append(dis)
        self.modify(f1, disezs=in_list)
        self.retract(f2)

    @Rule(AS.f1 << DiseaseWatch(disezs=MATCH.dis_list, completed=MATCH.com),
          TEST(lambda com: not com),
          NOT(DiseaseStub()))
    def mark_disez_completion(self, f1):
        self.modify(f1, completed=True)

    @Rule(AS.f1 << DiseaseWatch(disezs=MATCH.dis_list, completed=MATCH.com),
          TEST(lambda dis_list: len(dis_list) > 0),
          TEST(lambda com: com),
          AS.f2 << Num(name=MATCH.dis, required=MATCH.req, obtained=MATCH.obt, symptom=MATCH.clist))
    def obtain_exact_diagnosis(self, f1, f2, dis, req, obt, dis_list):
        if req == obt:
            self.retract(f2)
            self.all_matches.append(dis)
            self.declare(Result(name=dis))
        if dis in dis_list:
            in_list = list(dis_list)
            try:
                in_list.remove(dis)
            except ValueError:
                in_list = []
                self.declare(Error("Something went wrong while obtaining exact diagnosis"))
            self.modify(f1, disezs=in_list)

    @Rule(AS.f1 << DiseaseWatch(disezs=MATCH.dis_list, completed=MATCH.com),
          TEST(lambda dis_list: len(dis_list) == 0),
          TEST(lambda com: com),
          NOT(Task()),
          Result())
    def signal_exact_completion():
        self.declare(Task('store-result'))
        self.retract(f1)

    @Rule(AS.f1 << DiseaseWatch(disezs=MATCH.dis_list, completed=MATCH.com),
          TEST(lambda dis_list: len(dis_list) == 0),
          TEST(lambda com: com),
          NOT(Task()),
          NOT(Result()))
    def incomplete_information(self, f1):
        self.incomplete = True
        self.retract(f1)
        self.declare(Task('best-match'))
        self.declare(Ratio(ratio=0.0))

    @Rule(Task('best-match'),
          AS.f1 << Num(name=MATCH.dis, required=MATCH.req, obtained=MATCH.obt, symptom=MATCH.clist),
          AS.f2 << Ratio(ratio=MATCH.ratio),
          TEST(lambda req, obt, ratio: obt/req > ratio))
    def compute_max(self, f1, f2, req, obt):
        self.modify(f2, ratio=obt/req)

    @Rule(Task('best-match'),
          AS.f1 << Num(name=MATCH.dis, required=MATCH.req, obtained=MATCH.obt, symptom=MATCH.clist),
          AS.f2 << Ratio(ratio=MATCH.ratio),
          TEST(lambda req, obt, ratio: obt/req < ratio))
    def remove_mins(self, f1, f2, dis, req, obt):
        self.retract(f1)
        self.all_matches.append(tuple((dis, obt, req)))

    @Rule(Task('best-match'),
        AS.f1 << Num(name=MATCH.dis, required=MATCH.req, obtained=MATCH.obt, symptom=MATCH.clist),
        AS.f2 << Ratio(ratio=MATCH.ratio),
        TEST(lambda req, obt, ratio: obt/req == ratio))
    def keep_max(self, f1, f2, dis, req, obt, clist):
        self.retract(f1)
        self.all_matches.append(tuple((dis, obt, req)))
        self.declare(MaxNum(name=dis, required=req, obtained=obt, symptom=clist))

    @Rule(AS.f2 << Task('best-match'), NOT(Num()), AS.f1 << Ratio(ratio=MATCH.ratio))
    def cleanup_max_operation(self, f1, f2):
        self.retract(f1)
        self.retract(f2)
        self.declare(Task('store-result'))

    @Rule(AS.f1 << MaxNum(name=MATCH.dis, required=MATCH.req, obtained=MATCH.obt, symptom=MATCH.clist))
    def max_to_result(self, f1, dis):
        self.retract(f1)
        self.declare(Result(name=dis))

    @Rule(Task('store-result'),
          AS.f3 << Result(name=MATCH.dis))
    def store_result(self, f3, dis):
        self.diagnosis.append(dis)
        self.retract(f3)

if __name__ == "__main__":
    dis_symp_dict, symp_dis_dict, dis_list, symp_list = Db_Read()
    '''
    Print debugging!
    print(dis_symp_dict)
    print(symp_dis_dict)
    print(dis_list)
    print(symp_list)
    '''
    eng_client = Diagnose()
    eng_client.reset(dis_symp_dict=dis_symp_dict, symp_dis_dict=symp_dis_dict, symp_list=symp_list) 
    '''
    print(eng_client.truths)
    '''
    eng_client.run()
    '''
    print(eng_client.truths)
    print(eng_client.all_matches)
    print(eng_client.diagnosis)
    print(eng_client.incomplete)
    '''
    print("Hey {},".format(eng_client.patients))
    if not eng_client.incomplete:
        for dis in eng_client.diagnosis:
            deff_disez(dis)
        likely_disez = eng_client.all_matches
        for item in eng_client.diagnosis:
            try:
                likely_disez.remove(item)
            except ValueError:
                pass
        for dis, obt, req in likely_disez:
            pld(dis)
    else:
        likely_disez = eng_client.diagnosis
        for dis in likely_disez:
            pld(dis)