In [12]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
import logging
import pandas as pd
import numpy as np
import nibabel as nib
import torchvision.transforms.functional as TF
import os
import matplotlib.pyplot as plt  # For visualization

from scipy.special import erf
from skimage.morphology import binary_erosion
from skimage.measure import label

# ===============================
# Logging Configuration
# ===============================
logging.basicConfig(
    filename='visualization.log',
    filemode='a',
    format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
    datefmt='%H:%M:%S',
    level=logging.INFO  # Set to INFO to capture essential logs
)

# ===============================
# Model Definitions
# ===============================

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):
    """Downscaling with maxpool then double conv"""

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):
    """Upscaling then double conv"""

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # Handle padding if necessary
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = TF.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=False):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)      # (B, 64, H, W)
        x2 = self.down1(x1)   # (B, 128, H/2, W/2)
        x3 = self.down2(x2)   # (B, 256, H/4, W/4)
        x4 = self.down3(x3)   # (B, 512, H/8, W/8)
        x5 = self.down4(x4)   # (B, 1024, H/16, W/16)
        x = self.up1(x5, x4)  # (B, 512, H/8, W/8)
        x = self.up2(x, x3)   # (B, 256, H/4, W/4)
        x = self.up3(x, x2)   # (B, 128, H/2, W/2)
        x = self.up4(x, x1)   # (B, 64, H, W)
        logits = self.outc(x) # (B, n_classes, H, W)
        return logits

# ===============================
# Dataset Definition with Enhanced Error Handling
# ===============================

