<a href="https://www.kaggle.com/code/hazemegy/bone-fracture-classification-using-vit?scriptVersionId=179103589" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

In [1]:
import numpy as np
import pandas as pd 
import os
import glob

In [2]:
import os
import numpy as np
from PIL import Image
from tensorflow.keras.utils import Sequence

class CustomImageDataGenerator(Sequence):
    def __init__(self, image_paths, labels, batch_size, target_size, rescale):
        self.image_paths = image_paths
        self.labels = labels
        self.batch_size = batch_size
        self.target_size = target_size
        self.rescale = rescale
        self.indices = np.arange(len(self.image_paths))

    def __len__(self):
        return int(np.ceil(len(self.image_paths) / self.batch_size))

    def __getitem__(self, idx):
        batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_x = [self.load_image(self.image_paths[i]) for i in batch_indices]
        batch_y = [self.labels[i] for i in batch_indices]

        return np.array(batch_x), np.array(batch_y)

    def load_image(self, image_path):
        try:
            img = Image.open(image_path)
            if img.mode != 'RGB':
                img = img.convert('RGB')
            img = img.resize(self.target_size)
            img_array = np.array(img)
            img_array = img_array * self.rescale
            return img_array
        except (OSError, ValueError) as e:
            print(f"Error loading image {image_path}: {e}")
            return np.zeros((self.target_size[0], self.target_size[1], 3))  # Return a black image of the target size

    def on_epoch_end(self):
        np.random.shuffle(self.indices)

def filter_truncated_images(directory, target_size):
    image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff'}
    image_paths = []
    for root, _, files in os.walk(directory):
        for file in files:
            if any(file.lower().endswith(ext) for ext in image_extensions):
                image_path = os.path.join(root, file)
                try:
                    img = Image.open(image_path)
                    if img.mode != 'RGB':
                        img = img.convert('RGB')
                    img = img.resize(target_size)
                    img_array = np.array(img)
                    img_array = img_array * 1./255  # Rescale the image
                    image_paths.append(image_path)
                except (OSError, ValueError) as e:
                    print(f"Skipping corrupted image: {image_path} due to {e}")
    return image_paths

# Example usage
train_image_paths = filter_truncated_images('/kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/train', target_size=(224, 224))
val_image_paths = filter_truncated_images('/kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/val', target_size=(224, 224))
test_image_paths = filter_truncated_images('/kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/test', target_size=(224, 224))

# Generate labels based on file paths
train_labels = [0 if 'not fractured' in path else 1 for path in train_image_paths]
val_labels = [0 if 'not fractured' in path else 1 for path in val_image_paths]
test_labels = [0 if 'not fractured' in path else 1 for path in test_image_paths]

batch_size = 32  
target_size = (224, 224) 

train_generator = CustomImageDataGenerator(train_image_paths, train_labels, batch_size, target_size, rescale=1./255)
val_generator = CustomImageDataGenerator(val_image_paths, val_labels, batch_size, target_size, rescale=1./255)
test_generator = CustomImageDataGenerator(test_image_paths, test_labels, batch_size, target_size, rescale=1./255)


2024-05-22 10:17:17.156778: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-22 10:17:17.156872: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-22 10:17:17.292727: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Skipping corrupted image: /kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/train/not fractured/IMG0004347.jpg due to image file is truncated (40 bytes not processed)
Skipping corrupted image: /kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/train/not fractured/IMG0004148.jpg due to image file is truncated (14 bytes not processed)
Skipping corrupted image: /kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/train/not fractured/IMG0004134.jpg due to image file is truncated (1 bytes not processed)
Skipping corrupted image: /kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/train/not fractured/IMG0004149.jpg due to image file is truncated (33 bytes not processed)
Skipping corrupted image: /kaggle/input/fracture-multi-re



Skipping corrupted image: /kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/test/not fractured/IMG0004347.jpg due to image file is truncated (40 bytes not processed)
Skipping corrupted image: /kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/test/not fractured/IMG0004148.jpg due to image file is truncated (14 bytes not processed)
Skipping corrupted image: /kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/test/not fractured/IMG0004134.jpg due to image file is truncated (1 bytes not processed)
Skipping corrupted image: /kaggle/input/fracture-multi-region-x-ray-data/Bone_Fracture_Binary_Classification/Bone_Fracture_Binary_Classification/test/not fractured/IMG0004149.jpg due to image file is truncated (33 bytes not processed)
Skipping corrupted image: /kaggle/input/fracture-multi-region

