In [None]:
!pip install pydicom matplotlib numpy SimpleITK

Collecting pydicom
  Downloading pydicom-3.0.1-py3-none-any.whl.metadata (9.4 kB)
Collecting SimpleITK
  Downloading SimpleITK-2.4.1-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Downloading pydicom-3.0.1-py3-none-any.whl (2.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading SimpleITK-2.4.1-cp311-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.3/52.3 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK, pydicom
Successfully installed SimpleITK-2.4.1 pydicom-3.0.1


In [None]:
from google.colab import drive
drive.mount('/content/drive')



Mounted at /content/drive


In [None]:
import os
import pydicom
import numpy as np
import matplotlib.pyplot as plt
import SimpleITK as sitk
import torch
import torch.nn.functional as F
!wget https://github.com/QIICR/dcmqi/releases/download/v1.2.5/dcmqi-1.2.5-linux.tar.gz
!tar -xvzf dcmqi-1.2.5-linux.tar.gz
DCMQI_BIN = "/content/dcmqi-1.2.5-linux/bin/"

--2025-03-08 06:19:30--  https://github.com/QIICR/dcmqi/releases/download/v1.2.5/dcmqi-1.2.5-linux.tar.gz
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/50675718/79d3ad95-9f0c-42a4-a1c5-bf5a63461894?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250308%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250308T061930Z&X-Amz-Expires=300&X-Amz-Signature=1bface6dea7b561dbd7b482560059751534554cc0761b6af665643525454e424&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Ddcmqi-1.2.5-linux.tar.gz&response-content-type=application%2Foctet-stream [following]
--2025-03-08 06:19:30--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/50675718/79d3ad95-9f0c-42a4-a1c5-bf5a63461894?X-Amz-Algorithm=AWS4-HMAC-SHA256&

In [None]:
import os
import subprocess
import SimpleITK as sitk
import numpy as np
import torch
import torch.nn.functional as F

def pad_to_depth(x, target_depth):

    current_depth = x.shape[1]
    if current_depth < target_depth:
        pad_amount = target_depth - current_depth
        pad_tensor = torch.zeros(x.shape[0], pad_amount, x.shape[2], x.shape[3],
                                   device=x.device, dtype=x.dtype)
        x = torch.cat([x, pad_tensor], dim=1)
    return x

def custom_collate_fn(batch):
    """
    Pads the CT and segmentation tensors in the batch along the depth dimension.
    Each sample is a dict with keys 'ct' and 'seg' (each of shape (1, D, H, W)).
    """
    max_depth = max(sample['ct'].shape[1] for sample in batch)
    for sample in batch:
        sample['ct'] = pad_to_depth(sample['ct'], max_depth)
        sample['seg'] = pad_to_depth(sample['seg'], max_depth)
    batch_ct = torch.stack([sample['ct'] for sample in batch])
    batch_seg = torch.stack([sample['seg'] for sample in batch])
    return {'ct': batch_ct, 'seg': batch_seg}

# -----------------------
# Conversion Function
# -----------------------
def convert_seg_dicom_to_mha(im1_path, im3_path, output_dir):
    """
    Convert a DICOM segmentation in im3_path into .mha format using dcmqi.
    The resulting .mha file is saved into output_dir (the patient folder).

    :param im1_path: Path to the folder containing CT DICOM slices (e.g., 'im_1')
    :param im3_path: Path to the folder containing segmentation DICOM files (e.g., 'im_3')
    :param output_dir: Patient folder where the .mha file should be saved.
    :return: Path to the created .mha file, or None if conversion fails.
    """
    # Install dcmqi (if not already installed)
    !pip install --quiet dcmqi

    # Download and unpack dcmqi binaries if needed
    if not os.path.exists("dcmqi-1.2.5-linux"):
        !wget --quiet https://github.com/QIICR/dcmqi/releases/download/v1.2.5/dcmqi-1.2.5-linux.tar.gz
        !tar -xvzf dcmqi-1.2.5-linux.tar.gz

    dcmqi_bin = os.path.join(os.getcwd(), "dcmqi-1.2.5-linux", "bin")

    # Find the segmentation DICOM file in im3_path
    seg_dcm_files = [f for f in os.listdir(im3_path) if f.lower().endswith(".dcm")]
    if not seg_dcm_files:
        print(f"No segmentation DICOM file found in {im3_path}.")
        return None

    seg_dicom_path = os.path.join(im3_path, seg_dcm_files[0])

    # Ensure the output directory exists (output_dir is the patient folder)
    os.makedirs(output_dir, exist_ok=True)

    # Build and run the conversion command
    convert_cmd = [
        os.path.join(dcmqi_bin, "segimage2itkimage"),
        "--inputDICOM", seg_dicom_path,
        "--outputDirectory", output_dir,
        "-t", "mha",
        "-p", "segmentation"
    ]

    print("Running dcmqi conversion command:\n", " ".join(convert_cmd))
    result = subprocess.run(convert_cmd, capture_output=True, text=True)
    if result.returncode != 0:
        print("Conversion failed. Error:\n", result.stderr)
        return None

    # Look for the created .mha file (e.g., 'segmentation-1.mha') in the patient folder
    mha_files = [f for f in os.listdir(output_dir) if f.startswith("segmentation") and f.endswith(".mha")]
    if not mha_files:
        print(f"No .mha file was created in {output_dir}.")
        return None

    mha_path = os.path.join(output_dir, mha_files[0])
    print(f"Created .mha file: {mha_path}")
    return mha_path

# -----------------------
# Data Processing Functions
# -----------------------
def get_patient_folders(root_path):
    """
    Returns a list of directories under root_path that contain an 'im_1' subfolder.
    These directories are assumed to be patient folders.
    """
    patient_folders = []
    for entry in os.listdir(root_path):
        full_path = os.path.join(root_path, entry)
        if os.path.isdir(full_path) and os.path.exists(os.path.join(full_path, "im_1")):
            patient_folders.append(full_path)
    return patient_folders

import threading

def get_dicom_series(directory, timeout=60):
    """
    Loads a DICOM series from a directory with a timeout.
    If loading takes longer than 'timeout' seconds, the function skips the series.
    """
    reader = sitk.ImageSeriesReader()
    series_IDs = reader.GetGDCMSeriesIDs(directory)

    if not series_IDs:
        print(f"No DICOM series found in {directory}.")
        return None

    dicom_files = reader.GetGDCMSeriesFileNames(directory, series_IDs[0])
    reader.SetFileNames(dicom_files)

    # Result placeholder
    result = [None]

    def load_dicom():
        try:
            result[0] = reader.Execute()
        except Exception as e:
            print(f"Error loading DICOM series from {directory}: {e}")
            result[0] = None

    # Create a thread to run the DICOM loading process
    load_thread = threading.Thread(target=load_dicom)
    load_thread.start()

    # Wait for the thread to finish within the timeout
    load_thread.join(timeout)

    if load_thread.is_alive():
        print(f"Skipping {directory} due to timeout (took longer than {timeout} seconds).")
        return None

    return result[0]

def process_files():
    """
    Processes patient folders under "My Drive" and stops after 25 patients.
    """
    processed_data = []
    target_shape = (128, 128, 72)
    base_path = "/content/drive/My Drive/"
    patients = get_patient_folders(base_path)

    if not patients:
        print("No patient folders found under My Drive.")
        return processed_data

    count = 0  # Counter to track processed files
    max_files = 25  # Limit processing to 25 files

    for patient in patients:
        if count >= max_files:
            print("\nReached the limit of 25 processed files. Stopping.")
            break

        print(f"\nProcessing patient folder: {patient}")

        ct_path = os.path.join(patient, "im_1")
        seg_path = os.path.join(patient, "im_3")

        # Load CT series from im_1
        ct_image = get_dicom_series(ct_path)
        if ct_image is None:
            print(f"Skipping {patient}: No DICOM series found in im_1.")
            continue

        ct_volume = sitk.GetArrayFromImage(ct_image)
        print(f"Original CT Shape: {ct_volume.shape}")

        # Process segmentation from im_3
        if not os.path.exists(seg_path):
            print(f"Skipping {patient}: No im_3 folder found.")
            continue

        # Look for an existing .mha segmentation file in im_3
        seg_files = [f for f in os.listdir(seg_path) if f.endswith(".mha")]
        if not seg_files:
            print(f"No .mha segmentation file found in im_3 for {patient}. Attempting conversion...")
            mha_file = convert_seg_dicom_to_mha(ct_path, seg_path, patient)
            if mha_file is None:
                print(f"Conversion failed for {patient}. Skipping.")
                continue
            else:
                seg_file_path = mha_file
        else:
            seg_file_path = os.path.join(seg_path, seg_files[0])

        # Load segmentation image
        seg_image = sitk.ReadImage(seg_file_path)
        seg_volume = sitk.GetArrayFromImage(seg_image)
        print(f"Original Segmentation Shape: {seg_volume.shape}")

        # Resample segmentation if dimensions differ from the CT volume
        if seg_volume.shape != ct_volume.shape:
            print("Resampling segmentation to match CT dimensions...")
            seg_tensor = torch.tensor(seg_volume, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
            seg_resampled = F.interpolate(seg_tensor, size=ct_volume.shape, mode='nearest')
            seg_volume = seg_resampled.squeeze().numpy().astype(np.uint8)
            print(f"Resampled Segmentation Shape: {seg_volume.shape}")

        # Downsample CT and segmentation to target shape
        ct_tensor = torch.tensor(ct_volume, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        ct_resized = F.interpolate(ct_tensor, size=target_shape, mode='trilinear', align_corners=False)
        ct_volume_resized = ct_resized.squeeze().numpy()

        seg_tensor = torch.tensor(seg_volume, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        seg_resized = F.interpolate(seg_tensor, size=target_shape, mode='nearest')
        seg_volume_resized = seg_resized.squeeze().numpy().astype(np.uint8)

        print(f"Downsampled CT Shape: {ct_volume_resized.shape}")
        print(f"Downsampled Segmentation Shape: {seg_volume_resized.shape}")

        processed_data.append({
            "filename": patient,
            "ct_volume": ct_volume_resized,
            "segmentation": seg_volume_resized,
            "spacing": ct_image.GetSpacing()[::-1],  # (Depth, Height, Width)
            "origin": ct_image.GetOrigin()
        })

        count += 1  # Increment counter

    print("\nProcessing complete!")
    print(f"Processed {len(processed_data)} datasets successfully.")
    return processed_data

# -----------------------
# Main Execution
# -----------------------
from google.colab import drive
drive.mount('/content/drive')

# Run the processing pipeline over all patient folders in My Drive
processed_data = process_files()


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Processing patient folder: /content/drive/My Drive/0A44743795D421F7
Original CT Shape: (520, 512, 512)
No .mha segmentation file found in im_3 for /content/drive/My Drive/0A44743795D421F7. Attempting conversion...
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m40.4/40.4 MB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m
[?25hRunning dcmqi conversion command:
 /content/dcmqi-1.2.5-linux/bin/segimage2itkimage --inputDICOM /content/drive/My Drive/0A44743795D421F7/im_3/x0000.dcm --outputDirectory /content/drive/My Drive/0A44743795D421F7 -t mha -p segmentation
Created .mha file: /content/drive/My Drive/0A44743795D421F7/segmentation-1.mha
Original Segmentation Shape: (246, 512, 512)
Resampling segmentation to match CT dimensions...
Resampled Segmentation Shape: (520, 512, 512)
Downsampled CT Shape: (128, 128, 72)
Downsampled Segmentation Shape: (128, 1

In [None]:
!pip install monai

Collecting monai
  Downloading monai-1.4.0-py3-none-any.whl.metadata (11 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.9->monai)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.9->monai)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.9->monai)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.9->monai)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.9->monai)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch>=1.9->monai)
  Downloading nvidia_cufft_cu12-11.2.1

In [None]:
from monai.transforms import Compose, NormalizeIntensity, ToTensor


**Multipass Unet Model Training**

---



In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from monai.transforms import Compose

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# -------------------------
# Transforms for CT Scans
# -------------------------
class NormalizeCT:
    def __call__(self, ct_volume):
        ct_volume = np.clip(ct_volume, -1000, 1000)
        return (ct_volume + 1000) / 2000  # Scale to [0,1]

class ToTensor:
    def __call__(self, ct_volume):
        # Convert NumPy array to tensor and add a channel dimension -> (1, D, H, W)
        return torch.from_numpy(ct_volume).unsqueeze(0)

# -------------------------
# Custom Dataset
# -------------------------
class CTDataset(Dataset):
    def __init__(self, processed_data):
        self.data = processed_data
        self.transform = Compose([
            NormalizeCT(),
            ToTensor()
        ])

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        ct = sample['ct_volume']       # shape: (D, H, W)
        seg = sample['segmentation']   # shape: (D, H, W)

        # Apply transforms to CT volume
        ct = self.transform(ct)  # (1, D, H, W)
        # For segmentation, just convert to tensor with a channel dimension
        seg = torch.from_numpy(seg).unsqueeze(0).float()  # (1, D, H, W)

        return {'ct': ct.float(), 'seg': seg.float()}

# -------------------------
# Custom Collate Function (as defined above)
# -------------------------
def pad_to_depth(x, target_depth):
    current_depth = x.shape[1]
    if current_depth < target_depth:
        pad_amount = target_depth - current_depth
        pad_tensor = torch.zeros(x.shape[0], pad_amount, x.shape[2], x.shape[3], device=x.device, dtype=x.dtype)
        x = torch.cat([x, pad_tensor], dim=1)
    return x

def custom_collate_fn(batch):
    max_depth = max(sample['ct'].shape[1] for sample in batch)
    for sample in batch:
        sample['ct'] = pad_to_depth(sample['ct'], max_depth)
        sample['seg'] = pad_to_depth(sample['seg'], max_depth)
    batch_ct = torch.stack([sample['ct'] for sample in batch])
    batch_seg = torch.stack([sample['seg'] for sample in batch])
    return {'ct': batch_ct, 'seg': batch_seg}

# -------------------------
# Multipass 3D U-Net Model
# -------------------------
class MultiPassUNet3D(nn.Module):
    def __init__(self, in_channels=2, out_channels=1, base_channels=16):
        super(MultiPassUNet3D, self).__init__()
        # Encoder
        self.enc1 = self._block(in_channels, base_channels)
        self.pool1 = nn.MaxPool3d(2)
        self.enc2 = self._block(base_channels, base_channels*2)
        self.pool2 = nn.MaxPool3d(2)
        # Bridge
        self.bridge = self._block(base_channels*2, base_channels*4)
        # Decoder
        self.up1 = nn.ConvTranspose3d(base_channels*4, base_channels*2, 2, stride=2)
        self.dec1 = self._block(base_channels*4, base_channels*2)
        self.up2 = nn.ConvTranspose3d(base_channels*2, base_channels, 2, stride=2)
        self.dec2 = self._block(base_channels*2, base_channels)
        # Output
        self.out = nn.Conv3d(base_channels, out_channels, 1)

    def _block(self, in_channels, features):
        return nn.Sequential(
            nn.Conv3d(in_channels, features, 3, padding=1),
            nn.BatchNorm3d(features),
            nn.ReLU(inplace=True),
            nn.Conv3d(features, features, 3, padding=1),
            nn.BatchNorm3d(features),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool1(enc1))
        bridge = self.bridge(self.pool2(enc2))
        dec1 = self.up1(bridge)
        dec1 = torch.cat((dec1, enc2), dim=1)
        dec1 = self.dec1(dec1)
        dec2 = self.up2(dec1)
        dec2 = torch.cat((dec2, enc1), dim=1)
        dec2 = self.dec2(dec2)
        return self.out(dec2)

# -------------------------
# Dice Score Metric
# -------------------------
def dice_score(pred, target, smooth=1e-6):
    pred = (pred > 0.5).float()
    intersection = (pred * target).sum(dim=[1,2,3,4])
    union = pred.sum(dim=[1,2,3,4]) + target.sum(dim=[1,2,3,4])
    dice = (2. * intersection + smooth) / (union + smooth)
    return dice.mean().item()

# -------------------------
# Training Setup
# -------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiPassUNet3D().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)
criterion = nn.BCEWithLogitsLoss()
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3)

