# Digital Witness - LSTM Model Training

This notebook trains the LSTM behavior classifier for the Digital Witness system.

**Pipeline:** Video → CNN (ResNet18) → Features → LSTM → Classification

**Classes:** `normal` (0) | `shoplifting` (1)

## 1. Mount Google Drive & Setup

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Check if zip file exists
import os
zip_path = "/content/drive/MyDrive/DigitalWitness.zip"

if os.path.exists(zip_path):
    print("Found DigitalWitness.zip in Google Drive!")
else:
    print("ERROR: DigitalWitness.zip not found!")
    print(f"Expected at: {zip_path}")

In [None]:
# Extract project files
!rm -rf /content/Project_DigitalWitness  # Clean previous extraction
!unzip -q "/content/drive/MyDrive/DigitalWitness.zip" -d /content/

# List extracted contents
print("Extracted files:")
!ls /content/Project_DigitalWitness/

## 2. Install Dependencies

In [None]:
!pip install -q torch torchvision ultralytics opencv-python-headless mediapipe numpy

import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 3. Check Training Data

In [None]:
from pathlib import Path

project_root = Path("/content/Project_DigitalWitness")
normal_dir = project_root / "data" / "training" / "normal"
shoplifting_dir = project_root / "data" / "training" / "shoplifting"

print("=" * 50)
print("TRAINING DATA")
print("=" * 50)

normal_videos = list(normal_dir.glob("*.mp4")) if normal_dir.exists() else []
shoplifting_videos = list(shoplifting_dir.glob("*.mp4")) if shoplifting_dir.exists() else []

print(f"\nNormal videos: {len(normal_videos)}")
for v in normal_videos[:5]:
    print(f"  - {v.name}")
if len(normal_videos) > 5:
    print(f"  ... and {len(normal_videos) - 5} more")

print(f"\nShoplifting videos: {len(shoplifting_videos)}")
for v in shoplifting_videos[:5]:
    print(f"  - {v.name}")
if len(shoplifting_videos) > 5:
    print(f"  ... and {len(shoplifting_videos) - 5} more")

if len(normal_videos) == 0 or len(shoplifting_videos) == 0:
    print("\n" + "!" * 50)
    print("ERROR: Training videos not found!")
    print(f"Expected in: {normal_dir}")
    print(f"Expected in: {shoplifting_dir}")
    print("!" * 50)
else:
    print("\n" + "=" * 50)
    print("Training data ready!")
    print("=" * 50)

## 4. Training Configuration

Adjust these parameters as needed:

In [None]:
# ============================================================
# TRAINING CONFIGURATION - Adjust these as needed
# ============================================================

# Training parameters
EPOCHS = 50                    # Number of training epochs
BATCH_SIZE = 32                # Batch size (reduce if OOM error)
LEARNING_RATE = 0.001          # Learning rate
VAL_SPLIT = 0.2                # Validation split (20%)
EARLY_STOPPING = 10            # Stop if no improvement for N epochs

# Feature extraction parameters
SEQUENCE_LENGTH = 30           # Frames per sequence (~1 sec at 30fps)
SLIDING_STRIDE = 15            # Overlap between sequences (50%)
PROCESS_EVERY_FRAME = True     # Process all frames (recommended)

# LSTM architecture
LSTM_HIDDEN_DIM = 256          # Hidden state dimension
LSTM_NUM_LAYERS = 2            # Number of LSTM layers
LSTM_DROPOUT = 0.3             # Dropout for regularization

# Limit videos (set to None for all, or a number for quick test)
MAX_VIDEOS_PER_CLASS = None    # e.g., 5 for quick test

print("Configuration set!")
print(f"  - Epochs: {EPOCHS}")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Sequence length: {SEQUENCE_LENGTH} frames")
print(f"  - Process every frame: {PROCESS_EVERY_FRAME}")

## 5. Initialize CNN Feature Extractor

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
import cv2
import numpy as np

