# ASLive Sign2Text Model

This notebook implements the sign language to text model following the architecture:
- **Vision Layer (CNN)**: Extracts spatial features from each frame
- **Positional Encoding (PE)**: Adds temporal position information
- **Attention Layer (LSTM)**: Processes temporal sequence with attention
- **FC Layer**: Final classification layer


In [1]:
!pip install kagglehub torchcodec torchvision
!pip install git+https://github.com/facebookresearch/pytorchvideo

Collecting git+https://github.com/facebookresearch/pytorchvideo
  Cloning https://github.com/facebookresearch/pytorchvideo to /tmp/pip-req-build-av8zpa0l
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/pytorchvideo /tmp/pip-req-build-av8zpa0l
  Resolved https://github.com/facebookresearch/pytorchvideo to commit 0f9a5e102e4d84972b829fd30e3c3f78c7c7fd1a
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [2]:
# Before running, add everything from SQ_dataloader.ipynb into the cell below and run

## 1. Data Loading (from SQ_dataloader)


In [3]:
#add SQ_dataloader code here
from torch.utils.data import Dataset
from torchvision import transforms
from torchcodec.decoders import VideoDecoder
import kagglehub
import os
import json

import torch # Assuming torch is imported elsewhere
from torch.utils.data import Dataset
from torchvision import transforms
from torchcodec.decoders import VideoDecoder
import kagglehub
import os
import json
from PIL import Image # Needed for cropping if working with PIL images

class WLASLTorchCodec(Dataset):
  download_path = None

  def __init__(self, json_path=None, video_dir=None, download=True, max_classes=None, split="train", num_frames=32, transform=None):
    print("Will download:", download)
    if (json_path is None or video_dir is None) and download == False:
      raise ValueError("json_path and video_dir must be provided with download false")
    if download:
      if WLASLTorchCodec.download_path is None:
        path = kagglehub.dataset_download("sttaseen/wlasl2000-resized")
        WLASLTorchCodec.download_path = path
      else:
        path = WLASLTorchCodec.download_path
      print("Downloaded at path: ", path)

      self.video_dir = os.path.join(path, "wlasl-complete", "videos")
      json_path = os.path.join(path, "wlasl-complete","WLASL_v0.3.json")
      downloaded = True
    else:
      self.video_dir = video_dir
    self.num_frames = num_frames
    self.transform = transform

    # Read json
    with open(json_path, "r") as f:
      data = json.load(f)
    if max_classes is not None:
        if isinstance(max_classes, int):
            # Keep only the first N entries (Usually the most frequent in WLASL)
            data = data[:max_classes]
            print(f"Limiting dataset to top {max_classes} classes.")
        elif isinstance(max_classes, list):
            # Keep only entries that match specific glosses
            data = [entry for entry in data if entry['gloss'] in max_classes]
            print(f"Limiting dataset to {len(data)} specific classes.")
    self.samples = []
    self.label_map = {}
    label_id = 0

    for entry in data:
      gloss = entry["gloss"]

      if gloss not in self.label_map:
        self.label_map[gloss] = label_id
        label_id += 1

      label = self.label_map[gloss]

      for inst in entry["instances"]:
        if inst["split"] != split:
          continue

        video_id = inst["video_id"]
        file_path = os.path.join(self.video_dir, f"{video_id}.mp4")

        # 1. Modification in __init__: Extract and store frame/bbox info
        frame_start = inst.get("frame_start", 1) # Default to 1 if missing
        frame_end = inst.get("frame_end", -1)   # Default to -1 if missing
        bbox = inst.get("bbox", [0, 0, 1.0, 1.0]) # Default to normalized full frame if missing

        if os.path.isfile(file_path):
          # Store a tuple of (file_path, label, frame_start, frame_end, bbox)
          self.samples.append((file_path, label, frame_start, frame_end, bbox))
        self.num_classes = label_id
  def __len__(self):
    return len(self.samples)

  def __getitem__(self, idx):
    # 2. Modification in __getitem__: Unpack all instance info
    video_path, label, frame_start, frame_end, bbox = self.samples[idx]

    # Convert WLASL 1-based indices (inclusive start, exclusive end) to
    # torchcodec's 0-based indices (inclusive start, inclusive end).

    decoder = VideoDecoder(video_path)
    video_length = decoder.metadata.num_frames
    end_frame = frame_end - 1 if frame_end > 0 else video_length
    start_frame = 0
    if end_frame > video_length:
      end_frame = video_length
    else:
      end_frame = frame_end - 2 if frame_end > 0 else None
    if frame_start > video_length:
      start_frame = 0
    else:
      start_frame = frame_start - 1
    frames = decoder[start_frame:end_frame]
    if self.transform:
      # Transform should handle T x C x H x W input
      frames = self.transform(frames)
    return frames, torch.tensor(label) # Ensure label is a tensor

