In [58]:
import pickle
import random
import numpy as np

import torch

import transformers
from transformers import AutoTokenizer, AutoModel, AutoConfig

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GCNConv, GATConv
from torch_geometric.nn import global_mean_pool
from torch_geometric.data import InMemoryDataset, download_url

random.seed(123)
np.random.seed(123)
torch.manual_seed(123)
torch.cuda.manual_seed_all(123)


In [59]:
with open("SD++.p", "rb") as f:
    kg = pickle.load(f)

In [60]:
kg

{'S_237': {'S_31': 1.0,
  'S_166': 0.6103225806451613,
  'S_204': 0.46838709677419355,
  'S_11': 0.41419354838709677,
  'D_81': 1.0,
  'D_12': 0.9545454545454544,
  'S_227': 0.23483870967741935,
  'S_213': 0.22838709677419355,
  'S_29': 0.22193548387096773,
  'S_129': 0.2064516129032258,
  'D_27': 0.6245059288537549,
  'D_59': 0.5948616600790513,
  'D_5': 0.5869565217391304,
  'S_169': 0.18387096774193548,
  'S_168': 0.16903225806451613,
  'S_208': 0.1670967741935484,
  'S_4': 0.16193548387096773,
  'D_62': 0.458498023715415,
  'D_55': 0.40909090909090906,
  'S_0': 0.12129032258064516,
  'D_28': 0.33992094861660077,
  'S_165': 0.08903225806451613,
  'D_89': 0.2727272727272727,
  'S_226': 0.08516129032258063,
  'S_84': 0.08516129032258063,
  'S_10': 0.08451612903225807,
  'S_149': 0.08387096774193549,
  'D_53': 0.24703557312252963,
  'S_265': 0.06709677419354838,
  'S_177': 0.06709677419354838,
  'S_51': 0.06064516129032258,
  'S_54': 0.05548387096774193,
  'S_88': 0.05290322580645161,


In [61]:
len(kg.keys())

356

In [62]:
with open("SID.p", "rb") as f:
    sid = pickle.load(f)

In [63]:
sid

{'Abnormal appearing skin': 'S_237',
 'Abnormal involuntary movements': 'S_141',
 'Abnormal movement of eyelid': 'S_180',
 'Abnormal size or shape of ear': 'S_242',
 'Absence of menstruation': 'S_99',
 'Abusing alcohol': 'S_41',
 'Ache all over': 'S_88',
 'Acne or pimples': 'S_29',
 'Allergic reaction': 'S_169',
 'Ankle pain': 'S_131',
 'Ankle stiffness or tightness': 'S_87',
 'Ankle swelling': 'S_76',
 'Anxiety and nervousness': 'S_10',
 'Apnea': 'S_53',
 'Arm lump or mass': 'S_199',
 'Arm pain': 'S_77',
 'Arm stiffness or tightness': 'S_250',
 'Arm swelling': 'S_245',
 'Arm weakness': 'S_3',
 'Back cramps or spasms': 'S_35',
 'Back mass or lump': 'S_189',
 'Back pain': 'S_73',
 'Back stiffness or tightness': 'S_142',
 'Back weakness': 'S_94',
 'Bladder mass': 'S_123',
 'Bleeding from eye': 'S_235',
 'Bleeding gums': 'S_216',
 'Bleeding or discharge from nipple': 'S_46',
 'Blindness': 'S_262',
 'Blood in urine': 'S_254',
 'Bones are painful': 'S_158',
 'Bowlegged or knock-kneed': 'S_2

In [64]:
sid_small = {}

for i in sid:
    key = i
    value = sid[i]
    
    key = key.lower()
    sid_small[key] = value

In [65]:
sid_small

{'abnormal appearing skin': 'S_237',
 'abnormal involuntary movements': 'S_141',
 'abnormal movement of eyelid': 'S_180',
 'abnormal size or shape of ear': 'S_242',
 'absence of menstruation': 'S_99',
 'abusing alcohol': 'S_41',
 'ache all over': 'S_88',
 'acne or pimples': 'S_29',
 'allergic reaction': 'S_169',
 'ankle pain': 'S_131',
 'ankle stiffness or tightness': 'S_87',
 'ankle swelling': 'S_76',
 'anxiety and nervousness': 'S_10',
 'apnea': 'S_53',
 'arm lump or mass': 'S_199',
 'arm pain': 'S_77',
 'arm stiffness or tightness': 'S_250',
 'arm swelling': 'S_245',
 'arm weakness': 'S_3',
 'back cramps or spasms': 'S_35',
 'back mass or lump': 'S_189',
 'back pain': 'S_73',
 'back stiffness or tightness': 'S_142',
 'back weakness': 'S_94',
 'bladder mass': 'S_123',
 'bleeding from eye': 'S_235',
 'bleeding gums': 'S_216',
 'bleeding or discharge from nipple': 'S_46',
 'blindness': 'S_262',
 'blood in urine': 'S_254',
 'bones are painful': 'S_158',
 'bowlegged or knock-kneed': 'S_2

In [66]:
with open("DID.p", "rb") as f:
    did = pickle.load(f)

In [67]:
did

{'Acanthosis nigricans': 'D_3',
 'Acariasis': 'D_85',
 'Acne': 'D_53',
 'Actinic keratosis': 'D_12',
 'Acute glaucoma': 'D_48',
 'Acute kidney injury': 'D_86',
 'Acute stress reaction': 'D_76',
 'Adhesive capsulitis of the shoulder': 'D_42',
 'Adjustment reaction': 'D_33',
 'Air embolism': 'D_78',
 'Alcohol intoxication': 'D_47',
 'Allergy': 'D_28',
 'Alzheimer disease': 'D_79',
 'Amyloidosis': 'D_59',
 'Amyotrophic lateral sclerosis ALS': 'D_4',
 'Ankylosing spondylitis': 'D_75',
 'Anxiety': 'D_82',
 'Aphakia': 'D_50',
 'Carbon monoxide poisoning': 'D_38',
 'Carcinoid syndrome': 'D_44',
 'Carpal tunnel syndrome': 'D_64',
 'Cat scratch disease': 'D_58',
 'Central retinal artery or vein occlusion': 'D_80',
 'Cerebral edema': 'D_51',
 'Chagas disease': 'D_61',
 'Chalazion': 'D_37',
 'Chancroid': 'D_74',
 'Chickenpox': 'D_27',
 'Chlamydia': 'D_23',
 'Chondromalacia of the patella': 'D_84',
 'Chronic back pain': 'D_19',
 'Chronic kidney disease': 'D_7',
 'Chronic pain disorder': 'D_56',
 '

In [68]:
symptom_most_common_disease = {}
for j in range(266):
    key = "S_" + str(j)
    print("Key : ", key)
    q = kg[key]
    cnt = 0
    temp3 = []
    for i in q:
        
            
        if(i[0] == 'D' and q[i]>0):
#             if(cnt == 0):
#                 symptom_most_common_disease[key] = i
            if(cnt <= 0):
                temp3.append(i)    
                
            
            print(i, " : ",q[i])
            cnt+=1
    symptom_most_common_disease[key] = temp3        
    print() 

Key :  S_0
D_87  :  1.0
D_0  :  0.9424460431654677
D_9  :  0.5539568345323741
D_69  :  0.4508393285371703
D_27  :  0.4364508393285372
D_31  :  0.4244604316546763
D_28  :  0.4196642685851319
D_21  :  0.381294964028777
D_5  :  0.36211031175059955

Key :  S_1
D_45  :  1.0
D_58  :  0.9664948453608246
D_69  :  0.6726804123711341
D_37  :  0.6469072164948454
D_43  :  0.520618556701031
D_49  :  0.4587628865979381
D_28  :  0.3479381443298969
D_17  :  0.13917525773195877
D_55  :  0.1211340206185567

Key :  S_2
D_9  :  1.0

Key :  S_3
D_64  :  1.0

Key :  S_4
D_18  :  1.0
D_2  :  0.7434052757793764
D_12  :  0.47961630695443647
D_37  :  0.28776978417266186
D_81  :  0.2685851318944844
D_53  :  0.26378896882494
D_30  :  0.2302158273381295
D_3  :  0.16546762589928057

Key :  S_5
D_45  :  1.0
D_69  :  0.865350089766607
D_49  :  0.6894075403949732
D_17  :  0.5565529622980252
D_43  :  0.540394973070018
D_37  :  0.4272890484739677
D_50  :  0.35727109515260325
D_48  :  0.3267504488330341
D_27  :  0.181328

D_51  :  0.3825242718446602
D_86  :  0.37281553398058254
D_52  :  0.37087378640776697
D_47  :  0.32233009708737864

Key :  S_179
D_83  :  1.0
D_78  :  0.9228346456692913
D_61  :  0.3826771653543307
D_68  :  0.3653543307086614
D_54  :  0.36220472440944884
D_74  :  0.34960629921259845

Key :  S_180
D_43  :  1.0
D_18  :  0.4733009708737864
D_41  :  0.11893203883495146

Key :  S_181
D_60  :  1.0

Key :  S_182
D_4  :  1.0
D_64  :  0.3377777777777778

Key :  S_183
D_88  :  1.0
D_75  :  0.14156079854809436
D_52  :  0.08348457350272233

Key :  S_184
D_79  :  1.0
D_40  :  1.0

Key :  S_185
D_46  :  1.0
D_0  :  0.8571428571428572
D_85  :  0.7428571428571429
D_57  :  0.4857142857142857
D_73  :  0.4285714285714286

Key :  S_186
D_31  :  1.0

Key :  S_187
D_39  :  1.0

Key :  S_188
D_32  :  1.0

Key :  S_189
D_88  :  1.0

Key :  S_190
D_35  :  1.0
D_33  :  0.2380952380952381
D_79  :  0.19841269841269843

Key :  S_191
D_60  :  1.0

Key :  S_192
D_83  :  1.0
D_78  :  0.8998459167950694
D_61  :  0.355

In [69]:
symptom_most_common_disease

{'S_0': ['D_87'],
 'S_1': ['D_45'],
 'S_2': ['D_9'],
 'S_3': ['D_64'],
 'S_4': ['D_18'],
 'S_5': ['D_45'],
 'S_6': ['D_36'],
 'S_7': ['D_40'],
 'S_8': ['D_81'],
 'S_9': ['D_39'],
 'S_10': ['D_82'],
 'S_11': ['D_12'],
 'S_12': ['D_33'],
 'S_13': ['D_6'],
 'S_14': ['D_2'],
 'S_15': ['D_77'],
 'S_16': ['D_9'],
 'S_17': ['D_88'],
 'S_18': ['D_69'],
 'S_19': ['D_21'],
 'S_20': ['D_24'],
 'S_21': ['D_60'],
 'S_22': ['D_60'],
 'S_23': ['D_65'],
 'S_24': ['D_30'],
 'S_25': ['D_63'],
 'S_26': ['D_29'],
 'S_27': ['D_70'],
 'S_28': ['D_0'],
 'S_29': ['D_53'],
 'S_30': ['D_77'],
 'S_31': ['D_85'],
 'S_32': ['D_62'],
 'S_33': ['D_71'],
 'S_34': ['D_32'],
 'S_35': ['D_19'],
 'S_36': ['D_20'],
 'S_37': ['D_44'],
 'S_38': ['D_80'],
 'S_39': ['D_2'],
 'S_40': ['D_84'],
 'S_41': ['D_47'],
 'S_42': ['D_36'],
 'S_43': ['D_73'],
 'S_44': ['D_75'],
 'S_45': ['D_26'],
 'S_46': ['D_29'],
 'S_47': ['D_60'],
 'S_48': ['D_31'],
 'S_49': ['D_23'],
 'S_50': ['D_72'],
 'S_51': ['D_46'],
 'S_52': ['D_80'],
 'S_53': 

In [70]:
with open("new_self_report_symptom_map.p", "rb") as f:
    new_self_report_symptom_map = pickle.load(f)

In [71]:
new_self_report_symptom_map

{1059: ['spots or clouds in vision'],
 19510: ['shoulder pain'],
 25630: ['foreign body sensation in eye'],
 1467: ['low back pain'],
 5780: ['wrist pain'],
 26258: ['skin rash'],
 22733: ['depression'],
 1564: ['depression'],
 29211: ['plugged feeling in ear'],
 26235: ['anxiety and nervousness'],
 16427: ['dizziness'],
 16512: ['allergic reaction'],
 17308: ['pain or soreness of breast'],
 16360: ['vaginal pain'],
 15081: ['restlessness'],
 15386: ['pain in eye'],
 27441: ['skin rash'],
 9641: ['arm stiffness or tightness'],
 27768: ['skin growth'],
 21778: ['hot flashes'],
 18706: ['depressive or psychotic symptoms'],
 7857: ['abusing alcohol'],
 29589: ['leg pain'],
 15191: ['eye redness'],
 22535: ['blindness'],
 6047: ['depressive or psychotic symptoms'],
 20116: ['sharp abdominal pain'],
 26439: ['skin lesion', 'skin growth'],
 11928: ['back pain'],
 20634: ['hand or finger pain'],
 3590: ['retention of urine'],
 3726: ['neck pain'],
 7899: ['double vision'],
 9125: ['skin rash'

In [72]:
tokenizer  = AutoTokenizer.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
bert_model = AutoModel.from_pretrained("cambridgeltl/SapBERT-from-PubMedBERT-fulltext")
bert_model

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          

In [73]:
print(tokenizer.all_special_tokens)
print(tokenizer.all_special_ids)
p = {
        'additional_special_tokens' : ['[DOC]', '[PAT]', '[SR_START]', '[SR_END]']
    }

tokenizer.add_special_tokens(p)
print(tokenizer.convert_tokens_to_ids('[DOC]'))
print(tokenizer.convert_tokens_to_ids('[PAT]'))
print(tokenizer.convert_tokens_to_ids('[SR_START]'))
print(tokenizer.convert_tokens_to_ids('[SR_END]'))
print(len(tokenizer))
print(bert_model.resize_token_embeddings(len(tokenizer)))
print(tokenizer.all_special_tokens)
print(tokenizer.all_special_ids)

['[UNK]', '[SEP]', '[PAD]', '[CLS]', '[MASK]']
[1, 3, 0, 2, 4]
30522
30523
30524
30525
30526


KeyboardInterrupt: 

In [None]:
with open("final_combined_all_self_report.p", "rb") as f:
    final_combined_all_self_report = pickle.load(f)

In [None]:
final_combined_all_self_report

In [None]:
symptom_list = []
for i in sid_small:
    symptom_list.append(i)

In [None]:
symptom_list

In [None]:
len(symptom_list)

In [None]:
disease_list = []
for i in did:
    disease_list.append(i)

In [None]:
disease_list.sort()
disease_list

In [None]:
disease_to_id_small = {}

for i, j in enumerate(disease_list):
    key = j.lower()
    disease_to_id_small[key] = i

In [None]:
disease_to_id_small

In [None]:
symptom_tokenizer = tokenizer(symptom_list, padding = True, truncation = True, return_tensors = 'pt')
symptom_tokenizer['input_ids'].shape

In [None]:
symptom_embedding_map = {}
for i in range(len(symptom_list)):
    key = symptom_list[i]
    bert_output = bert_model(input_ids = symptom_tokenizer['input_ids'][i:i+1],
                            attention_mask = symptom_tokenizer['attention_mask'][i:i+1])
#     print(type(bert_output))
#     print(bert_output[0].shape)
#     print(bert_output[1].shape)
    print(bert_output[1].shape)
    temp5 = bert_output[1].squeeze() 
    print(temp5.shape)
    print(key)
    symptom_embedding_map[key] = temp5
    print(i)
    print()

In [None]:
disease_tokenizer = tokenizer(disease_list, padding = True, truncation = True, return_tensors = 'pt')
disease_tokenizer['input_ids'].shape

In [None]:
disease_embedding_map = {}
for i in range(len(disease_list)):
    key = disease_list[i]
    bert_output = bert_model(input_ids = disease_tokenizer['input_ids'][i:i+1],
                            attention_mask = disease_tokenizer['attention_mask'][i:i+1])
    
    temp5 = bert_output[1].squeeze()
    print(bert_output[1].shape)
    print(temp5.shape)
    disease_embedding_map[key] = temp5
    print(key)
    print(i)
    print()
    

In [None]:
def symptom_disease_edge_weight(begin, end):
#     print(begin)
#     print(end)
#     print(kg[begin])
    weight = kg[begin][end]
    return weight
    

In [None]:
def symptom_symptom_edge_weight(begin, end):
#     print(begin)
#     print(end)
#     print(kg[begin])
    weight = kg[begin][end]
#     weight = torch.tensor(weight)
    return weight
    

In [None]:
def make_edge(symp_list):
    edge_list = []
    weight_list = []
    symp_list_len = len(symp_list)
    
    symptom_disease_numbering = {}
    symptom_disease_numbering_with_ids = {}
    cnt = 1
    for i in range(symp_list_len):
        key = symp_list[i]
        symptom_disease_numbering[key] = cnt
        symptom_disease_numbering_with_ids[sid_small[key]] = cnt
        cnt+=1
    
    temp_disease_set = set()
    for i in symp_list:
        print(sid_small[i])
        print(symptom_most_common_disease[sid_small[i]])
        for y in symptom_most_common_disease[sid_small[i]]:
            temp_disease_set.add(y)
        
        print()
    print(temp_disease_set) 
    
    for i in temp_disease_set:
        symptom_disease_numbering_with_ids[i] = cnt
        cnt+=1
    
    for i in symp_list:
        for j in symp_list:
            if(i!=j):
                sid_begin = sid_small[i]
                sid_end = sid_small[j]
                weight = symptom_symptom_edge_weight(sid_begin, sid_end)
                weight_list.append(weight)
                
                num_begin = symptom_disease_numbering_with_ids[sid_begin]
                num_end = symptom_disease_numbering_with_ids[sid_end]
                
                edge = [num_begin, num_end]
                edge_list.append(edge)
                
                
    for i in symp_list:
        sid_symp = sid_small[i]
        print("---------Symptom_Disease----------")
        print(sid_symp)
        most_common = symptom_most_common_disease[sid_symp]
        print(most_common)
        for j in most_common:
            sid_begin = sid_small[i]
#             print("J : ", j)
#             print("DID J : ", did[j])
            did_end = j
            weight = symptom_disease_edge_weight(sid_begin, did_end)
            weight_list.append(weight)
            
            num_begin = symptom_disease_numbering_with_ids[sid_begin]
            num_end = symptom_disease_numbering_with_ids[did_end]
            
            edge = [num_begin, num_end]
            edge_list.append(edge)
            
            
    cls_dialog_token_id = 0
    for i in symp_list:
        weight = 1.0
        num_begin = cls_dialog_token_id
        num_end = symptom_disease_numbering_with_ids[sid_small[i]]
        
        weight_list.append(weight)
        
        edge = [num_begin, num_end]
        edge_list.append(edge)
        
        
        
    edge_list = torch.tensor(edge_list, dtype = torch.long)
    edge_list = edge_list.T
#     print("HELLO")
    print("Edge list : ",edge_list.shape)
    weight_list = torch.tensor(weight_list)
    print("weight list : ",weight_list.shape)
    
    return symptom_disease_numbering, symptom_disease_numbering_with_ids, edge_list, weight_list
    
    

In [None]:
d_s_list = new_self_report_symptom_map[1059]
d_s_list

In [None]:
make_edge(d_s_list)

In [None]:
reverse_sid = {}
for i in sid_small:
    reverse_sid[sid_small[i]] = i

In [None]:
reverse_sid

In [None]:
reverse_did = {}
for i in did:
    reverse_did[did[i]] = i

In [None]:
reverse_did

In [None]:
with open("new_dialog_id_list.p", "rb") as f:
    new_dialog_id_list = pickle.load(f)

In [None]:
new_dialog_id_list

In [None]:
len(new_dialog_id_list)

In [None]:
final_combined_all_self_report

In [None]:

dialog_id_dialog_map = {}
dialog_label_id_map = {}
for i,j in enumerate(new_dialog_id_list):
    dialog_id_dialog_map[j] = final_combined_all_self_report[i]['first_utterance']
    dialog_label_id_map[j] = final_combined_all_self_report[i]['label']

In [None]:
len(dialog_id_dialog_map.keys())

In [None]:
dialog_id_dialog_map[1059]

In [None]:
dialog_label_id_map[1059]

In [None]:
new_self_report_symptom_map

In [None]:
new_dialog_id_list[2:9]

In [None]:
new_dialog_id_list[2:5]

In [None]:
new_dialog_id_list[5:9]

In [None]:
full_dialog_data = []
# for i in range(2):
count = 951
for i in new_dialog_id_list[950:1000]:
    print("Count : ",count)
    count+=1
#     dlog_id = dialog_id_list[i]
    dlog_id = i
    symptom_list = new_self_report_symptom_map[dlog_id]
    print(symptom_list)
    symptom_disease_numbering, symptom_disease_numbering_with_ids, edge_list, weight_list = make_edge(symptom_list)
    print(symptom_disease_numbering)
    print(symptom_disease_numbering_with_ids)
#     curr_dialog = total_sep_dialog[i]
    curr_dialog = dialog_id_dialog_map[dlog_id]
    dialog_tokenizer = tokenizer(curr_dialog, padding='max_length', max_length = 512, truncation=True, return_tensors = 'pt')
    print("Dialog tokenizer shape : ", dialog_tokenizer['input_ids'].shape)
    
    output = bert_model(**dialog_tokenizer)
    
#     full_graph_embedding = []
    
    dialog_embedding = output['last_hidden_state'].permute(1,0,2)[0]
    print("Dialog embedding shape : ", dialog_embedding.shape)
    
#     full_graph_embedding.append(dialog_embedding)
    full_graph_embedding = dialog_embedding
    
    curr_symptom_disease_list = []
    
    for k in symptom_disease_numbering_with_ids:
        if(k[0]=='S'):
            print("K : ",k)
            symptom_name = reverse_sid[k]
            print("symptom_name : ",symptom_name)
            s_embedding = symptom_embedding_map[symptom_name]
            temp_embedding = s_embedding.unsqueeze(0)
            print("S_embedding shape : ", temp_embedding.shape)
            full_graph_embedding = torch.cat((full_graph_embedding,temp_embedding), axis=0)
#             full_graph_embedding.append(s_embedding)
            
        elif(k[0]=='D'):
            print("K : ",k)
            disease_name = reverse_did[k]
            print("Disease_name : ", disease_name)
            d_embedding = disease_embedding_map[disease_name]
            temp_embedding = d_embedding.unsqueeze(0)
            print("D_embedding shape : ", temp_embedding.shape)
            full_graph_embedding = torch.cat((full_graph_embedding,temp_embedding), axis=0)
#             full_graph_embedding.append(d_embedding)
    
#     print(len(full_graph_embedding))
    
#     full_graph_embedding = torch.tensor(full_graph_embedding)
    print("Full_graph_embedding shape : ", full_graph_embedding.shape)
    
    disease_label = dialog_label_id_map[dlog_id]
    print("Disease label : ",disease_label)
    disease_label_number = disease_to_id_small[disease_label]
    disease_number_z = torch.zeros(1)
    disease_number_z[0] = torch.tensor(disease_label_number)
    
    print("Disease label number : ", disease_number_z)
    
    data = Data(x = full_graph_embedding, edge_index = edge_list, edge_attr = weight_list, y = disease_number_z)
    full_dialog_data.append(data)
    print("\n\n")
    
    

In [None]:
len(full_dialog_data)

In [None]:
full_dialog_data[0].y

In [None]:
full_dialog_data[1].y

In [None]:
data_list = full_dialog_data

In [None]:
class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])
        
    @property
    def processed_file_names(self):
        return ['data_951_to_1000.pt']
    
    def process(self):
#         # Read data into huge `Data` list.
        data_list1 = data_list

        if self.pre_filter is not None:
            data_list1 = [data for data in data_list1 if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list1 = [self.pre_transform(data) for data in data_list1]

        data, slices = self.collate(data_list1)
        torch.save((data, slices), self.processed_paths[0])

In [None]:
dataset = MyOwnDataset(root="all_self_report_knowledge_add_one_disease/dialog_951_to_1000/")
dataset