In [None]:


import os, json, random, traceback
import numpy as np
import cv2
import pydicom
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler


# --------------------------- SEED ---------------------------
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# ---------------------- GPU SPEED SETTINGS ------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
if device.type == "cuda":
    print("GPU:", torch.cuda.get_device_name(0))
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

# ------------------------- PARAMS ---------------------------
NUM_SLICES = 18
IMG_SIZE   = 256

LEVELS = ["L1-L2","L2-L3","L3-L4","L4-L5","L5-S1"]
SIDES  = ["left","right"]
TARGET_KEYS = [f"{l}_{s}" for l in LEVELS for s in SIDES]  # 10 outputs

# ------------------------- PATHS ----------------------------
DICOM_ROOT = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\split_dataset\images\train_val"
LABEL_ROOT = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\split_dataset\label\train_val"
CACHE_ROOT = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\split_dataset\cache_npy"

os.makedirs(CACHE_ROOT, exist_ok=True)

# ---------------------- TRAINING CONFIG ---------------------
EPOCHS     = 100
LR         = 1e-4
BATCH_SIZE = 8

# IMPORTANT: start with 0 to DEBUG, then set 4 or 8 after stable
NUM_WORKERS = 0

BEST_MODEL_PATH = "best_model_convnext3d_regression.pth"


# ============================================================
#                   PREPROCESS + CACHE
# ============================================================

def extract_foreground(img, threshold=10):
    mask = img > threshold
    if not np.any(mask):
        return img
    coords = np.column_stack(np.where(mask))
    y_min, x_min = coords.min(axis=0)
    y_max, x_max = coords.max(axis=0)
    return img[y_min:y_max+1, x_min:x_max+1]

def extract_labels(json_path):
    with open(json_path, "r") as f:
        data = json.load(f)

    label = []
    mask  = []
    for key in TARGET_KEYS:
        lvl, side = key.split("_")
        coord = data.get(lvl, {}).get(side)
        if coord is None:
            label.append(0.0)
            mask.append(0.0)
        else:
            z = coord[2]
            z_index = int(round(z * NUM_SLICES))
            z_index = max(0, min(NUM_SLICES - 1, z_index))
            label.append(float(z_index))
            mask.append(1.0)

    return np.array(label, dtype=np.float32), np.array(mask, dtype=np.float32)

def load_and_preprocess_volume(dicom_dir):
    files = [os.path.join(dicom_dir, f) for f in os.listdir(dicom_dir) if f.endswith(".dcm")]
    if len(files) == 0:
        raise RuntimeError(f"No DICOM files in {dicom_dir}")

    def sort_key(p):
        try:
            return int(pydicom.dcmread(p, stop_before_pixels=True).InstanceNumber)
        except Exception:
            return p

    files = sorted(files, key=sort_key)

    # enforce NUM_SLICES
    if len(files) > NUM_SLICES:
        s = (len(files) - NUM_SLICES) // 2
        files = files[s:s + NUM_SLICES]
    elif len(files) < NUM_SLICES:
        files = files + [files[-1]] * (NUM_SLICES - len(files))

    vol = []
    for p in files:
        ds = pydicom.dcmread(p)
        img = ds.pixel_array.astype(np.float32)

        img = extract_foreground(img)
        img = (img - img.min()) / (img.max() - img.min() + 1e-8)

        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
        vol.append(img)

    vol = np.stack(vol).astype(np.float32)  # (D,H,W)
    if vol.shape != (NUM_SLICES, IMG_SIZE, IMG_SIZE):
        raise RuntimeError(f"Bad volume shape {vol.shape} in {dicom_dir}")

    return vol