In [4]:
import pytorchvideo.transforms as ptv_transforms
from torchvision.transforms import Compose, Lambda


mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]

# Test out dataset
train_transform = Compose(
    [
        # 1. Spatial Resize: Scale the shortest edge to SIDE_SIZE
        ptv_transforms.UniformTemporalSubsample(num_samples=24, temporal_dim=0),
        ptv_transforms.ConvertUint8ToFloat(),
        Lambda(lambda x: x.permute(1, 0, 2, 3)),
        ptv_transforms.Normalize(mean, std),
        Lambda(lambda x: x.permute(1, 0, 2, 3)),
        ptv_transforms.ShortSideScale(size=224),
        # ptv_transforms.RandAugment(magnitude=6, num_layers=2),
        # ptv_transforms.AugMix(magnitude=3),
    ]
)

def show_frame(video, frame_idx):
  import matplotlib.pyplot as plt
  import numpy as np
  single_frame = video[frame_idx]
  frame_np = single_frame.detach().cpu().numpy()

  frame_np = np.transpose(frame_np, (1, 2, 0))
  plt.figure(figsize=(5, 5))
  plt.imshow(frame_np)
  plt.title(f'Frame {frame_idx} from Video Batch')
  plt.axis('off') # Hide axis ticks and labels
  plt.show()

# clip = WLASLTorchCodec(max_classes=1, transform=train_transform)

# for video, label in clip:
#   show_frame(video, 0)

In [5]:
import os
import json
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchcodec.decoders import VideoDecoder
import numpy as np
from tqdm import tqdm

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


## 2. Vision Layer (CNN Backbone)

The Vision Layer extracts spatial features from each video frame using a CNN. We use a pretrained ResNet-18 as the backbone and remove the final classification layer to get feature embeddings.


In [6]:
class VisionLayer(nn.Module):
    """CNN backbone for extracting spatial features from video frames.

    Uses pretrained ResNet-18 as feature extractor.
    Input: (batch, T, C, H, W) - batch of T frames
    Output: (batch, T, feature_dim) - feature vectors for each frame
    """

    def __init__(self, feature_dim=512, pretrained=True, freeze_backbone=False):
        super(VisionLayer, self).__init__()

        # Load pretrained ResNet-18
        resnet = models.resnet18(weights='IMAGENET1K_V1' if pretrained else None)

        # Remove the final FC layer
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # ResNet-18 outputs 512-dim features
        self.resnet_feature_dim = 512

        # Optional projection layer to adjust feature dimension
        if feature_dim != self.resnet_feature_dim:
            self.projection = nn.Linear(self.resnet_feature_dim, feature_dim)
        else:
            self.projection = None

        self.feature_dim = feature_dim

        # Freeze backbone if specified
        self.set_freeze_backbone(freeze_backbone)

    def set_freeze_backbone(self, is_frozen):
      for param in self.backbone.parameters():
          param.requires_grad = not is_frozen

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, T, C, H, W)
        Returns:
            Feature tensor of shape (batch, T, feature_dim)
        """
        batch_size, T, C, H, W = x.shape

        # Reshape to process all frames together: (batch * T, C, H, W)
        x = x.view(batch_size * T, C, H, W)

        # Extract features: (batch * T, 512, 1, 1)
        features = self.backbone(x)

        # Flatten: (batch * T, 512)
        features = features.view(batch_size * T, -1)

        # Project features if needed
        if self.projection is not None:
            features = self.projection(features)

        # Reshape back: (batch, T, feature_dim)
        features = features.view(batch_size, T, self.feature_dim)

        return features


## 3. Positional Encoding (PE)

Sinusoidal positional encoding adds temporal position information to the frame features before feeding them to the LSTM.


In [7]:
class PositionalEncoding(nn.Module):
    """Sinusoidal positional encoding for temporal sequences.

    Adds position information to help the model understand the order of frames.
    """

    def __init__(self, d_model, max_len=500, dropout=0.1):
        super(PositionalEncoding, self).__init__()

        self.dropout = nn.Dropout(p=dropout)

        # Create positional encoding matrix
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        # Add batch dimension: (1, max_len, d_model)
        pe = pe.unsqueeze(0)

        # Register as buffer (not a parameter, but should be saved/loaded)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, T, d_model)
        Returns:
            Tensor with positional encoding added: (batch, T, d_model)
        """
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)


## 4. Attention Layer (LSTM with Attention)

Bidirectional LSTM processes the sequence of frame features, followed by an attention mechanism to weight the importance of different time steps.


In [8]:

