In [1]:
import sys
import cPickle as pickle
from datetime import datetime
import numpy as np
import random
import math

In [2]:
# icd-9编码处理

# 还原icd，保留最后两位小数点
def convert_to_icd9(dxStr):
    if dxStr.startswith('E'):
        if len(dxStr) > 4:
            return dxStr[:4] + '.' + dxStr[4:]
        else:
            return dxStr
    else:
        if len(dxStr) > 3:
            return dxStr[:3] + '.' + dxStr[3:]
        else:
            return dxStr
        
# 只取前三位数据，不保留小数点
def convert_to_3digit_icd9(dxStr):
    if dxStr.startswith('E'):
        if len(dxStr) > 4:
            return dxStr[:4]
        else:
            return dxStr
    else:
        if len(dxStr) > 3:
            return dxStr[:3]
        else:
            return dxStr

In [3]:
# 文件输入和输出路径
admissionFile = '数据/ADMISSIONS.csv'
diagnosisFile = '数据/DIAGNOSES_ICD.csv'
patientinfoFile='数据/PATIENTS.csv'
outFile = 'SEQ_full'

In [4]:
# pidAdmMap映射pid到admid，admDateMap映射admid到admtime
print 'Building pid-admission mapping, pid-death mapping, admission-date mapping'
all_death_num=0
pidAdmMap = {}
admDateMap = {}
pidDeathMap= {}
infd = open(admissionFile, 'r')
infd.readline()
for line in infd:
    tokens = line.strip().split(',')
    pid = int(tokens[1])
    admId = int(tokens[2])
    admTime = datetime.strptime(tokens[3], '%Y-%m-%d %H:%M:%S')
    admDateMap[admId] = admTime
    if pid in pidAdmMap:
        pidAdmMap[pid].append(admId)
    else:
        pidAdmMap[pid] = [admId]
    if tokens[5] != '':
        pidDeathMap[pid]=1
        all_death_num+=1
    else:
        if pid not in pidDeathMap:
            pidDeathMap[pid]=0      
infd.close()

Building pid-admission mapping, pid-death mapping, admission-date mapping


In [5]:
# pidGenBirthMap人的基本信息（性别、年龄、婚姻、人种、宗教），先使用性别和年龄
print 'Building patient information'
pidGenBirthMap={}
infd = open(patientinfoFile, 'r')
infd.readline();
for line in infd:
    tokens = line.strip().split(',')
    pid = int(tokens[1])
    if tokens[2]=='F':gender = 0
    else: gender = 1
    birth = datetime.strptime(tokens[3], '%Y-%m-%d %H:%M:%S')
    pidGenBirthMap[pid] = [gender,birth]
infd.close()

Building patient information


In [6]:
# 创建admission-dxList mapping，admDxMap映射admid到在这个住院中的诊断的icd9码
print 'Building admission-dxList mapping'
admDxMap = {}
admDxMap_3digit = {}
dx_icd={} #icd和dx的映射
dx_icd2={} #前面是dx码，后面是icd码
infd = open(diagnosisFile, 'r')
infd.readline()
raw_all_codes_num=0
for line in infd:
    raw_all_codes_num+=1
    tokens = line.strip().split(',')
    admId = int(tokens[2])
    dxStr = 'D_' + convert_to_icd9(tokens[4])
    dxStr_3digit = 'D_' + convert_to_3digit_icd9(tokens[4])
    
    if admId in admDxMap:
        admDxMap[admId].append(dxStr)
    else:
        admDxMap[admId] = [dxStr]

    if admId in admDxMap_3digit:
        admDxMap_3digit[admId].append(dxStr_3digit)
    else:
        admDxMap_3digit[admId] = [dxStr_3digit]
        
    if dxStr in dx_icd:
        continue
    else:
        dx_icd[dxStr]=tokens[4]
        dx_icd2[tokens[4]]=dxStr
infd.close()

