In [1]:
import os
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, classification_report

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!cd /content/drive/MyDrive/data

In [4]:
import os

os.chdir('/content/drive/MyDrive/data')

# Kiểm tra lại
print(" Thư mục hiện tại:", os.getcwd())
print(" Danh sách file:")
print(os.listdir())

 Thư mục hiện tại: /content/drive/MyDrive/data
 Danh sách file:
['MRI (1).zip', 'DTI.zip', 'test.csv', 'train.csv', 'val.csv', 'MRI', 'DTI', 'best_model.pth']


In [None]:
import zipfile
import os

# Đường dẫn đến file ZIP trong Google Drive
zip_path1 = '/content/drive/MyDrive/data/MRI (1).zip'
zip_path2 = '/content/drive/MyDrive/data/DTI.zip'

# Đường dẫn đến thư mục lưu trữ
extract_dir_1 = '/content/drive/MyDrive/data/MRI/'
extract_dir_2 = '/content/drive/MyDrive/data/DTI/'

# Tạo thư mục lưu nếu chưa có
os.makedirs(extract_dir_1, exist_ok=True)

# Giải nén file 1
with zipfile.ZipFile(zip_path1, 'r') as zip_ref:
    zip_ref.extractall(extract_dir_1)

with zipfile.ZipFile(zip_path2, 'r') as zip_ref:
    zip_ref.extractall(extract_dir_2)
print("Đã giải nén xong!")

Đã giải nén xong!


In [5]:
import pandas as pd

train_df = pd.read_csv('train.csv')
test_df = pd.read_csv('test.csv')
val_df = pd.read_csv('val.csv')

In [6]:
!pip install nibabel



In [7]:
import os
import pandas as pd

def preprocessing(df):
    # Chuẩn hóa đường dẫn
    df['dti_link'] = df['dti_link'].str.replace(r'\\', '/', regex=True)
    df['mri_link'] = df['mri_link'].str.replace(r'\\', '/', regex=True)

    # Thêm base dir
    BASE_DIR = '/content/drive/MyDrive/'
    df['dti_link'] = df['dti_link'].apply(lambda p: os.path.join(BASE_DIR, p))
    df['mri_link'] = df['mri_link'].apply(lambda p: os.path.join(BASE_DIR, p))

    # Đổi tên cột
    df = df.rename(columns={
        'ptdobyy': 'birth_year',
        'ptgender': 'gender',
        'diagnosis': 'label'
    })

    # Chuẩn hóa tuổi: 2025 - năm sinh
    df['age'] = 2025 - pd.to_numeric(df['birth_year'], errors='coerce')
    df['label'] = df['label'] - 1

    # (Tùy chọn) Xóa cột năm sinh nếu không cần dùng nữa
    df.drop(columns=['birth_year'], inplace=True)

    return df


In [8]:
train_df = preprocessing(train_df)
test_df = preprocessing(test_df)
val_df = preprocessing(val_df)

In [9]:
train_df.head()

Unnamed: 0,label,gender,dti_link,mri_link,age
0,0.0,2.0,/content/drive/MyDrive/data/DTI/002_S_1280/Axi...,/content/drive/MyDrive/data/MRI/002_S_1280/Acc...,89.0
1,0.0,2.0,/content/drive/MyDrive/data/DTI/002_S_6030/Axi...,/content/drive/MyDrive/data/MRI/002_S_6030/Acc...,73.0
2,0.0,1.0,/content/drive/MyDrive/data/DTI/016_S_6381/Axi...,/content/drive/MyDrive/data/MRI/016_S_6381/Acc...,80.0
3,0.0,1.0,/content/drive/MyDrive/data/DTI/016_S_6941/Axi...,/content/drive/MyDrive/data/MRI/016_S_6941/Acc...,66.0
4,0.0,1.0,/content/drive/MyDrive/data/DTI/016_S_6971/Axi...,/content/drive/MyDrive/data/MRI/016_S_6971/Acc...,77.0


