In [None]:
import os, json, random
import numpy as np
import cv2
import pydicom
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast

# --------------------------- SEED ---------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# ---------------------- DEVICE ------------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# ------------------------- PARAMS ---------------------------
NUM_SLICES = 18
IMG_SIZE   = 256

LEVELS = ["L1-L2","L2-L3","L3-L4","L4-L5","L5-S1"]
SIDES  = ["left","right"]
TARGET_KEYS = [f"{l}_{s}" for l in LEVELS for s in SIDES]  # 10 outputs

BATCH_SIZE  = 8
NUM_WORKERS = 0  # Windows safe

# ------------------------- PATHS ----------------------------
images_root = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\split_dataset\images"
dicom_root  = os.path.join(images_root, "test")  # <--- test DICOM folders
label_root  = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\split_dataset\label\test"

model_path  = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\best_model_convnext3d_regression.pth"


# ============================================================
#                   PREPROCESSING
# ============================================================

def extract_foreground(img, threshold=10):
    mask = img > threshold
    if not np.any(mask):
        return img
    coords = np.column_stack(np.where(mask))
    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)
    return img[y_min:y_max+1, x_min:x_max+1]

def extract_labels(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)

    label = []
    mask  = []
    for key in TARGET_KEYS:
        lvl, side = key.split("_")
        coord = data.get(lvl, {}).get(side)
        if coord is None:
            label.append(0.0)
            mask.append(0.0)
        else:
            z = coord[2]  # normalized [0..1]
            z_index = int(round(z * NUM_SLICES))
            z_index = max(0, min(NUM_SLICES - 1, z_index))
            label.append(float(z_index))
            mask.append(1.0)

    return np.array(label, dtype=np.float32), np.array(mask, dtype=np.float32)

def load_and_preprocess_volume(dicom_dir):
    files = [os.path.join(dicom_dir, f) for f in os.listdir(dicom_dir) if f.endswith(".dcm")]
    if len(files) == 0:
        raise RuntimeError(f"No DICOM files in {dicom_dir}")

    def sort_key(p):
        try:
            return int(pydicom.dcmread(p, stop_before_pixels=True).InstanceNumber)
        except Exception:
            return p

    files = sorted(files, key=sort_key)

    # enforce NUM_SLICES
    if len(files) > NUM_SLICES:
        s = (len(files) - NUM_SLICES) // 2
        files = files[s:s + NUM_SLICES]
    elif len(files) < NUM_SLICES:
        files = files + [files[-1]] * (NUM_SLICES - len(files))

    vol = []
    for p in files:
        ds = pydicom.dcmread(p)
        img = ds.pixel_array.astype(np.float32)

        img = extract_foreground(img)
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
        vol.append(img)

    vol = np.stack(vol).astype(np.float32)  # (D,H,W)
    if vol.shape != (NUM_SLICES, IMG_SIZE, IMG_SIZE):
        raise RuntimeError(f"Bad volume shape {vol.shape} in {dicom_dir}")

    return vol


# ============================================================
#                   DATASET / LOADER
# ============================================================

class LumbarDataset(Dataset):
    def __init__(self, ids):
        self.ids = ids

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        pid = self.ids[idx]
        vol_path   = os.path.join(dicom_root, pid)
        label_path = os.path.join(label_root, f"{pid}.json")

        vol = load_and_preprocess_volume(vol_path)
        y, m = extract_labels(label_path)

        vol = torch.from_numpy(vol).unsqueeze(0).float()  # (1,D,H,W)
        y   = torch.from_numpy(y).float()
        m   = torch.from_numpy(m).float()
        return vol, y, m


# Collect only valid patients that have json labels
patient_ids = []
for pid in os.listdir(dicom_root):
    vol_path   = os.path.join(dicom_root, pid)
    label_path = os.path.join(label_root, f"{pid}.json")
    if os.path.isdir(vol_path) and os.path.exists(label_path):
        patient_ids.append(pid)

print("Total test samples:", len(patient_ids))

test_loader = DataLoader(
    LumbarDataset(patient_ids),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)


# ============================================================
#                   MODEL
# ============================================================

