In [0]:
pip install keras

In [0]:
pip install tensorflow

In [0]:
from pyspark import SparkContext, SparkConf
from pyspark.sql.types import *
import pandas as pd
import pickle

ICD9CODES = pickle.load(open("/dbfs/FileStore/data_2/ICD9CODES.p", 'rb'))
ICD9CODES_TOP10 = pickle.load(open("/dbfs/FileStore/data_2/ICD9CODES_TOP10.p", 'rb'))
ICD9CODES_TOP50 = pickle.load(open("/dbfs/FileStore/data_2/ICD9CODES_TOP50.p", 'rb'))
ICD9CAT_TOP10 = pickle.load(open("/dbfs/FileStore/data_2/ICD9CAT_TOP10.p", 'rb'))
ICD9CAT_TOP50 = pickle.load(open("/dbfs/FileStore/data_2/ICD9CAT_TOP50.p", 'rb'))

In [0]:
from pyspark.ml.feature import StopWordsRemover
STOPWORDS_v0 = StopWordsRemover.loadDefaultStopWords("english") + ICD9CODES
STOPWORDS_v0 = [str(i) for i in STOPWORDS_v0]

In [0]:
from pyspark.ml.feature import HashingTF, IDF, RegexTokenizer, StopWordsRemover

def create_TFIDF_v0(trainData, applyData, inputCol="text", outputCol="features", minDocFreq=3, numFeatures=20):    
    tokenizer = RegexTokenizer(pattern="[.:\s]+", inputCol=inputCol, outputCol="z_words")
    wordsData1 = tokenizer.transform(trainData)
    wordsData2 = tokenizer.transform(applyData)
    
    remover = StopWordsRemover(inputCol="z_words", outputCol="z_filtered", stopWords=STOPWORDS_v0)
    wordsDataFiltered1 = remover.transform(wordsData1)
    wordsDataFiltered2 = remover.transform(wordsData2)
    
    hashingTF = HashingTF(inputCol="z_filtered", outputCol="z_rawFeatures", numFeatures=numFeatures)
    featurizedData1 = hashingTF.transform(wordsDataFiltered1)
    featurizedData2 = hashingTF.transform(wordsDataFiltered2)
    # alternatively, CountVectorizer can also be used to get term frequency vectors

    idf = IDF(inputCol="z_rawFeatures", outputCol=outputCol, minDocFreq=minDocFreq)
    idfModel = idf.fit(featurizedData1)
    
    rescaledData = idfModel.transform(featurizedData2)
    return rescaledData.drop("z_words", "z_filtered", "z_rawFeatures", inputCol)

In [0]:
import numpy as np
import re
from sklearn.feature_extraction.text import TfidfVectorizer
from pyspark.mllib.util import Vectors
from pyspark.mllib.linalg import VectorUDT
from pyspark.sql.functions import UserDefinedFunction

def preprocessor_v1(text):
    text = re.sub('\[\*\*[^\]]*\*\*\]', '', text)
    text = re.sub('<[^>]*>', '', text)
    text = re.sub('[\W]+', ' ', text.lower()) 
    return text

def create_TFIDF_v1(df_train, df_apply, inputCol="text", outputCol="features",
                    minDocFreq=3, maxDocFreq=1.0, numFeatures=20):
    df_train['z_cleaned'] = df_train[inputCol].apply(preprocessor_v1)
    df_apply['z_cleaned'] = df_apply[inputCol].apply(preprocessor_v1)

    # Now we create the sparse matrix of tfidf values
    tfidf = TfidfVectorizer(input='content',ngram_range=(1, 1),
                            stop_words=STOPWORDS_v1, 
                            min_df=minDocFreq,
                            max_df=maxDocFreq,
                            max_features=numFeatures)
    # I select to remove stopwords and minimun doc frequency =10 to delete very unusual words
    # that only show up in less than 10 notes (out of 59k notes available) 

    tfidf.fit([c for c in df_train['z_cleaned']])
    dtm = tfidf.transform([c for c in df_apply['z_cleaned']]).tocsr()
    dtm.sort_indices()
    df_apply[outputCol] = list(dtm)
   
    del df_train['z_cleaned']
    del df_apply['z_cleaned']
    del df_apply[inputCol]
    
    return df_apply

In [0]:
import nltk
nltk.download('stopwords')
from nltk.corpus import stopwords

from nltk.corpus import stopwords
print(len(stopwords.words('english')))

# print "WORD2VEC stop words"
# print STOPWORDS_WORD2VEC

In [0]:
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from nltk.corpus import stopwords

STOPWORDS_WORD2VEC = stopwords.words('english') + ICD9CODES

In [0]:
def preprocessor_word2vec(text):
    text = re.sub('\[\*\*[^\]]*\*\*\]', '', text)
    text = re.sub('<[^>]*>', '', text)
    text = re.sub('[\W]+', ' ', text.lower()) 
    text = re.sub(" \d+", " ", text)
    
    return text

def create_WORD2VEC_DL_V0(df, max_sequence_len=600, inputCol='text'):
    texts = df[inputCol].apply(preprocessor_word2vec)
    #texts = df['text']  # list of text samples

    toke = Tokenizer()
    toke.fit_on_texts(texts)
    sequence = toke.texts_to_sequences(texts)

    ave_seq = [len(i) for i in sequence]
    print(1.0* sum(ave_seq) / len(ave_seq))
    
    word_index = toke.word_index
    reverse_word_index = dict(zip(word_index.values(), word_index.keys())) # dict e.g. {1:'the', 2:'a' ...}
    #index_list = word_index.values()

    print('Found %s unique tokens.' % len(word_index))

    data = pad_sequences(sequence, maxlen=max_sequence_len)
    
    return data, word_index, reverse_word_index

def create_EmbeddingMatrix_V0(word_index, word2vec_model_path, remove_stopwords=True):

    embeddings_index = {}
    f = open(word2vec_model_path)
    for line in f:
        values = line.split()
        word = values[0]
        coefs = np.asarray(values[1:], dtype='float32')
        embeddings_index[word] = coefs
    f.close()

    print('Found %s word vectors.' % len(embeddings_index))
    
    if remove_stopwords:
        # Delete stopwords and ICD9 codes from pre-trained dictionary , 
        # so they will be zeros when we create embedding_matrix
        keys_updated = [word for word in embeddings_index.keys() if word not in STOPWORDS_WORD2VEC]
        index2word_set=set(keys_updated)
    else:
        index2word_set=set(embeddings_index.keys())
    
    EMBEDDING_DIM = embeddings_index.values()[0].size  # dimensions of the word2vec model

    embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM))
    for word, i in word_index.items():
        if word in index2word_set: 
            #embedding_vector = embeddings_index.get(word)
        #if embedding_vector is not None:
            # words not found in embedding index will be all-zeros.
            embedding_matrix[i] = embeddings_index.get(word)
            
    return embedding_matrix