class Attention(nn.Module):
    """Attention mechanism for weighting LSTM outputs.

    Computes attention weights over the sequence and returns a weighted sum.
    """

    def __init__(self, hidden_dim):
        super(Attention, self).__init__()

        self.attention = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.Tanh(),
            nn.Linear(hidden_dim // 2, 1)
        )

    def forward(self, lstm_output):
        """
        Args:
            lstm_output: LSTM outputs of shape (batch, T, hidden_dim)
        Returns:
            context: Weighted sum of shape (batch, hidden_dim)
            attention_weights: Attention weights of shape (batch, T)
        """
        # Compute attention scores: (batch, T, 1)
        scores = self.attention(lstm_output)

        # Apply softmax over time dimension: (batch, T, 1)
        attention_weights = F.softmax(scores, dim=1)

        # Compute weighted sum: (batch, hidden_dim)
        context = torch.sum(attention_weights * lstm_output, dim=1)

        return context, attention_weights.squeeze(-1)


class AttentionLSTM(nn.Module):
    """Bidirectional LSTM with attention mechanism.

    Processes temporal sequence of frame features and outputs a fixed-size representation.
    """

    def __init__(self, input_dim, hidden_dim=256, num_layers=2, dropout=0.3, bidirectional=True):
        super(AttentionLSTM, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1

        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=bidirectional
        )

        # Attention mechanism
        self.attention = Attention(hidden_dim * self.num_directions)

        # Output dimension
        self.output_dim = hidden_dim * self.num_directions

    def forward(self, x):
        """
        Args:
            x: Input tensor of shape (batch, T, input_dim)
        Returns:
            output: Context vector of shape (batch, hidden_dim * num_directions)
            attention_weights: Attention weights of shape (batch, T)
        """
        # LSTM forward pass: (batch, T, hidden_dim * num_directions)
        lstm_output, (hidden, cell) = self.lstm(x)

        # Apply attention
        context, attention_weights = self.attention(lstm_output)

        return context, attention_weights

## 5. Complete Sign2Text Model

Combines all components: Vision Layer → Positional Encoding → Attention LSTM → FC Layer → Classification


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.video import r3d_18, R3D_18_Weights
from torchvision.models.video import mc3_18, MC3_18_Weights

class SlowFast(nn.Module):
    def __init__(self, num_classes=100, pretrained=True, dropout=0.5):
        super().__init__()

        # === 1. Backbones ===
        slow_weights = R3D_18_Weights.KINETICS400_V1 if pretrained else None
        slow_net = r3d_18(weights=slow_weights)

        fast_weights = MC3_18_Weights.KINETICS400_V1 if pretrained else None
        fast_net = mc3_18(weights=fast_weights)

        # Split backbones
        self.slow_path = nn.ModuleList([
            slow_net.stem, slow_net.layer1, slow_net.layer2, slow_net.layer3, slow_net.layer4
        ])
        self.fast_path = nn.ModuleList([
            fast_net.stem, fast_net.layer1, fast_net.layer2, fast_net.layer3, fast_net.layer4
        ])

        # === 2. Lateral Connections (Fixed Channels) ===
        # Output channels must match the Slow pathway at that layer for summation
        # Slow Layer1: 64, Layer2: 128, Layer3: 256, Layer4: 512
        # We assume Fast path has same channel progression (64, 128, 256, 512)

        # Note: We use stride=(8,1,1) assuming Fast input has 8x frames and we need to match Slow T.
        # Padding ensures we don't lose frames excessively.
        self.lateral_p3 = nn.Conv3d(64,  64,  kernel_size=(5,1,1), stride=(8,1,1), padding=(2,0,0), bias=False)
        self.lateral_p4 = nn.Conv3d(128, 128, kernel_size=(5,1,1), stride=(8,1,1), padding=(2,0,0), bias=False)
        self.lateral_p5 = nn.Conv3d(256, 256, kernel_size=(5,1,1), stride=(8,1,1), padding=(2,0,0), bias=False)

        # === 3. Fusion + Head ===
        slow_feat_dim = 512
        fast_proj_dim = 64

        # Fixed: Defined in __init__, not forward
        self.fast_proj = nn.Conv3d(512, fast_proj_dim, kernel_size=1, bias=False)

        # Fixed: Input dim accounts for concatenation (512 + 64)
        fusion_dim = slow_feat_dim + fast_proj_dim

        self.fusion = nn.Conv3d(fusion_dim, fusion_dim, kernel_size=1, bias=False)
        self.bn_fusion = nn.BatchNorm3d(fusion_dim)

        self.avgpool = nn.AdaptiveAvgPool3d(1)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(fusion_dim, num_classes)

        # === 4. Fixed Initialization ===
        # Initialize ONLY new layers. Do NOT wipe out backbone weights.
        new_layers = [self.lateral_p3, self.lateral_p4, self.lateral_p5,
                      self.fast_proj, self.fusion, self.bn_fusion, self.fc]

        for module in new_layers:
            for m in module.modules():
                if isinstance(m, nn.Conv3d):
                    nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm3d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)

    def _upsample_add(self, fast_feat, slow_feat):
        # Match Temporal, Height, Width
        target_shape = slow_feat.shape[2:] # (T, H, W)

        fast_up = F.interpolate(
            fast_feat,
            size=target_shape,
            mode='trilinear',
            align_corners=False
        )
        return slow_feat + fast_up

    def forward(self, x):
        # x: (B, T, C, H, W) -> (B, C, T, H, W)
        x = x.permute(0, 2, 1, 3, 4)

        # Slow: sample every 8th frame
        slow = x[:, :, ::8, :, :]
        # Fast: all frames
        fast = x

        # Block 0 (Stem)
        slow = self.slow_path[0](slow)
        fast = self.fast_path[0](fast)

        # Block 1 + Lateral
        slow = self.slow_path[1](slow)
        fast = self.fast_path[1](fast)
        lateral = self.lateral_p3(fast)
        slow = self._upsample_add(lateral, slow)

        # Block 2 + Lateral
        slow = self.slow_path[2](slow)
        fast = self.fast_path[2](fast)
        lateral = self.lateral_p4(fast)
        slow = self._upsample_add(lateral, slow)

        # Block 3 + Lateral
        slow = self.slow_path[3](slow)
        fast = self.fast_path[3](fast)
        lateral = self.lateral_p5(fast)
        slow = self._upsample_add(lateral, slow)

        # Block 4
        slow = self.slow_path[4](slow)
        fast = self.fast_path[4](fast)

        # Final Fusion
        # Project Fast (512 -> 64)
        fast_final = self.fast_proj(fast)

        # Match Slow spatial/temporal dims for concatenation
        fast_final = F.interpolate(fast_final, size=slow.shape[2:], mode='trilinear', align_corners=False)

        # Concatenate
        x = torch.cat([slow, fast_final], dim=1) # (B, 576, T, H, W)

        # Head
        x = self.fusion(x)
        x = self.bn_fusion(x)
        x = F.relu_(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)

        return x

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models.video import r3d_18, R3D_18_Weights
from torchvision.models.video import mc3_18, MC3_18_Weights


