In [5]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score
import random

In [40]:
#Load dataset
df = pd.read_csv('tracking_distraction_expanded_200_samples_dataset.csv')

# Auto-infer features and label
feature_columns = [col for col in df.columns if col not in ['Sequence_ID', 'Timestep', 'Distracted_Label']]
label_column = 'Distracted_Label'
sequence_length = df.groupby('Sequence_ID').size().mode()[0]

# Dataset class
class BallTrackingDataset(Dataset):
    def __init__(self, dataframe, sequence_length):
        self.data = []
        self.labels = []
        grouped = dataframe.groupby('Sequence_ID')
        for _, group in grouped:
            group = group.sort_values('Timestep')
            features = group[feature_columns].values
            label = group[label_column].iloc[0]
            if len(features) == sequence_length:
                self.data.append(features)
                self.labels.append([label])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return torch.tensor(self.data[idx], dtype=torch.float32), torch.tensor(self.labels[idx], dtype=torch.float32)

# GRU model with dropout
class GRUDistractionModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size, dropout_rate=0.3):
        super(GRUDistractionModel, self).__init__()
        self.gru = nn.GRU(input_size, hidden_size, batch_first=True)
        self.dropout = nn.Dropout(dropout_rate)
        self.fc = nn.Linear(hidden_size, output_size)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        h0 = torch.zeros(1, x.size(0), hidden_size)
        out, _ = self.gru(x, h0)
        out = self.dropout(out[:, -1, :])
        out = self.fc(out)
        return self.sigmoid(out)

# Fixed hyperparameters
input_size = len(feature_columns)
hidden_size = 64
output_size = 1
batch_size = 16
epochs = 20
learning_rate = 0.005

# Dataset and split
dataset = BallTrackingDataset(df, sequence_length)
indices = list(range(len(dataset)))
labels = [dataset[i][1].item() for i in indices]
train_idx, val_idx = train_test_split(indices, test_size=0.2, stratify=labels)

# Check label distribution in validation set
val_labels = [dataset[i][1].item() for i in val_idx]
val_distribution = {label: val_labels.count(label) for label in set(val_labels)}
print("Validation Label Distribution:", val_distribution)

train_loader = DataLoader(torch.utils.data.Subset(dataset, train_idx), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(torch.utils.data.Subset(dataset, val_idx), batch_size=batch_size)

# Model setup
model = GRUDistractionModel(input_size, hidden_size, output_size)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    model.train()
    for sequences, labels in train_loader:
        outputs = model(sequences).squeeze()
        loss = criterion(outputs, labels.squeeze())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{epochs} completed")

# Accuracy evaluation
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for sequences, labels in val_loader:
        outputs = model(sequences).squeeze()
        preds = (outputs > 0.5).float()
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.squeeze().tolist())

accuracy = accuracy_score(all_labels, all_preds)
print(f"\nFinal Validation Accuracy: {accuracy:.4f}")


Validation Label Distribution: {0.0: 2}
Epoch 1/20 completed
Epoch 2/20 completed
Epoch 3/20 completed
Epoch 4/20 completed
Epoch 5/20 completed
Epoch 6/20 completed
Epoch 7/20 completed
Epoch 8/20 completed
Epoch 9/20 completed
Epoch 10/20 completed
Epoch 11/20 completed
Epoch 12/20 completed
Epoch 13/20 completed
Epoch 14/20 completed
Epoch 15/20 completed
Epoch 16/20 completed
Epoch 17/20 completed
Epoch 18/20 completed
Epoch 19/20 completed
Epoch 20/20 completed

Final Validation Accuracy: 1.0000
