In [1]:
import os
import cv2
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
from sklearn.metrics import confusion_matrix, classification_report

In [2]:
import pandas as pd

train_df = pd.read_csv('/kaggle/input/data-medical/train.csv')
test_df = pd.read_csv('/kaggle/input/data-medical/test.csv')
val_df = pd.read_csv('/kaggle/input/data-medical/val.csv')

In [3]:
!pip install nibabel



In [4]:
import os
import pandas as pd

def preprocessing(df):
    # Chuẩn hóa đường dẫn
    df['dti_link'] = df['dti_link'].str.replace(r'\\', '/', regex=True)
    df['mri_link'] = df['mri_link'].str.replace(r'\\', '/', regex=True)

    df['mri_link'] = df['mri_link'].str.replace('/data/', '/', regex=False)
    df['dti_link'] = df['dti_link'].str.replace('/data/', '/', regex=False)


    # Thêm base dir
    dti_base_dir = '/kaggle/input/data-medical/DTI/'
    mri_base_dir = '/kaggle/input/data-medical/'
    df['dti_link'] = df['dti_link'].apply(lambda p: os.path.join(dti_base_dir, p))
    df['mri_link'] = df['mri_link'].apply(lambda p: os.path.join(mri_base_dir, p))

    # Đổi tên cột
    df = df.rename(columns={
        'ptgender': 'gender',
        'diagnosis': 'label',
        'age_at_visit': 'age'
    })

    # Chuẩn hóa nhãn nếu cần: label - 1
    if df['label'].min() == 1:
        df['label'] = df['label'] - 1
    df['gender'] = df['gender'] - 1
    return df



In [5]:
train_df = preprocessing(train_df)
test_df = preprocessing(test_df)
val_df = preprocessing(val_df)

In [6]:
train_df.head()

Unnamed: 0,label,gender,mri_link,dti_link,age
0,0.0,1.0,/kaggle/input/data-medical/data/MRI/014_S_4401...,/kaggle/input/data-medical/DTI/data/DTI/014_S_...,75.0
1,0.0,1.0,/kaggle/input/data-medical/data/MRI/007_S_6310...,/kaggle/input/data-medical/DTI/data/DTI/007_S_...,69.0
2,0.0,1.0,/kaggle/input/data-medical/data/MRI/141_S_6416...,/kaggle/input/data-medical/DTI/data/DTI/141_S_...,72.0
3,0.0,1.0,/kaggle/input/data-medical/data/MRI/141_S_0767...,/kaggle/input/data-medical/DTI/data/DTI/141_S_...,84.0
4,0.0,1.0,/kaggle/input/data-medical/data/MRI/016_S_4951...,/kaggle/input/data-medical/DTI/data/DTI/016_S_...,77.0


In [7]:
val_df.head()

Unnamed: 0,label,gender,mri_link,dti_link,age
0,0.0,1.0,/kaggle/input/data-medical/data/MRI/130_S_5258...,/kaggle/input/data-medical/DTI/data/DTI/130_S_...,84.0
1,0.0,1.0,/kaggle/input/data-medical/data/MRI/014_S_6145...,/kaggle/input/data-medical/DTI/data/DTI/014_S_...,73.0
2,0.0,1.0,/kaggle/input/data-medical/data/MRI/070_S_6548...,/kaggle/input/data-medical/DTI/data/DTI/070_S_...,56.0
3,0.0,0.0,/kaggle/input/data-medical/data/MRI/041_S_5141...,/kaggle/input/data-medical/DTI/data/DTI/041_S_...,81.0
4,0.0,1.0,/kaggle/input/data-medical/data/MRI/116_S_4483...,/kaggle/input/data-medical/DTI/data/DTI/116_S_...,76.0


In [8]:
test_df.head()

