In [9]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import math
import torch.optim as optim

In [10]:
d = {
    "good": 0,
    "neutral": 1,
    "bad" : 2
    
}
def encoding(label):
    return d[label]

In [11]:
class RecurrentNetworkWithSleepData(nn.Module):
    def __init__(self, seq_length, hidden_size, num_layers):
        super(RecurrentNetworkWithSleepData, self).__init__()
        self.rnn = nn.RNN(input_size=7, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,  nonlinearity='relu')
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(1000, 500),
            nn.ReLU(),
            nn.Linear(500, 100),
            nn.ReLU(),
            nn.Linear(100, 50),
            nn.ReLU(),
            nn.Linear(50,3)
        )
    
    def forward(self, x):
        encoded_actigraphy_features, _ = self.rnn(x[:, :, :7])#actigraphy features
        sleep_features = x[:, :, 7:]
        concatenated = torch.cat((encoded_actigraphy_features, sleep_features), dim=2)
        return self.classifier(concatenated)

In [12]:
class PatientDataset(Dataset):
    def __init__(self, df_as_np, labels, seq_len):
        self.data = df_as_np
        self.labels = labels      
        self.seq_len = seq_len
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [13]:
def load_patient_data(df_as_np, labels, seq_len, batch_size=50):
    dataset = PatientDataset(df_as_np, labels, seq_len)
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return trainloader, valloader, testloader

In [14]:
def train(dataloader, seq_length, lr, epochs):
    model = RecurrentNetworkWithSleepData(seq_length=seq_length, hidden_size=3, num_layers=10)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001)
    model.train()
    criterion = nn.CrossEntropyLoss()
    batch = 0
    for epoch in range(epochs):
        
        for seq, label in dataloader:
            optimizer.zero_grad()
            outputs = model(seq.float())
            loss = criterion(outputs, label.long())
            loss.backward()
            optimizer.step()
            print(epoch, batch, loss.item())
            batch += 1
    return model

In [15]:
def test(model, dataloader):
    for seq, labels in dataloader:
        output = model(seq.float())
        pred_labels = torch.argmax(output, dim=1)
        acc = (pred_labels == labels).float().mean().item()
        print(acc)
        print(pred_labels)
        print("##########################")
        print(labels)
        break

In [16]:
df_features = pd.read_csv("minute_data_007_youcef/sleep+encoded.csv")
df_labels = pd.read_csv("minute_data_007_youcef/sleep+encoded_labels.csv")
df_labels = df_labels.drop(columns=["Unnamed: 0"])
print(df_labels.columns)
for col in df_labels.columns:
    df_labels[col] =  df_labels[col].apply(encoding)

df_features = df_features.drop(columns=["date", "date_only", "time"]).to_numpy()[:128800,:].reshape(2576, 50, 24)
df_labels = df_labels.to_numpy()

Index(['average', 'phq_9', 'cgis', 'gad_7', 'wsas', 'qids'], dtype='object')


In [17]:
for i in range(6):
    print(f"############### LABEL {i} #################")
    train_loader, val_loader, test_loader =load_patient_data(df_features, df_labels[:, i], seq_len=50, batch_size=100)
    model = train(dataloader=train_loader, lr=0.05, epochs=4, seq_length=50)
    test(model, test_loader)

############### LABEL 0 #################
0 0 1.0747333765029907
0 1 128.43490600585938
0 2 10.446381568908691
0 3 5.351353645324707
0 4 3.8651211261749268
0 5 1.6691842079162598
0 6 1.6057430505752563
0 7 0.8022276163101196
0 8 1.0506720542907715
0 9 0.47161731123924255
0 10 0.6595698595046997
0 11 0.8132575750350952
0 12 0.685547947883606
0 13 0.6658703684806824
0 14 0.35483384132385254
0 15 1.3660606145858765
0 16 0.47076842188835144
0 17 0.9755422472953796
0 18 6.136995792388916
0 19 3.034501314163208
0 20 2.3431127071380615
1 21 1.0684345960617065
1 22 0.95206618309021
1 23 0.9175847768783569
1 24 0.9577121734619141
1 25 0.9215032458305359
1 26 0.8595895171165466
1 27 15430474.0
1 28 0.808597207069397
1 29 0.8269146680831909
1 30 550.0515747070312
1 31 0.835552453994751
1 32 0.8079496026039124
1 33 0.7627130746841431
1 34 3.324538230895996
1 35 0.788767397403717
1 36 1.3255964517593384
1 37 1.7098971605300903
1 38 1.7856848239898682
1 39 2.211029291152954
1 40 1.3147393465042114
1