class SlowFast(nn.Module):
    """
    Accepts: (B, T, C, H, W)
    Returns: (B, num_classes)
    """
    def __init__(self, num_classes=100, pretrained=True, dropout=0.5):
        super().__init__()

        # === Slow pathway (R3D-18) ===
        slow_weights = R3D_18_Weights.KINETICS400_V1 if pretrained else None
        slow_net = r3d_18(weights=slow_weights)

        self.slow_path = nn.Sequential(
            slow_net.stem,
            slow_net.layer1,
            slow_net.layer2,
            slow_net.layer3,
            slow_net.layer4
        )
        slow_feat_dim = 512

        # === Fast pathway (MC3-18) ===
        fast_weights = MC3_18_Weights.KINETICS400_V1 if pretrained else None
        fast_net = mc3_18(weights=fast_weights)

        self.fast_path = nn.Sequential(
            fast_net.stem,
            fast_net.layer1,
            fast_net.layer2,
            fast_net.layer3,
            fast_net.layer4
        )

        # === Lateral connections (Fast → Slow) ===
        # Match channel sizes exactly:
        self.lateral_p3 = nn.Conv3d(64, 64, kernel_size=(5,1,1), stride=(8,1,1), padding=(2,0,0), bias=False)
        self.lateral_p4 = nn.Conv3d(128, 128, kernel_size=(5,1,1), stride=(8,1,1), padding=(2,0,0), bias=False)
        self.lateral_p5 = nn.Conv3d(256, 256, kernel_size=(5,1,1), stride=(8,1,1), padding=(2,0,0), bias=False)

        # === Final fast projection (512 → 64) ===
        self.fast_proj = nn.Conv3d(512, 64, kernel_size=1, bias=False)

        # === Fusion + head ===
        fusion_dim = slow_feat_dim + 64  # 512 + 64 = 576
        self.fusion = nn.Conv3d(fusion_dim, fusion_dim, kernel_size=1, bias=False)
        self.bn_fusion = nn.BatchNorm3d(fusion_dim)

        self.avgpool = nn.AdaptiveAvgPool3d(1)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.Linear(fusion_dim, num_classes)

        # Weight initialization
        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm3d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _upsample_add(self, fast_feat, slow_feat):
        _, _, T_f, H_f, W_f = fast_feat.shape
        target_T, target_H, target_W = slow_feat.shape[2:]
        fast_up = F.interpolate(
            fast_feat,
            size=(target_T, target_H, target_W),
            mode='trilinear',
            align_corners=False
        )
        return slow_feat + fast_up

    def forward(self, x):
        # x: (B, T, C, H, W) → (B, C, T, H, W)
        x = x.permute(0, 2, 1, 3, 4).contiguous()

        # Slow pathway (sample 1/8)
        slow_x = x[:, :, ::8]
        slow = self.slow_path[0](slow_x)
        slow = self.slow_path[1](slow)

        # Fast pathway
        fast = self.fast_path[0](x)
        fast = self.fast_path[1](fast)

        # Lateral p3
        lateral = self.lateral_p3(fast)
        slow = self._upsample_add(lateral, slow)

        # layer2
        slow = self.slow_path[2](slow)
        fast = self.fast_path[2](fast)

        lateral = self.lateral_p4(fast)
        slow = self._upsample_add(lateral, slow)

        # layer3
        slow = self.slow_path[3](slow)
        fast = self.fast_path[3](fast)

        lateral = self.lateral_p5(fast)
        slow = self._upsample_add(lateral, slow)

        # layer4
        slow = self.slow_path[4](slow)
        fast = self.fast_path[4](fast)

        # Final fast projection
        fast_final = F.adaptive_avg_pool3d(fast, slow.shape[2:])
        fast_final = self.fast_proj(fast_final)

        # Concatenate
        slow = torch.cat([slow, fast_final], dim=1)  # 512+64 = 576

        # Fusion head
        x = self.fusion(slow)
        x = self.bn_fusion(x)
        x = F.relu_(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        x = self.fc(x)
        return x


## 6. Training Utilities


In [11]:
from torch.cuda.amp import autocast, GradScaler
scaler = GradScaler()

def train_epoch(model, dataloader, criterion, optimizer, device):
    """Train the model for one epoch."""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    progress_bar = tqdm(dataloader, desc="Training")
    for frames, labels in progress_bar:
        frames = frames.to(device)
        labels = labels.to(device)
        # print(frames.shape)
        # Zero gradients
        optimizer.zero_grad()

        # Forward pass
        with autocast():
          outputs = model(frames)
          loss = criterion(outputs, labels)
        # Backward pass

        scaler.scale(loss).backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        scaler.step(optimizer)
        scaler.update()
        # loss.backward()
        # optimizer.step()

        # Statistics
        running_loss += loss.item() * frames.size(0)
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        progress_bar.set_postfix({
            'loss': loss.item(),
            'acc': 100 * correct / total
        })

    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc


def evaluate(model, dataloader, criterion, device):
    """Evaluate the model on a dataset."""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for frames, labels in tqdm(dataloader, desc="Evaluating"):
            frames = frames.to(device)
            labels = labels.to(device)

            with autocast():
                outputs = model(frames)
                loss = criterion(outputs, labels)

            running_loss += loss.item() * frames.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / total
    epoch_acc = 100 * correct / total

    return epoch_loss, epoch_acc


  scaler = GradScaler()


## 7. Configuration and Setup


In [22]:
# ============================================
# CONFIGURATION - Modify these paths and hyperparameters
# ============================================

# Data paths
JSON_PATH = "/content/drive/MyDrive/wlasl_resized/wlasl-complete/WLASL_v0.3.json"  # Path to WLASL JSON
VIDEO_DIR = "/content/drive/MyDrive/wlasl_resized/wlasl-complete/videos"  # Path to video directory

# Model hyperparameters
NUM_FRAMES = 16           # Number of frames to sample from each video
FEATURE_DIM = 512        # CNN feature dimension
HIDDEN_DIM = 256         # LSTM hidden dimension
NUM_LSTM_LAYERS = 2      # Number of LSTM layers
DROPOUT = 0.0            # Dropout rate

# Training hyperparameters
BATCH_SIZE = 12           # Batch size (adjust based on GPU memory)
LEARNING_RATE = 1e-4     # Learning rate
NUM_EPOCHS = 30          # Number of training epochs
WEIGHT_DECAY = 1e-4     # L2 regularization
IMG_SIZE=168
# Options
FREEZE_CNN = True       # Whether to freeze CNN backbone
PRETRAINED_CNN = True    # Use pretrained CNN weights
WORKERS = 6
EPOCHS_UNTIL_UNFREEZE = 50
CLASSES_COUNT = 50
PREFETCH = 3

In [23]:
from torchvision.transforms import v2
# Data transforms for training and validation
# train_transform = transforms.Compose([
#     v2.Resize((IMG_SIZE, IMG_SIZE)),
#     # v2.RandomHorizontalFlip(),
#     # v2.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
#     v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
# ])

# val_transform = transforms.Compose([
#     v2.Resize((IMG_SIZE, IMG_SIZE)),
#     v2.Compose([v2.ToImage(), v2.ToDtype(torch.float32, scale=True)]),
# ])

from torchvision.transforms import Compose
import pytorchvideo.transforms as ptv_transforms
from pytorchvideo.transforms import functional as ptv_functional

import torch
import torch.nn as nn
import torch.nn.functional as F

# Note: The transforms below expect the video tensor to be in the range [0.0, 1.0]
# and of shape (T, C, H, W). The `WLASLTorchCodec` implementation already ensures
# the shape is (T, C, H, W), but you must ensure the pixel values are converted
# to float and normalized to [0, 1] before applying the standard normalization.



mean = [0.45, 0.45, 0.45]
std = [0.225, 0.225, 0.225]

# Test out dataset
train_transform = Compose(
    [
        # 1. Spatial Resize: Scale the shortest edge to SIDE_SIZE
        ptv_transforms.UniformTemporalSubsample(num_samples=24, temporal_dim=0),
        ptv_transforms.ConvertUint8ToFloat(),
        Lambda(lambda x: x.permute(1, 0, 2, 3)),
        ptv_transforms.Normalize(mean, std),
        Lambda(lambda x: x.permute(1, 0, 2, 3)),
        ptv_transforms.ShortSideScale(size=224),
        ptv_transforms.RandAugment(magnitude=4, num_layers=2),
        # ptv_transforms.AugMix(magnitude=3),
    ]
)

# train_transform = Compose(
#     [
#         # 1. Spatial Resize: Scale the shortest edge to SIDE_SIZE
#         ptv_transforms.UniformTemporalSubsample(num_samples=NUM_FRAMES, temporal_dim=0),
#         ptv_transforms.ConvertUint8ToFloat(),
#         ptv_transforms.ShortSideScale(size=IMG_SIZE),
#         ptv_transforms.RandAugment(magnitude=15, num_layers=2),
#         ptv_transforms.AugMix(magnitude=3),
#     ]
# )

test_transform = Compose(
    [
        ptv_transforms.UniformTemporalSubsample(num_samples=24, temporal_dim=0),
        ptv_transforms.ConvertUint8ToFloat(),
        Lambda(lambda x: x.permute(1, 0, 2, 3)),
        ptv_transforms.Normalize(mean, std),
        Lambda(lambda x: x.permute(1, 0, 2, 3)),
        ptv_transforms.ShortSideScale(size=IMG_SIZE),
    ]
)
val_transform =test_transform

In [24]:
# Create datasets
train_dataset = WLASLTorchCodec(
    download=True,
    split="train",
    max_classes=CLASSES_COUNT,
    num_frames=NUM_FRAMES,
    transform=train_transform
)

val_dataset = WLASLTorchCodec(
    download=True,
    split="val",
    max_classes=CLASSES_COUNT,
    num_frames=NUM_FRAMES,
    transform=val_transform
)

test_dataset = WLASLTorchCodec(
    download=True,
    split="test",
    max_classes=CLASSES_COUNT,
    num_frames=NUM_FRAMES,
    transform=val_transform
)

# Get number of classes from dataset
NUM_CLASSES = train_dataset.num_classes

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")
print(f"Number of classes: {NUM_CLASSES}")

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=WORKERS,           # Start high. The optimal value is often 4 to 12.
    pin_memory=True,         # Accelerates the transfer of data from CPU to GPU VRAM.
    persistent_workers=True, # Recommended for PyTorch multi-process workers to save epoch setup time.
    prefetch_factor=PREFETCH
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=WORKERS,           # Start high. The optimal value is often 4 to 12.
                             # Since video decoding is CPU-heavy, 8 is a good starting point.
    pin_memory=True,         # Accelerates the transfer of data from CPU to GPU VRAM.
    persistent_workers=True, # Recommended for PyTorch multi-process workers to save epoch setup time.
    prefetch_factor=PREFETCH
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=WORKERS,           # Start high. The optimal value is often 4 to 12.
                             # Since video decoding is CPU-heavy, 8 is a good starting point.
    pin_memory=True,         # Accelerates the transfer of data from CPU to GPU VRAM.
    persistent_workers=True, # Recommended for PyTorch multi-process workers to save epoch setup time.
    prefetch_factor=PREFETCH
)

