Importing All Required Library

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Dropout, LayerNormalization, Add, GlobalAveragePooling1D
from tensorflow.keras.models import Model
from tensorflow.keras.layers import MultiHeadAttention
from tensorflow.keras import mixed_precision
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.optimizers import Adam

Downloading Eye Disease Datasets

In [None]:
# Install gdown if not already installed
!pip install -q gdown

# Download the zip file using file ID from Google Drive
!gdown --id 1vHCvd9ZFkY9lfNwUnnMUbDQZjjpby8lv -O disease.zip

# Extract the ZIP file
import zipfile
import os

zip_file = "disease.zip"  # Name of the downloaded file

# Make sure output directory exists
output_dir = "/content/extracted_file"
os.makedirs(output_dir, exist_ok=True)

# Extract the contents
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
    zip_ref.extractall(output_dir)

print("Extraction completed!")


Downloading...
From (original): https://drive.google.com/uc?id=1QwItNnETzlWHgZ89EbHFYg04ijkDszF6
From (redirected): https://drive.google.com/uc?id=1QwItNnETzlWHgZ89EbHFYg04ijkDszF6&confirm=t&uuid=dd8cae99-b0c3-43cc-b657-13e52c96e3d9
To: /kaggle/working/disease.zip
100%|██████████████████████████████████████| 3.85G/3.85G [00:44<00:00, 85.9MB/s]
Extraction completed!


Vision Transformer Architecture and training it

In [None]:
# Enable mixed precision
mixed_precision.set_global_policy("mixed_float16")

# Patch & Position Embedding Layer
class PatchEmbedding(tf.keras.layers.Layer):
    def __init__(self, patch_size, projection_dim):
        super().__init__()
        self.patch_size = patch_size
        self.projection_dim = projection_dim
        self.projection = tf.keras.layers.Conv2D(filters=projection_dim,
                                                 kernel_size=patch_size,
                                                 strides=patch_size,
                                                 padding='valid')
        self.flatten = tf.keras.layers.Reshape((-1, projection_dim))

    def call(self, images):
        x = self.projection(images)
        x = self.flatten(x)
        return x

class PositionalEncoding(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.pos_embedding = self.add_weight(
            name="pos_embed",
            shape=(1, num_patches + 1, projection_dim),
            initializer="random_normal",
            trainable=True
        )
        self.cls_token = self.add_weight(
            name="cls_token",
            shape=(1, 1, projection_dim),
            initializer="random_normal",
            trainable=True
        )

    def call(self, x):
        batch_size = tf.shape(x)[0]
        cls_tokens = tf.broadcast_to(self.cls_token, [batch_size, 1, tf.shape(x)[-1]])
        x = tf.concat([cls_tokens, x], axis=1)
        return x + self.pos_embedding


# Transformer Encoder Block
def transformer_encoder(x, projection_dim, num_heads, mlp_dim, dropout_rate):
    x1 = LayerNormalization(epsilon=1e-6)(x)
    attention_output = MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(x1, x1)
    x2 = Add()([x, attention_output])
    x3 = LayerNormalization(epsilon=1e-6)(x2)
    x3 = Dense(mlp_dim, activation='gelu')(x3)
    x3 = Dropout(dropout_rate)(x3)
    x3 = Dense(projection_dim)(x3)
    x3 = Dropout(dropout_rate)(x3)
    return Add()([x2, x3])

# Build ViT
def build_vit(input_shape=(224, 224, 3), num_classes=7,
              patch_size=16, projection_dim=256, transformer_layers=12,
              num_heads=8, mlp_dim=512, dropout_rate=0.1):

    inputs = Input(shape=input_shape)

    # Patch Embedding + Positional Encoding
    patch_embed = PatchEmbedding(patch_size, projection_dim)(inputs)
    num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
    encoded = PositionalEncoding(num_patches, projection_dim)(patch_embed)

    # Transformer Encoder
    for _ in range(transformer_layers):
        encoded = transformer_encoder(encoded, projection_dim, num_heads, mlp_dim, dropout_rate)

    # Classification head
    x = LayerNormalization(epsilon=1e-6)(encoded)
    x = x[:, 0, :]  # CLS token
    x = Dropout(0.3)(x)
    outputs = Dense(num_classes, activation='softmax', dtype='float32')(x)

    return Model(inputs, outputs)

# ----------------------------
# Training Pipeline
# ----------------------------

# Image size, batch size
img_size = (224, 224)
batch_size = 32
data_path = "/content/extracted_file/Retinal Fundus Images - Copy/train"

# Datasets
train_ds = tf.keras.utils.image_dataset_from_directory(
    data_path,
    validation_split=0.1,
    subset='training',
    seed=42,
    image_size=img_size,
    batch_size=batch_size,
    label_mode='categorical'
)

val_ds = tf.keras.utils.image_dataset_from_directory(
    data_path,
    validation_split=0.1,
    subset='validation',
    seed=42,
    image_size=img_size,
    batch_size=batch_size,
    label_mode='categorical'
)

# Augmentation + Normalization
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.Rescaling(1./255),
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomRotation(0.1),
    tf.keras.layers.RandomZoom(0.1),
])

