In [1]:
import os
import numpy as np

In [2]:
from scipy.io import loadmat
import six
def mat_to_data(path):
    mat = loadmat(path)
    names = mat['dataStruct'].dtype.names
    ndata = {n: mat['dataStruct'][n][0, 0] for n in names}
    for kk,vv in six.iteritems(ndata):
#         print(vv.shape)
        if vv.shape == (1,1):
            ndata[kk] = vv[0,0]
    return ndata

def get_label(infile):
    return infile.split(".")[-2][-1] == "0"

def test_mat_to_data(path = "./data/1_1_0.mat"):
    data = mat_to_data(path)
    label = get_label(path)

In [3]:
def simplegen(folder):
    """I_J_K.mat - the Jth training data segment 
    for the Ith patient (there are three patients)
    corresponding to the Kth class (K=0 for interictal, K=1 for preictal) ."""
    infiles = list(filter(lambda x: x.endswith("mat"), os.listdir(folder)))
    NUM_FILES = len(infiles)
    for nn, ff in enumerate(infiles):
        print(ff)
        ff = os.path.join(folder,ff)
        meta = [int(k) for k in ff.split("/")[-1].split(".")[0].split("_")[:2]]
        label = get_label(ff)
        data = mat_to_data(ff)
#         yield np.dstack(data["data"]).transpose(0,2,1), np.array([[label]])
        xx= (data["data"])
        piece_len=1000
        batch_size = xx.shape[0]//piece_len
        nchannels = xx.shape[1]
        yy = np.array(label)
        yield xx, yy, meta


def filegenchop(folder, piece_len=1000):
    gen = simplegen(folder)
    for nn, (xbig, ybig, meta) in enumerate(gen):
        piece_len=1000
        batch_size = xbig.shape[0]//piece_len
        nchannels = xbig.shape[1]
        xbig = np.reshape(xbig, (batch_size, piece_len, nchannels))
        for xx in xbig:
            yield xx, ybig, meta

In [4]:
datadir = "data/"
BATCH_SIZE = 128

gen = simplegen(datadir)

# gen = filegen(datadir, BATCH_SIZE = BATCH_SIZE)

gen = filegenchop(datadir, piece_len=1000)

for xx, yy, meta in gen:
    print(xx.shape, yy, meta)
    break

1_1_0.mat
(1000, 16) True [1, 1]


In [5]:
import sqlite3

def chop_into_database( datadir,
        piece_len = 1000,
        tablename = "train1",
        ):
    
    dbname = "%s_piece_%u.db" % (tablename, piece_len)
    print("SAVING TO ", dbname)
    with sqlite3.connect(dbname) as conn:
        curs = conn.cursor()

        print("PURGING")
        qry = "DROP TABLE IF EXISTS %s" % tablename
        curs.execute(qry)

        print("CREATING")
        qry = """CREATE TABLE IF NOT EXISTS train1(
        id INT PRIMARY KEY,
        label INT,
        data BLOB,
        individual INT,
        segment INT
        )"""

        curs.execute(qry)

        insert_qry = "INSERT INTO %s (id, label, data, individual, segment) VALUES (?,?,?,?,?)"  % tablename
        gen = filegenchop(datadir, piece_len=piece_len)

        print("CHOPPING AND INSERTING")
        for id_, (xx, yy, meta) in enumerate(gen):
        #     print(xx.shape, yy, meta)
            label = bool(yy)
            blob = sqlite3.Binary(xx.tobytes())
            curs.execute(insert_qry, (id_, label, blob, meta[0], meta[1]))
    return dbname

In [6]:
tablename =  "train1"
piece_len = 1000
dbname = chop_into_database( datadir,
        piece_len = piece_len,
        tablename =tablename,
        )

SAVING TO  train1_piece_1000.db
PURGING
CREATING
CHOPPING AND INSERTING
1_1_0.mat
1_1_1.mat
1_2_0.mat
1_2_1.mat


In [7]:
with sqlite3.connect(dbname) as conn:
    curs = conn.cursor()
    curs.execute("SELECT COUNT(*) FROM %s" % tablename)
    print(curs.fetchone())

(960,)


In [8]:
conn.close()

In [9]:
def reconstruct_from_sql(row, nchannels = 16):
    row = list(row)
    row[2] = np.fromstring(row[2], dtype=np.dtype('<f4'),).reshape(-1, nchannels)
    return row

In [10]:
"TEST"
with sqlite3.connect(dbname) as conn:
    curs = conn.cursor()
    gen = filegenchop(datadir, piece_len=piece_len)
    xx, yy, meta = next(gen)

    curs.execute("SELECT * FROM %s" % tablename)
    out = reconstruct_from_sql(curs.fetchone())
    xx_reconstr = out[2]

    assert (xx_reconstr == xx).all()

1_1_0.mat


In [11]:
batch_size = 10
with sqlite3.connect(dbname) as conn:
    curs = conn.cursor()
#     select_qry = "SELECT * FROM table WHERE id IN (SELECT id FROM %s ORDER BY RANDOM() LIMIT %u)"
#     select_qry = "SELECT id FROM %s ORDER BY RANDOM() LIMIT %u" % (tablename, batch_size)
    select_qry = """SELECT id,label INT, data, individual, segment
                    FROM (SELECT abs(random() % (SELECT COUNT(id) FROM {0})) AS dummyid FROM {0} LIMIT {1}) AS t1
                    INNER JOIN
                    (SELECT * FROM {0}) AS t2
                    ON dummyid = t2.id""".format(tablename, batch_size)
    curs.execute(select_qry)
#     print(len(curs.fetchmany(batch_size)))
    print(*[(x[0], x[1]) for x in curs.fetchmany(batch_size)], sep="\n")

(740, 0)
(771, 0)
(364, 0)
(763, 0)
(473, 0)
(23, 1)
(267, 0)
(911, 0)
(4, 1)
(336, 0)
(833, 0)
(932, 0)
(672, 1)
(616, 1)
(523, 1)
(773, 0)
(262, 0)
(575, 1)
(590, 1)
(737, 0)


In [12]:
def get_random_sqite_sample(dbname, tablename, batch_size = 20):
    with sqlite3.connect(dbname) as conn:
        curs = conn.cursor()
        #select_qry = "SELECT * FROM %s ORDER BY RANDOM() LIMIT %u" % (tablename, batch_size)
        select_qry = """SELECT id,label INT, data, individual, segment
                        FROM (SELECT abs(random() % (SELECT COUNT(id) FROM {0})) AS dummyid FROM {0} LIMIT {1}) AS t1
                        INNER JOIN
                        (SELECT * FROM {0}) AS t2
                        ON dummyid = t2.id""".format(tablename, batch_size)
        curs.execute(select_qry)
        output_x = []
        output_y = []
        while True:
            rows = curs.fetchmany(batch_size)
            if not rows: break
            for row in rows:
                row = reconstruct_from_sql(row)
                output_x.append(row[2])
                output_y.append(row[1])
        # dstack returns following dims: (seqlen=1000, nchannels=16, batch_size=20)
        # we need (batch_size, seqlen, nchannels)
        return np.dstack(output_x).transpose(2,0,1), np.array(output_y).ravel()
    
x_, y_ = get_random_sqite_sample(dbname, tablename, batch_size = 20)
x_.shape

(20, 1000, 16)