# Importation des librairies

In [1]:
import os
import av # PyAV library for efficient video I/O
import torch
import warnings
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T # Used for frame preprocessing
import torchvision.models as models # Needed for R(2+1)D model
import torch.optim.lr_scheduler as lr_scheduler
from typing import Dict, List, Tuple
from torch.utils.data import Dataset, DataLoader
from torchvision.models.video import R2Plus1D_18_Weights

# Extraction Features

In [None]:
# The expected feature dimension from the R(2+1)D-18 model without the final FC layer.
VIDEO_FEATURE_DIM = 512
TARGET_CLIP_LENGTH = 16
TARGET_H, TARGET_W = 112, 112

# Global feature extractor instance for lazy loading
_r2plus1d_feature_extractor = None

def get_r2plus1d_feature_extractor():
    """
    Lazily loads the pre-trained R(2+1)D-18 model and modifies it 
    to output the 512-dimensional features.
    """
    global _r2plus1d_feature_extractor
    if _r2plus1d_feature_extractor is None:
        print("Loading pre-trained r2plus1d_18 model for feature extraction...")
        # Load the R(2+1)D-18 model pre-trained on Kinetics-400
        model = models.video.r2plus1d_18(weights=R2Plus1D_18_Weights.KINETICS400_V1) 
        
        # Modify the last fully connected layer (fc) to be an Identity function
        model.fc = nn.Identity()
        
        model.eval() # Set to evaluation mode
        _r2plus1d_feature_extractor = model.cuda() if torch.cuda.is_available() else model
        print(f"R(2+1)D Feature Extractor loaded on device: {_r2plus1d_feature_extractor.device}")

    return _r2plus1d_feature_extractor

def extract_features_from_video(video_path: str) -> Tuple[torch.Tensor, str]:
    """
    Extracts R(2+1)D features from a raw video file using PyAV.
    Returns the feature sequence tensor and the output path for saving.
    """
    feature_extractor = get_r2plus1d_feature_extractor()
    cu = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(cu)
    # Define the output path for the feature file
    feature_save_path = video_path.rsplit('.', 1)[0] + '_features.npy'
    
    # Skip extraction if the feature file already exists
    if os.path.exists(feature_save_path):
        print(f"  --> Features already exist: {os.path.basename(feature_save_path)}. Skipping.")
        return None, feature_save_path
    
    print(f"  --> Starting extraction for {os.path.basename(video_path)}")

    # Preprocessing transforms
    preprocess = T.Compose([
        T.Resize((TARGET_H, TARGET_W)),
        T.ToTensor(), 
        T.Normalize(
            mean=[0.43216, 0.394666, 0.37645], # Kinetics-400 means
            std=[0.22803, 0.221459, 0.216328]
        )
    ])

    all_frames = []
    try:
        with av.open(video_path) as container:
            stream = container.streams.video[0]
            
            for frame in container.decode(stream):
                img = frame.to_rgb().to_image()
                tensor_frame = preprocess(img)
                all_frames.append(tensor_frame)

        if not all_frames:
            raise RuntimeError("Could not decode any frames from the video.")

        all_frames_tensor = torch.stack(all_frames).float()
        total_frames = all_frames_tensor.size(0)

    except av.AVError as e:
        warnings.warn(f"PyAV failed to load video {os.path.basename(video_path)}: {e}. Skipping.", UserWarning)
        return None, feature_save_path
    
    # --- Temporal sampling and feature extraction ---

    if total_frames < TARGET_CLIP_LENGTH:
        warnings.warn(f"Video {os.path.basename(video_path)} is too short ({total_frames} frames). Skipping extraction.", UserWarning)
        return None, feature_save_path 

    NUM_CLIPS_TO_EXTRACT = total_frames // TARGET_CLIP_LENGTH
    trimmed_frames = all_frames_tensor[:NUM_CLIPS_TO_EXTRACT * TARGET_CLIP_LENGTH]
    
    # Reshape and permute to R(2+1)D format: (N, C, T, H, W)
    video_clips = trimmed_frames.view(NUM_CLIPS_TO_EXTRACT, TARGET_CLIP_LENGTH, 3, TARGET_H, TARGET_W)
    video_clips = video_clips.permute(0, 2, 1, 3, 4).to(device)
    
    all_clip_features = []
    with torch.no_grad():
        # Process clips in batches to save memory/time if there are many clips
        clip_batch_size = 32 
        for i in range(0, video_clips.size(0), clip_batch_size):
            clip_batch = video_clips[i:i + clip_batch_size]
            
            # feature_vector shape: (batch_size, 512)
            feature_vector = feature_extractor(clip_batch) 
            all_clip_features.append(feature_vector)

    # Concatenate all clip features to get the sequence tensor (num_clips, 512)
    video_features_sequence = torch.cat(all_clip_features, dim=0).cpu() # Move back to CPU for numpy saving
    
    return video_features_sequence, feature_save_path