In [3]:
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

In [4]:
import tensorflow as tf
from tensorflow.keras import layers

class PatchEmbedding(layers.Layer):
    def __init__(self, patch_size, embed_dim):
        super(PatchEmbedding, self).__init__()
        self.patch_size = patch_size
        self.embed_dim = embed_dim
        self.proj = layers.Dense(embed_dim)

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding='VALID'
        )
        patch_dims = self.patch_size * self.patch_size * tf.shape(images)[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return self.proj(patches)

class MultiHeadSelfAttention(layers.Layer):
    def __init__(self, num_heads, embed_dim):
        super(MultiHeadSelfAttention, self).__init__()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)

    def call(self, inputs):
        return self.attention(inputs, inputs)

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout_rate):
        super(TransformerBlock, self).__init__()
        self.attention = MultiHeadSelfAttention(num_heads, embed_dim)
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = tf.keras.Sequential([
            layers.Dense(mlp_dim, activation='gelu'),
            layers.Dropout(dropout_rate),
            layers.Dense(embed_dim),
            layers.Dropout(dropout_rate)
        ])

    def call(self, inputs):
        x = self.layernorm1(inputs)
        x = self.attention(x)
        x = x + inputs
        x = self.layernorm2(x)
        x = self.mlp(x)
        return x + inputs

class VisionTransformer(tf.keras.Model):
    def __init__(self, image_size, patch_size, embed_dim, num_heads, num_blocks, mlp_dim, num_classes, dropout_rate=0.1):
        super(VisionTransformer, self).__init__()
        self.patch_embed = PatchEmbedding(patch_size, embed_dim)
        height, width, _ = image_size
        num_patches = (height // patch_size) * (width // patch_size)
        self.pos_embed = self.add_weight(name="pos_embed", shape=(1, num_patches + 1, embed_dim), initializer=tf.initializers.RandomNormal(stddev=0.02), trainable=True)
        self.cls_token = self.add_weight(name="cls_token", shape=(1, 1, embed_dim), initializer=tf.initializers.RandomNormal(stddev=0.02), trainable=True)
        self.dropout = layers.Dropout(dropout_rate)
        self.transformer_blocks = [TransformerBlock(embed_dim, num_heads, mlp_dim, dropout_rate) for _ in range(num_blocks)]
        self.layernorm = layers.LayerNormalization(epsilon=1e-6)
        self.classifier = layers.Dense(num_classes)

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = self.patch_embed(images)
        cls_tokens = tf.broadcast_to(self.cls_token, [batch_size, 1, tf.shape(patches)[-1]])
        x = tf.concat([cls_tokens, patches], axis=1)
        x = x + self.pos_embed
        x = self.dropout(x)
        for block in self.transformer_blocks:
            x = block(x)
        x = self.layernorm(x)
        cls_token_final = x[:, 0]  # Extract the cls_token for classification
        return self.classifier(cls_token_final)

In [5]:
# Define your model parameters (Reduced complexity)
image_size = (224, 224, 3)
patch_size = 16
embed_dim = 256
num_heads = 8
num_blocks = 6
mlp_dim = 256
num_classes = 1
dropout_rate = 0.1
learning_rate = 1e-4

vit_model = VisionTransformer(image_size, patch_size, embed_dim, num_heads, num_blocks, mlp_dim, num_classes, dropout_rate)
vit_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
vit_model.summary()

In [6]:
# Train the model
epochs = 2
vit_model.fit(train_generator, epochs=epochs, validation_data=val_generator)

Epoch 1/2


  self._warn_if_super_not_called()
I0000 00:00:1716373180.137797      67 device_compiler.h:186] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
W0000 00:00:1716373180.268941      67 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m 51/289[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m1:20[0m 339ms/step - accuracy: 0.5478 - loss: nan

W0000 00:00:1716373197.259212      67 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m289/289[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 236ms/step - accuracy: 0.5318 - loss: nan - val_accuracy: 0.5905 - val_loss: nan
Epoch 2/2


W0000 00:00:1716373248.282537      69 graph_launch.cc:671] Fallback to op-by-op mode because memset node breaks graph update


[1m289/289[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m58s[0m 196ms/step - accuracy: 0.5015 - loss: nan - val_accuracy: 0.5905 - val_loss: nan


<keras.src.callbacks.history.History at 0x7ea92c46ec80>