In [2]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from typing import List
import imageio
import matplotlib.pyplot as plt

In [3]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [5]:
vocab = list("abcdefghijklmnopqrstuvwxyz'?!123456789 ")
char_to_idx = {c: i for i, c in enumerate(vocab)}
idx_to_char = {i: c for c, i in char_to_idx.items()}

In [7]:
def norm_frames(frames):
    frames = np.array(frames).astype(np.float32)
    mean = frames.mean()
    std = frames.std()
    return (frames - mean) / std

In [10]:
def load_vid(path: str):
    cap = cv2.VideoCapture(path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        cropped = gray[190:236, 80:220]
        frames.append(cropped)
    cap.release()
    return norm_frames(frames)[..., np.newaxis]

In [11]:
def load_ali(path: str):
    tokens = []
    with open(path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if parts[2] != 'sil':
                tokens.append(' ')
                tokens.append(parts[2])
    return [char_to_idx[c] for word in tokens for c in word]

In [12]:
class LipReadingDataset(Dataset):
    def __init__(self, data_dir):
        self.video_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.mpg')]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        base_name = os.path.splitext(os.path.basename(video_path))[0]
        align_path = os.path.join('data', 'alignments', 's1', f'{base_name}.align')

        video = load_video(video_path)
        alignment = load_alignments(align_path)
        return torch.tensor(video).permute(3, 0, 1, 2), torch.tensor(alignment)

In [29]:
def collate(batch):
    videos, labels = zip(*batch)
    
    # Find max sequence length (frames) and label length
    max_video_len = max(v.shape[1] for v in videos)
    max_label_len = max(l.shape[0] for l in labels)

    # Pad videos and labels
    padded_videos = []
    padded_labels = []
    input_lengths = []
    label_lengths = []

    for v, l in zip(videos, labels):
        pad_len = max_video_len - v.shape[1]
        padded_v = F.pad(v, (0, 0, 0, 0, 0, 0, 0, pad_len))  # pad time dimension
        padded_l = F.pad(l, (0, max_label_len - l.shape[0]), value=0)

        padded_videos.append(padded_v)
        padded_labels.append(padded_l)
        input_lengths.append(v.shape[1])
        label_lengths.append(l.shape[0])

    return (
        torch.stack(padded_videos),
        torch.stack(padded_labels),
        torch.tensor(input_lengths),
        torch.tensor(label_lengths),
    )


In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LipReadingModel(nn.Module):
    def __init__(self, vocab_size):
        super(LipReadingModel2D, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), 
        )

        self.lstm_input_size = 256 * 5 * 17  
        self.lstm = nn.LSTM(self.lstm_input_size, 256, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, vocab_size)

    def forward(self, x):
        B, T, C, H, W = x.shape

        x = x.view(B * T, C, H, W) 
        x = self.conv(x)         
        x = x.view(B, T, -1)     
        
        x, _ = self.lstm(x)         
        x = self.fc(x)           
        return x


In [17]:
def decode(logits):
    pred = torch.argmax(logits, dim=-1)
    results = []
    for p in pred:
        chars = [idx_to_char[idx.item()] for idx in p]
        collapsed = []
        prev = None
        for c in chars:
            if c != prev and c != '_':
                collapsed.append(c)
            prev = c
        results.append(''.join(collapsed))
    return results

In [27]:
def train(model, train_loader, val_loader, opt, criterion, epochs):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for videos, labels, input_lengths, label_lengths in train_loader:
            videos, labels = videos.to(device), labels.to(device)
            opt.zero_grad()
            output = model(videos)  # [B, T, V]
            output = output.log_softmax(2).permute(1, 0, 2)  # For CTC Loss: [T, B, V]
            loss = criterion(output, labels, input_lengths, label_lengths)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

        model.eval()
        with torch.no_grad():
            for videos, labels, input_lengths, label_lengths in val_loader:
                videos = videos.to(device)
                output = model(videos)
                decoded = greedy_decode(output)
                print("Predicted:", decoded[0])
                print("Actual:", ''.join([idx_to_char[i.item()] for i in labels[0] if i.item() in idx_to_char]))
                break

In [30]:
train_dataset = LipReadingDataset('./data/s1')
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate)

val_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, collate_fn=collate)  # Can split for actual val

model = LipReadingModel(vocab_size=len(vocab))
opt = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CTCLoss(blank=len(vocab))

train(model, train_loader, val_loader, opt, criterion, 10)


NotImplementedError: The operator 'aten::max_pool3d_with_indices' is not currently implemented for the MPS device. If you want this op to be considered for addition please comment on https://github.com/pytorch/pytorch/issues/141287 and mention use-case, that resulted in missing op as well as commit hash 134179474539648ba7dee1317959529fbd0e7f89. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.