Unnamed: 0,label,gender,mri_link,dti_link,age
0,0.0,1.0,/kaggle/input/data-medical/data/MRI/068_S_2184...,/kaggle/input/data-medical/DTI/data/DTI/068_S_...,89.0
1,0.0,1.0,/kaggle/input/data-medical/data/MRI/003_S_4872...,/kaggle/input/data-medical/DTI/data/DTI/003_S_...,78.0
2,0.0,1.0,/kaggle/input/data-medical/data/MRI/014_S_4576...,/kaggle/input/data-medical/DTI/data/DTI/014_S_...,78.0
3,0.0,1.0,/kaggle/input/data-medical/data/MRI/116_S_6133...,/kaggle/input/data-medical/DTI/data/DTI/116_S_...,71.0
4,0.0,1.0,/kaggle/input/data-medical/data/MRI/023_S_0031...,/kaggle/input/data-medical/DTI/data/DTI/023_S_...,90.0


In [9]:
print(train_df['label'].value_counts())
print(test_df['label'].value_counts())
print(val_df['label'].value_counts())

label
0.0    413
1.0    357
2.0    241
Name: count, dtype: int64
label
1.0    97
0.0    73
2.0    39
Name: count, dtype: int64
label
1.0    83
0.0    79
2.0    36
Name: count, dtype: int64


In [10]:
print(train_df['gender'].value_counts())
print(test_df['gender'].value_counts())
print(val_df['gender'].value_counts())

gender
0.0    520
1.0    491
Name: count, dtype: int64
gender
1.0    106
0.0    103
Name: count, dtype: int64
gender
0.0    115
1.0     83
Name: count, dtype: int64


In [11]:
import os
import glob
import nibabel as nib
import numpy as np
import torch
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self, df, target_shape=(6, 182, 182)):
        super().__init__()
        self.df = df.reset_index(drop=True)
        self.target_shape = target_shape

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        def load_nifti(path):
            # Nếu path là thư mục, tìm file .nii hoặc .nii.gz bên trong
            if os.path.isdir(path):
               nii_files = glob.glob(os.path.join(path, "*.nii*"))
               if not nii_files:
                  raise FileNotFoundError(f"No NIfTI file found in folder {path}")
               path = nii_files[0]

            arr = nib.load(path).get_fdata().astype(np.float32)
            return (arr - arr.mean()) / (arr.std() + 1e-8)

        mri_vol = load_nifti(row['mri_link'])
        dti_vol = load_nifti(row['dti_link'])

        # Resize bằng crop/pad
        def resize_vol(vol, shape):
            tz, ty, tx = shape
            z, y, x = vol.shape
            cz, cy, cx = min(z, tz), min(y, ty), min(x, tx)
            sz, sy, sx = (z-cz)//2, (y-cy)//2, (x-cx)//2
            dz, dy, dx = (tz-cz)//2, (ty-cy)//2, (tx-cx)//2
            out = np.zeros((tz, ty, tx), dtype=vol.dtype)
            out[dz:dz+cz, dy:dy+cy, dx:dx+cx] = vol[sz:sz+cz, sy:sy+cy, sx:sx+cx]
            return out

        mri_vol = resize_vol(mri_vol, self.target_shape)
        dti_vol = resize_vol(dti_vol, self.target_shape)

        # Chuyển thành tensor 5D: [C=1, D, H, W] mỗi modality
        mri_tensor = torch.from_numpy(mri_vol).unsqueeze(0)  # (1, D, H, W)
        dti_tensor = torch.from_numpy(dti_vol).unsqueeze(0)  # (1, D, H, W)

        # Demo và label
        age    = torch.tensor(row['age'], dtype=torch.float32)
        gender = torch.tensor(row['gender'], dtype=torch.float32)
        label  = torch.tensor(row['label'], dtype=torch.long)

        return {
            'mri':    mri_tensor,
            'dti':    dti_tensor,
            'age':    age,
            'gender': gender,
            'label':  label
        }


In [12]:
from torch.utils.data import DataLoader, Subset
from collections import defaultdict
import numpy as np
import random

