In [28]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import os

# Dataset Path
dataset_dir = "C:/Users/samik/Documents/GitHub/MS-disease/SplitDataset"

# Define batch size and image size
img_size = (146, 81)  # Image is already a patch
batch_size = 8  # ✅ Reduce batch size to fit in memory

# Load datasets
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "train"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",  # ViT supports grayscale images too
    label_mode="binary"
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "val"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    label_mode="binary"
)

test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    os.path.join(dataset_dir, "test"),
    image_size=img_size,
    batch_size=batch_size,
    color_mode="grayscale",
    label_mode="binary"
)

# Optimize dataset for performance
AUTOTUNE = tf.data.AUTOTUNE
train_ds = train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.prefetch(buffer_size=AUTOTUNE)
test_ds = test_ds.prefetch(buffer_size=AUTOTUNE)

Found 198798 files belonging to 2 classes.
Found 24849 files belonging to 2 classes.
Found 24851 files belonging to 2 classes.


In [31]:
class TransformerBlock(layers.Layer):
    """Transformer Encoder Block with Local Self-Attention"""
    def __init__(self, num_heads, embed_dim, mlp_dim, dropout_rate=0.1, local_window_size=64):
        super().__init__()
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)

        # ✅ Use Local Self-Attention (Fixed)
        self.local_window_size = local_window_size  
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)

        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = keras.Sequential([
            layers.Dense(mlp_dim, activation=tf.nn.gelu),
            layers.Dropout(dropout_rate),
            layers.Dense(embed_dim),  
        ])

    def call(self, x):
        seq_len = tf.shape(x)[1]

        # ✅ Correct the Attention Mask Shape
        attention_mask = tf.sequence_mask(seq_len, seq_len)  # Shape: (seq_len, seq_len)
        attention_mask = tf.cast(attention_mask, dtype=tf.float32)  # Convert to float
        attention_mask = tf.expand_dims(attention_mask, axis=0)  # Shape: (1, seq_len, seq_len)

        attn_output = self.attn(x, x, attention_mask=attention_mask)  # ✅ Now works
        x = self.norm1(x + attn_output)
        x = self.norm2(x + self.mlp(x))
        
        return x


In [32]:
# ✅ Vision Transformer Model
def build_vit(image_size, num_heads, embed_dim, mlp_dim, num_layers):
    input_layer = layers.Input(shape=(image_size[0], image_size[1], 1))  # Grayscale input

    # ✅ Flatten image into a single token
    x = layers.Reshape((image_size[0] * image_size[1], 1))(input_layer)

    # ✅ Linear projection (Embedding Layer)
    x = layers.Dense(embed_dim)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)

    # ✅ Transformer Encoder Blocks
    for _ in range(num_layers):
        x = TransformerBlock(num_heads, embed_dim, mlp_dim)(x)

    # ✅ Use Global Average Pooling Instead of Flatten
    x = layers.GlobalAveragePooling1D()(x)  

    # ✅ Classification Head
    x = layers.Dense(mlp_dim, activation="gelu")(x)
    x = layers.Dropout(0.5)(x)
    
    output_layer = layers.Dense(1, activation="sigmoid")(x)

    model = keras.Model(inputs=input_layer, outputs=output_layer)
    return model

# ✅ Model Parameters (Reduced for Efficiency)
num_heads = 2  # ✅ Reduce attention heads to lower memory
embed_dim = 16  # ✅ Reduce embedding size
mlp_dim = 32
num_layers = 4

# ✅ Build the Vision Transformer model
vit_model = build_vit(img_size, num_heads, embed_dim, mlp_dim, num_layers)
vit_model.summary()


In [33]:
# Compile model
vit_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=3e-4),
                  loss="binary_crossentropy",
                  metrics=["accuracy"])


In [34]:
# ✅ Train without `batch_size`
history = vit_model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10
)

# ✅ Evaluate the model
test_loss, test_acc = vit_model.evaluate(test_ds)
print(f"Test Accuracy: {test_acc * 100:.2f}%")

Epoch 1/10


ResourceExhaustedError: Graph execution error:

Detected at node StatefulPartitionedCall/functional_18_1/transformer_block_17_1/multi_head_attention_17_1/MatMul defined at (most recent call last):
<stack traces unavailable>
OOM when allocating tensor with shape[8,2,11826,11826] and type float on /job:localhost/replica:0/task:0/device:CPU:0 by allocator mklcpu
	 [[{{node StatefulPartitionedCall/functional_18_1/transformer_block_17_1/multi_head_attention_17_1/MatMul}}]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. This isn't available when running in Eager mode.
 [Op:__inference_one_step_on_iterator_54005]