In [10]:
val_df.head()

Unnamed: 0,label,gender,dti_link,mri_link,age
0,0.0,2.0,/content/drive/MyDrive/data/DTI/041_S_4200/Axi...,/content/drive/MyDrive/data/MRI/041_S_4200/Acc...,84.0
1,0.0,2.0,/content/drive/MyDrive/data/DTI/002_S_0413/Axi...,/content/drive/MyDrive/data/MRI/002_S_0413/Acc...,96.0
2,0.0,1.0,/content/drive/MyDrive/data/DTI/036_S_4389/Axi...,/content/drive/MyDrive/data/MRI/036_S_4389/Acc...,95.0
3,0.0,2.0,/content/drive/MyDrive/data/DTI/041_S_4200/Axi...,/content/drive/MyDrive/data/MRI/041_S_4200/Acc...,84.0
4,0.0,2.0,/content/drive/MyDrive/data/DTI/024_S_5290/Axi...,/content/drive/MyDrive/data/MRI/024_S_5290/Acc...,79.0


In [11]:
test_df.head()

Unnamed: 0,label,gender,dti_link,mri_link,age
0,0.0,2.0,/content/drive/MyDrive/data/DTI/041_S_6785/Axi...,/content/drive/MyDrive/data/MRI/041_S_6785/Acc...,71.0
1,0.0,2.0,/content/drive/MyDrive/data/DTI/016_S_4951/Axi...,/content/drive/MyDrive/data/MRI/016_S_4951/Acc...,85.0
2,0.0,1.0,/content/drive/MyDrive/data/DTI/002_S_5178/Axi...,/content/drive/MyDrive/data/MRI/002_S_5178/Acc...,81.0
3,0.0,2.0,/content/drive/MyDrive/data/DTI/041_S_6786/Axi...,/content/drive/MyDrive/data/MRI/041_S_6786/Acc...,93.0
4,0.0,2.0,/content/drive/MyDrive/data/DTI/041_S_6192/Axi...,/content/drive/MyDrive/data/MRI/041_S_6192/Acc...,91.0


In [None]:
print(train_df['label'].value_counts())
print(test_df['label'].value_counts())
print(val_df['label'].value_counts())

label
0.0    47
1.0    42
2.0    26
Name: count, dtype: int64
label
0.0    29
1.0     9
2.0     5
Name: count, dtype: int64
label
0.0    16
1.0     8
2.0     6
Name: count, dtype: int64


In [55]:
import os
import glob
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, df, target_shape=(6, 182, 182)):
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.target_shape = target_shape

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        def load_nifti(path):
            # Nếu path là thư mục, tìm file .nii hoặc .nii.gz bên trong
            if os.path.isdir(path):
               nii_files = glob.glob(os.path.join(path, "*.nii*"))
               if not nii_files:
                  raise FileNotFoundError(f"No NIfTI file found in folder {path}")
               path = nii_files[0]

            arr = nib.load(path).get_fdata().astype(np.float32)
            return (arr - arr.mean()) / (arr.std() + 1e-8)

        mri_vol = load_nifti(row['mri_link'])
        dti_vol = load_nifti(row['dti_link'])

        # Resize bằng crop/pad
        def resize_vol(vol, shape):
            tz, ty, tx = shape
            z, y, x = vol.shape
            cz, cy, cx = min(z, tz), min(y, ty), min(x, tx)
            sz, sy, sx = (z-cz)//2, (y-cy)//2, (x-cx)//2
            dz, dy, dx = (tz-cz)//2, (ty-cy)//2, (tx-cx)//2
            out = np.zeros((tz, ty, tx), dtype=vol.dtype)
            out[dz:dz+cz, dy:dy+cy, dx:dx+cx] = vol[sz:sz+cz, sy:sy+cy, sx:sx+cx]
            return out

        mri_vol = resize_vol(mri_vol, self.target_shape)
        dti_vol = resize_vol(dti_vol, self.target_shape)

        # Chuyển thành tensor 5D: [C=1, D, H, W] mỗi modality
        mri_tensor = torch.from_numpy(mri_vol).unsqueeze(0)  # (1, D, H, W)
        dti_tensor = torch.from_numpy(dti_vol).unsqueeze(0)  # (1, D, H, W)

        # Demo và label
        age    = torch.tensor(row['age'], dtype=torch.float32)
        gender = torch.tensor(row['gender'], dtype=torch.float32)
        label  = torch.tensor(row['label'], dtype=torch.long)

        return {
            'mri':    mri_tensor,
            'dti':    dti_tensor,
            'age':    age,
            'gender': gender,
            'label':  label
        }