def balance_train_dataset(train_dataset, batch_size=16, seed=42, num_workers=4):
    random.seed(seed)
    label_to_indices = defaultdict(list)

    for idx in range(len(train_dataset)):
        label = int(train_dataset[idx]['label']) 
        label_to_indices[label].append(idx)

    min_class_size = min(len(idxs) for idxs in label_to_indices.values())
    print(f"Undersampling to {min_class_size} samples per class")

    balanced_indices = []
    for label, indices in label_to_indices.items():
        sampled = random.sample(indices, min_class_size)
        balanced_indices.extend(sampled)

    random.shuffle(balanced_indices)

    balanced_subset = Subset(train_dataset, balanced_indices)
    balanced_loader = DataLoader(balanced_subset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    return balanced_loader


In [13]:
# Create datasets
train_dataset = CustomDataset(train_df)
val_dataset = CustomDataset(val_df)
test_dataset = CustomDataset(test_df)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers = 4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers = 4)

In [14]:
# Create datasets
train_dataset = CustomDataset(train_df)
val_dataset = CustomDataset(val_df)
test_dataset = CustomDataset(test_df)

# Create dataloaders
train_loader = balance_train_dataset(train_dataset, batch_size=16,  num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers = 4)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers = 4)

FileNotFoundError: No such file or no access: '/kaggle/input/data-medical/data/MRI/014_S_4401/Accelerated_Sagittal_MPRAGE_ND/2019-09-09_08_18_44.0/I1224466'

In [None]:
import torch
import torch.nn as nn
import torchvision.models.video as models


class TabularMLP(nn.Module):
    def __init__(self, in_features=2, hidden_dim=64):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(in_features, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU()
        )

    def forward(self, x):
        return self.model(x)


class AttentionFusion(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.attn_layer = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.Tanh(),
            nn.Linear(input_dim, input_dim),
            nn.Sigmoid()
        )

    def forward(self, x):
        alpha = self.attn_layer(x)
        return x * alpha


class MultimodalAlzheimerClassifier(nn.Module):
    def __init__(self, num_classes=3, tabular_dim=2, backbone_out_dim=512, freeze_backbone=True):
        super().__init__()

        # MRI backbone
        self.mri_backbone = models.r3d_18(pretrained=True)
        self.mri_backbone.stem[0] = nn.Conv3d(1, 64, kernel_size=(3,7,7), stride=(1,2,2), padding=(1,3,3), bias=False)
        self.mri_backbone.fc = nn.Identity()

        # DTI backbone
        self.dti_backbone = models.r3d_18(pretrained=True)
        self.dti_backbone.stem[0] = nn.Conv3d(1, 64, kernel_size=(3,7,7), stride=(1,2,2), padding=(1,3,3), bias=False)
        self.dti_backbone.fc = nn.Identity()

        if freeze_backbone:
            for param in self.mri_backbone.parameters():
                param.requires_grad = False
            for param in self.dti_backbone.parameters():
                param.requires_grad = False

        self.tabular_branch = TabularMLP(in_features=tabular_dim, hidden_dim=64)

        self.fusion_dim = 2 * backbone_out_dim + 64
        self.attn_fusion = AttentionFusion(self.fusion_dim)

        self.classifier = nn.Sequential(
            nn.Linear(self.fusion_dim, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, mri, dti, age, gender):
        tabular = torch.stack([age, gender], dim=1)

        mri_feat = self.mri_backbone(mri)      # (B, 512)
        dti_feat = self.dti_backbone(dti)      # (B, 512)
        tab_feat = self.tabular_branch(tabular)  # (B, 64)

        fused = torch.cat([mri_feat, dti_feat, tab_feat], dim=1)  # (B, 1088)
        fused = self.attn_fusion(fused)

        out = self.classifier(fused)
        return out

    def unfreeze_backbones(self):
        for param in self.mri_backbone.parameters():
            param.requires_grad = True
        for param in self.dti_backbone.parameters():
            param.requires_grad = True



In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
model = MultimodalAlzheimerClassifier(num_classes=3, freeze_backbone=True).to(device)
criterion = nn.CrossEntropyLoss(label_smoothing=0.2)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)