class Block(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.dw  = nn.Conv3d(c, c, kernel_size=(3,7,7), padding=(1,3,3), groups=c, bias=False)
        self.gn  = nn.GroupNorm(1, c)
        self.pw1 = nn.Conv3d(c, 4*c, kernel_size=1, bias=False)
        self.act = nn.GELU()
        self.pw2 = nn.Conv3d(4*c, c, kernel_size=1, bias=False)

    def forward(self, x):
        return x + self.pw2(self.act(self.pw1(self.gn(self.dw(x)))))


class ConvNext3D(nn.Module):
    def __init__(self, n_outputs=10):
        super().__init__()
        self.stem = nn.Conv3d(1, 64, kernel_size=(1,4,4), stride=(1,4,4), bias=False)
        self.b1   = Block(64)

        self.d1   = nn.Conv3d(64, 128, kernel_size=2, stride=2, bias=False)
        self.b2   = Block(128)

        self.d2   = nn.Conv3d(128, 256, kernel_size=2, stride=2, bias=False)
        self.b3   = Block(256)

        self.d3   = nn.Conv3d(256, 512, kernel_size=2, stride=2, bias=False)
        self.b4   = Block(512)

        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc   = nn.Linear(512, n_outputs)

    def forward(self, x):
        x = self.b1(self.stem(x))
        x = self.b2(self.d1(x))
        x = self.b3(self.d2(x))
        x = self.b4(self.d3(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)





In [None]:

model = ConvNext3D(n_outputs=10).to(device).to(memory_format=torch.channels_last_3d)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

In [None]:


all_preds, all_targets, all_masks = [], [], []

with torch.no_grad():
    for vol, target, mask in test_loader:
        vol    = vol.to(device, non_blocking=True).to(memory_format=torch.channels_last_3d)
        target = target.to(device, non_blocking=True)
        mask   = mask.to(device, non_blocking=True)

        with autocast(enabled=(device.type == "cuda")):
            pred = model(vol)

        all_preds.append(pred.detach().cpu())
        all_targets.append(target.detach().cpu())
        all_masks.append(mask.detach().cpu())

all_preds   = torch.cat(all_preds, dim=0)
all_targets = torch.cat(all_targets, dim=0)
all_masks   = torch.cat(all_masks, dim=0)


In [None]:
from collections import Counter
Y_pr = all_preds.int()
# medail point is 6 
true_prediction  =  0 
false_prediction = 0
for i,x,z in zip (Y_pr,all_targets , all_masks):
    left =[] 
    right = []
    for j in range(len(i)):
        if j%2 == 0:
            left.append(i[j])
        else:
            right.append(i[j])
   
    # # ---------------------- Prediction Analysis ----------------------
    most_common_left = np.array(Counter(left).most_common(1)[0][0])
    most_common_right = np.array(Counter(right).most_common(1)[0][0])



    for iter in range(len(x)):
        if z[iter] == 1:
            if x[iter] == 0:
               continue
            # here for the left side
            elif iter % 2 == 0:
                if x[iter] == most_common_left  or x[iter]== most_common_left +1 or x[iter] == most_common_left -1:
                    true_prediction += 1
                else:
                    false_prediction += 1
            # here for the right side
            else:
                if x[iter] == most_common_right or x[iter] == most_common_right + 1 or x[iter] == most_common_right - 1:
                    true_prediction += 1
                else:
                    false_prediction += 1
          
           


  
accuracy_percent = true_prediction / (true_prediction + false_prediction) * 100
print(f"Accuracy of the model: {accuracy_percent:.2f}%")


In [None]:


# ---------------------- Most Common Left/Right Accuracy (±1 tolerance) ----------------------
Y_pr = all_preds.int()
true_prediction = 0
false_prediction = 0

for i, x, z in zip(Y_pr, all_targets, all_masks):
    left = [i[j].item() for j in range(len(i)) if j % 2 == 0]
    right = [i[j].item() for j in range(len(i)) if j % 2 == 1]

    most_common_left = Counter(left).most_common(1)[0][0]
    most_common_right = Counter(right).most_common(1)[0][0]

    for idx in range(len(x)):
        if z[idx] == 1:
            if x[idx] == 0:
                continue
            # Left side
            if idx % 2 == 0:
                if x[idx] in [most_common_left, most_common_left + 1, most_common_left - 1]:
                    true_prediction += 1
                else:
                    false_prediction += 1
            # Right side
            else:
                if x[idx] in [most_common_right, most_common_right + 1, most_common_right - 1]:
                    true_prediction += 1
                else:
                    false_prediction += 1

if (true_prediction + false_prediction) > 0:
    accuracy_percent = true_prediction / (true_prediction + false_prediction) * 100
    print(f"Most common left/right accuracy (±1): {accuracy_percent:.2f}%")
else:
    print("No valid predictions for most common left/right accuracy.")

In [None]:
if (true_prediction + false_prediction) > 0:
    accuracy_percent = true_prediction / (true_prediction + false_prediction) * 100
    precision_percent = true_prediction / (true_prediction + false_prediction) * 100  # For this context, precision = accuracy
    print(f"Most common left/right accuracy (±1): {accuracy_percent:.2f}%")
    print(f"Precision (±1): {precision_percent:.2f}%")
else:
    print("No valid predictions for most common left/right accuracy.")