In [19]:
import torch  
import torch.nn as nn  
import torch.optim as optim  
from torch.utils.data import DataLoader, Dataset  
from torchvision import transforms  
import numpy as np  
import pandas as pd  
from PIL import Image  
import random  
import nibabel as nib
import os

class BraTSDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.t1_files = dataframe[dataframe["Scan Type"] == "t1"]
        self.seg_files = dataframe[dataframe["Scan Type"] == "seg"]
        self.transform = transform

    def __len__(self):
        return len(self.t1_files) * 155

    def __getitem__(self, idx):
        subject_idx = idx // 155
        slice_idx = idx % 155

        t1_row = self.t1_files.iloc[subject_idx]
        seg_row = self.seg_files.iloc[subject_idx]

        t1_path = t1_row["File Path"]
        seg_path = seg_row["File Path"]

        # Skip blank slices
        if slice_idx < 15 or slice_idx > 142:
            return self.__getitem__((idx + 1) % self.__len__())  # Safely fetch another valid index

        # Load T1 slice
        t1_img = nib.load(t1_path).get_fdata()
        t1_slice = t1_img[:, :, slice_idx]
        t1_slice = torch.tensor(t1_slice, dtype=torch.float32).unsqueeze(0)

        # Load segmentation slice
        seg_img = nib.load(seg_path).get_fdata()
        seg_slice = seg_img[:, :, slice_idx]
        seg_label = 1 if np.any(seg_slice > 0) else 0

        if self.transform:
            t1_slice = self.transform(t1_slice)

        return t1_slice, seg_label



import matplotlib.pyplot as plt

def evaluate_discriminator_with_logging(discriminator, dataloader, device, sens=0.1):
    """
    Evaluate the discriminator on a dataset, compute accuracy and sensitivity, 
    and log misclassified slices.
    """
    discriminator.eval()
    correct = 0
    total = 0
    tp, fp, tn, fn = 0, 0, 0, 0  # Initialize counters for sensitivity calculation

    output_dir = f"misclassified_slices{sens}"
    # Create directory for saving misclassified slices
    os.makedirs(output_dir, exist_ok=True)
    misclassified_count = 0

    with torch.no_grad():
        for i, (t1_slice, seg_label) in enumerate(dataloader):
            if t1_slice is None:  # Skip invalid slices
                continue

            t1_slice = t1_slice.to(device)
            seg_label = seg_label.to(device)

            # Get discriminator output (pixel-wise probabilities)
            output = discriminator(t1_slice).squeeze()

            # Compute slice-level prediction
            slice_prediction = (output.mean() <= sens).long()  # Threshold at `sens` for slice-level classification

            # Update counters for accuracy and sensitivity
            if slice_prediction == seg_label:
                correct += 1
                if seg_label == 1:  # True Positive
                    tp += 1
                else:  # True Negative
                    tn += 1
            else:
                misclassified_count += 1
                if seg_label == 1:  # False Negative
                    fn += 1
                else:  # False Positive
                    fp += 1

                # Log and save misclassified slice
                # print(f"Misclassified Slice Index: {i}, Prediction: {slice_prediction.item()}, Ground Truth: {seg_label.item()}")
                save_misclassified_slice(t1_slice.cpu().squeeze().numpy(), seg_label.item(), slice_prediction.item(), i, output_dir)

            total += 1

    # Calculate metrics
    accuracy = (correct / total * 100) if total > 0 else 0
    sensitivity = (tp / (tp + fn) * 100) if (tp + fn) > 0 else 0  # Avoid division by zero

    # Print results
    print(f"Discriminator Accuracy: {accuracy:.2f}%")
    print(f"Discriminator Sensitivity: {sensitivity:.2f}%")
    print(f"Total Misclassified Slices: {misclassified_count}")
    print(f"True Positives (TP): {tp}, False Negatives (FN): {fn}")
    print(f"True Negatives (TN): {tn}, False Positives (FP): {fp}")

    return accuracy, sensitivity


    # Calculate accuracy
    accuracy = (correct / total * 100) if total > 0 else 0
    print(f"Discriminator Accuracy: {accuracy:.2f}%")
    print(f"Total Misclassified Slices: {misclassified_count}")
    return accuracy

