In [1]:
import math
import collections

import pandas as pd
import numpy as np
from keras import Input,backend
from keras.models import Model, Sequential
from keras.layers import *
from keras.callbacks import Callback
from sklearn.metrics import confusion_matrix, f1_score, precision_score, recall_score, accuracy_score, matthews_corrcoef

Using TensorFlow backend.


In [2]:
dataset = pd.read_csv("../datasets/train_test_data.csv",
                           header=0, parse_dates=[0], index_col=0)

#params for generator
label_index = len(dataset.columns) -1


In [99]:
#hyperparamters
batch_size=64

#params for generator
delay=10
step=1 # 1 timestep = 1 day
lookback=10

#ratio for train/val/test split
train_ratio=0.7
val_ratio=0.15

train_max_idx = math.ceil(len(dataset)*train_ratio)
val_max_idx = math.ceil(len(dataset)*(train_ratio+val_ratio))

# 1 step = 1 batche of samples 
train_steps = (train_max_idx+1) // batch_size
val_steps =  (val_max_idx - train_max_idx - lookback) // batch_size
test_steps = (len(dataset) - val_max_idx - lookback) // batch_size

In [100]:
def generator_for_binary_classifier(data, label_index, lookback, delay, min_index, max_index,
              shuffle=False, batch_size=64, step=1, interval_label=False):
    if max_index is None:
        max_index = len(data) - delay - 1
    i = min_index + lookback
    while 1:
        if shuffle:
            rows = np.random.randint(
                min_index + lookback, max_index, size=batch_size)
        else:
            if i + batch_size >= max_index: 
                i = min_index + lookback #reset 'i'
            rows = np.arange(i, min(i + batch_size, max_index))
            i += len(rows)

        samples = np.zeros((len(rows),
                           lookback // step,
                           (data.shape[-1])))
        labels = np.zeros((len(rows),))
        
        #generate one batch of samples and targets
        for j, row in enumerate(rows):
            indices = range(rows[j] - lookback, rows[j], step)
            samples[j] = data[indices]
            
            if interval_label is False:
                labels[j] = data[rows[j] + delay][label_index]
            else:
                labels[j] = 1 if 1 in data[rows[j]:(rows[j] + delay)][:,label_index] else 0
                
        yield samples, labels

In [101]:
#init generator_for_binary_classifiers
train_rand_gen = generator_for_binary_classifier(dataset.to_numpy(),
                      label_index=label_index,
                      lookback=lookback,
                      delay=delay,
                      min_index=0,
                      max_index=train_max_idx,
                      shuffle=True,
                      step=step, 
                      batch_size=batch_size,
                      interval_label=True)
train_gen = generator_for_binary_classifier(dataset.to_numpy(),
                      label_index=label_index,
                      lookback=lookback,
                      delay=delay,
                      min_index=0,
                      max_index=train_max_idx,
                      shuffle=False,
                      step=step, 
                      batch_size=batch_size,
                      interval_label=True)



val_gen = generator_for_binary_classifier(dataset.to_numpy(),
                    label_index=label_index,
                      lookback=lookback,
                      delay=delay,
                      min_index=train_max_idx+1,
                      max_index=val_max_idx,
                      shuffle=False,
                      step=step, 
                      batch_size=batch_size,
                      interval_label=True)

test_gen = generator_for_binary_classifier(dataset.to_numpy(),
                      label_index=label_index,
                      lookback=lookback,
                      delay=delay,
                      min_index=val_max_idx+1,
                      max_index=None,
                      shuffle=False,
                      step=step, 
                      batch_size=batch_size,
                      interval_label=True)


In [102]:
X,Y=generator_to_samples_and_targets(train_rand_gen, train_steps)

print(X.shape)
print(Y.shape)
print(collections.Counter(Y))

