In [72]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset, WeightedRandomSampler
from pytorch_tcn import TCN
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score

In [74]:
# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [76]:
class Video:
    # Class to load and process a video given its ID
    def __init__(self, video_id):
        self.video_id = video_id
        self.data = torch.load("all_videos.pt", map_location=device, weights_only=True)[video_id-1] # [200, 17, 2]
        self.data = self.data.view(200, -1).transpose(0, 1) # -> (200, 34) -> (34, 200)

In [78]:
# Custom Dataset
class DeadliftVideoDataset(Dataset):
    def __init__(self, video_ids, labels_csv_path, transform=None):
        self.video_ids = video_ids # List of video IDs
        self.labels_dict = get_labels_from_csv(video_ids, labels_csv_path)
        self.transform = transform # Optional transform to be applied to the video tensor

    def __len__(self):
        return len(self.video_ids)

    def __getitem__(self, idx):
        vid = self.video_ids[idx]
        video_instance = Video(vid)
        data = video_instance.data  # Expected shape: [200, 17, 2]
        if data is None:
            raise ValueError(f"Data for video {vid} could not be loaded.")
        if self.transform:
            data = self.transform(data)
        labels = self.labels_dict[vid]
        return data, labels, vid

In [80]:
# Model Definition
class DeadliftTCN(nn.Module):
    def __init__(self, input_size, num_channels, num_classes, kernel_size=2, dropout=0.2):
        super(DeadliftTCN, self).__init__()
        self.tcn = TCN(input_size, num_channels, kernel_size=kernel_size, dropout=dropout)
        self.fc = nn.Linear(num_channels[-1], num_classes)

    def forward(self, x): # x shape: (batch_size, input_size, sequence_length)
        y = self.tcn(x)  # Shape: (batch, num_channels[-1], sequence_length)
        y = torch.mean(y, dim=2)  # Global average pooling over time dimension
        logits = self.fc(y)       # Raw logits (BCEWithLogitsLoss expects raw logits)
        return logits

In [82]:
def get_labels_from_csv(video_ids, csv_path="DLdataset/simple_labels.csv"):
    # Creates a dictionary mapping video IDs to their corresponding label tensors + duplicates for flips
    labels_csv = pd.read_csv(csv_path).set_index("ID") # Read the CSV and set 'ID' as the index
    labels_dict = {}
    for vid in video_ids:
        try:
            # Get the label values, convert them to float, and create a tensor
            label_values = labels_csv.loc[vid].values.astype(float)
            labels_dict[vid] = torch.tensor(label_values, dtype=torch.float32)
        except Exception as e:
            print(f"Label for video {vid} not found or error occurred: {e}")
    
    n = len(labels_dict) # Determine the number of entries
    duplicate_dict = {vid + n: label for vid, label in labels_dict.items()} # Create duplicate dictionary with shifted index (for flipped vids)
    combined = {} # Combine the original and duplicate dictionaries
    combined.update(labels_dict)
    combined.update(duplicate_dict)
    return combined

In [84]:
# Utility Functions
def compute_sample_weights(dataset):
    # Computes sample weights to balance positive and negative classes
    annotations = []
    for vid in dataset.video_ids:
        labels = dataset.labels_dict.get(vid, None)
        if labels is not None:
            annotations.append(1 if labels.sum() > 0 else 0)
        else:
            annotations.append(0)
    annotations = pd.Series(annotations)
    num_pos = annotations.sum()
    num_neg = len(annotations) - num_pos
    weights = []
    for is_positive in annotations:
        if is_positive:
            weight = len(annotations) / (2.0 * num_pos)
        else:
            weight = len(annotations) / (2.0 * num_neg)
        weights.append(weight)
    return weights

In [86]:
def compute_pos_weight(labels_csv_path):
    # Compute pos_weight for BCEWithLogitsLoss from the CSV file
    df = pd.read_csv(labels_csv_path)
    pos_weight_list = []
    for col in ['Red', 'Blue', 'Yellow']:
        pos_count = df[col].sum()
        neg_count = len(df) - pos_count
        ratio = neg_count / pos_count if pos_count > 0 else 1.0
        pos_weight_list.append(ratio)
    return torch.tensor(pos_weight_list, dtype=torch.float32)

