In [26]:
import torch
import torch.nn as nn
import torchvision.models as models
import numpy as np
import cv2
import glob
import os
import random
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

IMAGE_SIZE = (600, 600)
MODEL_PATH = ".\\all_images_multimodel_512_20250915_184924\\efficientnet_b5_fold5_val80.4_test78.4_mult6303.4.pth"
BATCH_SIZE = 32
NUM_CLASSES = 5
FOLD = 5

def detect_and_convert_image(image):
    if len(image.shape) == 2:
        return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    elif len(image.shape) == 3:
        if image.shape[2] == 1:
            return cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        elif image.shape[2] == 3:
            return image
        elif image.shape[2] == 4:
            return cv2.cvtColor(image, cv2.COLOR_BGRA2RGB)
    return image

def load_combined_data():
    color_path = "D:\\Dropbox\\AI Projects\\buck\\trail cam\\images\\squared\\color\\*.png"
    gray_path = "D:\\Dropbox\\AI Projects\\buck\\trail cam\\images\\squared\\grayscale\\*.png"
    
    images = []
    ages = []
    sources = []
    
    print("Loading color images...")
    color_files = glob.glob(color_path)
    for img_path in color_files:
        try:
            img = cv2.imread(img_path)
            if img is None:
                continue
            
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = detect_and_convert_image(img)
            img_resized = cv2.resize(img, IMAGE_SIZE[::-1])
            
            filename = os.path.basename(img_path)
            filename_no_ext = os.path.splitext(filename)[0]
            parts = filename_no_ext.split('_')
            
            if len(parts) < 5:
                continue
            
            age_part = parts[3]
            if 'xpx' in age_part.lower() or 'p' not in age_part:
                continue
            
            try:
                age_value = float(age_part.replace('p', '.'))
                images.append(img_resized)
                ages.append(age_value)
                sources.append('color')
            except ValueError:
                continue
                
        except Exception:
            continue
    
    print(f"Loaded {len([s for s in sources if s == 'color'])} color images")
    
    print("Loading grayscale images...")
    gray_files = glob.glob(gray_path)
    for img_path in gray_files:
        try:
            img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
            if img is None:
                continue
            
            img = detect_and_convert_image(img)
            img_resized = cv2.resize(img, IMAGE_SIZE[::-1])
            
            filename = os.path.basename(img_path)
            filename_no_ext = os.path.splitext(filename)[0]
            parts = filename_no_ext.split('_')
            
            if len(parts) < 5:
                continue
            
            age_part = parts[3]
            if 'xpx' in age_part.lower() or 'p' not in age_part:
                continue
            
            try:
                age_value = float(age_part.replace('p', '.'))
                images.append(img_resized)
                ages.append(age_value)
                sources.append('grayscale')
            except ValueError:
                continue
                
        except Exception:
            continue
    
    print(f"Loaded {len([s for s in sources if s == 'grayscale'])} grayscale images")
    print(f"Total images: {len(images)}")
    
    return np.array(images), np.array(ages), sources

def enhanced_augment_image(image, strength='light'):
    if strength == 'light':
        rotate_prob = 0.3
        flip_prob = 0.3
        brightness_prob = 0.2
        gamma_prob = 0.2
        noise_prob = 0.1
    else:
        rotate_prob = 0.5
        flip_prob = 0.5
        brightness_prob = 0.4
        gamma_prob = 0.3
        noise_prob = 0.2
    
    if random.random() < rotate_prob:
        angle = random.uniform(-15, 15)
        h, w = image.shape[:2]
        M = cv2.getRotationMatrix2D((w/2, h/2), angle, 1.0)
        image = cv2.warpAffine(image, M, (w, h), borderMode=cv2.BORDER_REFLECT)
    
    if random.random() < flip_prob:
        image = cv2.flip(image, 1)
    
    if random.random() < brightness_prob:
        factor = random.uniform(0.85, 1.15)
        image = np.clip(image * factor, 0, 255).astype(np.uint8)
    
    if random.random() < gamma_prob:
        gamma = random.uniform(0.85, 1.15)
        inv_gamma = 1.0 / gamma
        table = np.array([((i / 255.0) ** inv_gamma) * 255 for i in np.arange(0, 256)]).astype("uint8")
        image = cv2.LUT(image, table)
    
    if random.random() < noise_prob:
        noise = np.random.normal(0, 5, image.shape).astype(np.int16)
        image_int16 = image.astype(np.int16)
        noisy_image = np.clip(image_int16 + noise, 0, 255)
        image = noisy_image.astype(np.uint8)
    
    return image

