# ASLive Sign2Text Model

This notebook implements the sign language to text model following the architecture:
- **Vision Layer (CNN)**: Extracts spatial features from each frame
- **Positional Encoding (PE)**: Adds temporal position information
- **Attention Layer (LSTM)**: Processes temporal sequence with attention
- **FC Layer**: Final classification layer


In [18]:
!pip install kagglehub torchcodec torchvision



In [19]:
# Before running, add everything from SQ_dataloader.ipynb into the cell below and run

## 1. Data Loading (from SQ_dataloader)


In [20]:
#add SQ_dataloader code here
from torch.utils.data import Dataset
from torchvision import transforms
from torchcodec.decoders import VideoDecoder
import kagglehub

class WLASLTorchCodec(Dataset):
  def __init__(self, json_path=None, video_dir=None, download=True, split="train", num_frames=32, transform=None, max_samples=None):
    if (json_path is None or video_dir is None) and not download:
      raise ValueError("json_path and video_dir must be provided when download=False")
    if download:
      path = kagglehub.dataset_download("sttaseen/wlasl2000-resized")
      print("Downloaded at path: ", path)
      self.video_dir = os.path.join(path, "wlasl-complete", "videos")
      json_path = os.path.join(path, "wlasl-complete", "WLASL_v0.3.json")
    else:
      self.video_dir = video_dir
    self.num_frames = num_frames
    self.transform = transform

    # Read json
    with open(json_path, "r") as f:
      data = json.load(f)

    self.samples = []
    self.label_map = {}
    label_id = 0

    for entry in data:
      gloss = entry["gloss"]

      if gloss not in self.label_map:
        self.label_map[gloss] = label_id
        label_id += 1

      label = self.label_map[gloss]

      for inst in entry["instances"]:
        if inst["split"] != split:
          continue

        video_id = inst["video_id"]
        file_path = os.path.join(self.video_dir, f"{video_id}.mp4")

        if os.path.isfile(file_path):
          self.samples.append((file_path, label))
    
    # Limit samples for quick testing
    if max_samples is not None and len(self.samples) > max_samples:
      self.samples = self.samples[:max_samples]
    
    self.num_classes = label_id

  def __len__(self):
    return len(self.samples)

  def _sample_frames(self, frames):
    """frames is a list of PIL images or Tensors."""
    T = len(frames)
    idx = torch.linspace(0, T - 1, self.num_frames).long()
    return [frames[i] for i in idx]

  def __getitem__(self, idx):
    video_path, label = self.samples[idx]

    decoder = VideoDecoder(video_path)

    frames = []
    for chunk in decoder:
      # chunk is a Tensor of frames: (T, H, W) OR (T, H, W, C)
      for frame_tensor in chunk:
        # Handle grayscale frames (H, W)
        if frame_tensor.dim() == 2:
          frame_tensor = frame_tensor.unsqueeze(2)  # → (H, W, 1)

        # Handle RGB frames (H, W, C)
        if frame_tensor.dim() == 3:
          pass
        else:
          raise ValueError(f"Unexpected frame shape: {frame_tensor.shape}")

        # Convert to C x H x W
        frame_chw = frame_tensor.permute(2, 0, 1)

        frame_pil = transforms.ToPILImage()(frame_chw)
        # Ensure RGB (some videos are grayscale)
        frame_pil = frame_pil.convert('RGB')
        frames.append(frame_pil)

    # Guard for short videos
    if len(frames) < self.num_frames:
        while len(frames) < self.num_frames:
            frames.extend(frames)
        frames = frames[: self.num_frames]

    # Uniform sampling
    frames = self._sample_frames(frames)

    if self.transform:
      frames = torch.stack([self.transform(f) for f in frames])
    else:
      frames = torch.stack([transforms.ToTensor()(f) for f in frames])

    return frames, label



# once this cell successfully runs, continue with the other cells

In [None]:
import os
import json
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchcodec.decoders import VideoDecoder
import numpy as np
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


## 2. Vision Layer (CNN Backbone)

The Vision Layer extracts spatial features from each video frame using a CNN. We use a pretrained ResNet-18 as the backbone and remove the final classification layer to get feature embeddings.


