# Notebook 1: Data Exploration and Preparation

**Objective:** This notebook provides an interactive guide to the initial data preparation workflow for the CT-RATE dataset. We will cover:

1.  **Loading and Configuring**: Setting up the environment and loading the main configuration.
2.  **Label Analysis**: Visualizing the distribution and co-occurrence of pathology labels.
3.  **Data Filtering**: Replicating the logic for creating a filtered master list of volumes, excluding certain patients to prevent data leakage.
4.  **K-Fold Split Generation**: Demonstrating how patient-aware, stratified cross-validation splits are created.

In [None]:
# Standard library imports
import sys
from pathlib import Path

# Third-party imports
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from skmultilearn.model_selection import iterative_train_test_split

# --- Project-Specific Imports ---
# To import from the 'src' directory, we need to add the project root to the Python path.
# We assume this notebook is located in a 'notebooks' directory at the project root.
project_root = Path.cwd().parent
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from src.config import load_config
from scripts.data_preparation.create_filtered_dataset import get_patient_id, normalize_name_from_path, natural_sort_key
from scripts.data_preparation.create_kfold_splits import create_kfold_splits

# --- Notebook Setup ---
sns.set_theme(context="notebook", style="whitegrid", font_scale=1.2)
# This setting ensures that all columns are displayed in pandas DataFrames.
pd.set_option('display.max_columns', None)

print("Imports successful and project path configured.")

In [None]:
# Load the main configuration file to access all project paths and parameters.
# This ensures that the notebook uses the same settings as the main application scripts.
try:
    config_path = project_root / 'configs' / 'config_example.yaml'
    config = load_config(config_path)
    print("Configuration loaded successfully.")
    # Display a few key paths to verify
    print(f"Data Directory: {config.paths.data_dir}")
    print(f"Labels File: {config.paths.labels.all}")
except FileNotFoundError:
    print("ERROR: Could not find the configuration file.")
    print("Please ensure 'configs/config_example.yaml' exists and the project root is correct.")

In [None]:
# Load the primary data files: the full metadata and the complete set of labels.
try:
    all_labels_df = pd.read_csv(config.paths.labels.all)
    
    # We also need the metadata files to get the full list of available volumes.
    train_metadata_df = pd.read_csv(config.paths.metadata.train)
    valid_metadata_df = pd.read_csv(config.paths.metadata.valid)
    all_volumes_df = pd.concat([train_metadata_df, valid_metadata_df], ignore_index=True)


    print("Full Label Set Info:")
    all_labels_df.info()
    print("\nFull Label Set Head:")
    display(all_labels_df.head())

    # print("\nAll Volumes (from metadata) Info:")
    # all_volumes_df.info()
    print("\nAll Volumes Head:")
    display(all_volumes_df.head())


except FileNotFoundError as e:
    print(f"ERROR: Could not load a required data file: {e}")
    print("Please ensure the paths in your config file are correct and the data files exist.")

## Understanding the Dataset Hierarchy

Before we analyze the labels, it is crucial to understand the structure of the CT-RATE dataset. The data is organized hierarchically:

-   **Patients**: The highest level. A single patient may have multiple CT scans over time.
-   **CT Scans**: An imaging session for a patient.
-   **Volumes (or Reconstructions)**: A single CT scan can be reconstructed with different parameters (e.g., different slice thickness or kernels), resulting in multiple 3D volume files (`.nii.gz`) for the same scan.

This means the total number of `.nii.gz` files is much larger than the number of unique scans, which in turn is larger than the number of unique patients. For our analysis, especially for splitting the data, we must operate at the **patient level** to prevent data leakage.

Let's quantify this structure using our loaded metadata.

In [None]:
# --- Calculate counts for each level of the hierarchy ---

# 1. Total number of reconstructed volumes (each row in the metadata is one volume)
total_volumes = len(all_volumes_df)

# 2. Number of unique CT scans
# We can identify a unique scan by its name minus the final reconstruction part (e.g., '_1', '_2')
# 'train_123_a_1' -> 'train_123_a'
all_volumes_df['ScanID'] = all_volumes_df['VolumeName'].str.rsplit('_', n=1).str[0]
total_scans = all_volumes_df['ScanID'].nunique()

# 3. Number of unique patients
# The PatientID is the second part of the VolumeName, e.g., 'train_123_a_1' -> '123'
all_volumes_df['PatientID'] = all_volumes_df['VolumeName'].apply(get_patient_id)
total_patients = all_volumes_df['PatientID'].nunique()

# --- Display the statistics ---
print("--- CT-RATE Dataset Hierarchy Statistics ---")
print(f"Total Reconstructed Volumes: {total_volumes:,}")
print(f"Total Unique CT Scans:      {total_scans:,}")
print(f"Total Unique Patients:        {total_patients:,}")
print("------------------------------------------")