IM here
row = 1672, row + delay=1682
[[0.36344155 0.76376348 0.37913939 0.43582964 0.61883276 0.59931933
  0.43292555 0.61537882 0.50760735 0.50070873 0.57485541 0.49287255
  0.49103385 0.53085565 0.30814734 0.61060973 0.52306413 0.47961165
  0.43650761 0.80162602 0.64424998 0.17189378 0.09909055 0.        ]
 [0.49846771 0.38956216 0.42080776 0.49291701 0.66469573 0.43592373
  0.48490445 0.63536873 0.50760735 0.61299371 0.56539537 0.48026222
  0.49140035 0.4702363  0.30814734 0.61060973 0.52306413 0.47961165
  0.43650761 0.80162602 0.64424998 0.17189378 0.09909055 0.        ]
 [0.55814095 0.50546727 0.4231342  0.44758273 0.59997641 0.54729107
  0.46125571 0.61092158 0.50760735 0.62463792 0.50284342 0.42558106
  0.49213334 0.47358206 0.30814734 0.61060973 0.52306413 0.47961165
  0.43650761 0.80162602 0.64424998 0.17189378 0.09909055 0.        ]
 [0.49144873 0.33788493 0.42623026 0.41692728 0.51546706 0.40652482
  0.47496894 0.56360643 0.42079597 0.56641934 0.54260236 0.54136047
  0.4921

IM here
row = 1997, row + delay=2007
[[0.4282599  0.41612834 0.46468578 0.42686638 0.42669139 0.44929294
  0.47647276 0.56103533 0.4080367  0.62657649 0.52053859 0.50864146
  0.49323293 0.47048538 0.55551945 0.61277498 0.75327619 0.43321963
  0.42071412 0.58861789 0.62224887 0.6454235  0.24128848 0.        ]
 [0.48549885 0.32227833 0.43939128 0.42387103 0.41687391 0.37135559
  0.4636815  0.56343427 0.43312711 0.50576051 0.57384008 0.61977254
  0.48003931 0.51320891 0.55551945 0.61277498 0.75327619 0.43321963
  0.42071412 0.58861789 0.62224887 0.6454235  0.24128848 0.        ]
 [0.41527354 0.3347063  0.35225781 0.38087794 0.30210584 0.40144986
  0.45027273 0.48181455 0.36727222 0.46413082 0.63270182 0.52219922
  0.50129937 0.4566952  0.55551945 0.61277498 0.75327619 0.43321963
  0.42071412 0.58861789 0.62224887 0.6454235  0.24128848 0.        ]
 [0.5530718  0.38081028 0.35420901 0.40520861 0.16665188 0.40957335
  0.45074005 0.56703941 0.36727222 0.50582986 0.5773912  0.52435859
  0.4935

  0.44807267 0.59481595 0.59542995 0.1058515  0.3017     0.        ]]
IM here
row = 576, row + delay=586
[[0.46833033 0.30249897 0.32470982 0.29685766 0.20184959 0.5179885
  0.45317885 0.46786484 0.45420344 0.50993627 0.44548393 0.54621499
  0.49767314 0.48082533 0.38965578 0.54805195 0.71776744 0.56796117
  0.93292908 0.56753956 0.70252639 0.47202418 0.31178869 1.        ]
 [0.46833033 0.34948827 0.32976051 0.29685766 0.71062433 0.3703427
  0.46283115 0.52188426 0.36476674 0.47406992 0.45262742 0.52147522
  0.47665206 0.46257851 0.38965578 0.54805195 0.71776744 0.56796117
  0.93292908 0.56753956 0.70252639 0.47202418 0.31178869 0.        ]
 [0.43042224 0.65148109 0.42170522 0.49447566 0.40161018 0.41879117
  0.57045989 0.4998567  0.46852795 0.58093809 0.48340517 0.55885605
  0.49398544 0.4799385  0.38965578 0.54805195 0.71776744 0.56796117
  0.93292908 0.56753956 0.70252639 0.47202418 0.31178869 0.        ]
 [0.4861301  0.27055368 0.5333236  0.49447566 0.65754689 0.36974034
  0.459854

[[0.55257199 0.30643667 0.31344646 0.36788954 0.37339211 0.40545498
  0.39190726 0.52690171 0.43669865 0.41124671 0.55910229 0.52079241
  0.50576976 0.4441798  0.35270614 0.59563164 0.64174609 0.33656958
  0.43075237 0.57257029 0.6891252  0.18791897 0.26589965 0.        ]
 [0.46563419 0.42628856 0.25186574 0.47108854 0.27424023 0.49254721
  0.53894223 0.5690563  0.43991822 0.58692524 0.57112591 0.54486092
  0.49618876 0.49654228 0.35270614 0.59563164 0.64174609 0.33656958
  0.43075237 0.57257029 0.6891252  0.18791897 0.26589965 1.        ]
 [0.52201774 0.41169982 0.54440398 0.43659612 0.61548385 0.44738382
  0.44219333 0.58242688 0.4725778  0.52760665 0.54671521 0.51728889
  0.49397898 0.47925749 0.35270614 0.59563164 0.64174609 0.33656958
  0.43075237 0.57257029 0.6891252  0.18791897 0.26589965 0.        ]
 [0.49713497 0.24524086 0.42714578 0.41595226 0.37307129 0.35169285
  0.45858602 0.56769481 0.51086366 0.52760665 0.5478298  0.46706523
  0.49250626 0.47074487 0.35270614 0.59563164

  0.4527068  0.60073604 0.63344559 0.21970683 0.28263492 0.        ]]
IM here
row = 3240, row + delay=3250
[[0.52060373 0.48412212 0.48035177 0.53717959 1.         0.43366311
  0.53506327 0.72369652 0.60040795 0.61558208 0.37855447 0.50135084
  0.49103385 0.4968213  0.42100143 0.60195279 0.82456675 0.47043248
  0.42740461 0.57415421 0.76342517 0.25466431 0.34421708 0.        ]
 [0.66949796 0.55577439 0.46568222 0.59160304 0.91768943 0.60256039
  0.52324962 0.76734887 0.62429101 0.66284342 0.49688248 0.44889878
  0.44205841 0.45191327 0.37820644 0.59847738 0.83933454 0.79935275
  0.40511798 0.66789563 0.69069665 0.25796291 0.33039542 0.        ]
 [0.51606163 0.36317177 0.50474418 0.43045625 0.55595049 0.41987869
  0.48002255 0.59638782 0.62429101 0.67665687 0.57013887 0.51879606
  0.4950914  0.51470387 0.37820644 0.59847738 0.83933454 0.79935275
  0.40511798 0.66789563 0.69069665 0.25796291 0.33039542 0.        ]
 [0.55099285 0.44129814 0.42507992 0.4969366  0.71575424 0.52269723
  0.51

  0.39704884 0.60254989 0.66702179 0.04132059 0.31711375 0.        ]]
IM here
row = 350, row + delay=360
[[0.50284831 0.44232783 0.58413295 0.43536097 0.60569275 0.49156057
  0.60378977 0.57998325 0.4743907  0.67730253 0.523572   0.51169662
  0.4969462  0.51137188 0.37551519 0.6359447  0.50527503 0.16973552
  0.46250381 0.57395683 0.62536512 0.41399245 0.39468854 0.        ]
 [0.51544811 0.42361196 0.5121396  0.47333119 0.56779078 0.43127011
  0.7360726  0.64787596 0.37675011 0.67730253 0.65604395 0.52943862
  0.49103385 0.50201586 0.37551519 0.6359447  0.50527503 0.16973552
  0.46250381 0.57395683 0.62536512 0.41399245 0.39468854 0.        ]
 [0.51544811 0.42361196 0.5121396  0.47333119 0.64081752 0.43127011
  0.7360726  0.64787596 0.37675011 0.67730253 0.65604395 0.52943862
  0.49103385 0.50201586 0.37551519 0.6359447  0.50527503 0.16973552
  0.46250381 0.57395683 0.62536512 0.41399245 0.39468854 0.        ]
 [0.5404945  0.38303623 0.44948281 0.36033965 0.43673625 0.40573116
  0.3242

  0.41579049 0.65542294 0.56820645 0.10297283 0.27387194 0.        ]]
IM here
row = 374, row + delay=384
[[0.5558834  0.34668779 0.31778724 0.37110787 0.43968436 0.38619232
  0.39479488 0.67452673 0.43035841 0.59410028 0.75806782 0.57814566
  0.48218295 0.49738241 0.39019571 0.58002429 0.60310875 0.5542423
  0.33610817 0.60620562 0.57957634 0.22978358 0.23875592 0.        ]
 [0.54423034 0.38996543 0.2763965  0.24016808 0.3754683  0.42532372
  0.4608963  0.46829574 0.48099017 0.55291609 0.69068675 0.45741724
  0.44160101 0.48053009 0.39019571 0.58002429 0.60310875 0.5542423
  0.33610817 0.60620562 0.57957634 0.22978358 0.23875592 0.        ]
 [0.47328179 0.35794444 0.23576321 0.32966448 0.44305825 0.39826736
  0.42155041 0.57717642 0.44397315 0.45056174 0.60867475 0.57080242
  0.52687931 0.48697073 0.39019571 0.58002429 0.60310875 0.5542423
  0.33610817 0.60620562 0.57957634 0.22978358 0.23875592 0.        ]
 [0.51258243 0.3424643  0.13220041 0.30166972 0.45840389 0.29437798
  0.4123210

  0.41130539 0.56505665 0.60141869 0.32301747 0.24210496 0.        ]]
IM here
row = 1049, row + delay=1059
[[0.52180046 0.48450473 0.39492012 0.44532721 0.70444404 0.48412608
  0.46931119 0.56321694 0.47066156 0.49592969 0.55792191 0.52525003
  0.49323377 0.48539738 0.30829646 0.58577244 0.52433384 0.45631068
  0.4463702  0.59780281 0.55122103 0.19364774 0.31382605 0.        ]
 [0.4893891  0.47955812 0.3955732  0.46470421 0.51605554 0.43991011
  0.45516099 0.58717179 0.4791618  0.52667603 0.59373017 0.49581512
  0.4932336  0.49503386 0.30829646 0.58577244 0.52433384 0.45631068
  0.4463702  0.59780281 0.55122103 0.19364774 0.31382605 0.        ]
 [0.4893891  0.47955812 0.3955732  0.46470421 0.55658818 0.43991011
  0.45516099 0.58717179 0.4791618  0.52667603 0.59373017 0.49581512
  0.4932336  0.49503386 0.30829646 0.58577244 0.52433384 0.45631068
  0.4463702  0.59780281 0.55122103 0.19364774 0.31382605 0.        ]
 [0.53074859 0.39544299 0.37282505 0.35331641 0.30954255 0.38448024
  0.45

  0.52023901 0.59623829 0.63946761 0.25117412 0.24243353 0.        ]]
IM here
row = 2766, row + delay=2776
[[0.08368903 0.39749479 0.48273116 0.44014825 0.53818359 0.44723895
  0.48144081 0.62835245 0.44947575 0.63488751 0.58851533 0.49767493
  0.48737299 0.47682593 0.45847524 0.78354978 0.55592728 0.58058252
  0.39944851 0.5927142  0.58737324 0.20092134 0.29015407 0.        ]
 [0.63465333 0.37537457 0.39760608 0.4379414  0.50590528 0.41233396
  0.49647302 0.546009   0.46381994 0.50410153 0.53892974 0.49661733
  0.49103385 0.48223695 0.45847524 0.78354978 0.55592728 0.58058252
  0.39944851 0.5927142  0.58737324 0.20092134 0.29015407 0.        ]
 [0.51375986 0.28027148 0.37630927 0.41775902 0.50723241 0.39527514
  0.48567849 0.53210504 0.44352055 0.54678611 0.57930845 0.48701658
  0.50531302 0.4958034  0.45847524 0.78354978 0.55592728 0.58058252
  0.39944851 0.5927142  0.58737324 0.20092134 0.29015407 0.        ]
 [0.65473237 0.33045384 0.37216552 0.41852293 0.53400581 0.35304363
  0.46

  0.46300409 0.57831325 0.41965863 0.89772905 0.23160061 0.        ]]
IM here
row = 3662, row + delay=3672
[[0.590285   0.45999144 0.42252949 0.4260664  0.55166695 0.4373462
  0.50252983 0.51088761 0.51104045 0.58730362 0.42077777 0.52570762
  0.49287037 0.47001024 0.33583393 0.59104161 0.65020908 0.37864078
  0.42033932 0.51884701 0.67416965 0.16884462 0.37898462 0.        ]
 [0.5816694  0.45999144 0.42252949 0.47472375 0.54148729 0.49315807
  0.50252983 0.57095449 0.51104045 0.58730362 0.61131133 0.48287344
  0.45786841 0.47505857 0.30184626 0.59393939 0.69857664 0.37864078
  0.40347537 0.49593496 0.64551582 0.08960159 0.24513956 0.        ]
 [0.5816694  0.39545603 0.41410126 0.45172971 0.47443481 0.43190222
  0.50811731 0.57095449 0.51104045 0.58730362 0.51127585 0.54283402
  0.48084836 0.48939076 0.30184626 0.59393939 0.69857664 0.37864078
  0.40347537 0.49593496 0.64551582 0.08960159 0.24513956 0.        ]
 [0.47644648 0.37632089 0.40673219 0.43284103 0.3726895  0.46256065
  0.501

  0.42543381 0.57560976 0.76843138 0.42635545 0.47167613 0.        ]]
IM here
row = 2365, row + delay=2375
[[0.47216736 0.29641931 0.42328421 0.37393285 0.32304572 0.51081967
  0.46052777 0.44739663 0.38105406 0.49301064 0.59485715 0.49147754
  0.49066546 0.48385392 0.37577523 0.59574468 0.62671938 0.37864078
  0.41453744 0.57491289 0.57507855 0.32145966 0.3306183  0.        ]
 [0.50735305 0.42583418 0.39471281 0.42896313 0.56809216 0.36541362
  0.47502849 0.56596926 0.4644958  0.54081156 0.55982249 0.5048779
  0.47703473 0.47701786 0.37577523 0.59574468 0.62671938 0.37864078
  0.41453744 0.57491289 0.57507855 0.32145966 0.3306183  0.        ]
 [0.52529264 0.37530364 0.37316639 0.43807642 0.4856631  0.44564357
  0.47592189 0.53555331 0.44913461 0.57593727 0.54145851 0.54273376
  0.48403086 0.5068544  0.37577523 0.59574468 0.62671938 0.37864078
  0.41453744 0.57491289 0.57507855 0.32145966 0.3306183  0.        ]
 [0.50071443 0.47643289 0.40252999 0.40309535 0.46975217 0.51965432
  0.469

  0.48511215 0.55937099 0.63161217 0.24583816 0.31037805 0.        ]]
IM here
row = 131, row + delay=141
[[0.42555809 0.4471633  0.40062318 0.41225664 0.55653496 0.50597901
  0.49018939 0.56980286 0.38498359 0.52537601 0.55513389 0.56356118
  0.48328301 0.48799656 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.63109587 0.43590789 0.39859323 0.42252447 0.5548187  0.41752541
  0.47674221 0.60819507 0.396426   0.51401367 0.55645582 0.49777655
  0.49472572 0.49288799 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.49476908 0.44162025 0.41807853 0.40313594 0.57295326 0.4772089
  0.48356865 0.54208082 0.48058476 0.69926526 0.55469055 0.50385669
  0.48586589 0.48874919 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.60445567 0.42408846 0.41424989 0.40542606 0.51567793 0.49057858
  0.45650

  0.41376181 0.59287892 0.62317077 0.46674025 0.28163591 0.        ]]
IM here
row = 1507, row + delay=1517
[[0.4573451  0.50106805 0.41807732 0.44434143 0.41470579 0.52825467
  0.47397421 0.51048007 0.38045229 0.53867672 0.55137243 0.48873444
  0.49030091 0.49942619 0.26079088 0.59390048 0.50885368 0.37864078
  0.41615005 0.57966243 1.         0.25021501 0.21990221 0.        ]
 [0.47736207 0.47499901 0.39653908 0.40994626 0.51484527 0.54083533
  0.47962764 0.59043354 0.45294329 0.56640243 0.56929815 0.51055995
  0.49103385 0.51993274 0.26079088 0.59390048 0.50885368 0.37864078
  0.41615005 0.57966243 1.         0.25021501 0.21990221 0.        ]
 [0.44541634 0.42906733 0.43647922 0.46019216 0.69798956 0.4108183
  0.48465036 0.69786318 0.54513571 0.49151331 0.55699458 0.49546976
  0.49103385 0.45590774 0.26079088 0.59390048 0.50885368 0.37864078
  0.41615005 0.57966243 1.         0.25021501 0.21990221 0.        ]
 [0.34677318 0.3824335  0.40656034 0.41978587 0.35995901 0.47010477
  0.480

IM here
row = 1731, row + delay=1741
[[0.53283371 0.55977425 0.5001887  0.48320514 0.73833207 0.45155466
  0.48995778 0.62068829 0.47382659 0.62195775 0.58353452 0.51324752
  0.49140032 0.49054878 0.         0.55896104 0.40720283 0.37864078
  0.42543381 0.57560976 0.76843138 0.42635545 0.47167613 0.        ]
 [0.48142543 0.69471563 0.4638341  0.46994711 0.6245792  0.44423283
  0.51426202 0.67546509 0.53154175 0.59145543 0.54342744 0.54640668
  0.49213325 0.50726978 0.         0.55896104 0.40720283 0.37864078
  0.42543381 0.57560976 0.76843138 0.42635545 0.47167613 0.        ]
 [0.51947503 0.33832785 0.4638341  0.36787626 0.45957066 0.42923073
  0.48016574 0.56489043 0.39007996 0.59145543 0.5535218  0.48887379
  0.48993449 0.47532521 0.         0.55896104 0.40720283 0.37864078
  0.42543381 0.57560976 0.76843138 0.42635545 0.47167613 0.        ]
 [0.57829543 0.3531464  0.37465128 0.36195488 0.47299219 0.37033832
  0.46124944 0.50865203 0.41080346 0.55627844 0.55526394 0.46965516
  0.4914

  0.13801602 0.60511746 0.70283852 0.50415724 0.16811794 1.        ]]
IM here
row = 2569, row + delay=2579
[[0.58163485 0.47877376 0.41406709 0.41841291 0.50128087 0.50471098
  0.47622855 0.51974669 0.45261058 0.54862458 0.55947272 0.47511536
  0.48993144 0.47149443 0.27895547 0.5787013  0.53107789 0.37864078
  0.44807267 0.59481595 0.59542995 0.1058515  0.3017     0.        ]
 [0.56716142 0.44522123 0.43158041 0.45272727 0.58943902 0.47543869
  0.46695423 0.60138235 0.50945212 0.57561065 0.55502774 0.48813963
  0.48478663 0.47568297 0.27895547 0.5787013  0.53107789 0.37864078
  0.44807267 0.59481595 0.59542995 0.1058515  0.3017     0.        ]
 [0.46044418 0.42624295 0.41480645 0.40751537 0.53861991 0.44594781
  0.49091164 0.69120059 0.43399688 0.53651368 0.51982523 0.57145026
  0.48552039 0.49197769 0.27895547 0.5787013  0.53107789 0.37864078
  0.44807267 0.59481595 0.59542995 0.1058515  0.3017     0.        ]
 [0.4846256  0.394078   0.38835057 0.41373653 0.48394188 0.45629228
  0.48

  0.40846984 0.58387691 0.66369541 0.41158776 0.39957876 0.        ]]
IM here
row = 2289, row + delay=2299
[[0.48495172 0.43696462 0.39401766 0.4190668  0.39715567 0.43132804
  0.4693574  0.55839293 0.46793065 0.55729923 0.64593241 0.52803434
  0.47300884 0.48776424 0.46245193 0.60714286 0.67774689 0.32815534
  0.44628246 0.5815435  0.57977889 0.37452184 0.2763315  0.        ]
 [0.39127804 0.40867368 0.37529615 0.40483056 0.50994447 0.40382639
  0.47692902 0.58906718 0.448547   0.4978781  0.55718645 0.51201698
  0.50870214 0.48556146 0.46245193 0.60714286 0.67774689 0.32815534
  0.44628246 0.5815435  0.57977889 0.37452184 0.2763315  0.        ]
 [0.38274496 0.43829143 0.36081026 0.41621763 0.53369091 0.45093485
  0.49549786 0.5577553  0.46188814 0.52098608 0.55527    0.60392712
  0.49287316 0.53217476 0.46245193 0.60714286 0.67774689 0.32815534
  0.44628246 0.5815435  0.57977889 0.37452184 0.2763315  0.        ]
 [0.46155572 0.41635659 0.27053061 0.41621763 0.49652438 0.48111732
  0.49

  0.40811156 0.62810915 0.61323239 0.25541708 0.07759713 0.        ]]
IM here
row = 1576, row + delay=1586
[[0.53982083 0.43300126 0.4423095  0.43486912 0.4788299  0.43055426
  0.48883847 0.58388401 0.4433444  0.56664057 0.57893716 0.48424488
  0.49103385 0.48066057 0.29035234 0.58811387 0.42027998 0.58899676
  0.40438234 0.59283537 0.60772149 0.24460433 0.19252765 0.        ]
 [0.49373009 0.60690158 0.39171832 0.41584895 0.45065399 0.5528089
  0.48092835 0.60111923 0.37075037 0.52436686 0.59654699 0.55472678
  0.49103385 0.51915298 0.29035234 0.58811387 0.42027998 0.58899676
  0.40438234 0.59283537 0.60772149 0.24460433 0.19252765 0.        ]
 [0.58459701 0.31948991 0.38768104 0.46954906 0.54529163 0.370563
  0.47331799 0.62729766 0.45511755 0.58308766 0.57884383 0.47452937
  0.49103385 0.47715701 0.29035234 0.58811387 0.42027998 0.58899676
  0.40438234 0.59283537 0.60772149 0.24460433 0.19252765 0.        ]
 [0.50803776 0.53199882 0.33958072 0.41955654 0.44997512 0.50000477
  0.47136

  0.39799962 0.57716569 0.58200406 0.54840706 0.21134582 0.        ]]
IM here
row = 1525, row + delay=1535
[[0.70642298 0.43648181 0.38824778 0.41541586 0.48646347 0.34240086
  0.49897167 0.55718908 0.4164598  0.6131084  0.58447047 0.49310929
  0.49140031 0.46639724 0.23739273 0.59511077 0.48580977 0.54019417
  0.40228144 0.57663151 0.91096714 0.25805598 0.1876169  0.        ]
 [0.52881251 0.35628861 0.33973014 0.39527808 0.28707888 0.43063597
  0.48804994 0.56074013 0.40020663 0.54868124 0.55731953 0.5582089
  0.4914003  0.49664879 0.23739273 0.59511077 0.48580977 0.54019417
  0.40228144 0.57663151 0.91096714 0.25805598 0.1876169  0.        ]
 [0.46076447 0.38573649 0.43741491 0.42235743 0.58858043 0.45342549
  0.48117024 0.56724266 0.42295257 0.60889847 0.55838908 0.53085752
  0.49103385 0.48599726 0.23739273 0.59511077 0.48580977 0.54019417
  0.40228144 0.57663151 0.91096714 0.25805598 0.1876169  0.        ]
 [0.5471563  0.42396554 0.38405934 0.40114013 0.37040295 0.35760796
  0.466

  0.42211612 0.58796056 0.58622131 0.25990562 0.30663438 0.        ]]
IM here
row = 2767, row + delay=2777
[[0.63465333 0.37537457 0.39760608 0.4379414  0.50590528 0.41233396
  0.49647302 0.546009   0.46381994 0.50410153 0.53892974 0.49661733
  0.49103385 0.48223695 0.45847524 0.78354978 0.55592728 0.58058252
  0.39944851 0.5927142  0.58737324 0.20092134 0.29015407 0.        ]
 [0.51375986 0.28027148 0.37630927 0.41775902 0.50723241 0.39527514
  0.48567849 0.53210504 0.44352055 0.54678611 0.57930845 0.48701658
  0.50531302 0.4958034  0.45847524 0.78354978 0.55592728 0.58058252
  0.39944851 0.5927142  0.58737324 0.20092134 0.29015407 0.        ]
 [0.65473237 0.33045384 0.37216552 0.41852293 0.53400581 0.35304363
  0.46809003 0.53551205 0.44902761 0.54962133 0.56093616 0.51461647
  0.50164639 0.50513875 0.45847524 0.78354978 0.55592728 0.58058252
  0.39944851 0.5927142  0.58737324 0.20092134 0.29015407 0.        ]
 [0.53042405 0.39029358 0.40934444 0.37822534 0.44899495 0.49342757
  0.46

[[0.58878953 0.3196101  0.47308759 0.49400869 0.50401376 0.46240868
  0.47950883 0.62039381 0.39950537 0.58706687 0.55866684 0.53814282
  0.51428906 0.50546332 0.31193409 0.25602968 0.53107789 0.54692557
  0.39799962 0.57716569 0.58200406 0.54840706 0.21134582 0.        ]
 [0.70251482 0.44955423 0.42475029 0.46982551 0.64529234 0.42729207
  0.50777716 0.60753111 0.52099036 0.58706687 0.55387663 0.54217467
  0.46890403 0.52693468 0.31193409 0.25602968 0.53107789 0.54692557
  0.39799962 0.57716569 0.58200406 0.54840706 0.21134582 0.        ]
 [0.70251482 0.44955423 0.42475029 0.46982551 0.52884833 0.42729207
  0.50777716 0.60753111 0.52099036 0.58706687 0.55387663 0.54217467
  0.46890403 0.52693468 0.39422228 0.61606    0.5144559  0.45076283
  0.42812707 0.57181572 0.48468108 0.50316539 0.27881608 0.        ]
 [0.5107671  0.47756853 0.38997917 0.37390349 0.55079722 0.433228
  0.49888001 0.53535688 0.46755299 0.50770394 0.56781549 0.49884142
  0.49324854 0.48463906 0.39422228 0.61606    0

IM here
row = 2258, row + delay=2268
[[0.52063516 0.39408471 0.43628483 0.37693937 0.45756863 0.47202062
  0.48645614 0.54590512 0.40278144 0.51837616 0.54558678 0.52338745
  0.49359973 0.49884665 0.22967019 0.58029775 0.52913667 0.49083064
  0.403807   0.58963415 0.59345287 0.46105814 0.3766238  0.        ]
 [0.51669755 0.49157908 0.41651865 0.41224169 0.49362134 0.48687098
  0.46848407 0.61818628 0.44993103 0.59497834 0.51271172 0.47716623
  0.46684338 0.46873597 0.22967019 0.58029775 0.52913667 0.49083064
  0.403807   0.58963415 0.59345287 0.46105814 0.3766238  0.        ]
 [0.50371796 0.44822137 0.43574034 0.43746327 0.59867674 0.46085689
  0.48389343 0.66065882 0.53363506 0.57660833 0.60179026 0.54303522
  0.47636054 0.50119038 0.22967019 0.58029775 0.52913667 0.49083064
  0.403807   0.58963415 0.59345287 0.46105814 0.3766238  0.        ]
 [0.48207935 0.40876329 0.40003844 0.42539268 0.47764921 0.43094123
  0.47024737 0.56221842 0.43222222 0.52009902 0.60164054 0.54109908
  0.4954

  0.5525006  0.5836495  0.54207225 0.22596748 0.34137152 0.        ]]
IM here
row = 63, row + delay=73
[[0.71278038 0.37504793 0.4432018  0.41619088 0.50542885 0.41765656
  0.49162165 0.53991504 0.45181664 0.52172673 0.53044473 0.48090776
  0.49842629 0.49529545 0.36063461 0.54161747 0.67669938 0.30651872
  0.41544049 0.59139168 0.60887341 1.         0.35004944 0.        ]
 [0.68367462 0.4328227  0.44605995 0.43437729 0.47172651 0.4335844
  0.48323609 0.60401824 0.50010428 0.5157918  0.55247335 0.52043872
  0.48696906 0.47844957 0.36063461 0.54161747 0.67669938 0.30651872
  0.41544049 0.59139168 0.60887341 1.         0.35004944 0.        ]
 [0.50475962 0.30305715 0.38599302 0.33261277 0.37630557 0.40712217
  0.419573   0.36575714 0.32170396 0.5157918  0.56927223 0.52569847
  0.47920732 0.47759124 0.36063461 0.54161747 0.67669938 0.30651872
  0.41544049 0.59139168 0.60887341 1.         0.35004944 0.        ]
 [0.64639505 0.47019563 0.35866577 0.43097874 0.46786661 0.50490854
  0.5037709

  0.40954805 0.59589041 0.80134582 0.45959769 0.26687497 0.        ]]
IM here
row = 2542, row + delay=2552
[[0.46861258 0.38934853 0.45671756 0.4247564  0.49936417 0.40552879
  0.49466124 0.59231393 0.44279198 0.5371117  0.52920624 0.46604878
  0.47043384 0.47671527 0.28983556 0.58441558 0.51510659 0.42672215
  0.40799916 0.62332661 0.59850929 0.00826594 0.3476775  0.        ]
 [0.51348322 0.42676582 0.38280982 0.44813928 0.52208393 0.4454014
  0.47941524 0.57593153 0.50899366 0.54404823 0.5452412  0.55310484
  0.50759936 0.509659   0.28983556 0.58441558 0.51510659 0.42672215
  0.40799916 0.62332661 0.59850929 0.00826594 0.3476775  0.        ]
 [0.53486361 0.32858975 0.41786028 0.40018556 0.44932762 0.45775433
  0.48048064 0.52453009 0.42644118 0.56050925 0.59848282 0.47400325
  0.5009674  0.48329433 0.28983556 0.58441558 0.51510659 0.42672215
  0.40799916 0.62332661 0.59850929 0.00826594 0.3476775  0.        ]
 [0.48454374 0.42534316 0.3940547  0.42722555 0.49198463 0.38554788
  0.489

  0.42712276 0.56917764 0.63141369 0.17762139 0.20946791 0.        ]]
IM here
row = 796, row + delay=806
[[0.54819194 0.42313156 0.43036225 0.47918053 0.36461271 0.43582598
  0.3826684  0.52164731 0.43838619 0.51828128 0.55910229 0.50165532
  0.49103385 0.49693292 0.33415922 0.58954998 0.58654977 0.20303926
  0.42471189 0.5909062  0.68678229 0.18151748 0.23621363 0.        ]
 [0.53442935 0.38817758 0.43395455 0.43002803 0.47615552 0.41683094
  0.52199131 0.64150969 0.47796971 0.58097661 0.53652513 0.50165532
  0.48808887 0.4908954  0.33415922 0.58954998 0.58654977 0.20303926
  0.42471189 0.5909062  0.68678229 0.18151748 0.23621363 0.        ]
 [0.57068933 0.4166399  0.43395455 0.4479849  0.43077579 0.37645184
  0.39476232 0.50677123 0.38436781 0.51463358 0.5526935  0.47573871
  0.49950154 0.45904525 0.33415922 0.58954998 0.58654977 0.20303926
  0.42471189 0.5909062  0.68678229 0.18151748 0.23621363 0.        ]
 [0.51490646 0.27417433 0.37764948 0.34506983 0.33934371 0.46941323
  0.4389

  0.42740461 0.57415421 0.76342517 0.25466431 0.34421708 0.        ]]
IM here
row = 3000, row + delay=3010
[[0.47356471 0.40819857 0.42972436 0.41885611 0.61228392 0.48961519
  0.49227171 0.59589701 0.48323227 0.58176687 0.48092503 0.45471513
  0.49543245 0.49302251 0.46694189 0.61879297 0.69980946 0.28684907
  0.40811156 0.62810915 0.61323239 0.25541708 0.07759713 0.        ]
 [0.48317373 0.45468305 0.41327004 0.43391821 0.53037562 0.44433717
  0.49240698 0.615035   0.43747579 0.6165693  0.50008898 0.42023412
  0.4881019  0.45828933 0.46694189 0.61879297 0.69980946 0.28684907
  0.40811156 0.62810915 0.61323239 0.25541708 0.07759713 0.        ]
 [0.47420306 0.46461097 0.40840302 0.43603814 0.57686252 0.44262793
  0.45828755 0.60053665 0.46010273 0.56744391 0.48873078 0.53275418
  0.49763142 0.50365361 0.46694189 0.61879297 0.69980946 0.28684907
  0.40811156 0.62810915 0.61323239 0.25541708 0.07759713 0.        ]
 [0.27460228 0.31831889 0.43306578 0.35283724 0.3162314  0.39328706
  0.45

[[0.54760948 0.41669688 0.40279104 0.41739398 0.56368981 0.42449995
  0.50439322 0.58523062 0.4795826  0.64055033 0.55645611 0.48255686
  0.49509481 0.49253181 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.57627416 0.34529945 0.40072576 0.39035972 0.50525824 0.39594173
  0.48001713 0.52827685 0.42811013 0.55416872 0.56880707 0.4942558
  0.4932486  0.48307482 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.46149054 0.38354996 0.40166629 0.37389886 0.50944701 0.3496097
  0.48118282 0.56216877 0.38541662 0.53395597 0.55425378 0.53522005
  0.48069913 0.49897465 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.53995437 0.45992984 0.38411106 0.37825244 0.50432758 0.46045039
  0.46803231 0.48429866 0.45156281 0.39790591 0.55425185 0.4785458
  0.50801845 0.47288359 0.35000294 1.         0.

[[0.59117679 0.41358721 0.44877618 0.4289903  0.6342131  0.39074476
  0.46865281 0.52278478 0.43041211 0.62389379 0.44047049 0.4299681
  0.49803079 0.45004392 0.40769224 0.60350493 0.56109432 0.34979196
  0.40030247 0.568784   0.90076651 0.30116012 0.38440598 0.        ]
 [0.59117679 0.42417309 0.37696196 0.34749021 0.4995317  0.37225376
  0.4717967  0.4996043  0.41928303 0.46525666 0.530059   0.61055946
  0.49582006 0.54107283 0.40769224 0.60350493 0.56109432 0.34979196
  0.40030247 0.568784   0.90076651 0.30116012 0.38440598 0.        ]
 [0.59117679 0.34011962 0.33810248 0.39885327 0.41176996 0.4624726
  0.47361765 0.53445606 0.41996042 0.51075103 0.55728273 0.55893073
  0.52232306 0.49766319 0.40769224 0.60350493 0.56109432 0.34979196
  0.40030247 0.568784   0.90076651 0.30116012 0.38440598 0.        ]
 [0.38302203 0.36655764 0.39614477 0.41095635 0.5716829  0.42552563
  0.48613248 0.58202478 0.4478408  0.47356852 0.55000312 0.49785107
  0.47044232 0.47369715 0.40769224 0.60350493 0

[[0.52850631 0.483006   0.40127615 0.56129915 0.54221236 0.34094219
  0.51040565 0.72102594 0.41047363 0.54504887 0.46593855 0.57745488
  0.48365791 0.51158697 0.42100143 0.60195279 0.82456675 0.47043248
  0.42740461 0.57415421 0.76342517 0.25466431 0.34421708 0.        ]
 [0.51915414 0.40334995 0.3872736  0.43888282 0.47552547 0.49914489
  0.48809325 0.55975808 0.40219068 0.54504887 0.62771046 0.4736366
  0.49435388 0.50118254 0.42100143 0.60195279 0.82456675 0.47043248
  0.42740461 0.57415421 0.76342517 0.25466431 0.34421708 0.        ]
 [0.49752878 0.47269153 0.48120407 0.40424809 0.6523376  0.51364782
  0.49189189 0.6266225  0.45182122 0.54504887 0.551921   0.57988748
  0.49250925 0.50056885 0.42100143 0.60195279 0.82456675 0.47043248
  0.42740461 0.57415421 0.76342517 0.25466431 0.34421708 0.        ]
 [0.44111371 0.49046738 0.4591544  0.42485114 0.19723663 0.46237724
  0.49018406 0.57188573 0.51439098 0.58399502 0.55550952 0.49798134
  0.51758963 0.5235568  0.42100143 0.60195279 

IM here
row = 1067, row + delay=1077
[[0.53398214 0.47186732 0.40880641 0.47443811 0.41888337 0.49261788
  0.47910273 0.62428912 0.5300563  0.49632157 0.55910229 0.47313923
  0.49066735 0.49093069 0.36401272 0.59330144 0.58319581 0.37864078
  0.43101651 0.57407312 0.49696132 0.34690338 0.27263903 0.        ]
 [0.53398214 0.47186732 0.40880641 0.47443811 0.37844688 0.49261788
  0.47910273 0.62428912 0.5300563  0.49632157 0.55910229 0.47313923
  0.49066735 0.49093069 0.36401272 0.59330144 0.58319581 0.37864078
  0.43101651 0.57407312 0.49696132 0.34690338 0.27263903 0.        ]
 [0.52742304 0.40697779 0.38503791 0.33184848 0.57662158 0.44968328
  0.5121163  0.56130906 0.40166825 0.69126203 0.554026   0.56330298
  0.48956785 0.50897624 0.36401272 0.59330144 0.58319581 0.37864078
  0.43101651 0.57407312 0.49696132 0.34690338 0.27263903 0.        ]
 [0.52916543 0.47059116 0.45741642 0.41591695 0.73686719 0.40578652
  0.46370122 0.58251373 0.49936995 0.51588383 0.558321   0.56083901
  0.4950

IM here
row = 3953, row + delay=3963
[[0.46444245 0.45327839 0.39855084 0.42782029 0.49430337 0.46721759
  0.4863332  0.55348471 0.47464373 0.57483359 0.5549893  0.57523913
  0.49840066 0.50386868 0.37080186 0.59902597 0.48862975 0.37864078
  0.3989661  0.52217295 0.6024331  0.22160695 0.21688797 0.        ]
 [0.56737753 0.41709932 0.48733064 0.41610018 0.43646679 0.44370168
  0.4697824  0.56375726 0.44884815 0.61042881 0.57555983 0.4699923
  0.49250683 0.47487842 0.37080186 0.59902597 0.48862975 0.37864078
  0.3989661  0.52217295 0.6024331  0.22160695 0.21688797 0.        ]
 [0.51282237 0.40209331 0.39899466 0.39203404 0.48744723 0.45754817
  0.48424996 0.59339988 0.45101363 0.55232578 0.55088467 0.47032851
  0.49250676 0.46436767 0.37080186 0.59902597 0.48862975 0.37864078
  0.3989661  0.52217295 0.6024331  0.22160695 0.21688797 0.        ]
 [0.43140321 0.33633661 0.39511116 0.38817165 0.43167095 0.42371668
  0.48683887 0.54260517 0.4249041  0.57473747 0.5549907  0.51733754
  0.49177

IM here
row = 3251, row + delay=3261
[[0.51575969 0.53054484 0.3687106  0.46376023 0.59951255 0.58568465
  0.47624739 0.49858201 0.35073153 0.59644643 0.5886317  0.38143198
  0.51797306 0.47542813 0.37820644 0.59847738 0.83933454 0.79935275
  0.40511798 0.66789563 0.69069665 0.25796291 0.33039542 0.        ]
 [0.54695399 0.52001992 0.4557352  0.51536061 0.58574578 0.42984657
  0.50654511 0.7141216  0.55523459 0.6645108  0.51491576 0.44952972
  0.48660965 0.4710564  0.37820644 0.59847738 0.83933454 0.79935275
  0.40511798 0.66789563 0.69069665 0.25796291 0.33039542 0.        ]
 [0.45362131 0.49056606 0.44640088 0.40928829 0.55141132 0.41337985
  0.52304042 0.58235311 0.46920739 0.58556663 0.54801527 0.51148498
  0.49177133 0.45527534 0.37820644 0.59847738 0.83933454 0.79935275
  0.40511798 0.66789563 0.69069665 0.25796291 0.33039542 0.        ]
 [0.42355594 0.2643986  0.41134476 0.37262878 0.51452442 0.35698357
  0.4659074  0.55961297 0.41261332 0.46868579 0.48142213 0.49149374
  0.4939

  0.40041057 0.46573751 0.62123989 0.1401702  0.27791095 0.        ]]
IM here
row = 2178, row + delay=2188
[[0.49920683 0.34891455 0.36542221 0.43400852 0.37721621 0.45758481
  0.48761603 0.56384276 0.41695315 0.56682652 0.42859941 0.48551536
  0.47563478 0.45438073 0.32462313 0.59369202 0.570558   0.37864078
  0.45413682 0.53736854 0.69248498 0.24228872 0.30213321 0.        ]
 [0.49892393 0.41497253 0.49332152 0.42305369 0.46577902 0.39640839
  0.48296065 0.56860337 0.41379288 0.54608164 0.5705738  0.52255951
  0.49250122 0.50037506 0.32462313 0.59369202 0.570558   0.37864078
  0.45413682 0.53736854 0.69248498 0.24228872 0.30213321 0.        ]
 [0.45833451 0.43209453 0.43093051 0.40345848 0.54292791 0.46739882
  0.47064566 0.55669953 0.47394526 0.59670336 0.61831572 0.56613699
  0.49543573 0.51750532 0.32462313 0.59369202 0.570558   0.37864078
  0.45413682 0.53736854 0.69248498 0.24228872 0.30213321 0.        ]
 [0.38372788 0.42610717 0.4705423  0.45015337 0.56817184 0.48520239
  0.48

  0.41615005 0.57966243 1.         0.25021501 0.21990221 0.        ]]
IM here
row = 136, row + delay=146
[[0.54760948 0.41669688 0.40279104 0.41739398 0.56368981 0.42449995
  0.50439322 0.58523062 0.4795826  0.64055033 0.55645611 0.48255686
  0.49509481 0.49253181 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.57627416 0.34529945 0.40072576 0.39035972 0.50525824 0.39594173
  0.48001713 0.52827685 0.42811013 0.55416872 0.56880707 0.4942558
  0.4932486  0.48307482 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.46149054 0.38354996 0.40166629 0.37389886 0.50944701 0.3496097
  0.48118282 0.56216877 0.38541662 0.53395597 0.55425378 0.53522005
  0.48069913 0.49897465 0.35000294 1.         0.64926675 0.15064203
  0.38785341 0.59002354 0.59376082 0.42410658 0.25891914 0.        ]
 [0.53995437 0.45992984 0.38411106 0.37825244 0.50432758 0.46045039
  0.468032

KeyboardInterrupt: 

In [6]:
def print_metric(y_true, Y_pred):
    print('F1 score: %f' % f1_score(y_true, Y_pred))
    print('precision score: %f' % precision_score(y_true, Y_pred))
    print('recall score: %f' % recall_score(y_true, Y_pred))
    print('accuracy score: %f' % accuracy_score(y_true, Y_pred))
    print('matthews_corrcoef: %f' % matthews_corrcoef(y_true, Y_pred))
    print('\nConfusion matrix:')
    print(confusion_matrix(y_true, Y_pred, labels=[0,1]))

In [7]:
def probs_to_binary_classes(preds, threshold=0.5):
    """
    preds: np array
    threshold: scalar
    """
    return np.where(preds > threshold, 1, 0)

In [8]:
def print_unique_counts(x):
    unique, counts = np.unique(x, return_counts=True)
    print(np.asarray((unique, counts)).T)


In [9]:
#build s FC model from the book

input_shape = (lookback//step, dataset.shape[-1] - 1)

model = Sequential()
model.add(Flatten(input_shape=input_shape))
model.add(Dense(32, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
print(model.summary())

# #fit the model
model.fit_generator(train_rand_gen,
                    steps_per_epoch=train_steps,
                    epochs=50, verbose=1)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_1 (Flatten)          (None, 250)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 32)                8032      
_________________________________________________________________
dense_2 (Dense)              (None, 1)                 33        
Total params: 8,065
Trainable params: 8,065
Non-trainable params: 0
_________________________________________________________________
None
Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50


<keras.callbacks.History at 0x7f97d53a29b0>

In [11]:
def generator_to_samples_and_targets(generator, steps):
    count=0
    X, Y = [], []
    for samples, targets in generator:
        if count >= steps:
            break;
        else:
            count+=1

        X.append(samples)
        Y.append(targets)

    return np.concatenate(X, axis=0), np.concatenate(Y, axis=0)


In [11]:
neg = 2762+158
pos = 177+807
print("We have {0} neg cases and {1} pos cases from train data".format(neg,pos))
print("the common sense baseline (accuracy score) is {0}".format(neg/(pos+neg)))

We have 2920 neg cases and 984 pos cases from train data
the common sense baseline (accuracy score) is 0.7479508196721312


In [12]:
#Make predictions for train set
    
X, Y = generator_to_samples_and_targets(train_gen, train_steps)    
Y_pred = model.predict(X)
print_metric(Y, probs_to_binary_classes(Y_pred))

F1 score: 0.500673
precision score: 0.741036
recall score: 0.378049
accuracy score: 0.809939
matthews_corrcoef: 0.432616

Confusion matrix:
[[2790  130]
 [ 612  372]]


In [13]:
#Make predictions from dev set
X, Y = generator_to_samples_and_targets(val_gen, val_steps)    
Y_pred = model.predict(X)
print_metric(Y, probs_to_binary_classes(Y_pred))


F1 score: 0.000000
precision score: 0.000000
recall score: 0.000000
accuracy score: 0.985577
matthews_corrcoef: -0.004015

Confusion matrix:
[[820   1]
 [ 11   0]]


In [14]:
#Make predictions for test set
X, Y = generator_to_samples_and_targets(test_gen, test_steps)    
Y_pred = model.predict(X)
print_metric(Y, probs_to_binary_classes(Y_pred))


F1 score: 0.000000
precision score: 0.000000
recall score: 0.000000
accuracy score: 0.963942
matthews_corrcoef: 0.000000

Confusion matrix:
[[802   0]
 [ 30   0]]


  'precision', 'predicted', average, warn_for)
  'precision', 'predicted', average, warn_for)
  mcc = cov_ytyp / np.sqrt(cov_ytyt * cov_ypyp)
