In [4]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [3]:
!python -c "import torch; print(torch.__version__)"
!python -c "import torch; print(torch.version.cuda)"

!pip install torch-geometric-temporal

2.0.1+cu118
11.8
Collecting torch-geometric-temporal
  Downloading torch_geometric_temporal-0.54.0.tar.gz (48 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m48.1/48.1 kB[0m [31m956.0 kB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pandas<=1.3.5 (from torch-geometric-temporal)
  Downloading pandas-1.3.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m11.5/11.5 MB[0m [31m80.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch_sparse (from torch-geometric-temporal)
  Downloading torch_sparse-0.6.18.tar.gz (209 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m210.0/210.0 kB[0m [31m22.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting torch_scatter (from torch-geometric-temporal)
  Downloading torch_scatter-2.1.2.tar.gz (108 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━

### `ASLDatasetLoader` Class

The `ASLDatasetLoader` class is designed for loading and processing the ASL dataset. Given a directory, it reads sign language data from JSON files and constructs graph representations suitable for graph-based neural networks. Crucially, the class converts JSON data into PyTorch Geometric (PyG) `Data` objects comprising `x` (node features), `edge_index` (graph connectivity), and `y' (labels) attributes.

**Methods**:

- `_create_sign_to_label_map`: Generates a mapping from sign names to unique labels.

- `_read_file_data`: Reads data from a given JSON file.

- `_augment_data`: Implements data augmentation by applying random rotation, translation, and scaling to landmarks, which can enhance the model's robustness.

- `_create_graph_from_frame`: Constructs a PyG `Data` object from frame data, concentrating on hand and face landmarks. Edges are created between consecutive landmarks and between left and right hand landmarks. Additional features, like hand-to-face distances, are also computed.

- `get_dataset`: Assembles the dataset, optionally incorporating data augmentation. The function outputs a list of PyG `Data` objects ready for graph neural network processing.

In [5]:
import torch
import os
import json
import numpy as np
from torch_geometric.data import Data

HAND_TO_FACE_THRESHOLD = 0.05

class ASLDatasetLoader:
    def __init__(self, directory_path):
        self.directory_path = directory_path
        self.sign_to_label = self._create_sign_to_label_map()

    def _create_sign_to_label_map(self):
        signs = [os.path.splitext(filename)[0] for filename in os.listdir(self.directory_path)]
        return {sign: i for i, sign in enumerate(signs)}

    def _read_file_data(self, file_path):
        with open(file_path, 'r') as f:
            return json.load(f)

    def _augment_data(self, frame_data, rotation_range=10, translation_range=0.05, scaling_range=0.1):
        """
        Augment the frame data with random rotation, translation, and scaling.

        :param frame_data: Dictionary containing frame landmarks and deltas.
        :param rotation_range: Maximum rotation angle in degrees.
        :param translation_range: Maximum translation as a fraction of landmark range.
        :param scaling_range: Maximum scaling factor.
        :return: Augmented frame data.
        """
        landmarks = np.array(frame_data["landmarks"])
        centroid = np.mean(landmarks, axis=0)

        # Random rotation
        theta = np.radians(np.random.uniform(-rotation_range, rotation_range))
        rotation_matrix = np.array([
            [np.cos(theta), -np.sin(theta)],
            [np.sin(theta), np.cos(theta)]
        ])
        landmarks = np.dot(landmarks - centroid, rotation_matrix) + centroid

        # Random translation
        max_translation = translation_range * (landmarks.max(axis=0) - landmarks.min(axis=0))
        translations = np.random.uniform(-max_translation, max_translation)
        landmarks += translations

        # Random scaling
        scale = np.random.uniform(1 - scaling_range, 1 + scaling_range)
        landmarks = centroid + scale * (landmarks - centroid)

        frame_data["landmarks"] = landmarks.tolist()
        return frame_data

    def _create_graph_from_frame(self, sign_name, frame_data, landmark_types):
        left_hand_indices = [i for i, t in enumerate(landmark_types) if t == "L"]
        right_hand_indices = [i for i, t in enumerate(landmark_types) if t == "R"]
        face_indices = [i for i, t in enumerate(landmark_types) if t == "F"]

        landmarks = np.array(frame_data["landmarks"])
        deltas = np.array(frame_data["deltas"])

        # Create weights based on landmark importance
        weights = [2 if t == "L" or t == "R" else 1 for t in landmark_types]

        # Create edges based on the number of available landmarks (or nodes)
        edges = [[i, i + 1] for i in range(len(landmarks) - 1)]

        # Add edges between the left and right hand landmarks
        for i in left_hand_indices:
            for j in right_hand_indices:
                edges.append([i, j])

        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

        # Compute additional features like hand-to-face and hand-to-body distances
        hand_to_face_contact = []
        for idx, ltype in enumerate(landmark_types):
            if ltype in ["L", "R"] and any(t == "F" for t in landmark_types):
                min_distance = min([np.linalg.norm(landmarks[idx] - landmarks[j]) for j, t in enumerate(landmark_types) if t == "F"])
                hand_to_face_contact.append(1 if min_distance < HAND_TO_FACE_THRESHOLD else 0)
            else:
                hand_to_face_contact.append(0)

        # Reshape the 1D arrays to 2D for concatenation
        weights_2d = np.array(weights)[:, np.newaxis]
        hand_to_face_contact_2d = np.array(hand_to_face_contact)[:, np.newaxis]

        # Concatenate landmarks, deltas, importance weights, and hand-to-face contact features
        x = torch.tensor(np.hstack((landmarks, deltas, weights_2d, hand_to_face_contact_2d)), dtype=torch.float)
        y = torch.tensor([self.sign_to_label[sign_name]], dtype=torch.long)

        return Data(x=x, edge_index=edge_index, y=y)


    def get_dataset(self, augment=False):
        dataset = []

        for filename in os.listdir(self.directory_path):
            sign_name = os.path.splitext(filename)[0]
            file_path = os.path.join(self.directory_path, filename)
            sign_data = self._read_file_data(file_path)

            for frame_data in sign_data["frames"]:
                landmark_types = sign_data.get("landmark_types", ["F", "L", "P", "R"])  # defaulting to all types

                if augment:
                  frame_data = self._augment_data(frame_data)
                graph_data = self._create_graph_from_frame(sign_name, frame_data, landmark_types)

                dataset.append(graph_data)

        return dataset

    def number_of_classes(self):
        return len(self.sign_to_label)

### `ASLGraphClassifier` Class

The `ASLGraphClassifier` is a Graph Convolutional Network (GCN) classifier that handles graph-structured data. It accepts a PyG `Data` object as input and produces class logits via the forward pass, which, when paired with a suitable loss function, aids in model training.

**Methods**:

- `forward`: Defines the forward pass of the model. Accepting a PyG `Data` object containing the entire graph, the method comprises two GCN layers with subsequent batch normalization and dropout layers process the input. Post global max-pooling, two linear layers coupled with dropout ensure final classification, leading to log-softmax outputs.

In [7]:
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_max_pool, global_mean_pool

class ASLGraphClassifier(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super(ASLGraphClassifier, self).__init__()
        self.conv1 = GCNConv(num_features, 256)  # Increased channels
        self.bn1 = torch.nn.BatchNorm1d(256)    # Batch normalization layer
        self.conv2 = GCNConv(256, 512)          # Increased channels
        self.bn2 = torch.nn.BatchNorm1d(512)    # Batch normalization layer
        self.lin1 = torch.nn.Linear(512, 256)
        self.lin2 = torch.nn.Linear(256, num_classes)
        self.dropout = torch.nn.Dropout(p=0.5)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.bn1(self.conv1(x, edge_index)))
        x = self.dropout(x)
        x = F.relu(self.bn2(self.conv2(x, edge_index)))
        x = self.dropout(x)
        x = global_max_pool(x, batch)
        x = F.relu(self.lin1(x))
        x = self.dropout(x)
        x = self.lin2(x)
        return F.log_softmax(x, dim=1)

In [9]:
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.loader import DataLoader
from collections import Counter
import random

EPOCHS = 200
LEARNING_RATE = 0.001


def stratified_data_split(data_list, test_size=0.2):
    """
    This function splits a dataset into training and testing subsets, preserving
    the class distribution by leveraging the stratification capabilities of
    `train_test_split` from `sklearn`. Stratification helps with potential class
    imbalances.
    """
    # Extract labels from data list
    labels = [data.y.item() for data in data_list]

    # Use sklearn's train_test_split with stratify option
    train_data, test_data = train_test_split(data_list, test_size=test_size, stratify=labels, random_state=42)

    return train_data, test_data


def validate(loader, model, device):
    """
    Used to evaluate the model on validation/test data, computing accuracy as a
    performance metric, and offering insights into the model's efficacy.
    """
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        with torch.no_grad():
            out = model(data)
        pred = out.argmax(dim=1)
        correct += int((pred == data.y).sum())
    return correct / len(loader.dataset)

def train():
    """
    The `train` function establishes the training loop for the graph-based
    neural network. It enacts typical training loop tasks like logging
    epoch-wise loss, validation, and early stopping.

    The function also harnesses schedulers, regularization techniques, and
    gradient clipping to ensure smooth and optimal training.
    """
    directory_path = "/content/drive/MyDrive/Colab Notebooks/DGMD E-14 Project/Datasets/ASL"
    loader = ASLDatasetLoader(directory_path)

    # Create the entire dataset without augmentation and then perform stratified split
    data_list = loader.get_dataset()
    train_dataset, test_dataset = stratified_data_split(data_list, test_size=0.2)

    # Now augment only the training dataset
    augmented_train_dataset = loader.get_dataset(augment=True)

    num_classes = loader.number_of_classes()

    train_labels = [data.y.item() for data in train_dataset]
    test_labels = [data.y.item() for data in test_dataset]

    print("Training label distribution:", Counter(train_labels))
    print("Test label distribution:", Counter(test_labels))

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    num_features = train_dataset[0].x.size(1)
    model = ASLGraphClassifier(num_features=num_features, num_classes=num_classes).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=5e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.7, patience=5, verbose=True)

    max_epochs_without_improvement = 20
    epochs_without_improvement = 0
    best_val_accuracy = 0

    model.train()
    for epoch in range(EPOCHS):
        total_loss = 0
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            out = model(batch)
            loss = F.nll_loss(out, batch.y)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)

            optimizer.step()
            total_loss += loss.item()

            # Check for NaN loss
            if np.isnan(loss.item()):
                print("Warning: NaN loss detected!")

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch}, Loss: {avg_loss}")

        val_accuracy = validate(test_loader, model, device)
        scheduler.step(val_accuracy)

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            epochs_without_improvement = 0
        else:
            epochs_without_improvement += 1

        if epochs_without_improvement >= max_epochs_without_improvement:
            print("Early stopping triggered.")
            break

    model.eval()
    correct = 0
    all_preds = []
    all_labels = []

    for batch in test_loader:
        batch = batch.to(device)
        with torch.no_grad():
            pred = model(batch).max(dim=1)[1]
            all_preds.extend(pred.cpu().numpy())
            all_labels.extend(batch.y.cpu().numpy())
            correct += pred.eq(batch.y).sum().item()

    print(f"Accuracy: {correct / len(test_dataset)}")
    print("Sample predictions:", all_preds[:20])
    print("Sample true labels:", all_labels[:20])

In [10]:
train()

Training label distribution: Counter({0: 438, 2: 430, 131: 430, 3: 420, 125: 402, 1: 400, 116: 363, 128: 346, 130: 338, 126: 338, 104: 332, 127: 330, 85: 324, 101: 323, 103: 316, 113: 310, 98: 309, 84: 308, 96: 307, 102: 306, 129: 304, 121: 298, 97: 297, 86: 297, 87: 296, 100: 295, 83: 295, 95: 295, 99: 294, 88: 293, 64: 292, 89: 290, 90: 289, 94: 288, 66: 286, 119: 282, 77: 279, 91: 279, 76: 278, 63: 278, 93: 278, 73: 277, 50: 277, 80: 276, 92: 275, 72: 274, 78: 274, 68: 274, 52: 273, 55: 273, 79: 271, 74: 270, 70: 270, 53: 270, 71: 270, 120: 269, 56: 269, 67: 268, 75: 266, 82: 265, 117: 265, 62: 265, 45: 264, 57: 263, 81: 263, 48: 263, 60: 262, 69: 262, 35: 262, 65: 262, 61: 261, 54: 261, 59: 261, 40: 260, 123: 258, 41: 258, 38: 257, 49: 256, 32: 255, 34: 255, 24: 255, 30: 254, 44: 254, 58: 253, 51: 253, 46: 252, 28: 251, 26: 250, 43: 250, 47: 250, 25: 250, 29: 249, 37: 249, 14: 247, 17: 247, 42: 247, 15: 247, 114: 247, 39: 247, 23: 246, 122: 246, 20: 246, 31: 244, 19: 244, 226: 243,