In [1]:
import pickle
import random
import numpy as np

import torch

import transformers
from transformers import AutoTokenizer, AutoModel, AutoConfig

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import InMemoryDataset, download_url

random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed_all(123)


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
with open("SD++.p", "rb") as f:
    kg = pickle.load(f)

In [3]:
kg

{'S_237': {'S_31': 1.0,
  'S_166': 0.6103225806451613,
  'S_204': 0.46838709677419355,
  'S_11': 0.41419354838709677,
  'D_81': 1.0,
  'D_12': 0.9545454545454544,
  'S_227': 0.23483870967741935,
  'S_213': 0.22838709677419355,
  'S_29': 0.22193548387096773,
  'S_129': 0.2064516129032258,
  'D_27': 0.6245059288537549,
  'D_59': 0.5948616600790513,
  'D_5': 0.5869565217391304,
  'S_169': 0.18387096774193548,
  'S_168': 0.16903225806451613,
  'S_208': 0.1670967741935484,
  'S_4': 0.16193548387096773,
  'D_62': 0.458498023715415,
  'D_55': 0.40909090909090906,
  'S_0': 0.12129032258064516,
  'D_28': 0.33992094861660077,
  'S_165': 0.08903225806451613,
  'D_89': 0.2727272727272727,
  'S_226': 0.08516129032258063,
  'S_84': 0.08516129032258063,
  'S_10': 0.08451612903225807,
  'S_149': 0.08387096774193549,
  'D_53': 0.24703557312252963,
  'S_265': 0.06709677419354838,
  'S_177': 0.06709677419354838,
  'S_51': 0.06064516129032258,
  'S_54': 0.05548387096774193,
  'S_88': 0.05290322580645161,


In [4]:
len(kg.keys())

356

In [5]:
with open("SID.p", "rb") as f:
    sid = pickle.load(f)

In [6]:
sid

{'Abnormal appearing skin': 'S_237',
 'Abnormal involuntary movements': 'S_141',
 'Abnormal movement of eyelid': 'S_180',
 'Abnormal size or shape of ear': 'S_242',
 'Absence of menstruation': 'S_99',
 'Abusing alcohol': 'S_41',
 'Ache all over': 'S_88',
 'Acne or pimples': 'S_29',
 'Allergic reaction': 'S_169',
 'Ankle pain': 'S_131',
 'Ankle stiffness or tightness': 'S_87',
 'Ankle swelling': 'S_76',
 'Anxiety and nervousness': 'S_10',
 'Apnea': 'S_53',
 'Arm lump or mass': 'S_199',
 'Arm pain': 'S_77',
 'Arm stiffness or tightness': 'S_250',
 'Arm swelling': 'S_245',
 'Arm weakness': 'S_3',
 'Back cramps or spasms': 'S_35',
 'Back mass or lump': 'S_189',
 'Back pain': 'S_73',
 'Back stiffness or tightness': 'S_142',
 'Back weakness': 'S_94',
 'Bladder mass': 'S_123',
 'Bleeding from eye': 'S_235',
 'Bleeding gums': 'S_216',
 'Bleeding or discharge from nipple': 'S_46',
 'Blindness': 'S_262',
 'Blood in urine': 'S_254',
 'Bones are painful': 'S_158',
 'Bowlegged or knock-kneed': 'S_2

In [7]:
sid_small = {}

for i in sid:
    key = i
    value = sid[i]
    
    key = key.lower()
    sid_small[key] = value

In [8]:
sid_small

{'abnormal appearing skin': 'S_237',
 'abnormal involuntary movements': 'S_141',
 'abnormal movement of eyelid': 'S_180',
 'abnormal size or shape of ear': 'S_242',
 'absence of menstruation': 'S_99',
 'abusing alcohol': 'S_41',
 'ache all over': 'S_88',
 'acne or pimples': 'S_29',
 'allergic reaction': 'S_169',
 'ankle pain': 'S_131',
 'ankle stiffness or tightness': 'S_87',
 'ankle swelling': 'S_76',
 'anxiety and nervousness': 'S_10',
 'apnea': 'S_53',
 'arm lump or mass': 'S_199',
 'arm pain': 'S_77',
 'arm stiffness or tightness': 'S_250',
 'arm swelling': 'S_245',
 'arm weakness': 'S_3',
 'back cramps or spasms': 'S_35',
 'back mass or lump': 'S_189',
 'back pain': 'S_73',
 'back stiffness or tightness': 'S_142',
 'back weakness': 'S_94',
 'bladder mass': 'S_123',
 'bleeding from eye': 'S_235',
 'bleeding gums': 'S_216',
 'bleeding or discharge from nipple': 'S_46',
 'blindness': 'S_262',
 'blood in urine': 'S_254',
 'bones are painful': 'S_158',
 'bowlegged or knock-kneed': 'S_2

In [9]:
with open("DID.p", "rb") as f:
    did = pickle.load(f)

In [10]:
did

{'Acanthosis nigricans': 'D_3',
 'Acariasis': 'D_85',
 'Acne': 'D_53',
 'Actinic keratosis': 'D_12',
 'Acute glaucoma': 'D_48',
 'Acute kidney injury': 'D_86',
 'Acute stress reaction': 'D_76',
 'Adhesive capsulitis of the shoulder': 'D_42',
 'Adjustment reaction': 'D_33',
 'Air embolism': 'D_78',
 'Alcohol intoxication': 'D_47',
 'Allergy': 'D_28',
 'Alzheimer disease': 'D_79',
 'Amyloidosis': 'D_59',
 'Amyotrophic lateral sclerosis ALS': 'D_4',
 'Ankylosing spondylitis': 'D_75',
 'Anxiety': 'D_82',
 'Aphakia': 'D_50',
 'Carbon monoxide poisoning': 'D_38',
 'Carcinoid syndrome': 'D_44',
 'Carpal tunnel syndrome': 'D_64',
 'Cat scratch disease': 'D_58',
 'Central retinal artery or vein occlusion': 'D_80',
 'Cerebral edema': 'D_51',
 'Chagas disease': 'D_61',
 'Chalazion': 'D_37',
 'Chancroid': 'D_74',
 'Chickenpox': 'D_27',
 'Chlamydia': 'D_23',
 'Chondromalacia of the patella': 'D_84',
 'Chronic back pain': 'D_19',
 'Chronic kidney disease': 'D_7',
 'Chronic pain disorder': 'D_56',
 '

In [11]:
symptom_most_common_disease = {}
for j in range(266):
    key = "S_" + str(j)
    print("Key : ", key)
    q = kg[key]
    cnt = 0
    temp3 = []
    for i in q:
        
            
        if(i[0] == 'D' and q[i]>0):
#             if(cnt == 0):
#                 symptom_most_common_disease[key] = i
            if(cnt <= 2):
                temp3.append(i)    
                
            
            print(i, " : ",q[i])
            cnt+=1
    symptom_most_common_disease[key] = temp3        
    print() 

Key :  S_0
D_87  :  1.0
D_0  :  0.9424460431654677
D_9  :  0.5539568345323741
D_69  :  0.4508393285371703
D_27  :  0.4364508393285372
D_31  :  0.4244604316546763
D_28  :  0.4196642685851319
D_21  :  0.381294964028777
D_5  :  0.36211031175059955

Key :  S_1
D_45  :  1.0
D_58  :  0.9664948453608246
D_69  :  0.6726804123711341
D_37  :  0.6469072164948454
D_43  :  0.520618556701031
D_49  :  0.4587628865979381
D_28  :  0.3479381443298969
D_17  :  0.13917525773195877
D_55  :  0.1211340206185567

Key :  S_2
D_9  :  1.0

Key :  S_3
D_64  :  1.0

Key :  S_4
D_18  :  1.0
D_2  :  0.7434052757793764
D_12  :  0.47961630695443647
D_37  :  0.28776978417266186
D_81  :  0.2685851318944844
D_53  :  0.26378896882494
D_30  :  0.2302158273381295
D_3  :  0.16546762589928057

Key :  S_5
D_45  :  1.0
D_69  :  0.865350089766607
D_49  :  0.6894075403949732
D_17  :  0.5565529622980252
D_43  :  0.540394973070018
D_37  :  0.4272890484739677
D_50  :  0.35727109515260325
D_48  :  0.3267504488330341
D_27  :  0.181328

D_41  :  0.8795180722891567
D_48  :  0.7538726333907058
D_45  :  0.5679862306368331
D_18  :  0.5335628227194492
D_51  :  0.5043029259896731
D_69  :  0.3717728055077453
D_49  :  0.33562822719449226
D_43  :  0.32185886402753877
D_37  :  0.2719449225473322

Key :  S_174
D_42  :  1.0
D_75  :  0.5636363636363636

Key :  S_175
D_26  :  1.0
D_46  :  0.8079584775086505
D_47  :  0.7352941176470588
D_73  :  0.7352941176470588
D_33  :  0.6903114186851211
D_32  :  0.5743944636678201
D_6  :  0.5173010380622838
D_76  :  0.4948096885813149
D_82  :  0.4671280276816609
D_79  :  0.41522491349480967
D_38  :  0.34602076124567477
D_39  :  0.33044982698961933

Key :  S_176
D_46  :  1.0
D_38  :  0.7428571428571429
D_0  :  0.6857142857142857
D_57  :  0.4857142857142857
D_18  :  0.42857142857142855
D_10  :  0.17142857142857143
D_26  :  0.05714285714285714

Key :  S_177
D_12  :  1.0
D_5  :  0.6265060240963856
D_55  :  0.5783132530120483
D_9  :  0.3373493975903615
D_29  :  0.3132530120481928

Key :  S_178
D_70  

In [12]:
symptom_most_common_disease

{'S_0': ['D_87', 'D_0', 'D_9'],
 'S_1': ['D_45', 'D_58', 'D_69'],
 'S_2': ['D_9'],
 'S_3': ['D_64'],
 'S_4': ['D_18', 'D_2', 'D_12'],
 'S_5': ['D_45', 'D_69', 'D_49'],
 'S_6': ['D_36', 'D_34', 'D_10'],
 'S_7': ['D_40', 'D_82'],
 'S_8': ['D_81'],
 'S_9': ['D_39'],
 'S_10': ['D_82', 'D_76', 'D_26'],
 'S_11': ['D_12', 'D_62', 'D_5'],
 'S_12': ['D_33', 'D_82', 'D_22'],
 'S_13': ['D_6', 'D_26', 'D_33'],
 'S_14': ['D_2'],
 'S_15': ['D_77', 'D_25'],
 'S_16': ['D_9', 'D_39', 'D_38'],
 'S_17': ['D_88'],
 'S_18': ['D_69', 'D_49', 'D_50'],
 'S_19': ['D_21', 'D_11', 'D_35'],
 'S_20': ['D_24'],
 'S_21': ['D_60', 'D_73'],
 'S_22': ['D_60'],
 'S_23': ['D_65', 'D_4', 'D_11'],
 'S_24': ['D_30'],
 'S_25': ['D_63', 'D_22', 'D_42'],
 'S_26': ['D_29'],
 'S_27': ['D_70', 'D_44'],
 'S_28': ['D_0', 'D_6'],
 'S_29': ['D_53', 'D_62', 'D_3'],
 'S_30': ['D_77', 'D_34', 'D_13'],
 'S_31': ['D_85', 'D_16', 'D_27'],
 'S_32': ['D_62', 'D_24', 'D_8'],
 'S_33': ['D_71', 'D_57'],
 'S_34': ['D_32'],
 'S_35': ['D_19'],
 'S

In [13]:
with open("new_dialog_symptom_map.p", "rb") as f:
    new_dialog_symptom_map = pickle.load(f)

In [14]:
new_dialog_symptom_map

{1059: ['spots or clouds in vision',
  'diminished vision',
  'symptoms of eye',
  'pain in eye'],
 19510: ['hip pain',
  'ache all over',
  'neck pain',
  'back pain',
  'low back pain',
  'shoulder pain'],
 25630: ['foreign body sensation in eye'],
 1467: ['side pain', 'back pain', 'low back pain'],
 5780: ['pain during pregnancy',
  'pain or soreness of breast',
  'excessive urination at night',
  'wrist pain',
  'facial pain',
  'joint stiffness or tightness',
  'shoulder cramps or spasms',
  'ankle pain',
  'knee lump or mass',
  'pain in eye'],
 26258: ['skin rash', 'skin dryness, peeling, scaliness, or roughness'],
 22733: ['depression',
  'fatigue',
  'cough',
  'sweating',
  'abnormal involuntary movements'],
 1564: ['shortness of breath', 'depression', 'abnormal involuntary movements'],
 29211: ['diminished hearing',
  'ear pain',
  'redness in ear',
  'plugged feeling in ear',
  'dizziness'],
 26235: ['anxiety and nervousness', 'sharp abdominal pain', 'abusing alcohol'],
 16

In [15]:
tokenizer  = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
bert_model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
bert_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [16]:
print(tokenizer.all_special_tokens)
print(tokenizer.all_special_ids)
p = {
        'additional_special_tokens' : ['[DOC]', '[PAT]', '[SR_START]', '[SR_END]']
    }

tokenizer.add_special_tokens(p)
print(tokenizer.convert_tokens_to_ids('[DOC]'))
print(tokenizer.convert_tokens_to_ids('[PAT]'))
print(tokenizer.convert_tokens_to_ids('[SR_START]'))
print(tokenizer.convert_tokens_to_ids('[SR_END]'))
print(len(tokenizer))
print(bert_model.resize_token_embeddings(len(tokenizer)))
print(tokenizer.all_special_tokens)
print(tokenizer.all_special_ids)

['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
[1, 3, 0, 2, 4]
30522
30523
30524
30525
30526
Embedding(30526, 768)
['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]', '[DOC]', '[PAT]', '[SR_START]', '[SR_END]']
[1, 3, 0, 2, 4, 30522, 30523, 30524, 30525]


In [17]:
with open("final_combined_all_dialog.p", "rb") as f:
    final_combined_all_dialog = pickle.load(f)

In [18]:
final_combined_all_dialog


[{'dialog': "[PAT] : Doctor, I see Spots or clouds in my vision. I can't see anything clearly. I feel helpless. [SEP] [DOC] : I understand how you feel. I need to ask a few questions to diagnose your disease. Have you got Diminished vision? [SEP] [PAT] : True, I am also dealing with Diminished vision. [SEP] [DOC] : Do you have Symptoms of eye? [SEP] [PAT] : Yes, I have got Symptoms of eye. [SEP] [DOC] : Have you felt Pain in eye? [SEP] [PAT] : Indeed, I am suffering from Pain in eye. [SEP] [DOC] : Your symptoms indicate that you have [MASK].",
  'label': 'central retinal artery or vein occlusion'},
 {'dialog': "[PAT] : Doctor, I have Shoulder pain. I feel terrible. Would you please tell me what could be its cause? [SEP] [DOC] :  I can feel the pain you feel. Let me ask you some questions. Do you have Back pain? [SEP] [PAT] : Literally, I am suffering from Back pain. [SEP] [DOC] : Do you have Low back pain? [SEP] [PAT] : Yes, I am dealing with Low back pain. [SEP] [DOC] : Are you also d

In [19]:
symptom_list = []
for i in sid_small:
    symptom_list.append(i)

In [20]:
symptom_list

['abnormal appearing skin',
 'abnormal involuntary movements',
 'abnormal movement of eyelid',
 'abnormal size or shape of ear',
 'absence of menstruation',
 'abusing alcohol',
 'ache all over',
 'acne or pimples',
 'allergic reaction',
 'ankle pain',
 'ankle stiffness or tightness',
 'ankle swelling',
 'anxiety and nervousness',
 'apnea',
 'arm lump or mass',
 'arm pain',
 'arm stiffness or tightness',
 'arm swelling',
 'arm weakness',
 'back cramps or spasms',
 'back mass or lump',
 'back pain',
 'back stiffness or tightness',
 'back weakness',
 'bladder mass',
 'bleeding from eye',
 'bleeding gums',
 'bleeding or discharge from nipple',
 'blindness',
 'blood in urine',
 'bones are painful',
 'bowlegged or knock-kneed',
 'burning abdominal pain',
 'chest tightness',
 'chills',
 'congestion in chest',
 'coryza',
 'cough',
 'cramps and spasms',
 'cross-eyed',
 'decreased appetite',
 'decreased heart rate',
 'delusions or hallucinations',
 'depression',
 'depressive or psychotic symptom

In [21]:
len(symptom_list)

266

In [22]:
disease_list = []
for i in did:
    disease_list.append(i)

In [23]:
disease_list.sort()
disease_list

['Acanthosis nigricans',
 'Acariasis',
 'Acne',
 'Actinic keratosis',
 'Acute glaucoma',
 'Acute kidney injury',
 'Acute stress reaction',
 'Adhesive capsulitis of the shoulder',
 'Adjustment reaction',
 'Air embolism',
 'Alcohol intoxication',
 'Allergy',
 'Alzheimer disease',
 'Amyloidosis',
 'Amyotrophic lateral sclerosis ALS',
 'Ankylosing spondylitis',
 'Anxiety',
 'Aphakia',
 'Carbon monoxide poisoning',
 'Carcinoid syndrome',
 'Carpal tunnel syndrome',
 'Cat scratch disease',
 'Central retinal artery or vein occlusion',
 'Cerebral edema',
 'Chagas disease',
 'Chalazion',
 'Chancroid',
 'Chickenpox',
 'Chlamydia',
 'Chondromalacia of the patella',
 'Chronic back pain',
 'Chronic kidney disease',
 'Chronic pain disorder',
 'Complex regional pain syndrome',
 'Concussion',
 'Conductive hearing loss',
 'Conjunctivitis due to allergy',
 'Connective tissue disorder',
 'Contact dermatitis',
 'Conversion disorder',
 'Corneal abrasion',
 'Corneal disorder',
 'Cushing syndrome',
 'Cyst of 

In [24]:
disease_to_id_small = {}

for i, j in enumerate(disease_list):
    key = j.lower()
    disease_to_id_small[key] = i

In [25]:
disease_to_id_small

{'acanthosis nigricans': 0,
 'acariasis': 1,
 'acne': 2,
 'actinic keratosis': 3,
 'acute glaucoma': 4,
 'acute kidney injury': 5,
 'acute stress reaction': 6,
 'adhesive capsulitis of the shoulder': 7,
 'adjustment reaction': 8,
 'air embolism': 9,
 'alcohol intoxication': 10,
 'allergy': 11,
 'alzheimer disease': 12,
 'amyloidosis': 13,
 'amyotrophic lateral sclerosis als': 14,
 'ankylosing spondylitis': 15,
 'anxiety': 16,
 'aphakia': 17,
 'carbon monoxide poisoning': 18,
 'carcinoid syndrome': 19,
 'carpal tunnel syndrome': 20,
 'cat scratch disease': 21,
 'central retinal artery or vein occlusion': 22,
 'cerebral edema': 23,
 'chagas disease': 24,
 'chalazion': 25,
 'chancroid': 26,
 'chickenpox': 27,
 'chlamydia': 28,
 'chondromalacia of the patella': 29,
 'chronic back pain': 30,
 'chronic kidney disease': 31,
 'chronic pain disorder': 32,
 'complex regional pain syndrome': 33,
 'concussion': 34,
 'conductive hearing loss': 35,
 'conjunctivitis due to allergy': 36,
 'connective 

In [26]:
symptom_tokenizer = tokenizer(symptom_list, padding = True, truncation = True, return_tensors = 'pt')
symptom_tokenizer['input_ids'].shape

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


torch.Size([266, 13])

In [27]:
symptom_embedding_map = {}
for i in range(len(symptom_list)):
    key = symptom_list[i]
    bert_output = bert_model(input_ids = symptom_tokenizer['input_ids'][i:i+1],
                            attention_mask = symptom_tokenizer['attention_mask'][i:i+1])
#     print(type(bert_output))
#     print(bert_output[0].shape)
#     print(bert_output[1].shape)
    print(bert_output[1].shape)
    temp5 = bert_output[1].squeeze() 
    print(temp5.shape)
    print(key)
    symptom_embedding_map[key] = temp5
    print(i)
    print()

torch.Size([1, 768])
torch.Size([768])
abnormal appearing skin
0

torch.Size([1, 768])
torch.Size([768])
abnormal involuntary movements
1

torch.Size([1, 768])
torch.Size([768])
abnormal movement of eyelid
2

torch.Size([1, 768])
torch.Size([768])
abnormal size or shape of ear
3

torch.Size([1, 768])
torch.Size([768])
absence of menstruation
4

torch.Size([1, 768])
torch.Size([768])
abusing alcohol
5

torch.Size([1, 768])
torch.Size([768])
ache all over
6

torch.Size([1, 768])
torch.Size([768])
acne or pimples
7

torch.Size([1, 768])
torch.Size([768])
allergic reaction
8

torch.Size([1, 768])
torch.Size([768])
ankle pain
9

torch.Size([1, 768])
torch.Size([768])
ankle stiffness or tightness
10

torch.Size([1, 768])
torch.Size([768])
ankle swelling
11

torch.Size([1, 768])
torch.Size([768])
anxiety and nervousness
12

torch.Size([1, 768])
torch.Size([768])
apnea
13

torch.Size([1, 768])
torch.Size([768])
arm lump or mass
14

torch.Size([1, 768])
torch.Size([768])
arm pain
15

torch.Size

torch.Size([1, 768])
torch.Size([768])
knee lump or mass
134

torch.Size([1, 768])
torch.Size([768])
knee pain
135

torch.Size([1, 768])
torch.Size([768])
knee stiffness or tightness
136

torch.Size([1, 768])
torch.Size([768])
knee swelling
137

torch.Size([1, 768])
torch.Size([768])
knee weakness
138

torch.Size([1, 768])
torch.Size([768])
lack of growth
139

torch.Size([1, 768])
torch.Size([768])
lacrimation
140

torch.Size([1, 768])
torch.Size([768])
leg cramps or spasms
141

torch.Size([1, 768])
torch.Size([768])
leg pain
142

torch.Size([1, 768])
torch.Size([768])
leg swelling
143

torch.Size([1, 768])
torch.Size([768])
leg weakness
144

torch.Size([1, 768])
torch.Size([768])
lip swelling
145

torch.Size([1, 768])
torch.Size([768])
long menstrual periods
146

torch.Size([1, 768])
torch.Size([768])
loss of sensation
147

torch.Size([1, 768])
torch.Size([768])
loss of sex drive
148

torch.Size([1, 768])
torch.Size([768])
low back pain
149

torch.Size([1, 768])
torch.Size([768])
low 

In [28]:
disease_tokenizer = tokenizer(disease_list, padding = True, truncation = True, return_tensors = 'pt')
disease_tokenizer['input_ids'].shape

torch.Size([90, 9])

In [29]:
disease_embedding_map = {}
for i in range(len(disease_list)):
    key = disease_list[i]
    bert_output = bert_model(input_ids = disease_tokenizer['input_ids'][i:i+1],
                            attention_mask = disease_tokenizer['attention_mask'][i:i+1])
    
    temp5 = bert_output[1].squeeze()
    print(bert_output[1].shape)
    print(temp5.shape)
    disease_embedding_map[key] = temp5
    print(key)
    print(i)
    print()
    

torch.Size([1, 768])
torch.Size([768])
Acanthosis nigricans
0

torch.Size([1, 768])
torch.Size([768])
Acariasis
1

torch.Size([1, 768])
torch.Size([768])
Acne
2

torch.Size([1, 768])
torch.Size([768])
Actinic keratosis
3

torch.Size([1, 768])
torch.Size([768])
Acute glaucoma
4

torch.Size([1, 768])
torch.Size([768])
Acute kidney injury
5

torch.Size([1, 768])
torch.Size([768])
Acute stress reaction
6

torch.Size([1, 768])
torch.Size([768])
Adhesive capsulitis of the shoulder
7

torch.Size([1, 768])
torch.Size([768])
Adjustment reaction
8

torch.Size([1, 768])
torch.Size([768])
Air embolism
9

torch.Size([1, 768])
torch.Size([768])
Alcohol intoxication
10

torch.Size([1, 768])
torch.Size([768])
Allergy
11

torch.Size([1, 768])
torch.Size([768])
Alzheimer disease
12

torch.Size([1, 768])
torch.Size([768])
Amyloidosis
13

torch.Size([1, 768])
torch.Size([768])
Amyotrophic lateral sclerosis ALS
14

torch.Size([1, 768])
torch.Size([768])
Ankylosing spondylitis
15

torch.Size([1, 768])
torch

In [30]:
def symptom_disease_edge_weight(begin, end):
#     print(begin)
#     print(end)
#     print(kg[begin])
    weight = kg[begin][end]
    return weight
    

In [31]:
def symptom_symptom_edge_weight(begin, end):
#     print(begin)
#     print(end)
#     print(kg[begin])
    weight = kg[begin][end]
#     weight = torch.tensor(weight)
    return weight
    

In [32]:
def make_edge(symp_list):
    edge_list = []
    weight_list = []
    symp_list_len = len(symp_list)
    
    symptom_disease_numbering = {}
    symptom_disease_numbering_with_ids = {}
    cnt = 1
    for i in range(symp_list_len):
        key = symp_list[i]
        symptom_disease_numbering[key] = cnt
        symptom_disease_numbering_with_ids[sid_small[key]] = cnt
        cnt+=1
    
    temp_disease_set = set()
    for i in symp_list:
        print(sid_small[i])
        print(symptom_most_common_disease[sid_small[i]])
        for y in symptom_most_common_disease[sid_small[i]]:
            temp_disease_set.add(y)
        
        print()
    print(temp_disease_set) 
    
    for i in temp_disease_set:
        symptom_disease_numbering_with_ids[i] = cnt
        cnt+=1
    
    for i in symp_list:
        for j in symp_list:
            if(i!=j):
                sid_begin = sid_small[i]
                sid_end = sid_small[j]
                weight = symptom_symptom_edge_weight(sid_begin, sid_end)
                weight_list.append(weight)
                
                num_begin = symptom_disease_numbering_with_ids[sid_begin]
                num_end = symptom_disease_numbering_with_ids[sid_end]
                
                edge = [num_begin, num_end]
                edge_list.append(edge)
                
                
    for i in symp_list:
        sid_symp = sid_small[i]
        print("---------Symptom_Disease----------")
        print(sid_symp)
        most_common = symptom_most_common_disease[sid_symp]
        print(most_common)
        for j in most_common:
            sid_begin = sid_small[i]
#             print("J : ", j)
#             print("DID J : ", did[j])
            did_end = j
            weight = symptom_disease_edge_weight(sid_begin, did_end)
            weight_list.append(weight)
            
            num_begin = symptom_disease_numbering_with_ids[sid_begin]
            num_end = symptom_disease_numbering_with_ids[did_end]
            
            edge = [num_begin, num_end]
            edge_list.append(edge)
            
            
    cls_dialog_token_id = 0
    for i in symp_list:
        weight = 1.0
        num_begin = cls_dialog_token_id
        num_end = symptom_disease_numbering_with_ids[sid_small[i]]
        
        weight_list.append(weight)
        
        edge = [num_begin, num_end]
        edge_list.append(edge)
        
        
        
    edge_list = torch.tensor(edge_list, dtype = torch.long)
    edge_list = edge_list.T
#     print("HELLO")
    print("Edge list : ",edge_list.shape)
    weight_list = torch.tensor(weight_list)
    print("weight list : ",weight_list.shape)
    
    return symptom_disease_numbering, symptom_disease_numbering_with_ids, edge_list, weight_list
    
    

In [33]:
d_s_list = new_dialog_symptom_map[1059]
d_s_list

['spots or clouds in vision',
 'diminished vision',
 'symptoms of eye',
 'pain in eye']

In [34]:
make_edge(d_s_list)

S_52
['D_80', 'D_17', 'D_41']

S_173
['D_80', 'D_17', 'D_50']

S_112
['D_43', 'D_18', 'D_48']

S_239
['D_83', 'D_78', 'D_49']

{'D_17', 'D_83', 'D_41', 'D_49', 'D_78', 'D_18', 'D_48', 'D_50', 'D_43', 'D_80'}
---------Symptom_Disease----------
S_52
['D_80', 'D_17', 'D_41']
---------Symptom_Disease----------
S_173
['D_80', 'D_17', 'D_50']
---------Symptom_Disease----------
S_112
['D_43', 'D_18', 'D_48']
---------Symptom_Disease----------
S_239
['D_83', 'D_78', 'D_49']
Edge list :  torch.Size([2, 28])
weight list :  torch.Size([28])


({'spots or clouds in vision': 1,
  'diminished vision': 2,
  'symptoms of eye': 3,
  'pain in eye': 4},
 {'S_52': 1,
  'S_173': 2,
  'S_112': 3,
  'S_239': 4,
  'D_17': 5,
  'D_83': 6,
  'D_41': 7,
  'D_49': 8,
  'D_78': 9,
  'D_18': 10,
  'D_48': 11,
  'D_50': 12,
  'D_43': 13,
  'D_80': 14},
 tensor([[ 1,  1,  1,  2,  2,  2,  3,  3,  3,  4,  4,  4,  1,  1,  1,  2,  2,  2,
           3,  3,  3,  4,  4,  4,  0,  0,  0,  0],
         [ 2,  3,  4,  1,  3,  4,  1,  2,  4,  1,  2,  3, 14,  5,  7, 14,  5, 12,
          13, 10, 11,  6,  9,  8,  1,  2,  3,  4]]),
 tensor([1.0000, 0.5014, 0.6444, 0.3974, 0.6385, 1.0000, 0.3109, 0.9966, 1.0000,
         0.2561, 1.0000, 0.6407, 1.0000, 0.9640, 0.9144, 1.0000, 0.8881, 0.8812,
         1.0000, 0.8591, 0.6906, 1.0000, 0.9153, 0.8998, 1.0000, 1.0000, 1.0000,
         1.0000]))

In [35]:
reverse_sid = {}
for i in sid_small:
    reverse_sid[sid_small[i]] = i

In [36]:
reverse_sid

{'S_237': 'abnormal appearing skin',
 'S_141': 'abnormal involuntary movements',
 'S_180': 'abnormal movement of eyelid',
 'S_242': 'abnormal size or shape of ear',
 'S_99': 'absence of menstruation',
 'S_41': 'abusing alcohol',
 'S_88': 'ache all over',
 'S_29': 'acne or pimples',
 'S_169': 'allergic reaction',
 'S_131': 'ankle pain',
 'S_87': 'ankle stiffness or tightness',
 'S_76': 'ankle swelling',
 'S_10': 'anxiety and nervousness',
 'S_53': 'apnea',
 'S_199': 'arm lump or mass',
 'S_77': 'arm pain',
 'S_250': 'arm stiffness or tightness',
 'S_245': 'arm swelling',
 'S_3': 'arm weakness',
 'S_35': 'back cramps or spasms',
 'S_189': 'back mass or lump',
 'S_73': 'back pain',
 'S_142': 'back stiffness or tightness',
 'S_94': 'back weakness',
 'S_123': 'bladder mass',
 'S_235': 'bleeding from eye',
 'S_216': 'bleeding gums',
 'S_46': 'bleeding or discharge from nipple',
 'S_262': 'blindness',
 'S_254': 'blood in urine',
 'S_158': 'bones are painful',
 'S_257': 'bowlegged or knock-kne

In [37]:
reverse_did = {}
for i in did:
    reverse_did[did[i]] = i

In [38]:
reverse_did

{'D_3': 'Acanthosis nigricans',
 'D_85': 'Acariasis',
 'D_53': 'Acne',
 'D_12': 'Actinic keratosis',
 'D_48': 'Acute glaucoma',
 'D_86': 'Acute kidney injury',
 'D_76': 'Acute stress reaction',
 'D_42': 'Adhesive capsulitis of the shoulder',
 'D_33': 'Adjustment reaction',
 'D_78': 'Air embolism',
 'D_47': 'Alcohol intoxication',
 'D_28': 'Allergy',
 'D_79': 'Alzheimer disease',
 'D_59': 'Amyloidosis',
 'D_4': 'Amyotrophic lateral sclerosis ALS',
 'D_75': 'Ankylosing spondylitis',
 'D_82': 'Anxiety',
 'D_50': 'Aphakia',
 'D_38': 'Carbon monoxide poisoning',
 'D_44': 'Carcinoid syndrome',
 'D_64': 'Carpal tunnel syndrome',
 'D_58': 'Cat scratch disease',
 'D_80': 'Central retinal artery or vein occlusion',
 'D_51': 'Cerebral edema',
 'D_61': 'Chagas disease',
 'D_37': 'Chalazion',
 'D_74': 'Chancroid',
 'D_27': 'Chickenpox',
 'D_23': 'Chlamydia',
 'D_84': 'Chondromalacia of the patella',
 'D_19': 'Chronic back pain',
 'D_7': 'Chronic kidney disease',
 'D_56': 'Chronic pain disorder',
 '

In [39]:
with open("new_dialog_id_list.p", "rb") as f:
    new_dialog_id_list = pickle.load(f)

In [40]:
new_dialog_id_list

[1059,
 19510,
 25630,
 1467,
 5780,
 26258,
 22733,
 1564,
 29211,
 26235,
 16427,
 16512,
 17308,
 16360,
 15081,
 15386,
 27441,
 9641,
 27768,
 21778,
 18706,
 7857,
 29589,
 15191,
 22535,
 6047,
 20116,
 26439,
 11928,
 20634,
 3590,
 3726,
 7899,
 9125,
 27079,
 7995,
 21990,
 12775,
 86,
 19263,
 19699,
 19386,
 23606,
 14328,
 14382,
 27516,
 17279,
 20111,
 19480,
 28758,
 7769,
 5013,
 26904,
 19940,
 3996,
 12006,
 6762,
 2874,
 9180,
 18054,
 12207,
 17117,
 20286,
 13025,
 791,
 10394,
 26863,
 11005,
 9671,
 11134,
 15276,
 16094,
 6624,
 6097,
 16311,
 339,
 16364,
 19563,
 4346,
 27942,
 7828,
 13246,
 16560,
 19534,
 120,
 8868,
 1947,
 28436,
 5509,
 24423,
 14480,
 13768,
 22417,
 21493,
 23450,
 23404,
 22550,
 20582,
 11716,
 18034,
 5402,
 4711,
 20581,
 7456,
 750,
 29975,
 28051,
 15254,
 25468,
 26528,
 275,
 10852,
 6183,
 23318,
 29262,
 14475,
 28630,
 621,
 18230,
 19696,
 15869,
 18797,
 880,
 17314,
 11154,
 10795,
 6953,
 29134,
 11948,
 28164,
 13969,


In [41]:
len(new_dialog_id_list)

1367

In [42]:
final_combined_all_dialog

[{'dialog': "[PAT] : Doctor, I see Spots or clouds in my vision. I can't see anything clearly. I feel helpless. [SEP] [DOC] : I understand how you feel. I need to ask a few questions to diagnose your disease. Have you got Diminished vision? [SEP] [PAT] : True, I am also dealing with Diminished vision. [SEP] [DOC] : Do you have Symptoms of eye? [SEP] [PAT] : Yes, I have got Symptoms of eye. [SEP] [DOC] : Have you felt Pain in eye? [SEP] [PAT] : Indeed, I am suffering from Pain in eye. [SEP] [DOC] : Your symptoms indicate that you have [MASK].",
  'label': 'central retinal artery or vein occlusion'},
 {'dialog': "[PAT] : Doctor, I have Shoulder pain. I feel terrible. Would you please tell me what could be its cause? [SEP] [DOC] :  I can feel the pain you feel. Let me ask you some questions. Do you have Back pain? [SEP] [PAT] : Literally, I am suffering from Back pain. [SEP] [DOC] : Do you have Low back pain? [SEP] [PAT] : Yes, I am dealing with Low back pain. [SEP] [DOC] : Are you also d

In [43]:
final_combined_all_dialog

dialog_id_dialog_map = {}
dialog_label_id_map = {}
for i,j in enumerate(new_dialog_id_list):
    dialog_id_dialog_map[j] = final_combined_all_dialog[i]['dialog']
    dialog_label_id_map[j] = final_combined_all_dialog[i]['label']

In [44]:
len(dialog_id_dialog_map.keys())

1367

In [45]:
dialog_id_dialog_map[1059]

"[PAT] : Doctor, I see Spots or clouds in my vision. I can't see anything clearly. I feel helpless. [SEP] [DOC] : I understand how you feel. I need to ask a few questions to diagnose your disease. Have you got Diminished vision? [SEP] [PAT] : True, I am also dealing with Diminished vision. [SEP] [DOC] : Do you have Symptoms of eye? [SEP] [PAT] : Yes, I have got Symptoms of eye. [SEP] [DOC] : Have you felt Pain in eye? [SEP] [PAT] : Indeed, I am suffering from Pain in eye. [SEP] [DOC] : Your symptoms indicate that you have [MASK]."

In [46]:
dialog_label_id_map[1059]

'central retinal artery or vein occlusion'

In [47]:
new_dialog_symptom_map

{1059: ['spots or clouds in vision',
  'diminished vision',
  'symptoms of eye',
  'pain in eye'],
 19510: ['hip pain',
  'ache all over',
  'neck pain',
  'back pain',
  'low back pain',
  'shoulder pain'],
 25630: ['foreign body sensation in eye'],
 1467: ['side pain', 'back pain', 'low back pain'],
 5780: ['pain during pregnancy',
  'pain or soreness of breast',
  'excessive urination at night',
  'wrist pain',
  'facial pain',
  'joint stiffness or tightness',
  'shoulder cramps or spasms',
  'ankle pain',
  'knee lump or mass',
  'pain in eye'],
 26258: ['skin rash', 'skin dryness, peeling, scaliness, or roughness'],
 22733: ['depression',
  'fatigue',
  'cough',
  'sweating',
  'abnormal involuntary movements'],
 1564: ['shortness of breath', 'depression', 'abnormal involuntary movements'],
 29211: ['diminished hearing',
  'ear pain',
  'redness in ear',
  'plugged feeling in ear',
  'dizziness'],
 26235: ['anxiety and nervousness', 'sharp abdominal pain', 'abusing alcohol'],
 16

In [48]:
full_dialog_data = []
# for i in range(2):
count = 161
for i in new_dialog_id_list[160:200]:
    print("Count : ",count)
    count+=1
#     dlog_id = dialog_id_list[i]
    dlog_id = i
    symptom_list = new_dialog_symptom_map[dlog_id]
    print(symptom_list)
    symptom_disease_numbering, symptom_disease_numbering_with_ids, edge_list, weight_list = make_edge(symptom_list)
    print(symptom_disease_numbering)
    print(symptom_disease_numbering_with_ids)
#     curr_dialog = total_sep_dialog[i]
    curr_dialog = dialog_id_dialog_map[dlog_id]
    dialog_tokenizer = tokenizer(curr_dialog, padding='max_length', max_length = 512, truncation=True, return_tensors = 'pt')
    print("Dialog tokenizer shape : ", dialog_tokenizer['input_ids'].shape)
    
    output = bert_model(**dialog_tokenizer)
    
#     full_graph_embedding = []
    
    dialog_embedding = output['last_hidden_state'].permute(1,0,2)[0]
    print("Dialog embedding shape : ", dialog_embedding.shape)
    
#     full_graph_embedding.append(dialog_embedding)
    full_graph_embedding = dialog_embedding
    
    curr_symptom_disease_list = []
    
    for k in symptom_disease_numbering_with_ids:
        if(k[0]=='S'):
            print("K : ",k)
            symptom_name = reverse_sid[k]
            print("symptom_name : ",symptom_name)
            s_embedding = symptom_embedding_map[symptom_name]
            temp_embedding = s_embedding.unsqueeze(0)
            print("S_embedding shape : ", temp_embedding.shape)
            full_graph_embedding = torch.cat((full_graph_embedding,temp_embedding), axis=0)
#             full_graph_embedding.append(s_embedding)
            
        elif(k[0]=='D'):
            print("K : ",k)
            disease_name = reverse_did[k]
            print("Disease_name : ", disease_name)
            d_embedding = disease_embedding_map[disease_name]
            temp_embedding = d_embedding.unsqueeze(0)
            print("D_embedding shape : ", temp_embedding.shape)
            full_graph_embedding = torch.cat((full_graph_embedding,temp_embedding), axis=0)
#             full_graph_embedding.append(d_embedding)
    
#     print(len(full_graph_embedding))
    
#     full_graph_embedding = torch.tensor(full_graph_embedding)
    print("Full_graph_embedding shape : ", full_graph_embedding.shape)
    
    disease_label = dialog_label_id_map[dlog_id]
    print("Disease label : ",disease_label)
    disease_label_number = disease_to_id_small[disease_label]
    disease_number_z = torch.zeros(1)
    disease_number_z[0] = torch.tensor(disease_label_number)
    
    print("Disease label number : ", disease_number_z)
    
    data = Data(x = full_graph_embedding, edge_index = edge_list, edge_attr = weight_list, y = disease_number_z)
    full_dialog_data.append(data)
    print("\n\n")
    
    

Count :  161
['loss of sex drive', 'infertility']
S_252
['D_77', 'D_44', 'D_70']

S_57
['D_25', 'D_3', 'D_36']

{'D_77', 'D_44', 'D_25', 'D_36', 'D_70', 'D_3'}
---------Symptom_Disease----------
S_252
['D_77', 'D_44', 'D_70']
---------Symptom_Disease----------
S_57
['D_25', 'D_3', 'D_36']
Edge list :  torch.Size([2, 10])
weight list :  torch.Size([10])
{'loss of sex drive': 1, 'infertility': 2}
{'S_252': 1, 'S_57': 2, 'D_77': 3, 'D_44': 4, 'D_25': 5, 'D_36': 6, 'D_70': 7, 'D_3': 8}
Dialog tokenizer shape :  torch.Size([1, 512])
Dialog embedding shape :  torch.Size([1, 768])
K :  S_252
symptom_name :  loss of sex drive
S_embedding shape :  torch.Size([1, 768])
K :  S_57
symptom_name :  infertility
S_embedding shape :  torch.Size([1, 768])
K :  D_77
Disease_name :  Erectile dysfunction
D_embedding shape :  torch.Size([1, 768])
K :  D_44
Disease_name :  Carcinoid syndrome
D_embedding shape :  torch.Size([1, 768])
K :  D_25
Disease_name :  Female infertility of unknown cause
D_embedding sh

Dialog embedding shape :  torch.Size([1, 768])
K :  S_4
symptom_name :  skin growth
S_embedding shape :  torch.Size([1, 768])
K :  S_74
symptom_name :  eyelid lesion or rash
S_embedding shape :  torch.Size([1, 768])
K :  S_200
symptom_name :  mass on eyelid
S_embedding shape :  torch.Size([1, 768])
K :  S_173
symptom_name :  diminished vision
S_embedding shape :  torch.Size([1, 768])
K :  S_239
symptom_name :  pain in eye
S_embedding shape :  torch.Size([1, 768])
K :  D_17
Disease_name :  Corneal disorder
D_embedding shape :  torch.Size([1, 768])
K :  D_83
Disease_name :  Fat embolism
D_embedding shape :  torch.Size([1, 768])
K :  D_37
Disease_name :  Chalazion
D_embedding shape :  torch.Size([1, 768])
K :  D_12
Disease_name :  Actinic keratosis
D_embedding shape :  torch.Size([1, 768])
K :  D_49
Disease_name :  Corneal abrasion
D_embedding shape :  torch.Size([1, 768])
K :  D_2
Disease_name :  Ganglion cyst
D_embedding shape :  torch.Size([1, 768])
K :  D_18
Disease_name :  Cyst of th

Dialog embedding shape :  torch.Size([1, 768])
K :  S_182
symptom_name :  hand or finger weakness
S_embedding shape :  torch.Size([1, 768])
K :  S_70
symptom_name :  problems with movement
S_embedding shape :  torch.Size([1, 768])
K :  S_84
symptom_name :  weakness
S_embedding shape :  torch.Size([1, 768])
K :  S_236
symptom_name :  difficulty in swallowing
S_embedding shape :  torch.Size([1, 768])
K :  D_22
Disease_name :  Essential tremor
D_embedding shape :  torch.Size([1, 768])
K :  D_64
Disease_name :  Carpal tunnel syndrome
D_embedding shape :  torch.Size([1, 768])
K :  D_4
Disease_name :  Amyotrophic lateral sclerosis ALS
D_embedding shape :  torch.Size([1, 768])
K :  D_59
Disease_name :  Amyloidosis
D_embedding shape :  torch.Size([1, 768])
K :  D_16
Disease_name :  Erythema multiforme
D_embedding shape :  torch.Size([1, 768])
K :  D_39
Disease_name :  Diabetic ketoacidosis
D_embedding shape :  torch.Size([1, 768])
K :  D_0
Disease_name :  Diabetes insipidus
D_embedding shape :

Dialog embedding shape :  torch.Size([1, 768])
K :  S_154
symptom_name :  neck pain
S_embedding shape :  torch.Size([1, 768])
K :  S_173
symptom_name :  diminished vision
S_embedding shape :  torch.Size([1, 768])
K :  S_71
symptom_name :  headache
S_embedding shape :  torch.Size([1, 768])
K :  S_178
symptom_name :  nausea
S_embedding shape :  torch.Size([1, 768])
K :  D_38
Disease_name :  Carbon monoxide poisoning
D_embedding shape :  torch.Size([1, 768])
K :  D_17
Disease_name :  Corneal disorder
D_embedding shape :  torch.Size([1, 768])
K :  D_52
Disease_name :  Concussion
D_embedding shape :  torch.Size([1, 768])
K :  D_50
Disease_name :  Aphakia
D_embedding shape :  torch.Size([1, 768])
K :  D_20
Disease_name :  Degenerative disc disease
D_embedding shape :  torch.Size([1, 768])
K :  D_70
Disease_name :  Granuloma inguinale
D_embedding shape :  torch.Size([1, 768])
K :  D_63
Disease_name :  Fibromyalgia
D_embedding shape :  torch.Size([1, 768])
K :  D_39
Disease_name :  Diabetic ke

Dialog embedding shape :  torch.Size([1, 768])
K :  S_95
symptom_name :  long menstrual periods
S_embedding shape :  torch.Size([1, 768])
K :  S_59
symptom_name :  lump or mass of breast
S_embedding shape :  torch.Size([1, 768])
K :  D_29
Disease_name :  Fibrocystic breast disease
D_embedding shape :  torch.Size([1, 768])
K :  D_71
Disease_name :  Endometrial cancer
D_embedding shape :  torch.Size([1, 768])
K :  D_25
Disease_name :  Female infertility of unknown cause
D_embedding shape :  torch.Size([1, 768])
Full_graph_embedding shape :  torch.Size([6, 768])
Disease label :  fibrocystic breast disease
Disease label number :  tensor([77.])



Count :  180
['cough']
S_0
['D_87', 'D_0', 'D_9']

{'D_0', 'D_9', 'D_87'}
---------Symptom_Disease----------
S_0
['D_87', 'D_0', 'D_9']
Edge list :  torch.Size([2, 4])
weight list :  torch.Size([4])
{'cough': 1}
{'S_0': 1, 'D_0': 2, 'D_9': 3, 'D_87': 4}
Dialog tokenizer shape :  torch.Size([1, 512])
Dialog embedding shape :  torch.Size([1, 768])
K

Dialog embedding shape :  torch.Size([1, 768])
K :  S_237
symptom_name :  abnormal appearing skin
S_embedding shape :  torch.Size([1, 768])
K :  S_166
symptom_name :  skin lesion
S_embedding shape :  torch.Size([1, 768])
K :  S_31
symptom_name :  skin rash
S_embedding shape :  torch.Size([1, 768])
K :  S_73
symptom_name :  back pain
S_embedding shape :  torch.Size([1, 768])
K :  S_204
symptom_name :  itching of skin
S_embedding shape :  torch.Size([1, 768])
K :  S_131
symptom_name :  ankle pain
S_embedding shape :  torch.Size([1, 768])
K :  D_27
Disease_name :  Chickenpox
D_embedding shape :  torch.Size([1, 768])
K :  D_30
Disease_name :  Flat feet
D_embedding shape :  torch.Size([1, 768])
K :  D_85
Disease_name :  Acariasis
D_embedding shape :  torch.Size([1, 768])
K :  D_55
Disease_name :  Contact dermatitis
D_embedding shape :  torch.Size([1, 768])
K :  D_83
Disease_name :  Fat embolism
D_embedding shape :  torch.Size([1, 768])
K :  D_19
Disease_name :  Chronic back pain
D_embedding

Dialog embedding shape :  torch.Size([1, 768])
K :  S_173
symptom_name :  diminished vision
S_embedding shape :  torch.Size([1, 768])
K :  D_50
Disease_name :  Aphakia
D_embedding shape :  torch.Size([1, 768])
K :  D_17
Disease_name :  Corneal disorder
D_embedding shape :  torch.Size([1, 768])
K :  D_80
Disease_name :  Central retinal artery or vein occlusion
D_embedding shape :  torch.Size([1, 768])
Full_graph_embedding shape :  torch.Size([5, 768])
Disease label :  acute glaucoma
Disease label number :  tensor([4.])



Count :  191
['knee pain']
S_163
['D_84', 'D_2', 'D_72']

{'D_2', 'D_72', 'D_84'}
---------Symptom_Disease----------
S_163
['D_84', 'D_2', 'D_72']
Edge list :  torch.Size([2, 4])
weight list :  torch.Size([4])
{'knee pain': 1}
{'S_163': 1, 'D_2': 2, 'D_72': 3, 'D_84': 4}
Dialog tokenizer shape :  torch.Size([1, 512])
Dialog embedding shape :  torch.Size([1, 768])
K :  S_163
symptom_name :  knee pain
S_embedding shape :  torch.Size([1, 768])
K :  D_2
Disease_name :  Gan

Dialog embedding shape :  torch.Size([1, 768])
K :  S_127
symptom_name :  vomiting
S_embedding shape :  torch.Size([1, 768])
K :  S_244
symptom_name :  dizziness
S_embedding shape :  torch.Size([1, 768])
K :  S_71
symptom_name :  headache
S_embedding shape :  torch.Size([1, 768])
K :  S_178
symptom_name :  nausea
S_embedding shape :  torch.Size([1, 768])
K :  D_38
Disease_name :  Carbon monoxide poisoning
D_embedding shape :  torch.Size([1, 768])
K :  D_85
Disease_name :  Acariasis
D_embedding shape :  torch.Size([1, 768])
K :  D_52
Disease_name :  Concussion
D_embedding shape :  torch.Size([1, 768])
K :  D_70
Disease_name :  Granuloma inguinale
D_embedding shape :  torch.Size([1, 768])
K :  D_39
Disease_name :  Diabetic ketoacidosis
D_embedding shape :  torch.Size([1, 768])
K :  D_0
Disease_name :  Diabetes insipidus
D_embedding shape :  torch.Size([1, 768])
K :  D_51
Disease_name :  Cerebral edema
D_embedding shape :  torch.Size([1, 768])
Full_graph_embedding shape :  torch.Size([12,

In [49]:
full_dialog_data

[Data(x=[9, 768], edge_index=[2, 10], edge_attr=[10], y=[1]),
 Data(x=[6, 768], edge_index=[2, 8], edge_attr=[8], y=[1]),
 Data(x=[21, 768], edge_index=[2, 70], edge_attr=[70], y=[1]),
 Data(x=[13, 768], edge_index=[2, 37], edge_attr=[37], y=[1]),
 Data(x=[18, 768], edge_index=[2, 40], edge_attr=[40], y=[1]),
 Data(x=[14, 768], edge_index=[2, 28], edge_attr=[28], y=[1]),
 Data(x=[13, 768], edge_index=[2, 27], edge_attr=[27], y=[1]),
 Data(x=[15, 768], edge_index=[2, 28], edge_attr=[28], y=[1]),
 Data(x=[13, 768], edge_index=[2, 27], edge_attr=[27], y=[1]),
 Data(x=[12, 768], edge_index=[2, 18], edge_attr=[18], y=[1]),
 Data(x=[14, 768], edge_index=[2, 28], edge_attr=[28], y=[1]),
 Data(x=[4, 768], edge_index=[2, 3], edge_attr=[3], y=[1]),
 Data(x=[12, 768], edge_index=[2, 26], edge_attr=[26], y=[1]),
 Data(x=[15, 768], edge_index=[2, 28], edge_attr=[28], y=[1]),
 Data(x=[10, 768], edge_index=[2, 18], edge_attr=[18], y=[1]),
 Data(x=[4, 768], edge_index=[2, 3], edge_attr=[3], y=[1]),
 D

In [50]:
len(full_dialog_data)

40

In [51]:
full_dialog_data[0].y

tensor([76.])

In [52]:
full_dialog_data[1].y

tensor([70.])

In [53]:
data_list = full_dialog_data

In [54]:
class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def processed_file_names(self):
        return ['data_161_to_200.pt']
    
    def process(self):
#         # Read data into huge `Data` list.
        data_list1 = data_list

        if self.pre_filter is not None:
            data_list1 = [data for data in data_list1 if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list1 = [self.pre_transform(data) for data in data_list1]

        data, slices = self.collate(data_list1)
        torch.save((data, slices), self.processed_paths[0])

In [55]:
dataset = MyOwnDataset(root="all_dialog_knowledge_add_three_disease/dialog_161_to_200/")
dataset

Processing...
Done!


MyOwnDataset(40)