# WSDM 2017 Neural Survival Recommender

## 0. Setup

import numpy, tensorflow

In [1]:
import numpy as np
import tensorflow as tf
import pandas as pd
import time
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.per_process_gpu_memory_fraction = 0.9
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

  from ._conv import register_converters as _register_converters


## 1. Data Path

data_path = "../data_sample/indoor/store_B/train_60days/"


In [2]:
store_name = 'store_B'
#data_path = "../data_sample/indoor/%s/train_60days/" % (store_name)
data_path = "../data/indoor/%s/train_180days/" % (store_name)
train_file = data_path + "train_visits.tsv"
test_file = data_path + "test_visits.tsv"

LEARNING_RATE = 0.001
TRAINING_ITERS = 10000
DISPLAY_STEP = 10

INPUT_SIZE = 300
BATCH_SIZE = 128


## 2. Data Preprocessing

- G: gap time
- D: date
- A: action
- U: user

In [3]:
def logging_time(original_fn):
    def wrapper_fn(*args, **kwargs):
        start_time = time.time()
        result = original_fn(*args, **kwargs)
        end_time = time.time()
        print("Running Time[{}]: {:.4f} sec".format(original_fn.__name__, end_time-start_time))
        return result
    return wrapper_fn

In [4]:
@logging_time
def _add_infos(df):
    df['l_index'] = df['indices'].apply(lambda x: [int(y) for y in x.split(';')])

    newidx = [item for sublist in list(df.l_index) for item in sublist]
    tmpdf = wifi_sessions.loc[newidx]
    traj_lens = df.l_index.apply(len)

    tmp_areas = list(tmpdf['area'])
    tmp_dt = list(tmpdf['dwell_time'])
    tmp_ts_start = list(np.array(tmpdf['ts']))
    tmp_ts_end = list(np.array(tmpdf['ts']) + np.array(tmp_dt))  # end time

    rslt_dt = []
    rslt_areas = []
    rslt_ts_start = []
    rslt_ts_end = []

    i = 0
    for x in traj_lens:
        rslt_dt.append(tmp_dt[i:i + x])
        rslt_areas.append(tmp_areas[i:i + x])
        rslt_ts_start.append(min(tmp_ts_start[i:i+x]))
        rslt_ts_end.append(max(tmp_ts_end[i:i+x]))
        i += x

    df['dwell_times'] = rslt_dt
    df['areas'] = rslt_areas
    df['ts_start'] = rslt_ts_start
    df['ts_end'] = rslt_ts_end
    assert all(df.ts_end - df.ts_start >= 0)
    return df

In [5]:
train_labels = pd.read_csv(data_path + 'train_labels.tsv', sep='\t')
test_labels = pd.read_csv(data_path + 'test_labels.tsv', sep='\t')
train_visits = pd.read_csv(data_path + 'train_visits.tsv', sep='\t')
test_visits = pd.read_csv(data_path + 'test_visits.tsv', sep='\t')
wifi_sessions = pd.read_csv(data_path + 'wifi_sessions.tsv', sep='\t').set_index('index')

train_visits = _add_infos(train_visits)
test_visits = _add_infos(test_visits)
assert len(test_visits.tail(3)) == 3

def _add_infos2(df):
    df['areas_str'] = df['areas'].apply(lambda x: ' '.join(x))
    return df

train_visits = _add_infos2(train_visits)
test_visits = _add_infos2(test_visits)

train_visits['wifi_id'] = train_visits['wifi_id'].astype(str)
test_visits['wifi_id'] = test_visits['wifi_id'].astype(str)
    
    

Running Time[_add_infos]: 0.6649 sec
Running Time[_add_infos]: 0.3308 sec


In [6]:
train_visits

