In [None]:
# Install requirements
pip install nibabel numpy scipy pandas torch torchvision sklearn matplotlib seaborn alive_progress

In [49]:
# Import
import nibabel as nib
import numpy as np
import pandas as pd
from scipy.ndimage import zoom, shift, rotate
import os

# Torch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F

# 
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, roc_curve, auc
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from alive_progress import alive_bar


In [None]:
# Preprocess
def resize_and_normalize_fMRI(data, affine, target_shape=(64, 64, 48)):
    """
    Resize và chuẩn hóa cường độ cho ảnh fMRI.
    
    Parameters:
    - data: Mảng 3D của ảnh fMRI
    - affine: Ma trận affine của ảnh gốc
    - target_shape: Kích thước mục tiêu (mặc định: 64x64x48)
    
    Returns:
    - normalized_data: Mảng đã resize và chuẩn hóa
    """
    # Resize ảnh về kích thước mục tiêu
    zoom_factors = [t / s for t, s in zip(target_shape, data.shape)]
    resized_data = zoom(data, zoom_factors, order=3)  # Nội suy tricubic

    # Chuẩn hóa cường độ
    mask = resized_data != 0
    mean_intensity = np.mean(resized_data[mask]) if np.any(mask) else 0
    std_intensity = np.std(resized_data[mask]) if np.any(mask) else 1
    normalized_data = (resized_data - mean_intensity) / std_intensity
    
    return normalized_data

def augment_fMRI(input_path, output_dir, target_shape=(64, 64, 48)):
    """
    Augment ảnh fMRI bằng zoom, shift, và rotate, tạo ra 3 ảnh mới và giữ ảnh gốc.
    
    Parameters:
    - input_path: Đường dẫn đến file NIfTI đầu vào
    - output_dir: Thư mục để lưu các file NIfTI đã augment
    - target_shape: Kích thước mục tiêu (mặc định: 64x64x48)
    
    Returns:
    - List các tên file đã tạo
    """
    try:
        # Đọc ảnh fMRI
        img = nib.load(input_path)
        data = img.get_fdata()
        affine = img.affine

        # Tạo tên file cơ bản
        base_name = os.path.splitext(os.path.basename(input_path))[0]
        if base_name.endswith('.nii'):
            base_name = os.path.splitext(base_name)[0]  # Xử lý .nii.gz

        output_files = []

        # 1. Ảnh gốc
        normalized_data = resize_and_normalize_fMRI(data, affine, target_shape)
        output_path = os.path.join(output_dir, f"{base_name}_original.nii.gz")
        nib.save(nib.Nifti1Image(normalized_data, affine), output_path)
        output_files.append(f"{base_name}_original.nii.gz")
        print(f"Đã lưu: {output_path}")

        # 2. Zoom
        zoom_factor = np.random.uniform(0.9, 1.1)  # Zoom ngẫu nhiên 90%-110%
        zoomed_data = zoom(data, zoom_factor, order=3)
        normalized_zoomed = resize_and_normalize_fMRI(zoomed_data, affine, target_shape)
        output_path = os.path.join(output_dir, f"{base_name}_zoomed.nii.gz")
        nib.save(nib.Nifti1Image(normalized_zoomed, affine), output_path)
        output_files.append(f"{base_name}_zoomed.nii.gz")
        print(f"Đã lưu: {output_path}")

        # 3. Shift
        shift_pixels = [np.random.uniform(-5, 5) for _ in range(3)]  # Dịch chuyển ±5 voxel
        shifted_data = shift(data, shift_pixels, order=3)
        normalized_shifted = resize_and_normalize_fMRI(shifted_data, affine, target_shape)
        output_path = os.path.join(output_dir, f"{base_name}_shifted.nii.gz")
        nib.save(nib.Nifti1Image(normalized_shifted, affine), output_path)
        output_files.append(f"{base_name}_shifted.nii.gz")
        print(f"Đã lưu: {output_path}")

        # 4. Rotate
        angle = np.random.uniform(-10, 10)  # Xoay ngẫu nhiên ±10 độ
        rotated_data = rotate(data, angle, axes=(0, 1), reshape=False, order=3)
        normalized_rotated = resize_and_normalize_fMRI(rotated_data, affine, target_shape)
        output_path = os.path.join(output_dir, f"{base_name}_rotated.nii.gz")
        nib.save(nib.Nifti1Image(normalized_rotated, affine), output_path)
        output_files.append(f"{base_name}_rotated.nii.gz")
        print(f"Đã lưu: {output_path}")

        return output_files

    except Exception as e:
        print(f"Lỗi khi xử lý {input_path}: {str(e)}")
        return []