def save_misclassified_slice(slice_data, true_label, predicted_label, index, output_dir):
    """
    Save a misclassified slice as an image for visualization.

    Args:
        slice_data (numpy array): The MRI slice data.
        true_label (int): The ground truth label (0 or 1).
        predicted_label (int): The predicted label (0 or 1).
        index (int): Index of the slice.
        output_dir (str): Directory to save the image.
    """
    plt.figure(figsize=(6, 6))
    plt.imshow(slice_data, cmap="gray")
    plt.title(f"True: {true_label}, Predicted: {predicted_label}")
    plt.axis("off")
    save_path = os.path.join(output_dir, f"slice_{index}_true_{true_label}_pred_{predicted_label}.png")
    plt.savefig(save_path)
    plt.close()



# Load test data
import pandas as pd

test_csv_path = "../data/selected_test_subject.csv"  # Update with the correct path
test_data = pd.read_csv(test_csv_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Prepare test dataset and dataloader
test_dataset = BraTSDataset(test_data)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Load the discriminator
class Discriminator(nn.Module):
    def __init__(self, in_channels):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

# Initialize and load the discriminator model
discriminator = Discriminator(in_channels=1).to(device)
discriminator.load_state_dict(torch.load("../model/discriminator_epoch_10.pth"))  # Update with the correct path

# Evaluate the discriminator
accuracy, sensitivity = evaluate_discriminator_with_logging(discriminator, test_loader, device, sens=0.1)
print(f"Final Results - Accuracy: {accuracy:.2f}%, Sensitivity: {sensitivity:.2f}%")




  discriminator.load_state_dict(torch.load("../model/discriminator_epoch_10.pth"))  # Update with the correct path


Discriminator Accuracy: 75.07%
Discriminator Sensitivity: 51.24%
Total Misclassified Slices: 1391
True Positives (TP): 1219, False Negatives (FN): 1160
True Negatives (TN): 2970, False Positives (FP): 231
Final Results - Accuracy: 75.07%, Sensitivity: 51.24%


In [20]:
accuracy, sensitivity = evaluate_discriminator_with_logging(discriminator, test_loader, device, sens=0.2)
print(f"Final Results - Accuracy: {accuracy:.2f}%, Sensitivity: {sensitivity:.2f}%")

Discriminator Accuracy: 79.53%
Discriminator Sensitivity: 88.52%
Total Misclassified Slices: 1142
True Positives (TP): 2106, False Negatives (FN): 273
True Negatives (TN): 2332, False Positives (FP): 869
Final Results - Accuracy: 79.53%, Sensitivity: 88.52%


In [21]:
accuracy, sensitivity = evaluate_discriminator_with_logging(discriminator, test_loader, device, sens=0.3)
print(f"Final Results - Accuracy: {accuracy:.2f}%, Sensitivity: {sensitivity:.2f}%")

Discriminator Accuracy: 76.40%
Discriminator Sensitivity: 95.00%
Total Misclassified Slices: 1317
True Positives (TP): 2260, False Negatives (FN): 119
True Negatives (TN): 2003, False Positives (FP): 1198
Final Results - Accuracy: 76.40%, Sensitivity: 95.00%


In [22]:
accuracy, sensitivity = evaluate_discriminator_with_logging(discriminator, test_loader, device, sens=0.4)
print(f"Final Results - Accuracy: {accuracy:.2f}%, Sensitivity: {sensitivity:.2f}%")

Discriminator Accuracy: 72.92%
Discriminator Sensitivity: 98.11%
Total Misclassified Slices: 1511
True Positives (TP): 2334, False Negatives (FN): 45
True Negatives (TN): 1735, False Positives (FP): 1466
Final Results - Accuracy: 72.92%, Sensitivity: 98.11%
