In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np
import pandas as pd
from pathlib import Path

### Dataset

In [2]:
# Define your dataset
class CustomDataset(Dataset):
    def __init__(self, ):
        data_path = Path("data/processed")
        csv_files = list(data_path.glob("*.csv"))

        self.Ys = []
        self.Xs = []

        for f in csv_files:
            df = pd.read_csv(f)

            # interpolate missing values
            df = df.interpolate(method="linear", limit_direction="both")

            # trim to 1800 samples
            df = df.iloc[:1800]

            # extract features
            self.Ys.append(df[["cursor_x", "cursor_y"]].values.astype(np.float32))
            self.Xs.append(df.drop(columns=["cursor_x", "cursor_y"]).values.astype(np.float32))


        self.data = list(zip(self.Xs, self.Ys))

    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)


## Model

In [3]:
class pytorchLSTM(nn.Module):
    def __init__(self, n_in, hidden_size, n_out):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_in = n_in
        self.lstm = nn.LSTM(n_in, hidden_size, num_layers=1, batch_first = True)
        self.output_layer = nn.Linear(hidden_size, n_out)

        
    def forward(self, X, hidden = None):
        if hidden == None:
            b =  X.shape[0]
            hidden = (
                torch.zeros(1, b, self.hidden_size).to(X.device),
                torch.zeros(1, b, self.hidden_size).to(X.device)
            )
            out, hidden = self.lstm(X, hidden)
            out = self.output_layer(out)
        else:
            out, hidden = self.lstm(X, hidden)
            out = self.output_layer(out)
        return out, hidden

## Training loop

In [4]:
batch_size = 4
hidden_size = 128
lr = 1e-3
num_epochs = 1000

In [5]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Create the dataset and data loader
dataset = CustomDataset()

n_in = dataset.data[0][0].shape[1]
n_out = dataset.data[0][1].shape[1]

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [6]:
model = pytorchLSTM(n_in, hidden_size, n_out).to(device)

# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)


In [7]:
hidden = None
for epoch in range(num_epochs):
    epoch_loss = 0
    for i, (X, Y) in enumerate(dataloader):
        optimizer.zero_grad()

        # Detach the hidden state from the previous sequence
        if hidden is not None:
            hidden = (hidden[0].detach(), hidden[1].detach())

        # Forward pass
        outputs, hidden = model(X, hidden=hidden)
        loss = criterion(outputs, Y)
        loss.backward()

        # Adjust learning weights
        optimizer.step()
        epoch_loss += loss.item()
        
    if epoch % 100 == 0 or epoch == 0:
        print(f'Epoch [{epoch}/{num_epochs}],  Loss: {epoch_loss:.4f}')

# Save the trained model
torch.save(model.state_dict(), 'models/lstm_model.pth')

Epoch [1/1000],  Loss: 1690763.1250
Epoch [11/1000],  Loss: 1689966.2500
Epoch [21/1000],  Loss: 1682905.0000
Epoch [31/1000],  Loss: 1677774.2500
Epoch [41/1000],  Loss: 1673452.7500
Epoch [51/1000],  Loss: 1669764.7500
Epoch [61/1000],  Loss: 1666315.2500
Epoch [71/1000],  Loss: 1663033.0000
Epoch [81/1000],  Loss: 1659873.0000
Epoch [91/1000],  Loss: 1656802.3750
Epoch [101/1000],  Loss: 1653800.8750
Epoch [111/1000],  Loss: 1650839.8750
Epoch [121/1000],  Loss: 1647869.8750
Epoch [131/1000],  Loss: 1644947.6250
Epoch [141/1000],  Loss: 1641987.2500
Epoch [151/1000],  Loss: 1639082.3750
Epoch [161/1000],  Loss: 1636212.5000
Epoch [171/1000],  Loss: 1633371.8750
Epoch [181/1000],  Loss: 1630557.2500
Epoch [191/1000],  Loss: 1627764.8750
Epoch [201/1000],  Loss: 1624992.8750
Epoch [211/1000],  Loss: 1622238.3750
Epoch [221/1000],  Loss: 1619500.2500
Epoch [231/1000],  Loss: 1616776.8750
Epoch [241/1000],  Loss: 1614067.5000
Epoch [251/1000],  Loss: 1611370.6250
Epoch [261/1000],  Loss