class BrainSegmentationDataset(Dataset):
    def __init__(self, csv_path, crop_coords=None):
        """
        Initializes the dataset by reading the CSV and preparing slice indices that contain tumors.

        Args:
            csv_path (str): Path to the CSV file containing dataset information.
            crop_coords (tuple, optional): Coordinates for cropping images. Defaults to None.
        """
        self.data = pd.read_csv(csv_path)
        self.crop_coords = crop_coords
        self.slice_info = self._create_slice_index()

    def _create_slice_index(self):
        """
        Creates a list of tuples containing subject IDs and slice indices that have tumor labels.

        Returns:
            list: List of (subject_id, slice_index) tuples.
        """
        slice_info = []
        subjects = self.data['Subject ID'].unique()
        for subject in subjects:
            subject_data = self.data[self.data['Subject ID'] == subject]
            try:
                # Paths for each modality and segmentation
                flair_path = subject_data[subject_data['Scan Type'] == 'flair']['File Path'].values[0]
                seg_path = subject_data[subject_data['Scan Type'] == 'seg']['File Path'].values[0]
                
                # Load images
                flair_nii = nib.load(flair_path)
                flair_image = flair_nii.get_fdata().astype(np.float32)
                seg_nii = nib.load(seg_path)
                seg_mask = seg_nii.get_fdata().astype(np.uint8)
                seg_mask[seg_mask == 4] = 3  # Merge label 4 into 3 if necessary
                
                depth = flair_image.shape[2]
                for z in range(15, depth - 12):
                    # Check if the slice has any tumor labels
                    seg_slice = seg_mask[:, :, z]
                    if np.any(np.isin(seg_slice, [1, 2, 3])):
                        slice_info.append((subject, z))
            except Exception as e:
                logging.error(f"Error processing subject {subject}: {e}")
        return slice_info

    @staticmethod
    def Img_proc(image, _lambda=-0.8, epsilon=1e-6):
        """
        Processes the image using a specific transformation pipeline.

        Args:
            image (numpy.ndarray): Input image slice.
            _lambda (float, optional): Lambda parameter for transformation. Defaults to -0.8.
            epsilon (float, optional): Small value to prevent division by zero. Defaults to 1e-6.

        Returns:
            numpy.ndarray: Processed image.
        """
        if np.isnan(image).any() or np.isinf(image).any():
            raise ValueError("Invalid image values.")
        I_img = image
        min_val = np.min(I_img)
        max_val = np.max(I_img)
        if max_val == min_val:
            return np.zeros_like(I_img)
        I_img_norm = (I_img - min_val) / (max_val - min_val + epsilon)
        max_I_img = np.max(I_img_norm)
        IMG1 = (max_I_img / np.log(max_I_img + 1 + epsilon)) * np.log(I_img_norm + 1)
        IMG2 = 1 - np.exp(-I_img_norm)
        IMG3 = (IMG1 + IMG2) / (_lambda + (IMG1 * IMG2))
        IMG4 = erf(_lambda * np.arctan(np.exp(IMG3)) - 0.5 * IMG3)
        min_IMG4 = np.min(IMG4)
        max_IMG4 = np.max(IMG4)
        if max_IMG4 == min_IMG4:
            return np.zeros_like(IMG4)
        IMG5 = (IMG4 - min_IMG4) / (max_IMG4 - min_IMG4 + epsilon)
        return IMG5

    def __len__(self):
        return len(self.slice_info)

    def __getitem__(self, idx):
        """
        Retrieves the image and mask for the given index.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (image_tensor, mask_tensor)
        """
        try:
            subject, z = self.slice_info[idx]
            subject_data = self.data[self.data['Subject ID'] == subject]
            modalities = ['flair', 't1', 't1ce', 't2']
            slices = []
            for modality in modalities:
                file_path = subject_data[subject_data['Scan Type'] == modality]['File Path'].values[0]
                if not os.path.exists(file_path):
                    raise FileNotFoundError(f"File not found: {file_path}")
                nii = nib.load(file_path)
                image = nii.get_fdata().astype(np.float32)
                image = (image - np.mean(image)) / (np.std(image) + 1e-6)
                if self.crop_coords:
                    min_x, max_x, min_y, max_y = self.crop_coords
                    image = image[min_x:max_x, min_y:max_y, :]
                slice_img = image[:, :, z]
                slice_img = self.Img_proc(slice_img)
                slices.append(slice_img)
            images = np.stack(slices, axis=0)
            seg_data = subject_data[subject_data['Scan Type'] == 'seg']
            if seg_data.empty:
                raise ValueError(f"Missing segmentation for subject {subject}")
            seg_path = seg_data['File Path'].values[0]
            if not os.path.exists(seg_path):
                raise FileNotFoundError(f"Segmentation file not found: {seg_path}")
            seg_nii = nib.load(seg_path)
            seg_mask = seg_nii.get_fdata().astype(np.uint8)
            seg_mask[seg_mask == 4] = 3
            seg_slice = seg_mask[:, :, z]
            if self.crop_coords:
                seg_slice = seg_mask[min_x:max_x, min_y:max_y, z]
            return torch.tensor(images, dtype=torch.float32), torch.tensor(seg_slice, dtype=torch.long)
        except Exception as e:
            logging.error(f"Error in __getitem__ at index {idx}: {e}")
            raise e  # Re-raise the exception to be handled by the DataLoader

# ===============================
# Visualization Function
# ===============================

def visualize_segmentation(image, ground_truth, prediction, save_path, slice_idx):
    """
    Visualize the input image, ground truth mask, and predicted mask side by side.

    Args:
        image (numpy.ndarray): The input MRI slice (H, W).
        ground_truth (numpy.ndarray): The ground truth segmentation mask (H, W).
        prediction (numpy.ndarray): The predicted segmentation mask (H, W).
        save_path (str): Directory path to save the visualization.
        slice_idx (int): Index of the slice for naming.
    """
    # Define the color map for segmentation
    cmap = plt.get_cmap('jet', np.max(ground_truth) - np.min(ground_truth) + 1)
    
    plt.figure(figsize=(18, 6))  # Increased size for better clarity

    # Input Image
    plt.subplot(1, 3, 1)
    plt.imshow(image, cmap='gray')
    plt.title('Input Image')
    plt.axis('off')

    # Ground Truth Mask
    plt.subplot(1, 3, 2)
    plt.imshow(ground_truth, cmap=cmap, interpolation='none')
    plt.title('Ground Truth Mask')
    plt.axis('off')

    # Predicted Mask
    plt.subplot(1, 3, 3)
    plt.imshow(prediction, cmap=cmap, interpolation='none')
    plt.title('Predicted Mask')
    plt.axis('off')

    # Create save directory if it doesn't exist
    os.makedirs(save_path, exist_ok=True)
    
    # Save the figure
    plt.savefig(os.path.join(save_path, f'slice_{slice_idx}.png'))
    plt.close()