In [13]:
# Create datasets
train_dataset = CustomDataset(train_df)
val_dataset = CustomDataset(val_df)
test_dataset = CustomDataset(test_df)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [61]:
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.utils.class_weight import compute_class_weight


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


train_dataset = CustomDataset(train_df)
val_dataset = CustomDataset(val_df)
test_dataset = CustomDataset(test_df)



train_labels = [int(sample['label']) for sample in train_dataset]

all_classes = np.array([0, 1, 2])


class_weights = compute_class_weight(
    class_weight='balanced',
    classes=all_classes,
    y=train_labels
)

class_weights_tensor = torch.tensor(class_weights, dtype=torch.float).to(device)



class_sample_counts = np.bincount(train_labels)
weights_per_class = 1.0 / class_sample_counts
sample_weights = [weights_per_class[label] for label in train_labels]

sampler = WeightedRandomSampler(weights=sample_weights,
                                num_samples=len(sample_weights),
                                replacement=True)


train_loader = DataLoader(train_dataset, batch_size=32, sampler=sampler, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)



mô hình gốc (acc = 60)

In [62]:
import torch
import torch.nn as nn
import torchvision.models as models

class AlzheimerMultiModalModel(nn.Module):
    def __init__(self, num_classes=3):
        super(AlzheimerMultiModalModel, self).__init__()

        # Sử dụng ResNet18 cho ảnh, sửa conv1 để nhận 6 kênh (thay vì 3 RGB)
        self.sMRI_backbone = models.resnet18(pretrained=False)
        self.DTI_backbone = models.resnet18(pretrained=False)

        # Sửa conv1 để nhận input có 6 kênh (tương ứng với 6 lát D trong MRI/DTI)
        self.sMRI_backbone.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.DTI_backbone.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)

        # Đầu ra backbone là 128 chiều
        self.sMRI_backbone.fc = nn.Linear(self.sMRI_backbone.fc.in_features, 128)
        self.DTI_backbone.fc = nn.Linear(self.DTI_backbone.fc.in_features, 128)

        # MLP cho dữ liệu tabular (tuổi, giới tính)
        self.tabular_mlp = nn.Sequential(
            nn.Linear(2, 32),
            nn.ReLU(),
            nn.Linear(32, 32)
        )

        # Head phân loại
        self.classifier = nn.Sequential(
            nn.Linear(128 * 2 + 32, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, smri, dti, age, gender):
        # Convert volume: [B, 1, D, H, W] → [B, D, H, W]
        smri = smri.squeeze(1)
        dti = dti.squeeze(1)

        # Trích đặc trưng ảnh
        smri_features = self.sMRI_backbone(smri)
        dti_features = self.DTI_backbone(dti)

        tabular = torch.stack((age, gender), dim=1)
        tabular_features = self.tabular_mlp(tabular)

        # Kết hợp
        combined = torch.cat((smri_features, dti_features, tabular_features), dim=1)
        out = self.classifier(combined)
        return out


Đã cải tiến nhưng chưa hiệu quả

In [63]:
import torch
import torch.nn as nn
import torchvision.models as models

class AlzheimerMultiModalModel(nn.Module):
    def __init__(self, num_classes=3):
        super(AlzheimerMultiModalModel, self).__init__()

        # Backbone ResNet18 for MRI
        self.sMRI_backbone = models.resnet18(pretrained=True)
        self.sMRI_backbone.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.sMRI_backbone.fc = nn.Sequential(
            nn.Linear(self.sMRI_backbone.fc.in_features, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.4)
        )

        # Backbone ResNet18 for DTI
        self.DTI_backbone = models.resnet18(pretrained=True)
        self.DTI_backbone.conv1 = nn.Conv2d(6, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.DTI_backbone.fc = nn.Sequential(
            nn.Linear(self.DTI_backbone.fc.in_features, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.4)
        )

        # MLP for tabular data (age, gender)
        self.tabular_mlp = nn.Sequential(
            nn.Linear(2, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(128*2 + 32, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, smri, dti, age, gender):
        smri = smri.squeeze(1)  # [B, 1, D, H, W] -> [B, D, H, W]
        dti  = dti.squeeze(1)

        smri_feat = self.sMRI_backbone(smri)
        dti_feat  = self.DTI_backbone(dti)

        tabular = torch.stack((age, gender), dim=1)
        tabular_feat = self.tabular_mlp(tabular)

        combined = torch.cat((smri_feat, dti_feat, tabular_feat), dim=1)
        return self.classifier(combined)




In [64]:
model = AlzheimerMultiModalModel(num_classes=3)
model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1, weight=class_weights_tensor)


optimizer = optim.AdamW(model.parameters(), lr=1e-4,weight_decay=1e-5)




In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [43]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AlzheimerMultiModalModel(num_classes=3)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.00015)

In [None]:
from tqdm import tqdm

def train_model(model, train_loader, criterion, optimizer, device, num_epochs=30):
    best_val_acc = 0.0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch in tqdm(train_loader):
            mri    = batch['mri'].to(device).float()
            dti    = batch['dti'].to(device).float()
            age    = batch['age'].to(device).float()
            gender = batch['gender'].to(device).float()
            label  = batch['label'].to(device).long()

            optimizer.zero_grad()
            outputs = model(mri, dti, age, gender)
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()

        train_acc = correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {train_acc:.4f}")

        #Validation
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
           for batch in tqdm(val_loader):
                mri    = batch['mri'].to(device).float()
                dti    = batch['dti'].to(device).float()
                age    = batch['age'].to(device).float()
                gender = batch['gender'].to(device).float()
                label  = batch['label'].to(device).long()

                outputs = model(mri, dti, age, gender)
                _, predicted = torch.max(outputs, 1)
                val_total += label.size(0)
                val_correct += (predicted == label).sum().item()

        val_acc = val_correct / val_total
        print(f"Validation Accuracy: {val_acc:.4f}")

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
train_model(model, train_loader, criterion, optimizer, device)


100%|██████████| 4/4 [07:19<00:00, 109.92s/it]


Epoch [1/30], Loss: 1.1464, Accuracy: 0.3739


100%|██████████| 1/1 [01:49<00:00, 109.10s/it]


Validation Accuracy: 0.2667


100%|██████████| 4/4 [01:33<00:00, 23.28s/it]


Epoch [2/30], Loss: 1.0658, Accuracy: 0.5043


100%|██████████| 1/1 [00:18<00:00, 18.25s/it]


Validation Accuracy: 0.2667


100%|██████████| 4/4 [01:31<00:00, 22.91s/it]


Epoch [3/30], Loss: 0.7243, Accuracy: 0.7565


100%|██████████| 1/1 [00:15<00:00, 15.02s/it]


Validation Accuracy: 0.2667


100%|██████████| 4/4 [01:30<00:00, 22.73s/it]


Epoch [4/30], Loss: 0.4105, Accuracy: 0.8957


100%|██████████| 1/1 [00:16<00:00, 16.32s/it]


Validation Accuracy: 0.2667


100%|██████████| 4/4 [01:33<00:00, 23.47s/it]


Epoch [5/30], Loss: 0.1729, Accuracy: 0.9826


100%|██████████| 1/1 [00:15<00:00, 15.23s/it]


Validation Accuracy: 0.2667


100%|██████████| 4/4 [01:33<00:00, 23.47s/it]


Epoch [6/30], Loss: 0.0559, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.93s/it]


Validation Accuracy: 0.2333


100%|██████████| 4/4 [01:33<00:00, 23.42s/it]


Epoch [7/30], Loss: 0.0334, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.23s/it]


