In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
from torch.utils.data import DataLoader
from typing import Tuple, Optional
import logging

# Set up logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

In [None]:
import os
import torch
from torch.utils.data import Dataset
import numpy as np
import cv2


class VideoLandmarkDataset(Dataset):
    def __init__(self, anno_path, video_prefix, landmark_prefix, mode, clip_len, frame_sample_rate, crop_size, short_side_size, args):
        """
        Initialize the VideoLandmarkDataset.
        
        Args:
            anno_path (str): Path to the annotation CSV file.
            video_prefix (str): Directory containing video files.
            landmark_prefix (str): Directory containing landmark files.
            mode (str): Dataset mode ('train', 'validation', 'test').
            clip_len (int): Number of frames per video clip.
            frame_sample_rate (int): Frame sampling rate.
            crop_size (int): Crop size for frames.
            short_side_size (int): Short side resizing for frames.
            args (Namespace): Additional arguments.
        """
        self.mode = mode
        self.clip_len = clip_len
        self.frame_sample_rate = frame_sample_rate
        self.crop_size = crop_size
        self.short_side_size = short_side_size

        self.video_prefix = video_prefix
        self.landmark_prefix = landmark_prefix

        # Load annotations
        self.annotations = self._load_annotations(anno_path)

    def _load_annotations(self, anno_path):
        """
        Load annotation data from a CSV file.
        """
        annotations = []
        with open(anno_path, 'r') as file:
            for line in file:
                video_file, landmark_file = line.strip().split(',')
                video_path = os.path.join(self.video_prefix, video_file)
                landmark_path = os.path.join(self.landmark_prefix, landmark_file)

                if not os.path.exists(video_path):
                    raise FileNotFoundError(f"Video file not found: {video_path}")
                if not os.path.exists(landmark_path):
                    raise FileNotFoundError(f"Landmark file not found: {landmark_path}")

                annotations.append((video_path, landmark_path))
        return annotations

    def _load_video(self, video_path):
        """
        Load video frames and preprocess them.
        """
        cap = cv2.VideoCapture(video_path)
        frames = []
        success, frame = cap.read()
        count = 0
        while success:
            if count % self.frame_sample_rate == 0:
                # Resize frame
                frame_resized = cv2.resize(frame, (self.crop_size, self.crop_size))
                frames.append(frame_resized)
            success, frame = cap.read()
            count += 1
        cap.release()

        # Limit to clip length and preprocess
        frames = np.array(frames[:self.clip_len])
        if len(frames) < self.clip_len:
            raise ValueError(f"Insufficient frames in video: {video_path}")
        frames = frames.transpose(0, 3, 1, 2)  # Convert to (T, C, H, W)
        return torch.tensor(frames, dtype=torch.float32) / 255.0

    def _load_landmarks(self, landmark_path):
        """
        Load landmark data and preprocess it.
        """
        landmarks = np.loadtxt(landmark_path, delimiter=',')  # CSV with x, y columns
        landmarks = landmarks.reshape(self.clip_len, -1, 2)  # (T, num_landmarks, 2)
        return torch.tensor(landmarks, dtype=torch.float32)

    def __getitem__(self, index):
        """
        Retrieve a sample from the dataset.
        """
        video_path, landmark_path = self.annotations[index]
        video = self._load_video(video_path)
        landmarks = self._load_landmarks(landmark_path)
        return video, landmarks

    def __len__(self):
        """
        Return the total number of samples.
        """
        return len(self.annotations)


In [None]:
def build_dataset(is_train, test_mode, args):
    """
    Build a dataset for video input and target landmarks.
    
    Args:
        is_train (bool): Whether the dataset is for training.
        test_mode (bool): Whether the dataset is for testing.
        args (Namespace): Configuration arguments containing dataset paths and parameters.

    Returns:
        dataset (Dataset): An instance of the dataset.
        nb_classes (int or None): Number of classes (None for non-classification tasks).
    """
    mode = 'train' if is_train else 'test' if test_mode else 'validation'
    anno_file = os.path.join(args.data_path, f"{mode}.csv")
    if not os.path.exists(anno_file):
        raise FileNotFoundError(f"Annotation file not found: {anno_file}")

    dataset = VideoLandmarkDataset(
        anno_path=anno_file,
        video_prefix=args.video_prefix,
        landmark_prefix=args.landmark_prefix,
        mode=mode,
        clip_len=args.num_frames,
        frame_sample_rate=args.sampling_rate,
        crop_size=args.input_size,
        short_side_size=args.short_side_size,
        args=args,
    )

    nb_classes = None  # This is not a classification task
    print(f"Dataset built successfully for mode: {mode}")
    return dataset, nb_classes