def cache_patient(pid):
    out_dir = os.path.join(CACHE_ROOT, pid)
    os.makedirs(out_dir, exist_ok=True)

    vol_out  = os.path.join(out_dir, "volume.npy")
    lab_out  = os.path.join(out_dir, "label.npy")
    mask_out = os.path.join(out_dir, "mask.npy")

    # already cached
    if os.path.exists(vol_out) and os.path.exists(lab_out) and os.path.exists(mask_out):
        return True

    dicom_dir = os.path.join(DICOM_ROOT, pid)
    json_path = os.path.join(LABEL_ROOT, f"{pid}.json")

    if not os.path.isdir(dicom_dir) or not os.path.exists(json_path):
        return False

    try:
        vol = load_and_preprocess_volume(dicom_dir)
        label, mask = extract_labels(json_path)
        np.save(vol_out, vol)
        np.save(lab_out, label)
        np.save(mask_out, mask)
        return True
    except Exception as e:
        print(f"❌ Cache failed for {pid}: {e}")
        return False

def verify_cache(pid):
    """Return True if sample is readable and valid."""
    base = os.path.join(CACHE_ROOT, pid)
    try:
        vol = np.load(os.path.join(base, "volume.npy"))
        y   = np.load(os.path.join(base, "label.npy"))
        m   = np.load(os.path.join(base, "mask.npy"))

        if vol.shape != (NUM_SLICES, IMG_SIZE, IMG_SIZE):
            return False
        if y.shape != (10,) or m.shape != (10,):
            return False
        return True
    except Exception:
        return False


# ============================================================
#                   DATASET (SAFE)
# ============================================================

class CachedNPYDataset(Dataset):
    def __init__(self, ids):
        self.ids = ids

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        pid = self.ids[idx]
        base = os.path.join(CACHE_ROOT, pid)

        # IMPORTANT: if something is wrong, raise a clear error
        vol  = np.load(os.path.join(base, "volume.npy"))
        y    = np.load(os.path.join(base, "label.npy"))
        m    = np.load(os.path.join(base, "mask.npy"))

        vol = torch.from_numpy(vol).unsqueeze(0)  # (1,D,H,W)
        y   = torch.from_numpy(y)
        m   = torch.from_numpy(m)

        return vol, y, m


def safe_collate(batch):
    """Drop None/bad samples (extra safety)."""
    batch = [b for b in batch if b is not None]
    if len(batch) == 0:
        return None
    return torch.utils.data.default_collate(batch)


# ============================================================
#                    MODEL (FAST + STABLE)
# ============================================================

