**Draft Model for Mice Steep Stage Analysis**

In [None]:
import random
import numpy as np
import pandas as pd
import os
import re
import copy
import time

import torch
from torch import optim, nn
from torchvision import transforms, datasets, models

from collections import OrderedDict

import torch.utils.data as utils

from sklearn.preprocessing import LabelBinarizer

Now we'd import the data and create a dataframe out of it.

In [None]:
"""
Insert some code here for loading the data
"""

In [None]:
# classes for classification tasks (what sleep stage the mouse is in)
# based on labels from annotated data, W is wake, N is Non-REM, R is REM, and A is artifact (unique to our model)

classes = {0: "W", 1: "N", 2: "R", 3: "A"}

**Creating class for Model**

In [None]:
""" the left hand side of the CNN"""
def CNN_eeg_layer1(fs): 
    return nn.Sequential(
            nn.Conv1d(1, 64, kernel_size=fs//2, stride=fs//16, padding=2),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.MaxPool1d(kernel_size=8, stride=8),
            
            nn.Conv1d(64, 64, kernel_size=8, padding=2),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=8, padding=2),
            nn.ReLU(), 
            nn.Conv1d(64, 64, kernel_size=8, padding=2),
            nn.ReLU())

""" the right hand side of the CNN"""
def CNN_eeg_layer2(fs): 
    return nn.Sequential(
            # 1 input channel for EEG, 64 filters applied 
            nn.Conv1d(1, 64, kernel_size=fs*4, stride=fs//2, padding=2),
            nn.ReLU(), 
            nn.Dropout(0.3),
            nn.MaxPool1d(kernel_size=4, stride=4),
            
            nn.Conv1d(64, 64, kernel_size=6, padding=2),
            nn.ReLU(),
            nn.Conv1d(64, 64, kernel_size=6, padding=2),
            nn.ReLU(), 
            nn.Conv1d(64, 64, kernel_size=6, padding=2),
            nn.ReLU())

class CNN(nn.Module):
  
    def __init__(self, n_cnn_dense=256, fs=100, num_classes=4):
      
        super(CNN, self).__init__()
        
        """ left and right hand sides of CNNs """
        self.layer1_eeg = CNN_eeg_layer1(fs)        
        self.layer2_eeg = CNN_eeg_layer2(fs)

         # maybe another network at some point for emg?
        
        # the fully connected layer concatenating the two outputs
        self.fc1 = nn.Sequential(
            nn.Linear(128, n_cnn_dense),
            nn.ReLU(),            
            nn.MaxPool1d(kernel_size=4, stride=4))
        
        self.fc2 = nn.Sequential(
            nn.Linear(2048, num_classes),
            nn.LogSoftmax(dim=1))
        
                
        
    def forward(self, channels):
      
        # at some point, we'll have a second channel for emg
        
        ch1 = # extract eeg channel (Channel 1) from data frame
        
        out1_eeg = self.layer1_eeg(ch1)
        out2_eeg = self.layer2_eeg(ch1)
               
        
        out = torch.cat((out1_eeg, out2_eeg,), dim=1)
        out = self.fc1(out)
        s = out.size()[0]
        out = out.view(s, -1)
        out = self.fc2(out)

        return out

**Training the Model**

We can play around with the parameters a bit. I've included the function for counting the number of parameters from the 06-convnet.ipynb from the CSE144 example repo.

For our optimization, I used Adam because it converges faster and I don't think our data is well-formatted enough yet to get decent results with SGD.

In [None]:
# function to count number of parameters
def get_n_params(model):
    np=0
    for p in list(model.parameters()):
        np += p.nelement()
    return np

In [None]:
# make our model
model = CNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# set up training and validation data
train_data, train_labels, val_data, val_labels = # USE DATAFRAME TO MAKE THIS

train = # utils.TensorDataset(...)
train_loader = utils.DataLoader(train, batch_size=64, shuffle=True)

test = # utils.TensorDataset(...)
test_loader = utils.DataLoader(test, batch_size=64, shuffle=True)

data_loaders = {'train': train_loader, 'valid': test_loader}
dataset_sizes = {'train': len(train), 'valid': len(test)}

# define optimization function and print number of params
optimizer = optim.Adam(model.parameters(), lr=0.005)
print('Number of parameters: {}'.format(get_n_params(model)))

In one example I saw, they used [model.state_dict()](https://pytorch.org/tutorials/recipes/recipes/what_is_state_dict.html) to record the best learnable parameters (i.e. weights and biases) of a model. They use it for storing the best possible model found during training.

In [None]:
def train_model(model, criteria, optimizer, scheduler, num_epochs):
    # deep copy and save the best model
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                scheduler.step()
                model.train()  
            else:
                model.eval()   
                
            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

   
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criteria(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            if phase == 'valid' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model