In [16]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os

# Dataset Path
dataset_dir = "C:/Users/samik/Documents/GitHub/MS-disease/SplitDataset"

# Define batch size and image size
img_size = (146, 81)  # Image is already a patch
batch_size = 32

# Load datasets
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "train"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",  # ViT supports grayscale images too
    label_mode="binary"
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "val"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    label_mode="binary"
)

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "test"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    label_mode="binary"
)

# Optimize dataset for performance
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)


Found 198798 files belonging to 2 classes.
Found 24849 files belonging to 2 classes.
Found 24851 files belonging to 2 classes.


In [13]:
def mlp(x, hidden_units, dropout_rate):
    """MLP Block"""
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x
class TransformerBlock(layers.Layer):
    """Transformer Encoder Block"""
    def __init__(self, num_heads, embed_dim, mlp_dim, dropout_rate=0.1):
        super().__init__()
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)

        # ✅ Define MLP once in __init__
        self.mlp = keras.Sequential([
            layers.Dense(mlp_dim, activation=tf.nn.gelu),
            layers.Dropout(dropout_rate),
            layers.Dense(embed_dim),  # Output shape must match `embed_dim`
        ])

    def call(self, x):
        attn_output = self.attn(x, x)
        x = self.norm1(x + attn_output)
        x = self.norm2(x + self.mlp(x))  # ✅ Use pre-defined MLP model
        return x



In [18]:
def build_vit(image_size, num_heads, embed_dim, mlp_dim, num_layers):
    input_layer = layers.Input(shape=(image_size[0], image_size[1], 1))  # Grayscale input

    # Flatten image into a single token
    x = layers.Reshape((image_size[0] * image_size[1], 1))(input_layer)

    # Linear projection (Embedding Layer)
    x = layers.Dense(embed_dim)(x)

    # Positional Encoding (Optional, since images already have spatial structure)
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    # Transformer Encoder Blocks
    for _ in range(num_layers):
        x = TransformerBlock(num_heads, embed_dim, mlp_dim)(x)

    # Global Average Pooling
    x = layers.GlobalAveragePooling1D()(x)

    # Classification Head
    x = layers.Dense(mlp_dim, activation="gelu")(x)
    x = layers.Dropout(0.5)(x)
    
    output_layer = layers.Dense(1, activation="sigmoid")(x)

    model = keras.Model(inputs=input_layer, outputs=output_layer)
    return model

# Define model parameters
num_heads = 4  # Number of attention heads
embed_dim = 64  # Embedding dimension
mlp_dim = 128  # Hidden layer dimension
num_layers = 2  # Number of Transformer Blocks

# Build model
vit_model = build_vit(img_size, num_heads, embed_dim, mlp_dim, num_layers)
vit_model.summary()

# Compile model
vit_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
                  loss="binary_crossentropy",
                  metrics=["accuracy"])


In [19]:
history = vit_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,  # Adjust based on performance
)
test_loss, test_acc = vit_model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc * 100:.2f}%")

Epoch 1/10


ResourceExhaustedError: Graph execution error:

Detected at node StatefulPartitionedCall/functional_15_1/transformer_block_10_1/multi_head_attention_10_1/MatMul defined at (most recent call last):
<stack traces unavailable>
OOM when allocating tensor with shape[32,4,11826,11826] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator mklcpu
	 [[{{node StatefulPartitionedCall/functional_15_1/transformer_block_10_1/multi_head_attention_10_1/MatMul}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_one_step_on_iterator_38696]