Building admission-dxList mapping


In [7]:
# sortedList按照时间储存了[时间，icd9诊断]对，pidSeqMap存入每个病人的sortedList
print 'Building pid-sortedVisits mapping'
pidSeqMap = {}
pidSeqMap_3digit = {}
for pid, admIdList in pidAdmMap.iteritems():
    if len(admIdList) < 2: continue

    sortedList = sorted([(admDateMap[admId], admDxMap[admId]) for admId in admIdList])
    pidSeqMap[pid] = sortedList

    sortedList_3digit = sorted([(admDateMap[admId], admDxMap_3digit[admId]) for admId in admIdList])
    pidSeqMap_3digit[pid] = sortedList_3digit

Building pid-sortedVisits mapping


# 提取HF的病人记录

In [36]:
# 心力衰竭的病人记录
HF_patient={}
for i in pidSeqMap:
    for j in range(len(pidSeqMap[i])):
        code=pidSeqMap[i][j][1]
        for dx in code:
            if (dx == 'D_"42.80"') or (dx == 'D_"42.81"')or(dx == 'D_"42.82"')or(dx == 'D_"42.83"')or(dx == 'D_"42.84"')or (dx == 'D_"42.89"')or(dx == 'D_"42.820"')or(dx == 'D_"42.821"')or(dx == 'D_"42.822"')or(dx == 'D_"42.823"')or(dx == 'D_"42.830"')or(dx == 'D_"42.831"')or(dx == 'D_"42.832"')or(dx == 'D_"42.833"')or(dx == 'D_"42.840"')or(dx == 'D_"42.841"')or(dx == 'D_"42.842"')or(dx == 'D_"42.843"'):
                if i in HF_patient:
                    continue
                else:
                    HF_patient[i]=pidSeqMap[i]

In [37]:
len(HF_patient)

3370

In [74]:
# 心力衰竭的住院记录
HF_record={}
record_code=[]
last_adm=[]
n=0
for i in pidSeqMap:
    for j in range(len(pidSeqMap[i])):
        code=pidSeqMap[i][j][1]
        for dx in code:
            if (dx == 'D_"42.80"') or (dx == 'D_"42.81"')or(dx == 'D_"42.82"')or(dx == 'D_"42.83"')or(dx == 'D_"42.84"')or (dx == 'D_"42.89"')or(dx == 'D_"42.820"')or(dx == 'D_"42.821"')or(dx == 'D_"42.822"')or(dx == 'D_"42.823"')or(dx == 'D_"42.830"')or(dx == 'D_"42.831"')or(dx == 'D_"42.832"')or(dx == 'D_"42.833"')or(dx == 'D_"42.840"')or(dx == 'D_"42.841"')or(dx == 'D_"42.842"')or(dx == 'D_"42.843"'):
                if (j==len(pidSeqMap[i])-1):
                    last_adm.append(i)
                HF_record[n]=pidSeqMap[i][j]
                n=n+1
                record_code.append(i)           

In [75]:
len(last_adm)

4323

In [76]:
#记录的总个数
len(record_code)

10760

In [77]:
# 最后一次死亡的
death_code_num=0
for i in range(len(last_adm)):
    if pidDeathMap[last_adm[i]]==1:
        death_code_num+=1

In [78]:
death_code_num

983

# 提取Sepsis的病人记录

In [57]:
# 败血症衰竭的病人记录
Sep_patient={}
for i in pidSeqMap:
    for j in range(len(pidSeqMap[i])):
        code=pidSeqMap[i][j][1]
        for dx in code:
            if (dx == 'D_"99.592"') :
                if i in HF_patient:
                    continue
                else:
                    Sep_patient[i]=pidSeqMap[i]

In [58]:
#患病病人个数
len(Sep_patient)

624

