# Six Seven Gesture Recognition Training

Two-head model: Detection (binary) + Counting (multi-class)

In [None]:
!pip install -q pytorchvideo av

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from pathlib import Path
import av
import random
from tqdm.notebook import tqdm

In [None]:
CONFIG = {
    'frames_per_clip': 48,
    'spatial_size': 224,
    'batch_size': 4,
    'learning_rate': 3e-4,
    'weight_decay': 0.01,
    'epochs': 20,
    'early_stopping_patience': 5,
    'dropout': 0.3,
    'num_count_classes': 12
}

DATA_DIR = Path('/kaggle/input/sixseven-gesture')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

In [None]:
class VideoTransform:
    def __init__(self, spatial_size=224, training=True):
        self.spatial_size = spatial_size
        self.training = training
        self.mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1, 1)
        self.std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1, 1)

    def __call__(self, frames):
        c, t, h, w = frames.shape
        scale = self.spatial_size / min(h, w)
        new_h, new_w = int(h * scale), int(w * scale)
        frames = frames.view(c * t, 1, h, w)
        frames = F.interpolate(frames, size=(new_h, new_w), mode='bilinear', align_corners=False)
        frames = frames.view(c, t, new_h, new_w)

        if self.training:
            if new_h > self.spatial_size and new_w > self.spatial_size:
                top = random.randint(0, new_h - self.spatial_size)
                left = random.randint(0, new_w - self.spatial_size)
                frames = frames[:, :, top:top+self.spatial_size, left:left+self.spatial_size]
            if random.random() < 0.5:
                frames = frames.flip(dims=[3])
            brightness = random.uniform(0.8, 1.2)
            frames = (frames * brightness).clamp(0, 1)
        else:
            top = (new_h - self.spatial_size) // 2
            left = (new_w - self.spatial_size) // 2
            frames = frames[:, :, top:top+self.spatial_size, left:left+self.spatial_size]

        if frames.shape[2] != self.spatial_size or frames.shape[3] != self.spatial_size:
            frames = F.interpolate(
                frames.view(c * t, 1, frames.shape[2], frames.shape[3]),
                size=(self.spatial_size, self.spatial_size),
                mode='bilinear', align_corners=False
            ).view(c, t, self.spatial_size, self.spatial_size)

        frames = (frames - self.mean) / self.std
        return frames

In [None]:
class GestureDataset(Dataset):
    def __init__(self, annotations_csv, root_dir, num_frames=48, transform=None, split=None):
        self.root_dir = Path(root_dir)
        self.num_frames = num_frames
        self.transform = transform
        df = pd.read_csv(annotations_csv)

        if split:
            np.random.seed(42)
            indices = np.random.permutation(len(df))
            n = len(df)
            if split == 'train':
                df = df.iloc[indices[:int(0.7 * n)]]
            elif split == 'val':
                df = df.iloc[indices[int(0.7 * n):int(0.85 * n)]]
            elif split == 'test':
                df = df.iloc[indices[int(0.85 * n):]]

        self.annotations = df.reset_index(drop=True)

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        row = self.annotations.iloc[idx]
        video_path = self.root_dir / row['video_path']
        
        is_gesture = torch.tensor(row['is_gesture'], dtype=torch.float32)
        cycle_count = torch.tensor(min(row['cycle_count'], 11), dtype=torch.long)

        container = av.open(str(video_path))
        frames_list = list(container.decode(video=0))
        total_frames = len(frames_list)

        if total_frames <= self.num_frames:
            indices = np.arange(total_frames)
        else:
            indices = np.linspace(0, total_frames - 1, self.num_frames, dtype=int)

        selected = [frames_list[i].to_ndarray(format='rgb24') for i in indices if i < len(frames_list)]
        while len(selected) < self.num_frames:
            selected.append(selected[-1] if selected else np.zeros((224, 224, 3), dtype=np.uint8))

        container.close()
        frames = torch.from_numpy(np.stack(selected)).permute(3, 0, 1, 2).float() / 255.0

        if self.transform:
            frames = self.transform(frames)

        return frames, is_gesture, cycle_count

In [None]:
class TwoHeadGestureModel(nn.Module):
    def __init__(self, num_count_classes=12, pretrained=True, dropout=0.3):
        super().__init__()
        self.backbone = torch.hub.load('facebookresearch/pytorchvideo', 'slow_r50', pretrained=pretrained)
        
        in_features = self.backbone.blocks[-1].proj.in_features
        self.backbone.blocks[-1].proj = nn.Identity()
        
        self.detection_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_features, 1),
            nn.Sigmoid()
        )
        
        self.counting_head = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(in_features, num_count_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        if features.dim() > 2:
            features = features.mean(dim=[2, 3, 4])
        detection = self.detection_head(features).squeeze(-1)
        count_logits = self.counting_head(features)
        return detection, count_logits

In [None]:
train_transform = VideoTransform(spatial_size=CONFIG['spatial_size'], training=True)
val_transform = VideoTransform(spatial_size=CONFIG['spatial_size'], training=False)

train_dataset = GestureDataset(
    annotations_csv=str(DATA_DIR / 'annotations.csv'),
    root_dir=str(DATA_DIR),
    num_frames=CONFIG['frames_per_clip'],
    transform=train_transform,
    split='train'
)

val_dataset = GestureDataset(
    annotations_csv=str(DATA_DIR / 'annotations.csv'),
    root_dir=str(DATA_DIR),
    num_frames=CONFIG['frames_per_clip'],
    transform=val_transform,
    split='val'
)

train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)

print(f'Train: {len(train_dataset)}, Val: {len(val_dataset)}')