In [22]:
class VisionLayer(nn.Module):
    """CNN backbone for extracting spatial features from video frames.

    Uses pretrained ResNet-18 as feature extractor.
    Input: (batch, T, C, H, W) - batch of T frames
    Output: (batch, T, feature_dim) - feature vectors for each frame
    """

    def __init__(self, feature_dim=512, pretrained=True, freeze_backbone=False):
        super(VisionLayer, self).__init__()

        # Load pretrained ResNet-18
        resnet = models.resnet18(weights='IMAGENET1K_V1' if pretrained else None)

        # Remove the final FC layer
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # ResNet-18 outputs 512-dim features
        self.resnet_feature_dim = 512

        # Optional projection layer to adjust feature dimension
        if feature_dim != self.resnet_feature_dim:
            self.projection = nn.Linear(self.resnet_feature_dim, feature_dim)
        else:
            self.projection = None

        self.feature_dim = feature_dim

        # Freeze backbone if specified
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, T, C, H, W)
        Returns:
            Feature tensor of shape (batch, T, feature_dim)
        """
        batch_size, T, C, H, W = x.shape

        # Reshape to process all frames together: (batch * T, C, H, W)
        x = x.view(batch_size * T, C, H, W)

        # Extract features: (batch * T, 512, 1, 1)
        features = self.backbone(x)

        # Flatten: (batch * T, 512)
        features = features.view(batch_size * T, -1)

        # Project features if needed
        if self.projection is not None:
            features = self.projection(features)

        # Reshape back: (batch, T, feature_dim)
        features = features.view(batch_size, T, self.feature_dim)

        return features


## 3. Positional Encoding (PE)

Sinusoidal positional encoding adds temporal position information to the frame features before feeding them to the LSTM.


In [23]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding for temporal sequences.

    Adds position information to help the model understand the order of frames.
    """

    def __init__(self, d_model, max_len=500, dropout=0.1):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)

        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension: (1, max_len, d_model)
        pe = pe.unsqueeze(0)

        # Register as buffer (not a parameter, but should be saved/loaded)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, T, d_model)
        Returns:
            Tensor with positional encoding added: (batch, T, d_model)
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


## 4. Attention Layer (LSTM with Attention)

Bidirectional LSTM processes the sequence of frame features, followed by an attention mechanism to weight the importance of different time steps.


In [24]:
class Attention(nn.Module):
    """Attention mechanism for weighting LSTM outputs.

    Computes attention weights over the sequence and returns a weighted sum.
    """

    def __init__(self, hidden_dim):
        super(Attention, self).__init__()

        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.Tanh(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, lstm_output):
        """
        Args:
            lstm_output: LSTM outputs of shape (batch, T, hidden_dim)
        Returns:
            context: Weighted sum of shape (batch, hidden_dim)
            attention_weights: Attention weights of shape (batch, T)
        """
        # Compute attention scores: (batch, T, 1)
        scores = self.attention(lstm_output)

        # Apply softmax over time dimension: (batch, T, 1)
        attention_weights = F.softmax(scores, dim=1)

        # Compute weighted sum: (batch, hidden_dim)
        context = torch.sum(attention_weights * lstm_output, dim=1)

        return context, attention_weights.squeeze(-1)


