# Teacher vs Ensemble Visualisation
Compare predictions from Teacher model (PyTorch) vs Ensemble model (Hard Overlay)
Visualize: Original Image, Ground Truth Mask, Teacher Prediction, Ensemble Prediction

In [None]:
import os
import sys
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import torch
import cv2
import matplotlib.pyplot as plt
import albumentations as A
from albumentations.pytorch import ToTensorV2
from pathlib import Path

# Suppress TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['SM_FRAMEWORK'] = 'tf.keras'

import tensorflow as tf
import segmentation_models as sm
tf.get_logger().setLevel('ERROR')

# Configuration
IMG_HEIGHT, IMG_WIDTH = 512, 512
NUM_CLASSES = 3
CLASS_NAMES = ['Background', 'Cancer', 'Other Tissue']

PRIMARY_CLASS_COLORS = {
    0: (0, 0, 0),           # Black - Background
    1: (245, 66, 66),       # Red - Cancer
    2: (66, 135, 245),      # Blue - Other tissue
}

CLASS_MAPPING = {
    (0, 0, 0): 0,           # Background - Black
    (245, 66, 66): 1,       # Cancer - Red
    (255, 0, 0): 1,         # Cancer - Pure red
    (66, 135, 245): 2,      # Other Tissue - Blue
    (0, 110, 255): 2,       # Other Tissue - Blue variant
}

INVERSE_CLASS_MAPPING = {v: PRIMARY_CLASS_COLORS[v] for v in range(NUM_CLASSES)}

# Model paths
GENERALIST_PATH = ''
SPECIALIST_PATH = ''
TEACHER_PATH = ''

print("Imports and configuration loaded")

In [None]:
# Helper functions
def decode_mask_to_colors(mask):
    """Convert class mask to RGB color image"""
    color_mask = np.zeros((mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    for class_index, color in INVERSE_CLASS_MAPPING.items():
        color_mask[mask == class_index] = color
    return color_mask

def convert_mask_to_classes(mask_image):
    """Convert RGB mask image to class indices"""
    mask_classes = np.zeros((mask_image.shape[0], mask_image.shape[1]), dtype=np.uint8)
    for color, class_index in CLASS_MAPPING.items():
        match = np.all(mask_image == color, axis=-1)
        mask_classes[match] = class_index
    return mask_classes

def load_mask(mask_path):
    """Load and convert mask to class indices"""
    mask = cv2.imread(mask_path)
    mask = cv2.cvtColor(mask, cv2.COLOR_BGR2RGB)
    mask = cv2.resize(mask, (IMG_WIDTH, IMG_HEIGHT), interpolation=cv2.INTER_NEAREST)
    mask = convert_mask_to_classes(mask)
    return mask

def load_image(image_path):
    """Load and normalize image"""
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (IMG_WIDTH, IMG_HEIGHT))
    return img / 255.0

def build_tf_model():
    """Build TensorFlow model for generalist/specialist"""
    return sm.Unet('seresnet50', input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), 
                    classes=NUM_CLASSES, activation='softmax', encoder_weights=None)

def pad_to_divisible_by_32(array):
    """Pad array to be divisible by 32"""
    h, w = array.shape[:2]
    pad_h = 32 - (h % 32) if h % 32 != 0 else 0
    pad_w = 32 - (w % 32) if w % 32 != 0 else 0
    if array.ndim == 3:
        value = [0, 0, 0]
    else:
        value = 0
    padded = cv2.copyMakeBorder(array, 0, pad_h, 0, pad_w, cv2.BORDER_CONSTANT, value=value)
    return padded

def get_transforms():
    """Get albumentations transforms for teacher model"""
    return A.Compose([
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), 
                   max_pixel_value=255, p=1.0),
        ToTensorV2()
    ])

TEACHER_CLASS_REMAP = {
    0: 0, 1: 2, 2: 1, 3: 2, 4: 2, 5: 0
}

print("Helper functions defined")

In [None]:
# Load all models
print("Loading models...")

# TensorFlow models
print("Loading generalist model...")
model_generalist = build_tf_model()
model_generalist.load_weights(GENERALIST_PATH)

print("Loading specialist model...")
model_specialist = build_tf_model()
model_specialist.load_weights(SPECIALIST_PATH)

# Teacher model (PyTorch)
print("Loading teacher model...")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import segmentation_models_pytorch as smp

teacher_model = smp.Unet(
    encoder_name="efficientnet-b0",
    encoder_weights=None,
    classes=6,
    activation=None
)
teacher_model.load_state_dict(torch.load(TEACHER_PATH, map_location=device))
teacher_model.eval()
teacher_model.to(device)

print("All models loaded successfully")

In [None]:
from sklearn.metrics import f1_score, jaccard_score

def find_mask_for_image(image_path):
    """
    Automatically find the corresponding mask for a given image path.
    Assumes structure: clean/<case_id>/<image_name> -> mask/<case_id>/<image_name>
    """
    image_path = Path(image_path)
    image_name = image_path.name
    
    case_id = image_path.parent.name
    
    mask_path = Path("") / case_id / image_name
    
    if not mask_path.exists():
        raise FileNotFoundError(f"Mask not found for image: {image_path}\nExpected mask at: {mask_path}")
    
    return str(mask_path)

