In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
import numpy as np
import math
import torch.optim as optim

In [2]:
d = {
    "good": 0,
    "neutral": 1,
    "bad" : 2
    
}
def encoding(label):
    return d[label]

In [3]:
class PatientDataset(Dataset):
    def __init__(self, df_as_np, labels, seq_len):
        self.data = df_as_np
        self.labels = labels      
        self.seq_len = seq_len
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [23]:
def load_patient_data(df_as_np, labels, seq_len, batch_size=500):
    dataset = PatientDataset(df_as_np, labels, seq_len)
    train_size = int(0.8 * len(dataset))
    val_size = int(0.1 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return trainloader, valloader, testloader

In [24]:
class RecurrentNetwork(nn.Module):
    def __init__(self, seq_length, hidden_size, num_layers):
        super(RecurrentNetwork, self).__init__()
        self.rnn = nn.RNN(input_size=8, hidden_size=hidden_size, num_layers=num_layers, batch_first=True,  nonlinearity='relu')
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(30,3),           
        )
    
    def forward(self, x):
        x, _ = self.rnn(x)
        return self.classifier(x)

In [25]:
df= pd.read_csv("cleaned_parquet/007/sample.csv")

df["label"] = df["label"].apply(encoding)
df = df.drop(columns=["Unnamed: 0.1", "Unnamed: 0", "date_time", "date"])
df_as_np = df.to_numpy()
df_as_np = df_as_np.reshape(4500, 10, 9)
labels = df_as_np[:, :, 8][:, 0]
df_as_np = df_as_np[:, :, 0:8]
print(df_as_np.shape)

(4500, 10, 8)


In [26]:
def train(dataloader, lr, epochs):
    model = RecurrentNetwork(seq_length=10, hidden_size=3, num_layers=5)
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=0.001)
    model.train()
    criterion = nn.CrossEntropyLoss()
    batch = 0
    for epoch in range(epochs):
        
        for seq, label in dataloader:
            optimizer.zero_grad()
            outputs = model(seq.float())
            loss = criterion(outputs, label.long())
            loss.backward()
            optimizer.step()
            print(epoch, batch, loss.item())
            batch += 1
    return model

In [27]:
train_loader, val_loader, test_loader = load_patient_data(df_as_np, labels, seq_len=10, batch_size=500)

In [28]:
model = train(dataloader=train_loader, lr=0.01, epochs=5)

0 0 1.1192748546600342
0 1 1.1224114894866943
0 2 1.1015324592590332
0 3 1.1014318466186523
0 4 1.0983772277832031
0 5 1.1023683547973633
0 6 1.0999263525009155
0 7 1.1077044010162354
1 8 1.102987289428711
1 9 1.097174048423767
1 10 1.1025125980377197
1 11 1.1005340814590454
1 12 1.101534128189087
1 13 1.1007795333862305
1 14 1.099208950996399
1 15 1.0997675657272339
2 16 1.0985827445983887
2 17 1.0986138582229614
2 18 1.1027570962905884
2 19 1.0985654592514038
2 20 1.1045836210250854
2 21 1.1008334159851074
2 22 1.0979225635528564
2 23 1.0956114530563354
3 24 1.098318099975586
3 25 1.098105549812317
3 26 1.099797248840332
3 27 1.100170612335205
3 28 1.0983442068099976
3 29 1.0990275144577026
3 30 1.1001068353652954
3 31 1.1037793159484863
4 32 1.0990773439407349
4 33 1.098301887512207
4 34 1.0989100933074951
4 35 1.0990291833877563
4 36 1.1005297899246216
4 37 1.0985257625579834
4 38 1.099249243736267
4 39 1.0983099937438965


In [29]:
def test(model, dataloader):
    for seq, labels in dataloader:
        output = model(seq.float())
        pred_labels = torch.argmax(output, dim=1)
        acc = (pred_labels == labels).float().mean().item()
        print(acc)

In [30]:
test(model, test_loader)

0.3400000035762787