class CNNFeatureExtractor:
    """ResNet18 CNN for extracting spatial features from video frames."""
    
    def __init__(self, device='auto'):
        if device == 'auto':
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = torch.device(device)
        
        # Load pretrained ResNet18 and remove classification head
        resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        self.model = nn.Sequential(*list(resnet.children())[:-1])  # Remove FC layer
        self.model = self.model.to(self.device)
        self.model.eval()
        
        # ImageNet normalization
        self.transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]
            )
        ])
        
        self.feature_dim = 512
        print(f"CNN initialized on {self.device}")
    
    def extract_features(self, frame):
        """Extract 512-dim feature vector from a single frame."""
        # Convert BGR to RGB
        if len(frame.shape) == 3 and frame.shape[2] == 3:
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        else:
            frame_rgb = frame
        
        # Transform and extract features
        input_tensor = self.transform(frame_rgb).unsqueeze(0).to(self.device)
        
        with torch.no_grad():
            features = self.model(input_tensor)
        
        return features.squeeze().cpu().numpy()

# Initialize CNN
cnn = CNNFeatureExtractor()
print(f"Feature dimension: {cnn.feature_dim}")

## 6. Extract Features from Training Videos

In [None]:
from pathlib import Path
import cv2
import numpy as np
from tqdm import tqdm

def extract_features_from_videos(video_paths, cnn, sequence_length, stride, process_every_frame=True):
    """
    Extract CNN features from videos and create sequences for LSTM.
    
    Args:
        video_paths: List of video file paths
        cnn: CNN feature extractor
        sequence_length: Frames per sequence
        stride: Stride between sequences
        process_every_frame: If True, process all frames (recommended for accuracy)
    
    Returns:
        List of feature sequences (each shape: seq_length x 512)
    """
    all_sequences = []
    
    for video_path in tqdm(video_paths, desc="Processing videos"):
        cap = cv2.VideoCapture(str(video_path))
        if not cap.isOpened():
            print(f"  Warning: Could not open {video_path.name}")
            continue
        
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Extract features from frames
        frame_features = []
        frame_num = 0
        
        while True:
            ret, frame = cap.read()
            if not ret:
                break
            
            # Process every frame for better accuracy (or skip for speed)
            if process_every_frame or frame_num % 2 == 0:
                features = cnn.extract_features(frame)
                frame_features.append(features)
            
            frame_num += 1
        
        cap.release()
        
        if not frame_features:
            continue
        
        # Create sequences with sliding window
        feature_array = np.array(frame_features)
        num_frames = len(frame_features)
        
        # Handle short videos by padding
        if num_frames < sequence_length:
            padding = np.zeros((sequence_length - num_frames, feature_array.shape[1]))
            feature_array = np.vstack([feature_array, padding])
            all_sequences.append(feature_array)
        else:
            # Create sliding window sequences
            for start in range(0, num_frames - sequence_length + 1, stride):
                end = start + sequence_length
                sequence = feature_array[start:end]
                all_sequences.append(sequence)
    
    return all_sequences

print("Feature extraction function ready!")

In [None]:
# Extract features from normal videos
print("=" * 50)
print("EXTRACTING FEATURES FROM NORMAL VIDEOS")
print("=" * 50)

normal_dir = Path("/content/Project_DigitalWitness/data/training/normal")
normal_videos = list(normal_dir.glob("*.mp4"))

if MAX_VIDEOS_PER_CLASS:
    normal_videos = normal_videos[:MAX_VIDEOS_PER_CLASS]
    print(f"Limited to {MAX_VIDEOS_PER_CLASS} videos for testing")

normal_sequences = extract_features_from_videos(
    normal_videos, 
    cnn, 
    SEQUENCE_LENGTH, 
    SLIDING_STRIDE,
    PROCESS_EVERY_FRAME
)

print(f"\nExtracted {len(normal_sequences)} sequences from {len(normal_videos)} normal videos")

In [None]:
# Extract features from shoplifting videos
print("=" * 50)
print("EXTRACTING FEATURES FROM SHOPLIFTING VIDEOS")
print("=" * 50)

shoplifting_dir = Path("/content/Project_DigitalWitness/data/training/shoplifting")
shoplifting_videos = list(shoplifting_dir.glob("*.mp4"))

if MAX_VIDEOS_PER_CLASS:
    shoplifting_videos = shoplifting_videos[:MAX_VIDEOS_PER_CLASS]
    print(f"Limited to {MAX_VIDEOS_PER_CLASS} videos for testing")

