<a href="https://colab.research.google.com/github/Laimo64/Laimo64/blob/main/ML_Cw2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!gdown --fuzzy https://drive.google.com/file/d/1Pf35070CMmKGFMTccZORTUb7bMROQj9W/view?usp=sharing

Downloading...
From (original): https://drive.google.com/uc?id=1Pf35070CMmKGFMTccZORTUb7bMROQj9W
From (redirected): https://drive.google.com/uc?id=1Pf35070CMmKGFMTccZORTUb7bMROQj9W&confirm=t&uuid=a17a4095-9a1b-40a1-a094-50a08e794019
To: /content/data.zip
100% 3.29G/3.29G [00:51<00:00, 64.4MB/s]


In [2]:
!unzip -q data.zip

In [6]:
# Processes data for Semi-Supervised Learning: Train, Test, and Unlabeled Data Preparation

import os
import nibabel as nib
import numpy as np
from scipy.ndimage import zoom

# Parameters
data_dir = "/content/data"  # Directory containing images and masks
output_dir = "/content/processed_data"  # Directory for processed outputs
unlabeled_dir = os.path.join(output_dir, "unlabeled")

num_train = 50  # Number of training images
num_test = 30  # Number of test images
target_shape = (64, 64, 32)  # Shape to re-process images to (x, y, slices)
slices_first = True  # Put the slice dimension first
roi_value = 6  # Value in the mask to extract as ROI
image_normalisation = True  # Whether to normalize images between 0 and 1

# Create necessary directories
os.makedirs(output_dir, exist_ok=True)
os.makedirs(unlabeled_dir, exist_ok=True)

# Check if output directory is empty
# if os.path.exists(output_dir) and os.listdir(output_dir):
#     raise ValueError(f"Output directory '{output_dir}' is not empty! Please clear the directory and rerun the script.")

# Initialize counters
train_count, test_count, unlabeled_count, total_processed = 0, 0, 0, 0

# Process files
for file_name in sorted(os.listdir(data_dir)):
    # Skip non-NIfTI files
    if not file_name.endswith(".nii"):
        continue

    # Identify image and mask files
    if "_img" in file_name:
        base_name = file_name.replace("_img.nii", "")
        mask_file_name = f"{base_name}_mask.nii"
        img_path = os.path.join(data_dir, file_name)
        mask_path = os.path.join(data_dir, mask_file_name)

        # Ensure mask file exists
        if not os.path.exists(mask_path):
            print(f"Mask file for {file_name} not found, skipping...")
            continue

        # Load image and mask
        img = nib.load(img_path)
        mask = nib.load(mask_path)

        img_data = img.get_fdata()
        mask_data = mask.get_fdata()

        # Skip files with insufficient slices
        if img_data.shape[2] < target_shape[2]:
            print(f"{file_name} has fewer slices ({img_data.shape[2]}) than required ({target_shape[2]}), skipping...")
            continue

        # Resample image and mask to target shape
        scale_factors = (
            target_shape[0] / img_data.shape[0],
            target_shape[1] / img_data.shape[1],
            1,  # Slice scaling handled by slicing
        )
        resampled_img = zoom(img_data, scale_factors, order=3)  # Cubic interpolation
        resampled_mask = zoom(mask_data, scale_factors, order=0)  # Nearest-neighbor interpolation

        # Select middle slices
        middle_slice_idx = img_data.shape[2] // 2
        slice_start = middle_slice_idx - target_shape[2] // 2
        slice_end = middle_slice_idx + target_shape[2] // 2
        resampled_img = resampled_img[:, :, slice_start:slice_end]
        resampled_mask = resampled_mask[:, :, slice_start:slice_end]

        # Skip if slicing went out of bounds
        if resampled_img.shape[2] != target_shape[2]:
            print(f"{file_name} could not be resized correctly to {target_shape}, skipping...")
            continue

        # Normalize image
        if image_normalisation:
            resampled_img = (resampled_img - np.min(resampled_img)) / (np.max(resampled_img) - np.min(resampled_img))

        # Create binary mask for ROI
        binary_mask = (resampled_mask == roi_value).astype(np.uint8)

        # Rearrange axes if slices_first is True
        if slices_first:
            resampled_img = np.transpose(resampled_img, (2, 0, 1))  # From (x, y, slices) to (slices, x, y)
            binary_mask = np.transpose(binary_mask, (2, 0, 1))

        # Save data into train, test, or unlabeled
        if total_processed < num_train:
            img_output_path = os.path.join(output_dir, f"image_train{train_count:02d}.npy")
            mask_output_path = os.path.join(output_dir, f"label_train{train_count:02d}.npy")
            np.save(img_output_path, resampled_img)
            np.save(mask_output_path, binary_mask)
            print(f"Processed TRAIN: {file_name} -> {os.path.basename(img_output_path)}, {os.path.basename(mask_output_path)}")
            train_count += 1
        elif total_processed < num_train + num_test:
            img_output_path = os.path.join(output_dir, f"image_test{test_count:02d}.npy")
            mask_output_path = os.path.join(output_dir, f"label_test{test_count:02d}.npy")
            np.save(img_output_path, resampled_img)
            np.save(mask_output_path, binary_mask)
            print(f"Processed TEST: {file_name} -> {os.path.basename(img_output_path)}, {os.path.basename(mask_output_path)}")
            test_count += 1
        else:
            unlabeled_output_path = os.path.join(unlabeled_dir, f"image_unlabeled{unlabeled_count:02d}.npy")
            np.save(unlabeled_output_path, resampled_img)
            print(f"Processed UNLABELED: {file_name} -> {os.path.basename(unlabeled_output_path)}")
            unlabeled_count += 1

        total_processed += 1