def extract(data_root_dir: str):
    """
    Walks through the data root directory, finds .mp4 files, extracts features, 
    and saves them as .npy files.
    """
    if not os.path.isdir(data_root_dir):
        print(f"Error: Directory not found at {data_root_dir}")
        return

    print(f"Starting batch feature extraction from: {data_root_dir}")
    
    video_files = []
    # Traverse the directory structure (assuming class folders are immediate subdirectories)
    for dirpath, dirnames, filenames in os.walk(data_root_dir):
        for filename in filenames:
            if filename.lower().endswith(('.mp4', '.avi', '.mov', '.webm')):
                video_files.append(os.path.join(dirpath, filename))

    if not video_files:
        print(f"No video files found in {data_root_dir}. Ensure they are in class subfolders.")
        return

    print(f"Found {len(video_files)} video files to process.")
    
    # Perform extraction for all found videos
    for video_path in video_files:
        try:
            features, save_path = extract_features_from_video(video_path)
            
            if features is not None:
                # Save the features using NumPy
                np.save(save_path, features.numpy())
                print(f"  -- Successfully saved features of shape {features.shape} to: {os.path.basename(save_path)}")
                
        except Exception as e:
            print(f"An unexpected error occurred processing {os.path.basename(video_path)}: {e}")

In [None]:
extract("./TrainValVideo")

# Data Loader

In [2]:
VIDEO_FEATURE_DIM = 512 # Matches the output dimension of R(2+1)D features
NUM_CLASSES = 2 # Default number of classes, will be updated by setup_data_pipeline
# Use a module-level variable for lazy loading of the feature extractor
_r2plus1d_feature_extractor = None

def get_r2plus1d_feature_extractor():
    """
    Lazily loads the pre-trained R(2+1)D-34 model and modifies it 
    to output the 2048-dimensional features.
    """
    global _r2plus1d_feature_extractor
    if _r2plus1d_feature_extractor is None:
        print("Loading pre-trained r2plus1d_34 model for feature extraction...")
        # Load the R(2+1)D-34 model pre-trained on Kinetics-400
        # NOTE: We use r2plus1d_18 from torchvision which is a standard choice.
        # It is highly recommended to use the best performing and most stable version available.
        model = models.video.r2plus1d_18(pretrained=True) 
        
        # Modify the last fully connected layer (fc) to be an Identity function
        # This allows us to access the 2048-dimensional output of the Global Average Pooling layer.
        model.fc = nn.Identity()
        
        model.eval() # Set to evaluation mode
        _r2plus1d_feature_extractor = model
        print(f"R(2+1)D Feature Extractor loaded. Output dimension: {VIDEO_FEATURE_DIM}")

    return _r2plus1d_feature_extractor


# --- 1. PyTorch Custom Dataset for Video Classification ---