# Đường dẫn đầu vào và đầu ra
input_dir = "./data"
output_dir = "./data_augmented"

# Tạo thư mục đầu ra
os.makedirs(output_dir, exist_ok=True)

# Lấy danh sách file NIfTI
files = [f for f in os.listdir(input_dir) if f.endswith((".nii", ".nii.gz"))]


for file in files:
    input_path = os.path.join(input_dir, file)
    output_files = augment_fMRI(input_path, output_dir)

In [None]:
# Create csv file
def create_augmented_csv(original_csv_path, output_dir, output_csv_path):
    df = pd.read_csv(original_csv_path)
    if not all(col in df.columns for col in ['filename', 'image_id', 'subject_id', 'label']):
        raise ValueError("CSV phải chứa các cột 'filename', 'image_id', 'subject_id', 'label'")
    new_rows = []
    
    for file in os.listdir(output_dir):
        if file.endswith((".nii", ".nii.gz")):
            base_name = file.replace('_original', '').replace('_zoomed', '').replace('_shifted', '').replace('_rotated', '').replace('.nii', '').replace('.nii.gz', '').replace('_rest', '').replace('_filtered', '')
            base_name = os.path.splitext(base_name)[0]
            image_id = base_name.split('-')
            if len(image_id) == 1:
                image_id = image_id[0]
            else:
                image_id = image_id[1]
            if image_id[0] == "I":
                image_id = image_id.replace('I', '')
            matching_rows = df[df["image_id"] == int(image_id)]
            if not matching_rows.empty:
                row = matching_rows.iloc[0]
                new_rows.append({
                    'filename': file,
                    'image_id': image_id,
                    'subject_id': row['subject_id'],
                    'label': row['label']
                })

    new_df = pd.DataFrame(new_rows)
    new_df.to_csv(output_csv_path, index=False)
    print(f"Đã tạo CSV mới: {output_csv_path}")

output_dir = "./data_augmented"
output_csv_path = "./labels_augmented.csv"
original_csv_path = "./final_nii_metadata.csv"

create_augmented_csv(original_csv_path, output_dir, output_csv_path)

In [None]:
# Split train_test_val
def split_dataset(csv_path, output_dir, train_size=1272, val_size=70, test_size=50):
    df = pd.read_csv(csv_path)
    
    # Lấy các file gốc (không augment) cho val/test
    original_files = df[df['filename'].str.contains('_original')]
    
    # Chia val/test từ file gốc
    val_test_df = original_files.sample(n=val_size + test_size, random_state=42)
    val_df, test_df = train_test_split(val_test_df, test_size=test_size, random_state=42)
    
    # Tập train bao gồm tất cả file (gốc + augment), trừ các file trong val/test
    train_df = df[~df['filename'].isin(val_test_df['filename'])]
    
    # Lấy đúng số mẫu train
    train_df = train_df.sample(n=train_size, random_state=42)
    
    # Lưu các file CSV
    os.makedirs(output_dir, exist_ok=True)
    train_df.to_csv(os.path.join(output_dir, 'train.csv'), index=False)
    val_df.to_csv(os.path.join(output_dir, 'val.csv'), index=False)
    test_df.to_csv(os.path.join(output_dir, 'test.csv'), index=False)
    print(f"Đã tạo: train.csv ({len(train_df)}), val.csv ({len(val_df)}), test.csv ({len(test_df)})")

# Chạy chia dữ liệu
output_dir = "./data_splits"
split_dataset("./labels_augmented.csv", output_dir)

Đã tạo: train.csv (1272), val.csv (70), test.csv (50)


