# Test Age and Gender Models on a Custom Image
This notebook loads your trained models and predicts age and gender for an image in the `test` folder.

In [None]:
# Imports and setup
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import os
import matplotlib.pyplot as plt

# Model and image config
IMG_SIZE = (128, 128)
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
MODEL_DIR = 'models'
TEST_IMG_DIR = 'test'

# List test images
test_images = [f for f in os.listdir(TEST_IMG_DIR) if f.lower().endswith(('.jpg','.png','.jpeg'))]
print('Test images:', test_images)

In [None]:
# Define model class (must match training definition)
class BaseCNN(nn.Module):
    def __init__(self, output_units=1, task='age'):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2), nn.BatchNorm2d(128),
            nn.Flatten()
        )
        self.classifier = nn.Sequential(
            nn.Linear(128*(IMG_SIZE[0]//8)*(IMG_SIZE[1]//8), 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, output_units)
        )
        self.task = task

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        if self.task=='gender':
            x = torch.sigmoid(x)
        return x

In [None]:
# Load models
age_model = BaseCNN(output_units=1, task='age')
gender_model = BaseCNN(output_units=1, task='gender')

age_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'age_model.pth'), map_location=DEVICE))
gender_model.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'gender_model.pth'), map_location=DEVICE))

age_model.eval(); age_model.to(DEVICE)
gender_model.eval(); gender_model.to(DEVICE)

In [None]:
# Define transforms (should match validation/test transforms used in training)
test_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor()
])

# Gender label mapping (edit if needed)
gender_map = {0: 'Male', 1: 'Female'}

In [None]:
# Predict on all test images
for fname in test_images:
    img_path = os.path.join(TEST_IMG_DIR, fname)
    img = Image.open(img_path).convert('RGB')
    img_tensor = test_transform(img).unsqueeze(0).to(DEVICE)  # shape (1,3,H,W)

    # Age prediction
    with torch.no_grad():
        age_pred = age_model(img_tensor).item()

    # Gender prediction
    with torch.no_grad():
        gender_pred = gender_model(img_tensor).item()
        gender_label = gender_map[int(round(gender_pred))]

    # Show result
    plt.imshow(img)
    plt.axis('off')
    plt.title(f'Predicted Age: {age_pred:.1f}\nPredicted Gender: {gender_label}')
    plt.show()