train_ds = train_ds.map(lambda x, y: (data_augmentation(x), y)).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(lambda x, y: (tf.keras.layers.Rescaling(1./255)(x), y)).prefetch(tf.data.AUTOTUNE)

# Build and compile
vit_model = build_vit()
vit_model.compile(optimizer=Adam(1e-4),
                  loss=tf.keras.losses.CategoricalCrossentropy(label_smoothing=0.1),
                  metrics=['accuracy'])

# Train
callbacks = [
    ModelCheckpoint("/content/best_vit_model.keras", save_best_only=True),
    # EarlyStopping(patience=5, restore_best_weights=True)
]

vit_model.fit(train_ds, validation_data=val_ds, epochs=30, callbacks=callbacks)

# Evaluate
loss, acc = vit_model.evaluate(val_ds)
print(f"Final Accuracy: {acc * 100:.2f}%")


2025-04-15 10:14:32.553202: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744712072.749709      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744712072.805162      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Found 55446 files belonging to 7 classes.
Using 49902 files for training.


I0000 00:00:1744712088.330170      31 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:00:04.0, compute capability: 6.0


Found 55446 files belonging to 7 classes.
Using 5544 files for validation.
Epoch 1/30


I0000 00:00:1744712148.670470     105 service.cc:148] XLA service 0x7bbfb4004ca0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1744712148.671365     105 service.cc:156]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1744712154.505427     105 cuda_dnn.cc:529] Loaded cuDNN version 90300
I0000 00:00:1744712181.939858     105 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.


[1m1560/1560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m839s[0m 481ms/step - accuracy: 0.2794 - loss: 1.9586 - val_accuracy: 0.5422 - val_loss: 1.3059
Epoch 2/30
[1m1560/1560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m718s[0m 460ms/step - accuracy: 0.5476 - loss: 1.3175 - val_accuracy: 0.6533 - val_loss: 1.1513
Epoch 3/30
[1m1560/1560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m716s[0m 458ms/step - accuracy: 0.6138 - loss: 1.2022 - val_accuracy: 0.6595 - val_loss: 1.1105
Epoch 4/30
[1m1560/1560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m713s[0m 457ms/step - accuracy: 0.6427 - loss: 1.1440 - val_accuracy: 0.6732 - val_loss: 1.0852
Epoch 5/30
[1m1560/1560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m714s[0m 457ms/step - accuracy: 0.6701 - loss: 1.0942 - val_accuracy: 0.7098 - val_loss: 1.0121
Epoch 6/30
[1m1560/1560[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m712s[0m 456ms/step - accuracy: 0.6883 - loss: 1.0557 - val_accuracy: 0.7087 - val_loss: 1.0049
Epo

Save the Vision Transformer Model

In [None]:
# Save entire model
vit_model.save('/content/vit_model.keras')  # or .keras