**Draft Model for Mice Steep Stage Analysis**

In [1]:
import torch
import torch.nn.functional as F
from torch import optim, nn
from torchvision import transforms, datasets, models


import torch.utils.data as utils
from inputMassager import *


In [2]:
inputHandler = inputMassager()
data_filepath = R"C:\Users\Adam\Desktop\CHDCtrl1_CHD801FR_normal\CHD801FR_20221123_normal.txt"
annotated_filepath = R"C:\Users\Adam\Desktop\CHDCtrl1_CHD801FR_normal\CHD801FR_20221123_normal_annotated.txt"
period_size = 200
num_periods = None
labels, eeg_samples, emg_samples, eeg_fft, emg_fft = get_labeled_data(data_filepath, annotated_filepath, period_size, num_periods)

In [None]:
everything = torch.cat((eeg_samples, eeg_fft), dim = 1)
print(everything.size())

torch.Size([100889, 2, 200])


In [None]:
#ds = torch.utils.data.TensorDataset(eeg_fft, emg_fft, labels)
ds = torch.utils.data.TensorDataset(torch.cat((eeg_samples, eeg_fft), dim = 1), torch.cat((emg_samples, emg_fft), dim = 1), labels)

train_size = int(len(ds) *.80)
val_size = len(ds) - int(len(ds) *.80)

train_dataset, val_dataset = torch.utils.data.random_split(ds, [train_size, val_size])

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=100, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=100, shuffle=True)

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

**Creating class for Model**