Validation Accuracy: 0.2667


100%|██████████| 4/4 [01:33<00:00, 23.28s/it]


Epoch [8/30], Loss: 0.0167, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.02s/it]


Validation Accuracy: 0.2667


100%|██████████| 4/4 [01:38<00:00, 24.74s/it]


Epoch [9/30], Loss: 0.0093, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.58s/it]


Validation Accuracy: 0.3333


100%|██████████| 4/4 [01:37<00:00, 24.35s/it]


Epoch [10/30], Loss: 0.0101, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.02s/it]


Validation Accuracy: 0.4667


100%|██████████| 4/4 [01:39<00:00, 24.80s/it]


Epoch [11/30], Loss: 0.0048, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.14s/it]


Validation Accuracy: 0.4667


100%|██████████| 4/4 [01:31<00:00, 22.84s/it]


Epoch [12/30], Loss: 0.0063, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.38s/it]


Validation Accuracy: 0.4667


100%|██████████| 4/4 [01:32<00:00, 23.17s/it]


Epoch [13/30], Loss: 0.0030, Accuracy: 1.0000


100%|██████████| 1/1 [00:14<00:00, 14.89s/it]


Validation Accuracy: 0.4667


100%|██████████| 4/4 [01:30<00:00, 22.61s/it]


Epoch [14/30], Loss: 0.0033, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.67s/it]