Will download: True
Downloaded at path:  /kaggle/input/wlasl2000-resized
Limiting dataset to top 50 classes.
Will download: True
Downloaded at path:  /kaggle/input/wlasl2000-resized
Limiting dataset to top 50 classes.
Will download: True
Downloaded at path:  /kaggle/input/wlasl2000-resized
Limiting dataset to top 50 classes.
Number of training samples: 785
Number of validation samples: 183
Number of test samples: 143
Number of classes: 50


In [25]:
# Initialize model
# model = Sign2TextModel(
#     num_classes=NUM_CLASSES,
#     feature_dim=FEATURE_DIM,
#     hidden_dim=HIDDEN_DIM,
#     num_lstm_layers=NUM_LSTM_LAYERS,
#     dropout=DROPOUT,
#     pretrained_cnn=PRETRAINED_CNN,
#     freeze_cnn=FREEZE_CNN,
#     max_frames=NUM_FRAMES
# ).to(device)

model = SlowFast(num_classes=NUM_CLASSES, dropout=DROPOUT).to(device)

# model = SignTimeSformer(
#     num_classes=NUM_CLASSES,
#     img_size=IMG_SIZE,
#     num_frames=NUM_FRAMES,
#     heads=12,
#     L=5,
#     dropout=DROPOUT
# ).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW( model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY )
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=0.5, patience=3 )