In [0]:
df1_sp = spark.sql('''
select distinct `row_id` as id,
`0389`,
`2449`,
`25000`,
`2720`,
`2724`,
`2760`,
`2761`,
`2762`,
`2767`,
`2851`,
`2859`,
`2875`,
`3051`,
`311`,
`32723`,
`4019`,
`40390`,
`40391`,
`41071`,
`412`,
`41401`,
`4168`,
`4240`,
`4241`,
`42731`,
`42789`,
`4280`,
`486`,
`49390`,
`496`,
`5070`,
`5119`,
`5180`,
`51881`,
`53081`,
`5845`,
`5849`,
`5859`,
`5990`,
`7742`,
`78552`,
`99592`,
`9971`,
`V053`,
`V1582`,
`V290`,
`V4581`,
`V4582`,
`V5861`,
`V5867`,
`text`
--`clean_text`
from tempdb.hadm_top_50_icd_step_5
''')

df1_pd = df1_sp.toPandas()

In [0]:
data, word_index, reverse_word_index = create_WORD2VEC_DL_V0(df1_pd.copy(), max_sequence_len=500)

In [0]:
import pandas as pd
p_data = pd.DataFrame(data)

In [0]:
# em = create_EmbeddingMatrix_V0(word_index, "/dbfs/FileStore/data/word2vec.txt", remove_stopwords=True)

In [0]:
embeddings_index = {}
remove_stopwords = True
f = open("/dbfs/FileStore/data/word2vec.txt")
for line in f:
    values = line.split()
    word = values[0]
    coefs = np.asarray(values[1:], dtype='float32')
    embeddings_index[word] = coefs
f.close()
            

In [0]:
print('Found %s word vectors.' % len(embeddings_index))

In [0]:
if remove_stopwords:
    # Delete stopwords and ICD9 codes from pre-trained dictionary , 
    # so they will be zeros when we create embedding_matrix
    keys_updated = [word for word in embeddings_index.keys() if word not in STOPWORDS_WORD2VEC]
    index2word_set=set(keys_updated)
else:
    index2word_set=set(embeddings_index.keys())

In [0]:
EMBEDDING_DIM = 300  # dimensions of the word2vec model

In [0]:
# list(embeddings_index.values())

In [0]:

# EMBEDDING_DIM = list(embeddings_index.values())[0].size  # dimensions of the word2vec model

embedding_matrix = np.zeros((len(word_index) + 1, EMBEDDING_DIM))
for word, i in word_index.items():
    if word in index2word_set: 
        #embedding_vector = embeddings_index.get(word)
    #if embedding_vector is not None:
        # words not found in embedding index will be all-zeros.
        embedding_matrix[i] = embeddings_index.get(word)

In [0]:
embedding_matrix

In [0]:
embedding_matrix_safe = embedding_matrix

In [0]:
import pickle
def output_pickle(obj, fname):
    f = open(fname, 'wb')
    pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
    f.close()

In [0]:
output_pickle("embedding_matrix", "/dbfs/FileStore/data/EMBMATRIXV1_WORD2VEC_v2_300dim.p")

In [0]:
combo = pd.concat([pd.DataFrame(df1_pd["id"]).reset_index(drop=True), p_data.reset_index(drop=True)], axis= 1)

In [0]:
col_ls = []
col_ls.append('id')
for i in list(range(0,500)):
  col_ls.append(str(i))
  
combo.columns = col_ls
  

In [0]:
combo.head()

Unnamed: 0,id,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,...,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499
0,1334282,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,5446,5,525,12,548,4,31,11,5932,128,466,3610,4143,631,25373,4,25,190,65,37,13,2519,1,399,610,106,1727,4,13,12,53,1,31,4,66,1183,1495,4627,5978,2597
1,902355,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,105,199,233,21,184,362,3246,880,804,140,2399,935,772,70,772,2146,18,359,29,531,838,45,1218,735,4,3,22,1543,365,735,4,3,22,799,436,142,2225,393,1381,628
2,1751765,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,600,655,54,20,794,421,4392,8,3,1220,5473,112,51,14,4682,241,32,51,47,18264,7,81,22,1214,51,2347,13805,717,3,394,25,61,677,9608,2,1169,981,97,6,421
3,1382677,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,...,8,554,102,452,193,2,3268,13,2,1819,11,443,267,664,269,8,329,3578,4,25,69,736,36,1,7384,270,1681,1,58,414,92,53,139,53,190,28,136,18,102,1877
4,424127,1987,811,207,4216,54,13728,11266,3143,1092,193,2,75,3291,6560,6139,2093,628,12,25,1397,2713,4,22,2171,16156,203,419,401,369,30,1657,462,1329,1214,881,1916,506,204,117,...,587,228,427,247,720,592,354,537,101,16,283,336,1025,484,417,847,483,1170,167,480,24,710,8,2166,1432,104,189,53,197,189,379,104,124,111,1074,789,24,10,1822,1369


In [0]:
# Auxiliar functions
def equivalent_type(f):
    if f == 'datetime64[ns]': return TimestampType()
    elif f == 'int64': return LongType()
    elif f == 'int32': return IntegerType()
    elif f == 'float64': return FloatType()
    else: return StringType()

def define_structure(string, format_type):
    try: typo = equivalent_type(format_type)
    except: typo = StringType()
    return StructField(string, typo)

# Given pandas dataframe, it will return a spark's dataframe.
def pandas_to_spark(pandas_df):
    columns = list(pandas_df.columns)
    types = list(pandas_df.dtypes)
    struct_list = []
    for column, typo in zip(columns, types): 
      struct_list.append(define_structure(column, typo))
    p_schema = StructType(struct_list)
    return sqlContext.createDataFrame(pandas_df, p_schema)

In [0]:
spark_seq_df = pandas_to_spark(combo)

In [0]:
spark_seq_df.registerTempTable('spark_seq_df')

In [0]:
%sql
create table tempdb.spark_seq_df as (
select * 
from spark_seq_df
)

num_affected_rows,num_inserted_rows