class AttentionLSTM(nn.Module):
    """Bidirectional LSTM with attention mechanism.

    Processes temporal sequence of frame features and outputs a fixed-size representation.
    """

    def __init__(self, input_dim, hidden_dim=256, num_layers=2, dropout=0.3, bidirectional=True):
        super(AttentionLSTM, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )

        # Attention mechanism
        self.attention = Attention(hidden_dim * self.num_directions)

        # Output dimension
        self.output_dim = hidden_dim * self.num_directions

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, T, input_dim)
        Returns:
            output: Context vector of shape (batch, hidden_dim * num_directions)
            attention_weights: Attention weights of shape (batch, T)
        """
        # LSTM forward pass: (batch, T, hidden_dim * num_directions)
        lstm_output, (hidden, cell) = self.lstm(x)

        # Apply attention
        context, attention_weights = self.attention(lstm_output)

        return context, attention_weights


## 5. Complete Sign2Text Model

Combines all components: Vision Layer → Positional Encoding → Attention LSTM → FC Layer → Classification


In [25]:
class Sign2TextModel(nn.Module):
    """Complete Sign Language to Text model.

    Architecture:
    1. Vision Layer (CNN): Extract spatial features from each frame
    2. Positional Encoding: Add temporal position information
    3. Attention LSTM: Process temporal sequence with attention
    4. FC Layer: Final classification
    """

    def __init__(self, num_classes, feature_dim=512, hidden_dim=256,
                 num_lstm_layers=2, dropout=0.3, pretrained_cnn=True,
                 freeze_cnn=False, max_frames=100):
        super(Sign2TextModel, self).__init__()

        # Vision Layer (CNN)
        self.vision_layer = VisionLayer(
            feature_dim=feature_dim,
            pretrained=pretrained_cnn,
            freeze_backbone=freeze_cnn
        )

        # Positional Encoding
        self.positional_encoding = PositionalEncoding(
            d_model=feature_dim,
            max_len=max_frames,
            dropout=dropout
        )

        # Attention Layer (LSTM)
        self.attention_lstm = AttentionLSTM(
            input_dim=feature_dim,
            hidden_dim=hidden_dim,
            num_layers=num_lstm_layers,
            dropout=dropout,
            bidirectional=True
        )

        # FC Layer (Classification)
        self.fc_layer = nn.Sequential(
            nn.Linear(self.attention_lstm.output_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )

        self.num_classes = num_classes

    def forward(self, x, return_attention=False):
        """
        Args:
            x: Input video frames of shape (batch, T, C, H, W)
            return_attention: If True, also return attention weights
        Returns:
            logits: Classification logits of shape (batch, num_classes)
            attention_weights (optional): Attention weights of shape (batch, T)
        """
        # Vision Layer: (batch, T, C, H, W) → (batch, T, feature_dim)
        features = self.vision_layer(x)

        # Positional Encoding: (batch, T, feature_dim) → (batch, T, feature_dim)
        features = self.positional_encoding(features)

        # Attention LSTM: (batch, T, feature_dim) → (batch, hidden_dim * 2)
        context, attention_weights = self.attention_lstm(features)

        # FC Layer: (batch, hidden_dim * 2) → (batch, num_classes)
        logits = self.fc_layer(context)

        if return_attention:
            return logits, attention_weights
        return logits


## 6. Training Utilities


In [26]:
def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc="Training")
    for frames, labels in progress_bar:
        frames = frames.to(device)
        labels = labels.to(device)

        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(frames)
        loss = criterion(outputs, labels)

        # Backward pass
        loss.backward()
        optimizer.step()

        # Statistics
        running_loss += loss.item() * frames.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        progress_bar.set_postfix({
            'loss': loss.item(),
            'acc': 100 * correct / total
        })

    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc


def evaluate(model, dataloader, criterion, device):
    """Evaluate the model on a dataset."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for frames, labels in tqdm(dataloader, desc="Evaluating"):
            frames = frames.to(device)
            labels = labels.to(device)

            outputs = model(frames)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * frames.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc


## 7. Configuration and Setup


In [None]:
# ============================================
# CONFIGURATION - Modify these paths and hyperparameters
# ============================================

# Data paths (not needed when download=True, but kept for manual override)
JSON_PATH = None  # Will be auto-set by kagglehub download
VIDEO_DIR = None  # Will be auto-set by kagglehub download

# Model hyperparameters
NUM_FRAMES = 16          # Number of frames to sample from each video
FEATURE_DIM = 512        # CNN feature dimension
HIDDEN_DIM = 256         # LSTM hidden dimension
NUM_LSTM_LAYERS = 2      # Number of LSTM layers
DROPOUT = 0.3            # Dropout rate

# Training hyperparameters
BATCH_SIZE = 32           # Batch size (adjust based on GPU memory)
LEARNING_RATE = 1e-4     # Learning rate
WEIGHT_DECAY = 1e-4      # L2 regularization

# ============================================
# QUICK TEST MODE - Set to True for 1-2 hour test run
# Set to False for full training (~14 days)
# ============================================
QUICK_TEST_MODE = True

if QUICK_TEST_MODE:
    NUM_EPOCHS = 3           # Quick test: 3 epochs
    MAX_TRAIN_SAMPLES = 500  # Limit training samples for faster iteration
    MAX_VAL_SAMPLES = 100    # Limit validation samples
else:
    NUM_EPOCHS = 30          # Full training: 30 epochs
    MAX_TRAIN_SAMPLES = None # Use all samples
    MAX_VAL_SAMPLES = None   # Use all samples

