In [170]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import json
from sklearn.metrics import accuracy_score
import math

In [171]:
class KeyboardDataset(Dataset):
    def __init__(self, data, labels):
        self.data = self.pad_data(data)
        self.labels = labels
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    def pad_data(self, data):
        new_data_list = []
        max_len = max(len(sublist) for sublist in data)
        padded_data = []
        for sublist in data:
            curr_len = len(sublist)
            new_data = sublist
            while (curr_len < max_len):
                new_data.append([-1.0, -1, -1])
                curr_len += 1
            new_data_list.append(new_data)
                
        return new_data_list


In [172]:
class PositionalEncoding(nn.Module):
    """
    https://pytorch.org/tutorials/beginner/transformer_tutorial.html
    """

    def __init__(self, d_model, vocab_size=5000, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(vocab_size, d_model)
        position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        # pe[:, 1::2] = torch.cos(position * div_term)
        pe[:, 0::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1), :].permute(2, 1, 0)
        return self.dropout(x)

In [173]:
class TransformerModel(nn.Module):
    def __init__(self, vocab_size, embedding_size, num_classes, num_layers=1, num_heads=2, hidden_size=64, dropout=0.1):
        super(TransformerModel, self).__init__()
        self.pos_encoder = PositionalEncoding(
            d_model=3,
            dropout=dropout,
            vocab_size=vocab_size,
        )
        
        encoder_layer = nn.TransformerEncoderLayer(embedding_size, num_heads, hidden_size, dropout)
        
        # Pass the encoder layer instance to nn.TransformerEncoder
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers)
        
        self.fc = nn.Linear(embedding_size, num_classes)
        
    def forward(self, x):
        # embedded = self.embedding(x)
        # embedded = embedded.permute(1, 0, 2)  # Change dimensions for transformer
        # output = self.transformer(embedded)
        x = x.permute(2, 1, 0)
        x = self.pos_encoder(x)
        x = x.permute(1, 0, 2).squeeze()
        output = self.transformer(x)
        output = output.mean(dim=0)  # Average across time steps
        output = self.fc(output)
        return output

In [174]:
def train(model, iterator, optimizer, criterion):
    model.train()
    epoch_loss = 0
    for src, trg in iterator:
        optimizer.zero_grad()
        output = model(src, trg)
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        trg = trg[1:].view(-1)
        loss = criterion(output, trg)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(iterator)

def evaluate(model, iterator, criterion):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for src, trg in iterator:
            output = model(src, trg)
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss / len(iterator)



In [175]:
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(device)

mps


In [176]:
# labels key:
# 0: Aidan
# 1: Srujan
# 2: Eric
# 3: Tony

# vocab_size = 113
vocab_size = 120  # made it 120 instead of 113 to clear an error
embedding_size = 3
num_classes = 4
num_layers = 2
num_heads = 1
hidden_size = 128
dropout = 0.1
learning_rate = 0.001
# batch_size = 32
batch_size = 1
epochs = 10
file_prefix = '../'
datapoints_per_person = 3000

fh = open(f'{file_prefix}aidan_final_data_overlapping.json', 'r')
aidan_data = json.load(fh)[:datapoints_per_person]

fh = open(f'{file_prefix}srujan_final_data_overlapping.json', 'r')
srujan_data = json.load(fh)[:datapoints_per_person]

fh = open(f'{file_prefix}eric_final_data_overlapping.json', 'r')
eric_data = json.load(fh)[:datapoints_per_person]

fh = open(f'{file_prefix}tony_final_data_overlapping.json', 'r')
tony_data = json.load(fh)[:datapoints_per_person]

data = aidan_data + srujan_data + eric_data + tony_data
labels = ([0] * datapoints_per_person) + ([1] * datapoints_per_person) + ([2] * datapoints_per_person) + ([3] * datapoints_per_person)

print(f'There are {len(data)} 5-second intervals, {datapoints_per_person} intervals from each person')

train_data, val_data, train_labels, val_labels = train_test_split(data, labels, test_size=0.2, random_state=42)

