<a href="https://colab.research.google.com/github/Laimo64/COMP0249_24-25/blob/main/AI_small_pelvis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [15]:
!gdown --folder https://drive.google.com/drive/folders/1zERaJ8JhQsCd7-Hbx-OWImvafc47LHTc?usp=drive_link

Retrieving folder contents
Retrieving folder 1YglgePc5HR7xwJsJ4JtxZwCxS_3w4BG1 1PA001
Processing file 14NBV6rs0GBG153iHd7XEqYg0oTKqUJLT ct.nii.gz
Processing file 1PLBoY1bmNQsijtF5sy1xy9CHekEdjRw0 mask.nii.gz
Processing file 1TEC7hyly_wefXVEMkLu6O5KGvktqoHwl mr.nii.gz
Retrieving folder 1G7EXDcDfwwyHtwuvLekjGBsfY7HMX0Vw 1PA004
Processing file 1Kg0rpar2VH96dqIdTPw-xZ40MdeihDWh ct.nii.gz
Processing file 1TOh8uqcwE-2JDSGoupPCpPE_VDIkFaS6 mask.nii.gz
Processing file 1x1ZklUzMMR6bsSUhMcVvZ5mZKLMNKN4m mr.nii.gz
Retrieving folder 1dtZxQG_73GK2E0b12bsW3Kl-1pGNIx3F 1PA005
Processing file 1_Rs9AI_Kcxt1oTuY6RTSOSSpMjwbU0gA ct.nii.gz
Processing file 1vxbkohAz2M-A0EWiNA6-kNb_xlpPdu48 mask.nii.gz
Processing file 1ndEp6swdkzjVH7IkexwWn-HFp2ubd5Gm mr.nii.gz
Retrieving folder 1ylCONRohODlpknK20m6pkNp7aVjVL8s7 1PA009
Processing file 1PmIxMGp8Cxcy8hfDH84JfPtZJV0VPbp5 ct.nii.gz
Processing file 1yMO7J9DSMBnDsZuNG3xbF8KEhbJdgSYl mask.nii.gz
Processing file 1prLf0ASYhXG-FOu1iEtUnFOokn1V55Rs mr.nii.gz
Retrievin

In [None]:
# !unzip -q pelvis_small.zip

In [24]:
import os
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset
import torch.nn.functional as F
from scipy.ndimage import zoom

class MRCTDataset(Dataset):
    def __init__(self, data_dir, target_size=(128, 500, 310)):
        """
        初始化數據集
        Args:
            data_dir (str): MRI 和 CT 數據的根目錄。
            target_size (tuple): 將 MRI 和 CT 影像調整為的固定尺寸。
        """
        self.data_dir = data_dir
        self.target_size = target_size
        self.samples = [
            os.path.join(root)
            for root, _, files in os.walk(data_dir)
            if "mr.nii.gz" in files and "ct.nii.gz" in files
        ]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        """
        根據索引返回一組 MR 和 CT 影像。
        Args:
            idx (int): 數據的索引。
        Returns:
            Tuple[torch.Tensor, torch.Tensor]: MR 和 CT 的張量形式。
        """
        sample_path = self.samples[idx]

        # 加載 MRI 和 CT 影像
        mr = nib.load(os.path.join(sample_path, "mr.nii.gz")).get_fdata()
        ct = nib.load(os.path.join(sample_path, "ct.nii.gz")).get_fdata()

        # Z-score 標準化
        mr = self._normalize(mr)
        ct = self._normalize(ct)

        # 調整或填充影像大小
        mr = self._resize_or_pad(mr, self.target_size)
        ct = self._resize_or_pad(ct, self.target_size)

        # 轉換為 PyTorch 張量並增加通道維度
        mr = torch.tensor(mr, dtype=torch.float32).unsqueeze(0)  # (1, D, H, W)
        ct = torch.tensor(ct, dtype=torch.float32).unsqueeze(0)  # (1, D, H, W)

        return mr, ct

    def _normalize(self, image):
        """
        Z-score 標準化影像數據。
        Args:
            image (np.ndarray): 輸入影像。
        Returns:
            np.ndarray: 標準化的影像。
        """
        if np.std(image) != 0:
            return (image - np.mean(image)) / np.std(image)
        else:
            return image

    def _resize_or_pad(self, image, desired_shape):
        """
        調整影像大小或填充至固定大小。
        Args:
            image (np.ndarray): 輸入影像。
            desired_shape (tuple): 目標大小。
        Returns:
            np.ndarray: 調整或填充後的影像。
        """
        current_shape = image.shape
        scale = [d / c for d, c in zip(desired_shape, current_shape)]
        resized_image = zoom(image, scale, order=1)  # 調整大小

        # 填充影像至目標大小
        padded_image = np.zeros(desired_shape, dtype=resized_image.dtype)
        pad_slices = tuple(slice(0, min(dim, resized_image.shape[i])) for i, dim in enumerate(desired_shape))
        padded_image[pad_slices] = resized_image[:desired_shape[0], :desired_shape[1], :desired_shape[2]]

        return padded_image