# ===============================
# Helper Functions
# ===============================

def load_checkpoint(model, checkpoint_path, device):
    """
    Load model weights from a checkpoint file.

    Args:
        model (nn.Module): The model to load weights into.
        checkpoint_path (str): Path to the checkpoint file.
        device (torch.device): Device to map the model weights.

    Returns:
        nn.Module: The model loaded with weights.
    """
    try:
        checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
        logging.info(f"Checkpoint loaded with weights_only=True from {checkpoint_path}")
    except TypeError:
        # If torch.load doesn't support weights_only (older PyTorch versions)
        checkpoint = torch.load(checkpoint_path, map_location=device)
        logging.warning("torch.load does not support weights_only. Loaded without it.")

    # Check if it's a DataParallel model
    model_keys = list(model.state_dict().keys())
    checkpoint_keys = list(checkpoint.keys())

    if any(key.startswith("module.") for key in checkpoint_keys) and not any(key.startswith("module.") for key in model_keys):
        # Remove 'module.' prefix
        checkpoint = {key.replace("module.", ""): value for key, value in checkpoint.items()}
    elif not any(key.startswith("module.") for key in checkpoint_keys) and any(key.startswith("module.") for key in model_keys):
        # Add 'module.' prefix
        checkpoint = {"module." + key: value for key, value in checkpoint.items()}

    model.load_state_dict(checkpoint, strict=False)
    logging.info(f"Loaded state_dict into the model from {checkpoint_path}")
    return model

def test_dataset(dataset, num_samples=5):
    """
    Manually test the dataset by accessing a few samples.

    Args:
        dataset (Dataset): The dataset to test.
        num_samples (int): Number of samples to test.
    """
    print("Testing dataset integrity...")
    for idx in range(min(num_samples, len(dataset))):
        try:
            images, masks = dataset[idx]
            print(f"Sample {idx}: images shape {images.shape}, masks shape {masks.shape}")
            logging.info(f"Sample {idx}: images shape {images.shape}, masks shape {masks.shape}")
        except Exception as e:
            print(f"Error accessing sample {idx}: {e}")
            logging.error(f"Error accessing sample {idx}: {e}")

# ===============================
# Main Function
# ===============================