In [51]:
def normalize_intensity(img_tensor, normalization="mean"):
    if normalization == "mean":
        mask = img_tensor.ne(0.0)
        desired = img_tensor[mask]
        mean_val, std_val = desired.mean(), desired.std()
        img_tensor = (img_tensor - mean_val) / std_val
    elif normalization == "max":
        MAX, MIN = img_tensor.max(), img_tensor.min()
        img_tensor = (img_tensor - MIN) / (MAX - MIN)
    return img_tensor

class loader(Dataset):
    def __init__(self, csv_path, data_dir, transform=None):
        self.df = pd.read_csv(csv_path)
        self.data_dir = data_dir
        self.transform = transform
        self.label_map = {1: 0, 2: 1, 3: 2}  # Ánh xạ: 1->0 (normal), 2->1 (MCI), 3->2 (AD)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        filename = self.df.iloc[index]['filename']
        label = self.label_map[self.df.iloc[index]['label']]
        img_path = os.path.join(self.data_dir, filename)
        img = nib.load(img_path).get_fdata()
        
        if img.shape != (64, 64, 48):
            zoom_factors = [t / s for t, s in zip((64, 64, 48), img.shape)]
            img = zoom(img, zoom_factors, order=3)
        
        img_tensor = torch.tensor(img, dtype=torch.float32).unsqueeze(0)
        img_tensor = normalize_intensity(img_tensor)
        
        if self.transform:
            img_tensor = self.transform(img_tensor)
        
        return img_tensor, torch.tensor(label, dtype=torch.long)

# 4. Mô hình CNN
def compute_output_size(i, K, P, S):
    return int(((i - K + 2*P) / S) + 1)

class CNN_8CL_B:
    def __init__(self):
        self.input_dim = [64, 64, 48]
        self.out_channels = [8, 8, 16, 16, 32, 32, 64, 64]
        self.in_channels = [1] + self.out_channels[:-1]
        self.n_conv = len(self.out_channels)
        self.kernels = [(3, 3, 3)] * self.n_conv
        self.pooling = [(4, 4, 4), (0, 0, 0), (3, 3, 3), (0, 0, 0), (2, 2, 2),
                        (0, 0, 0), (2, 2, 2), (0, 0, 0)]
        for i in range(self.n_conv):
            for d in range(3):
                if self.pooling[i][d] != 0:
                    self.input_dim[d] = compute_output_size(self.input_dim[d], self.pooling[i][d], 0, self.pooling[i][d])
        out = self.input_dim[0] * self.input_dim[1] * self.input_dim[2]
        self.fweights = [self.out_channels[-1] * out, 3]  # [64, 3] cho 3 lớp
        self.dropout = 0.0

class CNN(nn.Module):
    def __init__(self, param):
        super(CNN, self).__init__()
        self.embedding = nn.ModuleList()
        for i in range(param.n_conv):
            pad = tuple([int((k-1)/2) for k in param.kernels[i]])
            if param.pooling[i] != (0, 0, 0):
                self.embedding.append(nn.Sequential(
                    nn.Conv3d(param.in_channels[i], param.out_channels[i], param.kernels[i], stride=(1, 1, 1), padding=pad, bias=False),
                    nn.BatchNorm3d(param.out_channels[i]),
                    nn.ReLU(inplace=True),
                    nn.MaxPool3d(param.pooling[i], stride=param.pooling[i])))
            else:
                self.embedding.append(nn.Sequential(
                    nn.Conv3d(param.in_channels[i], param.out_channels[i], param.kernels[i], stride=(1, 1, 1), padding=pad, bias=False),
                    nn.BatchNorm3d(param.out_channels[i]),
                    nn.ReLU(inplace=True)))
        self.ReLU = nn.ReLU(inplace=True)
        self.Dropout = nn.Dropout(p=param.dropout)
        self.f = nn.ModuleList()
        for i in range(len(param.fweights)-1):
            self.f.append(nn.Linear(param.fweights[i], param.fweights[i+1]))
    
    def forward(self, x, return_conv=False):
        out = self.embedding[0](x)
        if return_conv:
            all_layers = [out]
        for i in range(1, len(self.embedding)):
            out = self.embedding[i](out)
            if return_conv:
                all_layers.append(out)
        out = out.view(out.size(0), -1)
        for fc in self.f[:-1]:
            out = fc(out)
            out = self.ReLU(out)
            out = self.Dropout(out)
        out = self.f[-1](out)
        if return_conv:
            return F.softmax(out, dim=1), all_layers
        else:
            return F.softmax(out, dim=1)