# Create DataLoader using our custom collate function
dataset = CTDataset(processed_data)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2, collate_fn=custom_collate_fn)

# -------------------------
# Multipass Training Loop
# -------------------------
num_epochs = 50
best_loss = float('inf')
MODEL_NAME = "MultiPassUNet3D"
SAVE_DIR = f"/content/gdrive/My Drive/UNet Model"
os.makedirs(SAVE_DIR, exist_ok=True)

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    dice_total = 0.0
    num_batches = 0

    for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}"):
        ct = batch['ct'].to(device)   # (B, 1, D, H, W)
        seg = batch['seg'].to(device) # (B, 1, D, H, W)

        # First Pass: Blank shape context
        blank_context = torch.zeros_like(ct)
        input_pass1 = torch.cat([ct, blank_context], dim=1)  # (B, 2, D, H, W)
        output_pass1 = model(input_pass1)
        loss1 = criterion(output_pass1, seg)

        # Second Pass: Use thresholded prediction from pass 1 as shape context
        with torch.no_grad():
            context = (torch.sigmoid(output_pass1) > 0.5).float()
        input_pass2 = torch.cat([ct, context], dim=1)
        output_pass2 = model(input_pass2)
        loss2 = criterion(output_pass2, seg)

        total_loss = loss1 + loss2

        optimizer.zero_grad()
        total_loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_loss += total_loss.item()
        dice_total += dice_score(torch.sigmoid(output_pass2), seg)
        num_batches += 1

    avg_loss = epoch_loss / len(dataloader)
    avg_dice = dice_total / num_batches
    scheduler.step(avg_loss)

    print(f"Epoch {epoch+1} | Avg Loss: {avg_loss:.4f} | Avg Dice Score: {avg_dice:.4f}")
    Avg_dice_score[epoch]=avg_dice
    Avg_loss[epoch]=avg_loss

    if avg_loss < best_loss:
        best_loss = avg_loss
        torch.save(model.state_dict(), f"best_model_epoch{epoch+1}.pth")