In [88]:
# 败血症的住院记录
Sep_record={}
sep_record_code=[]
sep_last_ad=[]
n=0
for i in pidSeqMap:
    for j in range(len(pidSeqMap[i])):
        code=pidSeqMap[i][j][1]
        for dx in code:
            if (dx == 'D_"99.591"') or(dx == 'D_"99.592"') or (dx == 'D_"99.593"')or  (dx == 'D_"99.594"'):
                if (j==len(pidSeqMap[i])-1):
                    sep_last_ad.append(i)
                Sep_record[n]=pidSeqMap[i][j]
                n=n+1
                sep_record_code.append(i)

In [89]:
len(sep_record_code)

2499

In [90]:
len(sep_last_ad)

1260

In [91]:
# 死亡的
sep_death_code_num=0
for i in range(len(sep_last_ad)):
    if pidDeathMap[sep_last_ad[i]]==1:
        sep_death_code_num+=1

In [92]:
sep_death_code_num

561

In [27]:
1# 创建病人id--pid；visit时间--date；诊断--seq, 每个病人visit的个数--visits_num, 每个visit中code的个数codes_num，病人性别和出生日期patientsinfo
print 'Building pids, dates, strSeqs, visits_num, codes_num, patientsinfo, death_labels'
death_num=0 #住院中死亡的人数
pids = []
dates = []
seqs = []
visits_num = []
codes_num = []
all_codes_num = []
patientsinfo = [] # 性别和出生日期
death_labels=[]
for pid, visits in pidSeqMap.iteritems():
    pids.append(pid)
    visits_num.append([len(visits)])
    seq = []
    date = []
    code_num = []
    all_code_num=0
    patientsinfo.append(pidGenBirthMap[pid])
    if pidDeathMap[pid] == 1:
        death_labels.append([0,1])
        death_num+=1
    else: death_labels.append([1,0])
    for visit in visits:
        date.append(visit[0])
        seq.append(visit[1])
        code_num.append(len(visit[1]))
        all_code_num+=len(visit[1])
    dates.append(date)
    seqs.append(seq)
    codes_num.append(code_num)
    all_codes_num.append(all_code_num)
# 创建病人id字典，重新编号
type_pids={}
for i in range(len(pids)):
    type_pids[pids[i]]=i
    
# 3digital
print 'Building pids, dates, strSeqs for 3digit ICD9 code'
seqs_3digit = []
for pid, visits in pidSeqMap_3digit.iteritems():
    seq = []
    for visit in visits:
        seq.append(visit[1])
    seqs_3digit.append(seq)

Building pids, dates, strSeqs, visits_num, codes_num, patientsinfo, death_labels
Building pids, dates, strSeqs for 3digit ICD9 code


In [28]:
death_num

1462

In [29]:
# 对icd9重新编码，从0开始，获取newSeqs和types对应表
print 'Converting strSeqs to intSeqs, and making types'
types = {}
newSeqs = []
for patient in seqs:
    newPatient = []
    for visit in patient:
        newVisit = []
        for code in visit:
            if code in types:
                newVisit.append(types[code])
            else:
                types[code] = len(types)
                newVisit.append(types[code])
        newPatient.append(newVisit)
    newSeqs.append(newPatient)

# 同上
print 'Converting strSeqs to intSeqs, and making types for 3digit ICD9 code'
types_3digit = {}
newSeqs_3digit = []
for patient in seqs_3digit:
    newPatient = []
    for visit in patient:
        newVisit = []
        for code in set(visit):
            if code in types_3digit:
                newVisit.append(types_3digit[code])
            else:
                types_3digit[code] = len(types_3digit)
                newVisit.append(types_3digit[code])
        newPatient.append(newVisit)
    newSeqs_3digit.append(newPatient)

Converting strSeqs to intSeqs, and making types
Converting strSeqs to intSeqs, and making types for 3digit ICD9 code


In [30]:
# 把病人之间的记录用-1隔开
print 'Re-formatting seqs'
seqs = []
for patient in newSeqs:
    seqs.extend(patient)
    seqs.append([-1])
