In [20]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import layers, models
from tqdm import tqdm
from PIL import Image

# Path & data settings
DATA_DIR = r"C:\Users\Alma\Desktop\Data\sliced_data"
AXES = ['X', 'Y', 'Z']
METRICS = ['FA', 'MD', 'RD', 'AD']
CLASSES = ['MCI', 'NC']

# Image and sequence shape settings
NUM_FRAMES = 100    # slices per sequence
IMG_SIZE = 128      # target size of images
CHANNELS = len(METRICS)  # 4 (FA, MD, RD, AD)

BATCH_SIZE = 4
EPOCHS = 10


In [21]:
def load_subject_slices(subject_path, axis):
    """
    Load slices for a single subject along a specific axis.
    Returns a 4D array: (frames, height, width, channels)
    """
    frames = []
    for slice_idx in range(NUM_FRAMES):
        slice_imgs = []
        for metric in METRICS:
            folder = f"Sliced_{axis}_{metric}"
            slice_filename = f"slice_{slice_idx:03d}.png"
            slice_path = os.path.join(subject_path, folder, slice_filename)

            if not os.path.exists(slice_path):
                print(f"Warning: Missing slice {slice_path}")
                continue

            # Load and preprocess the image
            img = Image.open(slice_path).convert('L')  # grayscale
            img = img.resize((IMG_SIZE, IMG_SIZE))     # resize image
            img = np.array(img) / 255.0                # normalize [0, 1]

            slice_imgs.append(img)

        if len(slice_imgs) == len(METRICS):
            # Stack all metrics as channels: (H, W, 4)
            slice_stack = np.stack(slice_imgs, axis=-1)
            frames.append(slice_stack)
        else:
            print(f"Warning: Incomplete metrics at slice {slice_idx} in {subject_path} for axis {axis}")

    return np.array(frames)


In [22]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import tensorflow as tf

AXES = ['X', 'Y', 'Z']
METRICS = ['FA', 'MD', 'RD', 'AD']
SLICE_COUNT = 100   # Number of slices per axis
IMG_SIZE = (224, 224)
NUM_CLASSES = 2


In [23]:

def load_subject_slices(subject_path):
    """Load slices for one subject."""
    subject_data = []
    axes_found = True  # We use this to skip subjects with missing data

    for axis in AXES:
        axis_slices = []
        for slice_idx in range(SLICE_COUNT):
            metrics_slices = []
            for metric in METRICS:
                # FIXED path construction
                metric_folder = os.path.join(subject_path, f"Sliced_{axis}_{metric}")

                if not os.path.exists(metric_folder):
                    print(f"❌ Missing folder: {metric_folder}")
                    axes_found = False
                    break

                slice_filename = f"slice_{slice_idx:03d}.png"
                slice_path = os.path.join(metric_folder, slice_filename)

                if not os.path.exists(slice_path):
                    print(f"❌ Missing slice {slice_path}")
                    axes_found = False
                    break

                # Load and preprocess image
                img = Image.open(slice_path).convert('L')  # grayscale
                img = img.resize(IMG_SIZE)
                img_array = np.array(img, dtype=np.float32) / 255.0
                metrics_slices.append(img_array)

            if not axes_found:
                break

            # Stack FA, MD, RD, AD channels into (224, 224, 4)
            axis_slices.append(np.stack(metrics_slices, axis=-1))

        if not axes_found:
            print(f"⚠️ Skipping subject {subject_path} due to missing data on axis {axis}")
            return None

        # Stack 100 slices along time axis (100, 224, 224, 4)
        subject_data.append(np.stack(axis_slices))

    # Final shape (3 axes, 100 slices, 224, 224, 4 metrics)
    return np.stack(subject_data)

def build_dataset(data_dir):
    """Iterate over folders and build X and y datasets."""
    X = []
    y = []

    # Get correct class folders (MCI and NC ONLY)
    class_names = sorted([
        cls for cls in os.listdir(data_dir)
        if os.path.isdir(os.path.join(data_dir, cls)) and cls in ['MCI', 'NC']
    ])

    print("✅ Classes found:", class_names)

    for label_idx, class_name in enumerate(class_names):
        class_folder = os.path.join(data_dir, class_name)
        subject_folders = [
            sf for sf in os.listdir(class_folder)
            if os.path.isdir(os.path.join(class_folder, sf))
        ]

        print(f"\n🔎 Processing class '{class_name}' with {len(subject_folders)} subjects")

        for subject_id in tqdm(subject_folders):
            subject_path = os.path.join(class_folder, subject_id)

            subject_slices = load_subject_slices(subject_path)

            if subject_slices is not None:
                X.append(subject_slices)
                y.append(label_idx)

    X = np.array(X)
    y = np.array(y)

    print(f"\n✅ Dataset loaded. X shape: {X.shape}, y shape: {y.shape}")
    return X, y


In [24]:
X, y = build_dataset(DATA_DIR)

# Split into train and validation sets
X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=42
)

print(f"\nTrain set shape: {X_train.shape}, Validation set shape: {X_val.shape}")


✅ Classes found: ['MCI', 'NC']

🔎 Processing class 'MCI' with 101 subjects


100%|████████████████████████████████████████████████████████████████████████████████| 101/101 [38:22<00:00, 22.80s/it]



🔎 Processing class 'NC' with 101 subjects


 54%|████████████████████████████████████████████                                     | 55/101 [26:01<21:45, 28.39s/it]


MemoryError: Unable to allocate 230. MiB for an array with shape (3, 100, 224, 224, 4) and data type float32

In [None]:
def create_vivit_model(input_shape, num_classes=2, projection_dim=64, transformer_layers=4, num_heads=4):
    inputs = layers.Input(shape=input_shape)

    # Tubelet embedding (3D patches)
    patch_size = (10, 16, 16)  # Time, Height, Width
    num_patches = (
        input_shape[0] // patch_size[0],
        input_shape[1] // patch_size[1],
        input_shape[2] // patch_size[2]
    )
    
    x = layers.Conv3D(
        filters=projection_dim,
        kernel_size=patch_size,
        strides=patch_size,
        padding='valid'
    )(inputs)
    
    x = layers.Reshape((num_patches[0] * num_patches[1] * num_patches[2], projection_dim))(x)

    # Positional Encoding
    positions = tf.range(start=0, limit=x.shape[1], delta=1)
    pos_encoding = layers.Embedding(input_dim=x.shape[1], output_dim=projection_dim)(positions)
    x = x + pos_encoding

    # Transformer Encoder blocks
    for _ in range(transformer_layers):
        # LayerNorm
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        
        # Multi-head Self Attention
        attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(x1, x1)
        x2 = layers.Add()([attention, x])

        # MLP
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        mlp = layers.Dense(projection_dim * 4, activation='gelu')(x3)
        mlp = layers.Dense(projection_dim)(mlp)
        x = layers.Add()([x2, mlp])

    # Classification Head
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = models.Model(inputs, outputs)
    return model



In [None]:
model = create_vivit_classifier()

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

history = model.fit(
    X_train, y_train_cat,
    validation_data=(X_val, y_val_cat),
    epochs=EPOCHS,
    batch_size=BATCH_SIZE
)


In [None]:
# Evaluate model on validation set
val_loss, val_acc = model.evaluate(X_val, y_val_cat)
print(f"Validation Accuracy: {val_acc:.4f}")

# Plot accuracy curves
plt.plot(history.history['accuracy'], label='Train Accuracy')
plt.plot(history.history['val_accuracy'], label='Val Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()
plt.show()