In [88]:
def create_dataloaders(video_ids, labels_csv_path, batch_size=16, test_size=0.08):
    # Create training and validation DataLoaders with a WeightedRandomSampler for the training set
    dataset = DeadliftVideoDataset(video_ids, labels_csv_path)
    indices = list(range(len(dataset)))
    train_indices, val_indices = train_test_split(indices, test_size=test_size, random_state=18)
    train_subset = Subset(dataset, train_indices)
    val_subset = Subset(dataset, val_indices)
    sample_weights = compute_sample_weights(dataset) # Compute sample weights only for the training set
    train_weights = [sample_weights[i] for i in train_indices]
    train_sampler = WeightedRandomSampler(train_weights, num_samples=len(train_weights), replacement=True)
    train_loader = DataLoader(train_subset, batch_size=batch_size, sampler=train_sampler)
    val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader, dataset

In [90]:
# Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=10):
    # Train the model and evaluate it after each epoch.
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels, vids in train_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)  # Raw logits
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
        epoch_loss = running_loss / len(train_loader.dataset)
        print(f"Epoch {epoch+1}/{num_epochs} | Training Loss: {epoch_loss:.4f}")
        #evaluate_model(model, val_loader, device)

In [92]:
def judge_score(labels, predictions):
        # For each sample, sum the three values.
        pred_goodlift = (predictions.sum(dim=1) == 0) # If the sum is 0, it's a goodlift
        true_goodlift = (labels.sum(dim=1) == 0)
        misclassified = (pred_goodlift != true_goodlift).float().mean().item() # misclassification when the prediction doesn't match the label.
        return (1.0 - misclassified)

In [94]:
# Evaluation Function
def evaluate_model(model, data_loader, device):
    # Evaluate the model on the given data_loader and return metrics
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels, vids in data_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            outputs = model(inputs)
            # Apply sigmoid to convert logits to probabilities
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.9).float()
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())
    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)
    acc = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    judge = judge_score(all_labels, all_preds)
    print(f"Evaluation Metrics -> Acc: {acc:.4f} | Precision: {precision:.4f} | Recall: {recall:.4f} | F1: {f1:.4f}")
    print(f"Evaluation Metrics -> Judge Score: {judge:.4f}")

In [96]:
def predict_video(video_id, model, device, video_file="all_videos.pt", threshold=0.9):
    #Given a video ID, this function loads the preprocessed video data,
    all_videos = torch.load(video_file, map_location=device) # Load the stacked tensor of all videos (shape: (N, 200, 17, 2))
    video_tensor = all_videos[video_id - 1]  # shape: (200, 17, 2)
    video_tensor = video_tensor.view(200, -1).transpose(0, 1)  # shape: (34, 200)
    video_tensor = video_tensor.unsqueeze(0).to(device) # Add a batch dimension: final shape (1, 34, 200)
    model.eval() # Set model to eval mode
    with torch.no_grad():
        logits = model(video_tensor)  # raw logits output
        probs = torch.sigmoid(logits) # convert to probabilities
        prediction = (probs > threshold).float() # Convert probabilities to binary predictions using the threshold
    return prediction.cpu().tolist() # Return the prediction as a Python list

In [98]:
# Initialize model and hyperparameters
labels_csv_path = "DLdataset/simple_labels.csv"
video_ids = pd.read_csv(labels_csv_path)["ID"].tolist()
batch_size = 64
train_loader, val_loader, dataset = create_dataloaders(video_ids, labels_csv_path, batch_size=batch_size) # Create DataLoaders
pos_weight_tensor = compute_pos_weight(labels_csv_path).to(device) # Compute pos_weight for loss balancing
print(f"Pos weight: {pos_weight_tensor}")

# Model hyperparameters
input_size = 34 # 17 joints * 2 coordinates per frame, flattened
num_classes = 3 # 3 infractions
num_channels = [50,75,100,125,125] # Filters per layer
kernel_size = 7
dropout = 0.5

# Initialize model, loss function, and optimizer
model = DeadliftTCN(input_size, num_channels, num_classes, kernel_size, dropout).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)
optimizer = optim.Adam(model.parameters(), lr=0.001)

Pos weight: tensor([11.5000, 13.4231,  2.1513], device='cuda:0')


In [None]:
# Train the model (evaluation happens inside the training loop via evaluate_model)
num_epochs = 400
train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=num_epochs)

In [100]:
model.load_state_dict(torch.load("model_weights.pt", map_location=device, weights_only=True))

<All keys matched successfully>

In [102]:
evaluate_model(model, val_loader, device)

Evaluation Metrics -> Acc: 0.8000 | Precision: 0.7418 | Recall: 0.9231 | F1: 0.8215
Evaluation Metrics -> Judge Score: 0.8333
