<a href="https://colab.research.google.com/github/IRPARKS/NMML/blob/main/NMMLHW12P2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import h5py
import numpy as np
from sklearn.model_selection import train_test_split
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras import backend as K
import tensorflow as tf
from memory_profiler import profile

# Function to create a data generator
def data_generator(data, labels, batch_size):
    num_samples = len(data)
    while True:
        indices = np.random.permutation(num_samples)
        for start in range(0, num_samples, batch_size):
            batch_indices = indices[start:start+batch_size]
            batch_data = data[batch_indices]
            batch_labels = labels[batch_indices]
            yield batch_data[:, :, np.newaxis], batch_labels

# Load data from HDF5 file
filepath = 'Rat08-20130711_017.h5'
f = h5py.File(filepath, 'r')

# Extract LFP data and labels
states = list(f.keys())  # ['NREM', 'WAKE']
lfp_data = []
labels = []

for state in states:
    group = f[state]
    n_segments = len(group)
    for i in range(n_segments):
        lfp_data.append(group[str(i+1)][()].astype(float))
        labels.append(0 if state == 'NREM' else 1)  # Use 0 for NREM and 1 for WAKE

# Pad sequences to a maximum length (adjust maxlen based on your data)
max_sequence_length = 1000
lfp_data_padded = pad_sequences(lfp_data, dtype='float32', padding='post', truncating='post', maxlen=max_sequence_length)

# Convert to numpy arrays
labels = np.array(labels)

# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(lfp_data_padded, labels, test_size=0.2, random_state=42)

# Define batch size and create data generators
batch_size = 16
train_generator = data_generator(X_train, y_train, batch_size)
test_generator = data_generator(X_test, y_test, batch_size)

# Clear session and set GPU memory growth (if using GPU)
K.clear_session()
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    tf.config.experimental.set_memory_growth(physical_devices[0], True)

# Define LSTM model
model = Sequential()
model.add(LSTM(units=32, input_shape=(max_sequence_length, 1)))
model.add(Dropout(0.2))
model.add(Dense(1, activation='sigmoid'))

# Compile model
model.compile(optimizer=Adam(learning_rate=0.001), loss='binary_crossentropy', metrics=['accuracy'])

# Train model with memory profiling
@profile
def train_model():
    model.fit(train_generator, epochs=10, steps_per_epoch=len(X_train)//batch_size, validation_data=test_generator, validation_steps=len(X_test)//batch_size)

train_model()

# Evaluate model
loss, accuracy = model.evaluate(test_generator, steps=len(X_test)//batch_size)
print(f"Test Accuracy: {accuracy*100:.2f}%")


Collecting memory_profiler
  Downloading memory_profiler-0.61.0-py3-none-any.whl (31 kB)
Installing collected packages: memory_profiler
Successfully installed memory_profiler-0.61.0



sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/local/lib/python3.10/dist-packages/memory_profiler.py", line 847, in enable
    sys.settrace(self.trace_memory_usage)



ERROR: Could not find file <ipython-input-2-f375ada17fef>
NOTE: %mprun can only be used on functions defined in physical files, and not in the IPython environment.
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10



sys.settrace() should not be used when the debugger is being used.
This may cause the debugger to stop working correctly.
If this is needed, please check: 
http://pydev.blogspot.com/2007/06/why-cant-pydev-debugger-work-with.html
to see how to restore the debug tracing back correctly.
Call Location:
  File "/usr/local/lib/python3.10/dist-packages/memory_profiler.py", line 850, in disable
    sys.settrace(self._original_trace_function)



Test Accuracy: 62.50%