In [None]:
class Parameters:
    def __init__(self):
        # Dataset parameters
        self.data_set = 'UCF101'
        self.data_path = 'your_data_path'
        self.prefix = ''
        self.split = ' '
        self.filename_tmpl = 'img_{:05}.jpg'
        self.nb_classes = 101
        self.use_decord = True
        self.trimmed = 60
        self.time_stride = 16
        
        # Model parameters
        self.input_size = 224
        self.short_side_size = 224
        self.num_frames = 16
        self.sampling_rate = 4
        
        # Training parameters
        self.batch_size = 32
        self.num_workers = 4
        self.learning_rate = 0.001
        self.log_interval = 10
        self.epochs = 100
        self.test_num_segment = 5
        self.test_num_crop = 3

In [None]:
class FacialLandmarkCNN(nn.Module):
    def __init__(self, args: Parameters, num_landmarks: int = 68):
        super(FacialLandmarkCNN, self).__init__()
        self.args = args
        
        # Load and modify ResNet backbone
        self.backbone = models.resnet18(pretrained=True)
        # Remove final FC and pooling layers
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # Calculate output size from backbone
        with torch.no_grad():
            dummy_input = torch.zeros(1, 3, args.input_size, args.input_size)
            backbone_output = self.backbone(dummy_input)
            self.backbone_output_shape = backbone_output.shape[1:]
        
        # Temporal modeling with 3D convolutions
        self.temporal_conv = nn.Sequential(
            nn.Conv3d(self.backbone_output_shape[0], 128, kernel_size=(3, 3, 3), 
                     padding=(1, 1, 1)),
            nn.BatchNorm3d(128),
            nn.ReLU(inplace=True),
            nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)),
            nn.BatchNorm3d(256),
            nn.ReLU(inplace=True)
        )
        
        # Adaptive pooling to ensure consistent size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((56, 56))
        
        # Refinement network with proper size handling
        self.refinement = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1)
        )
        
        # Landmark prediction head
        self.landmark_head = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, num_landmarks * 2, kernel_size=1)
        )
        
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize model weights using Kaiming initialization"""
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Conv3d, nn.ConvTranspose2d)):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, (nn.BatchNorm2d, nn.BatchNorm3d)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        batch_size, num_frames, c, h, w = x.shape
        
        # Validate input dimensions
        if num_frames != self.args.num_frames:
            raise ValueError(f"Expected {self.args.num_frames} frames, got {num_frames}")
        if h != self.args.input_size or w != self.args.input_size:
            raise ValueError(f"Expected size {self.args.input_size}, got {h}x{w}")
        
        # Process all frames in parallel
        x = x.view(batch_size * num_frames, c, h, w)
        features = self.backbone(x)
        
        # Apply adaptive pooling to ensure consistent size
        features = self.adaptive_pool(features)
        
        # Extract landmarks and feature maps
        landmarks = self.landmark_head(features)
        landmarks = landmarks.view(batch_size, num_frames, -1, 2)
        
        # Generate feature maps
        feature_maps = torch.sigmoid(self.refinement(features))
        feature_maps = feature_maps.view(batch_size, num_frames, 1, h, w)
        
        return feature_maps, landmarks

class FacialLandmarkDetector:
    def __init__(self, args: Parameters):
        self.args = args
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = FacialLandmarkCNN(args).to(self.device)
        
        # Initialize loss functions with reduction method
        self.criterion_landmarks = nn.MSELoss(reduction='mean')
        self.criterion_features = nn.BCEWithLogitsLoss(reduction='mean')
        
        # Initialize optimizer with weight decay
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.args.learning_rate,
            weight_decay=0.01
        )
        
        # Initialize learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', factor=0.5, patience=5,
            verbose=True
        )
        
        # Initialize loss history
        self.train_losses = []
        self.val_losses = []
    
    @torch.no_grad()
    def create_landmark_map(self, landmarks: torch.Tensor, height: int, 
                          width: int, sigma: float = 3.0) -> torch.Tensor:
        """
        Create Gaussian heatmaps for landmarks with numerical stability improvements.
        
        Args:
            landmarks: Tensor of shape (batch_size, num_landmarks, 2)
            height: Height of output heatmap
            width: Width of output heatmap
            sigma: Standard deviation for Gaussian kernel
            
        Returns:
            Tensor of shape (batch_size, 1, height, width)
        """
        batch_size = landmarks.size(0)
        landmark_map = torch.zeros(batch_size, 1, height, width, 
                                 device=self.device)
        
        # Create coordinate grids
        y_grid, x_grid = torch.meshgrid(
            torch.arange(height, device=self.device),
            torch.arange(width, device=self.device)
        )
        
        # Pre-compute maximum square distance for clipping
        max_squared_dist = 9 * sigma * sigma
        
        for b in range(batch_size):
            for landmark in landmarks[b]:
                x, y = landmark
                
                # Skip if landmark is outside image bounds
                if not (0 <= x < width and 0 <= y < height):
                    continue
                
                # Compute distances efficiently using broadcasting
                squared_dist = (x_grid - x) ** 2 + (y_grid - y) ** 2
                
                # Clip distances for numerical stability
                squared_dist = torch.clamp(squared_dist, max=max_squared_dist)
                
                # Compute Gaussian values
                gaussian = torch.exp(-squared_dist / (2 * sigma * sigma))
                
                # Update landmark map using maximum values
                landmark_map[b, 0] = torch.maximum(landmark_map[b, 0], gaussian)
        
        return landmark_map
    
    def train_epoch(self, train_loader: DataLoader) -> float:
        """Train for one epoch and return average loss."""
        self.model.train()
        total_loss = 0.0
        num_batches = len(train_loader)
        
        for batch_idx, (data, target_landmarks) in enumerate(train_loader):
            # Skip invalid batches
            if data.size(1) != self.args.num_frames:
                continue
            
            # Move data to device
            data = data.to(self.device)
            target_landmarks = target_landmarks.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Forward pass
            feature_maps, predicted_landmarks = self.model(data)
            
            # Create target maps efficiently
            target_maps = self.create_landmark_map(
                target_landmarks.view(-1, target_landmarks.size(-2), 2),
                data.size(-2), data.size(-1)
            ).view_as(feature_maps)
            
            # Compute losses
            landmark_loss = self.criterion_landmarks(
                predicted_landmarks, target_landmarks
            )
            feature_loss = self.criterion_features(feature_maps, target_maps)
            loss = landmark_loss + feature_loss
            
            # Backward pass
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()
            
            total_loss += loss.item()
            
            # Log progress
            if batch_idx % self.args.log_interval == 0:
                logger.info(
                    f'Train Batch: {batch_idx}/{num_batches} '
                    f'Loss: {loss.item():.6f}'
                )
        
        return total_loss / num_batches
    
    @torch.no_grad()
    def validate(self, val_loader: DataLoader) -> float:
        """Validate the model and return average validation loss."""
        self.model.eval()
        total_loss = 0.0
        num_batches = len(val_loader)
        
        for data, target_landmarks in val_loader:
            if data.size(1) != self.args.num_frames:
                continue
            
            data = data.to(self.device)
            target_landmarks = target_landmarks.to(self.device)
            
            # Forward pass
            feature_maps, predicted_landmarks = self.model(data)
            
            # Create target maps
            target_maps = self.create_landmark_map(
                target_landmarks.view(-1, target_landmarks.size(-2), 2),
                data.size(-2), data.size(-1)
            ).view_as(feature_maps)
            
            # Compute losses
            landmark_loss = self.criterion_landmarks(
                predicted_landmarks, target_landmarks
            )
            feature_loss = self.criterion_features(feature_maps, target_maps)
            loss = landmark_loss + feature_loss
            
            total_loss += loss.item()
        
        return total_loss / num_batches
    
    def train(self, train_loader: DataLoader, val_loader: DataLoader, 
              num_epochs: Optional[int] = None) -> None:
        """Train the model for specified number of epochs."""
        num_epochs = num_epochs or self.args.epochs
        best_val_loss = float('inf')
        patience = 10
        patience_counter = 0
        
        for epoch in range(num_epochs):
            logger.info(f'\nEpoch: {epoch+1}/{num_epochs}')
            
            # Train and validate
            train_loss = self.train_epoch(train_loader)
            val_loss = self.validate(val_loader)
            
            # Update learning rate
            self.scheduler.step(val_loss)
            
            # Store losses
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            
            logger.info(f'Training Loss: {train_loss:.6f}')
            logger.info(f'Validation Loss: {val_loss:.6f}')
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                
                # Save checkpoint
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'scheduler_state_dict': self.scheduler.state_dict(),
                    'best_val_loss': best_val_loss,
                    'train_losses': self.train_losses,
                    'val_losses': self.val_losses
                }, 'best_model.pth')
                
                logger.info('Saved best model checkpoint')
            else:
                patience_counter += 1
            
            # Early stopping
            if patience_counter >= patience:
                logger.info(f'Early stopping triggered after {epoch + 1} epochs')
                break
    
    @torch.no_grad()
    def extract_features(self, video_tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Extract features and landmarks from a video tensor."""
        self.model.eval()
        video_tensor = video_tensor.to(self.device)
        
        if video_tensor.size(1) != self.args.num_frames:
            raise ValueError(
                f"Expected {self.args.num_frames} frames, "
                f"got {video_tensor.size(1)}"
            )
        
        feature_maps, landmarks = self.model(video_tensor)
        return feature_maps, landmarks