Validation Accuracy: 0.4333


100%|██████████| 4/4 [01:32<00:00, 23.09s/it]


Epoch [15/30], Loss: 0.0029, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.09s/it]


Validation Accuracy: 0.4667


100%|██████████| 4/4 [01:30<00:00, 22.51s/it]


Epoch [16/30], Loss: 0.0015, Accuracy: 1.0000


100%|██████████| 1/1 [00:19<00:00, 19.64s/it]


Validation Accuracy: 0.5000


100%|██████████| 4/4 [01:35<00:00, 23.82s/it]


Epoch [17/30], Loss: 0.0024, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.87s/it]


Validation Accuracy: 0.5000


100%|██████████| 4/4 [01:31<00:00, 22.90s/it]


Epoch [18/30], Loss: 0.0012, Accuracy: 1.0000


100%|██████████| 1/1 [00:14<00:00, 14.95s/it]


Validation Accuracy: 0.5000


100%|██████████| 4/4 [01:30<00:00, 22.73s/it]


Epoch [19/30], Loss: 0.0018, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.85s/it]


Validation Accuracy: 0.5000


100%|██████████| 4/4 [01:31<00:00, 22.93s/it]


Epoch [20/30], Loss: 0.0023, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.00s/it]


Validation Accuracy: 0.5000


100%|██████████| 4/4 [01:31<00:00, 22.88s/it]


Epoch [21/30], Loss: 0.0018, Accuracy: 1.0000


100%|██████████| 1/1 [00:16<00:00, 16.28s/it]


Validation Accuracy: 0.5667


100%|██████████| 4/4 [01:38<00:00, 24.68s/it]


Epoch [22/30], Loss: 0.0010, Accuracy: 1.0000


100%|██████████| 1/1 [00:16<00:00, 16.48s/it]


