In [1]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from typing import List
import imageio
import matplotlib.pyplot as plt

In [2]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [5]:
vocab = list("abcdefghijklmnopqrstuvwxyz'?!123456789 ")
char_to_idx = {c: i for i, c in enumerate(vocab)}
idx_to_char = {i: c for c, i in char_to_idx.items()}

In [7]:
def norm_frames(frames):
    frames = np.array(frames).astype(np.float32)
    mean = frames.mean()
    std = frames.std()
    return (frames - mean) / std

In [10]:
def load_vid(path: str):
    cap = cv2.VideoCapture(path)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
        cropped = gray[190:236, 80:220]
        frames.append(cropped)
    cap.release()
    return norm_frames(frames)[..., np.newaxis]

In [11]:
def load_ali(path: str):
    tokens = []
    with open(path, 'r') as f:
        for line in f:
            parts = line.strip().split()
            if parts[2] != 'sil':
                tokens.append(' ')
                tokens.append(parts[2])
    return [char_to_idx[c] for word in tokens for c in word]

In [12]:
class LipReadingDataset(Dataset):
    def __init__(self, data_dir):
        self.video_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.mpg')]

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        base_name = os.path.splitext(os.path.basename(video_path))[0]
        align_path = os.path.join('data', 'alignments', 's1', f'{base_name}.align')

        video = load_video(video_path)
        alignment = load_alignments(align_path)
        return torch.tensor(video).permute(3, 0, 1, 2), torch.tensor(alignment)

In [29]:
def collate(batch):
    videos, labels = zip(*batch)
    
    # Find max sequence length (frames) and label length
    max_video_len = max(v.shape[1] for v in videos)
    max_label_len = max(l.shape[0] for l in labels)

    # Pad videos and labels
    padded_videos = []
    padded_labels = []
    input_lengths = []
    label_lengths = []

    for v, l in zip(videos, labels):
        pad_len = max_video_len - v.shape[1]
        padded_v = F.pad(v, (0, 0, 0, 0, 0, 0, 0, pad_len))  # pad time dimension
        padded_l = F.pad(l, (0, max_label_len - l.shape[0]), value=0)

        padded_videos.append(padded_v)
        padded_labels.append(padded_l)
        input_lengths.append(v.shape[1])
        label_lengths.append(l.shape[0])

    return (
        torch.stack(padded_videos),
        torch.stack(padded_labels),
        torch.tensor(input_lengths),
        torch.tensor(label_lengths),
    )


In [33]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LipReadingModel(nn.Module):
    def __init__(self, vocab_size):
        super(LipReadingModel2D, self).__init__()

        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1), 
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), 
        )

        self.lstm_input_size = 256 * 5 * 17  
        self.lstm = nn.LSTM(self.lstm_input_size, 256, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, vocab_size)

    def forward(self, x):
        B, T, C, H, W = x.shape

        x = x.view(B * T, C, H, W) 
        x = self.conv(x)         
        x = x.view(B, T, -1)     
        
        x, _ = self.lstm(x)         
        x = self.fc(x)           
        return x


In [17]:
def decode(logits):
    pred = torch.argmax(logits, dim=-1)
    results = []
    for p in pred:
        chars = [idx_to_char[idx.item()] for idx in p]
        collapsed = []
        prev = None
        for c in chars:
            if c != prev and c != '_':
                collapsed.append(c)
            prev = c
        results.append(''.join(collapsed))
    return results

In [27]:
def train(model, train_loader, val_loader, opt, criterion, epochs):
    model.to(device)
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for videos, labels, input_lengths, label_lengths in train_loader:
            videos, labels = videos.to(device), labels.to(device)
            opt.zero_grad()
            output = model(videos)  # [B, T, V]
            output = output.log_softmax(2).permute(1, 0, 2)  # For CTC Loss: [T, B, V]
            loss = criterion(output, labels, input_lengths, label_lengths)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")

        model.eval()
        with torch.no_grad():
            for videos, labels, input_lengths, label_lengths in val_loader:
                videos = videos.to(device)
                output = model(videos)
                decoded = greedy_decode(output)
                print("Predicted:", decoded[0])
                print("Actual:", ''.join([idx_to_char[i.item()] for i in labels[0] if i.item() in idx_to_char]))
                break

In [3]:
class Dataset(Dataset):
    def __init__(self, video_dir, alignment_dir, vocab):
        self.video_paths = []
        self.align_paths = []
        self.vocab = vocab

        for file in os.listdir(video_dir):
            if file.endswith('.mpg'):
                self.video_paths.append(os.path.join(video_dir, file))
                align_file = file.replace('.mpg', '.align')
                self.align_paths.append(os.path.join(alignment_dir, align_file))

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        align_path = self.align_paths[idx]

        cap = cv2.VideoCapture(video_path)
        frames = []
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
            gray = cv2.resize(gray, (140, 46))
            frames.append(gray)
        cap.release()

        frames = np.stack(frames)
        frames = torch.tensor(frames, dtype=torch.float32) / 255.0
        frames = frames.unsqueeze(1)

        with open(align_path, 'r') as f:
            words = []
            for line in f:
                parts = line.strip().split()
                if len(parts) == 3 and parts[2] not in ['sil', 'sp']:
                    words.append(parts[2])
        
        label = [self.vocab[c] for word in words for c in word]
        label = torch.tensor(label, dtype=torch.long)

        return frames, label