seqs = seqs[:-1]

seqs_3digit = []
for patient in newSeqs_3digit:
    seqs_3digit.extend(patient)
    seqs_3digit.append([-1])
seqs_3digit = seqs_3digit[:-1]

# 把时间用-1隔开
print 'Re-formatting dates'
dates2 = []
for d in dates:
    dates2.extend([d])
    dates2.append([-1])
dates2 = dates2[:-1]

Re-formatting seqs
Re-formatting dates


In [31]:
# 每个病人的每个visit的info,性别，年龄（当前visit的时间-病人出生日期）
visitspatientsinfo = []
for i in range(len(visits_num)):
    visitspatientinfo = []
    for j in range(visits_num[i][0]):
        age=(dates[i][j] - patientsinfo[i][1]).days/365
        if age==300: age=90
        visitpatientinfo = [patientsinfo[i][0], age]
        visitspatientinfo.append(visitpatientinfo)
    visitspatientsinfo.append(visitspatientinfo)

In [32]:
# visit的时间差，作为TMGRUAE网络的visit层的时间输入
print 'visit_delt_dates'
visit_delt_dates=[]
for date in dates:
    if date == [-1] : continue
    visit_delt_date=[]
    for d in range(len(date)):
        if d==0:visit_delt_date.append(0)
        else: visit_delt_date.append((date[d]-date[d-1]).days)
    visit_delt_dates.append(visit_delt_date)

visit_delt_dates


In [33]:
max_day=0
for i in range(len(visit_delt_dates)):
    for j in range (len(visit_delt_dates[i])):
        if max_day<visit_delt_dates[i][j]:
            max_day=visit_delt_dates[i][j]
        

In [34]:
max_day/12

343

In [35]:
# 一个patient的整个序列的时间，不再区分visit，作为TMGRUAE网络的encoder两层的时间输入
print 'code_delt_dates'
code_delt_dates=[]
for i in range(len(codes_num)):
    newdeltdates = []
    newdeltdate = []
    for j in range(len(codes_num[i])):
        newdeltdate.append (visit_delt_dates[i][j])
        for k in range(codes_num[i][j]-1):
            newdeltdate.append(0)
    newdeltdates.extend(newdeltdate)
    code_delt_dates.append(newdeltdates)
    
print 'codespatientsinfo'
codespatientsinfo=[]
for i in range(len(codes_num)):
    newcodespatientinfo = []
    newcodepatientinfo = []
    for j in range(len(codes_num[i])):
        newcodepatientinfo.append (visitspatientsinfo[i][j])
        for k in range(codes_num[i][j]-1):
            newcodepatientinfo.append([0,0])
    newcodespatientinfo.extend(newcodepatientinfo)
    codespatientsinfo.append(newcodespatientinfo)


code_delt_dates
codespatientsinfo


In [36]:
# code级别的每次住院的年龄
# 做差

In [37]:
print 'codeEach__dates'
codeEach__dates=[]
for i in range(len(visitspatientsinfo)):
    newdeltdates = []
    newdeltdate = []
    for j in range(len(codes_num[i])):
        newdeltdate.append(visitspatientsinfo[i][j][1])
        for k in range(codes_num[i][j]-1):
            newdeltdate.append(visitspatientsinfo[i][j][1])
    newdeltdates.extend(newdeltdate)
    codeEach__dates.append(newdeltdates)
    

codeEach__dates


In [38]:
codeEachOther_delt_dates=[]
for i in range(len(codeEach__dates)):
    patient_codedeltadate=[]
    for j in range(len(codeEach__dates[i])):
        eachcode=[]
        for k in range(len(codeEach__dates[i])):
            eachcode.append(float(abs(codeEach__dates[i][k]-codeEach__dates[i][j])))
        patient_codedeltadate.append(eachcode)
        
    codeEachOther_delt_dates.append(patient_codedeltadate)
            

