In [1]:
from keras.models import Sequential
from keras.layers import Dense, Dropout, Embedding, LSTM, GRU, Bidirectional, TimeDistributed, BatchNormalization, Embedding

from numpy import array
from keras.models import load_model
from keras.utils import np_utils
from keras.callbacks import ModelCheckpoint
import os

import numpy as np
import pandas as pd
from tqdm import tqdm
from sklearn import preprocessing

from trackml.dataset import load_event, load_dataset
from trackml.score import score_event

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import warnings
warnings.filterwarnings('ignore')

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
scl = preprocessing.StandardScaler()

#https://www.kaggle.com/mikhailhushchyn/dbscan-benchmark
#https://www.kaggle.com/mikhailhushchyn/hough-transform
def norm_points(df):
    x = df.x.values
    y = df.y.values
    z = df.z.values
    r = np.sqrt(x**2 + y**2 + z**2)
    df['x2'] = x/r
    df['y2'] = y/r
    df['z2'] = z / np.sqrt(x**2 + y**2)
    return df


In [3]:
def append_to_csv(batch, csv_file):
    props = dict(encoding='utf-8', index=False)
    if not os.path.exists(csv_file):
        batch.to_csv(csv_file, **props)
    else:
        batch.to_csv(csv_file, mode='a', header=False, **props)

def delete_file_if_exists(filename):
    if os.path.exists(filename):
        os.remove(filename)

In [4]:
def create_one_event_submission(event_id, hits, labels):
    sub_data = np.column_stack(([event_id]*len(hits), hits.hit_id.values, labels))
    submission = pd.DataFrame(data=sub_data, columns=["event_id", "hit_id", "track_id"]).astype(int)
    return submission

In [5]:
hits = pd.read_csv('../cache/hits_6488.csv')
# hits['new_pid'] = hits.particle_id.astype('str') + '_' + hits.event_id.astype('str') 
# hits = hits[hits.nhits >= 9]
hits = norm_points(hits)

In [6]:
hits.head()

Unnamed: 0,hit_id,x,y,z,volume_id,layer_id,module_id,particle_id,tx,ty,...,px,py,pz,q,nhits,event_id,new_pid,x2,y2,z2
0,2,-72.4041,-1.15933,-1502.5,7,2,1,702569307170668544,-72.4163,-1.13906,...,-0.376872,-0.027696,-7.8103,1.0,14.0,6488,702569307170668544_6488,-0.048133,-0.000771,-20.748928
1,3,-56.7804,-8.25539,-1502.5,7,2,1,58552223994478592,-56.7717,-8.2386,...,-0.993228,-0.127344,-26.2075,-1.0,12.0,6488,58552223994478592_6488,-0.037763,-0.00549,-26.186271
2,4,-93.1412,-14.678,-1502.5,7,2,1,1062854322422808576,-93.1464,-14.6696,...,-0.194728,0.000293,-3.03961,-1.0,16.0,6488,1062854322422808576_6488,-0.061869,-0.00975,-15.934772
3,6,-68.1201,-6.84073,-1502.5,7,2,1,833166824416739328,-68.1281,-6.81962,...,-0.170736,-0.03791,-3.82846,1.0,12.0,6488,833166824416739328_6488,-0.045291,-0.004548,-21.946251
4,7,-60.4104,-11.6005,-1502.5,7,2,1,887215586222800896,-60.3889,-11.6158,...,-0.35928,-0.049738,-8.85119,-1.0,11.0,6488,887215586222800896_6488,-0.040173,-0.007714,-24.425283


In [7]:
hits.shape

(121814, 28)

In [8]:
hits.particle_id.nunique()

10149

In [9]:
new_pid_list = list(set(hits.new_pid.values))

new_pid_count = list(range(len(new_pid_list)))
new_pid_dict = dict(zip(new_pid_list, new_pid_count))
    
    
hits['nid'] = hits['new_pid'].map(lambda x: new_pid_dict[x])
hits.sort_values(['nid', 'z'], inplace=True)

In [10]:
hits.tail()