In [0]:
spark_seq_train_df = spark.sql(
'''
select distinct 

  `id`,
  `4019`,
  `2724`,
  `25000`,
  `4280`,
  `41401`,
  `53081`,
  `51881`,
  `42731`,
  `5849`,
  `5990`,
  c.`0`,
  c.`1`,
  c.`2`,
  c.`3`,
  c.`4`,
  c.`5`,
  c.`6`,
  c.`7`,
  c.`8`,
  c.`9`,
  c.`10`,
  c.`11`,
  c.`12`,
  c.`13`,
  c.`14`,
  c.`15`,
  c.`16`,
  c.`17`,
  c.`18`,
  c.`19`,
  c.`20`,
  c.`21`,
  c.`22`,
  c.`23`,
  c.`24`,
  c.`25`,
  c.`26`,
  c.`27`,
  c.`28`,
  c.`29`,
  c.`30`,
  c.`31`,
  c.`32`,
  c.`33`,
  c.`34`,
  c.`35`,
  c.`36`,
  c.`37`,
  c.`38`,
  c.`39`,
  c.`40`,
  c.`41`,
  c.`42`,
  c.`43`,
  c.`44`,
  c.`45`,
  c.`46`,
  c.`47`,
  c.`48`,
  c.`49`,
  c.`50`,
  c.`51`,
  c.`52`,
  c.`53`,
  c.`54`,
  c.`55`,
  c.`56`,
  c.`57`,
  c.`58`,
  c.`59`,
  c.`60`,
  c.`61`,
  c.`62`,
  c.`63`,
  c.`64`,
  c.`65`,
  c.`66`,
  c.`67`,
  c.`68`,
  c.`69`,
  c.`70`,
  c.`71`,
  c.`72`,
  c.`73`,
  c.`74`,
  c.`75`,
  c.`76`,
  c.`77`,
  c.`78`,
  c.`79`,
  c.`80`,
  c.`81`,
  c.`82`,
  c.`83`,
  c.`84`,
  c.`85`,
  c.`86`,
  c.`87`,
  c.`88`,
  c.`89`,
  c.`90`,
  c.`91`,
  c.`92`,
  c.`93`,
  c.`94`,
  c.`95`,
  c.`96`,
  c.`97`,
  c.`98`,
  c.`99`,
  c.`100`,
  c.`101`,
  c.`102`,
  c.`103`,
  c.`104`,
  c.`105`,
  c.`106`,
  c.`107`,
  c.`108`,
  c.`109`,
  c.`110`,
  c.`111`,
  c.`112`,
  c.`113`,
  c.`114`,
  c.`115`,
  c.`116`,
  c.`117`,
  c.`118`,
  c.`119`,
  c.`120`,
  c.`121`,
  c.`122`,
  c.`123`,
  c.`124`,
  c.`125`,
  c.`126`,
  c.`127`,
  c.`128`,
  c.`129`,
  c.`130`,
  c.`131`,
  c.`132`,
  c.`133`,
  c.`134`,
  c.`135`,
  c.`136`,
  c.`137`,
  c.`138`,
  c.`139`,
  c.`140`,
  c.`141`,
  c.`142`,
  c.`143`,
  c.`144`,
  c.`145`,
  c.`146`,
  c.`147`,
  c.`148`,
  c.`149`,
  c.`150`,
  c.`151`,
  c.`152`,
  c.`153`,
  c.`154`,
  c.`155`,
  c.`156`,
  c.`157`,
  c.`158`,
  c.`159`,
  c.`160`,
  c.`161`,
  c.`162`,
  c.`163`,
  c.`164`,
  c.`165`,
  c.`166`,
  c.`167`,
  c.`168`,
  c.`169`,
  c.`170`,
  c.`171`,
  c.`172`,
  c.`173`,
  c.`174`,
  c.`175`,
  c.`176`,
  c.`177`,
  c.`178`,
  c.`179`,
  c.`180`,
  c.`181`,
  c.`182`,
  c.`183`,
  c.`184`,
  c.`185`,
  c.`186`,
  c.`187`,
  c.`188`,
  c.`189`,
  c.`190`,
  c.`191`,
  c.`192`,
  c.`193`,
  c.`194`,
  c.`195`,
  c.`196`,
  c.`197`,
  c.`198`,
  c.`199`,
  c.`200`,
  c.`201`,
  c.`202`,
  c.`203`,
  c.`204`,
  c.`205`,
  c.`206`,
  c.`207`,
  c.`208`,
  c.`209`,
  c.`210`,
  c.`211`,
  c.`212`,
  c.`213`,
  c.`214`,
  c.`215`,
  c.`216`,
  c.`217`,
  c.`218`,
  c.`219`,
  c.`220`,
  c.`221`,
  c.`222`,
  c.`223`,
  c.`224`,
  c.`225`,
  c.`226`,
  c.`227`,
  c.`228`,
  c.`229`,
  c.`230`,
  c.`231`,
  c.`232`,
  c.`233`,
  c.`234`,
  c.`235`,
  c.`236`,
  c.`237`,
  c.`238`,
  c.`239`,
  c.`240`,
  c.`241`,
  c.`242`,
  c.`243`,
  c.`244`,
  c.`245`,
  c.`246`,
  c.`247`,
  c.`248`,
  c.`249`,
  c.`250`,
  c.`251`,
  c.`252`,
  c.`253`,
  c.`254`,
  c.`255`,
  c.`256`,
  c.`257`,
  c.`258`,
  c.`259`,
  c.`260`,
  c.`261`,
  c.`262`,
  c.`263`,
  c.`264`,
  c.`265`,
  c.`266`,
  c.`267`,
  c.`268`,
  c.`269`,
  c.`270`,
  c.`271`,
  c.`272`,
  c.`273`,
  c.`274`,
  c.`275`,
  c.`276`,
  c.`277`,
  c.`278`,
  c.`279`,
  c.`280`,
  c.`281`,
  c.`282`,
  c.`283`,
  c.`284`,
  c.`285`,
  c.`286`,
  c.`287`,
  c.`288`,
  c.`289`,
  c.`290`,
  c.`291`,
  c.`292`,
  c.`293`,
  c.`294`,
  c.`295`,
  c.`296`,
  c.`297`,
  c.`298`,
  c.`299`,
  c.`300`,
  c.`301`,
  c.`302`,
  c.`303`,
  c.`304`,
  c.`305`,
  c.`306`,
  c.`307`,
  c.`308`,
  c.`309`,
  c.`310`,
  c.`311`,
  c.`312`,
  c.`313`,
  c.`314`,
  c.`315`,
  c.`316`,
  c.`317`,
  c.`318`,
  c.`319`,
  c.`320`,
  c.`321`,
  c.`322`,
  c.`323`,
  c.`324`,
  c.`325`,
  c.`326`,
  c.`327`,
  c.`328`,
  c.`329`,
  c.`330`,
  c.`331`,
  c.`332`,
  c.`333`,
  c.`334`,
  c.`335`,
  c.`336`,
  c.`337`,
  c.`338`,
  c.`339`,
  c.`340`,
  c.`341`,
  c.`342`,
  c.`343`,
  c.`344`,
  c.`345`,
  c.`346`,
  c.`347`,
  c.`348`,
  c.`349`,
  c.`350`,
  c.`351`,
  c.`352`,
  c.`353`,
  c.`354`,
  c.`355`,
  c.`356`,
  c.`357`,
  c.`358`,
  c.`359`,
  c.`360`,
  c.`361`,
  c.`362`,
  c.`363`,
  c.`364`,
  c.`365`,
  c.`366`,
  c.`367`,
  c.`368`,
  c.`369`,
  c.`370`,
  c.`371`,
  c.`372`,
  c.`373`,
  c.`374`,
  c.`375`,
  c.`376`,
  c.`377`,
  c.`378`,
  c.`379`,
  c.`380`,
  c.`381`,
  c.`382`,
  c.`383`,
  c.`384`,
  c.`385`,
  c.`386`,
  c.`387`,
  c.`388`,
  c.`389`,
  c.`390`,
  c.`391`,
  c.`392`,
  c.`393`,
  c.`394`,
  c.`395`,
  c.`396`,
  c.`397`,
  c.`398`,
  c.`399`,
  c.`400`,
  c.`401`,
  c.`402`,
  c.`403`,
  c.`404`,
  c.`405`,
  c.`406`,
  c.`407`,
  c.`408`,
  c.`409`,
  c.`410`,
  c.`411`,
  c.`412`,
  c.`413`,
  c.`414`,
  c.`415`,
  c.`416`,
  c.`417`,
  c.`418`,
  c.`419`,
  c.`420`,
  c.`421`,
  c.`422`,
  c.`423`,
  c.`424`,
  c.`425`,
  c.`426`,
  c.`427`,
  c.`428`,
  c.`429`,
  c.`430`,
  c.`431`,
  c.`432`,
  c.`433`,
  c.`434`,
  c.`435`,
  c.`436`,
  c.`437`,
  c.`438`,
  c.`439`,
  c.`440`,
  c.`441`,
  c.`442`,
  c.`443`,
  c.`444`,
  c.`445`,
  c.`446`,
  c.`447`,
  c.`448`,
  c.`449`,
  c.`450`,
  c.`451`,
  c.`452`,
  c.`453`,
  c.`454`,
  c.`455`,
  c.`456`,
  c.`457`,
  c.`458`,
  c.`459`,
  c.`460`,
  c.`461`,
  c.`462`,
  c.`463`,
  c.`464`,
  c.`465`,
  c.`466`,
  c.`467`,
  c.`468`,
  c.`469`,
  c.`470`,
  c.`471`,
  c.`472`,
  c.`473`,
  c.`474`,
  c.`475`,
  c.`476`,
  c.`477`,
  c.`478`,
  c.`479`,
  c.`480`,
  c.`481`,
  c.`482`,
  c.`483`,
  c.`484`,
  c.`485`,
  c.`486`,
  c.`487`,
  c.`488`,
  c.`489`,
  c.`490`,
  c.`491`,
  c.`492`,
  c.`493`,
  c.`494`,
  c.`495`,
  c.`496`,
  c.`497`,
  c.`498`,
  c.`499`
from tempdb.hadm_top_50_icd_step_5 a 
inner join tempdb.hadm_dataset_lookup b  on a.row_id = b.row_id  and b.group_type = 'train' 
inner join tempdb.spark_seq_df c  on c.id = b.row_id
'''
)