Processed TRAIN: 001000_img.nii -> image_train00.npy, label_train00.npy
001001_img.nii has fewer slices (25) than required (32), skipping...
Processed TRAIN: 001002_img.nii -> image_train01.npy, label_train01.npy
Processed TRAIN: 001003_img.nii -> image_train02.npy, label_train02.npy
Processed TRAIN: 001004_img.nii -> image_train03.npy, label_train03.npy
Processed TRAIN: 001005_img.nii -> image_train04.npy, label_train04.npy
001006_img.nii has fewer slices (28) than required (32), skipping...
Processed TRAIN: 002000_img.nii -> image_train05.npy, label_train05.npy
Processed TRAIN: 002002_img.nii -> image_train06.npy, label_train06.npy
Processed TRAIN: 002003_img.nii -> image_train07.npy, label_train07.npy
Processed TRAIN: 002004_img.nii -> image_train08.npy, label_train08.npy
Processed TRAIN: 002005_img.nii -> image_train09.npy, label_train09.npy
Processed TRAIN: 002006_img.nii -> image_train10.npy, label_train10.npy
002007_img.nii has fewer slices (23) than required (32), skipping...
0

U Net

In [7]:
import torch
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        # 定義編碼器
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        # 定義解碼器
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 64, kernel_size=2, stride=2),
            nn.Conv2d(64, 1, kernel_size=1)
        )

    def forward(self, x):
        x1 = self.encoder(x)
        x2 = self.decoder(x1)
        return x2

model = UNet()


Semi-supervised learning

In [9]:
import torch
import torch.nn.functional as F
from torch.optim import Adam

# 模型
model = UNet()
optimizer = Adam(model.parameters(), lr=1e-3)

# 標籤數據損失函數
def supervised_loss(predictions, labels, smoothing=0.1):
    n_classes = predictions.size(1)
    one_hot = F.one_hot(labels, num_classes=n_classes).float()
    one_hot = one_hot * (1 - smoothing) + smoothing / n_classes
    return F.cross_entropy(predictions, one_hot)

# 一致性損失函數
def consistency_loss(original_logits, perturbed_logits, gamma=1.0):
    T = 0.5  # 溫度
    original_probs = F.softmax(original_logits / T, dim=1)
    perturbed_probs = F.softmax(perturbed_logits / T, dim=1)
    return gamma * F.mse_loss(original_probs, perturbed_probs)

# 動態調整 λ
def get_lambda(epoch, max_epoch, max_lambda=10.0):
    return max_lambda * (epoch / max_epoch)

# 數據
labeled_data, labels = get_labeled_data()  # 有標籤數據
unlabeled_data = get_unlabeled_data()      # 無標籤數據

# 訓練參數
max_epoch = 100

# 訓練過程
for epoch in range(1, max_epoch + 1):
    model.train()
    lambda_ = get_lambda(epoch, max_epoch)

    # 擾動無標籤數據
    perturbed_unlabeled_data = augment(unlabeled_data)

    # 前向傳播
    labeled_logits = model(labeled_data)
    original_logits = model(unlabeled_data)
    perturbed_logits = model(perturbed_unlabeled_data)

    # 計算損失
    supervised_loss_value = supervised_loss(labeled_logits, labels)
    consistency_loss_value = consistency_loss(original_logits, perturbed_logits)

    total_loss = supervised_loss_value + lambda_ * consistency_loss_value

    # 優化
    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()

    # 日誌
    print(f"Epoch {epoch}/{max_epoch}, Supervised Loss: {supervised_loss_value.item():.4f}, "
          f"Consistency Loss: {consistency_loss_value.item():.4f}, Total Loss: {total_loss.item():.4f}")


NameError: name 'get_labeled_data' is not defined