In [None]:
import os
import numpy as np
import cv2
from PIL import Image
import tensorflow as tf
import torch
import torch.nn as nn
import pandas as pd
from collections import defaultdict
from tqdm import tqdm
from torchvision.models.resnet import ResNet, BasicBlock
import random
from tensorflow.keras.models import load_model
try:
    import timm
except ImportError:
    timm = None
from torchvision import transforms

# ======== CONFIGURABLE SETTINGS ========
DATASET_DIR = "dataset"
TEST_DIR = "unprocessed_dataset"
IMAGES_PER_CLASS = 100 # Adjust as needed for confusion matrix generation
OUTPUT_SUBFOLDER = "confusionMatrix" # Define the subfolder for saving CSVs
# Define Model Weights for different Ensembles
ENSEMBLE_WEIGHTS = {
    'EnsembleA': {'VGG16': 0.4, 'ResNet': 0.4, 'ViT': 0.2}, # tie break A
    'EnsembleB': {'VGG16': 0.34, 'ResNet': 0.34, 'ViT': 0.32}, # tie break B
    'EnsembleC': {'VGG16': 0.5, 'ResNet': 0.5, 'ViT': 0.0}, # no ViT
    'EnsembleD': {'VGG16': 0.0, 'ResNet': 0.5, 'ViT': 0.5}, # no VGG16
    'EnsembleE': {'VGG16': 0.5, 'ResNet': 0.0, 'ViT': 0.5} # no ResNet
}
# =======================================

valid_exts = ('.png', '.jpg', '.jpeg', '.bmp', '.gif')

# Preprocessing function for OCR-like binarization
def ocr_preprocessing(image, block_size=31, C=20):
    gray = np.dot(image[..., :3], [0.2989, 0.5870, 0.1140])
    gray_uint8 = (gray * 255).astype(np.uint8)
    binarized = cv2.adaptiveThreshold(gray_uint8, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, block_size, C)
    return np.stack([binarized/255.0]*3, axis=-1)

# Binarize image function using TensorFlow/Keras
def binarize_image(image_path, target_size=(56,56)):
    img = tf.keras.preprocessing.image.load_img(image_path, target_size=target_size)
    arr = tf.keras.preprocessing.image.img_to_array(img)/255.0
    bin_arr = ocr_preprocessing(arr)
    return Image.fromarray((bin_arr*255).astype(np.uint8))

# Function to get sorted class names from directory structure
def get_class_names(train_dir):
    # Check if the directory exists before listing
    if not os.path.isdir(train_dir):
        print(f"Warning: Training directory '{train_dir}' not found for deriving class names.")
        return []
    return sorted([d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))])

# Predict top N classes using a Keras model
def predict_top_n_keras(model, image, class_names, top_n=3):
    # Basic check if model expects classes and class_names list is available
    if not class_names and hasattr(model, 'output_shape') and model.output_shape[-1] > 0:
         print("Warning: Keras model prediction called without class names derived.")
    try:
        preds = model.predict(np.expand_dims(np.array(image)/255.0, axis=0), verbose=0)[0]
        top_indices = np.argsort(preds)[-top_n:][::-1]
        results = []
        for i in top_indices:
             # Attempt to map index to name, fallback to index string if needed
             if class_names and i < len(class_names):
                 results.append((class_names[i], preds[i]))
             else:
                  results.append((f"Index_{i}", preds[i]))
        return results
    except Exception as e:
        print(f"Error during Keras prediction: {e}")
        return []


# Predict top N classes using a PyTorch model
def predict_top_n_torch(model, transform, image, class_names, device, top_n=3):
    # Basic check for class names
    if not class_names:
        print("Warning: PyTorch model prediction called without class names derived.")
    try:
        input_tensor = transform(image).unsqueeze(0).to(device)
        with torch.no_grad():
            output = model(input_tensor)
            logits = output.logits if hasattr(output, "logits") else output
            probabilities = torch.nn.functional.softmax(logits, dim=1)[0]
        top_probs, top_indices = torch.topk(probabilities, top_n)
        top_n_preds = []
        for i in range(top_n):
            cls_idx = top_indices[i].item()
            prob = top_probs[i].item()
            # Attempt to map index to name, fallback to index string if needed
            if class_names and cls_idx < len(class_names):
                 top_n_preds.append((class_names[cls_idx], prob))
            else:
                 top_n_preds.append((f"Index_{cls_idx}", prob))
        return top_n_preds
    except Exception as e:
        print(f"Error during PyTorch prediction: {e}")
        return []