# Print model summary
# print(model)
print(f"\nTotal parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")



Total parameters: 45,481,138
Trainable parameters: 45,481,138


## 8. Training Loop


In [None]:
from typing import *
# Training loop
best_val_acc = 0.0
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}

from torch.utils.flop_counter import FlopCounterMode

def get_flops(model, inp: Union[torch.Tensor, Tuple], with_backward=False):

    istrain = model.training
    model.eval()

    inp = inp if isinstance(inp, torch.Tensor) else torch.randn(inp)

    flop_counter = FlopCounterMode(mods=model, display=False, depth=None)
    with flop_counter:
        if with_backward:
            model(inp).sum().backward()
        else:
            model(inp)
    total_flops =  flop_counter.get_total_flops()
    if istrain:
        model.train()
    return total_flops

import torch
torch.cuda.empty_cache()

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 40)
    # if epoch > EPOCHS_UNTIL_UNFREEZE and FREEZE_CNN:
    #     model.set_freeze(False)
    # Train

    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    # train_flops = get_flops(model, )
    # Validate
    val_loss, val_acc = evaluate(model, val_loader, criterion, device)

    # Update scheduler
    scheduler.step(val_loss)

    # Save history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_acc': val_acc,
            'label_map': train_dataset.label_map
        }, 'best_model.pth')
        print(f"✓ Saved new best model with Val Acc: {val_acc:.2f}%")