Unnamed: 0,hit_id,x,y,z,volume_id,layer_id,module_id,particle_id,tx,ty,...,py,pz,q,nhits,event_id,new_pid,x2,y2,z2,nid
96469,116061,-471.454,154.666,52.0,13,6,933,216176218087620608,-471.462,154.651,...,-0.089198,0.021324,1.0,13.0,6488,216176218087620608_6488,-0.945,0.310018,0.104802,10148
96119,115565,-473.197,157.971,52.2,13,6,855,216176218087620608,-473.192,157.982,...,-0.089198,0.021324,1.0,13.0,6488,216176218087620608_6488,-0.943389,0.314939,0.104637,10148
96114,115559,-476.419,164.489,52.5429,13,6,854,216176218087620608,-476.431,164.469,...,-0.089198,0.021324,1.0,13.0,6488,216176218087620608_6488,-0.940152,0.324598,0.104249,10148
100945,123411,-493.809,433.838,83.2,13,8,1213,216176218087620608,-493.831,433.819,...,-0.089198,0.021324,1.0,13.0,6488,216176218087620608_6488,-0.745305,0.654791,0.126576,10148
100942,123407,-489.842,445.385,84.4,13,8,1212,216176218087620608,-489.841,445.386,...,-0.089198,0.021324,1.0,13.0,6488,216176218087620608_6488,-0.733944,0.667333,0.127482,10148


In [11]:
model = Sequential()
# model.add(BatchNormalization(input_shape=(1,2)))
model.add(Bidirectional(LSTM(100, return_sequences=True), input_shape=(1,3)))
# model.add(Bidirectional(LSTM(100, return_sequences=True)))
# model.add(Bidirectional(LSTM(100, return_sequences=True)))
model.add(Bidirectional(LSTM(100)))
# model.add(Dropout(0.3))
model.add(Dense(32, activation='relu'))

# model.add(Dense(100000, activation='relu'))

