# 1. 导入 + 设置路径

In [9]:
import os
import sys
import torch
from torch.utils.data import DataLoader
from torch import optim
import matplotlib.pyplot as plt

# 添加自定义模块路径
sys.path.append(os.path.abspath("../utils"))
sys.path.append(os.path.abspath("../models"))
sys.path.append(os.path.abspath("../scripts"))
sys.path.append("/root/Implementation") 
sys.path.append("../dataset")  # 如果你在 notebooks/ 中

from train import train
from simple_cnn import Simple3DCNNEncoder
from decoder3D import LightDecoder
from encoder3D import SparseEncoder
from AnatoMask import SparK
from real_dataset import Real3DMedicalDataset, SegmentationPatchDataset,RareClassPatchDataset
from checkpoint import save_checkpoint, load_checkpoint
from visualize import show_slice_comparison, show_slice_with_error
from scripts.preprocess_segmentation_amos import preprocess_segmentation_amos
from segmentation_dataset import SegmentationDataset
from segmentation_model import SegmentationModel
from losses import segmentation_loss
from train_segmentation import train_full_segmentation
from FullVolumeDataset import FullVolumeDataset


from STUNet_head import STUNet  # ✅ 不是 STUNet.py 的
from encoder3D import SparseEncoder
from decoder3D import SMiMTwoDecoder  # 或 LightDecoder
from AnatoMask import SparK

The history saving thread hit an unexpected error (OperationalError('unable to open database file')).History will not be written to the database.


In [10]:
import sys
print(sys.path)

['/root/miniconda/lib/python38.zip', '/root/miniconda/lib/python3.8', '/root/miniconda/lib/python3.8/lib-dynload', '', '/root/miniconda/lib/python3.8/site-packages', '/root/Implementation/utils', '/root/Implementation/models', '/root/Implementation/scripts', '/root/Implementation', '../dataset', '/root/Implementation/utils', '/root/Implementation/models', '/root/Implementation/scripts', '/root/Implementation', '../dataset', '/root/Implementation/utils', '/root/Implementation/models', '/root/Implementation/scripts', '/root/Implementation', '../dataset']


# 数据处理

# 2. 超参数设置

In [11]:
# ====== Config ======
PATCH_SIZE = (64, 64, 64)
BATCH_SIZE = 2
EPOCHS = 10
LR = 1e-4
DATA_DIR = "/root/lanyun-tmp/amos_dataset/amos22/npy_patches"
CHECKPOINT_PATH = "/root/lanyun-tmp/checkpoints/my_model.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#  3. 模型构建

In [12]:
# 定义输入 patch 尺寸
PATCH_SIZE = (64, 64, 64)
DIMS = [32, 64, 128, 256, 512, 512]
POOL_KERNELS = [[2, 2, 2]] * 5
CONV_KERNELS = [[3, 3, 3]] * 6

# 构造 STU-Net
cnn = STUNet(
    input_channels=1,
    num_classes=1,
    depth=[1, 1, 1, 1, 1, 1],
    dims=DIMS,
    pool_op_kernel_sizes=POOL_KERNELS,
    conv_kernel_sizes=CONV_KERNELS,
    enable_deep_supervision=False  # 预训练时不需要
)

# 包装成 sparse encoder
encoder = SparseEncoder(cnn, input_size=PATCH_SIZE, sbn=False)

# 构造 decoder
decoder = SMiMTwoDecoder(up_sample_ratio=encoder.downsample_ratio,
                         width=encoder.enc_feat_map_chs[-1],  # = 512
                         sbn=False)

# 构造完整 SparK 模型
model = SparK(
    sparse_encoder=encoder,
    dense_decoder=decoder,
    mask_ratio=0.6  # SparK 预训练遮盖比例
).to(DEVICE)


optimizer = optim.Adam(model.parameters(), lr=LR)