In [0]:
spark_seq_test_df = spark.sql(
'''
select distinct 

  `id`,
  `4019`,
  `2724`,
  `25000`,
  `4280`,
  `41401`,
  `53081`,
  `51881`,
  `42731`,
  `5849`,
  `5990`,
  c.`0`,
  c.`1`,
  c.`2`,
  c.`3`,
  c.`4`,
  c.`5`,
  c.`6`,
  c.`7`,
  c.`8`,
  c.`9`,
  c.`10`,
  c.`11`,
  c.`12`,
  c.`13`,
  c.`14`,
  c.`15`,
  c.`16`,
  c.`17`,
  c.`18`,
  c.`19`,
  c.`20`,
  c.`21`,
  c.`22`,
  c.`23`,
  c.`24`,
  c.`25`,
  c.`26`,
  c.`27`,
  c.`28`,
  c.`29`,
  c.`30`,
  c.`31`,
  c.`32`,
  c.`33`,
  c.`34`,
  c.`35`,
  c.`36`,
  c.`37`,
  c.`38`,
  c.`39`,
  c.`40`,
  c.`41`,
  c.`42`,
  c.`43`,
  c.`44`,
  c.`45`,
  c.`46`,
  c.`47`,
  c.`48`,
  c.`49`,
  c.`50`,
  c.`51`,
  c.`52`,
  c.`53`,
  c.`54`,
  c.`55`,
  c.`56`,
  c.`57`,
  c.`58`,
  c.`59`,
  c.`60`,
  c.`61`,
  c.`62`,
  c.`63`,
  c.`64`,
  c.`65`,
  c.`66`,
  c.`67`,
  c.`68`,
  c.`69`,
  c.`70`,
  c.`71`,
  c.`72`,
  c.`73`,
  c.`74`,
  c.`75`,
  c.`76`,
  c.`77`,
  c.`78`,
  c.`79`,
  c.`80`,
  c.`81`,
  c.`82`,
  c.`83`,
  c.`84`,
  c.`85`,
  c.`86`,
  c.`87`,
  c.`88`,
  c.`89`,
  c.`90`,
  c.`91`,
  c.`92`,
  c.`93`,
  c.`94`,
  c.`95`,
  c.`96`,
  c.`97`,
  c.`98`,
  c.`99`,
  c.`100`,
  c.`101`,
  c.`102`,
  c.`103`,
  c.`104`,
  c.`105`,
  c.`106`,
  c.`107`,
  c.`108`,
  c.`109`,
  c.`110`,
  c.`111`,
  c.`112`,
  c.`113`,
  c.`114`,
  c.`115`,
  c.`116`,
  c.`117`,
  c.`118`,
  c.`119`,
  c.`120`,
  c.`121`,
  c.`122`,
  c.`123`,
  c.`124`,
  c.`125`,
  c.`126`,
  c.`127`,
  c.`128`,
  c.`129`,
  c.`130`,
  c.`131`,
  c.`132`,
  c.`133`,
  c.`134`,
  c.`135`,
  c.`136`,
  c.`137`,
  c.`138`,
  c.`139`,
  c.`140`,
  c.`141`,
  c.`142`,
  c.`143`,
  c.`144`,
  c.`145`,
  c.`146`,
  c.`147`,
  c.`148`,
  c.`149`,
  c.`150`,
  c.`151`,
  c.`152`,
  c.`153`,
  c.`154`,
  c.`155`,
  c.`156`,
  c.`157`,
  c.`158`,
  c.`159`,
  c.`160`,
  c.`161`,
  c.`162`,
  c.`163`,
  c.`164`,
  c.`165`,
  c.`166`,
  c.`167`,
  c.`168`,
  c.`169`,
  c.`170`,
  c.`171`,
  c.`172`,
  c.`173`,
  c.`174`,
  c.`175`,
  c.`176`,
  c.`177`,
  c.`178`,
  c.`179`,
  c.`180`,
  c.`181`,
  c.`182`,
  c.`183`,
  c.`184`,
  c.`185`,
  c.`186`,
  c.`187`,
  c.`188`,
  c.`189`,
  c.`190`,
  c.`191`,
  c.`192`,
  c.`193`,
  c.`194`,
  c.`195`,
  c.`196`,
  c.`197`,
  c.`198`,
  c.`199`,
  c.`200`,
  c.`201`,
  c.`202`,
  c.`203`,
  c.`204`,
  c.`205`,
  c.`206`,
  c.`207`,
  c.`208`,
  c.`209`,
  c.`210`,
  c.`211`,
  c.`212`,
  c.`213`,
  c.`214`,
  c.`215`,
  c.`216`,
  c.`217`,
  c.`218`,
  c.`219`,
  c.`220`,
  c.`221`,
  c.`222`,
  c.`223`,
  c.`224`,
  c.`225`,
  c.`226`,
  c.`227`,
  c.`228`,
  c.`229`,
  c.`230`,
  c.`231`,
  c.`232`,
  c.`233`,
  c.`234`,
  c.`235`,
  c.`236`,
  c.`237`,
  c.`238`,
  c.`239`,
  c.`240`,
  c.`241`,
  c.`242`,
  c.`243`,
  c.`244`,
  c.`245`,
  c.`246`,
  c.`247`,
  c.`248`,
  c.`249`,
  c.`250`,
  c.`251`,
  c.`252`,
  c.`253`,
  c.`254`,
  c.`255`,
  c.`256`,
  c.`257`,
  c.`258`,
  c.`259`,
  c.`260`,
  c.`261`,
  c.`262`,
  c.`263`,
  c.`264`,
  c.`265`,
  c.`266`,
  c.`267`,
  c.`268`,
  c.`269`,
  c.`270`,
  c.`271`,
  c.`272`,
  c.`273`,
  c.`274`,
  c.`275`,
  c.`276`,
  c.`277`,
  c.`278`,
  c.`279`,
  c.`280`,
  c.`281`,
  c.`282`,
  c.`283`,
  c.`284`,
  c.`285`,
  c.`286`,
  c.`287`,
  c.`288`,
  c.`289`,
  c.`290`,
  c.`291`,
  c.`292`,
  c.`293`,
  c.`294`,
  c.`295`,
  c.`296`,
  c.`297`,
  c.`298`,
  c.`299`,
  c.`300`,
  c.`301`,
  c.`302`,
  c.`303`,
  c.`304`,
  c.`305`,
  c.`306`,
  c.`307`,
  c.`308`,
  c.`309`,
  c.`310`,
  c.`311`,
  c.`312`,
  c.`313`,
  c.`314`,
  c.`315`,
  c.`316`,
  c.`317`,
  c.`318`,
  c.`319`,
  c.`320`,
  c.`321`,
  c.`322`,
  c.`323`,
  c.`324`,
  c.`325`,
  c.`326`,
  c.`327`,
  c.`328`,
  c.`329`,
  c.`330`,
  c.`331`,
  c.`332`,
  c.`333`,
  c.`334`,
  c.`335`,
  c.`336`,
  c.`337`,
  c.`338`,
  c.`339`,
  c.`340`,
  c.`341`,
  c.`342`,
  c.`343`,
  c.`344`,
  c.`345`,
  c.`346`,
  c.`347`,
  c.`348`,
  c.`349`,
  c.`350`,
  c.`351`,
  c.`352`,
  c.`353`,
  c.`354`,
  c.`355`,
  c.`356`,
  c.`357`,
  c.`358`,
  c.`359`,
  c.`360`,
  c.`361`,
  c.`362`,
  c.`363`,
  c.`364`,
  c.`365`,
  c.`366`,
  c.`367`,
  c.`368`,
  c.`369`,
  c.`370`,
  c.`371`,
  c.`372`,
  c.`373`,
  c.`374`,
  c.`375`,
  c.`376`,
  c.`377`,
  c.`378`,
  c.`379`,
  c.`380`,
  c.`381`,
  c.`382`,
  c.`383`,
  c.`384`,
  c.`385`,
  c.`386`,
  c.`387`,
  c.`388`,
  c.`389`,
  c.`390`,
  c.`391`,
  c.`392`,
  c.`393`,
  c.`394`,
  c.`395`,
  c.`396`,
  c.`397`,
  c.`398`,
  c.`399`,
  c.`400`,
  c.`401`,
  c.`402`,
  c.`403`,
  c.`404`,
  c.`405`,
  c.`406`,
  c.`407`,
  c.`408`,
  c.`409`,
  c.`410`,
  c.`411`,
  c.`412`,
  c.`413`,
  c.`414`,
  c.`415`,
  c.`416`,
  c.`417`,
  c.`418`,
  c.`419`,
  c.`420`,
  c.`421`,
  c.`422`,
  c.`423`,
  c.`424`,
  c.`425`,
  c.`426`,
  c.`427`,
  c.`428`,
  c.`429`,
  c.`430`,
  c.`431`,
  c.`432`,
  c.`433`,
  c.`434`,
  c.`435`,
  c.`436`,
  c.`437`,
  c.`438`,
  c.`439`,
  c.`440`,
  c.`441`,
  c.`442`,
  c.`443`,
  c.`444`,
  c.`445`,
  c.`446`,
  c.`447`,
  c.`448`,
  c.`449`,
  c.`450`,
  c.`451`,
  c.`452`,
  c.`453`,
  c.`454`,
  c.`455`,
  c.`456`,
  c.`457`,
  c.`458`,
  c.`459`,
  c.`460`,
  c.`461`,
  c.`462`,
  c.`463`,
  c.`464`,
  c.`465`,
  c.`466`,
  c.`467`,
  c.`468`,
  c.`469`,
  c.`470`,
  c.`471`,
  c.`472`,
  c.`473`,
  c.`474`,
  c.`475`,
  c.`476`,
  c.`477`,
  c.`478`,
  c.`479`,
  c.`480`,
  c.`481`,
  c.`482`,
  c.`483`,
  c.`484`,
  c.`485`,
  c.`486`,
  c.`487`,
  c.`488`,
  c.`489`,
  c.`490`,
  c.`491`,
  c.`492`,
  c.`493`,
  c.`494`,
  c.`495`,
  c.`496`,
  c.`497`,
  c.`498`,
  c.`499`
from tempdb.hadm_top_50_icd_step_5 a 
inner join tempdb.hadm_dataset_lookup b  on a.row_id = b.row_id  and b.group_type = 'test' 
inner join tempdb.spark_seq_df c  on c.id = b.row_id
'''
)