# Display the new columns to verify
display(all_volumes_df[['VolumeName', 'ScanID', 'PatientID']].head())

## 2. Label Distribution Analysis

Before filtering, let's analyze the raw label data to understand its characteristics. We will visualize three key aspects:

-   **Label Frequency**: How many times does each pathology appear in the dataset? This helps identify class imbalance.
-   **Labels per Scan**: How many pathologies are typically assigned to a single CT scan?
-   **Label Co-occurrence**: Which pathologies tend to appear together? This reveals potential clinical correlations.

In [None]:
# These are the refined plotting functions adapted from your EDA notebook.

def plot_label_frequencies(df_labels, label_columns):
    """Generates a horizontal bar chart showing the frequency of each label."""
    label_counts = df_labels[label_columns].sum().sort_values(ascending=False)
    
    plt.figure(figsize=(12, 10))
    sns.barplot(x=label_counts.values, y=label_counts.index, palette="viridis", orient='h')
    plt.xlabel("Frequency (Number of Positive Cases)")
    plt.ylabel("Pathology Label")
    plt.title("Frequency of Each Pathology Label", fontsize=16)
    plt.tight_layout()
    plt.show()

def plot_labels_per_scan_distribution(df_labels, label_columns):
    """Generates a histogram of the number of positive labels per scan."""
    labels_per_scan = df_labels[label_columns].sum(axis=1)
    
    plt.figure(figsize=(10, 6))
    sns.histplot(labels_per_scan, kde=False, color="skyblue", discrete=True)
    plt.xlabel("Number of Positive Labels per Scan")
    plt.ylabel("Number of Scans")
    plt.title("Distribution of Number of Labels per Scan", fontsize=16)
    plt.xticks(range(0, labels_per_scan.max() + 1))
    plt.tight_layout()
    plt.show()

def plot_label_cooccurrence_heatmap(df_labels, label_columns):
    """Generates a heatmap showing the co-occurrence of label pairs."""
    df_label_data = df_labels[label_columns]
    cooccurrence_matrix = df_label_data.T.dot(df_label_data)
    
    # For visualization, we normalize by the diagonal to see conditional probabilities
    # P(Y | X) = Count(X and Y) / Count(X)
    diagonal_counts = np.diag(cooccurrence_matrix)
    with np.errstate(divide='ignore', invalid='ignore'):
        normalized_matrix = cooccurrence_matrix.astype(float).div(diagonal_counts, axis=0)
    normalized_matrix = normalized_matrix.fillna(0)

    plt.figure(figsize=(16, 14))
    sns.heatmap(normalized_matrix, annot=True, fmt=".2f", cmap="viridis", linewidths=.5)
    plt.xlabel("Given this Label is Present (X)")
    plt.ylabel("Probability of this Label also being Present (Y)")
    plt.title("Normalized Label Co-occurrence P(Y|X)", fontsize=16)
    plt.xticks(rotation=45, ha="right")
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

# Get the list of pathology columns from the config file
pathology_columns = config.pathologies.columns

### Correcting for the Data Hierarchy

As noted, the `all_labels_df` contains entries for every reconstructed volume, not every unique scan. To get an accurate view of the label distribution, we must first deduplicate the data so that each unique scan is represented only once. We will use the `ScanID` we created earlier for this purpose.

In [None]:
# --- 1. Create the 'ScanID' in the labels DataFrame ---
# As you correctly pointed out, we first need to create the ScanID in our labels DataFrame.
# The logic is the same: strip the final reconstruction part from the VolumeName.
all_labels_df['ScanID'] = all_labels_df['VolumeName'].str.rsplit('_', n=1).str[0]

# --- 2. Now, perform the deduplication ---
# With the ScanID present, we can now drop duplicates to get one entry per unique scan.
scan_level_labels_df = all_labels_df.drop_duplicates(subset=['ScanID']).reset_index(drop=True)


print(f"Original number of label entries (volumes): {len(all_labels_df):,}")
print(f"Deduplicated number of label entries (scans): {len(scan_level_labels_df):,}")

print("\nHead of the new scan-level DataFrame with 'ScanID':")
display(scan_level_labels_df[['VolumeName', 'ScanID'] + pathology_columns].head())

In [None]:
plot_label_frequencies(scan_level_labels_df, pathology_columns)

In [None]:
plot_labels_per_scan_distribution(scan_level_labels_df, pathology_columns)

In [None]:
plot_label_cooccurrence_heatmap(scan_level_labels_df, pathology_columns)

## 3. Executing the Data Filtering Script and creation of a Master List

