<center><h1>CTA Prediction Tool <center>

In [1]:
import os
import shutil
import zipfile
import numpy as np
import torch
import nibabel as nib
import cv2
import pydicom
import pathlib
import pandas as pd
import torch.nn as nn
import torchvision.models as models
import ipywidgets as widgets
from ipyfilechooser import FileChooser
from IPython.display import display, HTML, clear_output
from tqdm.notebook import tqdm

# --- Global Stop Flag ---
stop_flag = {'terminate': False}
last_processed = {'scan': None}

# --- ResNet18 Model Definition ---
class SliceCNN(nn.Module):
    def __init__(self, feature_dim=128):
        super(SliceCNN, self).__init__()
        base_model = models.resnet18(pretrained=False)
        base_model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.backbone = nn.Sequential(*list(base_model.children())[:-1])
        self.fc = nn.Linear(512, feature_dim)

    def forward(self, x):
        x = self.backbone(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class CTAQuality2D(nn.Module):
    def __init__(self, feature_dim=128):
        super(CTAQuality2D, self).__init__()
        self.slice_cnn = SliceCNN(feature_dim)
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        B, S, C, H, W = x.shape
        x = x.view(B * S, C, H, W)
        feats = self.slice_cnn(x)
        feats = feats.view(B, S, -1)
        pooled = feats.mean(dim=1)
        out = self.classifier(pooled)
        return torch.sigmoid(out.squeeze(1))

# --- Preprocessing ---
def load_scan(file_path):
    file_path = str(pathlib.Path(file_path))
    if file_path.endswith('.nii') or file_path.endswith('.nii.gz'):
        img = nib.load(file_path)
        return img.get_fdata()
    elif os.path.isdir(file_path):
        dicom_files = [f for f in os.listdir(file_path) if f.endswith('.dcm')]
        if not dicom_files:
            raise FileNotFoundError(f"No DICOM files in {file_path}")
        slices = [pydicom.dcmread(os.path.join(file_path, f)) for f in dicom_files]
        slices.sort(key=lambda x: float(x.ImagePositionPatient[2]))
        volume = np.stack([s.pixel_array for s in slices], axis=-1)
        return volume
    else:
        raise ValueError("Unsupported file type")

def normalize_volume(volume):
    volume = volume.astype(np.float32)
    volume -= np.min(volume)
    volume /= np.max(volume)
    return volume

def extract_12_slices(volume):
    z_len = volume.shape[2]
    axial_indices = np.linspace(0, z_len - 1, 10, dtype=int)
    axial_slices = [volume[:, :, idx] for idx in axial_indices]
    coronal = volume[:, volume.shape[1] // 2, :]
    sagittal = volume[volume.shape[0] // 2, :, :]
    return axial_slices + [coronal, sagittal]

def preprocess_scan(path, target_size=(224, 224)):
    volume = load_scan(path)
    volume = normalize_volume(volume)
    slices = extract_12_slices(volume)
    resized = [cv2.resize(s, target_size, interpolation=cv2.INTER_AREA) for s in slices]
    return np.stack(resized, axis=0)

def predict_single(model, path, device, threshold=0.5):
    try:
        slices = preprocess_scan(path)
        print("Input stats:")
        print(f"  Min: {np.min(slices):.4f}, Max: {np.max(slices):.4f}, Std: {np.std(slices):.4f}")
    except Exception as e:
        print(f"Failed to preprocess {path}: {e}")
        return None  # Or return a default value

    slices = torch.tensor(slices).unsqueeze(1).unsqueeze(0).float().to(device)

    with torch.no_grad():
        logit = model(slices).item()
        prob = torch.sigmoid(torch.tensor(logit)).item()
        pred = 1 if prob > threshold else 0

    print(f"Scan: {path}")
    print(f"  Logit: {logit:.2f}")
    print(f"  Sigmoid Probability: {prob:.4f}")
    print(f"  Predicted Label: {pred} (Threshold: {threshold})")
    return round(prob, 4)

def colorize_score(val):
    if pd.isnull(val):
        return 'background-color: lightgray'
    elif val >= 0.7:
        return 'background-color: lightgreen'
    elif val >= 0.4:
        return 'background-color: khaki'
    else:
        return 'background-color: lightcoral'

def clear_done_cache():
    if os.path.exists(".predictions"):
        shutil.rmtree(".predictions")
        print("Cleared .predictions cache.")
    else:
        print("No .predictions cache to clear.")

# --- Restore sync prediction function ---
def predict_batch_recursive(model_path, root_dir):
    results = []
    log_lines = []
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = CTAQuality2D()
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    scan_targets = []
    for dirpath, _, filenames in os.walk(root_dir):
        nifti_files = [f for f in filenames if f.endswith('.nii') or f.endswith('.nii.gz')]
        for f in nifti_files:
            rel_path = os.path.relpath(os.path.join(dirpath, f), root_dir)
            if not os.path.exists(f".predictions/{rel_path}.done"):
                scan_targets.append((os.path.join(dirpath, f), 'NIfTI'))

        dicom_files = [f for f in filenames if f.endswith('.dcm')]
        if dicom_files:
            rel_path = os.path.relpath(dirpath, root_dir)
            if not os.path.exists(f".predictions/{rel_path}.done"):
                scan_targets.append((dirpath, 'DICOM'))

    os.makedirs(".predictions", exist_ok=True)

    for scan_path, scan_type in tqdm(scan_targets, desc="Processing Scans"):
        if stop_flag['terminate']:
            tqdm.write("Termination requested. Stopping early.")
            break
        last_processed['scan'] = scan_path
        tqdm.write(f"Processing: {scan_path}")
        try:
            score = predict_single(model, scan_path, device)
            results.append({
                "scan": os.path.relpath(scan_path, root_dir),
                "type": scan_type,
                "quality_score": score
            })
            done_flag = f".predictions/{os.path.relpath(scan_path, root_dir)}.done"
            os.makedirs(os.path.dirname(done_flag), exist_ok=True)
            with open(done_flag, 'w') as f:
                f.write("done")
            log_lines.append(f"SUCCESS: {scan_path} -> {score:.2f}")
        except Exception as e:
            results.append({
                "scan": os.path.relpath(scan_path, root_dir),
                "type": scan_type,
                "quality_score": None,
                "error": str(e)
            })
            log_lines.append(f"ERROR: {scan_path} -> {e}")

    with open("cta_prediction_log.txt", "w", encoding="utf-8") as log:
        for line in log_lines:
            log.write(line + "\n")

    print(f"Finished. Last processed scan: {last_processed['scan']}")
    print(f"Total processed: {len(results)}")
    return pd.DataFrame(results)

In [2]:
# --- UI Block ---
def stop_prediction(_):
    stop_flag['terminate'] = True
# (same as before except run_prediction becomes synchronous)
def run_prediction(_):
    stop_flag['terminate'] = False
    output.clear_output()
    with output:
        model_path = model_chooser.selected
        scan_path = scan_chooser.selected

        if not model_path or not scan_path:
            print("Please select both a model and a scan input.")
            return

        if os.path.isfile(scan_path) and (scan_path.endswith('.nii') or scan_path.endswith('.nii.gz')):
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
            model = CTAQuality2D()
            model.load_state_dict(torch.load(model_path, map_location=device))
            model.to(device)
            model.eval()
            score = predict_single(model, scan_path, device)
            print(f"Predicted quality score for {scan_path} -> {score:.2f}")

        elif os.path.isdir(scan_path):
            df = predict_batch_recursive(model_path, scan_path)
            if 'quality_score' in df.columns:
                styled_df = df.style.map(colorize_score, subset=['quality_score'])
                display(HTML(styled_df.to_html()))
            else:
                print("No valid results to display.")
                display(df)

            save_btn = widgets.Button(description="Save CSV")
            zip_btn = widgets.Button(description="Download Results as ZIP")

            def save_callback(b):
                try:
                    df.to_csv("cta_quality_results.csv", index=False)
                    print("Saved to cta_quality_results.csv")
                except Exception as e:
                    print("Failed to save CSV:", e)

            def zip_callback(b):
                try:
                    with zipfile.ZipFile("cta_results_bundle.zip", 'w') as zipf:
                        zipf.write("cta_quality_results.csv")
                        zipf.write("cta_prediction_log.txt")
                    print("Created cta_results_bundle.zip")
                except Exception as e:
                    print("Failed to create ZIP:", e)

            save_btn.on_click(save_callback)
            zip_btn.on_click(zip_callback)
            display(widgets.HBox([save_btn, zip_btn]))

        else:
            print("Unsupported input.")

model_chooser = FileChooser(os.getcwd(), title="Select Model (.pt)", filter_pattern="*.pt")
scan_chooser = FileChooser(os.getcwd(), title="Select scan file or folder")
run_btn = widgets.Button(description="Run Prediction")
stop_btn = widgets.Button(description="Stop")
clear_cache_btn = widgets.Button(description="Clear .done Cache", button_style='warning')
output = widgets.Output()

clear_cache_btn.on_click(lambda b: clear_done_cache())
stop_btn.on_click(stop_prediction)
run_btn.on_click(run_prediction)

display(widgets.VBox([
    model_chooser,
    scan_chooser,
    widgets.HBox([run_btn, stop_btn, clear_cache_btn]),
    output
]))


VBox(children=(FileChooser(path='C:\Users\HenryLi\Desktop\Python Projects\CTA Scan Binary Classifier', filenam…