# Face Detection and Recognition Training Pipeline

This pipeline outlines the overall steps for training and using a face recognition model:

## Table of Contents
1. [Data Preparation](#1-import-libraries)
2. [Image Preprocessing (Face Detection and Cropping)](#2-face-detection)
3. [InceptionResnetV1 Model Initialization](#3-inceptionresnetv1)
4. [Model Training (Fine-tuning)](#4-fine-tuning)
5. [Testing and Inference](#5-inference-and-testing)

## Pipeline Overview

1. **[Data Preparation](#1-import-libraries)**: Face image data is organized in a specific directory structure, with each subdirectory representing an identity (person).

2. **[Image Preprocessing (Face Detection and Cropping)](#2-face-detection)**:
   * The MTCNN (Multi-task Cascaded Convolutional Networks) model is used to detect faces in each image.
   * Detected faces are cropped and aligned to a standard size (160x160 pixels), then saved to a new directory. This step ensures that only faces are fed into training, removing noise from the surrounding environment.

3. **[InceptionResnetV1 Model Initialization](#3-inceptionresnetv1)**:
   * An InceptionResnetV1 model is initialized. This model is initially trained on a large dataset (such as VGGFace2) and then fine-tuned for the specific number of classes (people) in your dataset.

4. **[Model Training (Fine-tuning)](#4-fine-tuning)**:
   * The cropped and normalized face data is split into training and validation sets.
   * The model is trained using Cross Entropy loss function and Adam optimizer.
   * The training process runs through multiple epochs, with performance monitoring on both training and validation sets.
   * After training, the model weights are saved.

5. **[Testing and Inference](#5-inference-and-testing)**:
   * The trained model is loaded back.
   * When a new image arrives, MTCNN is again used to detect and crop the face.
   * The cropped face is fed into the InceptionResnetV1 model to predict the identity and confidence of that prediction.

## 1. Import Libraries

In [1]:
from facenet_pytorch import MTCNN, InceptionResnetV1, fixed_image_standardization, training
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torch import optim
from torch.optim.lr_scheduler import MultiStepLR
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import numpy as np
import os
import matplotlib.pyplot as plt
from PIL import Image
import os

  from .autonotebook import tqdm as notebook_tqdm


### 1.1. Define run parameters

The dataset should follow the VGGFace2/ImageNet-style directory layout. Modify `data_dir` to the location of the dataset on wish to finetune on.

In [2]:
data_dir = 'Dataset/raw'

batch_size = 32
epochs = 8
workers = 0 if os.name == 'nt' else 8

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Running on device: {}'.format(device))

Running on device: cpu


### 1.2. TODO 1: Function dataset_info

In [4]:
def dataset_info(data_dir):
    """Print dataset information including number of classes and images per class
    TODO 1
    Args:
        data_dir (str): Path to the dataset directory (e.g., 'Dataset/raw')
    
    Expected Output:
        Should print information about the dataset in this format:
        
        Dataset Information for: Dataset/raw
        ==================================================
        Number of classes: 2
        Classes: ['vanhau', 'vantoan']
        
        Images per class:
        ------------------------------
          vanhau: 15 images
            Examples: OIP (1).jpg, OIP (2).jpg, OIP (3).jpg
            ... and 12 more
        
          vantoan: 12 images
            Examples: OIP (1).jpg, image1.jpg, image2.jpg
            ... and 9 more
        
        Total images in dataset: 27
        ==================================================
    
    Requirements:
        1. Check if the data_dir exists, if not print error message
        2. Find all subdirectories (these represent classes/people)
        3. Count total number of classes
        4. For each class, count how many image files (.jpg, .jpeg, .png, .bmp, .gif)
        5. Show first 3 image names as examples for each class
        6. Calculate and display total images across all classes
        7. Handle edge cases (empty directories, no classes found)
    
    Hints:
        - Use os.path.exists() to check if directory exists
        - Use os.listdir() to get directory contents
        - Use os.path.isdir() to check if item is a directory
        - Use str.lower().endswith() to check file extensions
        - Use len() to count items
        - Use sorted() to display classes in alphabetical order
    """
    pass

# Test the function with your dataset
dataset_info('Dataset/raw')

# def dataset_info(data_dir):
#     """Print dataset information including number of classes and images per class"""
#     # Step 1: Check if directory exists
#     if not os.path.exists(data_dir):
#         print(f"Error: Dataset directory '{data_dir}' not found.")
#         return
    
#     # Step 2: Find all subdirectories (classes)
#     items = os.listdir(data_dir)
#     classes = []
#     for item in items:
#         item_path = os.path.join(data_dir, item)
#         if os.path.isdir(item_path):
#             classes.append(item)
    
#     # Step 3: Handle no classes found
#     if not classes:
#         print(f"No class directories found in '{data_dir}'")
#         return
    
#     # Sort classes alphabetically
#     classes = sorted(classes)
    
#     # Step 4: Print header information
#     print(f"Dataset Information for: {data_dir}")
#     print("=" * 50)
#     print(f"Number of classes: {len(classes)}")
#     print(f"Classes: {classes}")
#     print()
#     print("Images per class:")
#     print("-" * 30)
    
#     # Step 5: Count images in each class
#     total_images = 0
#     image_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
    
#     for class_name in classes:
#         class_path = os.path.join(data_dir, class_name)
        
#         # Get all files in class directory
#         files = os.listdir(class_path)
        
#         # Filter only image files
#         images = []
#         for file in files:
#             if file.lower().endswith(image_extensions):
#                 images.append(file)
        
#         num_images = len(images)
#         total_images += num_images
        
#         # Print class information
#         print(f"  {class_name}: {num_images} images")
        
#         # Show examples if images exist
#         if num_images > 0:
#             examples = images[:3]  # First 3 images
#             print(f"    Examples: {', '.join(examples)}")
#             if num_images > 3:
#                 print(f"    ... and {num_images - 3} more")
#         print()
    
#     # Step 6: Print total
#     print(f"Total images in dataset: {total_images}")
#     print("=" * 50)
    
# dataset_info('Dataset/raw')

## 2. Face Detection

In [5]:
mtcnn = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

In [6]:
dataset = datasets.ImageFolder(data_dir, transform=transforms.Resize((512, 512)))
dataset.samples = [
    (p, p.replace(data_dir, data_dir + '_cropped'))
        for p, _ in dataset.samples
]
        
loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    collate_fn=training.collate_pil
)

for i, (x, y) in enumerate(loader):
    mtcnn(x, save_path=y)
    print('\rBatch {} of {}'.format(i + 1, len(loader)), end='')
    
# Remove mtcnn to reduce GPU memory usage
del mtcnn

Batch 1 of 1

## 3. InceptionResNetV1

### 3.1. Init model

In [7]:
resnet = InceptionResnetV1(
    classify=True,
    pretrained='vggface2',
    num_classes=len(dataset.class_to_idx)
).to(device)

### 3.2. Setup Optimizer, Scheduler, ...

In [8]:
optimizer = optim.Adam(resnet.parameters(), lr=0.001)
scheduler = MultiStepLR(optimizer, [5, 10])

trans = transforms.Compose([
    np.float32,
    transforms.ToTensor(),
    fixed_image_standardization
])
dataset = datasets.ImageFolder(data_dir + '_cropped', transform=trans)
img_inds = np.arange(len(dataset))
np.random.shuffle(img_inds)
train_inds = img_inds[:int(0.8 * len(img_inds))]
val_inds = img_inds[int(0.8 * len(img_inds)):]

train_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(train_inds)
)
val_loader = DataLoader(
    dataset,
    num_workers=workers,
    batch_size=batch_size,
    sampler=SubsetRandomSampler(val_inds)
)

### 3.3. TODO 2: The Accuracy metric

In [9]:
def pseudo_accuracy(outputs, targets):
    """
    TODO 2
    Calculate accuracy using simple math operations (for educational purposes)
    
    Args:
        outputs: Model predictions (tensor with shape [batch_size, num_classes])
        targets: True labels (tensor with shape [batch_size])
    
    Returns:
        accuracy: Float value between 0 and 1
    
    Example:
        If we have 4 samples with predictions and true labels:
        outputs = [[0.1, 0.9], [0.8, 0.2], [0.3, 0.7], [0.6, 0.4]]  # 2 classes
        targets = [1, 0, 1, 0]  # true labels
        
        Step 1: Find predicted class (highest value index)
        predicted = [1, 0, 1, 0]  # [0.9>0.1, 0.8>0.2, 0.7>0.3, 0.6>0.4]
        
        Step 2: Compare with true labels
        correct = [True, True, True, True]  # all predictions match targets
        
        Step 3: Calculate accuracy
        accuracy = 4/4 = 1.0 (100% correct)
    
    Your task:
        1. Convert PyTorch tensors to Python lists
        2. Find the predicted class for each sample (index of maximum value)
        3. Compare predictions with true labels
        4. Count correct predictions
        5. Calculate accuracy = correct_count / total_count
    
    Hints:
        - Use .tolist() to convert tensor to Python list
        - Use max() and list.index() to find the index of maximum value
        - Use sum() to count True values in a boolean list
        - Use len() to get total number of samples
    """
    pass

# def pseudo_accuracy(outputs, targets):
#     """Calculate accuracy using simple math operations (for educational purposes)"""
#     # Step 1: Convert PyTorch tensors to Python lists
#     outputs_list = outputs.tolist()
#     targets_list = targets.tolist()
    
#     # Step 2: Find predicted class for each sample (index of maximum value)
#     predicted_classes = []
#     for output_row in outputs_list:
#         # Find the index of maximum value
#         max_value = max(output_row)
#         predicted_class = output_row.index(max_value)
#         predicted_classes.append(predicted_class)
    
#     # Step 3: Compare predictions with true labels
#     correct_predictions = []
#     for i in range(len(predicted_classes)):
#         is_correct = predicted_classes[i] == targets_list[i]
#         correct_predictions.append(is_correct)
    
#     # Step 4: Count correct predictions
#     correct_count = sum(correct_predictions)  # sum() counts True as 1, False as 0
    
#     # Step 5: Calculate accuracy
#     total_count = len(targets_list)
#     accuracy = correct_count / total_count
    
#     return torch.tensor(accuracy)


def test_pseudo_accuracy():
    """Test function to verify pseudo_accuracy implementation"""
    print("Testing pseudo_accuracy function...")
    print("=" * 50)
    
    # Check if pseudo_accuracy is implemented
    try:
        # Try a simple test first to see if function is implemented
        test_outputs = torch.tensor([[0.1, 0.9]], dtype=torch.float32)
        test_targets = torch.tensor([1], dtype=torch.long)
        test_result = pseudo_accuracy(test_outputs, test_targets)
        
        # If we get here, function is implemented but might return None
        if test_result is None:
            print("❌ Function pseudo_accuracy() returns None - not implemented yet!")
            print("Please implement the function according to the TODO instructions.")
            return
            
    except (NotImplementedError, TypeError, AttributeError):
        print("❌ Function pseudo_accuracy() is not implemented yet!")
        print("Please implement the function according to the TODO instructions.")
        return
    except Exception as e:
        print(f"❌ Error in pseudo_accuracy() implementation: {str(e)}")
        print("Please check your implementation and try again.")
        return
    
    print("✅ Function pseudo_accuracy() is implemented! Running tests...")
    print()
    
    # Test case 1: Perfect accuracy
    # Simulating outputs for 2-class problem
    outputs_list = [
        [0.1, 0.9],  # predicted class 1, target should be 1
        [0.8, 0.2],  # predicted class 0, target should be 0  
        [0.3, 0.7],  # predicted class 1, target should be 1
        [0.9, 0.1]   # predicted class 0, target should be 0
    ]
    targets_list = [1, 0, 1, 0]
    
    # Convert to tensors (simulating PyTorch tensors)
    outputs_tensor = torch.tensor(outputs_list, dtype=torch.float32)
    targets_tensor = torch.tensor(targets_list, dtype=torch.long)
    
    try:
        # Test your function
        accuracy = pseudo_accuracy(outputs_tensor, targets_tensor)
        expected_accuracy = 1.0  # 100% correct
        
        print(f"Test Case 1 - Perfect Accuracy:")
        print(f"Outputs: {outputs_list}")
        print(f"Targets: {targets_list}")
        print(f"Your result: {accuracy}")
        print(f"Expected: {expected_accuracy}")
        
        # Handle different return types
        if accuracy is None:
            print(f"Status: ❌ FAIL - Function returns None")
        elif isinstance(accuracy, (int, float, torch.Tensor)):
            # Convert tensor to float for comparison
            acc_value = float(accuracy) if isinstance(accuracy, torch.Tensor) else accuracy
            status = "✅ PASS" if abs(acc_value - expected_accuracy) < 0.001 else "❌ FAIL"
            print(f"Status: {status}")
        else:
            print(f"Status: ❌ FAIL - Unexpected return type: {type(accuracy)}")
        print()
        
    except Exception as e:
        print(f"Test Case 1 - Error: {str(e)}")
        print("❌ FAIL - Exception occurred during test")
        print()
        return
    
    # Test case 2: 50% accuracy
    outputs_list_2 = [
        [0.8, 0.2],  # predicted class 0, target is 1 (wrong)
        [0.1, 0.9],  # predicted class 1, target is 1 (correct)
        [0.6, 0.4],  # predicted class 0, target is 1 (wrong) 
        [0.3, 0.7]   # predicted class 1, target is 1 (correct)
    ]
    targets_list_2 = [1, 1, 1, 1]
    
    outputs_tensor_2 = torch.tensor(outputs_list_2, dtype=torch.float32)
    targets_tensor_2 = torch.tensor(targets_list_2, dtype=torch.long)
    
    try:
        accuracy_2 = pseudo_accuracy(outputs_tensor_2, targets_tensor_2)
        expected_accuracy_2 = 0.5  # 50% correct (2 out of 4)
        
        print(f"Test Case 2 - 50% Accuracy:")
        print(f"Outputs: {outputs_list_2}")
        print(f"Targets: {targets_list_2}")
        print(f"Your result: {accuracy_2}")
        print(f"Expected: {expected_accuracy_2}")
        
        # Handle different return types
        if accuracy_2 is None:
            print(f"Status: ❌ FAIL - Function returns None")
        elif isinstance(accuracy_2, (int, float, torch.Tensor)):
            # Convert tensor to float for comparison
            acc_value_2 = float(accuracy_2) if isinstance(accuracy_2, torch.Tensor) else accuracy_2
            status = "✅ PASS" if abs(acc_value_2 - expected_accuracy_2) < 0.001 else "❌ FAIL"
            print(f"Status: {status}")
        else:
            print(f"Status: ❌ FAIL - Unexpected return type: {type(accuracy_2)}")
        print()
        
    except Exception as e:
        print(f"Test Case 2 - Error: {str(e)}")
        print("❌ FAIL - Exception occurred during test")
        print()
        return
    
    # Compare with PyTorch's built-in accuracy (only if both tests passed)
    try:
        if accuracy is not None and accuracy_2 is not None:
            pytorch_acc_1 = training.accuracy(outputs_tensor, targets_tensor)
            pytorch_acc_2 = training.accuracy(outputs_tensor_2, targets_tensor_2)
            
            print(f"Comparison with PyTorch built-in accuracy:")
            print(f"Test 1 - Your: {accuracy}, PyTorch: {pytorch_acc_1}")
            print(f"Test 2 - Your: {accuracy_2}, PyTorch: {pytorch_acc_2}")
            
            # Convert to float for comparison
            acc_1_val = float(accuracy) if isinstance(accuracy, torch.Tensor) else accuracy
            acc_2_val = float(accuracy_2) if isinstance(accuracy_2, torch.Tensor) else accuracy_2
            pytorch_1_val = float(pytorch_acc_1)
            pytorch_2_val = float(pytorch_acc_2)
            
            if (abs(acc_1_val - pytorch_1_val) < 0.001 and 
                abs(acc_2_val - pytorch_2_val) < 0.001):
                print("🎉 Your implementation matches PyTorch's accuracy!")
            else:
                print("❌ Your implementation doesn't match PyTorch's accuracy.")
                print("Please check your logic and try again.")
        
    except Exception as e:
        print(f"Error comparing with PyTorch accuracy: {str(e)}")

# Test the pseudo_accuracy function with error handling
test_pseudo_accuracy()

Testing pseudo_accuracy function...
❌ Function pseudo_accuracy() returns None - not implemented yet!
Please implement the function according to the TODO instructions.


The accuracy metrics above were merely our educational pseudo-implementation—a delightful learning exercise! Now, let us embrace the **official, battle-tested accuracy** that the pros use

In [10]:
loss_fn = torch.nn.CrossEntropyLoss()
metrics = {
    'fps': training.BatchTimer(),
    'acc': training.accuracy
}

## 4. Fine-tuning

In [11]:
writer = SummaryWriter()
writer.iteration, writer.interval = 0, 10

print('\n\nInitial')
print('-' * 10)
resnet.eval()
training.pass_epoch(
    resnet, loss_fn, val_loader,
    batch_metrics=metrics, show_running=True, device=device,
    writer=writer
)

for epoch in range(epochs):
    print('\nEpoch {}/{}'.format(epoch + 1, epochs))
    print('-' * 10)

    resnet.train()
    training.pass_epoch(
        resnet, loss_fn, train_loader, optimizer, scheduler,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

    resnet.eval()
    training.pass_epoch(
        resnet, loss_fn, val_loader,
        batch_metrics=metrics, show_running=True, device=device,
        writer=writer
    )

writer.close()

# Save the trained model after training completes
model_save_path = 'facenet_vantoan_vanhau.pth'
torch.save(resnet.state_dict(), model_save_path)
print(f'\nModel saved to: {model_save_path}')

# Save class names for inference
class_names_save_path = 'class_names.txt'
with open(class_names_save_path, 'w') as f:
    for class_name in dataset.classes:
        f.write(f"{class_name}\n")
print(f'Class names saved to: {class_names_save_path}')
print(f'Classes: {dataset.classes}')



Initial
----------
Valid |     1/1    | loss:    0.6379 | fps:   18.5233 | acc:    0.5000   

Epoch 1/8
----------
Train |     1/1    | loss:    0.7728 | fps:    9.0614 | acc:    0.6000   
Valid |     1/1    | loss:    0.1957 | fps:   20.9738 | acc:    1.0000   

Epoch 2/8
----------
Train |     1/1    | loss:    0.1837 | fps:    9.5596 | acc:    0.9333   
Valid |     1/1    | loss:   22.0069 | fps:   24.8276 | acc:    0.5000   

Epoch 3/8
----------
Train |     1/1    | loss:    0.3759 | fps:   11.0099 | acc:    0.8667   
Valid |     1/1    | loss:   94.0064 | fps:   25.0130 | acc:    0.5000   

Epoch 4/8
----------
Train |     1/1    | loss:    0.0022 | fps:    9.4539 | acc:    1.0000   
Valid |     1/1    | loss:  122.8393 | fps:   17.0070 | acc:    0.5000   

Epoch 5/8
----------
Train |     1/1    | loss:    0.0571 | fps:   11.3488 | acc:    0.9333   
Valid |     1/1    | loss:  120.2414 | fps:   24.9019 | acc:    0.5000   

Epoch 6/8
----------
Train |     1/1    | loss:    0.3

## 5. Inference and Testing
Test the trained model on sample images.

In [12]:
from PIL import Image
import matplotlib.pyplot as plt

# Load trained model for inference
def load_trained_model(model_path, num_classes, device):
    model = InceptionResnetV1(
        classify=True,
        pretrained='vggface2',
        num_classes=num_classes
    ).to(device)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()
    return model

# Initialize MTCNN for inference
mtcnn_inference = MTCNN(
    image_size=160, margin=0, min_face_size=20,
    thresholds=[0.6, 0.7, 0.7], factor=0.709, post_process=True,
    device=device
)

# Load class names with error handling
class_names_path = 'class_names.txt'
model_path = 'facenet_vantoan_vanhau.pth'

try:
    # Try to load class names from file
    with open(class_names_path, 'r') as f:
        class_names = [line.strip() for line in f.readlines()]
    print(f"Loaded class names from file: {class_names}")
except FileNotFoundError:
    # Fallback: use current dataset classes
    try:
        class_names = dataset.classes
        print(f"Using current dataset classes: {class_names}")
    except NameError:
        print("Error: No dataset or class names file found. Please run training first.")
        class_names = []

# Load trained model with error handling
if class_names and os.path.exists(model_path):
    try:
        model_inference = load_trained_model(model_path, len(class_names), device)
        print(f"Model loaded successfully for inference.")
    except Exception as e:
        print(f"Error loading model: {e}")
        model_inference = None
else:
    if not class_names:
        print("Cannot load model: No class names available.")
    else:
        print(f"Cannot load model: Model file not found at {model_path}")
        print("Please run the training cells first to create the model.")
    model_inference = None

Loaded class names from file: ['vanhau', 'vantoan']
Model loaded successfully for inference.


In [17]:
def predict_image(image_path, model, mtcnn, class_names, device):
    """
    TODO 3
    Predict class of face in image using trained model
    
    Args:
        image_path (str): Path to the image file
        model: Trained InceptionResnetV1 model
        mtcnn: MTCNN face detection model
        class_names (list): List of class names
        device: PyTorch device (CPU or CUDA)
    
    Returns:
        tuple: (predicted_class, confidence_score)
               - predicted_class (str): Name of predicted class or error message
               - confidence_score (float): Confidence score between 0 and 1
    
    Example:
        If input image contains a face of "vanhau" class:
        return ("vanhau", 0.95)
        
        If no face detected:
        return ("No face detected", 0.0)
        
        If error occurs:
        return ("Error: description", 0.0)
    
    Your task:
        1. Load and convert image to RGB
        2. Use MTCNN to detect and crop face from image
        3. Check if face was detected, return appropriate message if not
        4. Preprocess the cropped face for model input
        5. Run inference with the trained model
        6. Apply softmax to get probabilities
        7. Find the class with highest probability
        8. Return predicted class name and confidence score
        9. Handle any exceptions that might occur
    
    Hints:
        - Use Image.open(image_path).convert('RGB') to load image
        - Use mtcnn(img) to detect and crop face
        - Use img_cropped.unsqueeze(0).to(device) to add batch dimension
        - Use torch.no_grad() context for inference
        - Use torch.nn.functional.softmax(outputs, dim=1) for probabilities
        - Use torch.max(probabilities, 1) to get max probability and index
        - Use predicted.item() to get class index, then class_names[index] for name
        - Use try-except to handle errors gracefully
    """
    pass

# def predict_image(image_path, model, mtcnn, class_names, device):
#     """Predict class of face in image"""
#     try:
#         # Step 1: Load and convert image to RGB
#         img = Image.open(image_path).convert('RGB')
        
#         # Step 2: Use MTCNN to detect and crop face
#         img_cropped = mtcnn(img)
        
#         # Step 3: Check if face was detected
#         if img_cropped is None:
#             return "No face detected", 0.0
        
#         # Step 4: Preprocess for model input (add batch dimension and move to device)
#         img_cropped = img_cropped.unsqueeze(0).to(device)
        
#         # Step 5: Run inference with trained model
#         with torch.no_grad():
#             outputs = model(img_cropped)
            
#             # Step 6: Apply softmax to get probabilities
#             probabilities = torch.nn.functional.softmax(outputs, dim=1)
            
#             # Step 7: Find class with highest probability
#             confidence, predicted = torch.max(probabilities, 1)
        
#         # Step 8: Get predicted class name and confidence score
#         predicted_class = class_names[predicted.item()]
#         confidence_score = confidence.item()
        
#         return predicted_class, confidence_score
        
#     except Exception as e:
#         # Step 9: Handle exceptions gracefully
#         return f"Error: {str(e)}", 0.0


def test_predict_image():
    """Test function to verify predict_image implementation"""
    print("Testing predict_image function...")
    print("=" * 50)
    
    # Check if predict_image is implemented
    try:
        # Create dummy inputs for testing
        test_image_path = "dummy_path.jpg"  # This will cause an error, which is expected for testing
        dummy_model = None
        dummy_mtcnn = None
        dummy_class_names = ["class1", "class2"]
        dummy_device = torch.device('cpu')
        
        test_result = predict_image(test_image_path, dummy_model, dummy_mtcnn, dummy_class_names, dummy_device)
        
        # If we get here, function is implemented but might return None
        if test_result is None:
            print("❌ Function predict_image() returns None - not implemented yet!")
            print("Please implement the function according to the TODO instructions.")
            return
        
        # Check if it returns a tuple with 2 elements
        if not isinstance(test_result, tuple) or len(test_result) != 2:
            print("❌ Function should return a tuple with 2 elements: (predicted_class, confidence)")
            print("Please check your implementation.")
            return
            
    except (NotImplementedError, TypeError, AttributeError):
        print("❌ Function predict_image() is not implemented yet!")
        print("Please implement the function according to the TODO instructions.")
        return
    except Exception as e:
        # This is expected since we're using dummy inputs
        if "dummy_path.jpg" in str(e) or "NoneType" in str(e):
            print("✅ Function predict_image() is implemented! (Error handling works correctly)")
            print("Function correctly handles invalid inputs and returns error messages.")
        else:
            print(f"❌ Unexpected error in predict_image() implementation: {str(e)}")
            print("Please check your implementation and try again.")
            return
    
    print("\n🎉 predict_image() function structure is correct!")
    print("💡 To fully test this function, you need:")
    print("   1. A trained model loaded")
    print("   2. MTCNN initialized") 
    print("   3. Valid image files")
    print("   4. Class names list")
    print("\nOnce training is complete, this function will be tested automatically.")

# Test the predict_image function
test_predict_image()

Testing predict_image function...
❌ Function predict_image() returns None - not implemented yet!
Please implement the function according to the TODO instructions.


In [18]:
def test_sample_images():
    """Test model on sample images from both classes"""
    # Check if predict_image function is implemented first
    try:
        # Test with dummy inputs to see if function is implemented
        dummy_result = predict_image("dummy.jpg", None, None, ["test"], torch.device('cpu'))
        
        # If we get here and result is None, function is not implemented
        if dummy_result is None:
            print("❌ Function predict_image() is not implemented yet!")
            print("Please implement the predict_image() function according to the TODO 3 instructions.")
            return []
            
    except (NotImplementedError, TypeError, AttributeError):
        print("❌ Function predict_image() is not implemented yet!")
        print("Please implement the predict_image() function according to the TODO 3 instructions.")
        return []
    except Exception as e:
        # This is expected for dummy inputs - function is implemented
        pass
    
    # Check if model and class names are available
    if model_inference is None:
        print("Error: Model not loaded. Cannot run inference.")
        return []
    
    if not class_names:
        print("Error: No class names available.")
        return []
    
    print("✅ predict_image() function is implemented! Running tests...")
    print()
    
    test_results = []
    
    for class_name in class_names:
        class_dir = os.path.join(data_dir, class_name)
        if os.path.exists(class_dir):
            images = [f for f in os.listdir(class_dir) if f.lower().endswith(('.jpg', '.jpeg', '.png'))]
            
            if not images:
                print(f"No images found in {class_dir}")
                continue
            
            # Test first few images from each class
            for img_file in images[:3]:
                img_path = os.path.join(class_dir, img_file)
                predicted_class, confidence = predict_image(
                    img_path, model_inference, mtcnn_inference, class_names, device
                )
                
                result = {
                    'true_class': class_name,
                    'predicted_class': predicted_class,
                    'confidence': confidence,
                    'correct': predicted_class == class_name,
                    'image_path': img_path
                }
                test_results.append(result)
                
                status = "✅" if result['correct'] else "❌"
                print(f"{status} {class_name} -> {predicted_class} (conf: {confidence:.3f})")
        else:
            print(f"Directory not found: {class_dir}")
    
    # Summary
    if test_results:
        correct_predictions = sum(1 for r in test_results if r['correct'])
        accuracy = correct_predictions / len(test_results)
        print(f"\nTest Results: {correct_predictions}/{len(test_results)} correct ({accuracy*100:.1f}%)")
    else:
        print("No test results available.")
    
    return test_results

# Run tests with comprehensive error handling
print("Testing model on sample images...")
print("=" * 50)

# Check all prerequisites
if model_inference is not None and class_names:
    test_results = test_sample_images()
else:
    print("Cannot run tests: Model or class names not available.")
    print("Please ensure training has completed successfully.")
    print()
    
    # Also check if predict_image is implemented
    try:
        dummy_result = predict_image("dummy.jpg", None, None, ["test"], torch.device('cpu'))
        if dummy_result is None:
            print("Additionally: predict_image() function is not implemented.")
    except (NotImplementedError, TypeError, AttributeError):
        print("Additionally: predict_image() function is not implemented.")
    except:
        print("predict_image() function appears to be implemented.")

Testing model on sample images...
❌ Function predict_image() is not implemented yet!
Please implement the predict_image() function according to the TODO 3 instructions.