shoplifting_sequences = extract_features_from_videos(
    shoplifting_videos, 
    cnn, 
    SEQUENCE_LENGTH, 
    SLIDING_STRIDE,
    PROCESS_EVERY_FRAME
)

print(f"\nExtracted {len(shoplifting_sequences)} sequences from {len(shoplifting_videos)} shoplifting videos")

In [None]:
# Prepare training data
print("=" * 50)
print("PREPARING TRAINING DATA")
print("=" * 50)

# Combine sequences and create labels
# Class 0 = normal, Class 1 = shoplifting
all_sequences = normal_sequences + shoplifting_sequences
all_labels = [0] * len(normal_sequences) + [1] * len(shoplifting_sequences)

print(f"\nTotal sequences: {len(all_sequences)}")
print(f"  - Normal (class 0): {len(normal_sequences)}")
print(f"  - Shoplifting (class 1): {len(shoplifting_sequences)}")

# Check class balance
balance_ratio = min(len(normal_sequences), len(shoplifting_sequences)) / max(len(normal_sequences), len(shoplifting_sequences))
print(f"\nClass balance ratio: {balance_ratio:.2f}")
if balance_ratio < 0.5:
    print("  Warning: Classes are imbalanced. Consider adding more data to minority class.")

if len(all_sequences) < 10:
    print("\nERROR: Not enough training data! Need at least 10 sequences.")
else:
    print("\nData preparation complete!")

In [None]:
# Split into training and validation sets
from sklearn.model_selection import train_test_split

# Shuffle and split
indices = np.arange(len(all_sequences))
np.random.seed(42)  # For reproducibility
np.random.shuffle(indices)

val_size = int(len(indices) * VAL_SPLIT)
val_indices = indices[:val_size]
train_indices = indices[val_size:]

train_sequences = [all_sequences[i] for i in train_indices]
train_labels = [all_labels[i] for i in train_indices]
val_sequences = [all_sequences[i] for i in val_indices]
val_labels = [all_labels[i] for i in val_indices]

print(f"Training samples: {len(train_sequences)}")
print(f"Validation samples: {len(val_sequences)}")

# Check label distribution
train_normal = sum(1 for l in train_labels if l == 0)
train_shoplifting = sum(1 for l in train_labels if l == 1)
print(f"\nTraining set distribution:")
print(f"  - Normal: {train_normal} ({train_normal/len(train_labels)*100:.1f}%)")
print(f"  - Shoplifting: {train_shoplifting} ({train_shoplifting/len(train_labels)*100:.1f}%)")

## 7. Define LSTM Model

In [None]:
import torch
import torch.nn as nn

class LSTMClassifier(nn.Module):
    """
    Bidirectional LSTM with Attention for behavior classification.
    
    Architecture:
        Input (batch, seq_len, 512)
            ↓
        Bidirectional LSTM (2 layers)
            ↓
        Attention Mechanism
            ↓
        Fully Connected Layers
            ↓
        Output (batch, 2)  [normal, shoplifting]
    """
    
    def __init__(self, input_dim=512, hidden_dim=256, num_layers=2, 
                 num_classes=2, dropout=0.3):
        super().__init__()
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # Bidirectional LSTM
        self.lstm = nn.LSTM(
            input_size=input_dim,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0,
            bidirectional=True
        )
        
        # Attention layer
        attention_dim = hidden_dim * 2  # Bidirectional doubles the dim
        self.attention = nn.Sequential(
            nn.Linear(attention_dim, attention_dim // 2),
            nn.Tanh(),
            nn.Linear(attention_dim // 2, 1)
        )
        
        # Output layers
        self.fc = nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(attention_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, num_classes)
        )
    
    def forward(self, x):
        # LSTM forward pass
        lstm_out, _ = self.lstm(x)  # (batch, seq_len, hidden*2)
        
        # Attention weights
        attention_weights = self.attention(lstm_out)  # (batch, seq_len, 1)
        attention_weights = torch.softmax(attention_weights, dim=1)
        
        # Weighted sum (context vector)
        context = torch.sum(attention_weights * lstm_out, dim=1)  # (batch, hidden*2)
        
        # Classification
        output = self.fc(context)  # (batch, num_classes)
        
        return output, attention_weights.squeeze(-1)

# Initialize model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = LSTMClassifier(
    input_dim=512,
    hidden_dim=LSTM_HIDDEN_DIM,
    num_layers=LSTM_NUM_LAYERS,
    num_classes=2,
    dropout=LSTM_DROPOUT
).to(device)