In [None]:
import numpy as np
import torch

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0, path='/kaggle/working/checkpoint.pt', trace_func=print):
 
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):
        score = -val_loss  # lower val_loss is better

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)

        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                self.trace_func(f'EarlyStopping counter: {self.counter} / {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True

        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Lưu model khi validation loss giảm'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss


In [None]:
from tqdm import tqdm
import torch
import numpy as np
from sklearn.metrics import confusion_matrix

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=15):
    best_val_acc = 0.0
    early_stopping = EarlyStopping(patience=patience, verbose=True)

    train_losses = []
    train_accuracies = []
    val_losses = []
    val_accuracies = []

    y_true_best = []
    y_pred_best = []

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            mri    = batch['mri'].to(device).float()
            dti    = batch['dti'].to(device).float()
            age    = batch['age'].to(device).float()
            gender = batch['gender'].to(device).float()
            label  = batch['label'].to(device).long()

            optimizer.zero_grad()
            outputs = model(mri, dti, age, gender)
            loss = criterion(outputs, label)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += label.size(0)
            correct += (predicted == label).sum().item()

        train_loss = running_loss / len(train_loader)
        train_acc = correct / total
        train_losses.append(train_loss)
        train_accuracies.append(train_acc)

        print(f"\nEpoch [{epoch+1}/{num_epochs}]")
        print(f"Train Loss: {train_loss:.4f} | Train Accuracy: {train_acc:.4f}")

        # Validation
        model.eval()
        val_loss_total = 0.0
        val_correct = 0
        val_total = 0
        y_true = []
        y_pred = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
                mri    = batch['mri'].to(device).float()
                dti    = batch['dti'].to(device).float()
                age    = batch['age'].to(device).float()
                gender = batch['gender'].to(device).float()
                label  = batch['label'].to(device).long()

                outputs = model(mri, dti, age, gender)
                loss = criterion(outputs, label)
                val_loss_total += loss.item()

                _, predicted = torch.max(outputs, 1)
                val_total += label.size(0)
                val_correct += (predicted == label).sum().item()

                y_true.extend(label.cpu().numpy())
                y_pred.extend(predicted.cpu().numpy())

        val_loss = val_loss_total / len(val_loader)
        val_acc = val_correct / val_total
        val_losses.append(val_loss)
        val_accuracies.append(val_acc)

        print(f"Val Loss: {val_loss:.4f} | Val Accuracy: {val_acc:.4f}")

        scheduler.step(val_loss)  

        # Lưu mô hình tốt nhất
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            y_true_best = y_true
            y_pred_best = y_pred
            torch.save(model.state_dict(), "/kaggle/working/best_model.pth")


        early_stopping(val_loss, model)
        if early_stopping.early_stop:
            print(" Early stopping!")
            break

        torch.cuda.empty_cache()  

    return train_losses, train_accuracies, val_losses, val_accuracies, y_true_best, y_pred_best


In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
import numpy as np

def plot_results(train_losses, train_accuracies, val_losses, val_accuracies, y_true, y_pred, class_names=None):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(14, 6))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label="Train Loss")
    plt.plot(epochs, val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title("Training & Validation Loss")
    plt.legend()
    plt.grid(True)

    # Plot accuracy
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label="Train Acc")
    plt.plot(epochs, val_accuracies, label="Val Acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.title("Training & Validation Accuracy")
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)

    if class_names is None:
        class_names = [str(c) for c in np.unique(y_true)]

    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
    fig, ax = plt.subplots(figsize=(6, 6))
    disp.plot(cmap=plt.cm.Blues, ax=ax, values_format='d')
    plt.title("Confusion Matrix")
    plt.grid(False)
    plt.show()

    # Classification report
    print("Classification Report:")
    print(classification_report(y_true, y_pred, target_names=class_names))

In [None]:
train_losses, train_accs, val_losses, val_accs, y_true, y_pred = train_model(
    model, train_loader, val_loader, criterion, optimizer, scheduler, device, num_epochs=50, patience=15
)

In [None]:
class_names = ['CN', 'MCI', 'AD']
plot_results(train_losses, train_accs, val_losses, val_accs, y_true, y_pred, class_names)