In [21]:
pip install pytorch-tcn torchvision pillow

Note: you may need to restart the kernel to use updated packages.


In [24]:
# Updated dataset for annotation CSV format (each split folder contains _annotations.csv + images)
import os
import csv
from typing import List, Tuple, Dict
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

# Hyperparameters (adjust as needed)
NUM_CLASSES = 3  # Rock, Paper, Scissors
SEQ_LEN = 8      # number of early frames to use / pseudo sequence length
FRAME_SIZE = 128 # resize for faster encoding
FEATURE_DIM = 64 # per-frame embedding size (becomes num_inputs for TCN)
CHANNELS = [64, 64]  # TCN hidden channels
KERNEL_SIZE = 3
DROPOUT = 0.1
EARLY_WEIGHT_MODE = 'linear'  # 'linear' or 'exp'
ALPHA_EXP = 2.0  # used if EARLY_WEIGHT_MODE == 'exp'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Class names as they appear in the CSV
CLASS_MAP = {'Rock':0,'Paper':1,'Scissors':2}

# Augmentations; different random transform per repeated frame to simulate temporal evolution
base_transform = transforms.Compose([
    transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.25, hue=0.05),
    transforms.RandomRotation(15),
    transforms.ToTensor(),
])

noaug_transform = transforms.Compose([
    transforms.Resize((FRAME_SIZE, FRAME_SIZE)),
    transforms.ToTensor(),
])

class AnnotatedGestureDataset(Dataset):
    """Dataset reading Roboflow-style _annotations.csv with columns:
    filename,width,height,class,xmin,ymin,xmax,ymax
    If multiple boxes per image, picks the largest area.
    Returns pseudo sequence by repeating cropped region with augmentation.
    """
    def __init__(self, root: str, seq_len: int = SEQ_LEN, augment: bool = True, annotations_file: str = '_annotations.csv'):
        self.root = root
        self.seq_len = seq_len
        self.augment = augment
        self.annotations_path = os.path.join(root, annotations_file)
        self.transform = base_transform if augment else noaug_transform
        self.samples: List[Tuple[str,int,Tuple[int,int,int,int]]] = []  # (image_path,label,(xmin,ymin,xmax,ymax))
        if not os.path.isfile(self.annotations_path):
            raise FileNotFoundError(f'Annotations file not found: {self.annotations_path}')
        # aggregate boxes per filename
        boxes_per_file: Dict[str, List[Tuple[int,int,int,int,int]]] = {}
        with open(self.annotations_path, 'r', newline='') as f:
            reader = csv.DictReader(f)
            for row in reader:
                fname = row['filename']
                cls = row['class']
                if cls not in CLASS_MAP:
                    continue
                xmin = int(float(row['xmin']))
                ymin = int(float(row['ymin']))
                xmax = int(float(row['xmax']))
                ymax = int(float(row['ymax']))
                area = (xmax - xmin) * (ymax - ymin)
                boxes_per_file.setdefault(fname, []).append((area, xmin, ymin, xmax, ymax, CLASS_MAP[cls]))
        for fname, box_list in boxes_per_file.items():
            # choose largest box
            box_list.sort(key=lambda x: x[0], reverse=True)
            _, xmin, ymin, xmax, ymax, label = box_list[0]
            img_path = os.path.join(root, fname)
            if os.path.isfile(img_path):
                self.samples.append((img_path, label, (xmin, ymin, xmax, ymax)))
        if len(self.samples) == 0:
            raise RuntimeError(f'No samples parsed from {self.annotations_path}. Check class names or paths.')

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx: int):
        path, label, (xmin, ymin, xmax, ymax) = self.samples[idx]
        img = Image.open(path).convert('RGB')
        # clamp box
        xmin = max(0, xmin); ymin = max(0, ymin)
        xmax = min(img.width, xmax); ymax = min(img.height, ymax)
        crop = img.crop((xmin, ymin, xmax, ymax))
        frames = []
        for t in range(self.seq_len):
            frames.append(self.transform(crop))
        seq = torch.stack(frames, dim=0)  # (T,C,H,W)
        return seq, label

class FrameEncoder(nn.Module):
    def __init__(self, out_dim=FEATURE_DIM):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=2, padding=1), nn.ReLU(),
            nn.AdaptiveAvgPool2d(1),
        )
        self.proj = nn.Linear(64, out_dim)
    def forward(self, x):
        z = self.net(x).view(x.size(0), -1)
        return self.proj(z)