In [None]:
def main():
    # Initialize parameters
    args = Parameters()
    logger.info("Initialized parameters")
    
    # Create detector instance
    detector = FacialLandmarkDetector(args)
    logger.info("Created facial landmark detector")
    
    try:
        train_dataset, _ = build_dataset(is_train=True, test_mode=False, args=args)
        val_dataset, _ = build_dataset(is_train=False, test_mode=False, args=args)
        
        train_loader = torch.utils.data.DataLoader(
            train_dataset, 
            batch_size=args.batch_size, 
            shuffle=True,
            num_workers=args.num_workers, 
            pin_memory=True
        )
        
        val_loader = torch.utils.data.DataLoader(
            val_dataset, 
            batch_size=args.batch_size, 
            shuffle=False,
            num_workers=args.num_workers, 
            pin_memory=True
        )
        
        # Training configuration
        logger.info("Starting training...")
        detector.train(train_loader, val_loader)
        logger.info("Training completed successfully")
        
        # Save final model
        torch.save({
            'model_state_dict': detector.model.state_dict(),
            'optimizer_state_dict': detector.optimizer.state_dict(),
            'scheduler_state_dict': detector.scheduler.state_dict(),
            'train_losses': detector.train_losses,
            'val_losses': detector.val_losses
        }, 'final_model.pth')
        logger.info("Saved final model")
        
    except Exception as e:
        logger.error(f"An error occurred during training: {str(e)}")
        raise
    
    return detector