In [16]:
from sklearn.model_selection import train_test_split
import shutil

data_path = "/content/pelvis_smalllll"
output_root = "/content/split"  # 輸出目錄

# 創建輸出資料夾
os.makedirs(output_root, exist_ok=True)
os.makedirs(os.path.join(output_root, "train"), exist_ok=True)
os.makedirs(os.path.join(output_root, "validation"), exist_ok=True)
os.makedirs(os.path.join(output_root, "test"), exist_ok=True)

# 獲取所有樣本資料夾名稱
samples = [name for name in os.listdir(data_path) if os.path.isdir(os.path.join(data_path, name))]

# 按照 70:15:15 的比例分割
train_samples, test_samples = train_test_split(samples, test_size=0.3, random_state=42)
validation_samples, test_samples = train_test_split(test_samples, test_size=0.5, random_state=42)

print(f"Total samples: {len(samples)}")
print(f"Train samples: {len(train_samples)}, Validation samples: {len(validation_samples)}, Test samples: {len(test_samples)}")

# 定義拷貝函數
def move_samples(samples, output_dir):
    for sample in samples:
        src_path = os.path.join(data_path, sample)  # 原始路徑
        dst_path = os.path.join(output_dir, sample)  # 目標路徑
        if os.path.exists(dst_path):
            print(f"Sample {sample} already exists in {output_dir}, skipping.")
            continue
        shutil.copytree(src_path, dst_path)  # 拷貝整個資料夾
        # print(f"Moved {sample} to {output_dir}")

# 將樣本移動到各自的資料夾
move_samples(train_samples, os.path.join(output_root, "train"))
move_samples(validation_samples, os.path.join(output_root, "validation"))
move_samples(test_samples, os.path.join(output_root, "test"))

print("Data split and moved successfully!")


Total samples: 10
Train samples: 7, Validation samples: 1, Test samples: 2
Data split and moved successfully!


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F


# Swin Transformer Block (簡化版，適用於3D)
class SwinTransformerBlock(nn.Module):
    def __init__(self, dim, input_size):
        super(SwinTransformerBlock, self).__init__()
        self.norm = nn.LayerNorm(dim)
        self.fc = nn.Linear(dim, dim)
        self.window_size = input_size // 4  # 分割窗口，根據3D的尺寸調整

    def forward(self, x):
        b, c, d, h, w = x.shape  # 3D 輸入
        x = x.view(b, -1, c)  # 展平為序列
        x = self.norm(x)
        x = self.fc(x)
        return x.view(b, c, d, h, w)  # 還原為 3D

# MSEP 網路
class MSEP(nn.Module):
    def __init__(self):
        super(MSEP, self).__init__()
        # Encoder 部分
        self.encoder = nn.Sequential(
            nn.Conv3d(1, 64, kernel_size=3, stride=1, padding=1),  # 使用3D卷積
            nn.ReLU(),
            nn.Conv3d(64, 128, kernel_size=3, stride=1, padding=1),  # 使用3D卷積
            nn.ReLU()
        )
        # Skip 連接部分 + RDSformer
        self.skip = SwinTransformerBlock(128, input_size=160)
        # Decoder 部分
        self.decoder = nn.Sequential(
            nn.ConvTranspose3d(128, 64, kernel_size=3, stride=1, padding=1),  # 使用3D反卷積
            nn.ReLU(),
            nn.ConvTranspose3d(64, 1, kernel_size=3, stride=1, padding=1)  # 使用3D反卷積
        )

    def forward(self, x):
        enc = self.encoder(x)
        skip = self.skip(enc)  # 加入 skip connection
        dec = self.decoder(skip)
        return dec

# Initialize model
model = MSEP()

# Test the model with dummy data (e.g., [Batch Size, Channel, Depth, Height, Width])
dummy_input = torch.randn(1, 1, 128, 128, 128)  # 假設數據大小是 [1, 1, 128, 128, 128]
output = model(dummy_input)

print("Output shape:", output.shape)  # 應該返回符合預期的 3D 輸出


Output shape: torch.Size([1, 1, 128, 128, 128])


In [27]:
import torch
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