print(f"LSTM model initialized on {device}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

## 8. Train the Model

In [None]:
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

def pad_sequences(sequences, max_len=None):
    """Pad sequences to same length."""
    if max_len is None:
        max_len = max(seq.shape[0] for seq in sequences)
    
    feature_dim = sequences[0].shape[1]
    padded = np.zeros((len(sequences), max_len, feature_dim))
    
    for i, seq in enumerate(sequences):
        length = min(seq.shape[0], max_len)
        padded[i, :length, :] = seq[:length]
    
    return padded

# Prepare data
X_train = pad_sequences(train_sequences)
y_train = np.array(train_labels)
X_val = pad_sequences(val_sequences)
y_val = np.array(val_labels)

# Create data loaders
train_dataset = TensorDataset(
    torch.FloatTensor(X_train),
    torch.LongTensor(y_train)
)
val_dataset = TensorDataset(
    torch.FloatTensor(X_val),
    torch.LongTensor(y_val)
)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

In [None]:
# Training loop
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# History tracking
history = {
    'train_loss': [], 'train_acc': [],
    'val_loss': [], 'val_acc': []
}

best_val_loss = float('inf')
best_model_state = None
patience_counter = 0

print("=" * 50)
print("TRAINING LSTM CLASSIFIER")
print("=" * 50)
print(f"Epochs: {EPOCHS}, Batch size: {BATCH_SIZE}, LR: {LEARNING_RATE}")
print("=" * 50)

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for batch_X, batch_y in train_loader:
        batch_X, batch_y = batch_X.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        outputs, _ = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        train_total += batch_y.size(0)
        train_correct += (predicted == batch_y).sum().item()
    
    avg_train_loss = train_loss / len(train_loader)
    train_acc = train_correct / train_total
    
    # Validation phase
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    
    with torch.no_grad():
        for batch_X, batch_y in val_loader:
            batch_X, batch_y = batch_X.to(device), batch_y.to(device)
            outputs, _ = model(batch_X)
            loss = criterion(outputs, batch_y)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            val_total += batch_y.size(0)
            val_correct += (predicted == batch_y).sum().item()
    
    avg_val_loss = val_loss / len(val_loader)
    val_acc = val_correct / val_total
    
    # Save history
    history['train_loss'].append(avg_train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(avg_val_loss)
    history['val_acc'].append(val_acc)
    
    # Early stopping check
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        best_model_state = model.state_dict().copy()
        patience_counter = 0
    else:
        patience_counter += 1
    
    # Print progress
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{EPOCHS}: "
              f"Train Loss={avg_train_loss:.4f}, Acc={train_acc:.4f} | "
              f"Val Loss={avg_val_loss:.4f}, Acc={val_acc:.4f}")
    
    # Early stopping
    if patience_counter >= EARLY_STOPPING:
        print(f"\nEarly stopping at epoch {epoch + 1}!")
        break

# Load best model
if best_model_state:
    model.load_state_dict(best_model_state)
    print(f"\nLoaded best model (val_loss: {best_val_loss:.4f})")

print("\n" + "=" * 50)
print("TRAINING COMPLETE")
print("=" * 50)
print(f"Final Training Accuracy: {history['train_acc'][-1]:.1%}")
print(f"Final Validation Accuracy: {history['val_acc'][-1]:.1%}")

In [None]:
# Plot training history
import matplotlib.pyplot as plt

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
ax1.plot(history['train_loss'], label='Train Loss')
ax1.plot(history['val_loss'], label='Val Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training & Validation Loss')
ax1.legend()
ax1.grid(True)

# Accuracy plot
ax2.plot(history['train_acc'], label='Train Acc')
ax2.plot(history['val_acc'], label='Val Acc')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training & Validation Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

## 9. Evaluate Model

In [None]:
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns

# Get predictions on validation set
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for batch_X, batch_y in val_loader:
        batch_X = batch_X.to(device)
        outputs, _ = model(batch_X)
        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(batch_y.numpy())

# Classification report
class_names = ['normal', 'shoplifting']
print("=" * 50)
print("CLASSIFICATION REPORT")
print("=" * 50)
print(classification_report(all_labels, all_preds, target_names=class_names))

# Confusion matrix
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names)
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.title('Confusion Matrix')
plt.show()