In [None]:
model = TwoHeadGestureModel(
    num_count_classes=CONFIG['num_count_classes'],
    pretrained=True,
    dropout=CONFIG['dropout']
).to(device)

bce_loss = nn.BCELoss()
ce_loss = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=CONFIG['learning_rate'], weight_decay=CONFIG['weight_decay'])
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG['epochs'])

In [None]:
def train_epoch(model, loader):
    model.train()
    total_loss, det_correct, count_correct, count_total, total = 0, 0, 0, 0, 0
    
    for frames, is_gesture, cycle_count in tqdm(loader, desc='Train'):
        frames, is_gesture, cycle_count = frames.to(device), is_gesture.to(device), cycle_count.to(device)
        
        optimizer.zero_grad()
        detection, count_logits = model(frames)
        
        detection_loss = bce_loss(detection, is_gesture)
        
        positive_mask = is_gesture == 1
        if positive_mask.any():
            counting_loss = ce_loss(count_logits[positive_mask], cycle_count[positive_mask])
        else:
            counting_loss = torch.tensor(0.0, device=device)
        
        loss = detection_loss + counting_loss
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        det_correct += ((detection > 0.5).float() == is_gesture).sum().item()
        total += is_gesture.size(0)
        
        if positive_mask.any():
            count_preds = count_logits[positive_mask].argmax(dim=-1)
            count_correct += (torch.abs(count_preds - cycle_count[positive_mask]) <= 1).sum().item()
            count_total += positive_mask.sum().item()
    
    return total_loss / len(loader), det_correct / total, count_correct / count_total if count_total > 0 else 0

def validate(model, loader):
    model.eval()
    total_loss, det_correct, count_correct, count_total, total = 0, 0, 0, 0, 0
    
    with torch.no_grad():
        for frames, is_gesture, cycle_count in tqdm(loader, desc='Val'):
            frames, is_gesture, cycle_count = frames.to(device), is_gesture.to(device), cycle_count.to(device)
            
            detection, count_logits = model(frames)
            
            detection_loss = bce_loss(detection, is_gesture)
            
            positive_mask = is_gesture == 1
            if positive_mask.any():
                counting_loss = ce_loss(count_logits[positive_mask], cycle_count[positive_mask])
            else:
                counting_loss = torch.tensor(0.0, device=device)
            
            loss = detection_loss + counting_loss
            
            total_loss += loss.item()
            det_correct += ((detection > 0.5).float() == is_gesture).sum().item()
            total += is_gesture.size(0)
            
            if positive_mask.any():
                count_preds = count_logits[positive_mask].argmax(dim=-1)
                count_correct += (torch.abs(count_preds - cycle_count[positive_mask]) <= 1).sum().item()
                count_total += positive_mask.sum().item()
    
    return total_loss / len(loader), det_correct / total, count_correct / count_total if count_total > 0 else 0

In [None]:
best_val_loss = float('inf')
patience_counter = 0

for epoch in range(CONFIG['epochs']):
    print(f'\nEpoch {epoch + 1}/{CONFIG["epochs"]}')
    train_loss, train_det_acc, train_count_acc = train_epoch(model, train_loader)
    val_loss, val_det_acc, val_count_acc = validate(model, val_loader)
    scheduler.step()

    print(f'Train Loss: {train_loss:.4f}, Det Acc: {train_det_acc:.4f}, Count ±1: {train_count_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f}, Det Acc: {val_det_acc:.4f}, Count ±1: {val_count_acc:.4f}')

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        torch.save(model.state_dict(), '/kaggle/working/best_model.pth')
        print('Saved best model')
    else:
        patience_counter += 1
        if patience_counter >= CONFIG['early_stopping_patience']:
            print('Early stopping')
            break

In [None]:
model.load_state_dict(torch.load('/kaggle/working/best_model.pth'))
model.eval()

example_input = torch.randn(1, 3, CONFIG['frames_per_clip'], CONFIG['spatial_size'], CONFIG['spatial_size']).to(device)
traced = torch.jit.trace(model, example_input)
traced.save('/kaggle/working/model_traced.pt')
print('Saved traced model')

In [None]:
test_dataset = GestureDataset(
    annotations_csv=str(DATA_DIR / 'annotations.csv'),
    root_dir=str(DATA_DIR),
    num_frames=CONFIG['frames_per_clip'],
    transform=val_transform,
    split='test'
)
test_loader = DataLoader(test_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2)

det_preds, det_labels = [], []
count_preds, count_labels = [], []

model.eval()
with torch.no_grad():
    for frames, is_gesture, cycle_count in test_loader:
        frames = frames.to(device)
        detection, count_logits = model(frames)
        
        det_preds.extend((detection > 0.5).float().cpu().numpy())
        det_labels.extend(is_gesture.numpy())
        
        positive_mask = is_gesture == 1
        if positive_mask.any():
            count_preds.extend(count_logits[positive_mask].argmax(dim=-1).cpu().numpy())
            count_labels.extend(cycle_count[positive_mask].numpy())

det_preds, det_labels = np.array(det_preds), np.array(det_labels)
count_preds, count_labels = np.array(count_preds), np.array(count_labels)

det_acc = (det_preds == det_labels).mean()
fpr = det_preds[det_labels == 0].sum() / (det_labels == 0).sum() if (det_labels == 0).sum() > 0 else 0

count_acc = (np.abs(count_preds - count_labels) <= 1).mean() if len(count_preds) > 0 else 0
mae = np.abs(count_preds - count_labels).mean() if len(count_preds) > 0 else 0

print(f'Detection Accuracy: {det_acc:.4f}')
print(f'False Positive Rate: {fpr:.4f}')
print(f'Counting ±1 Accuracy: {count_acc:.4f}')
print(f'Counting MAE: {mae:.4f}')