In [0]:
spark_seq_valid_df = spark.sql(
'''
select distinct 

  `id`,
  `4019`,
  `2724`,
  `25000`,
  `4280`,
  `41401`,
  `53081`,
  `51881`,
  `42731`,
  `5849`,
  `5990`,
  c.`0`,
  c.`1`,
  c.`2`,
  c.`3`,
  c.`4`,
  c.`5`,
  c.`6`,
  c.`7`,
  c.`8`,
  c.`9`,
  c.`10`,
  c.`11`,
  c.`12`,
  c.`13`,
  c.`14`,
  c.`15`,
  c.`16`,
  c.`17`,
  c.`18`,
  c.`19`,
  c.`20`,
  c.`21`,
  c.`22`,
  c.`23`,
  c.`24`,
  c.`25`,
  c.`26`,
  c.`27`,
  c.`28`,
  c.`29`,
  c.`30`,
  c.`31`,
  c.`32`,
  c.`33`,
  c.`34`,
  c.`35`,
  c.`36`,
  c.`37`,
  c.`38`,
  c.`39`,
  c.`40`,
  c.`41`,
  c.`42`,
  c.`43`,
  c.`44`,
  c.`45`,
  c.`46`,
  c.`47`,
  c.`48`,
  c.`49`,
  c.`50`,
  c.`51`,
  c.`52`,
  c.`53`,
  c.`54`,
  c.`55`,
  c.`56`,
  c.`57`,
  c.`58`,
  c.`59`,
  c.`60`,
  c.`61`,
  c.`62`,
  c.`63`,
  c.`64`,
  c.`65`,
  c.`66`,
  c.`67`,
  c.`68`,
  c.`69`,
  c.`70`,
  c.`71`,
  c.`72`,
  c.`73`,
  c.`74`,
  c.`75`,
  c.`76`,
  c.`77`,
  c.`78`,
  c.`79`,
  c.`80`,
  c.`81`,
  c.`82`,
  c.`83`,
  c.`84`,
  c.`85`,
  c.`86`,
  c.`87`,
  c.`88`,
  c.`89`,
  c.`90`,
  c.`91`,
  c.`92`,
  c.`93`,
  c.`94`,
  c.`95`,
  c.`96`,
  c.`97`,
  c.`98`,
  c.`99`,
  c.`100`,
  c.`101`,
  c.`102`,
  c.`103`,
  c.`104`,
  c.`105`,
  c.`106`,
  c.`107`,
  c.`108`,
  c.`109`,
  c.`110`,
  c.`111`,
  c.`112`,
  c.`113`,
  c.`114`,
  c.`115`,
  c.`116`,
  c.`117`,
  c.`118`,
  c.`119`,
  c.`120`,
  c.`121`,
  c.`122`,
  c.`123`,
  c.`124`,
  c.`125`,
  c.`126`,
  c.`127`,
  c.`128`,
  c.`129`,
  c.`130`,
  c.`131`,
  c.`132`,
  c.`133`,
  c.`134`,
  c.`135`,
  c.`136`,
  c.`137`,
  c.`138`,
  c.`139`,
  c.`140`,
  c.`141`,
  c.`142`,
  c.`143`,
  c.`144`,
  c.`145`,
  c.`146`,
  c.`147`,
  c.`148`,
  c.`149`,
  c.`150`,
  c.`151`,
  c.`152`,
  c.`153`,
  c.`154`,
  c.`155`,
  c.`156`,
  c.`157`,
  c.`158`,
  c.`159`,
  c.`160`,
  c.`161`,
  c.`162`,
  c.`163`,
  c.`164`,
  c.`165`,
  c.`166`,
  c.`167`,
  c.`168`,
  c.`169`,
  c.`170`,
  c.`171`,
  c.`172`,
  c.`173`,
  c.`174`,
  c.`175`,
  c.`176`,
  c.`177`,
  c.`178`,
  c.`179`,
  c.`180`,
  c.`181`,
  c.`182`,
  c.`183`,
  c.`184`,
  c.`185`,
  c.`186`,
  c.`187`,
  c.`188`,
  c.`189`,
  c.`190`,
  c.`191`,
  c.`192`,
  c.`193`,
  c.`194`,
  c.`195`,
  c.`196`,
  c.`197`,
  c.`198`,
  c.`199`,
  c.`200`,
  c.`201`,
  c.`202`,
  c.`203`,
  c.`204`,
  c.`205`,
  c.`206`,
  c.`207`,
  c.`208`,
  c.`209`,
  c.`210`,
  c.`211`,
  c.`212`,
  c.`213`,
  c.`214`,
  c.`215`,
  c.`216`,
  c.`217`,
  c.`218`,
  c.`219`,
  c.`220`,
  c.`221`,
  c.`222`,
  c.`223`,
  c.`224`,
  c.`225`,
  c.`226`,
  c.`227`,
  c.`228`,
  c.`229`,
  c.`230`,
  c.`231`,
  c.`232`,
  c.`233`,
  c.`234`,
  c.`235`,
  c.`236`,
  c.`237`,
  c.`238`,
  c.`239`,
  c.`240`,
  c.`241`,
  c.`242`,
  c.`243`,
  c.`244`,
  c.`245`,
  c.`246`,
  c.`247`,
  c.`248`,
  c.`249`,
  c.`250`,
  c.`251`,
  c.`252`,
  c.`253`,
  c.`254`,
  c.`255`,
  c.`256`,
  c.`257`,
  c.`258`,
  c.`259`,
  c.`260`,
  c.`261`,
  c.`262`,
  c.`263`,
  c.`264`,
  c.`265`,
  c.`266`,
  c.`267`,
  c.`268`,
  c.`269`,
  c.`270`,
  c.`271`,
  c.`272`,
  c.`273`,
  c.`274`,
  c.`275`,
  c.`276`,
  c.`277`,
  c.`278`,
  c.`279`,
  c.`280`,
  c.`281`,
  c.`282`,
  c.`283`,
  c.`284`,
  c.`285`,
  c.`286`,
  c.`287`,
  c.`288`,
  c.`289`,
  c.`290`,
  c.`291`,
  c.`292`,
  c.`293`,
  c.`294`,
  c.`295`,
  c.`296`,
  c.`297`,
  c.`298`,
  c.`299`,
  c.`300`,
  c.`301`,
  c.`302`,
  c.`303`,
  c.`304`,
  c.`305`,
  c.`306`,
  c.`307`,
  c.`308`,
  c.`309`,
  c.`310`,
  c.`311`,
  c.`312`,
  c.`313`,
  c.`314`,
  c.`315`,
  c.`316`,
  c.`317`,
  c.`318`,
  c.`319`,
  c.`320`,
  c.`321`,
  c.`322`,
  c.`323`,
  c.`324`,
  c.`325`,
  c.`326`,
  c.`327`,
  c.`328`,
  c.`329`,
  c.`330`,
  c.`331`,
  c.`332`,
  c.`333`,
  c.`334`,
  c.`335`,
  c.`336`,
  c.`337`,
  c.`338`,
  c.`339`,
  c.`340`,
  c.`341`,
  c.`342`,
  c.`343`,
  c.`344`,
  c.`345`,
  c.`346`,
  c.`347`,
  c.`348`,
  c.`349`,
  c.`350`,
  c.`351`,
  c.`352`,
  c.`353`,
  c.`354`,
  c.`355`,
  c.`356`,
  c.`357`,
  c.`358`,
  c.`359`,
  c.`360`,
  c.`361`,
  c.`362`,
  c.`363`,
  c.`364`,
  c.`365`,
  c.`366`,
  c.`367`,
  c.`368`,
  c.`369`,
  c.`370`,
  c.`371`,
  c.`372`,
  c.`373`,
  c.`374`,
  c.`375`,
  c.`376`,
  c.`377`,
  c.`378`,
  c.`379`,
  c.`380`,
  c.`381`,
  c.`382`,
  c.`383`,
  c.`384`,
  c.`385`,
  c.`386`,
  c.`387`,
  c.`388`,
  c.`389`,
  c.`390`,
  c.`391`,
  c.`392`,
  c.`393`,
  c.`394`,
  c.`395`,
  c.`396`,
  c.`397`,
  c.`398`,
  c.`399`,
  c.`400`,
  c.`401`,
  c.`402`,
  c.`403`,
  c.`404`,
  c.`405`,
  c.`406`,
  c.`407`,
  c.`408`,
  c.`409`,
  c.`410`,
  c.`411`,
  c.`412`,
  c.`413`,
  c.`414`,
  c.`415`,
  c.`416`,
  c.`417`,
  c.`418`,
  c.`419`,
  c.`420`,
  c.`421`,
  c.`422`,
  c.`423`,
  c.`424`,
  c.`425`,
  c.`426`,
  c.`427`,
  c.`428`,
  c.`429`,
  c.`430`,
  c.`431`,
  c.`432`,
  c.`433`,
  c.`434`,
  c.`435`,
  c.`436`,
  c.`437`,
  c.`438`,
  c.`439`,
  c.`440`,
  c.`441`,
  c.`442`,
  c.`443`,
  c.`444`,
  c.`445`,
  c.`446`,
  c.`447`,
  c.`448`,
  c.`449`,
  c.`450`,
  c.`451`,
  c.`452`,
  c.`453`,
  c.`454`,
  c.`455`,
  c.`456`,
  c.`457`,
  c.`458`,
  c.`459`,
  c.`460`,
  c.`461`,
  c.`462`,
  c.`463`,
  c.`464`,
  c.`465`,
  c.`466`,
  c.`467`,
  c.`468`,
  c.`469`,
  c.`470`,
  c.`471`,
  c.`472`,
  c.`473`,
  c.`474`,
  c.`475`,
  c.`476`,
  c.`477`,
  c.`478`,
  c.`479`,
  c.`480`,
  c.`481`,
  c.`482`,
  c.`483`,
  c.`484`,
  c.`485`,
  c.`486`,
  c.`487`,
  c.`488`,
  c.`489`,
  c.`490`,
  c.`491`,
  c.`492`,
  c.`493`,
  c.`494`,
  c.`495`,
  c.`496`,
  c.`497`,
  c.`498`,
  c.`499`
from tempdb.hadm_top_50_icd_step_5 a 
inner join tempdb.hadm_dataset_lookup b  on a.row_id = b.row_id  and b.group_type = 'valid' 
inner join tempdb.spark_seq_df c  on c.id = b.row_id
'''
)