class EarlyStopping:
    def __init__(self, patience=5, verbose=True, delta=0.00005, path="checkpoint.pt"):
        """
        Args:
            patience (int): 容忍驗證損失未改善的次數 (default: 5)
            verbose (bool): 是否打印相關資訊 (default: False)
            delta (float): 最小改善幅度，只有超過此值才算改善 (default: 0)
            path (str): 模型權重保存路徑 (default: "checkpoint.pt")
        """
        self.patience = patience
        self.verbose = verbose
        self.delta = delta
        self.path = path
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float("inf")

    def __call__(self, val_loss, model):
        # 計算當前得分（驗證損失的負值，因為越小越好）
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """保存當前模型權重"""
        if self.verbose:
            print(f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model...")
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


def visualize_results(input_image, target_image, predicted_image, epoch, idx, batch_idx=0):
    """
    視覺化輸入影像、目標影像與預測影像
    """
    input_image = input_image[batch_idx, 0].cpu().detach().numpy()  # [depth, height, width]
    target_image = target_image[batch_idx, 0].cpu().detach().numpy()
    predicted_image = predicted_image[batch_idx, 0].cpu().detach().numpy()

    num_slices = input_image.shape[0]  # 影像深度（切片數量）

    fig, axes = plt.subplots(num_slices, 3, figsize=(10, num_slices * 3))
    axes = np.atleast_2d(axes)  # 確保 axes 是 2D

    for i in range(num_slices):
        axes[i, 0].imshow(input_image[i], cmap="gray")
        axes[i, 0].set_title(f"Input MR - Slice {i}")
        axes[i, 0].axis("off")

        axes[i, 1].imshow(target_image[i], cmap="gray")
        axes[i, 1].set_title(f"Target CT - Slice {i}")
        axes[i, 1].axis("off")

        axes[i, 2].imshow(predicted_image[i], cmap="gray")
        axes[i, 2].set_title(f"Predicted CT - Slice {i}")
        axes[i, 2].axis("off")

    plt.suptitle(f"Epoch {epoch}, Batch {idx}, Patient {batch_idx}")
    plt.tight_layout()
    plt.show()


Checkpoint

In [28]:
import os

def save_checkpoint(epoch, val_loss, optimizer, model, path):
    filename = path
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_loss': val_loss,
    }, filename)
    print(f"Checkpoint saved at {filename}")

def load_checkpoint(model, optimizer, path, device='cuda'):
    if not os.path.exists(path):
          print(f"Checkpoint file '{path}' does not exist. Starting from scratch.")
          return model, optimizer, 0, float('inf')  # 返回初始值
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    val_loss = checkpoint['val_loss']
    print(f"Checkpoint loaded from {path}, starting from epoch {epoch+1}" )
    return model, optimizer, epoch, val_loss


In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 優化器與損失函數
optimizer = optim.Adam(model.parameters(), lr=0.00005)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)  # 學習率調整
criterion = nn.L1Loss()

# 訓練參數
epochs = 1
best_loss = float("inf")

train_path = "/content/split/train"
valid_path = "/content/split/validation"
test_path = "/content/split/test"

train_dataset = MRCTDataset(train_path)
valid_dataset = MRCTDataset(valid_path)
test_dataset = MRCTDataset(test_path)


train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=2)
val_loader = DataLoader(valid_dataset, batch_size=2, shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=2)

for mr, ct in train_loader:
    print(f"MR shape: {mr.shape}, CT shape: {ct.shape}")
    break

# early_stopping = EarlyStopping(patience=5, verbose=True, path="best_model.pt")

for epoch in range(epochs):
    # 訓練階段
    model.train()
    running_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")
    for mr, ct in progress_bar:
        mr, ct = mr.to(device), ct.to(device)
        optimizer.zero_grad()
        output = model(mr)
        loss = criterion(output, ct)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        progress_bar.set_postfix(loss=loss.item())
    # if epoch % 5 == 0:  # 每 5 個 epoch 可視化一次
    #     visualize_results(mr, ct, output, epoch, idx=1)
    # visualize_results(mr, ct, output, epoch, idx=1)

    avg_train_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs}, Avg Train Loss: {avg_train_loss}")

    # 驗證階段
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for mr, ct in val_loader:
            mr, ct = mr.to(device), ct.to(device)
            output = model(mr)
            loss = criterion(output, ct)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{epochs}, Avg Validation Loss: {avg_val_loss}")
    # save_checkpoint(epoch, avg_val_loss, g_optimizer, model, path="generator_checkpoint.pth")



visualize_results(mr, ct, output, epoch, idx=1, batch_idx=0)


AttributeError: partially initialized module 'torch._dynamo' has no attribute 'config' (most likely due to a circular import)

In [20]:
import nibabel as nib

# 讀取 NIfTI 檔案
nii_file = nib.load("/content/pelvis_smalllll/1PA001/ct.nii.gz")

# 獲取影像數據的形狀
image_shape = nii_file.shape

print("影像大小:", image_shape)


影像大小: (565, 338, 146)


In [23]:
import nibabel as nib

# 讀取 NIfTI 檔案
nii_file = nib.load("/content/pelvis_smalllll/1PA014/ct.nii.gz")

# 獲取影像數據的形狀
image_shape = nii_file.shape

print("影像大小:", image_shape)

影像大小: (568, 392, 147)