Validation Accuracy: 0.5667


100%|██████████| 4/4 [01:30<00:00, 22.73s/it]


Epoch [23/30], Loss: 0.0014, Accuracy: 1.0000


100%|██████████| 1/1 [00:14<00:00, 14.91s/it]


Validation Accuracy: 0.5667


100%|██████████| 4/4 [01:31<00:00, 22.90s/it]


Epoch [24/30], Loss: 0.0008, Accuracy: 1.0000


100%|██████████| 1/1 [00:16<00:00, 16.32s/it]


Validation Accuracy: 0.5667


100%|██████████| 4/4 [01:30<00:00, 22.70s/it]


Epoch [25/30], Loss: 0.0008, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.13s/it]


Validation Accuracy: 0.5333


100%|██████████| 4/4 [01:32<00:00, 23.03s/it]


Epoch [26/30], Loss: 0.0010, Accuracy: 1.0000


100%|██████████| 1/1 [00:16<00:00, 16.11s/it]


Validation Accuracy: 0.5000


100%|██████████| 4/4 [01:30<00:00, 22.68s/it]


Epoch [27/30], Loss: 0.0005, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.00s/it]


Validation Accuracy: 0.5000


100%|██████████| 4/4 [01:39<00:00, 24.81s/it]


Epoch [28/30], Loss: 0.0007, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.01s/it]


Validation Accuracy: 0.5333


100%|██████████| 4/4 [01:30<00:00, 22.74s/it]


Epoch [29/30], Loss: 0.0004, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.35s/it]


Validation Accuracy: 0.5333


100%|██████████| 4/4 [01:33<00:00, 23.33s/it]


Epoch [30/30], Loss: 0.0006, Accuracy: 1.0000


100%|██████████| 1/1 [00:15<00:00, 15.12s/it]

Validation Accuracy: 0.5333





đã cải tiến nhưng chưa cải thiện về kết quả chung

In [65]:
class EarlyStopping:
    def __init__(self, patience=5, delta=0.001, path='best_model.pth'):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.counter = 0
        self.early_stop = False
        self.path = path

    def __call__(self, val_acc, model):
        score = val_acc

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"⏳ EarlyStopping counter: {self.counter} / {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(model)
            self.counter = 0

    def save_checkpoint(self, model):
        torch.save(model.state_dict(), self.path)


In [None]:
import torch
import torch.nn as nn
from tqdm import tqdm
from sklearn.metrics import accuracy_score


def freeze_backbones(model):
    for param in model.sMRI_backbone.parameters():
        param.requires_grad = False
    for param in model.DTI_backbone.parameters():
        param.requires_grad = False


def unfreeze_backbones(model):
    for param in model.sMRI_backbone.parameters():
        param.requires_grad = True
    for param in model.DTI_backbone.parameters():
        param.requires_grad = True


def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, patience=7, warmup_epochs=3):
    best_val_acc = 0.0
    early_stop_counter = 0

    # Scheduler
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=3, verbose=True)

    for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}:")

        # Warmup giai đoạn đầu
        if epoch == 0:
            freeze_backbones(model)
        if epoch == warmup_epochs:
            unfreeze_backbones(model)

        model.train()
        running_loss = 0.0
        total = 0
        correct = 0

        for batch in tqdm(train_loader):
            mri = batch['mri'].to(device).float()
            dti = batch['dti'].to(device).float()
            age = batch['age'].to(device).float()
            gender = batch['gender'].to(device).float()
            label = batch['label'].to(device).long()

            optimizer.zero_grad()
            outputs = model(mri, dti, age, gender)
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()

        train_acc = correct / total
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {train_acc:.4f}")

        # Validation phase
        model.eval()
        val_correct = 0
        val_total = 0
        with torch.no_grad():
            for batch in tqdm(val_loader):
                mri = batch['mri'].to(device).float()
                dti = batch['dti'].to(device).float()
                age = batch['age'].to(device).float()
                gender = batch['gender'].to(device).float()
                label = batch['label'].to(device).long()

                outputs = model(mri, dti, age, gender)
                _, predicted = torch.max(outputs, 1)
                val_total += label.size(0)
                val_correct += (predicted == label).sum().item()

        val_acc = val_correct / val_total
        print(f"Validation Accuracy: {val_acc:.4f}")

        scheduler.step(val_acc)

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), "best_model.pth")
            print(" Validation improved. Model saved.")
            early_stop_counter = 0
        else:
            early_stop_counter += 1
            print(f" EarlyStopping counter: {early_stop_counter} / {patience}")
            if early_stop_counter >= patience:
                print(" Early stopping triggered.")
                break