In [0]:
spark_seq_valid_df.registerTempTable("spark_seq_valid_df_new")

In [0]:
%sql
select * 
from spark_seq_valid_df_new


id,4019,2724,25000,4280,41401,53081,51881,42731,5849,5990,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41,42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63,64,65,66,67,68,69,70,71,72,73,74,75,76,77,78,79,80,81,82,83,84,85,86,87,88,89,90,91,92,93,94,95,96,97,98,99,100,101,102,103,104,105,106,107,108,109,110,111,112,113,114,115,116,117,118,119,120,121,122,123,124,125,126,127,128,129,130,131,132,133,134,135,136,137,138,139,140,141,142,143,144,145,146,147,148,149,150,151,152,153,154,155,156,157,158,159,160,161,162,163,164,165,166,167,168,169,170,171,172,173,174,175,176,177,178,179,180,181,182,183,184,185,186,187,188,189,190,191,192,193,194,195,196,197,198,199,200,201,202,203,204,205,206,207,208,209,210,211,212,213,214,215,216,217,218,219,220,221,222,223,224,225,226,227,228,229,230,231,232,233,234,235,236,237,238,239,240,241,242,243,244,245,246,247,248,249,250,251,252,253,254,255,256,257,258,259,260,261,262,263,264,265,266,267,268,269,270,271,272,273,274,275,276,277,278,279,280,281,282,283,284,285,286,287,288,289,290,291,292,293,294,295,296,297,298,299,300,301,302,303,304,305,306,307,308,309,310,311,312,313,314,315,316,317,318,319,320,321,322,323,324,325,326,327,328,329,330,331,332,333,334,335,336,337,338,339,340,341,342,343,344,345,346,347,348,349,350,351,352,353,354,355,356,357,358,359,360,361,362,363,364,365,366,367,368,369,370,371,372,373,374,375,376,377,378,379,380,381,382,383,384,385,386,387,388,389,390,391,392,393,394,395,396,397,398,399,400,401,402,403,404,405,406,407,408,409,410,411,412,413,414,415,416,417,418,419,420,421,422,423,424,425,426,427,428,429,430,431,432,433,434,435,436,437,438,439,440,441,442,443,444,445,446,447,448,449,450,451,452,453,454,455,456,457,458,459,460,461,462,463,464,465,466,467,468,469,470,471,472,473,474,475,476,477,478,479,480,481,482,483,484,485,486,487,488,489,490,491,492,493,494,495,496,497,498,499