# Define a ResNet18 model adapted for binary (1-channel) input
class BinaryResNet18(ResNet):
    def __init__(self, num_classes=45): # Default classes if not provided
        super().__init__(BasicBlock, [2,2,2,2], num_classes=num_classes)
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Evaluate models and generate confusion matrices
def evaluate_models_confusion_matrix():
    # Derive class names and count
    train_dir = os.path.join(DATASET_DIR, "train")
    class_names = get_class_names(train_dir)
    num_classes = len(class_names)
    if num_classes == 0:
        # Changed from warning to error as num_classes=0 will likely break model loading/evaluation
        raise ValueError(f"No class subdirectories found in '{train_dir}'. Cannot proceed without classes.")

    # Collect image paths from the test directory
    image_paths_by_class = defaultdict(list)
    if not os.path.exists(TEST_DIR) or not os.path.isdir(TEST_DIR):
        raise FileNotFoundError(f"Test directory '{TEST_DIR}' not found.")
    for root, _, files in os.walk(TEST_DIR):
        for file in files:
            if file.lower().endswith(valid_exts):
                class_name = os.path.basename(root)
                if class_name in class_names: # Only include images from known classes
                    image_paths_by_class[class_name].append(os.path.join(root, file))

    # Select images for evaluation
    selected_paths = []
    selected_labels = []
    print(f"Selecting {IMAGES_PER_CLASS} image(s) per class for evaluation...")
    for cls in class_names:
        available_images = image_paths_by_class[cls]
        if len(available_images) < IMAGES_PER_CLASS:
            print(f"Warning: Class {cls} has insufficient test images ({len(available_images)}) - needs at least {IMAGES_PER_CLASS}. Using all available.")
            if not available_images:
                 print(f"Warning: Class {cls} has no images in {TEST_DIR}. Skipping this class.")
                 continue
            num_to_sample = len(available_images)
        else:
            num_to_sample = IMAGES_PER_CLASS

        selected_paths.extend(random.sample(available_images, num_to_sample))
        selected_labels.extend([cls] * num_to_sample)

    if not selected_paths:
        print("No images selected for evaluation. Exiting.")
        return

    print(f"Total images selected for evaluation: {len(selected_paths)}")

    # Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load Keras model (VGG16)
    keras_model_path = "vgg16_model.h5"
    keras_model = None
    if os.path.exists(keras_model_path):
        try:
            keras_model = load_model(keras_model_path)
            # Basic validation if model output matches num_classes
            if keras_model.output_shape[-1] != num_classes:
                 print(f"Warning: Keras model output units ({keras_model.output_shape[-1]}) mismatch dataset classes ({num_classes}).")
            print(f"Keras model loaded from {keras_model_path}")
        except Exception as e:
            print(f"Error loading Keras model from {keras_model_path}: {e}")
    else:
        print(f"Keras model file not found: {keras_model_path}")

    # Load PyTorch model (ResNet18)
    resnet_model_path = "resnet18_model.pth"
    resnet_model = None
    if os.path.exists(resnet_model_path):
        try:
            resnet_model = BinaryResNet18(num_classes=num_classes).to(device) # Use derived num_classes
            checkpoint = torch.load(resnet_model_path, map_location=device)
            if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
                state_dict = checkpoint["model_state_dict"]
            elif isinstance(checkpoint, dict):
                 state_dict = checkpoint
            else:
                 resnet_model = checkpoint
                 state_dict = None

            if state_dict:
                # Clean keys
                new_state_dict = {
                    k[len("model."):] if k.startswith("model.") else k: v
                    for k, v in state_dict.items()
                }
                # Load state_dict, consider strict=False if necessary but be cautious
                try:
                     resnet_model.load_state_dict(new_state_dict, strict=True)
                except RuntimeError as e:
                     print(f"Error loading ResNet state_dict (likely mismatch): {e}. Trying strict=False.")
                     try:
                         resnet_model.load_state_dict(new_state_dict, strict=False)
                         print("Loaded ResNet state_dict with strict=False.")
                     except RuntimeError as e2:
                         print(f"Failed loading ResNet state_dict even with strict=False: {e2}")
                         resnet_model = None # Mark as failed

            if resnet_model:
                 # Validate output layer if possible (depends on model structure)
                 if hasattr(resnet_model, 'fc') and resnet_model.fc.out_features != num_classes:
                     print(f"Warning: ResNet model FC layer units ({resnet_model.fc.out_features}) mismatch dataset classes ({num_classes}).")
                 resnet_model.eval()
                 print(f"ResNet model loaded from {resnet_model_path}")
        except Exception as e:
            print(f"Error loading or instantiating ResNet model from {resnet_model_path}: {e}")
            resnet_model = None
    else:
        print(f"ResNet model file not found: {resnet_model_path}")

    # Load PyTorch model (ViT)
    vit_model_path = "vit_model.pth"
    vit_model = None
    if os.path.exists(vit_model_path):
        try:
            checkpoint_vit = torch.load(vit_model_path, map_location=device)
            if isinstance(checkpoint_vit, dict):
                state_dict_vit = checkpoint_vit.get("model_state_dict", checkpoint_vit)
                if timm is not None:
                    print("Loading ViT state_dict using timm...")
                    vit_model = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=num_classes) # Use derived num_classes
                    state_dict_vit = {k.replace("model.", ""): v for k, v in state_dict_vit.items()}
                    # Load state_dict, consider strict=False
                    try:
                         vit_model.load_state_dict(state_dict_vit, strict=True)
                         vit_model.to(device)
                    except RuntimeError as e:
                         print(f"Error loading ViT state_dict (likely mismatch): {e}. Trying strict=False.")
                         try:
                             vit_model.load_state_dict(state_dict_vit, strict=False)
                             vit_model.to(device)
                             print("Loaded ViT state_dict with strict=False.")
                         except RuntimeError as e2:
                             print(f"Failed loading ViT state_dict even with strict=False: {e2}")
                             vit_model = None # Mark as failed
                else:
                    print("timm library not available, cannot load ViT state_dict automatically.")
                    vit_model = None
            else:
                 print("Loading saved ViT model object...")
                 vit_model = checkpoint_vit.to(device) # Assumes saved model object is compatible

            if vit_model:
                 # Validate output layer if possible (depends on timm model structure)
                 if hasattr(vit_model, 'head') and hasattr(vit_model.head, 'out_features') and vit_model.head.out_features != num_classes:
                      print(f"Warning: ViT model head layer units ({vit_model.head.out_features}) mismatch dataset classes ({num_classes}).")
                 vit_model.eval()
                 print(f"ViT model loaded from {vit_model_path}")

        except Exception as e:
            print(f"Error loading ViT model from {vit_model_path}: {e}")
            vit_model = None
    else:
        print(f"ViT model file not found: {vit_model_path}")

    # Check if at least one model loaded
    if not keras_model and not resnet_model and not vit_model:
        print("No models were loaded successfully. Cannot generate confusion matrices.")
        return

    # Define image transformations
    resnet_transform = transforms.Compose([
        transforms.Grayscale(1),
        transforms.Resize((56,56)),
        transforms.ToTensor(),
        transforms.Normalize([0.5], [0.5])
    ])

    vit_transform = transforms.Compose([
        transforms.Resize((56,56)),
        transforms.ToTensor(),
        transforms.Normalize([0.5]*3, [0.5]*3)
    ])

    # Initialize confusion matrices including the new ensemble
    confusion_matrices = {
        'VGG16': np.zeros((num_classes, num_classes), dtype=int),
        'ResNet': np.zeros((num_classes, num_classes), dtype=int),
        'ViT': np.zeros((num_classes, num_classes), dtype=int),
        'ConfidenceEnsemble': np.zeros((num_classes, num_classes), dtype=int),
        'ConfidenceEnsemble_NoViT': np.zeros((num_classes, num_classes), dtype=int), # New ensemble
        'EnsembleA': np.zeros((num_classes, num_classes), dtype=int),
        'EnsembleB': np.zeros((num_classes, num_classes), dtype=int),
        'EnsembleC': np.zeros((num_classes, num_classes), dtype=int),
        'EnsembleD': np.zeros((num_classes, num_classes), dtype=int),
        'EnsembleE': np.zeros((num_classes, num_classes), dtype=int)
    }

    # Process each selected image
    print("Generating predictions and updating confusion matrices...")
    for path, true_label in tqdm(zip(selected_paths, selected_labels), total=len(selected_paths), desc="Evaluating Images"):
        try:
            true_idx = class_names.index(true_label)
            img_binarized = binarize_image(path, target_size=(56, 56))

            # Store predictions from each model for ensemble calculation
            model_top1_predictions = {}
            keras_preds_list = []
            resnet_preds_list = []
            vit_preds_list = []

            # Keras (VGG16) predictions
            if keras_model:
                try:
                    keras_preds_list = predict_top_n_keras(keras_model, img_binarized, class_names, top_n=3)
                    if keras_preds_list:
                        top1_pred_class = keras_preds_list[0][0]
                        model_top1_predictions['VGG16'] = top1_pred_class
                        # Update individual model confusion matrix if prediction is valid class name
                        if top1_pred_class in class_names:
                           pred_idx = class_names.index(top1_pred_class)
                           confusion_matrices['VGG16'][true_idx, pred_idx] += 1
                        else: print(f"Warning: Keras predicted '{top1_pred_class}' not in class_names.")
                    else: model_top1_predictions['VGG16'] = None
                except Exception as e:
                    print(f"\nError during Keras prediction for {path}: {e}")
                    model_top1_predictions['VGG16'] = None

            # ResNet predictions
            if resnet_model:
                try:
                    img_for_resnet = img_binarized.convert("L")
                    resnet_preds_list = predict_top_n_torch(resnet_model, resnet_transform, img_for_resnet, class_names, device, top_n=3)
                    if resnet_preds_list:
                        top1_pred_class = resnet_preds_list[0][0]
                        model_top1_predictions['ResNet'] = top1_pred_class
                        # Update individual model confusion matrix
                        if top1_pred_class in class_names:
                             pred_idx = class_names.index(top1_pred_class)
                             confusion_matrices['ResNet'][true_idx, pred_idx] += 1
                        else: print(f"Warning: ResNet predicted '{top1_pred_class}' not in class_names.")
                    else: model_top1_predictions['ResNet'] = None
                except Exception as e:
                    print(f"\nError during ResNet prediction for {path}: {e}")
                    model_top1_predictions['ResNet'] = None

            # ViT predictions
            if vit_model:
                try:
                    vit_preds_list = predict_top_n_torch(vit_model, vit_transform, img_binarized, class_names, device, top_n=3)
                    if vit_preds_list:
                        top1_pred_class = vit_preds_list[0][0]
                        model_top1_predictions['ViT'] = top1_pred_class
                         # Update individual model confusion matrix
                        if top1_pred_class in class_names:
                             pred_idx = class_names.index(top1_pred_class)
                             confusion_matrices['ViT'][true_idx, pred_idx] += 1
                        else: print(f"Warning: ViT predicted '{top1_pred_class}' not in class_names.")
                    else: model_top1_predictions['ViT'] = None
                except Exception as e:
                    print(f"\nError during ViT prediction for {path}: {e}")
                    model_top1_predictions['ViT'] = None

            # Calculate Confidence-weighted ensemble predictions (Original and NoViT)
            confidence_sums = defaultdict(float)
            confidence_sums_no_vit = defaultdict(float)

            # Accumulate probabilities from Keras predictions
            for cls, prob in keras_preds_list:
                 # Ensure class is valid before adding
                 if cls in class_names:
                      confidence_sums[cls] += prob
                      confidence_sums_no_vit[cls] += prob

            # Accumulate probabilities from ResNet predictions
            for cls, prob in resnet_preds_list:
                  if cls in class_names:
                       confidence_sums[cls] += prob
                       confidence_sums_no_vit[cls] += prob

            # Accumulate probabilities from ViT predictions (only for the original ensemble)
            for cls, prob in vit_preds_list:
                  if cls in class_names:
                       confidence_sums[cls] += prob

            # Determine final prediction for ConfidenceEnsemble
            if confidence_sums:
                confidence_pred_class = max(confidence_sums, key=confidence_sums.get)
                pred_idx = class_names.index(confidence_pred_class)
                confusion_matrices['ConfidenceEnsemble'][true_idx, pred_idx] += 1

            # Determine final prediction for ConfidenceEnsemble_NoViT
            if confidence_sums_no_vit:
                confidence_no_vit_pred_class = max(confidence_sums_no_vit, key=confidence_sums_no_vit.get)
                pred_idx = class_names.index(confidence_no_vit_pred_class)
                confusion_matrices['ConfidenceEnsemble_NoViT'][true_idx, pred_idx] += 1


            # Calculate predictions for weighted ensembles (A, B, C, D, E) using top-1 predictions
            for ensemble_name, weights in ENSEMBLE_WEIGHTS.items():
                ensemble_scores = defaultdict(float)
                for model_name, top1_class in model_top1_predictions.items():
                    # Check if prediction exists, is a valid class, and has weight in this ensemble
                    if top1_class is not None and top1_class in class_names and model_name in weights:
                        ensemble_scores[top1_class] += weights.get(model_name, 0.0) # Use get for safety

                if ensemble_scores:
                    weighted_pred_class = max(ensemble_scores, key=ensemble_scores.get)
                    pred_idx = class_names.index(weighted_pred_class)
                    confusion_matrices[ensemble_name][true_idx, pred_idx] += 1

        except Exception as e:
            print(f"\nError processing image {path}: {e}")
            continue

    # Create output subfolder if it doesn't exist
    output_dir = OUTPUT_SUBFOLDER
    os.makedirs(output_dir, exist_ok=True)
    print(f"\nSaving confusion matrices to subfolder: {output_dir}")

    # Save confusion matrices to CSV files inside the subfolder
    active_models_and_ensembles = ['VGG16', 'ResNet', 'ViT', 'ConfidenceEnsemble', 'ConfidenceEnsemble_NoViT'] + list(ENSEMBLE_WEIGHTS.keys()) # Added new ensemble

    for name in active_models_and_ensembles:
        should_save = False
        # Determine if the matrix should be saved based on model availability or if it's an ensemble
        if name == 'VGG16' and keras_model: should_save = True
        elif name == 'ResNet' and resnet_model: should_save = True
        elif name == 'ViT' and vit_model: should_save = True
        elif name.startswith('Ensemble') or name.startswith('ConfidenceEnsemble'): # Cover all ensemble types
             # Ensembles should be saved if at least one contributing model loaded
             if name == 'ConfidenceEnsemble_NoViT':
                 should_save = bool(keras_model or resnet_model)
             elif name == 'EnsembleC': # VGG+ResNet
                 should_save = bool(keras_model or resnet_model)
             elif name == 'EnsembleD': # ResNet+ViT
                 should_save = bool(resnet_model or vit_model)
             elif name == 'EnsembleE': # VGG+ViT
                 should_save = bool(keras_model or vit_model)
             else: # ConfidenceEnsemble, EnsembleA, EnsembleB assume all models might contribute
                 should_save = bool(keras_model or resnet_model or vit_model)


        if should_save and name in confusion_matrices:
            matrix = confusion_matrices[name]
            # Only save if the matrix contains counts (i.e., predictions were made)
            if matrix.sum() > 0:
                df = pd.DataFrame(matrix, index=class_names, columns=class_names)
                output_filename = f'confusion_{name}.csv'
                output_path = os.path.join(output_dir, output_filename)
                try:
                    df.to_csv(output_path)
                    print(f"Saved confusion matrix to {output_path}")
                except IOError as e:
                    print(f"Error saving confusion matrix {output_filename} to CSV: {e}")
            else:
                 print(f"Skipping saving empty confusion matrix for {name} (no predictions recorded).")
        elif name in confusion_matrices: # Matrix exists but shouldn't be saved (e.g., model not loaded)
             print(f"Skipping saving confusion matrix for {name} (required model(s) not loaded or no predictions made).")