In [39]:
# 一个病人所有的code，不区分visit
print 'patient_code'
patient_code = []
for patient in newSeqs:
    visit_code=[]
    for visit in patient:
        visit_code.extend(visit)
    patient_code.append(visit_code)

patient_codes_3digit = []
for patient in newSeqs_3digit:
    visit_code=[]
    for visit in patient:
        visit_code.extend(visit)
    patient_codes_3digit.append(visit_code)

patient_code


In [40]:
types2 = {value:key for key,value in types.items()}
types_3digit2= {value:key for key,value in types_3digit.items()}
type_pids2= {value:key for key,value in type_pids.items()}

In [41]:
# 保存序列化
pickle.dump(patient_code, open(outFile + '/patient_code' + '.seqs', 'wb'), -1) # 一个病人所有的code，不区分visit
pickle.dump(patientsinfo, open(outFile + '/patientsinfo' + '.seqs', 'wb'), -1) # 病人的信息（性别和出生日期）
pickle.dump(visitspatientsinfo, open(outFile + '/visitspatientsinfo' + '.seqs', 'wb'), -1)  # 每个病人的每个visit的info 
pickle.dump(codespatientsinfo, open(outFile + '/codespatientsinfo' + '.seqs', 'wb'), -1)  # 每个病人的每个code的info和code_delt_dates类似
pickle.dump(code_delt_dates, open(outFile + '/code_delt_dates' + '.seqs', 'wb'), -1)  # code的时间差，一个patient的所有code的时间差
pickle.dump(visit_delt_dates, open(outFile + '/visit_delt_dates' + '.seqs', 'wb'), -1) #visit的时间差，一个patient的visit第一个时间差是0
pickle.dump(visits_num, open(outFile + '/visits_num' + '.seqs', 'wb'), -1)  # 每个patient的visit的个数visits_num
pickle.dump(codes_num, open(outFile + '/codes_num' + '.seqs', 'wb'), -1)  # 每个visit下code的个数codes_num
pickle.dump(all_codes_num, open(outFile + '/all_codes_num' + '.seqs', 'wb'), -1)  # 每个patient下code的个数codes_num
pickle.dump(pidSeqMap, open(outFile + '/pidSeqMap' + '.seqs', 'wb'), -1)  # pidSeqMap
pickle.dump(pids, open(outFile + '/pids' + '.seqs', 'wb'), -1)  # pids
pickle.dump(dates, open(outFile + '/dates' + '.seqs', 'wb'), -1)  # dates
pickle.dump(dates2, open(outFile + '/dates2' + '.seqs', 'wb'), -1)  # dates2
pickle.dump(newSeqs, open(outFile + '/newSeqs' + '.seqs', 'wb'), -1) # 每个病人的visit下的code
pickle.dump(seqs, open(outFile + '/seqs' + '.seqs', 'wb'), -1) # 每个病人的visit下的code,-1隔开
pickle.dump(types, open(outFile + '/types' + '.seqs', 'wb'), -1) # code的新编码和原编码对照表
pickle.dump(types2, open(outFile + '/types2' + '.seqs', 'wb'), -1) # code的新编码和原编码对照表
pickle.dump(type_pids, open(outFile + '/type_pids' + '.seqs', 'wb'), -1) #type_pids
pickle.dump(type_pids2, open(outFile + '/type_pids2' + '.seqs', 'wb'), -1) #type_pids
pickle.dump(death_labels, open(outFile + '/death_labels' + '.seqs', 'wb'), -1) #death_labels
pickle.dump(codeEachOther_delt_dates, open(outFile + '/codeEachOther_delt_dates' + '.seqs', 'wb'), -1) #每两个code之间的时间差