In [0]:
spark_seq_valid_df.repartition(1).write.format('com.databricks.spark.csv').save("/dbfs/FileStore/output/DATA_WORDSEQV0_HADM_TOP10_valid.csv",header = 'true')
spark_seq_train_df.repartition(1).write.format('com.databricks.spark.csv').save("/dbfs/FileStore/output/DATA_WORDSEQV0_HADM_TOP10_train.csv",header = 'true')
spark_seq_test_df.repartition(1).write.format('com.databricks.spark.csv').save("/dbfs/FileStore/output/DATA_WORDSEQV0_HADM_TOP10_test.csv",header = 'true')

In [0]:
%sql
select * 
from tempdb.hadm_dataset_lookup 
where hadm_id =  '100852'

hadm_id,row_id,group_type
100852,368191,train
100852,9318,test
100852,368111,train
100852,368133,test
100852,368154,train
100852,368221,train
100852,368223,train
100852,1066162,train
100852,1065938,train
100852,368074,train


In [0]:
import pickle
def separate_2(df, hadmid_pickle):
    f = open(hadmid_pickle, 'rb')
    hadmid_train = pickle.load(f)
    hadmid_val = pickle.load(f)
    hadmid_test = pickle.load(f)
    f.close()
    
    df2 = df.copy()
    df2['_idx'] = df2.index
    df2.set_index('id', inplace=True)
    
    idx_train = list(df2[hadmid_train]['_idx'])
    idx_val = list(df2[hadmid_val]['_idx'])
    idx_test = list(df2[hadmid_test]['_idx'])
    
    return idx_train, idx_val, idx_test

def batch_output_pickle(df, data, reversemap, fname, labels, hadmid_pickle='/dbfs/FileStore/data_2/TRAIN_VAL_TEST_HADMID.p'):
    idx_tuple = separate_2(df, hadmid_pickle)
    
    f = open(fname, 'wb')
    pickle.dump(reversemap, f, protocol=pickle.HIGHEST_PROTOCOL)
    for i in idx_tuple:
        pickle.dump(data[i], f, protocol=pickle.HIGHEST_PROTOCOL)
    for i in idx_tuple:
        pickle.dump(df.loc[i][labels].values, f, protocol=pickle.HIGHEST_PROTOCOL)        
    f.close()

In [0]:
pd_df_lb = df1_pd[['id',"4019", "2724","25000","4280","41401","53081","51881","42731","5849","5990",'text']]

In [0]:
idx_tuple = list(df1_pd["id"]) 
#idx_tuple

In [0]:
# fname = "/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10_new.p"
# f = open(fname, 'wb')
# pickle.dump(reverse_word_index, f, protocol=pickle.HIGHEST_PROTOCOL)
# for count, i in enumerate(idx_tuple):
#     pickle.dump(data[count], f, protocol=pickle.HIGHEST_PROTOCOL)
# for count, i in enumerate(idx_tuple):
#     pickle.dump(df1_pd[count][labels].values, f, protocol=pickle.HIGHEST_PROTOCOL)        
# f.close()
# # pickle.dump(reverse_word_index, f, protocol=pickle.HIGHEST_PROTOCOL)
# # for count, i in enumerate(idx_tuple):
# #     pickle.dump(data[count], f, protocol=pickle.HIGHEST_PROTOCOL)

In [0]:
# f = open("/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10_new.p", 'rb')
# embmatrix = "/dbfs/FileStore/data/EMBMATRIXV1_WORD2VEC_v2_300dim.p"
# loaded_data = []
# for i in range(2): # [reverse_dictionary, train_sequence, test_sequence, train_label, test_label]:
#     loaded_data.append(pickle.load(f))