class EarlyGestureTCN(nn.Module):
    def __init__(self, feature_dim=FEATURE_DIM, channels=CHANNELS, kernel_size=KERNEL_SIZE, dropout=DROPOUT, num_classes=NUM_CLASSES):
        super().__init__()
        self.encoder = FrameEncoder(feature_dim)
        from pytorch_tcn import TCN as LibTCN
        self.tcn = LibTCN(
            num_inputs=feature_dim,
            num_channels=channels,
            kernel_size=kernel_size,
            dropout=dropout,
            causal=True,
            use_norm='weight_norm',
            activation='relu',
            input_shape='NCL'
        )
        self.classifier = nn.Linear(channels[-1], num_classes)

    def forward(self, seq_frames):
        B, T, C, H, W = seq_frames.shape
        frames = seq_frames.view(B*T, C, H, W)
        feats = self.encoder(frames)            # (B*T,F)
        feats = feats.view(B, T, -1).transpose(1,2)  # (B,F,T)
        tcn_out = self.tcn(feats)               # (B,F_last,T)
        logits_time = self.classifier(tcn_out.transpose(1,2))  # (B,T,num_classes)
        return logits_time

    def predict_early(self, seq_frames, thresh: float = 0.6):
        with torch.no_grad():
            logits_time = self.forward(seq_frames)
            probs_time = torch.softmax(logits_time, dim=-1)
            preds = []
            for b in range(probs_time.size(0)):
                earliest = None
                for t in range(probs_time.size(1)):
                    p = probs_time[b,t]
                    if p.max().item() >= thresh:
                        earliest = (t, p.argmax().item(), p.max().item())
                        break
                if earliest is None:
                    p = probs_time[b,-1]
                    earliest = (probs_time.size(1)-1, p.argmax().item(), p.max().item())
                preds.append(earliest)
        return preds, probs_time

# Time weighting for early emphasis
def time_weights(T: int, mode: str = EARLY_WEIGHT_MODE, alpha: float = ALPHA_EXP, device=DEVICE):
    if mode == 'linear':
        w = torch.linspace(1.0, 0.3, steps=T)
    elif mode == 'exp':
        t = torch.arange(T)
        w = torch.exp(-alpha * t / T)
    else:
        w = torch.ones(T)
    w = w / w.sum() * T
    return w.to(device)

# Early loss function
def compute_loss(logits_time, labels):
    B,T,C = logits_time.shape
    weights = time_weights(T)
    loss_fn = nn.CrossEntropyLoss(reduction='none')
    labels_time = labels.unsqueeze(1).expand(B,T).reshape(B*T)
    logits_flat = logits_time.view(B*T, C)
    per_frame_loss = loss_fn(logits_flat, labels_time).view(B,T)
    weighted = per_frame_loss * weights
    return weighted.mean(), per_frame_loss.mean(dim=0).detach(), weights.detach()

# Instantiate model (re-instantiated if cell re-run)
model_early = EarlyGestureTCN().to(DEVICE)
print('Model (updated) parameters:', sum(p.numel() for p in model_early.parameters()))

Model (updated) parameters: 110339


In [25]:
# Training loop adapted for annotation CSV dataset
import math, time
from torch.optim import Adam

DATA_ROOT = os.path.join(os.getcwd(), 'Rock Paper Scissors SXSW.v14i.tensorflow', 'train')
VAL_ROOT = os.path.join(os.getcwd(), 'Rock Paper Scissors SXSW.v14i.tensorflow', 'valid')

train_ds = AnnotatedGestureDataset(DATA_ROOT, seq_len=SEQ_LEN, augment=True)
val_ds   = AnnotatedGestureDataset(VAL_ROOT,   seq_len=SEQ_LEN, augment=False)
print('Train samples:', len(train_ds), 'Val samples:', len(val_ds))

BATCH_SIZE = 16
LR = 1e-3
EPOCHS = 5

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

optimizer = Adam(model_early.parameters(), lr=LR)
best_val_acc = 0.0

for epoch in range(1, EPOCHS+1):
    model_early.train()
    total_loss = 0.0
    batches = 0
    for seq, label in train_loader:
        seq = seq.to(DEVICE)
        label = label.to(DEVICE)
        optimizer.zero_grad()
        logits_time = model_early(seq)
        loss, _, _ = compute_loss(logits_time, label)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        batches += 1
    avg_loss = total_loss / max(1,batches)

    # Validation metrics including early coverage and average first correct timestep
    model_early.eval()
    correct_final = 0
    correct_early = 0
    total = 0
    time_to_first_correct = []
    with torch.no_grad():
        for seq, label in val_loader:
            seq = seq.to(DEVICE)
            label = label.to(DEVICE)
            logits_time = model_early(seq)
            probs_time = torch.softmax(logits_time, dim=-1)  # (B,T,C)
            final_preds = probs_time[:,-1].argmax(dim=-1)
            correct_final += (final_preds == label).sum().item()
            B = seq.size(0)
            for b in range(B):
                gt = label[b].item()
                earliest = None
                for t in range(probs_time.size(1)):
                    pred_t = probs_time[b,t].argmax().item()
                    if pred_t == gt:
                        earliest = t
                        break
                if earliest is not None:
                    correct_early += 1
                    time_to_first_correct.append(earliest+1)
            total += B
    final_acc = correct_final / max(1,total)
    early_cov = correct_early / max(1,total)
    avg_time_first = sum(time_to_first_correct)/len(time_to_first_correct) if time_to_first_correct else math.nan

    print(f"Epoch {epoch}: train_loss={avg_loss:.4f} val_final_acc={final_acc:.3f} early_cover={early_cov:.3f} avg_first_correct_t={avg_time_first}")

    if final_acc > best_val_acc:
        best_val_acc = final_acc
        torch.save({'model':model_early.state_dict()}, 'best_early_tcn.pt')
        print('Saved new best model.')