train_dataset = KeyboardDataset(train_data, train_labels)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = KeyboardDataset(val_data, val_labels)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

model = TransformerModel(vocab_size, embedding_size, num_classes, num_layers, num_heads, hidden_size, dropout)
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)


There are 12000 5-second intervals, 3000 intervals from each person


In [178]:
# Training loop
epochs = 100
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(epochs):
    model.train()
    temp_train_losses = []
    correct_train = 0  
    total_train = 0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        labels = torch.tensor(labels)
        
        # re-format the data
        # batch length list, where each list is length 96
        reformatted_data = [[] for i in range(batch_size)]
        for group in inputs:
            j = 0
            times = group[0]  # len batch_size
            characters = group[1]  # len batch_size
            updown = group[2]  # len batch_size
            while j < batch_size:
                reformatted_data[j].append([float(times[j]), float(characters[j]), float(updown[j])])
                j = j + 1
        
        reformatted_data = torch.tensor(reformatted_data)
        outputs = model(reformatted_data.to(device))
        loss = criterion(outputs.to(device), labels[0].to(device))
        temp_train_losses.append(loss.item())
        loss.backward()
        optimizer.step()

        _, predicted_train = torch.max(outputs.data, dim = 0)
        total_train += labels.size(0)
        correct_train += (predicted_train == labels[0].to(device)).sum().item()
    
    train_losses.append(np.mean(temp_train_losses))  # only append at the end of the batch
    train_accuracy = correct_train / total_train
    train_accuracies.append(train_accuracy)
    print(f'Epoch [{epoch+1}/{epochs}], Training Accuracy: {train_accuracy:.4f}')

    # Validation
    model.eval()
    val_predictions = []
    val_targets = []
    temp_val_losses = []
    with torch.no_grad():
        for inputs, labels in val_loader:
            labels = torch.tensor(labels)
            reformatted_data = [[] for i in range(batch_size)]
            for group in inputs:
                j = 0
                times = group[0]  # len batch_size
                characters = group[1]  # len batch_size
                updown = group[2]  # len batch_size
                while j < batch_size:
                    reformatted_data[j].append([float(times[j]), float(characters[j]), float(updown[j])])
                    j = j + 1
            
            reformatted_data = torch.tensor(reformatted_data)

            outputs = model(reformatted_data.to(device))
            loss = criterion(outputs.to(device), labels[0].to(device))
            temp_val_losses.append(loss)
            predicted = torch.argmax(outputs)
            val_predictions.append(predicted.cpu().numpy())
            val_targets.append(labels.cpu().numpy())
    
    val_accuracy = accuracy_score(val_targets, val_predictions)
    val_accuracies.append(val_accuracy)
    val_losses.append(np.mean(temp_val_losses))
    print(f'Epoch [{epoch+1}/{epochs}], Validation Accuracy: {val_accuracy:.4f}')


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [1/100], Validation Accuracy: 0.3221


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [2/100], Validation Accuracy: 0.3225


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [3/100], Validation Accuracy: 0.3196


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [4/100], Validation Accuracy: 0.3225


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [5/100], Validation Accuracy: 0.3312


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [6/100], Validation Accuracy: 0.3154


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [7/100], Validation Accuracy: 0.3325


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [8/100], Validation Accuracy: 0.3221


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [9/100], Validation Accuracy: 0.3192


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [10/100], Validation Accuracy: 0.3287


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [11/100], Validation Accuracy: 0.3221


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [12/100], Validation Accuracy: 0.3350


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [13/100], Validation Accuracy: 0.3300


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [14/100], Validation Accuracy: 0.3233


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [15/100], Validation Accuracy: 0.3196


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [16/100], Validation Accuracy: 0.3237


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [17/100], Validation Accuracy: 0.3221


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [18/100], Validation Accuracy: 0.3346


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [19/100], Validation Accuracy: 0.3404


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [20/100], Validation Accuracy: 0.3017


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [21/100], Validation Accuracy: 0.3233


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [22/100], Validation Accuracy: 0.3250


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [23/100], Validation Accuracy: 0.3325


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [24/100], Validation Accuracy: 0.3350


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [25/100], Validation Accuracy: 0.3221


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [26/100], Validation Accuracy: 0.3167


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [27/100], Validation Accuracy: 0.3196


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [28/100], Validation Accuracy: 0.3337


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [29/100], Validation Accuracy: 0.3279


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [30/100], Validation Accuracy: 0.3187


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [31/100], Validation Accuracy: 0.3233


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [32/100], Validation Accuracy: 0.3183


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [33/100], Validation Accuracy: 0.3183


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [34/100], Validation Accuracy: 0.3208


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [35/100], Validation Accuracy: 0.3246


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [36/100], Validation Accuracy: 0.3167


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [37/100], Validation Accuracy: 0.3283


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [38/100], Validation Accuracy: 0.3175


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [39/100], Validation Accuracy: 0.3258


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [40/100], Validation Accuracy: 0.3229


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [41/100], Validation Accuracy: 0.3246


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [42/100], Validation Accuracy: 0.3229


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [43/100], Validation Accuracy: 0.3217


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [44/100], Validation Accuracy: 0.3279


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [45/100], Validation Accuracy: 0.3292


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [46/100], Validation Accuracy: 0.3208


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [47/100], Validation Accuracy: 0.3233


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [48/100], Validation Accuracy: 0.3362


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [49/100], Validation Accuracy: 0.3229


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [50/100], Validation Accuracy: 0.3246


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [51/100], Validation Accuracy: 0.3262


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [52/100], Validation Accuracy: 0.3175


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [53/100], Validation Accuracy: 0.3212


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [54/100], Validation Accuracy: 0.3362


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [55/100], Validation Accuracy: 0.3283


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [56/100], Validation Accuracy: 0.3246


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [57/100], Validation Accuracy: 0.3354


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [58/100], Validation Accuracy: 0.3179


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [59/100], Validation Accuracy: 0.3317


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [60/100], Validation Accuracy: 0.3258


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [61/100], Validation Accuracy: 0.3200


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [62/100], Validation Accuracy: 0.3300


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [63/100], Validation Accuracy: 0.3225


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [64/100], Validation Accuracy: 0.3217


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [65/100], Validation Accuracy: 0.3312


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [66/100], Validation Accuracy: 0.3321


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [67/100], Validation Accuracy: 0.3242


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [68/100], Validation Accuracy: 0.3204


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [69/100], Validation Accuracy: 0.3279


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [70/100], Validation Accuracy: 0.3317


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [71/100], Validation Accuracy: 0.3308


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [72/100], Validation Accuracy: 0.3275


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [73/100], Validation Accuracy: 0.3271


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [74/100], Validation Accuracy: 0.3329


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [75/100], Validation Accuracy: 0.3287


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [76/100], Validation Accuracy: 0.3279


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [77/100], Validation Accuracy: 0.3262


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [78/100], Validation Accuracy: 0.3287


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [79/100], Validation Accuracy: 0.3229


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [80/100], Validation Accuracy: 0.3237


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [81/100], Validation Accuracy: 0.3233


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [82/100], Validation Accuracy: 0.3271


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [83/100], Validation Accuracy: 0.3237


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [84/100], Validation Accuracy: 0.3225


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [85/100], Validation Accuracy: 0.3337


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [86/100], Validation Accuracy: 0.3242


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [87/100], Validation Accuracy: 0.3262


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [88/100], Validation Accuracy: 0.3258


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [89/100], Validation Accuracy: 0.3271


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [90/100], Validation Accuracy: 0.3233


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [91/100], Validation Accuracy: 0.3187


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [92/100], Validation Accuracy: 0.3279


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [93/100], Validation Accuracy: 0.3233


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [94/100], Validation Accuracy: 0.3300


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [95/100], Validation Accuracy: 0.3242


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [96/100], Validation Accuracy: 0.3300


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [97/100], Validation Accuracy: 0.3229


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [98/100], Validation Accuracy: 0.3271


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [99/100], Validation Accuracy: 0.3287


  labels = torch.tensor(labels)
  labels = torch.tensor(labels)


Epoch [100/100], Validation Accuracy: 0.3233
