In [None]:
import pandas as pd
import numpy as np
import os
import torch
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
import tifffile as tiff
import cv2

# --- Import from your project's src files ---
import sys
sys.path.append('..')
from src import config
from src.dataset import HubmapWsiDataset # Using the openslide dataset
from src.utils import dice_score, rle_decode

# --- 1. LOAD MODEL AND DATA ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = smp.Unet(config.MODEL_ENCODER, encoder_weights=None, in_channels=3, classes=1).to(device)

# Path to your best saved model from one of the folds
MODEL_PATH = '../models/best_model_fold_0.pth'
if not os.path.exists(MODEL_PATH):
    raise FileNotFoundError(f"Saved model not found at {MODEL_PATH}.")
    
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

# Create validation set
df = pd.read_csv(config.METADATA_FILE)
_, val_df = train_test_split(df, test_size=0.2, random_state=42)
val_dataset = HubmapWsiDataset(val_df)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)

# --- 2. CALCULATE SCORES & IDENTIFY WORST PERFORMERS ---
results = []
with torch.no_grad():
    for i, (images, masks) in enumerate(tqdm(val_loader)):
        images = images.to(device)
        masks = masks.to(device).unsqueeze(1).float()
        outputs = model(images)
        
        for j in range(images.size(0)):
            idx = i * val_loader.batch_size + j
            if idx < len(val_df):
                image_id = val_df.iloc[idx]['id']
                score = dice_score(outputs[j], masks[j]).item()
                results.append({'id': image_id, 'dice_score': score})

results_df = pd.DataFrame(results)
worst_performers = results_df.sort_values(by='dice_score').head(5)

print("\n--- Visualizing 5 Worst Performers ---")
for index, row in worst_performers.iterrows():
    image_id = int(row['id'])
    dice = row['dice_score']
    
    image_info = df[df['id'] == image_id].iloc[0]
    image_path = os.path.join(config.IMAGES_DIR, f"{image_id}.tiff")
    original_image = tiff.imread(image_path)
    true_mask = rle_decode(image_info['rle'], (image_info['img_height'], image_info['img_width']))

    image_tensor, _ = val_dataset[val_df.index.get_loc(image_info.name)]
    image_tensor = image_tensor.unsqueeze(0).to(device)
    pred_mask = (torch.sigmoid(model(image_tensor)) > 0.5).float().squeeze().cpu().numpy()
    
    # Plotting
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f"ID: {image_id} | Dice Score: {dice:.4f}", fontsize=14)
    axes[0].imshow(cv2.resize(original_image, (config.PATCH_SIZE, config.PATCH_SIZE)))
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    axes[1].imshow(cv2.resize(true_mask, (config.PATCH_SIZE, config.PATCH_SIZE)), cmap='gray')
    axes[1].set_title("Ground Truth Mask")
    axes[1].axis('off')
    
    axes[2].imshow(pred_mask, cmap='gray')
    axes[2].set_title("Model Prediction")
    axes[2].axis('off')
    
    plt.show()