def main():
    print("Start visualization process")
    logging.info("Visualization process started")
    
    csv_path = '../data/selected_test_subject.csv'  # Update this path as needed
    model_dir = "../model"  # Directory containing model checkpoints
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Reading dataset")
    logging.info(f"Reading dataset from {csv_path}")
    
    # Initialize dataset without cropping coordinates (if needed, set crop_coords)
    dataset = BrainSegmentationDataset(csv_path)
    
    # Test dataset integrity before proceeding
    test_dataset(dataset)
    
    # Initialize DataLoader with num_workers=0 and reduced batch_size to avoid issues
    dataloader = DataLoader(
        dataset,
        batch_size=32,         # Adjust as per your system's capability
        shuffle=False,
        num_workers=0,         # Set to 0 to prevent worker errors
        pin_memory=True
    )
    
    # Directory to save all visualizations
    visualization_root = "visualizations"
    os.makedirs(visualization_root, exist_ok=True)
    logging.info(f"Visualization root directory set at '{visualization_root}'")
    
    # Iterate over all checkpoint files in the model directory
    for root, _, files in os.walk(model_dir):
        for file in files:
            if file.endswith(".pt") or file.endswith(".pth"):
                checkpoint_path = os.path.join(root, file)
                print(f"Read file from checkpoint {checkpoint_path}")
                logging.info(f"Processing checkpoint: {checkpoint_path}")
                
                # Initialize the model
                model = UNet(n_channels=4, n_classes=4, bilinear=True).to(device)
                
                try:
                    # Load model weights
                    model = load_checkpoint(model, checkpoint_path, device)
                    print("Model Loaded Successfully")
                    logging.info(f"Model loaded successfully from {checkpoint_path}")
                except Exception as e:
                    print(f"Failed to load the Model: {e}")
                    logging.error(f"Failed to load the model from {checkpoint_path}: {e}")
                    continue  # Skip to the next checkpoint
                
                model.eval()
                print(f"Dataset size: {len(dataset)}, Dataloader batches: {len(dataloader)}")
                logging.info(f"Dataset size: {len(dataset)}, Dataloader batches: {len(dataloader)}")
                
                # Directory to save visualizations for this model
                save_path = os.path.join(visualization_root, os.path.splitext(file)[0])
                os.makedirs(save_path, exist_ok=True)
                logging.info(f"Saving visualizations to '{save_path}'")
                
                # Counter for saved visualizations
                num_visualizations = 5  # Number of slices to visualize per model
                saved_visualizations = 0
                
                with torch.no_grad():
                    for images, masks in tqdm(dataloader, desc=f"Visualizing with {file}"):
                        images = images.to(device, dtype=torch.float32)
                        masks = masks.numpy()  # Shape: (B, H, W)

                        outputs = model(images)  # Shape: (B, 4, H, W)
                        preds = torch.argmax(outputs, dim=1).cpu().numpy()  # Shape: (B, H, W)

                        batch_size = preds.shape[0]
                        for i in range(batch_size):
                            if saved_visualizations >= num_visualizations:
                                break
                            
                            # Select the first modality (e.g., FLAIR) for visualization
                            input_image = images[i, 0].cpu().numpy()  # Assuming modality 0 is FLAIR
                            ground_truth = masks[i]
                            prediction = preds[i]

                            # Visualize and save
                            try:
                                visualize_segmentation(
                                    image=input_image,
                                    ground_truth=ground_truth,
                                    prediction=prediction,
                                    save_path=save_path,
                                    slice_idx=saved_visualizations + 1
                                )
                                saved_visualizations += 1
                                logging.info(f"Saved visualization slice {saved_visualizations} for model {file}")
                            except Exception as e:
                                print(f"Error during visualization: {e}")
                                logging.error(f"Error during visualization for slice {saved_visualizations + 1} in model {file}: {e}")

                        if saved_visualizations >= num_visualizations:
                            break  # Move to the next checkpoint after saving required visualizations
                
                print(f"Saved {saved_visualizations} visualizations for model {file} at '{save_path}'")
                logging.info(f"Saved {saved_visualizations} visualizations for model {file} at '{save_path}'")

    print("Visualization process completed")
    logging.info("Visualization process completed")

if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        logging.critical(f"An unexpected error occurred: {e}")


Start visualization process
Reading dataset
Testing dataset integrity...
Sample 0: images shape torch.Size([4, 240, 240]), masks shape torch.Size([240, 240])
Sample 1: images shape torch.Size([4, 240, 240]), masks shape torch.Size([240, 240])
Sample 2: images shape torch.Size([4, 240, 240]), masks shape torch.Size([240, 240])
Sample 3: images shape torch.Size([4, 240, 240]), masks shape torch.Size([240, 240])
Sample 4: images shape torch.Size([4, 240, 240]), masks shape torch.Size([240, 240])
Read file from checkpoint ../model\model.pth
Model Loaded Successfully
Dataset size: 2379, Dataloader batches: 75


Visualizing with model.pth:   0%|                                                               | 0/75 [00:10<?, ?it/s]

Saved 5 visualizations for model model.pth at 'visualizations\model'
Visualization process completed