Unnamed: 0,visit_id,wifi_id,date,indices,l_index,dwell_times,areas,ts_start,ts_end,areas_str
0,v0,4,17201,163795;163797;163798;163800;163812;163814;1638...,"[163795, 163797, 163798, 163800, 163812, 16381...","[1566, 675, 1575, 795, 982, 1158, 1035, 1243, ...","[1f-left-2, 1f-left-3, 1f-right-2, 1f-right-3,...",1486182424,1486184012,1f-left-2 1f-left-3 1f-right-2 1f-right-3 2f-l...
1,v1,8,17252,400820;400821;400827;400852;400860;400861;400864,"[400820, 400821, 400827, 400852, 400860, 40086...","[1075, 1013, 121, 127, 953, 207, 642]","[1f-right-2, 1f-right-1, 2f-right-1, 1f-left-1...",1490600915,1490602298,1f-right-2 1f-right-1 2f-right-1 1f-left-1 1f-...
2,v2,10,17296,550515;550517;550518;550521;550549,"[550515, 550517, 550518, 550521, 550549]","[341, 262, 470, 226, 60]","[1f-left-3, 1f-right-3, 1f-right-2, 1f-left-2,...",1494392280,1494392766,1f-left-3 1f-right-3 1f-right-2 1f-left-2 1f-l...
3,v3,11,17178,58212;58224;58225;58226;58227;58240;58249;5825...,"[58212, 58224, 58225, 58226, 58227, 58240, 582...","[617, 6, 28, 28, 30, 37, 248, 69, 226, 227]","[1f-right-2, 2f-right-1, 2f-right-2, 2f-left-1...",1484204206,1484204823,1f-right-2 2f-right-1 2f-right-2 2f-left-1 2f-...
4,v4,13,17328,651250;651251;651259;651260;651268;651271;6512...,"[651250, 651251, 651259, 651260, 651268, 65127...","[733, 724, 660, 659, 568, 544, 432, 328, 256, ...","[1f-right-3, 1f-left-3, 1f-left-2, 1f-right-2,...",1497144840,1497174948,1f-right-3 1f-left-3 1f-left-2 1f-right-2 1f-r...
5,v5,14,17238,343964;343971;343972,"[343964, 343971, 343972]","[160, 9, 39]","[1f-left-3, 1f-right-2, 1f-right-3]",1489396397,1489396563,1f-left-3 1f-right-2 1f-right-3
6,v6,17,17200,160644;160645;160667;160668,"[160644, 160645, 160667, 160668]","[200, 200, 19, 11]","[1f-left-2, 1f-right-2, 1f-left-3, 1f-right-3]",1486113925,1486114137,1f-left-2 1f-right-2 1f-left-3 1f-right-3
7,v7,18,17221,265656;265657;265658;265662;265674;265689;2657...,"[265656, 265657, 265658, 265662, 265674, 26568...","[888, 874, 852, 826, 701, 419, 216, 192, 200, ...","[1f-left-3, 1f-right-3, 1f-left-2, 1f-right-2,...",1487933951,1487934839,1f-left-3 1f-right-3 1f-left-2 1f-right-2 1f-r...
8,v8,18,17347,714012;714016;714018;714020;714030;714040;7140...,"[714012, 714016, 714018, 714020, 714030, 71404...","[947, 756, 913, 909, 837, 770, 85, 60, 556, 4]","[1f-left-3, 1f-right-3, 1f-left-2, 1f-right-2,...",1498811487,1498812434,1f-left-3 1f-right-3 1f-left-2 1f-right-2 1f-r...
9,v9,22,17253,404590;404591;404597;404604;404606;404608,"[404590, 404591, 404597, 404604, 404606, 404608]","[608, 756, 729, 694, 694, 678]","[1f-right-3, 1f-left-3, 1f-left-2, 1f-right-1,...",1490680668,1490681424,1f-right-3 1f-left-3 1f-left-2 1f-right-1 1f-r...


In [7]:
kor_UTC = 32400
visit_dict = {}
for i in range(len(train_visits)):
    line = train_visits.iloc[i]
    u = line.wifi_id
    if u not in visit_dict.keys():
        visit_dict[u] = {}
        visit_dict[u]["last_ts_end"] = min(train_visits.ts_start) # 여기가 문제임! 첫 번째 방문의  gap을 구해야 하는데 아마도 관측 시작 날짜로 해야할듯 함 
        visit_dict[u]["g"] = []
        visit_dict[u]["d"] = []

    ts_start = line.ts_start
    _day = time.gmtime(line.ts_end + kor_UTC).tm_wday
    _time = time.gmtime(line.ts_end + kor_UTC).tm_hour
    
    visit_dict[u]["g"].append(int(round((ts_start - visit_dict[u]["last_ts_end"])/3600.))) #1시간 간격으로 quantize함
    visit_dict[u]["d"].append(_day*24 + _time)
    visit_dict[u]["last_ts_end"] = line.ts_end

for u in visit_dict.keys():
    visit_dict[u].pop("last_ts_end")

In [8]:
chk_list = [len(visit_dict[u]['g']) for u in visit_dict.keys()]
print(max(chk_list))

15


In [9]:
visit_dict

