Import libraries

In [16]:
import torch
import cv2
%reload_ext autoreload
%autoreload 2
# this is the main library used (sits on top of PyTorch)
from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *
from sklearn import metrics

Set up data: PATH should lead to the data folder with subfolders train, valid, and test. Train and valid should each have subfolders "yes" and "no". For this to run, you need something in these "yes" and "no" folders, though it won't affect the predictions--just put in some images from your test data.

In [17]:
# where the data is
PATH = "data\e_UMHS-0018-0002.day-04_BD3-BD4"
# using resnet architecture
arch = resnet34
# size of square image in pixels
sz = 44
# transforms used on training data
transforms_up_down = [RandomScale(sz,1.2),RandomRotate(1)]
tfms = tfms_from_model(arch,sz,crop_type = CropType.NO,aug_tfms=transforms_up_down)
# data: comes from PATH, used tfms on training data, bs of 8 for training data, test data located in test folder
data = ImageClassifierData.from_paths(PATH,tfms=tfms,bs=8,test_name='test')
# load in pretraine
state = torch.load('saved_model.pkl',map_location=torch.device('cpu')) # remove map_location parameter if on GPU
learn2 = ConvLearner.pretrained(arch,data,precompute=False)
learn2.model.load_state_dict(state)

<All keys matched successfully>

In [18]:
log_preds_test = learn2.predict(is_test=True)
preds_test = np.argmax(log_preds_test,axis=1)
probs_test = np.exp(log_preds_test[:,1])

In [19]:
test_names = np.empty_like(data.test_ds.fnames)
for i in range(len(data.test_ds.fnames)):
    test_names[i] = data.test_ds.fnames[i]
    #temp = data.test_ds.fnames[i]
    #matchobj = re.search('.*im.*',temp)
    #test_names[i] = matchobj.group()
test_df = pd.DataFrame(data = test_names,columns = ['image_number'])
test_df['prediction'] = preds_test
test_df['probability'] = probs_test

In [20]:
print(test_df.to_string())

             image_number  prediction   probability
0          test\img_1.jpg           0  6.568760e-06
1         test\img_10.jpg           0  2.097026e-04
2        test\img_100.jpg           0  2.503952e-04
3       test\img_1000.jpg           0  1.407602e-03
4      test\img_10000.jpg           0  8.248540e-04
5      test\img_10001.jpg           0  4.378020e-06
6      test\img_10002.jpg           0  2.191625e-04
7      test\img_10003.jpg           0  2.184888e-04
8      test\img_10004.jpg           0  3.545652e-04
9      test\img_10005.jpg           0  7.569134e-05
10     test\img_10006.jpg           0  1.704193e-05
11     test\img_10007.jpg           0  1.015291e-03
12     test\img_10008.jpg           0  1.188789e-04
13     test\img_10009.jpg           1  9.475780e-01
14      test\img_1001.jpg           0  9.269609e-04
15     test\img_10010.jpg           0  2.077658e-04
16     test\img_10011.jpg           0  1.681919e-04
17     test\img_10012.jpg           0  6.889042e-05
18     test\

In [21]:
#Print only the indicies with a prediction of 1
num_detected = 0
detected_index = [];
for i in range(len(test_df)):
    if test_df.prediction[i] == 1:
        print(test_df.loc[[i]])
        num_detected = num_detected + 1
        detected_index.append(i)

print('\nThe number of detected spike-ripples: ',num_detected)

          image_number  prediction  probability
13  test\img_10009.jpg           1     0.947578
           image_number  prediction  probability
100  test\img_10088.jpg           1     0.620529
           image_number  prediction  probability
241  test\img_10214.jpg           1     0.832503
           image_number  prediction  probability
242  test\img_10215.jpg           1     0.531013
           image_number  prediction  probability
310  test\img_10277.jpg           1     0.845432
           image_number  prediction  probability
971  test\img_10872.jpg           1     0.536029
            image_number  prediction  probability
1040  test\img_10934.jpg           1     0.717062
            image_number  prediction  probability
1070  test\img_10961.jpg           1     0.979805
            image_number  prediction  probability
1071  test\img_10962.jpg           1     0.919648
            image_number  prediction  probability
1219  test\img_11095.jpg           1     0.937083
            im

In [22]:
#Create a copy of the original data frame to manipulate
spike_ripple_df = test_df.copy()

#Get rid of all rows that are not flagged as spike ripples
spike_ripple_df = spike_ripple_df.drop(index=spike_ripple_df.index.difference(detected_index))

#Sort data frame by increasing probability
spike_ripple_df.sort_values(by=['probability'],inplace=True)

print(spike_ripple_df)

             image_number  prediction  probability
2895   test\img_12603.jpg           1     0.505601
8822   test\img_17939.jpg           1     0.511505
8742   test\img_17867.jpg           1     0.514990
16160    test\img_825.jpg           1     0.527075
10862   test\img_3481.jpg           1     0.530605
...                   ...         ...          ...
1978   test\img_11779.jpg           1     0.991343
10022   test\img_2725.jpg           1     0.991491
8682   test\img_17812.jpg           1     0.992501
16184   test\img_8271.jpg           1     0.993342
2070   test\img_11861.jpg           1     0.998512

[114 rows x 3 columns]


In [23]:
#Convert the dataframe to a csv file and export it to the given file name
spike_ripple_csv = spike_ripple_df.to_csv("Detected_spike_ripples\e_UMHS-0018-0002.day-04_BD3-BD4.csv")