In [18]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import math
import torch.optim as optim

In [19]:
d = {
    "good": 0,
    "neutral": 1,
    "bad" : 2
    
}
def encoding(label):
    return d[label]

In [20]:
class PatientDataset(Dataset):
    def __init__(self, df_as_np, labels, seq_len):
        self.data = df_as_np
        self.labels = labels      
        self.seq_len = seq_len
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [21]:
def load_patient_data(df_as_np, labels, seq_len, batch_size=50):
    dataset = PatientDataset(df_as_np, labels, seq_len)
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return trainloader, valloader, testloader

In [22]:
class RecurrentNetwork(nn.Module):
    def __init__(self, seq_length, hidden_size, num_layers):
        super(RecurrentNetwork, self).__init__()
        self.rnn = nn.RNN(input_size=7, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,  nonlinearity='relu')
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(150,3),           
        )
    
    def forward(self, x):
        x, _ = self.rnn(x)
        return self.classifier(x)

In [23]:
df_features= pd.read_csv("minute_data/encoded.csv")
df_labels = pd.read_csv("minute_data/labels.csv")
df_features = df_features.drop(columns=["Unnamed: 0.1", "Unnamed: 0", "Unnamed: 0.2", "date", "date_only", "time"])
df_labels = df_labels.drop(columns=["Unnamed: 0"])
for column in df_labels.columns:
    df_labels[column] = df_labels[column].apply(encoding)
df_features_as_np = df_features.to_numpy()[:180950,:]
df_features_as_np = df_features_as_np.reshape(3619, 50, 7)
df_labels_as_np = df_labels.to_numpy()
print(df_features_as_np.shape)
print(df_labels_as_np.shape)

(3619, 50, 7)
(3619, 6)


In [24]:
def train(dataloader, lr, epochs):
    model = RecurrentNetwork(seq_length=50, hidden_size=3, num_layers=10)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.0001)
    model.train()
    criterion = nn.CrossEntropyLoss()
    batch = 0
    for epoch in range(epochs):
        
        for seq, label in dataloader:
            optimizer.zero_grad()
            outputs = model(seq.float())
            loss = criterion(outputs, label.long())
            loss.backward()
            optimizer.step()
            print(epoch, batch, loss.item())
            batch += 1
    return model

In [25]:
def select_target(i):
    train_loader, val_loader, test_loader = load_patient_data(df_features_as_np, df_labels_as_np[:, i], seq_len=10, batch_size=500)
    return train_loader, val_loader,test_loader

In [26]:
def test(model, dataloader):
    for seq, labels in dataloader:
        output = model(seq.float())
        pred_labels = torch.argmax(output, dim=1)
        acc = (pred_labels == labels).float().mean().item()
        print(acc)
        print(pred_labels)
        print("##########################")
        print(labels)

In [32]:
for i in range(6):
    print(f"############### LABEL {i} #################")
    train_loader, val_loader, test_loader = select_target(i)
    model = train(dataloader=train_loader, lr=0.05, epochs=10)
    print("###############################################")
    test(model, test_loader)
    

############### LABEL 0 #################
0 0 1.1241484880447388
0 1 1.8465445041656494
0 2 1.190786600112915
0 3 0.9194248914718628
0 4 1.149585485458374
0 5 1.1491881608963013
1 6 1.039686679840088
1 7 0.9813769459724426
1 8 0.8989889621734619
1 9 0.8837395310401917
1 10 0.8619005084037781
1 11 0.8717852830886841
2 12 0.8987894654273987
2 13 1.0093607902526855
2 14 0.8791966438293457
2 15 0.8836639523506165
2 16 0.8807154297828674
2 17 0.8905461430549622
3 18 0.8910221457481384
3 19 0.9140796065330505
3 20 0.9195306897163391
3 21 0.8821433186531067
3 22 0.9118075370788574
3 23 0.8628560900688171
4 24 0.826291024684906
4 25 0.9026026725769043
4 26 0.9242742657661438
4 27 0.9211835861206055
4 28 0.9050272703170776
4 29 0.8602099418640137
5 30 0.8732168674468994
5 31 0.9166204333305359
5 32 0.897769570350647
5 33 0.8896781206130981
5 34 0.8513472080230713
5 35 0.8831977844238281
6 36 0.8803014159202576
6 37 0.8898118138313293
6 38 0.8865800499916077
6 39 0.8929550647735596
6 40 0.858550