def calculate_metrics(gt_mask, pred_mask):
    """Calculate F1 and IoU metrics between ground truth and prediction"""
    gt_flat = gt_mask.flatten()
    pred_flat = pred_mask.flatten()
    
    f1 = f1_score(gt_flat, pred_flat, average='macro', zero_division=0)
    iou = jaccard_score(gt_flat, pred_flat, average='macro', zero_division=0)
    
    return f1, iou

def compare_predictions(image_path):
    """
    Load image, automatically find mask, and evaluate with teacher vs ensemble predictions.
    Display 4-panel figure: Original, GT Mask, Teacher Pred, Ensemble Pred
    
    Args:
        image_path: Path to the image file. Mask will be automatically located.
    """
    image_path = Path(image_path)
    
    mask_path = find_mask_for_image(image_path)
    
    image = load_image(str(image_path))
    gt_mask = load_mask(mask_path)
    
    image_rgb = cv2.imread(str(image_path))
    image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
    image_resized = cv2.resize(image_rgb, (IMG_WIDTH, IMG_HEIGHT))
    image_padded = pad_to_divisible_by_32(image_resized)
    
    transforms = get_transforms()
    augmented = transforms(image=image_padded)
    input_tensor = augmented['image'].unsqueeze(0).to(device, dtype=torch.float32)
    
    with torch.no_grad():
        output_logits = teacher_model(input_tensor)
        probs_6class = torch.softmax(output_logits, dim=1)
    
    probs_3class = torch.zeros((1, 3, probs_6class.shape[2], probs_6class.shape[3]), device=device)
    probs_3class[:, 0] = probs_6class[:, 0] + probs_6class[:, 5]  
    probs_3class[:, 1] = probs_6class[:, 2]                        
    probs_3class[:, 2] = probs_6class[:, 1] + probs_6class[:, 3] + probs_6class[:, 4]  
    
    teacher_pred = torch.argmax(probs_3class, dim=1).squeeze().cpu().numpy()
    teacher_pred = teacher_pred[:IMG_HEIGHT, :IMG_WIDTH]
    
    image_array = np.expand_dims(image, axis=0)
    
    probs_gen = model_generalist.predict(image_array, verbose=0)[0]
    probs_spec = model_specialist.predict(image_array, verbose=0)[0]
    
    mask_gen = np.argmax(probs_gen, axis=-1)
    mask_spec = np.argmax(probs_spec, axis=-1)
    
    ensemble_pred = mask_gen.copy()
    cancer_indices = (mask_spec == 1)
    ensemble_pred[cancer_indices] = 1
    
    teacher_f1, teacher_iou = calculate_metrics(gt_mask, teacher_pred.astype(np.uint8))
    ensemble_f1, ensemble_iou = calculate_metrics(gt_mask, ensemble_pred.astype(np.uint8))
    
    fig, axes = plt.subplots(1, 4, figsize=(18, 5))
    
    # Panel 1: Original Image
    axes[0].imshow(image)
    axes[0].set_title("Original Image", fontsize=12)
    axes[0].axis('off')
    
    # Panel 2: Ground Truth Mask
    gt_colored = decode_mask_to_colors(gt_mask)
    axes[1].imshow(gt_colored)
    axes[1].set_title("Ground Truth Mask", fontsize=12)
    axes[1].axis('off')
    
    # Panel 3: Teacher Prediction with metrics
    teacher_colored = decode_mask_to_colors(teacher_pred.astype(np.uint8))
    axes[2].imshow(teacher_colored)
    axes[2].set_title(f"Teacher \nF1: {teacher_f1:.3f} | IoU: {teacher_iou:.3f}", fontsize=12)
    axes[2].axis('off')
    
    # Panel 4: Ensemble Prediction with metrics
    ensemble_colored = decode_mask_to_colors(ensemble_pred.astype(np.uint8))
    axes[3].imshow(ensemble_colored)
    axes[3].set_title(f"Ensemble \nF1: {ensemble_f1:.3f} | IoU: {ensemble_iou:.3f}", fontsize=12)
    axes[3].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Print file info and metrics
    print(f"\nImage: {image_path.name}")
    print(f"   Path: {image_path}")
    print(f"   Mask: {Path(mask_path).name}")
    print(f"\nMetrics:")
    print(f"   Teacher - F1: {teacher_f1:.4f} | IoU: {teacher_iou:.4f}")
    print(f"   Ensemble - F1: {ensemble_f1:.4f} | IoU: {ensemble_iou:.4f}")
    
    return image, gt_mask, teacher_pred, ensemble_pred, teacher_f1, teacher_iou, ensemble_f1, ensemble_iou

print("Comparison function refactored with auto mask detection and metrics")

In [None]:
compare_predictions(
    ""
)

In [None]:
import glob
from pathlib import Path
import pandas as pd