# Options
FREEZE_CNN = False       # Whether to freeze CNN backbone
PRETRAINED_CNN = True    # Use pretrained CNN weights


In [28]:
# Data transforms for training and validation
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [None]:
# Create datasets (uses kagglehub download by default)
train_dataset = WLASLTorchCodec(
    download=True,
    split="train",
    num_frames=NUM_FRAMES,
    transform=train_transform,
    max_samples=MAX_TRAIN_SAMPLES  # Limit samples in quick test mode
)

val_dataset = WLASLTorchCodec(
    download=True,
    split="val",
    num_frames=NUM_FRAMES,
    transform=val_transform,
    max_samples=MAX_VAL_SAMPLES  # Limit samples in quick test mode
)

test_dataset = WLASLTorchCodec(
    download=True,
    split="test",
    num_frames=NUM_FRAMES,
    transform=val_transform,
    max_samples=MAX_VAL_SAMPLES  # Limit samples in quick test mode
)

# Get number of classes from dataset
NUM_CLASSES = train_dataset.num_classes

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
print(f"Number of classes: {NUM_CLASSES}")
if QUICK_TEST_MODE:
    print(f"\n⚡ QUICK TEST MODE ENABLED - Using limited samples for faster testing")


Downloaded at path:  C:\Users\Zach\.cache\kagglehub\datasets\sttaseen\wlasl2000-resized\versions\1
Downloaded at path:  C:\Users\Zach\.cache\kagglehub\datasets\sttaseen\wlasl2000-resized\versions\1
Downloaded at path:  C:\Users\Zach\.cache\kagglehub\datasets\sttaseen\wlasl2000-resized\versions\1
Number of training samples: 500
Number of validation samples: 100
Number of test samples: 100
Number of classes: 2000

⚡ QUICK TEST MODE ENABLED - Using limited samples for faster testing


In [30]:
# Create data loaders
# Note: num_workers=0 for Windows compatibility, pin_memory only when CUDA available
NUM_WORKERS = 0  # Set to 0 for Windows (multiprocessing issues), increase on Linux
PIN_MEMORY = torch.cuda.is_available()

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY
)

print(f"DataLoader config: num_workers={NUM_WORKERS}, pin_memory={PIN_MEMORY}")


DataLoader config: num_workers=0, pin_memory=False


In [31]:
# Initialize model
model = Sign2TextModel(
    num_classes=NUM_CLASSES,
    feature_dim=FEATURE_DIM,
    hidden_dim=HIDDEN_DIM,
    num_lstm_layers=NUM_LSTM_LAYERS,
    dropout=DROPOUT,
    pretrained_cnn=PRETRAINED_CNN,
    freeze_cnn=FREEZE_CNN,
    max_frames=NUM_FRAMES
).to(device)

# Print model summary
print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")


Sign2TextModel(
  (vision_layer): VisionLayer(
    (backbone): Sequential(
      (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (4): Sequential(
        (0): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): BasicBlock(
          (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn1): BatchNor

# Loss function and optimizer



In [32]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW( model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY )

# Learning rate scheduler



In [33]:
#no verbose
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=3 )

## 8. Training Loop


In [None]:
# Training loop
best_val_acc = 0.0
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 40)

    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)

    # Validate
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    # Update scheduler
    scheduler.step(val_loss)

    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'label_map': train_dataset.label_map
        }, 'best_model.pth')
        print(f"✓ Saved new best model with Val Acc: {val_acc:.2f}%")

print(f"\nTraining complete! Best Val Acc: {best_val_acc:.2f}%")



Epoch 1/3
----------------------------------------


Training:  12%|█▎        | 2/16 [02:28<17:41, 75.83s/it, loss=7.61, acc=0]

## 9. Evaluation and Visualization


In [None]:
import matplotlib.pyplot as plt

# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()


In [None]:
# Load best model and evaluate on test set
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"\nTest Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")


## 10. Attention Visualization

Visualize which frames the model attends to most when making predictions.