print(f"\nTraining complete! Best Val Acc: {best_val_acc:.2f}%")



Epoch 1/30
----------------------------------------


  with autocast():
Training: 100%|██████████| 66/66 [04:55<00:00,  4.48s/it, loss=4.02, acc=1.78]
  with autocast():
Evaluating: 100%|██████████| 16/16 [00:11<00:00,  1.39it/s]


Train Loss: 3.9655, Train Acc: 1.78%
Val Loss: 3.9505, Val Acc: 2.73%
✓ Saved new best model with Val Acc: 2.73%

Epoch 2/30
----------------------------------------


Training: 100%|██████████| 66/66 [04:52<00:00,  4.44s/it, loss=4.03, acc=2.55]
Evaluating: 100%|██████████| 16/16 [00:10<00:00,  1.48it/s]


Train Loss: 3.9399, Train Acc: 2.55%
Val Loss: 3.9272, Val Acc: 3.28%
✓ Saved new best model with Val Acc: 3.28%

Epoch 3/30
----------------------------------------


Training: 100%|██████████| 66/66 [04:52<00:00,  4.43s/it, loss=3.82, acc=3.31]
Evaluating: 100%|██████████| 16/16 [00:10<00:00,  1.51it/s]


Train Loss: 3.9274, Train Acc: 3.31%
Val Loss: 3.9735, Val Acc: 4.37%
✓ Saved new best model with Val Acc: 4.37%

Epoch 4/30
----------------------------------------


Training: 100%|██████████| 66/66 [04:52<00:00,  4.43s/it, loss=4.04, acc=4.59]
Evaluating: 100%|██████████| 16/16 [00:10<00:00,  1.48it/s]


Train Loss: 3.9087, Train Acc: 4.59%
Val Loss: 4.0103, Val Acc: 3.28%

Epoch 5/30
----------------------------------------


Training: 100%|██████████| 66/66 [04:52<00:00,  4.43s/it, loss=3.73, acc=4.2]
Evaluating: 100%|██████████| 16/16 [00:10<00:00,  1.47it/s]


Train Loss: 3.9067, Train Acc: 4.20%
Val Loss: 3.9655, Val Acc: 3.83%

Epoch 6/30
----------------------------------------


Training:  38%|███▊      | 25/66 [01:53<03:02,  4.44s/it, loss=3.74, acc=4.67]

## 9. Evaluation and Visualization


In [None]:
import matplotlib.pyplot as plt

# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Accuracy plot
axes[1].plot(history['train_acc'], label='Train Acc', marker='o')
axes[1].plot(history['val_acc'], label='Val Acc', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

plt.tight_layout()
plt.savefig('training_history.png', dpi=150)
plt.show()


In [None]:
# Load best model and evaluate on test set
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

test_loss, test_acc = evaluate(model, test_loader, criterion, device)
print(f"\nTest Results:")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")


## 10. Attention Visualization

Visualize which frames the model attends to most when making predictions.


In [None]:
import matplotlib.pyplot as plt


def visualize_attention(model, frames, true_label, label_map, device):
    """Visualize attention weights over video frames."""
    model.eval()

    # Get reverse label map
    idx_to_label = {v: k for k, v in label_map.items()}

    with torch.no_grad():
        # Add batch dimension
        frames_batch = frames.unsqueeze(0).to(device)

        # Get predictions and attention weights
        logits, attention_weights = model(frames_batch, return_attention=True)
        pred_label = torch.argmax(logits, dim=1).item()
        attention = attention_weights[0].cpu().numpy()

    # Create visualization
    num_frames = frames.shape[0]
    fig, axes = plt.subplots(2, num_frames, figsize=(2 * num_frames, 6))

    # Denormalize frames for visualization
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

    for i in range(num_frames):
        frame = frames[i].cpu()
        frame = frame * std + mean
        frame = frame.clamp(0, 1).permute(1, 2, 0).numpy()

        # Frame image
        axes[0, i].imshow(frame)
        axes[0, i].set_title(f"Frame {i+1}")
        axes[0, i].axis('off')

        # Attention weight bar
        axes[1, i].bar([0], [attention[i]], color='blue', alpha=0.7)
        axes[1, i].set_ylim(0, max(attention) * 1.2)
        axes[1, i].set_title(f"{attention[i]:.3f}")
        axes[1, i].axis('off')

    plt.suptitle(
        f"True: {idx_to_label.get(true_label, true_label)} | "
        f"Predicted: {idx_to_label.get(pred_label, pred_label)}",
        fontsize=14
    )
    plt.tight_layout()
    plt.show()

    return pred_label, attention


In [None]:
# Visualize attention for a sample from the test set
sample_idx = 0
frames, label = test_dataset[sample_idx]
pred, attn = visualize_attention(model, frames, label, train_dataset.label_map, device)


## 11. Inference Function


In [None]:
def predict_video(model, video_path, transform, num_frames, label_map, device):
    """Predict the sign language class for a video file."""
    model.eval()

    # Get reverse label map
    idx_to_label = {v: k for k, v in label_map.items()}

    # Decode video
    decoder = VideoDecoder(video_path)
    frames = []

    for chunk in decoder:
        for frame_tensor in chunk:
            if frame_tensor.dim() == 2:
                frame_tensor = frame_tensor.unsqueeze(2)
            frame_chw = frame_tensor.permute(2, 0, 1)
            frame_pil = transforms.ToPILImage()(frame_chw)
            frames.append(frame_pil)

    # Handle short videos
    while len(frames) < num_frames:
        frames.extend(frames)

    # Sample frames
    T = len(frames)
    idx = torch.linspace(0, T - 1, num_frames).long()
    frames = [frames[i] for i in idx]

    # Apply transforms
    frames = torch.stack([transform(f) for f in frames])

    # Predict
    with torch.no_grad():
        frames_batch = frames.unsqueeze(0).to(device)
        logits, attention = model(frames_batch, return_attention=True)
        probabilities = F.softmax(logits, dim=1)
        pred_idx = torch.argmax(logits, dim=1).item()
        confidence = probabilities[0, pred_idx].item()

    predicted_label = idx_to_label.get(pred_idx, f"Unknown ({pred_idx})")

    return {
        'prediction': predicted_label,
        'confidence': confidence,
        'attention_weights': attention[0].cpu().numpy(),
        'all_probabilities': probabilities[0].cpu().numpy()
    }


In [None]:
# Example inference (uncomment and modify path to use)
# result = predict_video(
#     model=model,
#     video_path="/path/to/your/video.mp4",
#     transform=val_transform,
#     num_frames=NUM_FRAMES,
#     label_map=train_dataset.label_map,
#     device=device
# )
# print(f"Prediction: {result['prediction']}")
# print(f"Confidence: {result['confidence']:.2%}")


## 12. Save Final Model


In [None]:
# Save complete model for deployment
torch.save({
    'model_state_dict': model.state_dict(),
    'label_map': train_dataset.label_map,
    'config': {
        'num_classes': NUM_CLASSES,
        'feature_dim': FEATURE_DIM,
        'hidden_dim': HIDDEN_DIM,
        'num_lstm_layers': NUM_LSTM_LAYERS,
        'num_frames': NUM_FRAMES,
        'dropout': DROPOUT
    }
}, 'sign2text_model_final.pth')

print("Model saved to sign2text_model_final.pth")


---
**Note:** The cells above contain the complete implementation. Make sure to run them in order from top to bottom.


In [None]:
# PositionalEncoding class is defined below cell 7 - this cell can be ignored
# The model requires running cells in sequential order