## 10. Save Model

In [None]:
import json
from datetime import datetime

# Create models directory
models_dir = Path("/content/Project_DigitalWitness/models")
models_dir.mkdir(exist_ok=True)

# Save model in the format expected by the application
model_path = models_dir / "lstm_classifier.pt"

torch.save({
    'model_state_dict': model.state_dict(),
    'config': {
        'input_dim': 512,
        'hidden_dim': LSTM_HIDDEN_DIM,
        'num_layers': LSTM_NUM_LAYERS,
        'num_classes': 2,
        'dropout': LSTM_DROPOUT,
        'bidirectional': True,
        'classes': ['normal', 'shoplifting']
    }
}, model_path)

print(f"Model saved to: {model_path}")

# Save training info
info = {
    'training_date': datetime.now().isoformat(),
    'n_train_samples': len(train_sequences),
    'n_val_samples': len(val_sequences),
    'n_normal_videos': len(normal_videos),
    'n_shoplifting_videos': len(shoplifting_videos),
    'sequence_length': SEQUENCE_LENGTH,
    'stride': SLIDING_STRIDE,
    'epochs': len(history['train_loss']),
    'final_train_acc': history['train_acc'][-1],
    'final_val_acc': history['val_acc'][-1],
    'classes': ['normal', 'shoplifting'],
    'config': {
        'hidden_dim': LSTM_HIDDEN_DIM,
        'num_layers': LSTM_NUM_LAYERS,
        'dropout': LSTM_DROPOUT,
        'learning_rate': LEARNING_RATE,
        'batch_size': BATCH_SIZE
    }
}

info_path = models_dir / "lstm_classifier_info.json"
with open(info_path, 'w') as f:
    json.dump(info, f, indent=2)

print(f"Training info saved to: {info_path}")

In [None]:
# Copy to Google Drive
import shutil

drive_dest = "/content/drive/MyDrive/DigitalWitness_Models/"
os.makedirs(drive_dest, exist_ok=True)

# Copy model files
shutil.copy(model_path, drive_dest)
shutil.copy(info_path, drive_dest)

print("=" * 50)
print("MODEL SAVED TO GOOGLE DRIVE")
print("=" * 50)
print(f"Location: {drive_dest}")
print("\nFiles:")
print("  - lstm_classifier.pt")
print("  - lstm_classifier_info.json")
print("\nDownload these files and place them in your local project's 'models/' folder.")

## 11. Test Inference (Optional)

In [None]:
# Test the model on a sample sequence
model.eval()

# Use first validation sample
test_seq = torch.FloatTensor(X_val[0:1]).to(device)
true_label = y_val[0]

with torch.no_grad():
    output, attention = model(test_seq)
    probs = torch.softmax(output, dim=1)
    pred_class = torch.argmax(probs, dim=1).item()

print("=" * 50)
print("SAMPLE INFERENCE")
print("=" * 50)
print(f"True label: {class_names[true_label]}")
print(f"Predicted: {class_names[pred_class]}")
print(f"\nProbabilities:")
print(f"  - normal: {probs[0][0].item():.2%}")
print(f"  - shoplifting: {probs[0][1].item():.2%}")
print(f"\nAttention weights shape: {attention.shape}")
print(f"Top 5 attention frames: {attention[0].argsort(descending=True)[:5].tolist()}")

---

## Done!

Your trained model is saved to Google Drive. To use it:

1. Download `lstm_classifier.pt` and `lstm_classifier_info.json` from Google Drive
2. Place them in your local project's `models/` folder
3. Run `python run.py your_video.mp4` to analyze videos

**Model outputs:**
- `normal` - Regular shopping behavior
- `shoplifting` - Suspicious/stealing behavior