class VideoClassificationDataset(Dataset):
    """
    A custom PyTorch Dataset to load either pre-extracted features (.npy, .pt) or 
    raw video files (.mp4) and their corresponding labels.
    """
    def __init__(self, data_root: str, extensions: List[str] = ['.npy', '.pt', '.mp4']):
        self.data_root = data_root
        self.extensions = extensions
        self.samples = [] # Stores list of (file_filepath, class_label_index)
        self.class_to_idx: Dict[str, int] = {}
        
        # 1. Index the dataset
        self._index_dataset()

    def _index_dataset(self):
        """Walks the data root to find all feature/video files and assign labels."""
        
        # Map class names (folder names) to integer indices
        class_names = [d.name for d in os.scandir(self.data_root) if d.is_dir()]
        class_names.sort() # Ensure consistent mapping
        
        for i, class_name in enumerate(class_names):
            self.class_to_idx[class_name] = i
            class_dir = os.path.join(self.data_root, class_name)
            
            for entry in os.scandir(class_dir):
                if entry.is_file() and any(entry.name.endswith(ext) for ext in self.extensions):
                    # Add the sample: (filepath, class_index)
                    self.samples.append((entry.path, i))

        print(f"Found {len(self.samples)} files across {len(self.class_to_idx)} classes.")
        print(f"Class Mapping: {self.class_to_idx}")
        
    def _extract_features_from_video(self, video_path: str) -> torch.Tensor:
        """
        Extracts R(2+1)D features from a raw video file.
        
        CRITICAL NOTE: This method simulates video loading and preprocessing.
        You MUST integrate a video library (PyAV recommended for speed) to 
        replace the simulated I/O and preprocessing steps below.
        
        The final output MUST be a tensor of shape (num_clips, VIDEO_FEATURE_DIM).
        """

        # --- ACTUAL R(2+1)D MODEL LOADING ---
        feature_extractor = get_r2plus1d_feature_extractor()
        
        # --- PLACEHOLDER: Video I/O, Resizing, and Frame Sampling ---
        
        TARGET_CLIP_LENGTH = 16
        TARGET_H, TARGET_W = 112, 112
        
        all_frames = []
        
        # Define the necessary preprocessing transforms for R(2+1)D 
        # (Resize to 112x112 and normalize with Kinetics-400 mean/std)
        preprocess = T.Compose([
            T.Resize((TARGET_H, TARGET_W)), # Resize to 112x112 (no cropping)
            T.ToTensor(), # Converts PIL to Tensor C,H,W, float [0, 1]
            T.Normalize(
                mean=[0.43216, 0.394666, 0.37645], # Kinetics-400 means
                std=[0.22803, 0.221459, 0.216328]   # Kinetics-400 stds
            )
        ])

        try:
            with av.open(video_path) as container:
                # Get the first video stream
                stream = container.streams.video[0]
                
                # Iterate through all frames
                for frame in container.decode(stream):
                    # Convert AVFrame to a standard format (RGB, Image) and apply preprocessing
                    img = frame.to_rgb().to_image()
                    tensor_frame = preprocess(img)
                    all_frames.append(tensor_frame)

            if not all_frames:
                raise RuntimeError("Could not decode any frames from the video.")

            # Stack all frames: (Total_Frames, C, H, W) -> (Total_Frames, 3, 112, 112)
            all_frames_tensor = torch.stack(all_frames).float()
            total_frames = all_frames_tensor.size(0)

        except av.AVError as e:
            warnings.warn(f"PyAV failed to load video {os.path.basename(video_path)}: {e}. Returning zero tensor.", UserWarning)
            total_frames = 0
            all_frames_tensor = torch.empty(0)
            
        # --- 2. Dense, Non-overlapping temporal sampling ---
        
        if total_frames < TARGET_CLIP_LENGTH:
            # Handle case where video is too short
            warnings.warn(f"Video {os.path.basename(video_path)} is too short ({total_frames} frames). Returning single zero feature vector.", UserWarning)
            # This returns a single zero feature vector, which is then padded later.
            return torch.zeros(1, VIDEO_FEATURE_DIM) 

        # Calculate the maximum number of full, non-overlapping 16-frame clips.
        NUM_CLIPS_TO_EXTRACT = total_frames // TARGET_CLIP_LENGTH
        
        # Slice the frames to only include the clips we can form 
        trimmed_frames = all_frames_tensor[:NUM_CLIPS_TO_EXTRACT * TARGET_CLIP_LENGTH]
        
        # Reshape trimmed frames into clips:
        # (N*16, 3, H, W) -> (N, 16, 3, H, W)
        video_clips = trimmed_frames.view(NUM_CLIPS_TO_EXTRACT, TARGET_CLIP_LENGTH, TARGET_H, TARGET_W, 3) # N, T, H, W, C
        # Permute to the R(2+1)D required format: (N, C, T, H, W)
        video_clips = video_clips.permute(0, 4, 1, 2, 3) # N, C, T, H, W
        
        if NUM_CLIPS_TO_EXTRACT == 0:
            # Handle case where video is too short (less than 16 frames)
            warnings.warn(f"Video {os.path.basename(video_path)} is too short ({total_frames} frames) to extract a full 16-frame clip. Using a single zero-padded clip.", UserWarning)
            # Create a single zero-padded clip (1, 3, 16, 112, 112)
            video_clips = torch.zeros(1, 3, TARGET_CLIP_LENGTH, TARGET_H, TARGET_W).float()
            # Note: A real implementation might sample the available frames and pad the rest.
        else:
            # Slice the frames to only include the clips we can form (e.g., 150 -> 144 frames for 9 clips)
            trimmed_frames = all_frames_tensor[:NUM_CLIPS_TO_EXTRACT * TARGET_CLIP_LENGTH]
            
            # Reshape trimmed frames into clips:
            # (N*16, 3, H, W) -> (N, 16, 3, H, W)
            video_clips = trimmed_frames.view(NUM_CLIPS_TO_EXTRACT, TARGET_CLIP_LENGTH, 3, TARGET_H, TARGET_W)
            # Permute to the R(2+1)D required format: (N, C, T, H, W)
            video_clips = video_clips.permute(0, 2, 1, 3, 4) 
        
        # --- END PLACEHOLDER ---

        # 4. Extract features for each clip
        all_clip_features = []
        with torch.no_grad():
            # Check if video_clips is empty (only happens if total_frames < 16)
            if video_clips.size(0) > 0:
                for clip in video_clips:
                    # Add batch dimension (1, 3, 16, 112, 112)
                    clip = clip.unsqueeze(0) 
                    
                    # Get the 2048-dimensional feature vector
                    feature_vector = feature_extractor(clip).squeeze(0) # (2048)
                    all_clip_features.append(feature_vector)

        if not all_clip_features:
            # If no clips were extracted (e.g., short video handled by warning), return a tensor of zeros
            video_features_sequence = torch.zeros(1, VIDEO_FEATURE_DIM) 
        else:
            # Concatenate all clip features to get the sequence tensor (num_clips, 2048)
            video_features_sequence = torch.stack(all_clip_features) 
        
        return video_features_sequence

    def __len__(self) -> int:
        return len(self.samples)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, int]:
        """
        Loads the video features (either directly from file or by extraction) 
        and returns the feature tensor and the label.
        """
        feature_filepath, label = self.samples[idx]
        
        if feature_filepath.endswith('.mp4'):
            # If it's a raw video, extract features on the fly
            features = self._extract_features_from_video(feature_filepath)
        elif feature_filepath.endswith('.npy'):
            features = torch.from_numpy(np.load(feature_filepath)).float()
        elif feature_filepath.endswith('.pt'):
            features = torch.load(feature_filepath).float()
        else:
            raise IOError(f"Unsupported file type: {feature_filepath}")
        
        # Ensure the feature tensor has the required shape (num_clips, feature_dim)
        if features.dim() != 2:
             print(f"Warning: Features at {feature_filepath} have shape {features.shape}. Expected 2D. Skipping.")
             return self.__getitem__((idx + 1) % len(self)) 
        
        return features, label