def load_pretrained_detector(checkpoint_path: str, args: Parameters = None) -> FacialLandmarkDetector:
    """
    Load a pretrained facial landmark detector from a checkpoint.
    
    Args:
        checkpoint_path: Path to the checkpoint file
        args: Optional Parameters object. If None, default parameters will be used.
        
    Returns:
        FacialLandmarkDetector: Loaded detector with pretrained weights
    """
    if args is None:
        args = Parameters()
    
    try:
        # Create detector instance
        detector = FacialLandmarkDetector(args)
        
        # Load checkpoint
        checkpoint = torch.load(checkpoint_path, map_location=detector.device)
        
        # Load model state
        detector.model.load_state_dict(checkpoint['model_state_dict'])
        
        # Optionally load optimizer and scheduler states if they exist
        if 'optimizer_state_dict' in checkpoint:
            detector.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if 'scheduler_state_dict' in checkpoint:
            detector.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        # Load training history if available
        if 'train_losses' in checkpoint:
            detector.train_losses = checkpoint['train_losses']
        if 'val_losses' in checkpoint:
            detector.val_losses = checkpoint['val_losses']
        
        logger.info(f"Successfully loaded pretrained model from {checkpoint_path}")
        return detector
        
    except Exception as e:
        logger.error(f"Error loading pretrained model: {str(e)}")
        raise

if __name__ == "__main__":
    try:
        detector = main()
    except Exception as e:
        logger.error(f"Program terminated with error: {str(e)}")
        raise