In [0]:
train_s_idx_df = spark.sql('''select distinct row_id from tempdb.hadm_dataset_lookup where group_type = 'train' ''')
valid_s_idx_df = spark.sql('''select distinct row_id from tempdb.hadm_dataset_lookup where group_type = 'valid' ''')
test_s_idx_df = spark.sql('''select distinct row_id from tempdb.hadm_dataset_lookup where group_type = 'test' ''')


In [0]:
train_p_idx_df = train_s_idx_df.toPandas()
valid_p_idx_df = valid_s_idx_df.toPandas()
test_p_idx_df = test_s_idx_df.toPandas()

In [0]:
labels=["4019", "2724","25000","4280","41401","53081","51881","42731","5849","5990"]

In [0]:
test_idx_ls = test_p_idx_df["row_id"]
valid_idx_ls = valid_p_idx_df["row_id"]
train_idx_ls = train_p_idx_df["row_id"]

In [0]:

comb_ls = [*valid_idx_ls, *train_idx_ls]

In [0]:
test_ind_list = df1_pd[df1_pd['id'].isin(test_idx_ls)].reset_index()['index']
valid_ind_list = df1_pd[df1_pd['id'].isin(valid_idx_ls)].reset_index()['index']
train_ind_list = df1_pd[~df1_pd['id'].isin(comb_ls)].reset_index()['index']

In [0]:
test_labels = np.array(df1_pd.loc[df1_pd.index[test_ind_list]][labels])
valid_labels = np.array(df1_pd.loc[df1_pd.index[valid_ind_list]][labels])
train_labels = np.array(df1_pd.loc[df1_pd.index[train_ind_list]][labels])

In [0]:
test_data = [data[i] for i in test_ind_list]
valid_data = [data[i] for i in valid_ind_list]
train_data = [data[i] for i in train_ind_list]

In [0]:
output_pickle("test_labels", "/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10_test_labels.p")


In [0]:
output_pickle("test_labels", "/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10_train_labels.p")
output_pickle("valid_labels", "/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10_valid_labels.p")
output_pickle("test_data", "/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10_test_data.p")
output_pickle("train_data", "/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10_train_data.p")
output_pickle("valid_data", "/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10_valid_data.p")

after failed attempts to export the data and labels,I am moving the training of the model to this workbook, I apologize for the messy code

In [0]:
import pandas as pd
import numpy as np
import pickle
import argparse
import os, sys
# import tensorflow as tf
from os.path import join
# import wordseq_models
from keras.callbacks import ModelCheckpoint, EarlyStopping
from keras.layers import Embedding
from keras.utils.vis_utils import plot_model

In [0]:
from keras.models import *
from keras.layers import *

def lstm_1(input_shape, output_shape, embedding_layer):
    print('Build model...')
    model = Sequential()
    model.add(embedding_layer)
    model.add(LSTM(256, return_sequences=True))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(LSTM(64))
    model.add(Dropout(0.5))
    model.add(BatchNormalization())
    model.add(Dense(output_shape, activation='sigmoid'))

    model.compile(loss='binary_crossentropy', optimizer='rmsprop', metrics=['acc', 'mse'])
    model.summary()
    return model
  

In [0]:
f = open("/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10.p", 'rb')
embmatrix = "/dbfs/FileStore/data/EMBMATRIXV1_WORD2VEC_v2_300dim.p"
loaded_data = []
for i in range(1): # [reverse_dictionary, train_sequence, test_sequence, train_label, test_label]:
    loaded_data.append(pickle.load(f))

In [0]:
dictionary_new = loaded_data[0]

In [0]:
dictionary_new

In [0]:
import gensim
model_w2v= gensim.models.Word2Vec.load('/dbfs/FileStore/data/model_word2vec_skipgram_300dim')

In [0]:
len(test_data)

In [0]:
def train():
    nb_epoch = 1
    batch_size = 128
    model_name = lstm_1
    pre_train = False
    labelmode = "tile"
    append_name = ''
    f = open("/dbfs/FileStore/data/DATA_WORDSEQV0_HADM_TOP10.p", 'rb')
    embmatrix = embedding_matrix_safe
    loaded_data = []
    patience = 5
#     for i in range(7): # [reverse_dictionary, train_sequence, test_sequence, train_label, test_label]:
#         loaded_data.append(pickle.load(f))
#     f.close()

    dictionary = dictionary_new
    train_sequence = np.array(train_data)
    val_sequence = np.array(valid_data)
    train_label = np.array(train_labels)
    val_label = np.array(valid_labels)
    

#     if labelmode[:4] == 'tile':
#     n = int(labelmode[4:].strip())
    n = 4
    train_label = np.tile(train_label, n)
    val_label = np.tile(val_label, n)
    print('labelmode: tile {0}'.format(train_label.shape))
#     elif args.labelmode[:6] == 'repeat':
#         n = int(args.labelmode[6:].strip())
#         train_label = np.repeat(train_label, n, axis=1)
#         val_label = np.repeat(val_label, n, axis=1)
#         print('labelmode: repeat {0}'.format(train_label.shape))
#     elif args.labelmode[:5] == 'range':
#         n = [int(i) for i in args.labelmode[5:].split("_")]
#         train_label = train_label[:,n[0]:n[1]]
#         val_label = val_label[:,n[0]:n[1]]
#         print('labelmode: range{0}'.format(train_label.shape))

    # f = open('./data/dictionary_v0.p', 'wb')
    # cPickle.dump(dictionary, f)
    # f.close()

#     f = open(embmatrix)
#     embedding_matrix = pickle.load(f)
#     f.close()

    embedding_matrix = embmatrix

    max_sequence_length = train_sequence.shape[1]
    vocabulary_size = len(dictionary) + 1
    embedding_dim = embedding_matrix.shape[1]
    category_number = train_label.shape[1]
    input_shape = train_sequence.shape[1:]

    embedding_layer = Embedding(vocabulary_size,
                        embedding_dim,
                        weights=[embedding_matrix],
                        input_length=max_sequence_length,
                        trainable=False,
                        input_shape=input_shape)

    model_func = lstm_1
    model = model_func(input_shape, category_number, embedding_layer)

#     if args.plot_model:
#         fig_name = './data/cache/' + args.model_name + '.png'
#         plot_model(model, fig_name, True, False)
        # return

#     if not os.path.isdir('./data/cache'):
#         os.mkdir('./data/cache')
    weight_name = 'weights_' + 'lstm_1' + append_name + '.h5'
    weights_path = join('/dbfs/FileStore/data/cache', weight_name)
#     if pre_train:
#         model.load_weights(weights_path)

    print('checkpoint')
    checkpointer = ModelCheckpoint(filepath=weights_path, verbose=1, save_best_only=True)
    earlystopping = EarlyStopping(monitor='val_loss', patience=patience, verbose=0, mode='auto')
    print('early stop at ',patience)

    #train_sequence = np.concatenate((train_sequence, val_sequence), axis=0)
    #train_label = np.concatenate((train_label, val_label), axis=0)

    model.fit(train_sequence, train_label,
              batch_size = batch_size,
              epochs = nb_epoch,
              validation_data = [val_sequence, val_label],
              callbacks=[checkpointer, earlystopping])

In [0]:
train()

In [0]:
454148