In [None]:
import torch
from torch.utils.data import DataLoader
import pandas as pd
from tqdm import tqdm
from pathlib import Path
from unet_v1 import UNet, BrainSegmentationDataset, dice_coeff, multiclass_dice_coeff

# File paths
checkpoint_path = "/mnt/data/checkpoint_epoch1.pth"
csv_path = "/mnt/data/training_detailed_summary_2020.csv"

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load dataset
dataset = BrainSegmentationDataset(csv_path)
loader_args = dict(batch_size=1, num_workers=0, pin_memory=True)
data_loader = DataLoader(dataset, shuffle=False, **loader_args)

# Load model
model = UNet(n_channels=4, n_classes=4)  # Adjust based on your dataset
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.to(device)
model.eval()

# Validation loop
results = []
with torch.no_grad():
    for idx, (images, true_masks) in enumerate(tqdm(data_loader, desc="Validation")):
        images = images.to(device, dtype=torch.float32)
        true_masks = true_masks.to(device, dtype=torch.long)
        
        # Prediction
        pred_masks = model(images)
        
        # Calculate Dice score
        if model.n_classes == 1:
            dice_score = dice_coeff((torch.sigmoid(pred_masks) > 0.5).float(), true_masks.float()).item()
        else:
            true_masks_onehot = torch.nn.functional.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float()
            pred_masks_onehot = torch.nn.functional.one_hot(pred_masks.argmax(dim=1), model.n_classes).permute(0, 3, 1, 2).float()
            dice_score = multiclass_dice_coeff(pred_masks_onehot, true_masks_onehot).item()

        # Record result
        results.append({"Index": idx, "Dice Score": dice_score})

# Convert results to DataFrame
results_df = pd.DataFrame(results)

# Sort by Dice Score to find worst-performing images
results_df = results_df.sort_values(by="Dice Score", ascending=True)

# Save results
output_path = Path("/mnt/data/validation_results.csv")
results_df.to_csv(output_path, index=False)
print(f"Results saved to {output_path}")