In [53]:
def train_model(model, train_loader, val_loader, device, max_epochs=200, max_errors=20, lr=0.001, weight_decay=0.01):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    
    best_val_acc = 0.0
    patience = max_errors
    errors = 0
    history = {'train_loss': [], 'val_acc': []}
    
    for epoch in range(max_epochs):
        model.train()
        train_loss = 0.0
        with alive_bar(len(train_loader), bar='classic', spinner='arrow') as bar:
            for x, y in train_loader:
                x, y = x.to(device), y.to(device)
                optimizer.zero_grad()
                outputs = model(x)
                loss = criterion(outputs, y)
                loss.backward()
                optimizer.step()
                train_loss += loss.item()
                bar()
        
        train_loss /= len(train_loader)
        history['train_loss'].append(train_loss)
        
        model.eval()
        val_preds, val_labels = [], []
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                _, preds = torch.max(outputs, 1)
                val_preds.extend(preds.cpu().numpy())
                val_labels.extend(y.cpu().numpy())
        
        val_acc = accuracy_score(val_labels, val_preds)
        history['val_acc'].append(val_acc)
        
        print(f"Epoch {epoch+1}/{max_epochs}, Train Loss: {train_loss:.4f}, Val Acc: {val_acc:.4f}")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model.pt')
            errors = 0
        else:
            errors += 1
            if errors >= max_errors:
                print("Early stopping!")
                break
    
    return history

def evaluate_model(model, test_loader, device, saver_path):
    model.eval()
    labels, prob0, prob1, prob2, preds = predict(model, test_loader, device, return_prob=True)
    
    test_acc = accuracy_score(labels, preds)
    print(f"Test Accuracy: {test_acc:.4f}")
    
    plot_confusion_matrix(labels, preds, saver_path)
    
    # ROC Curve cho từng lớp
    fpr = {}
    tpr = {}
    roc_auc = {}
    for i, cls in enumerate(['normal', 'MCI', 'AD']):
        fpr[i], tpr[i], _ = roc_curve(labels == i, [prob0, prob1, prob2][i])
        roc_auc[i] = auc(fpr[i], tpr[i])
    plot_auc_curve(fpr, tpr, roc_auc, saver_path)
    
    # Classification Report
    report = pd.DataFrame({
        'Precision': [0.0, 0.0, 0.0],
        'Recall': [0.0, 0.0, 0.0],
        'F1-Score': [0.0, 0.0, 0.0]
    }, index=['normal', 'MCI', 'AD'])
    cm = confusion_matrix(labels, preds)
    for i, cls in enumerate(['normal', 'MCI', 'AD']):
        precision = cm[i, i] / cm[:, i].sum() if cm[:, i].sum() > 0 else 0
        recall = cm[i, i] / cm[i, :].sum() if cm[i, :].sum() > 0 else 0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
        report.loc[cls] = [precision, recall, f1]
    plot_complete_report(report, saver_path, labels=['normal', 'MCI', 'AD'])

def predict(net, data_loader, device, return_prob=True):
    predictions, labels = torch.tensor([]), torch.tensor([])
    all_prob0, all_prob1, all_prob2 = torch.tensor([]), torch.tensor([]), torch.tensor([])
    print('\nModel prediction. Total number of steps:', len(data_loader))
    with alive_bar(len(data_loader), bar='classic', spinner='arrow') as bar:
        for _, (x, y) in enumerate(data_loader):
            x = x.to(device)
            output = net(x).detach().cpu()
            if return_prob:
                prob0 = output[:, 0]
                prob1 = output[:, 1]
                prob2 = output[:, 2]
                all_prob0 = torch.cat((all_prob0, prob0), dim=0)
                all_prob1 = torch.cat((all_prob1, prob1), dim=0)
                all_prob2 = torch.cat((all_prob2, prob2), dim=0)
            y_pred = torch.argmax(output, dim=1)
            predictions = torch.cat((predictions, y_pred), dim=0)
            labels = torch.cat((labels, y), dim=0)
            bar()
        if return_prob:
            return labels.numpy(), all_prob0.numpy(), all_prob1.numpy(), all_prob2.numpy(), predictions.numpy()
        else:
            return labels.numpy(), predictions.numpy()