class Block(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.dw  = nn.Conv3d(c, c, kernel_size=(3,7,7), padding=(1,3,3), groups=c, bias=False)
        self.gn  = nn.GroupNorm(1, c)
        self.pw1 = nn.Conv3d(c, 4*c, kernel_size=1, bias=False)
        self.act = nn.GELU()
        self.pw2 = nn.Conv3d(4*c, c, kernel_size=1, bias=False)

    def forward(self, x):
        return x + self.pw2(self.act(self.pw1(self.gn(self.dw(x)))))

class ConvNext3D(nn.Module):
    def __init__(self, n_outputs=10):
        super().__init__()
        self.stem = nn.Conv3d(1, 64, kernel_size=(1,4,4), stride=(1,4,4), bias=False)
        self.b1   = Block(64)

        self.d1   = nn.Conv3d(64, 128, kernel_size=2, stride=2, bias=False)
        self.b2   = Block(128)

        self.d2   = nn.Conv3d(128, 256, kernel_size=2, stride=2, bias=False)
        self.b3   = Block(256)

        self.d3   = nn.Conv3d(256, 512, kernel_size=2, stride=2, bias=False)
        self.b4   = Block(512)

        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc   = nn.Linear(512, n_outputs)

    def forward(self, x):
        x = self.b1(self.stem(x))
        x = self.b2(self.d1(x))
        x = self.b3(self.d2(x))
        x = self.b4(self.d3(x))
        x = self.pool(x).flatten(1)
        return self.fc(x)


# ============================================================
#                     LOSS + EVAL + TRAIN
# ============================================================

def masked_mse_loss(pred, target, mask, eps=1e-6):
    return (((pred - target) ** 2) * mask).sum() / (mask.sum() + eps)

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    total, n = 0.0, 0
    for batch in loader:
        if batch is None:
            continue
        vol, y, m = batch
        vol = vol.to(device, non_blocking=True).to(memory_format=torch.channels_last_3d)
        y   = y.to(device, non_blocking=True)
        m   = m.to(device, non_blocking=True)
        with autocast(enabled=(device.type == "cuda")):
            pred = model(vol)
            loss = masked_mse_loss(pred, y, m)
        total += float(loss.item())
        n += 1
    return total / max(n, 1)


def main():
    # ------------------ CACHE STAGE ------------------
    all_pids = [p for p in os.listdir(DICOM_ROOT) if os.path.isdir(os.path.join(DICOM_ROOT, p))]
    print(f"Found {len(all_pids)} patient folders.")

    cached = []
    for pid in tqdm(all_pids, desc="Caching"):
        ok = cache_patient(pid)
        if ok:
            cached.append(pid)

    print(f"Cached candidates: {len(cached)}")

    # ------------------ VERIFY CACHE (CRITICAL) ------------------
    good = []
    bad = []
    for pid in tqdm(cached, desc="Verifying cache"):
        if verify_cache(pid):
            good.append(pid)
        else:
            bad.append(pid)

    print(f"✅ Good cached samples: {len(good)}")
    print(f"❌ Bad cached samples:  {len(bad)}")

    # Save bad list
    if len(bad) > 0:
        with open("bad_samples.txt", "w") as f:
            for pid in bad:
                f.write(pid + "\n")
        print("Saved bad sample IDs to bad_samples.txt")

    if len(good) < 5:
        raise RuntimeError("Too few valid samples. Check cache/paths.")

    # ------------------ SPLIT ------------------
    train_ids, val_ids = train_test_split(good, test_size=0.125, random_state=SEED)

    train_ds = CachedNPYDataset(train_ids)
    val_ds   = CachedNPYDataset(val_ids)

    # IMPORTANT: start with NUM_WORKERS=0 to see true error, then increase later
    train_loader = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        persistent_workers=(NUM_WORKERS > 0),
        collate_fn=safe_collate
    )
    val_loader = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=NUM_WORKERS,
        pin_memory=True,
        persistent_workers=(NUM_WORKERS > 0),
        collate_fn=safe_collate
    )

    # ------------------ MODEL ------------------
    model = ConvNext3D(n_outputs=10).to(device).to(memory_format=torch.channels_last_3d)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)
    scaler = GradScaler(enabled=(device.type == "cuda"))

    best_val = float("inf")
    train_losses, val_losses = [], []

    # ------------------ TRAIN ------------------
    for epoch in range(EPOCHS):
        model.train()
        epoch_loss, batches = 0.0, 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
        for batch in pbar:
            if batch is None:
                continue
            vol, y, m = batch

            vol = vol.to(device, non_blocking=True).to(memory_format=torch.channels_last_3d)
            y   = y.to(device, non_blocking=True)
            m   = m.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)

            with autocast(enabled=(device.type == "cuda")):
                pred = model(vol)
                loss = masked_mse_loss(pred, y, m)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            epoch_loss += float(loss.item())
            batches += 1
            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_train = epoch_loss / max(batches, 1)
        val_loss  = evaluate(model, val_loader)

        train_losses.append(avg_train)
        val_losses.append(val_loss)

        print(f"Epoch {epoch+1}: Train {avg_train:.4f} | Val {val_loss:.4f}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), BEST_MODEL_PATH)
            print(f"✅ Saved Best Model: {BEST_MODEL_PATH} (Val {best_val:.4f})")

    # ------------------ PLOT ------------------
    plt.figure(figsize=(8, 6))
    plt.plot(range(1, EPOCHS+1), train_losses, label="Train Loss")
    plt.plot(range(1, EPOCHS+1), val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Train/Val Loss")
    plt.grid(True)
    plt.legend()
    plt.show()

    print("✅ Done.")
    print("\nNEXT STEP:")
    print("Set NUM_WORKERS=4 or 8 after it runs without errors for speed.")

if __name__ == "__main__":
    main()


In [None]:
import os
from torch.utils.data import Dataset, DataLoader
import torch

# ===================== PATHS =====================
images_root = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\split_dataset\images"
dicom_root  = os.path.join(images_root, "test")   # <-- IMPORTANT
label_root  = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\split_dataset\label\test"

# ===================== DATASET =====================
class LumbarDataset(Dataset):
    def __init__(self, data):
        self.data = data  # list of tuples: (volume, (label, mask)) OR (volume, label) depending on your extract_labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        vol, labels = self.data[idx]

        # vol: (D,H,W) -> torch (1,D,H,W)
        vol = torch.from_numpy(vol).unsqueeze(0).float()

        # If extract_labels returns (label, mask)
        if isinstance(labels, tuple) and len(labels) == 2:
            y, m = labels
            y = torch.from_numpy(y).float()
            m = torch.from_numpy(m).float()
            return vol, y, m

        # If extract_labels returns only label
        y = torch.from_numpy(labels).float()
        return vol, y

# ===================== BUILD TEST DATA =====================
test_data = []
for patient_id in os.listdir(dicom_root):
    vol_path   = os.path.join(dicom_root, patient_id)          # folder that contains .dcm slices
    label_path = os.path.join(label_root, f"{patient_id}.json")

    if not os.path.isdir(vol_path):
        continue
    if not os.path.exists(label_path):
        continue

    volume = load_and_preprocess_volume(vol_path)

    # IMPORTANT:
    # - if your extract_labels returns (label, mask) keep as is
    # - if it returns only label, keep as is
    labels = extract_labels(label_path)

    test_data.append((volume, labels))

print(f"Total test samples: {len(test_data)}")

test_loader = DataLoader(
    LumbarDataset(test_data),
    batch_size=8,
    shuffle=False,
    num_workers=0,   # Windows safe
    pin_memory=True
)


In [None]:
# `---------------------- MODEL ----------------------
model = ConvNext3D(n_outputs=10)   # create model first
model = model.to(device).to(memory_format=torch.channels_last_3d)

# ---------------------- LOAD BEST MODEL ----------------------
model_path = r"D:\Submitted Matrial (conference&journal)\Sagittal Data Artical\V0.47 Dataset analysis\dataspitting\best_model_convnext3d_regression.pth"
state = torch.load(model_path, map_location=device)
model.load_state_dict(state)

model.eval()  # IMPORTANT


all_preds = []
all_targets = []
all_masks = []

with torch.no_grad():
    for vol, target, mask in test_loader:
        vol    = vol.to(device, non_blocking=True).to(memory_format=torch.channels_last_3d)
        target = target.to(device, non_blocking=True)
        mask   = mask.to(device, non_blocking=True)

        pred = model(vol)

        all_preds.append(pred.cpu())
        all_targets.append(target.cpu())
        all_masks.append(mask.cpu())

all_preds   = torch.cat(all_preds, dim=0)
all_targets = torch.cat(all_targets, dim=0)
all_masks   = torch.cat(all_masks, dim=0)

# ---------------------- ±1 slice PR / R / F1 / Accuracy ----------------------
pred_round = torch.round(all_preds)
valid_mask = all_masks > 0.5
within_tol = torch.abs(pred_round - all_targets) <= 1.0

TP = ((within_tol) & valid_mask).sum().item()
FN = ((~within_tol) & valid_mask).sum().item()

FP = 0  # no "negative prediction" in your setup

precision = TP / max(TP + FP, 1)
recall    = TP / max(TP + FN, 1)
f1_score  = 2 * precision * recall / max(precision + recall, 1e-8)
accuracy  = TP / max(TP + FN, 1)

print(f"Precision (±1 slice): {precision * 100:.2f}%")
print(f"Recall    (±1 slice): {recall * 100:.2f}%")
print(f"F1-score  (±1 slice): {f1_score * 100:.2f}%")
print(f"Accuracy  (±1 slice): {accuracy * 100:.2f}%")