[SparK.__init__, densify 1/5]: use nn.Identity() as densify_proj
[SparK.__init__, densify 2/5]: densify_proj(ksz=3, #para=1.77M)
[SparK.__init__, densify 3/5]: densify_proj(ksz=3, #para=0.44M)
[SparK.__init__, densify 4/5]: densify_proj(ksz=3, #para=0.11M)
[SparK.__init__, densify 5/5]: densify_proj(ksz=3, #para=0.03M)
[SparK.__init__] dims of mask_tokens=(512, 256, 128, 64, 32)


# 4. 数据加载

In [13]:
dataset = Real3DMedicalDataset(DATA_DIR)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# 5. 启动训练

In [None]:
train(model, optimizer, DEVICE, epochs=EPOCHS, dataloader=loader, save_path=CHECKPOINT_PATH)

Traceback (most recent call last):
  File "/root/miniconda/lib/python3.8/multiprocessing/queues.py", line 239, in _feed
    obj = _ForkingPickler.dumps(obj)
  File "/root/miniconda/lib/python3.8/multiprocessing/reduction.py", line 51, in dumps
    cls(buf, protocol).dump(obj)
  File "/root/miniconda/lib/python3.8/site-packages/torch/multiprocessing/reductions.py", line 370, in reduce_storage
    df = multiprocessing.reduction.DupFd(fd)
  File "/root/miniconda/lib/python3.8/multiprocessing/reduction.py", line 198, in DupFd
    return resource_sharer.DupFd(fd)
  File "/root/miniconda/lib/python3.8/multiprocessing/resource_sharer.py", line 53, in __init__
    self._id = _resource_sharer.register(send, close)
  File "/root/miniconda/lib/python3.8/multiprocessing/resource_sharer.py", line 77, in register
    self._start()
  File "/root/miniconda/lib/python3.8/multiprocessing/resource_sharer.py", line 130, in _start
    self._listener = Listener(authkey=process.current_process().authkey)
  F

# 6. 加载训练好的模型（如已训练过）

In [None]:
# 如果你想加载已有的模型，设置这个变量为 True
LOAD_PRETRAINED = True

if LOAD_PRETRAINED:
    model, optimizer, start_epoch = load_checkpoint(model, optimizer, CHECKPOINT_PATH, DEVICE)
    print(f"✅ 成功加载模型！从 epoch {start_epoch+1} 继续")
else:
    print("⚠️ 未加载任何 checkpoint，重新开始训练")

# 7. 可视化重建效果

In [None]:
model.eval()
sample = dataset[0].unsqueeze(0).to(DEVICE)
mask = model.mask(sample.shape[0], sample.device)

with torch.no_grad():
    original, masked, reconstructed = model(sample, active_b1ff=mask, vis=True)

show_slice_with_error(original, masked, reconstructed, axis=2)
# show_slice_comparison(original, masked, reconstructed, axis=2) 如果不需要重建误差图

# 8. 分割微调

## 8.1 数据整理

## 8.2 构建 segmentation Dataset + DataLoader

In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# 确保你之前已经定义了 FullVolumeDataset 类

# 替换为你的实际路径
image_dir = "/root/lanyun-tmp/amos_dataset/amos22/imagesTr"
label_dir = "/root/lanyun-tmp/amos_dataset/amos22/labelsTr"

# 创建 Dataset 实例
dataset = FullVolumeDataset(image_dir=image_dir, label_dir=label_dir)

# 用 DataLoader 测试读取
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# 取出一条数据
image, label = next(iter(loader))

print(f"Image shape: {image.shape}")   # [B, C, H, W, D]
print(f"Label shape: {label.shape}")   # [B, H, W, D]

# 可视化中间切片（中间的 axial 切片）
mid_slice = image[0, 0, :, :, image.shape[-1] // 2].numpy()
plt.imshow(mid_slice, cmap='gray')
plt.title("Middle slice of image volume")
plt.axis('off')
plt.show()


In [None]:
import numpy as np

# 提取整个标签 volume，维度: [1, H, W, D]
label_volume = label[0].numpy()

unique_classes = np.unique(label_volume)
print(f"Unique class IDs in this label: {unique_classes}")
print(f"Total classes: {len(unique_classes)}")

## 8.3 构建 Segmentation 微调模型（预训练 encoder + 分割 head）

In [None]:
import gc

gc.collect()
torch.cuda.empty_cache()

model = SegmentationModel(
    input_size=(768, 768, 90),
    num_classes=16,
    checkpoint_path="/root/lanyun-tmp/checkpoints/anatomask_real.pth"
).cuda()


## 8.4 模型训练

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
from torch.utils.data import DataLoader
dataset = RareClassPatchDataset(
    image_dir="/root/lanyun-tmp/amos_dataset/segmentation_npy/images",
    label_dir="/root/lanyun-tmp/amos_dataset/segmentation_npy/labels",
    focus_foreground=True  # ✅ 启用“只取含前景”的采样方式
)
loader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

train_full_segmentation(
    model, loader, optimizer, device="cuda",
    epochs=10,
    loss_fn="dice_ce",
    save_path="/root/lanyun-tmp/checkpoints/segmentation_finetune.pth"
)


## 8.5 Sliding Window 全图预测 + 拼接还原

In [None]:
import numpy as np
import torch
import torch.nn.functional as F

def predict_full_volume(model, volume_tensor, patch_size=(128, 128, 64), stride=(64, 64, 32), num_classes=16):
    """
    Args:
        model: segmentation model
        volume_tensor: [1, D, H, W] torch.Tensor (single volume)
        patch_size: spatial window
        stride: spatial step for sliding
        num_classes: number of segmentation classes

    Returns:
        full_pred: [num_classes, D, H, W] torch.Tensor
    """
    model.eval()
    device = next(model.parameters()).device
    volume_tensor = volume_tensor.unsqueeze(0).to(device)  # → [1, 1, D, H, W]

    C, D, H, W = volume_tensor.shape[1:]
    output_volume = torch.zeros((1, num_classes, D, H, W), device=device)
    count_map = torch.zeros((1, 1, D, H, W), device=device)

    for z in range(0, D - patch_size[0] + 1, stride[0]):
        for y in range(0, H - patch_size[1] + 1, stride[1]):
            for x in range(0, W - patch_size[2] + 1, stride[2]):
                patch = volume_tensor[:, :, z:z+patch_size[0], y:y+patch_size[1], x:x+patch_size[2]]
                with torch.no_grad():
                    pred = model(patch)  # [1, C, z, y, x]
                    output_volume[:, :, z:z+patch_size[0], y:y+patch_size[1], x:x+patch_size[2]] += pred
                    count_map[:, :, z:z+patch_size[0], y:y+patch_size[1], x:x+patch_size[2]] += 1

    # Avoid division by zero
    count_map = torch.clamp(count_map, min=1.0)
    output_volume = output_volume / count_map

    return output_volume.squeeze(0).softmax(dim=0)  # [C, D, H, W]


In [None]:
import nibabel as nib

nii_path = "/root/lanyun-tmp/amos_dataset/amos22/imagesTr/amos_0001.nii.gz"
nii_img = nib.load(nii_path)
volume = torch.from_numpy(nii_img.get_fdata()).float().permute(2, 0, 1)  # [D, H, W]

volume = (volume - volume.mean()) / (volume.std() + 1e-5)
volume = volume.unsqueeze(0).unsqueeze(0)  # [1, 1, D, H, W]

# Run prediction
pred = predict_full_volume(model, volume.squeeze(0), patch_size=(128, 128, 64), stride=(64, 64, 32), num_classes=16)
mask = pred.argmax(dim=0).cpu().numpy()  # [D, H, W]
print(mask.shape)

## 8.6 结果保存

In [None]:
import nibabel as nib
import os

def save_mask_as_nii(mask_np, reference_nii_path, output_path):
    """
    Args:
        mask_np: np.ndarray, shape: [D, H, W] (uint8 or int)
        reference_nii_path: 用于对齐空间信息的原始图像路径
        output_path: 保存路径
    """
    ref_img = nib.load(reference_nii_path)
    affine = ref_img.affine
    header = ref_img.header

    mask_img = nib.Nifti1Image(mask_np.astype(np.uint8), affine, header)
    nib.save(mask_img, output_path)
    print(f"✅ Saved mask to {output_path}")


In [None]:
save_mask_as_nii(
    mask,  # ← 来自 predict_full_volume 的 argmax 结果
    reference_nii_path="/root/lanyun-tmp/amos_dataset/amos22/imagesTr/amos_0001.nii.gz",
    output_path="/root/Implementation/predicted_masks/amos_0001_pred.nii.gz"
)


## 8.7 Dice per class 评估

In [None]:
def compute_dice_per_class(pred_mask, gt_mask, num_classes=16, ignore_background=False):
    """
    Args:
        pred_mask: np.ndarray, shape [D, H, W], predicted label map
        gt_mask: np.ndarray, same shape, ground truth
        num_classes: total number of classes (e.g. 16)
        ignore_background: whether to exclude class 0

    Returns:
        dict: {class_idx: dice_score}
    """
    dice_dict = {}
    classes = range(1, num_classes) if ignore_background else range(num_classes)

    for c in classes:
        pred_c = (pred_mask == c).astype(np.uint8)
        gt_c = (gt_mask == c).astype(np.uint8)

        intersect = (pred_c * gt_c).sum()
        denom = pred_c.sum() + gt_c.sum()

        dice = (2. * intersect) / denom if denom > 0 else 1.0
        dice_dict[c] = dice

    return dice_dict


In [None]:
gt_path = "/root/lanyun-tmp/amos_dataset/amos22/labelsTr/amos_0001.nii.gz"
gt = nib.load(gt_path).get_fdata().astype(np.uint8)
gt = np.transpose(gt, (2, 0, 1))  # [D, H, W]

dice_scores = compute_dice_per_class(mask, gt, num_classes=16)
for c, d in dice_scores.items():
    print(f"Class {c}: Dice = {d:.4f}")

## 8.8 可视化分割结果

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def visualize_segmentation_slice(image_3d, mask_3d, slice_index=None, title=None):
    """
    Args:
        image_3d: np.ndarray, shape [D, H, W]
        mask_3d: np.ndarray, shape [D, H, W]
        slice_index: which slice to visualize (default: center)
        title: optional title for the plot
    """
    D = image_3d.shape[0]
    if slice_index is None:
        slice_index = D // 2

    img_slice = image_3d[slice_index]
    mask_slice = mask_3d[slice_index]

    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(img_slice, cmap='gray')
    plt.title(f"Original Slice {slice_index}")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(img_slice, cmap='gray')
    plt.imshow(mask_slice, cmap='jet', alpha=0.5)  # ⬅️ overlay 分割 mask
    plt.title(title or "Overlay Segmentation")
    plt.axis('off')
    plt.show()


In [None]:
# 原图 [D, H, W]
image = nib.load("/root/lanyun-tmp/amos_dataset/amos22/imagesTr/amos_0001.nii.gz").get_fdata()
image = np.transpose(image, (2, 0, 1))  # → [D, H, W]

# 可视化中间切片
visualize_segmentation_slice(image_3d=image, mask_3d=mask, slice_index=1)

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

label_dir = "/root/lanyun-tmp/amos_dataset/segmentation_npy/labels"
cls_hist = np.zeros(16)

for f in os.listdir(label_dir):
    y = np.load(os.path.join(label_dir, f))
    hist = np.bincount(y.flatten(), minlength=16)
    cls_hist += hist

plt.bar(range(16), cls_hist)
plt.xlabel("Class")
plt.ylabel("Pixel Count")
plt.title("Class Distribution")
plt.show()