In [4]:
def collate_fn(batch):
    videos, labels = zip(*batch)

    max_len = max(v.size(0) for v in videos)
    padded_videos = []
    for v in videos:
        pad = torch.zeros(max_len - v.size(0), *v.shape[1:])
        padded = torch.cat([v, pad], dim=0)
        padded_videos.append(padded)
    videos = torch.stack(padded_videos)

    label_lengths = torch.tensor([len(l) for l in labels], dtype=torch.long)
    labels = torch.cat(labels)
    input_lengths = torch.full((len(videos),), fill_value=max_len, dtype=torch.long)

    return videos, labels, input_lengths, label_lengths


In [5]:
class LipReadingModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 64, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2), 

            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  

            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
        )
        self.lstm = nn.LSTM(256 * 5 * 17, 256, num_layers=2, bidirectional=True, batch_first=True)
        self.fc = nn.Linear(512, vocab_size)

    def forward(self, x): 
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        x = self.conv(x)      
        x = x.view(B, T, -1)    
        x, _ = self.lstm(x)         
        x = self.fc(x)               
        return x


In [6]:
def train(model, train_loader, optimizer, criterion, epochs, device):
    model.to(device)

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        for videos, labels, input_lengths, label_lengths in train_loader:
            videos = videos.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            outputs = model(videos)                         # (B, T, V)
            outputs = F.log_softmax(outputs, dim=2).permute(1, 0, 2)  # (T, B, V)

            loss = criterion(
            outputs.cpu(),
            labels.cpu(),
            input_lengths.cpu(),
            label_lengths.cpu()
            )

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}, Loss: {total_loss:.4f}")


In [None]:
from torch.utils.data import DataLoader

vocab_chars = sorted(set("abcdefghijklmnopqrstuvwxyz ")) 
vocab = {c: i for i, c in enumerate(vocab_chars)}
vocab_size = len(vocab)

dataset = Dataset(
    video_dir="./data/s1",
    alignment_dir="./data/alignments/s1",
    vocab=vocab
)
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = LipReadingModel(vocab_size)
criterion = nn.CTCLoss(blank=vocab_size - 1)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

train(model, loader, optimizer, criterion, epochs=30, device=device)


In [8]:
torch.save(model.state_dict(), 'saved_model.pth')

In [9]:
vocab_chars = sorted(set("abcdefghijklmnopqrstuvwxyz ")) 
vocab = {c: i for i, c in enumerate(vocab_chars)}
vocab_size = len(vocab)

dataset = Dataset(
    video_dir="./data/s1",
    alignment_dir="./data/alignments/s1",
    vocab=vocab
)
loader = DataLoader(dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

In [10]:
model = LipReadingModel(vocab_size)
model.load_state_dict(torch.load('saved_model.pth'))
model.to(device)
model.eval()


LipReadingModel(
  (conv): Sequential(
    (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (lstm): LSTM(21760, 256, num_layers=2, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=512, out_features=27, bias=True)
)

In [11]:
#TESTING MODEL
def greedy_decode(output, blank_idx):
    # output: (T, B, V)
    output = output.permute(1, 0, 2)  # (B, T, V)
    pred_sequences = []
    
    for batch in output:
        pred = torch.argmax(batch, dim=1).cpu().numpy()
        prev = -1
        decoded = []
        for p in pred:
            if p != prev and p != blank_idx:
                decoded.append(p)
            prev = p
        pred_sequences.append(decoded)
    return pred_sequences
def indices_to_text(indices, inv_vocab):
    return ''.join([inv_vocab[i] for i in indices])


In [12]:
inv_vocab = {i: c for c, i in vocab.items()}

In [None]:
with torch.no_grad():
    sample, label = dataset[0]
    input_len = torch.tensor([sample.shape[0]])
    label_len = torch.tensor([label.shape[0]])

    sample = sample.unsqueeze(0).to(device)
    output = model(sample) 
    output = F.log_softmax(output, dim=2).permute(1, 0, 2)  

    pred_indices = greedy_decode(output, blank_idx=vocab_size - 1)[0]
    pred_text = indices_to_text(pred_indices, inv_vocab)

    true_text = indices_to_text(label.tolist(), inv_vocab)

    print(f"\nPredicted: {pred_text}")
    print(f"Ground Truth: {true_text}")