class OptimizedDataset(Dataset):
    def __init__(self, base_images, labels, aug_strength='light', target_per_class=200, training=False):
        self.base_images = base_images
        self.labels = np.array(labels)
        self.aug_strength = aug_strength
        self.training = training
        self.target_per_class = target_per_class
        
        unique_classes = np.unique(labels)
        self.class_to_indices = {}
        for cls in unique_classes:
            self.class_to_indices[cls] = np.where(self.labels == cls)[0]
        
        self.num_classes = len(unique_classes)
        self.class_list = sorted(unique_classes)
        self.length = self.num_classes * self.target_per_class
        
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32).reshape(3, 1, 1)
        self.std = np.array([0.229, 0.224, 0.225], dtype=np.float32).reshape(3, 1, 1)
        
        print(f"Dataset: {self.length} samples from {len(base_images)} base images")
    
    def __len__(self):
        return self.length
    
    def __getitem__(self, idx):
        class_idx = idx // self.target_per_class
        within_class_idx = idx % self.target_per_class
        
        target_class = self.class_list[class_idx]
        available_indices = self.class_to_indices[target_class]
        
        base_idx = available_indices[within_class_idx % len(available_indices)]
        image = self.base_images[base_idx].copy()
        
        if within_class_idx >= len(available_indices):
            image = enhanced_augment_image(image, self.aug_strength)
        
        image = image.astype(np.float32) / 255.0
        if len(image.shape) == 3:
            image = image.transpose(2, 0, 1)
        
        if not self.training and random.random() < 0.5:
            image = np.flip(image, axis=2).copy()
        
        image = (image - self.mean) / self.std
        
        return torch.from_numpy(image.astype(np.float32)), target_class

def load_model(model_path, num_classes, device):
    model = models.efficientnet_b5(weights='DEFAULT')
    
    layers_to_freeze = list(model.features.children())[:3]
    for layer in layers_to_freeze:
        for param in layer.parameters():
            param.requires_grad = False
    
    if isinstance(model.classifier, nn.Sequential):
        original_features = model.classifier[-1].in_features
    else:
        original_features = model.classifier.in_features
    
    model.classifier = nn.Sequential(
        nn.Dropout(0.3),
        nn.Linear(original_features, 512),
        nn.ReLU(inplace=True),
        nn.Dropout(0.15),
        nn.Linear(512, 256),
        nn.ReLU(inplace=True),
        nn.Dropout(0.075),
        nn.Linear(256, num_classes)
    )
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)
    model.eval()
    
    print(f"Model loaded from: {model_path}")
    print(f"Validation accuracy: {checkpoint['validation_accuracy']:.1f}%")
    print(f"Test accuracy: {checkpoint['test_accuracy']:.1f}%")
    
    return model

def evaluate_model(model, data_loader, device):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            
            with torch.amp.autocast('cuda'):
                outputs1 = model(images)
                outputs2 = model(torch.flip(images, [3]))
                outputs = (outputs1 + outputs2) / 2
            
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    return accuracy

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name()}")
    
    print("\nLoading data...")
    images, ages, sources = load_combined_data()
    
    unique_ages = sorted(list(set(ages)))
    label_mapping = {age: i for i, age in enumerate(unique_ages)}
    y_indices = np.array([label_mapping[age] for age in ages])
    
    print(f"\nAge distribution: {dict(zip(*np.unique(ages, return_counts=True)))}")
    print(f"Classes: {unique_ages}")
    print(f"Label indices distribution: {dict(zip(*np.unique(y_indices, return_counts=True)))}")
    
    # Check if we can stratify
    min_class_count = min(np.bincount(y_indices))
    print(f"Minimum class count: {min_class_count}")
    
    if min_class_count < 2:
        print("\nWARNING: Cannot stratify - some classes have only 1 sample")
        print("Using non-stratified split instead...")
        X_train, X_test, y_train, y_test = train_test_split(
            images, y_indices, test_size=0.2, random_state=FOLD * 42, stratify=None
        )
    else:
        print(f"\nReplicating fold {FOLD} train/test split...")
        X_train, X_test, y_train, y_test = train_test_split(
            images, y_indices, test_size=0.2, random_state=FOLD * 42, stratify=y_indices
        )
    
    print(f"Test set: {len(X_test)} base images")
    
    test_dataset = OptimizedDataset(X_test, y_test, 'light', 200, False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    print("\nLoading model...")
    model = load_model(MODEL_PATH, NUM_CLASSES, device)
    
    print("\nEvaluating model on held-out test set with augmentation...")
    accuracy = evaluate_model(model, test_loader, device)
    
    print(f"\nTest Accuracy (with augmentation): {accuracy:.2f}%")
    print(f"Expected from checkpoint: 78.4%")

if __name__ == "__main__":
    main()


Using device: cuda
GPU: NVIDIA GeForce RTX 5090

Loading data...
Loading color images...
Loaded 368 color images
Loading grayscale images...
Loaded 112 grayscale images
Total images: 480

Age distribution: {np.float64(1.5): np.int64(67), np.float64(2.5): np.int64(85), np.float64(3.5): np.int64(115), np.float64(4.5): np.int64(90), np.float64(5.5): np.int64(91), np.float64(6.5): np.int64(22), np.float64(7.5): np.int64(1), np.float64(8.5): np.int64(6), np.float64(9.5): np.int64(1), np.float64(12.5): np.int64(2)}
Classes: [np.float64(1.5), np.float64(2.5), np.float64(3.5), np.float64(4.5), np.float64(5.5), np.float64(6.5), np.float64(7.5), np.float64(8.5), np.float64(9.5), np.float64(12.5)]
Label indices distribution: {np.int64(0): np.int64(67), np.int64(1): np.int64(85), np.int64(2): np.int64(115), np.int64(3): np.int64(90), np.int64(4): np.int64(91), np.int64(5): np.int64(22), np.int64(6): np.int64(1), np.int64(7): np.int64(6), np.int64(8): np.int64(1), np.int64(9): np.int64(2)}
Minimum 