In [67]:
train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs=50, patience=10)



Epoch 1:


100%|██████████| 4/4 [00:33<00:00,  8.38s/it]


Epoch [1/50], Loss: 1.2081, Accuracy: 0.3304


100%|██████████| 1/1 [00:12<00:00, 12.46s/it]


Validation Accuracy: 0.2667
 Validation improved. Model saved.
Epoch 2:


100%|██████████| 4/4 [00:38<00:00,  9.56s/it]


Epoch [2/50], Loss: 1.1731, Accuracy: 0.3478


100%|██████████| 1/1 [00:12<00:00, 12.21s/it]


Validation Accuracy: 0.2667
 EarlyStopping counter: 1 / 10
Epoch 3:


100%|██████████| 4/4 [00:35<00:00,  8.93s/it]


Epoch [3/50], Loss: 1.1608, Accuracy: 0.3652


100%|██████████| 1/1 [00:12<00:00, 12.39s/it]


Validation Accuracy: 0.2667
 EarlyStopping counter: 2 / 10
Epoch 4:


100%|██████████| 4/4 [00:34<00:00,  8.73s/it]


Epoch [4/50], Loss: 1.0667, Accuracy: 0.4087


100%|██████████| 1/1 [00:12<00:00, 12.52s/it]


Validation Accuracy: 0.2667
 EarlyStopping counter: 3 / 10
Epoch 5:


100%|██████████| 4/4 [00:35<00:00,  8.86s/it]


Epoch [5/50], Loss: 0.8264, Accuracy: 0.6000


100%|██████████| 1/1 [00:12<00:00, 12.24s/it]


Validation Accuracy: 0.3000
 Validation improved. Model saved.
Epoch 6:


100%|██████████| 4/4 [00:37<00:00,  9.50s/it]


Epoch [6/50], Loss: 0.6742, Accuracy: 0.6957


100%|██████████| 1/1 [00:12<00:00, 12.51s/it]


Validation Accuracy: 0.2333
 EarlyStopping counter: 1 / 10
Epoch 7:


100%|██████████| 4/4 [00:34<00:00,  8.72s/it]


Epoch [7/50], Loss: 0.5697, Accuracy: 0.8174


100%|██████████| 1/1 [00:12<00:00, 12.48s/it]


Validation Accuracy: 0.4667
 Validation improved. Model saved.
Epoch 8:


100%|██████████| 4/4 [00:38<00:00,  9.74s/it]


Epoch [8/50], Loss: 0.5042, Accuracy: 0.8696


100%|██████████| 1/1 [00:12<00:00, 12.50s/it]


Validation Accuracy: 0.4333
 EarlyStopping counter: 1 / 10
Epoch 9:


100%|██████████| 4/4 [00:34<00:00,  8.67s/it]


Epoch [9/50], Loss: 0.4730, Accuracy: 0.9391


100%|██████████| 1/1 [00:12<00:00, 12.61s/it]


Validation Accuracy: 0.4333
 EarlyStopping counter: 2 / 10
Epoch 10:


100%|██████████| 4/4 [00:34<00:00,  8.72s/it]