def evaluate_all_images():
    """
    Evaluate all test images and return results sorted by ensemble F1 score (worst first)
    """
    test_image_dir = Path("")
    all_images = list(test_image_dir.glob("*/*.png"))
    
    results = []
    
    print(f"Evaluating {len(all_images)} test images...")
    
    for i, image_path in enumerate(all_images):
        if i % 20 == 0:
            print(f"  Progress: {i}/{len(all_images)} images processed")
        
        try:
            image_path_str = str(image_path)
            
            mask_path = find_mask_for_image(image_path_str)
            
            image = load_image(image_path_str)
            gt_mask = load_mask(mask_path)
            
            image_rgb = cv2.imread(image_path_str)
            image_rgb = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2RGB)
            image_resized = cv2.resize(image_rgb, (IMG_WIDTH, IMG_HEIGHT))
            image_padded = pad_to_divisible_by_32(image_resized)
            
            transforms = get_transforms()
            augmented = transforms(image=image_padded)
            input_tensor = augmented['image'].unsqueeze(0).to(device, dtype=torch.float32)
            
            with torch.no_grad():
                output_logits = teacher_model(input_tensor)
                probs_6class = torch.softmax(output_logits, dim=1)
            
            probs_3class = torch.zeros((1, 3, probs_6class.shape[2], probs_6class.shape[3]), device=device)
            probs_3class[:, 0] = probs_6class[:, 0] + probs_6class[:, 5]
            probs_3class[:, 1] = probs_6class[:, 2]                        
            probs_3class[:, 2] = probs_6class[:, 1] + probs_6class[:, 3] + probs_6class[:, 4]  
            
            teacher_pred = torch.argmax(probs_3class, dim=1).squeeze().cpu().numpy()
            teacher_pred = teacher_pred[:IMG_HEIGHT, :IMG_WIDTH]
            
            image_array = np.expand_dims(image, axis=0)
            
            probs_gen = model_generalist.predict(image_array, verbose=0)[0]
            probs_spec = model_specialist.predict(image_array, verbose=0)[0]
            
            mask_gen = np.argmax(probs_gen, axis=-1)
            mask_spec = np.argmax(probs_spec, axis=-1)
            
            ensemble_pred = mask_gen.copy()
            cancer_indices = (mask_spec == 1)
            ensemble_pred[cancer_indices] = 1
            
            teacher_f1, teacher_iou = calculate_metrics(gt_mask, teacher_pred.astype(np.uint8))
            ensemble_f1, ensemble_iou = calculate_metrics(gt_mask, ensemble_pred.astype(np.uint8))
            
            results.append({
                'image_path': image_path_str,
                'mask_path': mask_path,
                'case_id': image_path.parent.name,
                'image_name': image_path.name,
                'teacher_f1': teacher_f1,
                'teacher_iou': teacher_iou,
                'ensemble_f1': ensemble_f1,
                'ensemble_iou': ensemble_iou
            })
            
        except Exception as e:
            print(f"  Error processing {image_path.name}: {e}")
            continue
    
    print(f"\nEvaluation complete: {len(results)} images successfully processed")
    
    df = pd.DataFrame(results)
    df_sorted = df.sort_values('ensemble_f1', ascending=True)
    
    return df_sorted

def showcase_worst_predictions(n=10):
    """
    Find and visualize the N worst predictions based on ensemble F1 score
    
    Args:
        n: Number of worst predictions to show (default 10)
    """
    print("Finding worst predictions...")
    
    results_df = evaluate_all_images()
    
    worst_predictions = results_df.head(n)
    
    print(f"\n{n} Worst Ensemble Predictions (by F1 score):")
    print("-" * 80)
    
    for idx, row in worst_predictions.iterrows():
        print(f"{len(worst_predictions) - list(worst_predictions.index).index(idx)}. {row['image_name']}")
        print(f"   Case: {row['case_id']}")
        print(f"   Ensemble F1: {row['ensemble_f1']:.4f} | IoU: {row['ensemble_iou']:.4f}")
        print(f"   Teacher  F1: {row['teacher_f1']:.4f} | IoU: {row['teacher_iou']:.4f}")
        print()
    
    print(f"\nOverall Statistics:")
    print(f"   Best Ensemble F1:  {results_df['ensemble_f1'].max():.4f}")
    print(f"   Worst Ensemble F1: {results_df['ensemble_f1'].min():.4f}")
    print(f"   Mean Ensemble F1:  {results_df['ensemble_f1'].mean():.4f}")
    print(f"   Std Ensemble F1:   {results_df['ensemble_f1'].std():.4f}")
    
    print(f"\nVisualizing {n} worst predictions...")
    
    for idx, row in worst_predictions.iterrows():
        print(f"\n{'='*60}")
        print(f"Rank {len(worst_predictions) - list(worst_predictions.index).index(idx)}: {row['image_name']}")
        print(f"Ensemble F1: {row['ensemble_f1']:.4f}")
        print(f"{'='*60}")
        
        compare_predictions(row['image_path'])
    
    return results_df, worst_predictions

print("Worst predictions analysis functions defined")

In [None]:
results_df, worst_5 = showcase_worst_predictions(n=10)