{'4': {'d': [48], 'g': [818]},
 '8': {'d': [17], 'g': [2045]},
 '10': {'d': [28], 'g': [3098]},
 '11': {'d': [37], 'g': [268]},
 '13': {'d': [60], 'g': [3863]},
 '14': {'d': [18], 'g': [1711]},
 '17': {'d': [46], 'g': [799]},
 '18': {'d': [48, 45], 'g': [1304, 3021]},
 '22': {'d': [22], 'g': [2067]},
 '23': {'d': [22], 'g': [874]},
 '24': {'d': [39], 'g': [2287]},
 '25': {'d': [60], 'g': [2527]},
 '28': {'d': [32], 'g': [1440]},
 '29': {'d': [58], 'g': [3700]},
 '30': {'d': [37], 'g': [269]},
 '33': {'d': [51, 45, 50], 'g': [316, 817, 358]},
 '34': {'d': [53], 'g': [650]},
 '35': {'d': [63, 18], 'g': [178, 36]},
 '36': {'d': [50], 'g': [467]},
 '37': {'d': [51, 46, 26, 48], 'g': [2164, 163, 1424, 577]},
 '39': {'d': [34], 'g': [2091]},
 '40': {'d': [40], 'g': [4135]},
 '44': {'d': [19], 'g': [720]},
 '47': {'d': [20, 41, 57], 'g': [2568, 577, 873]},
 '48': {'d': [48], 'g': [3825]},
 '55': {'d': [33, 49], 'g': [2281, 33]},
 '56': {'d': [56], 'g': [3682]},
 '58': {'d': [43], 'g': [292]},

In [10]:
for u in visit_dict.keys():
    print(visit_dict[u]['g'])

[818]
[2045]
[3098]
[268]
[3863]
[1711]
[799]
[1304, 3021]
[2067]
[874]
[2287]
[2527]
[1440]
[3700]
[269]
[316, 817, 358]
[650]
[178, 36]
[467]
[2164, 163, 1424, 577]
[2091]
[4135]
[720]
[2568, 577, 873]
[3825]
[2281, 33]
[3682]
[292]
[1155, 260]
[2354]
[1878]
[3727, 236]
[1493]
[3144]
[2504]
[2242]
[3058]
[2552, 88]
[2290]
[368]
[3994]
[365]
[5, 144, 3190]
[3203, 854]
[2303]
[8, 289]
[391, 2180]
[3363]
[3226]
[1444]
[945]
[754, 784, 120, 29, 24, 1849]
[984, 460, 93]
[755]
[418, 403, 840, 264]
[632]
[170, 290, 917]
[1469, 1685]
[1851]
[1232]
[633]
[292, 2064]
[867]
[3795]
[634, 19]
[1799, 196]
[1185]
[2956]
[3411]
[3010]
[1256]
[2192]
[2255]
[3863]
[4111]
[2452]
[3466]
[4136]
[1159]
[657]
[2313]
[1610]
[96]
[1657]
[2040, 1683]
[2314]
[1775]
[298, 1123, 20]
[2531, 45]
[1560]
[3201]
[124, 268, 19, 3505]
[1132]
[1634]
[698]
[1132]
[1831]
[2524]
[3489]
[2687]
[1378, 2949]
[1156]
[2833]
[731]
[3773]
[819]
[2146]
[1424]
[34, 2034, 1948]
[2856]
[3343]
[292]
[1089, 1224]
[4180]
[1473, 448, 365

[4252]
[1879, 835, 121, 28, 258, 170]
[697]
[585, 187, 170, 531, 1742, 193, 148]
[3028]
[417]
[750]
[3677]
[3026]
[611]
[3508]
[1155]
[632, 929, 1923, 314]
[2812]
[335]
[1906]
[730]
[2984]
[1306]
[770, 1950, 351]
[457]
[249]
[1306]
[364]
[3844]
[2065]
[727]
[4187, 36]
[1161]
[2186]
[3866]
[600]
[2809]
[1993, 263, 106, 301]
[3098]
[2289]
[4158]
[632]
[778]
[2456]
[2359]
[2839]
[170]
[2624]
[850]
[989, 2523]
[2046, 91]
[731]
[32]
[1226]
[890]
[4035]
[3772, 25]
[222]
[321]
[1234, 711, 199]
[1346]
[1778, 17]
[1396]
[3006]
[3267]
[364]
[1949]
[792, 338]
[1609]
[434]
[365]
[1281]
[1111, 115, 2732]
[2067, 486, 161, 462]
[707, 162, 1133, 480, 1124]
[2809]
[1347]
[250, 3264, 600, 212]
[3536]
[3674]
[3265]
[3942]
[1251]
[7, 1872, 1773]
[1709]
[5, 1294]
[487]
[2144, 2089]
[3747]
[2047]
[1978]
[2188]
[294, 1392, 1552, 25, 1035]
[315, 940, 1461, 936]
[3915, 199]
[702, 193, 1083, 1236, 775]
[2863]
[1995]
[195]
[1350, 409]
[4307]
[3273]
[1513]
[2835]
[1667]
[361, 3530]
[2092]
[2929]
[729]
[3779]
[356

[3364]
[1180, 124]
[3534]
[582]
[681]
[1325]
[1930, 39]
[32, 26]
[3727, 45]
[412]
[3176, 41, 606]
[4302]
[1633]
[3965]
[1343]
[752]
[1]
[3369]
[2122]
[3098]
[751, 2896, 273, 18]
[3510]
[1300]
[1176]
[1137]
[801]
[3535]
[1492, 24]
[2954, 626]
[177, 20]
[510]
[844]
[1400]
[4041]
[344]
[2159, 24]
[560]
[3918]
[1490, 188]
[2578]
[899, 120, 10, 1250, 1008, 848]
[1159]
[3825, 38]
[670, 26]
[1775]
[2694]
[3575]
[1326, 763, 747, 336, 219, 882]
[3369]
[2555]
[1040, 90]
[4210, 65]
[1872, 1569, 884]
[1153, 968]
[3486]
[1396]
[1276]
[1777, 701, 473, 269, 44]
[1226]
[1421, 1751, 1029]
[4207, 46]
[269]
[1276]
[3579]
[2288]
[1282, 522]
[311]
[458]
[4257]
[1594, 1417, 39, 959, 53]
[295]
[3079, 89]
[1181]
[1755]
[342]
[146, 87, 1061]
[275, 86]
[509, 45]
[624]
[1]
[653]
[1564]
[440]
[2378, 317, 1100]
[3531]
[1350]
[678]
[1560]
[1130, 340, 700]
[3514]
[2351]
[99]
[1233, 16]
[3215]
[1587]
[102]
[2336]
[53, 1179, 1893, 723, 15, 248]
[193]
[512]
[1299, 740]
[771]
[1834]
[3055]
[3993, 15]
[215]
[3345]
[347, 

[199, 1125]
[3819]
[1276]
[1491]
[2280]
[441, 1030, 26, 23, 18, 1394]
[699]
[1257]
[2378]
[960, 3150]
[1323, 246]
[344, 73]
[462, 2691]
[3437]
[171, 48]
[1537, 1543, 215]
[178]
[3559, 19]
[1156]
[10, 3860]
[2478]
[3752]
[677, 164, 653]
[4015]
[3514]
[4132]
[2568]
[2193]
[74]
[3675]
[1323]
[1730]
[171, 3627]
[1704]
[4209]
[1066, 279]
[383]
[1682]
[3774]
[2219]
[2240, 857]
[1276]
[3600]
[31]
[3820]
[3509]
[3602]
[730, 1382, 269]
[218]
[1689]
[179]
[2044, 2066]
[203, 1748]
[1034]
[1732]
[338]
[1781]
[1468, 622]
[820, 1002, 80, 120, 1725]
[1352]
[4183]
[1420]
[4091, 43, 45]
[197]
[1562]
[1086]
[631]
[53]
[1040, 308]
[3697]
[3586]
[418]
[388, 739]
[748, 52, 1051, 45, 778]
[2544]
[1106, 1467, 1221]
[2337, 49]
[241]
[529]
[4327]
[1762]
[965]
[1252]
[873]
[1011]
[4031]
[3724]
[1306]
[1107, 960, 413]
[1565, 577]
[1930]
[3943]
[3298]
[611]
[3921]
[985]
[2939]
[2594]
[1853]
[196]
[943, 1319]
[3238, 25]
[2374]
[1185]
[2242]
[220]
[1661]
[2386]
[2386, 20, 49]
[431]
[4256]
[1137, 25, 95, 521, 198]
[

[3703]
[3217]
[321]
[4185]
[3388]
[3270, 717, 53]
[3608]
[3346]
[2962]
[2314]
[923, 71, 23, 1824]
[1419]
[3077]
[2169, 46]
[1536]
[3966]
[149]
[2548]
[2458]
[922]
[2455, 1461]
[3076]
[1469]
[1254, 268, 253, 73, 1730, 293]
[432, 346, 2346]
[1492, 2500]
[1322, 1351, 1026, 318]
[940]
[820, 408]
[1994]
[3533]
[3296]
[817, 146]
[1661]
[1446]
[1762]
[2647]
[1323, 864]
[1320]
[3049]
[2908]
[2335]
[2792, 74, 12, 635, 548]
[1416, 1563]
[530]
[2819]
[1321]
[2649]
[1513, 2666]
[4306]
[2471]
[1686]
[1155, 31]
[2459]
[1059]
[2353]
[5]
[202, 23, 40, 786]
[2927]
[1679]
[1469]
[1368]
[275, 18, 771]
[679]
[2453, 1822]
[1780]
[2715]
[3863]
[3894]
[3321]
[4018]
[3749, 71]
[723]
[81, 283, 1279, 213, 66, 914, 292, 95, 521, 31]
[3361]
[611]
[2314, 1818]
[1009]
[3026]
[3872]
[603]
[1258, 37]
[3749]
[1153, 417]
[1346]
[457, 1228]
[554, 1125, 173]
[293, 20, 1062, 16]
[866, 344]
[3226]
[3721, 56]
[824]
[243]
[2452]
[1729]
[1955, 70]
[769]
[418]
[2354]
[2120, 22]
[1033]
[1710]
[3538]
[3874, 36]
[3006]
[2602]
[81

[2191]
[772]
[127]
[2449]
[681]
[966]
[2099]
[1299]
[2523]
[1906]
[1515]
[2807]
[1113]
[2024, 21]
[3942, 68]
[3514]
[1320]
[1162]
[2958]
[3870]
[538]
[890, 264, 197]
[986, 78, 692, 143]
[3031, 121]
[531, 17]
[3537]
[2768]
[4206, 71]
[4136]
[147, 698, 605, 2079]
[1856]
[863, 851]
[1395]
[4064, 39]
[2120, 22]
[801]
[751]
[3898]
[1713]
[3635, 87]
[674]
[1325, 1034, 962]
[984, 3321]
[2406]
[4040]
[489, 765]
[1276]
[1560]
[1304]
[1543, 74]
[4155]
[370, 1573, 896, 165]
[3273]
[1205]
[1063, 121]
[2424]
[513]
[122]
[825]
[1106]
[311]
[1710]
[865, 73]
[3706]
[321]
[4009]
[1584, 32, 1238]
[819]
[127]
[1160]
[2304, 298, 15]
[27, 452, 96, 294, 1485, 870, 746]
[4128]
[3893]
[131]
[912, 984, 179]
[1783, 18]
[1883]
[392]
[802, 1964, 712, 625]
[3874]
[1159, 26]
[3217]
[2554]
[3371, 60]
[2434, 213, 25, 354]
[293]
[488]
[1513]
[411]
[1396, 359, 312, 122, 772]
[1258, 94, 75, 1503]
[49]
[2472]
[989]
[2217, 18]
[1419, 294]
[3631, 145]
[2449]
[1924, 24]
[3727]
[1995]
[412]
[2530]
[3166]
[1595]
[2122, 25]
[9

[2289, 69, 717]
[1659, 29]
[198, 1439]
[3965]
[2361]
[847]
[3248]
[3988]
[3915]
[3415]
[1132, 1650, 272]
[1689]
[1199, 250, 2563, 93]
[1161, 2736]
[438, 3220]
[2955]
[940]
[2361]
[701, 22]
[749]
[772, 51]
[913, 813, 2447]
[345, 258, 582, 145]
[2162]
[3897]
[3176]
[3412, 360, 45, 271]
[32, 1223, 378, 415, 42]
[4007]
[3803, 23]
[143, 3074, 33, 305]
[917]
[4014]
[465, 1314, 1702, 58, 282, 96, 44, 287]
[726]
[127, 21, 1368]
[1112]
[1372]
[104, 97, 135, 990, 162, 1903]
[2167, 49]
[3987]
[3006]
[3743]
[3721]
[4105]
[82, 619, 170]
[610]
[671]
[2784]
[1448]
[3202]
[676]
[2499]
[2712]
[2672]
[360, 128, 305, 792, 1743, 262]
[320]
[2842]
[2]
[802]
[1683]
[294, 1266, 2363]
[341]
[2191]
[2359, 815, 184]
[721]
[4113]
[3289]
[52, 821]
[418, 23]
[509]
[4013]
[4233]
[4279]
[1034, 240, 835, 101, 67, 33, 400, 606, 69, 210, 239, 442]
[533, 50]
[104]
[2837]
[1592]
[1275]
[2785]
[4012]
[1919]
[2171]
[1713, 408, 360, 863, 813]
[2001]
[362]
[603, 96, 2236]
[5]
[29, 1412, 967, 738, 316, 674, 147]
[2092]
[1133]

[295, 700, 110, 2095, 115]
[3989]
[4305]
[1521]
[2168, 329]
[53]
[1323]
[297, 48]
[1640, 16]
[1205]
[2858]
[1590, 574, 138, 119]
[794]
[682]
[822]
[273, 624]
[4199]
[2166, 1147]
[1950]
[2190]
[1061]
[410, 292]
[1466]
[3913]
[2978]
[2424]
[1300, 2419, 265]
[676, 1703]
[3818]
[3367, 144, 454]
[3587]
[295, 1008]
[1897]
[657]
[106]
[4182]
[2075]
[558]
[1568, 26]
[1705]
[920]
[1752]
[632]
[1517, 48, 822, 1025, 892]
[100, 385, 506]
[581]
[940]
[2387]
[1449, 143]
[1833]
[2459, 1360]
[942, 1845]
[4230]
[3729]
[1305]
[2569, 1209]
[1013, 978]
[575]
[197, 477, 2310, 25, 157, 1161]
[1187]
[891]
[1780]
[802]
[3128, 17]
[3649]
[171]
[1351, 161, 2071]
[1543]
[1231, 51]
[1897]
[1449]
[2309]
[754]
[507]
[2544]
[2478]
[99, 168, 691, 29, 24, 192]
[3624]
[2002, 258]
[2186, 1845]
[2049]
[489, 48]
[54]
[2524]
[1643]
[1511]
[4014]
[1017]
[3391]
[35]
[3555]
[443]
[1489]
[1033]
[4112]
[869]
[1713]
[2865]
[1035, 409, 748]
[4179]
[3723]
[461]
[4206, 45]
[793, 2433, 400, 360, 21, 45]
[1595]
[146, 796]
[815]
[1703

[1130]
[1039]
[749]
[3527]
[2426]
[1233, 19, 52, 287, 41, 1610, 770]
[2695]
[1106]
[338, 1880, 1434, 329]
[2694]
[875, 1743]
[730, 551, 862, 165, 116, 561]
[697]
[244]
[2046]
[1207]
[989]
[1163, 23, 646]
[264]
[3898, 426]
[387]
[1755]
[1274]
[1254, 26]
[1545]
[439]
[153, 978, 575, 626, 214, 292, 188, 126, 18, 526, 582]
[1328, 14]
[1521]
[865]
[4019]
[75, 104]
[485]
[3700]
[3704]
[3893, 147]
[512]
[530]
[697]
[1354, 24]
[556]
[2147, 61]
[4058]
[2527]
[295, 67]
[3891]
[3102]
[3490, 187, 221]
[2160]
[321]
[415]
[681]
[2865]
[2067, 43]
[2088]
[1687]
[1737]
[360]
[53]
[3988, 192]
[1373, 170, 2305]
[599]
[1614]
[3681]
[701, 1275]
[1779]
[2192]
[1105]
[985, 123, 623]
[1931]
[8]
[3923]
[1042]
[1079]
[1396]
[1850]
[2219]
[1281]
[1761]
[653, 360, 1031, 1318, 412, 399]
[1974]
[2183]
[1560, 46]
[847]
[1947]
[824]
[1038]
[608, 407]
[1250]
[961]
[1374]
[2386]
[770, 3124]
[2233, 81]
[2020]
[298, 24]
[755]
[3178]
[1369, 25, 223]
[1162]
[2407]
[3580]
[608]
[1375, 1630]
[127, 1219]
[922, 1749]
[2554]
[6

[720]
[291, 887]
[821]
[1690, 1800, 97]
[2020, 28]
[176, 74, 403, 165]
[177]
[1059]
[2330]
[1852]
[196]
[4200, 125]
[125, 1779, 193]
[1880]
[987]
[559]
[959]
[2407, 48]
[3727]
[1757]
[1806]
[2471]
[1283]
[55, 1032]
[2836, 54]
[1499, 13]
[1635, 147, 1122]
[3869]
[1736, 373, 54, 73, 288, 766, 331]
[1641, 24]
[3316]
[1568]
[2907]
[657]
[2359]
[650]
[4324]
[817]
[1976]
[966]
[487]
[1248]
[1616, 2179]
[1468]
[345]
[1320]
[2842]
[1280]
[3579]
[3441, 48]
[3412]
[2218, 24]
[1976, 19, 170]
[1105, 939, 932, 152, 19]
[1011]
[891]
[2505]
[1039, 2020, 46, 25]
[216]
[3388]
[3338]
[2287]
[701]
[216]
[1511]
[128]
[2793]
[371, 1727]
[1064]
[342]
[534]
[3533]
[3347]
[1946]
[1498]
[515]
[864]
[79, 215, 17, 270, 409, 42, 103, 477, 2492, 50]
[2816]
[295, 1443, 23, 94]
[2359]
[30]
[3699]
[1249]
[1110]
[365]
[1391]
[1710]
[1953, 67]
[1497]
[1305, 86]
[1392]
[343]
[4257]
[3344, 359]
[730]
[803]
[1153, 1702]
[3625]
[176]
[1067, 1319]
[1248]
[1906, 1249]
[1321]
[441, 18]
[971]
[4272]
[1928]
[1162, 66, 25]
[3650

[698]
[2836]
[4034, 23]
[2834, 1414]
[4017]
[2452]
[29]
[4184]
[2190]
[531]
[2761]
[3897, 23]
[3752]
[3752, 15]
[1162]
[1465]
[1181]
[586]
[222]
[4087, 144]
[843]
[1348]
[1954]
[2694]
[2723, 41]
[3178]
[1809]
[392]
[2861]
[3941]
[2283]
[96, 1156]
[3293]
[2333]
[1015]
[149]
[3720]
[1019]
[628]
[656, 63, 557]
[1665, 13]
[1323]
[3120]
[2115]
[191, 558, 2612]
[528, 1350]
[916]
[896]
[532, 69, 3560]
[3442]
[1898]
[1689]
[3581]
[274]
[2, 267, 502]
[1183]
[2406]
[2278]
[1155, 75]
[706]
[34]
[265, 1008, 557, 2328]
[969, 15, 26, 319]
[1399]
[1636]
[1499]
[1786]
[483]
[3791]
[1922]
[1352]
[3487]
[225, 23]
[369]
[2927, 387, 190, 58]
[4019]
[1136]
[1449, 410]
[2764]
[818, 702]
[2573]
[167]
[3101, 98]
[105]
[873]
[1850]
[1643]
[3130]
[3534]
[2861]
[3586]
[818]
[1997]
[264, 26, 31, 15, 26, 3174]
[1129, 1466]
[2549, 19]
[1394]
[2696]
[3562]
[720]
[1374]
[1948]
[796, 500, 49]
[2000]
[1739]
[723]
[1690]
[1513]
[1567]
[3101, 44]
[1786, 15]
[1102, 724]
[778, 1195]
[1351, 142, 2227]
[2930]
[2576]
[344, 24

[1919]
[437]
[198, 293, 119, 792, 550, 66, 483, 329, 770, 200]
[1584]
[2934]
[1690]
[1257, 783, 387, 790, 192]
[824]
[56]
[3345]
[1897]
[1226]
[150, 4060]
[657]
[944]
[1326]
[123]
[1133]
[536, 136, 559, 2322]
[1369]
[946, 912, 1032, 448, 727]
[3630]
[98, 2085, 1929]
[871]
[1759]
[369, 113, 1298]
[725]
[2982]
[5, 1732]
[846]
[4206]
[105, 743, 2330]
[3219]
[1834]
[2049]
[921]
[9]
[3053, 94]
[3125, 99, 259, 288]
[2934]
[1136]
[895]
[1396]
[125]
[1281, 1442]
[3681]
[2890]
[3585]
[2666, 1372]
[887]
[1399, 553, 18]
[778, 1414]
[983, 804]
[1279]
[674]
[3917]
[1400]
[2122]
[703, 115]
[4251]
[1878, 2373]
[2857]
[1758, 791]
[3897]
[1391, 719]
[1370]
[2983]
[410]
[1949]
[274]
[2266]
[26]
[1107]
[291, 67, 221]
[1546]
[153]
[2836]
[2409]
[1397, 837, 368]
[362]
[1421, 1756]
[2405, 18]
[4330]
[2550]
[319, 48]
[839]
[1739, 36]
[1828]
[272, 435, 811, 841]
[730]
[652]
[2698, 192]
[2164]
[3994]
[821]
[394]
[316]
[3818]
[4227]
[1251, 74]
[142, 59, 62, 841, 529, 223, 22, 1504]
[1929]
[4083]
[4258]
[683, 22

[360]
[1330]
[2935, 652]
[456]
[314]
[4319]
[507]
[1898]
[1969, 527, 816]
[867, 1581, 1032, 53]
[3268]
[148, 767, 43]
[1161]
[718]
[3298]
[538, 1287, 1712]
[602]
[9]
[963]
[149]
[1010, 289]
[630, 1722]
[3584]
[4179]
[1728]
[2232]
[1328]
[1180]
[1783]
[2193]
[3824]
[192]
[484]
[1993]
[393, 546]
[2910]
[3609]
[1615, 17]
[198]
[290]
[3656]
[3177]
[385]
[1907, 24]
[4061]
[247, 69]
[2240]
[915]
[1061]
[1499, 880, 1183]
[2169]
[2039]
[579]
[918]
[3083]
[2954]
[1976]
[3747]
[1739, 1200, 311, 712, 26]
[3864]
[3993]
[2212]
[272, 44]
[224]
[2435, 47]
[1898]
[149, 42]
[965]
[82, 42]
[4038]
[915, 534, 549, 410]
[558]
[3632, 404]
[533]
[750]
[2434]
[1803]
[273]
[31]
[1324]
[2787]
[1419]
[56]
[1999]
[4061]
[1905, 814]
[56]
[509]
[673, 28]
[1177]
[2002]
[71]
[364]
[321]
[3969]
[1568, 235, 1448]
[1737]
[3583, 26]
[369, 880]
[3604, 51, 643]
[1210]
[2615]
[1640]
[4299]
[4009]
[1083]
[1062]
[4035]
[484, 1943, 1417]
[1619, 838, 48, 958]
[264]
[51, 933, 20]
[3267, 287]
[3199]
[1153]
[2840]
[1590]
[2692, 12

In [11]:
print(len(train_visits))

45286


## 3. Model

x = tf.placeholder(tf.float64, name = 'x')
y = tf.placeholder(tf.float64, name = 'y')

weight = tf.Variable(tf.random_normal([None, None]))
bias = tf.Variable(tf.random_normal([None]))

In [12]:
def model_inputs():
    """
    Create the model inputs
    """
    inputs_ = tf.placeholder(tf.int32, [None, None], name='inputs')
    labels_ = tf.placeholder(tf.int32, [None, None], name='labels')
    keep_prob_ = tf.placeholder(tf.float32, name='keep_prob')
    
    return inputs_, labels_, keep_prob_

In [13]:
def build_embedding_layer(inputs_, vocab_size, embed_size):
    """
    Create the embedding layer
    """
    embedding = tf.Variable(tf.random_uniform((vocab_size, embed_size), -1, 1))
    embed = tf.nn.embedding_lookup(embedding, inputs_)
    
    return embed

In [14]:
def build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size):
    """
    Create the LSTM layers
    """
    #lstms = [tf.contrib.cudnn_rnn.CudnnCompatibleLSTMCell(size) for size in lstm_sizes]
    lstms = [tf.contrib.rnn.BasicLSTMCell(size) for size in lstm_sizes]

    # Add dropout to the cell
    drops = [tf.contrib.rnn.DropoutWrapper(lstm, output_keep_prob=keep_prob_) for lstm in lstms]

    # Stack up multiple LSTM layers, for deep learning
    cell = tf.contrib.rnn.MultiRNNCell(drops)

    # Getting an initial state of all zeros
    initial_state = cell.zero_state(batch_size, tf.float32)
    
    lstm_outputs, final_state = tf.nn.dynamic_rnn(cell, embed, initial_state=initial_state)
    
    return initial_state, lstm_outputs, cell, final_state

In [15]:
def build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate):
    """
    Create the Loss function and Optimizer
    """
    predictions = tf.contrib.layers.fully_connected(lstm_outputs[:, -1], 1, activation_fn=tf.nn.relu)
    loss = tf.losses.mean_squared_error(labels_, predictions)
    optimzer = tf.train.AdadeltaOptimizer(learning_rate).minimize(loss)
    
    return predictions, loss, optimzer

In [16]:
def build_accuracy(predictions, labels_):
    """
    Create accuracy
    """
    correct_pred = tf.equal(tf.cast(tf.round(predictions), tf.int32), labels_)
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    
    return accuracy

In [17]:
def get_batches(x, y, batch_size=100):
    """
    Batch Generator for Training
    :param x: Input array of x data
    :param y: Input array of y data
    :param batch_size: Input int, size of batch
    :return: generator that returns a tuple of our x batch and y batch
    """
    n_batches = len(x)//batch_size
    x, y = x[:n_batches*batch_size], y[:n_batches*batch_size]
    for ii in range(0, len(x), batch_size):
        yield x[ii:ii+batch_size], y[ii:ii+batch_size]

In [18]:
def build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size,
                            learning_rate, keep_prob, train_x, val_x, train_y, val_y):
    
    inputs_, labels_, keep_prob_ = model_inputs()
    embed = build_embedding_layer(inputs_, vocab_size, embed_size)
    initial_state, lstm_outputs, lstm_cell, final_state = build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size)
    predictions, loss, optimizer = build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate)
    accuracy = build_accuracy(predictions, labels_)
    
    saver = tf.train.Saver()
    
    with tf.Session() as sess:
        
        sess.run(tf.global_variables_initializer())
        n_batches = len(train_x)//batch_size
        for e in range(epochs):
            state = sess.run(initial_state)
            
            train_acc = []
            for ii, (x, y) in enumerate(get_batches(train_x, train_y, batch_size), 1):
                feed = {inputs_: x,
                        labels_: y[:, None],
                        keep_prob_: keep_prob,
                        initial_state: state}

                loss_, state, _,  batch_acc = sess.run([loss, final_state, optimizer, accuracy], feed_dict=feed)
                train_acc.append(batch_acc)
                
                if (ii + 1) % n_batches == 0:
                    
                    val_acc = []
                    val_state = sess.run(lstm_cell.zero_state(batch_size, tf.float32))
                    for xx, yy in get_batches(val_x, val_y, batch_size):
                        feed = {inputs_: xx,
                                labels_: yy[:, None],
                                keep_prob_: 1,
                                initial_state: val_state}
                        val_batch_acc, val_state = sess.run([accuracy, final_state], feed_dict=feed)
                        val_acc.append(val_batch_acc)
                    
                    print("Epoch: {}/{}...".format(e+1, epochs),
                          "Batch: {}/{}...".format(ii+1, n_batches),
                          "Train Loss: {:.3f}...".format(loss_),
                          "Train Accruacy: {:.3f}...".format(np.mean(train_acc)),
                          "Val Accuracy: {:.3f}".format(np.mean(val_acc)))
    
        saver.save(sess, "checkpoints/sentiment.ckpt")

In [19]:
# Define Inputs and Hyperparameters
lstm_sizes = [128, 64]
vocab_size = len(train_visits) + 1 #add one for padding
embed_size = 300
epochs = 50
batch_size = 256
learning_rate = 0.001
keep_prob = 0.5

In [20]:
g = [visit_dict[u]['g'] for u in visit_dict.keys()]
train_x = np.array([g_value for g_value in g])
train_y = np.array([1 if len(g_value) > 1 else 0 for g_value in g])
val_x = train_x
val_y = train_y

In [None]:
chk_list = [len(g_value) for g_value in g]
max_size = max(chk_list)
print(max_size)

new_g = []
for g_value in g:
    new_x = g_value[:]
    for i in range(max_size - len(g_value)):
        new_x = [0] + new_x
    new_g.append(new_x)
train_x = np.array(new_g)
val_x = train_x

train_x

15


array([[   0,    0,    0, ...,    0,    0,  818],
       [   0,    0,    0, ...,    0,    0, 2045],
       [   0,    0,    0, ...,    0,    0, 3098],
       ...,
       [   0,    0,    0, ...,    0,    0, 1233],
       [   0,    0,    0, ...,    0,    0,  455],
       [   0,    0,    0, ...,    0,    0, 4302]])

In [None]:
with tf.Graph().as_default():
    build_and_train_network(lstm_sizes, vocab_size, embed_size, epochs, batch_size, learning_rate, keep_prob, train_x, val_x, train_y, val_y)

Epoch: 1/50... Batch: 115/115... Train Loss: 0.273... Train Accruacy: 0.656... Val Accuracy: 0.714
Epoch: 2/50... Batch: 115/115... Train Loss: 0.269... Train Accruacy: 0.663... Val Accuracy: 0.714
Epoch: 3/50... Batch: 115/115... Train Loss: 0.279... Train Accruacy: 0.672... Val Accuracy: 0.714
Epoch: 4/50... Batch: 115/115... Train Loss: 0.261... Train Accruacy: 0.673... Val Accuracy: 0.714
Epoch: 5/50... Batch: 115/115... Train Loss: 0.251... Train Accruacy: 0.671... Val Accuracy: 0.714
Epoch: 6/50... Batch: 115/115... Train Loss: 0.261... Train Accruacy: 0.677... Val Accuracy: 0.714
Epoch: 7/50... Batch: 115/115... Train Loss: 0.258... Train Accruacy: 0.677... Val Accuracy: 0.714
Epoch: 8/50... Batch: 115/115... Train Loss: 0.243... Train Accruacy: 0.679... Val Accuracy: 0.714
Epoch: 9/50... Batch: 115/115... Train Loss: 0.223... Train Accruacy: 0.678... Val Accuracy: 0.714
Epoch: 10/50... Batch: 115/115... Train Loss: 0.256... Train Accruacy: 0.675... Val Accuracy: 0.714


In [None]:
def test_network(model_dir, batch_size, test_x, test_y):
    
    inputs_, labels_, keep_prob_ = model_inputs()
    embed = build_embedding_layer(inputs_, vocab_size, embed_size)
    initial_state, lstm_outputs, lstm_cell, final_state = build_lstm_layers(lstm_sizes, embed, keep_prob_, batch_size)
    predictions, loss, optimizer = build_cost_fn_and_opt(lstm_outputs, labels_, learning_rate)
    accuracy = build_accuracy(predictions, labels_)
    
    saver = tf.train.Saver()
    
    test_acc = []
    with tf.Session() as sess:
        saver.restore(sess, tf.train.latest_checkpoint(model_dir))
        test_state = sess.run(lstm_cell.zero_state(batch_size, tf.float32))
        for ii, (x, y) in enumerate(get_batches(test_x, test_y, batch_size), 1):
            feed = {inputs_: x,
                    labels_: y[:, None],
                    keep_prob_: 1,
                    initial_state: test_state}
            batch_acc, test_state = sess.run([accuracy, final_state], feed_dict=feed)
            test_acc.append(batch_acc)
        print("Test Accuracy: {:.3f}".format(np.mean(test_acc)))

In [None]:
with tf.Graph().as_default():
    test_network('checkpoints', batch_size, test_x, test_y)