def plot_complete_report(data, saver_path, labels=None):
    if labels:
        data['Class'] = labels
        data = data.set_index('Class')
    plt.figure(figsize=(10,3.5))
    ax = sns.heatmap(data, annot=data, fmt='.2f', square=True, cmap='Blues')
    ax.set(ylabel='Classes')
    plt.title('Classification Report')
    plt.savefig(f'{saver_path}/Evaluation.png')
    plt.close()

def plot_auc_curve(fpr, tpr, roc_auc, saver_path):
    plt.figure(figsize=(7, 7))
    for i, cls in enumerate(['normal', 'MCI', 'AD']):
        plt.plot(fpr[i], tpr[i], label=f"AUC {cls}={roc_auc[i]:.4f}")
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.title('AUC curve')
    plt.legend(loc='lower right')
    plt.savefig(f'{saver_path}/AUC.png')
    plt.close()

def plot_confusion_matrix(y_test, y_pred, saver_path):
    cm = confusion_matrix(y_test, y_pred)
    disp = ConfusionMatrixDisplay(cm, display_labels=['normal', 'MCI', 'AD'])
    disp.plot(cmap=plt.cm.Blues)
    plt.savefig(f'{saver_path}/confusion_matrix.png')
    plt.close()


# Chia dữ liệu
split_dataset(output_csv_path, "./data_splits")

# Đào tạo và đánh giá
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CNN(CNN_8CL_B()).to(device)
train_loader = DataLoader(loader("./data_splits/train.csv", "./data_augmented"), batch_size=32, shuffle=True)
val_loader = DataLoader(loader("./data_splits/val.csv", "./data_augmented"), batch_size=32)
test_loader = DataLoader(loader("./data_splits/test.csv", "./data_augmented"), batch_size=32)

os.makedirs("./results", exist_ok=True)
history = train_model(model, train_loader, val_loader, device)
evaluate_model(model, test_loader, device, "./results")

Đã tạo: train.csv (1272), val.csv (70), test.csv (50)
Epoch 1/200, Train Loss: 1.0591, Val Acc: 0.3286
Epoch 2/200, Train Loss: 0.9816, Val Acc: 0.4000
Epoch 3/200, Train Loss: 0.9125, Val Acc: 0.3429
Epoch 4/200, Train Loss: 0.8503, Val Acc: 0.7286
Epoch 5/200, Train Loss: 0.7846, Val Acc: 0.6857
Epoch 6/200, Train Loss: 0.7696, Val Acc: 0.8143
Epoch 7/200, Train Loss: 0.7314, Val Acc: 0.8286
Epoch 8/200, Train Loss: 0.7057, Val Acc: 0.4429
Epoch 9/200, Train Loss: 0.6860, Val Acc: 0.4714
Epoch 10/200, Train Loss: 0.6853, Val Acc: 0.7714
Epoch 11/200, Train Loss: 0.6713, Val Acc: 0.5143
Epoch 12/200, Train Loss: 0.6592, Val Acc: 0.8714
Epoch 13/200, Train Loss: 0.6358, Val Acc: 0.7000
Epoch 14/200, Train Loss: 0.6207, Val Acc: 0.7714
Epoch 15/200, Train Loss: 0.6139, Val Acc: 0.6000
Epoch 16/200, Train Loss: 0.6096, Val Acc: 0.4429
Epoch 17/200, Train Loss: 0.5954, Val Acc: 0.8857
Epoch 18/200, Train Loss: 0.5930, Val Acc: 0.9857
Epoch 19/200, Train Loss: 0.5965, Val Acc: 1.0000
Epoch