# model.add(Dropout(0.1))
model.add(Dense(10149, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam')

In [12]:
for n in tqdm(range(102)):
#     if n == 0:
#         continue
    w_start = n*100
    w_end = (n+1)*100
    if w_end > 10149:
        w_end = 10149
    hits1 = hits[(hits.nid >= w_start) & (hits.nid < w_end)]
   
    str1='../cache/checkpoint/checkpoint-6488-copy3-{}'.format(n)
    str2 = '{epoch:02d}.hdf5'
    filepath = str1 + '-' + str2
    checkpoint = ModelCheckpoint(filepath, verbose=0, save_best_only=False)
    callbacks_list = [checkpoint]

#     print(hits1.head())
    X_train = scl.fit_transform(hits1[['x', 'y', 'z']].values)
#     X_train = hits1[['x', 'y', 'z']].values
    
    y = hits1['nid'].values
    print(y)
#     print(len(y))
    print(X_train.shape)
    
    model.fit(X_train.reshape(X_train.shape[0], 1, 3), y, batch_size=32, epochs=3000, shuffle=False, verbose=0, 
          callbacks=callbacks_list)

  0%|          | 0/102 [00:00<?, ?it/s]

[ 0  0  0 ... 99 99 99]
(1174, 3)


  1%|          | 1/102 [1:02:09<104:38:15, 3729.66s/it]

[100 100 100 ... 199 199 199]
(1212, 3)


  2%|▏         | 2/102 [2:00:08<100:07:13, 3604.34s/it]

[200 200 200 ... 299 299 299]
(1203, 3)


  3%|▎         | 3/102 [2:59:26<98:41:49, 3588.98s/it] 

[300 300 300 ... 399 399 399]
(1195, 3)


  4%|▍         | 4/102 [3:58:37<97:26:27, 3579.46s/it]

[400 400 400 ... 499 499 499]
(1202, 3)


  5%|▍         | 5/102 [4:59:08<96:43:16, 3589.65s/it]

[500 500 500 ... 599 599 599]
(1202, 3)


  6%|▌         | 6/102 [5:59:16<95:48:20, 3592.71s/it]

[600 600 600 ... 699 699 699]
(1183, 3)


  7%|▋         | 7/102 [6:58:10<94:35:19, 3584.42s/it]

[700 700 700 ... 799 799 799]
(1195, 3)


  8%|▊         | 8/102 [7:58:15<93:39:35, 3586.98s/it]

[800 800 800 ... 899 899 899]
(1209, 3)


  9%|▉         | 9/102 [8:58:37<92:45:45, 3590.81s/it]

[900 900 900 ... 999 999 999]
(1195, 3)


 10%|▉         | 10/102 [9:58:38<91:47:33, 3591.89s/it]

[1000 1000 1000 ... 1099 1099 1099]
(1200, 3)


 11%|█         | 11/102 [10:58:39<90:48:53, 3592.68s/it]

[1100 1100 1100 ... 1199 1199 1199]
(1196, 3)


 12%|█▏        | 12/102 [11:58:49<89:51:09, 3594.10s/it]

[1200 1200 1200 ... 1299 1299 1299]
(1207, 3)


 13%|█▎        | 13/102 [12:59:22<88:55:46, 3597.15s/it]

[1300 1300 1300 ... 1399 1399 1399]
(1192, 3)


 14%|█▎        | 14/102 [14:01:41<88:10:36, 3607.23s/it]

[1400 1400 1400 ... 1499 1499 1499]
(1188, 3)


 15%|█▍        | 15/102 [15:03:25<87:19:53, 3613.72s/it]

[1500 1500 1500 ... 1599 1599 1599]
(1192, 3)


 16%|█▌        | 16/102 [16:02:16<86:12:14, 3608.54s/it]

[1600 1600 1600 ... 1699 1699 1699]
(1197, 3)


 17%|█▋        | 17/102 [17:00:36<85:03:03, 3602.16s/it]

[1700 1700 1700 ... 1799 1799 1799]
(1216, 3)


 18%|█▊        | 18/102 [18:01:23<84:06:28, 3604.63s/it]

[1800 1800 1800 ... 1899 1899 1899]
(1169, 3)


 19%|█▊        | 19/102 [18:58:14<82:52:21, 3594.47s/it]

[1900 1900 1900 ... 1999 1999 1999]
(1201, 3)


 20%|█▉        | 20/102 [19:55:11<81:40:15, 3585.55s/it]

[2000 2000 2000 ... 2099 2099 2099]
(1199, 3)


 21%|██        | 21/102 [20:55:49<80:43:55, 3588.09s/it]

[2100 2100 2100 ... 2199 2199 2199]
(1226, 3)


 22%|██▏       | 22/102 [21:56:53<79:48:43, 3591.54s/it]

[2200 2200 2200 ... 2299 2299 2299]
(1218, 3)


 23%|██▎       | 23/102 [22:54:53<78:42:26, 3586.67s/it]

[2300 2300 2300 ... 2399 2399 2399]
(1224, 3)


 24%|██▎       | 24/102 [23:54:24<77:41:48, 3586.01s/it]

[2400 2400 2400 ... 2499 2499 2499]
(1201, 3)


 25%|██▍       | 25/102 [24:56:01<76:47:45, 3590.47s/it]

[2500 2500 2500 ... 2599 2599 2599]
(1198, 3)


 25%|██▌       | 26/102 [25:54:21<75:43:31, 3587.00s/it]

[2600 2600 2600 ... 2699 2699 2699]
(1205, 3)


 26%|██▋       | 27/102 [26:54:55<74:45:53, 3588.71s/it]

[2700 2700 2700 ... 2799 2799 2799]
(1213, 3)


 27%|██▋       | 28/102 [27:55:50<73:48:59, 3591.07s/it]

[2800 2800 2800 ... 2899 2899 2899]
(1196, 3)


 28%|██▊       | 29/102 [28:56:32<72:51:18, 3592.86s/it]

[2900 2900 2900 ... 2999 2999 2999]
(1211, 3)


 29%|██▉       | 30/102 [29:58:31<71:56:28, 3597.06s/it]

[3000 3000 3000 ... 3099 3099 3099]
(1188, 3)


 30%|███       | 31/102 [31:00:27<71:01:03, 3600.90s/it]

[3100 3100 3100 ... 3199 3199 3199]
(1191, 3)


 31%|███▏      | 32/102 [32:02:29<70:05:26, 3604.66s/it]

[3200 3200 3200 ... 3299 3299 3299]
(1208, 3)


 32%|███▏      | 33/102 [33:05:05<69:10:38, 3609.25s/it]

[3300 3300 3300 ... 3399 3399 3399]
(1185, 3)


 33%|███▎      | 34/102 [34:07:00<68:14:00, 3612.36s/it]

[3400 3400 3400 ... 3499 3499 3499]
(1213, 3)


 34%|███▍      | 35/102 [35:09:35<67:18:21, 3616.44s/it]

[3500 3500 3500 ... 3599 3599 3599]
(1179, 3)


 35%|███▌      | 36/102 [36:10:19<66:18:56, 3617.22s/it]

[3600 3600 3600 ... 3699 3699 3699]
(1202, 3)


 36%|███▋      | 37/102 [37:12:36<65:22:09, 3620.45s/it]

[3700 3700 3700 ... 3799 3799 3799]
(1203, 3)


 37%|███▋      | 38/102 [38:14:58<64:25:13, 3623.64s/it]

[3800 3800 3800 ... 3899 3899 3899]
(1190, 3)


 38%|███▊      | 39/102 [39:16:23<63:26:29, 3625.23s/it]

[3900 3900 3900 ... 3999 3999 3999]
(1202, 3)


 39%|███▉      | 40/102 [40:14:24<62:22:19, 3621.60s/it]

[4000 4000 4000 ... 4099 4099 4099]
(1195, 3)


 40%|████      | 41/102 [41:11:10<61:16:38, 3616.37s/it]

[4100 4100 4100 ... 4199 4199 4199]
(1207, 3)


 41%|████      | 42/102 [42:08:31<60:12:11, 3612.18s/it]

[4200 4200 4200 ... 4299 4299 4299]
(1181, 3)


 42%|████▏     | 43/102 [43:04:59<59:06:50, 3606.96s/it]

[4300 4300 4300 ... 4399 4399 4399]
(1210, 3)


 43%|████▎     | 44/102 [44:05:38<58:07:26, 3607.70s/it]

[4400 4400 4400 ... 4499 4499 4499]
(1200, 3)


 44%|████▍     | 45/102 [45:07:12<57:09:07, 3609.61s/it]

[4500 4500 4500 ... 4599 4599 4599]
(1208, 3)


 45%|████▌     | 46/102 [46:08:12<56:09:59, 3610.70s/it]

[4600 4600 4600 ... 4699 4699 4699]
(1212, 3)


 46%|████▌     | 47/102 [47:09:22<55:10:58, 3611.96s/it]

[4700 4700 4700 ... 4799 4799 4799]
(1199, 3)


 47%|████▋     | 48/102 [48:10:23<54:11:41, 3613.00s/it]

[4800 4800 4800 ... 4899 4899 4899]
(1184, 3)


 48%|████▊     | 49/102 [49:10:19<53:11:10, 3612.65s/it]

[4900 4900 4900 ... 4999 4999 4999]
(1203, 3)


 49%|████▉     | 50/102 [50:11:23<52:11:51, 3613.68s/it]

[5000 5000 5000 ... 5099 5099 5099]
(1194, 3)


 50%|█████     | 51/102 [51:12:10<51:12:10, 3614.33s/it]

[5100 5100 5100 ... 5199 5199 5199]
(1197, 3)


 51%|█████     | 52/102 [52:13:01<50:12:31, 3615.03s/it]

[5200 5200 5200 ... 5299 5299 5299]
(1184, 3)


 52%|█████▏    | 53/102 [53:12:45<49:11:48, 3614.45s/it]

[5300 5300 5300 ... 5399 5399 5399]
(1195, 3)


 53%|█████▎    | 54/102 [54:13:07<48:11:39, 3614.58s/it]

[5400 5400 5400 ... 5499 5499 5499]
(1213, 3)


 54%|█████▍    | 55/102 [55:14:16<47:12:11, 3615.57s/it]

[5500 5500 5500 ... 5599 5599 5599]
(1182, 3)


 55%|█████▍    | 56/102 [56:14:44<46:12:06, 3615.80s/it]

[5600 5600 5600 ... 5699 5699 5699]
(1187, 3)


 56%|█████▌    | 57/102 [57:16:03<45:12:40, 3616.91s/it]

[5700 5700 5700 ... 5799 5799 5799]
(1153, 3)


 57%|█████▋    | 58/102 [58:15:41<44:11:54, 3616.23s/it]

[5800 5800 5800 ... 5899 5899 5899]
(1192, 3)


 58%|█████▊    | 59/102 [59:16:59<43:12:23, 3617.28s/it]

[5900 5900 5900 ... 5999 5999 5999]
(1182, 3)


 59%|█████▉    | 60/102 [60:17:26<42:12:12, 3617.45s/it]

[6000 6000 6000 ... 6099 6099 6099]
(1200, 3)


 60%|█████▉    | 61/102 [61:19:36<41:13:10, 3619.28s/it]

[6100 6100 6100 ... 6199 6199 6199]
(1214, 3)


 61%|██████    | 62/102 [62:26:56<40:17:22, 3626.07s/it]

[6200 6200 6200 ... 6299 6299 6299]
(1192, 3)


 62%|██████▏   | 63/102 [63:34:33<39:21:23, 3632.91s/it]

[6300 6300 6300 ... 6399 6399 6399]
(1210, 3)


 63%|██████▎   | 64/102 [64:40:50<38:24:14, 3638.29s/it]

[6400 6400 6400 ... 6499 6499 6499]
(1184, 3)


 64%|██████▎   | 65/102 [65:45:50<37:26:05, 3642.31s/it]

[6500 6500 6500 ... 6599 6599 6599]
(1197, 3)


 65%|██████▍   | 66/102 [66:50:24<36:27:29, 3645.82s/it]

[6600 6600 6600 ... 6699 6699 6699]
(1214, 3)


 66%|██████▌   | 67/102 [67:56:46<35:29:39, 3650.84s/it]

[6700 6700 6700 ... 6799 6799 6799]
(1210, 3)


 67%|██████▋   | 68/102 [69:06:08<34:33:04, 3658.36s/it]

[6800 6800 6800 ... 6899 6899 6899]
(1179, 3)


 68%|██████▊   | 69/102 [70:14:48<33:35:46, 3665.05s/it]

[6900 6900 6900 ... 6999 6999 6999]
(1202, 3)


 69%|██████▊   | 70/102 [71:24:35<32:38:40, 3672.50s/it]

[7000 7000 7000 ... 7099 7099 7099]
(1179, 3)


 70%|██████▉   | 71/102 [72:32:42<31:40:28, 3678.34s/it]

[7100 7100 7100 ... 7199 7199 7199]
(1220, 3)


 71%|███████   | 72/102 [73:42:32<30:42:43, 3685.45s/it]

[7200 7200 7200 ... 7299 7299 7299]
(1200, 3)


 72%|███████▏  | 73/102 [74:50:29<29:43:53, 3690.81s/it]

[7300 7300 7300 ... 7399 7399 7399]
(1240, 3)


 73%|███████▎  | 74/102 [76:00:03<28:45:25, 3697.34s/it]

[7400 7400 7400 ... 7499 7499 7499]
(1207, 3)


 74%|███████▎  | 75/102 [77:09:00<27:46:26, 3703.21s/it]

[7500 7500 7500 ... 7599 7599 7599]
(1199, 3)


 75%|███████▍  | 76/102 [78:17:11<26:46:55, 3708.31s/it]

[7600 7600 7600 ... 7699 7699 7699]
(1220, 3)


 75%|███████▌  | 77/102 [79:26:25<25:47:32, 3714.10s/it]

[7700 7700 7700 ... 7799 7799 7799]
(1190, 3)


 76%|███████▋  | 78/102 [80:34:10<24:47:26, 3718.59s/it]

[7800 7800 7800 ... 7899 7899 7899]
(1196, 3)


 77%|███████▋  | 79/102 [81:43:03<23:47:28, 3723.84s/it]

[7900 7900 7900 ... 7999 7999 7999]
(1202, 3)


 78%|███████▊  | 80/102 [82:52:37<22:47:28, 3729.47s/it]

[8000 8000 8000 ... 8099 8099 8099]
(1178, 3)


 79%|███████▉  | 81/102 [84:00:38<21:46:50, 3733.81s/it]

[8100 8100 8100 ... 8199 8199 8199]
(1193, 3)


 80%|████████  | 82/102 [85:09:23<20:46:11, 3738.58s/it]

[8200 8200 8200 ... 8299 8299 8299]
(1216, 3)


 81%|████████▏ | 83/102 [86:15:51<19:44:50, 3741.58s/it]

[8300 8300 8300 ... 8399 8399 8399]
(1210, 3)


 82%|████████▏ | 84/102 [87:18:36<18:42:33, 3741.87s/it]

[8400 8400 8400 ... 8499 8499 8499]
(1195, 3)


KeyboardInterrupt: 