For the experiments in this repository, a crucial decision was made to separate the manually labeled scans from the main corpus. These manually labeled scans, which can be found at the [CT-CLIP repository](https://github.com/ibrahimethemhamamci/CT-CLIP/tree/main/text_classifier/data), are treated as a gold standard for evaluation.

Furthermore, as noted in the official [CT-RATE dataset correction note](https://huggingface.co/datasets/ibrahimhamamci/CT-RATE/blob/main/dataset/data_correction_note.md), certain scans have been identified as brain scans or as having a missing z-space.

For these reasons, all of these scans (manual labels, brain scans, missing z-space) are removed from the main body of data. The original train/validation split provided with the CT-RATE dataset on Hugging Face is ignored. Instead, this repository provides a `FILTERED_MASTER_LIST.csv`. This approach offers two main advantages:

1.  **Flexibility**: Users of this repository can freely create their own data splits from a clean, reliable master list.
2.  **Gold Standard Evaluation**: The separated manual labels can be used as a high-quality, independent test set.

The trade-off is that results from this repository may not be directly comparable to other models trained on the original, unfiltered dataset splits. However, for the objectives of these experiments, having a reliable, manually verified evaluation set is of higher importance.

#### A Note on Implementation

For the filtering script to function correctly, it is required that the user generates a unified CSV file containing the manual labels.

This file, which corresponds to the `manual_labels` key in the `exclusion_files` section of the config, must contain a `VolumeName` column for each scan that has been manually labeled. It is highly recommended that this CSV also include the corresponding pathology labels and the full text report for each scan.

The necessary mapping and labels to create this file can be found in the link provided above for the manual labels.

In [None]:
# --- 1. Execute the filtering script function ---
# This function encapsulates the entire filtering process.
from scripts.data_preparation.create_filtered_dataset import create_filtered_dataset

print("Running the create_filtered_dataset function...")
create_filtered_dataset(config)

# --- 2. Load and inspect the generated file ---
try:
    filtered_list_path = Path(config.paths.data_dir) / config.paths.output_filename
    filtered_df = pd.read_csv(filtered_list_path)
    
    print(f"\nSuccessfully loaded the filtered list from: {filtered_list_path}")
    
    initial_volume_count = len(all_volumes_df)
    final_volume_count = len(filtered_df)
    
    print(f"\nInitial number of volumes: {initial_volume_count}")
    print(f"Number of volumes after filtering: {final_volume_count}")
    print(f"Total volumes removed: {initial_volume_count - final_volume_count}")

    print("\nHead of the final filtered list of volumes:")
    display(filtered_df.head())

except FileNotFoundError:
    print(f"ERROR: The expected output file was not found at {filtered_list_path}")

## 4. Generating K-Fold Splits

Now, using the filtered master list, we will call the `create_kfold_splits` function. This script will generate the stratified, patient-aware cross-validation folds and save them as separate CSV files.

In [None]:
# --- 1. Define parameters and execute the splitting function ---
from scripts.data_preparation.create_kfold_splits import create_kfold_splits

N_SPLITS = 5
output_dir = Path(config.paths.data_dir) / "splits" / f"kfold_{N_SPLITS}"

print(f"Running the create_kfold_splits function for {N_SPLITS} splits...")
# The config needs to be updated to point to the correct master list file.
config.paths.full_dataset_csv = config.paths.output_filename
create_kfold_splits(config, n_splits=N_SPLITS, output_dir=output_dir)

# --- 2. Verify that the split files have been created ---
print(f"\nChecking for output files in: {output_dir}")
assert output_dir.exists(), "Output directory for splits was not created."

# Check for the first fold's files
train_path = output_dir / "train_fold_0.csv"
valid_path = output_dir / "valid_fold_0.csv"

assert train_path.exists(), "train_fold_0.csv not found."
assert valid_path.exists(), "valid_fold_0.csv not found."
print("Split files for Fold 0 found successfully.")


# --- 3. Load a sample fold and verify no patient leakage ---
df_train_fold = pd.read_csv(train_path)
df_valid_fold = pd.read_csv(valid_path)

# Extract patient IDs
df_train_fold['PatientID'] = df_train_fold['VolumeName'].apply(get_patient_id)
df_valid_fold['PatientID'] = df_valid_fold['VolumeName'].apply(get_patient_id)

train_patients = set(df_train_fold['PatientID'])
valid_patients = set(df_valid_fold['PatientID'])

print(f"\nNumber of patients in training set (Fold 0): {len(train_patients)}")
print(f"Number of patients in validation set (Fold 0): {len(valid_patients)}")

assert train_patients.isdisjoint(valid_patients), "Patient leakage detected!"
print("\nVerification successful: No patient leakage between training and validation sets for Fold 0.")