In [None]:
class CNN(nn.Module):
    def __init__(self, period_size, input_channels=1):
        super(CNN, self).__init__()
        self.input_channels = input_channels

        self.eeg_conv1 = nn.Conv1d(in_channels=input_channels, out_channels=period_size, kernel_size=period_size//2)
        self.eeg_conv2 = nn.Conv1d(in_channels=period_size, out_channels=period_size//4, kernel_size=8)
        self.eeg_conv3 = nn.Conv1d(in_channels=period_size // 4, out_channels=256, kernel_size=2)

        self.emg_conv1 = nn.Conv1d(in_channels=input_channels, out_channels=period_size, kernel_size=period_size//2)
        self.emg_conv2 = nn.Conv1d(in_channels=period_size, out_channels=period_size//4, kernel_size=8)
        self.emg_conv3 = nn.Conv1d(in_channels=period_size // 4, out_channels=256, kernel_size=2)



        self.fc1 = nn.Linear(2560*2, 20)
        self.fc2 = nn.Linear(10, 5)


    def forward(self, c1, c2):
        #print(x.size())
        c1 = self.eeg_conv1(c1)
        c1 = F.relu(c1)
        # print(x.size())
        c1 = F.max_pool1d(c1, kernel_size=2)
        c1 = self.eeg_conv2(c1)
        c1 = F.max_pool1d(c1, kernel_size=2)
        # print(x.size())
        c1 = self.eeg_conv3(c1)
        c1 = F.max_pool1d(c1, kernel_size=2)
        #print(x.size())

        c2 = self.eeg_conv1(c2)
        c2 = F.relu(c2)
        # print(x.size())
        c2 = F.max_pool1d(c2, kernel_size=2)
        c2 = self.eeg_conv2(c2)
        c2 = F.max_pool1d(c2, kernel_size=2)
        # print(x.size())
        c2 = self.eeg_conv3(c2)
        c2 = F.max_pool1d(c2, kernel_size=2)


        c1c2 = torch.cat((c1, c2), dim=1)

        c1c2 = c1c2.flatten(1)
        #print(x.size())
        c1c2 = self.fc1(c1c2)
        c1c2 = F.max_pool1d(c1c2, kernel_size=2)
        # print(x.size())
        c1c2 = self.fc2(c1c2)
        c1c2 = F.log_softmax(c1c2, dim=1)
#ME      
        #c1c2 = F.relu(c1c2)
        
        return c1c2

In [None]:
def train_model(epochs, model):
    print("device:", device)
    model.to(device)
    model.train() # set model to training mode
  
    # class_weights = torch.tensor([1/50, 1/5886, 1/13258, 1/285, 0], device=device)
    class_weights = torch.tensor([0.9, 2.6, 0.7, 0.5, 0.9], device=device)

    loss_fun = nn.CrossEntropyLoss(weight=class_weights) #define a loss function object

    for epoch in range(epochs):

      for batch_idx, (channel1, channel2, target) in enumerate(train_loader):
          #print(c)
          # print(target)
          channel1, channel2, target = channel1.to(device), channel2.to(device), target.to(device)
          
          optimizer.zero_grad()
          output = model(channel1, channel2) # guess we have to pass all channels instead of data?]\
          model.to(device)
          loss = loss_fun(output,target)
          loss.backward()
          optimizer.step()
          if batch_idx % 100 == 0:
              print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                  epoch, batch_idx * len(channel1), len(train_loader.dataset),
                  100. * batch_idx / len(train_loader), loss.item()))

In [None]:
model = CNN(period_size, 2)
model.to(device)
  
optimizer = torch.optim.SGD(model.parameters(), lr=.01, momentum=0.9)

train_model(10, model)

device: cuda:0


In [None]:
def evaluate_model(model, dataloader, is_test=False, confidence_level = -0.5):
  # Set model to evaluation mode
  model.eval()

  #dictionary to store the accuracy of predictions by sleep stage
  acurracy_stages = {0:[], 1:[], 2:[], 3:[], 4:[]}
  stage_names = ["not sure(0): ", "rem accuracy(1): ", "non-rem accuracy(2): ", "wake accuracy(3): ", "artifact accuracy(4): "]
  predictions = []
  label_list = []
  

  with torch.no_grad():
    correct = 0
    loss = 0

    criterion = torch.nn.CrossEntropyLoss(reduction='none')

    for channel1, channel2, target in dataloader:
        channel1, channel2, target = channel1.to(device), channel2.to(device), target.to(device)
        outputs = model(channel1, channel2)

        loss += torch.sum(criterion(outputs, target)).item()

        pred_value, pred = outputs.data.max(1, keepdim=True) # get the index of the max log-probability                                                         
          
        for i in range(len(target)):      
          
          #if the predicted value is lower that the confidence level, set the predicted value to 0 and the target value to 0
           if pred_value[i] < confidence_level:
            pred[i] = 4
            target[i] = 4

          #append boolean indicating whether or not each prediction matched the target to the apporpriate sleep stage list
           acurracy_stages[target[i].item()].append(pred[i].eq(target[i]).item())
           predictions.append(pred[i].item())
           label_list.append(target[i].item())
           
        correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()               
          

#calculate total accuracy and loss
    loss /= len(dataloader.dataset)
    accuracy = 100. * correct / len(dataloader.dataset)
    print('\n{} set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        "Test" if is_test else "Validation",
        loss, correct, len(dataloader.dataset),
        accuracy))

  # Set model back to training mode
  model.train()
  return predictions, label_list

In [None]:
#if you make the confidence level too low, you'll get an error about not being able to divede by 0,
#that just means that none of the data is being classified to not sure, keep it below -.75 to avoid this
from sklearn.metrics import classification_report
predictions, label_list = evaluate_model(model, val_loader, is_test=True, confidence_level=-0.75)

target_names = ['REM', 'Non-REM', 'Wake', 'Artifact']
print('Testing Report for Dual CNN with Fourier Transform for Accuracy, Precision, Recall, and F1-Score')
print('------------------------------------------------------------------------------------------------')
print(classification_report(label_list, predictions, labels=[1,2,3,4], target_names=target_names))



# Calculate class-wise counts
predictions_df = pd.Series(predictions, name='Predicted')
counts = predictions_df.value_counts()

print('Classified Periods:')
for class_name, count in counts.items():
    print(f"Class {class_name}, {target_names[class_name-1]}: {count}")


Test set: Average loss: 0.4903, Accuracy: 17810/20178 (88%)

Testing Report for Dual CNN with Fourier Transform for Accuracy, Precision, Recall, and F1-Score
------------------------------------------------------------------------------------------------
              precision    recall  f1-score   support

         REM       0.08      0.06      0.07       714
     Non-REM       0.87      0.95      0.91      5719
        Wake       0.92      0.91      0.92     12924
    Artifact       0.93      0.72      0.81       821

    accuracy                           0.88     20178
   macro avg       0.70      0.66      0.67     20178
weighted avg       0.88      0.88      0.88     20178

Classified Periods:
Class 3, Wake: 12752
Class 2, Non-REM: 6230
Class 4, Artifact: 640
Class 1, REM: 556


In [None]:
print(len(predictions))

20178


with more training data, the model almost always predicts wake becuase it's so overepresented, when i was running just 10000 periods, it was in the mid to low 90s for non-rem and rem accuracy but with all the data, it's 99% accurate for wake and in the 30s for non rem