# --- 2. Custom Collate Function for Padding ---

def collate_fn(batch: List[Tuple[torch.Tensor, int]]) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Pads feature sequences to the maximum length in the batch.
    """
    # Separate features and labels
    features, labels = zip(*batch)
    
    # Get max sequence length in the batch
    max_len = max(f.size(0) for f in features)
    
    # Pad all feature tensors to the max length
    padded_features = []
    for f in features:
        pad_len = max_len - f.size(0)
        # Pad with zeros at the end: (num_clips, feature_dim) -> (max_len, feature_dim)
        padding = torch.zeros((pad_len, f.size(1)), dtype=f.dtype)
        padded_features.append(torch.cat([f, padding], dim=0))
        
    # Stack tensors
    padded_features = torch.stack(padded_features)
    labels = torch.tensor(labels, dtype=torch.long)
    
    return padded_features, labels

# --- 3. Example Usage to Integrate with Your Model ---

def setup_data_pipeline(data_root: str, batch_size: int, num_workers: int = 0, shuffle: bool = True) -> DataLoader:
    """
    Initializes the Dataset and DataLoader and updates the global NUM_CLASSES.
    """
    global NUM_CLASSES
    
    # 1. Create Dataset
    dataset = VideoClassificationDataset(data_root=data_root)
    
    # Update global NUM_CLASSES to reflect the actual data
    new_num_classes = len(dataset.class_to_idx)
    if NUM_CLASSES != new_num_classes:
        print(f"--- INFO: Updating global NUM_CLASSES from {NUM_CLASSES} to {new_num_classes} based on data folders. ---")
        NUM_CLASSES = new_num_classes
        
    # 2. Create DataLoader
    data_loader = DataLoader(
        dataset,
        shuffle=shuffle,
        batch_size=batch_size,
        collate_fn=collate_fn, # Use the custom padding function
        num_workers=num_workers
    )
    return data_loader

# Model

In [4]:
class SelfAttention(nn.Module):
    """
    Self-Attention Mechanism for sequence data (LSTM outputs).
    """
    def __init__(self, hidden_dim, attention_dim):
        super(SelfAttention, self).__init__()
        # Bi-LSTM output is hidden_dim * 2
        self.hidden_dim = hidden_dim * 2
        
        self.W_a = nn.Linear(self.hidden_dim, attention_dim, bias=False)
        self.V_a = nn.Linear(attention_dim, 1, bias=False)

    def forward(self, H):
        """
        H: (batch_size, sequence_length, hidden_dim * num_directions) -> LSTM output sequence
        """
        U = torch.tanh(self.W_a(H))
        scores = self.V_a(U).squeeze(-1)
        weights = torch.softmax(scores, dim=1)
        context = torch.sum(weights.unsqueeze(-1) * H, dim=1)
        
        return context, weights 

class SA_LSTM_Classification_Model(nn.Module):
    """
    Sequence-to-Classification Model using Bi-directional LSTM and Self-Attention.
    """
    def __init__(self, video_feature_dim: int, hidden_dim: int, attention_dim: int, num_classes: int):
        super(SA_LSTM_Classification_Model, self).__init__()
        
        self.lstm = nn.LSTM(
            input_size=video_feature_dim, 
            hidden_size=hidden_dim, 
            num_layers=1, 
            batch_first=True,
            bidirectional=True
        )
        
        lstm_output_dim = hidden_dim * 2
        
        self.attention = SelfAttention(hidden_dim, attention_dim)
        
        self.classifier = nn.Sequential(
            nn.Dropout(0.7),
            nn.Linear(lstm_output_dim, num_classes)
        )

    def forward(self, features: torch.Tensor):
        # features shape: (batch_size, sequence_length, video_feature_dim)
        H_out, _ = self.lstm(features)
        context, _ = self.attention(H_out)
        logits = self.classifier(context)
        return logits



# Execution et tests

In [5]:
def evaluate_model(model: nn.Module, data_loader: DataLoader, criterion: nn.Module, device: torch.device) -> Tuple[float, float, torch.Tensor, torch.Tensor]:
    """
    Calculates the average loss and accuracy over the given data loader.
    Returns: avg_loss, accuracy, all_predictions, all_labels
    """
    model.eval() # Set model to evaluation mode
    total_loss = 0.0
    correct_predictions = 0
    total_samples = 0
    
    all_predictions = []
    all_labels = []

    # Log the expected number of samples for sanity check
    expected_samples = len(data_loader.dataset)

    with torch.no_grad(): # Disable gradient calculations during evaluation
        for features, labels in data_loader:
            features, labels = features.to(device), labels.to(device)
            
            outputs = model(features)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item() * features.size(0)
            _, predicted = torch.max(outputs.data, 1)
            
            # Store predictions and labels for detailed diagnostics
            all_predictions.append(predicted.cpu())
            all_labels.append(labels.cpu())
            
            total_samples += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    if total_samples == 0:
        print("Warning: No samples were processed during evaluation.")
        return 0.0, 0.0, torch.empty(0, dtype=torch.long), torch.empty(0, dtype=torch.long)
        
    avg_loss = total_loss / total_samples
    accuracy = 100 * correct_predictions / total_samples
    
    # Sanity check if the processed samples match the dataset size
    if total_samples != expected_samples:
         print(f"CRITICAL WARNING: Processed {total_samples} samples, but dataset has {expected_samples}. Check for bad data!")
         
    # Concatenate all stored predictions and labels
    all_predictions = torch.cat(all_predictions)
    all_labels = torch.cat(all_labels)

    return avg_loss, accuracy, all_predictions, all_labels


In [6]:
def run_classification_example(train_data_root: str, test_data_root: str):
    """
    Simulates a training and evaluation run using separate data roots, 
    performing validation on the test set after every epoch.
    
    Includes model checkpointing based on best validation (test) accuracy.
    """
    print("\n--- Running SA-LSTM Classification Pipeline Example ---")
    
    Train_accs = []
    Train_losss = []
    Val_accs = []
    Val_losss = []
    LR = []

    HIDDEN_DIM = 512
    ATTENTION_DIM = 512
    BATCH_SIZE = 4
    # Set to 50 epochs to match the user's log context
    NUM_EPOCHS = 50 

    try:
        # 1. Setup Training DataLoader
        train_loader = setup_data_pipeline(train_data_root, BATCH_SIZE, shuffle=True)
        
        # 2. Setup Testing DataLoader (Validation Set)
        test_loader = setup_data_pipeline(test_data_root, BATCH_SIZE, shuffle=False)
        
        global NUM_CLASSES, VIDEO_FEATURE_DIM 
        
        # Get class index to name mapping from the test dataset for reporting
        idx_to_class = {v: k for k, v in test_loader.dataset.class_to_idx.items()}
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        # Instantiate the model
        model = SA_LSTM_Classification_Model(
            video_feature_dim=VIDEO_FEATURE_DIM,
            hidden_dim=HIDDEN_DIM,
            attention_dim=ATTENTION_DIM,
            num_classes=NUM_CLASSES 
        ).to(device)

        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-3)
        scheduler = lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)  # Reduce LR by half every 10 epochs

        # --- Checkpointing Variables ---
        best_test_accuracy = -1.0
        # Updated path to match what was seen in the user's log output
        checkpoint_path = 'best_model_checkpoint_full.pt' 
        # -------------------------------

        # --- Training Loop ---
        print("\n--- Starting Training (Validation and Checkpointing enabled) ---")
        for epoch in range(NUM_EPOCHS):
            # Training Phase
            model.train() 
            running_loss = 0.0
            running_corrects = 0
            total_train_samples = 0
            
            for features, labels in train_loader: 
                features, labels = features.to(device), labels.to(device)
                
                optimizer.zero_grad()
                
                logits = model(features)
                loss = criterion(logits, labels)
                
                loss.backward()
                optimizer.step()
                
                # Update training statistics
                running_loss += loss.item() * features.size(0)
                _, predicted = torch.max(logits.data, 1)
                running_corrects += (predicted == labels).sum().item()
                total_train_samples += labels.size(0)
            
            # Training Metrics
            train_loss = running_loss / total_train_samples
            train_acc = 100 * running_corrects / total_train_samples
            
            # Validation (Test) Phase
            test_loss, test_accuracy, _, _ = evaluate_model(model, test_loader, criterion, device) 

            # ---------------------------------
            lr = optimizer.param_groups[0]["lr"]

            # Print Epoch Summary
            print(f"Epoch {epoch+1}/{NUM_EPOCHS} | Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.6f}% | Test Loss: {test_loss:.4f}, Test Acc: {test_accuracy:.6f}%, lr: {lr:3.6f}")
        
            Train_accs.append(train_acc)
            Train_losss.append(train_loss)
            Val_accs.append(test_accuracy)
            Val_losss.append(test_loss)
            LR.append(lr)



            # --- Checkpointing Logic ---
            if test_accuracy > best_test_accuracy:
                print(f"    --> New best validation accuracy: {test_accuracy:.2f}%. Saving model to {checkpoint_path}")
                best_test_accuracy = test_accuracy
                # Save only the model's learned parameters (state_dict)
                torch.save(model.state_dict(), checkpoint_path)

            # ---------------------------

            scheduler.step()



        # Final Output
        print("\n--- Training and Evaluation Process Complete ---")
        print(f"--- Best model saved with Test Accuracy: {best_test_accuracy:.2f}% ---")
        return Train_accs, Train_losss, Val_accs, Val_losss, LR
    
    except Exception as e:
        print(f"Error during data pipeline execution: {e}")
        # Add the helpful advice for CUDA errors
        if "CUDA error" in str(e):
             print("\nSUGGESTION: The CUDA error often means the real error occurred earlier. Try running with the environment variable CUDA_LAUNCH_BLOCKING=1 to pinpoint the exact line of code that failed.")

In [7]:
Train_accs, Train_losss, Val_accs, Val_losss, LR = run_classification_example('./TrainValFeatures/Train', './TrainValFeatures/Val')


--- Running SA-LSTM Classification Pipeline Example ---
Found 6513 files across 20 classes.
Class Mapping: {'advertisement': 0, 'animals_pets': 1, 'animation': 2, 'beauty_fashion': 3, 'cooking': 4, 'documentary': 5, 'education': 6, 'food_drink': 7, 'gaming': 8, 'howto': 9, 'kids_family': 10, 'movie_comedy': 11, 'music': 12, 'news_events_politics': 13, 'people': 14, 'science_technology': 15, 'sports_actions': 16, 'travel': 17, 'tv shows': 18, 'vehicles_autos': 19}
--- INFO: Updating global NUM_CLASSES from 2 to 20 based on data folders. ---
Found 497 files across 20 classes.
Class Mapping: {'advertisement': 0, 'animals_pets': 1, 'animation': 2, 'beauty_fashion': 3, 'cooking': 4, 'documentary': 5, 'education': 6, 'food_drink': 7, 'gaming': 8, 'howto': 9, 'kids_family': 10, 'movie_comedy': 11, 'music': 12, 'news_events_politics': 13, 'people': 14, 'science_technology': 15, 'sports_actions': 16, 'travel': 17, 'tv shows': 18, 'vehicles_autos': 19}

--- Starting Training (Validation and Che