pickle.dump(patient_codes_3digit, open(outFile + '/patient_codes_3digit' + '.seqs', 'wb'), -1) # 一个病人所有的3code，不区分visit
pickle.dump(pidSeqMap_3digit, open(outFile + '/pidSeqMap_3digit' + '.3digitICD9.types', 'wb'), -1)
pickle.dump(pids, open(outFile + '/pids' + '.3digitICD9.types', 'wb'), -1)
pickle.dump(dates, open(outFile + '/dates' + '.3digitICD9.types', 'wb'), -1)
pickle.dump(seqs_3digit, open(outFile + '/seqs'+ '.3digitICD9.seqs', 'wb'), -1)
pickle.dump(types_3digit, open(outFile + '/types' + '.3digitICD9.seqs', 'wb'), -1)
pickle.dump(types_3digit2, open(outFile + '/types' + '.3digit2ICD9.seqs', 'wb'), -1)

IOError: [Errno 2] No such file or directory: 'SEQ_full/patient_code.seqs'

In [25]:
#生成一个patient的padding 后的code的彼此的时间差
def get_codeEachOther_delt_dates(no,visit_num,code_num):
    
    codeEachOther_delt_dates_file=open(outFile + '/codeEachOther_delt_dates' + '.seqs','rb')
    codeEachOther_delt_dates=pickle.load(codeEachOther_delt_dates_file)
    
    codeEachOther_delt_dates_file.close()
    a=np.zeros([visit_num*code_num,visit_num*code_num])
    for i in range(visit_num*code_num):
        for j in range(visit_num*code_num):

            indexi=0
            indexj=0

            if codes_num[no][i/code_num] >= code_num:#j/code_num是在那个visit里面
                indexi=i%code_num
                for k in range(i/code_num):
                    indexi+=codes_num[no][k]

                if codes_num[no][j/code_num]>=code_num:
                    indexj=j%code_num
                    for k in range(j/code_num):
                        indexj+=codes_num[no][k]
                    a[i][j]=codeEachOther_delt_dates[no][indexi][indexj]
                else:
                    indexj=j%code_num
                    if indexj>(codes_num[no][j/code_num]-1): 
                        a[i][j]=0
                        continue
                    else:
                        for k in range(j/code_num):
                            indexj+=codes_num[no][k]
                        a[i][j]=codeEachOther_delt_dates[no][indexi][indexj]

            else:
                indexi=i%code_num
                if indexi>(codes_num[no][i/code_num]-1): 
                    a[i][j]=0
                    continue
                else:
                    for k in range(i/code_num):
                        indexi+=codes_num[no][k]
                if codes_num[no][j/code_num] >= code_num:
                    indexj=j%code_num
                    for k in range(j/code_num):
                        indexj+=codes_num[no][k]
                    a[i][j]=codeEachOther_delt_dates[no][indexi][indexj]  
                else:
                    indexj=j%code_num
                    if indexj > (codes_num[no][j/code_num]-1): 
                        a[i][j]=0
                        continue
                    else:
                        for k in range(j/code_num):
                            indexj+=codes_num[no][k]

                        a[i][j]=codeEachOther_delt_dates[no][indexi][indexj]
    
    return a
    
    
    # batch生成器
    # bacth中有cutting和padding，由于长度不同，因此开始学习的时候可以每个数据单独学习
    # 生成的batch里的病人visit和code的个数一样