print("Training complete!")


Epoch 1:   0%|          | 0/25 [00:00<?, ?it/s]

In [None]:
import matplotlib.pyplot as plt
#dice score plot
plt.figure(figsize(10,5))
plt.plot(epoch,Avg_dice_score,label='Dice Score',marker='o')
plt.xlabel('Epoch')
plt.ylabel('Dice Score')
plt.title('Dice Score vs Epoch')
plt.legend()
plt.grid(True)
plt.show()
#avg loss plot
plt.figure(figsize(10,5))
plt.plot(epoch,Avg_Loss,Label='Loss',marker='o')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Loss vs Epoch')
plt.legend()
plt.grid(True)
plt.show()
#inference
# Load the trained model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MultiPassUNet3D().to(device)
model.load_state_dict(torch.load("best_model_epochX.pth"))  # Replace with actual epoch
model.eval()

# Function to preprocess a test CT scan
def preprocess_ct_scan(ct_path, target_shape=(128, 128, 72)):
    # Load CT scan (assuming MHA format)
    ct_image = sitk.ReadImage(ct_path)
    ct_volume = sitk.GetArrayFromImage(ct_image)  # Shape: (D, H, W)

    # Normalize
    ct_volume = np.clip(ct_volume, -1000, 1000)
    ct_volume = (ct_volume + 1000) / 2000  # Scale to [0,1]

    # Resize to target shape
    ct_tensor = torch.tensor(ct_volume, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # (1, 1, D, H, W)
    ct_resized = F.interpolate(ct_tensor, size=target_shape, mode='trilinear', align_corners=False)
    return ct_resized.squeeze(0)  # (1, D, H, W)

# Function to run inference
def infer(ct_tensor):
    with torch.no_grad():
        blank_context = torch.zeros_like(ct_tensor)  # Pass 1: blank shape context
        input_pass1 = torch.cat([ct_tensor, blank_context], dim=1).to(device)
        output_pass1 = model(input_pass1)

        context = (torch.sigmoid(output_pass1) > 0.5).float()  # Pass 2: thresholded prediction as context
        input_pass2 = torch.cat([ct_tensor, context], dim=1).to(device)
        output_pass2 = model(input_pass2)

    return torch.sigmoid(output_pass2).cpu().numpy()  # Convert to NumPy array

# Paths to test images (update these paths)
test1_ct_path = "path/to/test1.mha"
test2_ct_path = "path/to/test2.mha"
test1_seg_path = "path/to/test1_seg.mha"  # Ground truth segmentation
test2_seg_path = "path/to/test2_seg.mha"

# Preprocess test scans
ct_test1 = preprocess_ct_scan(test1_ct_path)
ct_test2 = preprocess_ct_scan(test2_ct_path)

# Run inference
pred_test1 = infer(ct_test1)
pred_test2 = infer(ct_test2)

# Load ground truth segmentations
gt_test1 = sitk.GetArrayFromImage(sitk.ReadImage(test1_seg_path))  # (D, H, W)
gt_test2 = sitk.GetArrayFromImage(sitk.ReadImage(test2_seg_path))

# Plot function
def plot_results(ct_volume, ground_truth, prediction, slice_idx):
    plt.figure(figsize=(12, 4))

    # Original CT slice
    plt.subplot(1, 3, 1)
    plt.imshow(ct_volume[slice_idx], cmap='gray')
    plt.title("CT Slice")

    # Ground truth segmentation
    plt.subplot(1, 3, 2)
    plt.imshow(ground_truth[slice_idx], cmap='gray')
    plt.title("Ground Truth Segmentation")

    # Predicted segmentation
    plt.subplot(1, 3, 3)
    plt.imshow(prediction[slice_idx], cmap='gray')
    plt.title("Predicted Segmentation")

    plt.show()

# Visualize a slice (e.g., middle slice)
mid_slice = ct_test1.shape[1] // 2  # Depth index
plot_results(ct_test1.squeeze().numpy(), gt_test1, pred_test1.squeeze(), mid_slice)
plot_results(ct_test2.squeeze().numpy(), gt_test2, pred_test2.squeeze(), mid_slice)
