In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os


# Dataset Path
dataset_dir = "C:/Users/samik/Documents/GitHub/MS-disease/SplitDataset"

# Define batch size and image size
img_size = (146, 81)  # No need to split into patches
batch_size = 32

# Load datasets (Grayscale → Convert to RGB)
def preprocess_image(image, label):
    image = tf.image.grayscale_to_rgb(image)  # Convert 1-channel grayscale to 3-channel RGB
    return image, label

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "train"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    label_mode="binary"
).map(preprocess_image)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "val"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    label_mode="binary"
).map(preprocess_image)

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "test"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    label_mode="binary"
).map(preprocess_image)

# Optimize dataset
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)

# ViT Parameters
embedding_dim = 128  # Dimension of embeddings
num_heads = 4  # Multi-head attention
mlp_units = [256, 128]  # MLP Head units
num_transformer_layers = 8  # Transformer depth
patch_size = (8, 8)  # Patching instead of flattening

# Vision Transformer Encoder
def transformer_encoder(inputs):
    x = layers.LayerNormalization(epsilon=1e-6)(inputs)
    x = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embedding_dim)(x, x)
    x = layers.Add()([x, inputs])

    x_skip = x  # Save residual for later

    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = keras.Sequential([layers.Dense(units, activation="relu") for units in mlp_units])(x)
    x = layers.Add()([x, x_skip])  # Residual connection
    return x

# Vision Transformer Model
def build_vit_model():
    inputs = layers.Input(shape=(146, 81, 3))  

    # Patch Embedding using Conv2D
    x = layers.Conv2D(embedding_dim, kernel_size=patch_size, strides=patch_size, padding="valid")(inputs)  
    x = layers.Reshape((-1, embedding_dim))(x)  # Reshape into patches

    # Transformer Layers
    for _ in range(num_transformer_layers):
        x = transformer_encoder(x)

    # Classification Head
    representation = layers.LayerNormalization(epsilon=1e-6)(x)
    representation = layers.GlobalAveragePooling1D()(representation)
    representation = keras.Sequential(
        [layers.Dense(units, activation="relu") for units in mlp_units]
    )(representation)
    outputs = layers.Dense(1, activation="sigmoid")(representation)

    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

# Compile Model
vit_model = build_vit_model()
vit_model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=0.0001),
    loss="binary_crossentropy",
    metrics=["accuracy"]
)

# Print Model Summary
vit_model.summary()

# Train Model
history = vit_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    batch_size=batch_size
)

# Evaluate Model
test_loss, test_acc = vit_model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc:.4f}")

# Save Model
vit_model.save("ms_detection_ViT_patched.h5")


Found 198798 files belonging to 2 classes.
Found 24849 files belonging to 2 classes.
Found 24851 files belonging to 2 classes.


Epoch 1/10
[1m6213/6213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11470s[0m 2s/step - accuracy: 0.8463 - loss: 0.3310 - val_accuracy: 0.8706 - val_loss: 0.2756
Epoch 2/10
[1m6213/6213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11857s[0m 2s/step - accuracy: 0.8661 - loss: 0.2765 - val_accuracy: 0.8712 - val_loss: 0.2614
Epoch 3/10
[1m6213/6213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11823s[0m 2s/step - accuracy: 0.8696 - loss: 0.2655 - val_accuracy: 0.8673 - val_loss: 0.2705
Epoch 4/10
[1m6213/6213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12451s[0m 2s/step - accuracy: 0.8837 - loss: 0.2490 - val_accuracy: 0.8916 - val_loss: 0.2368
Epoch 5/10
[1m6213/6213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12422s[0m 2s/step - accuracy: 0.8910 - loss: 0.2352 - val_accuracy: 0.8892 - val_loss: 0.2387
Epoch 6/10
[1m6213/6213[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14449s[0m 2s/step - accuracy: 0.8966 - loss: 0.2268 - val_accuracy: 0.8862 - val_loss: 0.242