In [63]:
import pickle
import numpy as np
import os
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, TensorDataset, Dataset
from sklearn.model_selection import StratifiedShuffleSplit

In [64]:
class WindowedSequenceDataset(Dataset):
    def __init__(self, data, labels, window_size, stride):
        self.data = []
        self.labels = []
        for trace, event_annotations in zip(data, labels):
            for i in range(0, len(trace) - window_size + 1, stride):
                window_trace = trace[i:i + window_size]
                window_labels = event_annotations[i:i + window_size]
                self.data.append(window_trace)
                self.labels.append(window_labels)
        self.data = np.array(self.data)
        self.labels = np.array(self.labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float32).unsqueeze(-1), torch.tensor(self.labels[idx], dtype=torch.long)


In [65]:
# Load data function
def load_data(fdata):
    data = []
    labels = []
    for expt in fdata.keys():
        for scan in fdata[expt]:
            trace = fdata[expt][scan]['SOMA']['DENOISED_TRACE']
            event = fdata[expt][scan]['SOMA']['SPIKE_INTERVAL']
            data.append(trace)
            labels.append(event)
    return data, labels

pickle_path = '/scratch/da3245/datasets/neuro_scans/cnn_training_set.pkl'
with open(pickle_path, 'rb') as f:
    fdata = pickle.load(f) 
data, labels = load_data(fdata)
del fdata

In [66]:
split_percent=0.15
X_train=[]
y_train=[]
X_val=[]
y_val=[]
X_test=[]
y_test=[]
for i in range(len(data)):
    len_trace=len(data[i])
    test_inx=split_percent*len_trace
    train_inx =len_trace - 2* test_inx
    val_inx=len_trace - test_inx
    X_train.append(data[i][0:int(train_inx)])
    y_train.append(labels[i][0:int(train_inx)])

    X_val.append(data[i][int(train_inx):int(val_inx)])
    y_val.append(labels[i][int(train_inx):int(val_inx)])

    X_test.append(data[i][int(val_inx):])
    y_test.append(labels[i][int(val_inx):])

    

In [67]:

window_size = 500
stride = 450

windowed_train_dataset = WindowedSequenceDataset(X_train, y_train, window_size, stride)

In [68]:
windowed_train_dataset.data.shape

(25418, 500)

In [69]:
stratify_labels = (windowed_train_dataset.labels.sum(axis=1) > 0).astype(int)
print(sum(stratify_labels))
print('Imbalance Perecentage Training',sum(stratify_labels)/windowed_train_dataset.labels.shape[0])

10362
Imbalance Perecentage Training 0.40766386025651113


In [70]:

windowed_val_dataset = WindowedSequenceDataset(X_val, y_val, window_size, stride)
print(windowed_val_dataset.data.shape)
stratify_labels = (windowed_val_dataset.labels.sum(axis=1) > 0).astype(int)
print(sum(stratify_labels))
print('Imbalance Perecentage valing',sum(stratify_labels)/windowed_val_dataset.labels.shape[0])

(5373, 500)
1534
Imbalance Perecentage valing 0.28550158198399406


In [71]:
windowed_test_dataset = WindowedSequenceDataset(X_test, y_test, window_size, stride)
print(windowed_test_dataset.data.shape)
stratify_labels = (windowed_test_dataset.labels.sum(axis=1) > 0).astype(int)
print(sum(stratify_labels))
print('Imbalance Perecentage testing',sum(stratify_labels)/windowed_test_dataset.labels.shape[0])

(5373, 500)
1483
Imbalance Perecentage testing 0.27600967801972826


In [72]:
with open('datasets/timeseries_voltage/pre_split_overlap50.pkl', 'wb') as f:
    pickle.dump({'X_train': windowed_train_dataset.data, 'y_train': windowed_train_dataset.labels,
                 'X_val': windowed_val_dataset.data, 'y_val': windowed_val_dataset.labels,
                 'X_test': windowed_test_dataset.data, 'y_test': windowed_test_dataset.labels,},
                f)


In [22]:

window_size = 200
stride = 200


windowed_dataset = WindowedSequenceDataset(data, labels, window_size, stride)

In [23]:
windowed_dataset.labels.shape

(81984, 200)

In [24]:
stratify_labels = (windowed_dataset.labels.sum(axis=1) > 0).astype(int)
print(sum(stratify_labels))
print('Imbalance Perecentage',sum(stratify_labels)/windowed_dataset.labels.shape[0])

14977
Imbalance Perecentage 0.18268198672911787


In [25]:
X_train, X_test, y_train, y_test = train_test_split(
    windowed_dataset.data, windowed_dataset.labels, test_size=0.15, stratify=stratify_labels, random_state=42
)
new_strat=(y_train.sum(axis=1)>0).astype(int)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.15,stratify=new_strat, random_state=42
)

In [26]:
print(X_train.shape,X_test.shape,X_val.shape)

(59233, 200) (12298, 200) (10453, 200)


In [27]:
(y_train.sum(axis=1)>0).astype(int).sum()/len(y_train)

np.float64(0.18266844495467055)

In [28]:
(y_val.sum(axis=1)>0).astype(int).sum()/len(y_val)

np.float64(0.18272266335023438)

In [29]:
(y_test.sum(axis=1)>0).astype(int).sum()/len(y_test)

np.float64(0.1827126362010083)

In [30]:
with open('datasets/timeseries_voltage/stratified_no_overlap.pkl', 'wb') as f:
    pickle.dump({'X_train': X_train, 'y_train': y_train,
                 'X_val': X_val, 'y_val': y_val,
                 'X_test': X_test, 'y_test': y_test}, f)