# --- Main Execution ---
if __name__ == "__main__":
    try:
        evaluate_models_confusion_matrix()
        print("\nScript finished.")
    except FileNotFoundError as e:
        print(f"\nError: {e}. Please ensure dataset directories and model files exist.")
    except ValueError as e:
        print(f"\nError: {e}. Please check dataset structure or configuration.")
    except Exception as e:
        print(f"\nAn unexpected error occurred: {e}")

Selecting 100 image(s) per class for evaluation...
Total images selected for evaluation: 4500
Using device: cpu




Keras model loaded from vgg16_model.h5
ResNet model loaded from resnet18_model.pth
Loading saved ViT model object...
ViT model loaded from vit_model.pth
Generating predictions and updating confusion matrices...


Evaluating Images: 100%|██████████| 4500/4500 [22:57<00:00,  3.27it/s]


Saving confusion matrices to subfolder: confusionMatrix
Saved confusion matrix to confusionMatrix\confusion_VGG16.csv
Saved confusion matrix to confusionMatrix\confusion_ResNet.csv
Saved confusion matrix to confusionMatrix\confusion_ViT.csv
Saved confusion matrix to confusionMatrix\confusion_ConfidenceEnsemble.csv
Saved confusion matrix to confusionMatrix\confusion_ConfidenceEnsemble_NoViT.csv
Saved confusion matrix to confusionMatrix\confusion_EnsembleA.csv
Saved confusion matrix to confusionMatrix\confusion_EnsembleB.csv
Saved confusion matrix to confusionMatrix\confusion_EnsembleC.csv
Saved confusion matrix to confusionMatrix\confusion_EnsembleD.csv
Saved confusion matrix to confusionMatrix\confusion_EnsembleE.csv

Script finished.