In [None]:
def visualize_attention(model, frames, true_label, label_map, device):
    """Visualize attention weights over video frames."""
    model.eval()

    # Get reverse label map
    idx_to_label = {v: k for k, v in label_map.items()}

    with torch.no_grad():
        # Add batch dimension
        frames_batch = frames.unsqueeze(0).to(device)

        # Get predictions and attention weights
        logits, attention_weights = model(frames_batch, return_attention=True)
        pred_label = torch.argmax(logits, dim=1).item()
        attention = attention_weights[0].cpu().numpy()

    # Create visualization
    num_frames = frames.shape[0]
    fig, axes = plt.subplots(2, num_frames, figsize=(2 * num_frames, 6))

    # Denormalize frames for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    for i in range(num_frames):
        frame = frames[i].cpu()
        frame = frame * std + mean
        frame = frame.clamp(0, 1).permute(1, 2, 0).numpy()

        # Frame image
        axes[0, i].imshow(frame)
        axes[0, i].set_title(f"Frame {i+1}")
        axes[0, i].axis('off')

        # Attention weight bar
        axes[1, i].bar([0], [attention[i]], color='blue', alpha=0.7)
        axes[1, i].set_ylim(0, max(attention) * 1.2)
        axes[1, i].set_title(f"{attention[i]:.3f}")
        axes[1, i].axis('off')

    plt.suptitle(
        f"True: {idx_to_label.get(true_label, true_label)} | "
        f"Predicted: {idx_to_label.get(pred_label, pred_label)}",
        fontsize=14
    )
    plt.tight_layout()
    plt.show()

    return pred_label, attention


In [None]:
# Visualize attention for a sample from the test set
sample_idx = 0
frames, label = test_dataset[sample_idx]
pred, attn = visualize_attention(model, frames, label, train_dataset.label_map, device)


## 11. Inference Function


In [None]:
def predict_video(model, video_path, transform, num_frames, label_map, device):
    """Predict the sign language class for a video file."""
    model.eval()

    # Get reverse label map
    idx_to_label = {v: k for k, v in label_map.items()}

    # Decode video
    decoder = VideoDecoder(video_path)
    frames = []

    for chunk in decoder:
        for frame_tensor in chunk:
            if frame_tensor.dim() == 2:
                frame_tensor = frame_tensor.unsqueeze(2)
            frame_chw = frame_tensor.permute(2, 0, 1)
            frame_pil = transforms.ToPILImage()(frame_chw)
            frames.append(frame_pil)

    # Handle short videos
    while len(frames) < num_frames:
        frames.extend(frames)

    # Sample frames
    T = len(frames)
    idx = torch.linspace(0, T - 1, num_frames).long()
    frames = [frames[i] for i in idx]

    # Apply transforms
    frames = torch.stack([transform(f) for f in frames])

    # Predict
    with torch.no_grad():
        frames_batch = frames.unsqueeze(0).to(device)
        logits, attention = model(frames_batch, return_attention=True)
        probabilities = F.softmax(logits, dim=1)
        pred_idx = torch.argmax(logits, dim=1).item()
        confidence = probabilities[0, pred_idx].item()

    predicted_label = idx_to_label.get(pred_idx, f"Unknown ({pred_idx})")

    return {
        'prediction': predicted_label,
        'confidence': confidence,
        'attention_weights': attention[0].cpu().numpy(),
        'all_probabilities': probabilities[0].cpu().numpy()
    }


In [None]:
# Example inference (uncomment and modify path to use)
# result = predict_video(
#     model=model,
#     video_path="/path/to/your/video.mp4",
#     transform=val_transform,
#     num_frames=NUM_FRAMES,
#     label_map=train_dataset.label_map,
#     device=device
# )
# print(f"Prediction: {result['prediction']}")
# print(f"Confidence: {result['confidence']:.2%}")


## 12. Save Final Model


In [None]:
# Save complete model for deployment
torch.save({
    'model_state_dict': model.state_dict(),
    'label_map': train_dataset.label_map,
    'config': {
        'num_classes': NUM_CLASSES,
        'feature_dim': FEATURE_DIM,
        'hidden_dim': HIDDEN_DIM,
        'num_lstm_layers': NUM_LSTM_LAYERS,
        'num_frames': NUM_FRAMES,
        'dropout': DROPOUT
    }
}, 'sign2text_model_final.pth')

print("Model saved to sign2text_model_final.pth")


---
**Note:** The cells above contain the complete implementation. Make sure to run them in order from top to bottom.


In [None]:
# PositionalEncoding class is defined below cell 7 - this cell can be ignored
# The model requires running cells in sequential order