Epoch [10/50], Loss: 0.3722, Accuracy: 0.9304


100%|██████████| 1/1 [00:12<00:00, 12.45s/it]


Validation Accuracy: 0.4000
 EarlyStopping counter: 3 / 10
Epoch 11:


100%|██████████| 4/4 [00:36<00:00,  9.04s/it]


Epoch [11/50], Loss: 0.3645, Accuracy: 0.9478


100%|██████████| 1/1 [00:12<00:00, 12.31s/it]


Validation Accuracy: 0.3667
 EarlyStopping counter: 4 / 10
Epoch 12:


100%|██████████| 4/4 [00:35<00:00,  8.83s/it]


Epoch [12/50], Loss: 0.3105, Accuracy: 0.9826


100%|██████████| 1/1 [00:12<00:00, 12.31s/it]


Validation Accuracy: 0.3000
 EarlyStopping counter: 5 / 10
Epoch 13:


100%|██████████| 4/4 [00:34<00:00,  8.58s/it]


Epoch [13/50], Loss: 0.2773, Accuracy: 0.9826


100%|██████████| 1/1 [00:12<00:00, 12.46s/it]


Validation Accuracy: 0.4000
 EarlyStopping counter: 6 / 10
Epoch 14:


100%|██████████| 4/4 [00:35<00:00,  8.85s/it]


Epoch [14/50], Loss: 0.2860, Accuracy: 0.9739


100%|██████████| 1/1 [00:12<00:00, 12.41s/it]


Validation Accuracy: 0.4000
 EarlyStopping counter: 7 / 10
Epoch 15:


100%|██████████| 4/4 [00:36<00:00,  9.05s/it]


Epoch [15/50], Loss: 0.2376, Accuracy: 1.0000


100%|██████████| 1/1 [00:12<00:00, 12.21s/it]


Validation Accuracy: 0.3667
 EarlyStopping counter: 8 / 10
Epoch 16:


100%|██████████| 4/4 [00:34<00:00,  8.57s/it]


Epoch [16/50], Loss: 0.1974, Accuracy: 1.0000


100%|██████████| 1/1 [00:12<00:00, 12.21s/it]


Validation Accuracy: 0.3667
 EarlyStopping counter: 9 / 10
Epoch 17:


100%|██████████| 4/4 [00:33<00:00,  8.49s/it]


Epoch [17/50], Loss: 0.1899, Accuracy: 1.0000


100%|██████████| 1/1 [00:12<00:00, 12.13s/it]

Validation Accuracy: 0.3667
 EarlyStopping counter: 10 / 10
 Early stopping triggered.





In [52]:
import torch
from tqdm import tqdm
from sklearn.metrics import classification_report

# Tải mô hình đã lưu
model.load_state_dict(torch.load("best_model.pth"))
model.eval()

y_true = []
y_pred = []

with torch.no_grad():
    for batch in tqdm(test_loader):
        mri    = batch['mri'].to(device).float()
        dti    = batch['dti'].to(device).float()
        age    = batch['age'].to(device).float()
        gender = batch['gender'].to(device).float()
        label  = batch['label'].to(device).long()

        outputs = model(mri, dti, age, gender)
        _, predicted = torch.max(outputs, 1)
        y_true.extend(label.cpu().numpy())
        y_pred.extend(predicted.cpu().numpy())

# In classification report
print("\nClassification Report:")
try:
    print(classification_report(y_true, y_pred, target_names=label_encoder.classes_))
except:
    print(classification_report(y_true, y_pred))


100%|██████████| 2/2 [00:15<00:00,  7.73s/it]


Classification Report:
              precision    recall  f1-score   support

           0       0.71      0.83      0.76        29
           1       0.17      0.11      0.13         9
           2       0.33      0.20      0.25         5

    accuracy                           0.60        43
   macro avg       0.40      0.38      0.38        43
weighted avg       0.55      0.60      0.57        43