def batch_generator(outFile, visit_num, code_num, patient_num, batch_size):
    patient_code_file=open(outFile + '/patient_code' + '.seqs','rb')
    patient_code=pickle.load(patient_code_file)

    codespatientsinfo_file=open(outFile + '/codespatientsinfo' + '.seqs','rb')
    codespatientsinfo=pickle.load(codespatientsinfo_file)

    visit_delt_dates_file=open(outFile + '/visit_delt_dates' + '.seqs','rb')
    visit_delt_dates=pickle.load(visit_delt_dates_file)

    code_delt_dates_file=open(outFile + '/code_delt_dates' + '.seqs','rb')
    code_delt_dates=pickle.load(code_delt_dates_file)



    visits_num_file=open(outFile + '/visits_num' + '.seqs','rb')
    visits_num=pickle.load(visits_num_file)

    codes_num_file=open(outFile + '/codes_num' + '.seqs','rb')
    codes_num=pickle.load(codes_num_file)

    death_labels_file=open(outFile + '/death_labels' + '.seqs','rb')
    death_labels=pickle.load(death_labels_file)

    patient_code_file.close()
    codespatientsinfo_file.close()
    visit_delt_dates_file.close()
    code_delt_dates_file.close()
    visits_num_file.close()
    codes_num_file.close()

    batch_patient_code=[]
    batch_code_delt_dates=[]
    batch_codespatientsinfo=[]
    batch_visit_delt_dates=[]
    batch_labels=[]
    batch_codeEachOther_delt_dates=[]

    # padding and cutting
    for i in range(batch_size):
        j=random.randint(0, patient_num-1)
        if visits_num[j][0]== visit_num:
           # print 'pick',j
            code=[]
            code_date=[]
            code_info=[]
            codeEachOther_dates=[]
            for k in range(visit_num):
                if  codes_num[j][k]> code_num:
                    if k==0: 
                        code.extend(patient_code[j][0:code_num])
                        code_date.extend(code_delt_dates[j][0:code_num])
                        code_info.extend(codespatientsinfo[j][0:code_num])

                    else:
                        start=k*codes_num[j][k-1]
                        code.extend(patient_code[j][start :start+code_num])
                        code_date.extend(code_delt_dates[j][start:start+code_num])
                        code_info.extend(codespatientsinfo[j][start:start+code_num])


                elif codes_num[j][k] < code_num:
                    if k==0:
                        code.extend(patient_code[j][:codes_num[j][k]])
                        code.extend([0]*(code_num-codes_num[j][k]))
                        code_date.extend(code_delt_dates[j][:codes_num[j][k]])
                        code_date.extend([0]*(code_num-codes_num[j][k]))
                        code_info.extend(codespatientsinfo[j][:codes_num[j][k]])
                        code_info.extend([[0,0]]*(code_num-codes_num[j][k]))

                    else:
                        start2=0
                        for n in range(k):
                            start2+=codes_num[j][n]
                        code.extend(patient_code[j][start2: start2+codes_num[j][k]])
                        code.extend([0]*(code_num-codes_num[j][k]))
                        code_date.extend(code_delt_dates[j][start2:start2+codes_num[j][k]])
                        code_date.extend([0]*(code_num-codes_num[j][k]))
                        code_info.extend(codespatientsinfo[j][start2:start2+codes_num[j][k]])
                        code_info.extend([[0,0]]*(code_num-codes_num[j][k]))
                else:
                    code.extend(patient_code[j])
                    code_date.extend(code_delt_dates[j])
                    code_info.extend(codespatientsinfo[j])

            batch_patient_code.append(code)
            batch_code_delt_dates.append(code_date)  
            batch_codespatientsinfo.append(code_info)
            batch_labels.append(death_labels[j])
            batch_codeEachOther_delt_dates.append(get_codeEachOther_delt_dates(j,visit_num,code_num))
    return batch_patient_code, batch_codespatientsinfo, batch_code_delt_dates,batch_labels,batch_codeEachOther_delt_dates
        

In [26]:
outFile='SEQ'
batch_size=4
visit_num=2 #visit_num=random.randint() 
code_num=10 #code_num=random.randint() # code_num=visit_length
patient_num=14 #病人的总个数
xindex, info, t,batch_labels,batch_codeEachOther_delt_dates = batch_generator(outFile, visit_num, code_num, patient_num,batch_size)

In [28]:
a=batch_codeEachOther_delt_dates[:][:(visit_num-1)*code_num][:(visit_num-1)*code_num]

In [31]:
len(a[0][0])

20

In [None]:
admissionFile = '数据/ADMISSIONS.csv'