print('Training complete. Best final acc:', best_val_acc)

Train samples: 3939 Val samples: 338


KeyboardInterrupt: 

In [None]:
# Streaming / real-time inference utilities
import collections
from torchvision import transforms as T

stream_transform = T.Compose([
    T.Resize((FRAME_SIZE, FRAME_SIZE)),
    T.ToTensor(),
])

class StreamingEarlyPredictor:
    def __init__(self, model: EarlyGestureTCN, seq_len: int = SEQ_LEN, threshold: float = 0.6):
        self.model = model
        self.seq_len = seq_len
        self.threshold = threshold
        self.buffer = collections.deque(maxlen=seq_len)
        self.model.eval()

    def reset(self):
        self.buffer.clear()
        # If using library TCN internal buffers for causal conv, we can also call:
        self.model.tcn.reset_buffers()

    def update(self, frame_pil: Image.Image):
        tensor = stream_transform(frame_pil).unsqueeze(0)  # (1,C,H,W)
        self.buffer.append(tensor)
        if len(self.buffer) < 2:
            return None  # need at least 2 frames maybe
        # Build sequence (1,T,C,H,W)
        seq = torch.cat(list(self.buffer), dim=0).unsqueeze(0).to(DEVICE)
        with torch.no_grad():
            logits_time = self.model(seq)
            probs_time = torch.softmax(logits_time, dim=-1)[0]  # (T,C)
        # earliest confident
        for t in range(probs_time.size(0)):
            p = probs_time[t]
            if p.max().item() >= self.threshold:
                return {'timestep': t+1, 'pred': int(p.argmax().item()), 'confidence': float(p.max().item())}
        return None

# Example usage (pseudo):
# predictor = StreamingEarlyPredictor(model_early)
# for frame in webcam_frames():
#     result = predictor.update(frame)
#     if result:
#         print('EARLY PRED', result)
#         # optionally reset after a stable prediction


In [None]:
# # Keras TCN alternative (only if you later switch to TensorFlow).
# # NOTE: This is illustrative; run in a TF environment. Here we just store the code.
# keras_tcn_code = r"""
# import tensorflow as tf
# from tensorflow import keras
# from tensorflow.keras import layers

# NUM_CLASSES = 3
# SEQ_LEN = 8
# FEATURE_DIM = 64  # if you extract per-frame CNN features separately

# # Frame encoder (simple) - assuming input (H,W,3)
# def build_frame_encoder():
#     inputs = keras.Input(shape=(128,128,3))
#     x = layers.Conv2D(32,3,strides=2,padding='same',activation='relu')(inputs)
#     x = layers.Conv2D(64,3,strides=2,padding='same',activation='relu')(x)
#     x = layers.Conv2D(64,3,strides=2,padding='same',activation='relu')(x)
#     x = layers.GlobalAveragePooling2D()(x)
#     outputs = layers.Dense(FEATURE_DIM)(x)
#     return keras.Model(inputs, outputs, name='frame_encoder')

# frame_encoder = build_frame_encoder()

# # TCN block (simplified causal dilated conv stack)
# def tcn_block(x, filters, kernel_size, dilation_rate, dropout):
#     prev = x
#     x = layers.Conv1D(filters, kernel_size, padding='causal', dilation_rate=dilation_rate, activation='relu')(x)
#     x = layers.Dropout(dropout)(x)
#     x = layers.Conv1D(filters, kernel_size, padding='causal', dilation_rate=dilation_rate, activation='relu')(x)
#     if prev.shape[-1] != filters:
#         prev = layers.Conv1D(filters, 1, padding='same')(prev)
#     return layers.Add()([prev, x])

# # Full model: input (T, H, W, 3)
# video_inputs = keras.Input(shape=(SEQ_LEN,128,128,3))
# # TimeDistributed encoding
# encoded = layers.TimeDistributed(frame_encoder)(video_inputs)  # (B,T,F)
# # TCN stacks
# x = encoded
# for d in [1,2,4]:
#     x = tcn_block(x, 64, 3, dilation_rate=d, dropout=0.1)
# # Per-timestep classification
# logits = layers.TimeDistributed(layers.Dense(NUM_CLASSES))(x)  # (B,T,C)
# # Optionally final timestep output
# final_logits = layers.Lambda(lambda t: t[:,-1])(logits)
# model = keras.Model(video_inputs, [logits, final_logits])
# model.compile(optimizer='adam', loss=[keras.losses.CategoricalCrossentropy(from_logits=True), keras.losses.CategoricalCrossentropy(from_logits=True)], loss_weights=[1.0, 0.3])
# print(model.summary())
# """
# print('Stored Keras TCN example in variable keras_tcn_code (string).')