# Skin Disease Classification using Vision Transformer (ViT)

This notebook implements skin disease classification using the Vision Transformer architecture.

In [None]:
import tensorflow as tf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report, confusion_matrix
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import tensorflow_addons as tfa

## 1. Vision Transformer Implementation

In [None]:
def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = tf.keras.layers.Dense(units, activation=tf.nn.gelu)(x)
        x = tf.keras.layers.Dropout(dropout_rate)(x)
    return x

class Patches(tf.keras.layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = tf.keras.layers.Dense(units=projection_dim)
        self.position_embedding = tf.keras.layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

## 2. Build ViT Model

In [None]:
def create_vit_classifier():
    input_shape = (224, 224, 3)
    patch_size = 16
    num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)
    projection_dim = 64
    num_heads = 4
    transformer_units = [projection_dim * 2, projection_dim]
    transformer_layers = 8
    mlp_head_units = [2048, 1024]
    num_classes = 7

    inputs = tf.keras.layers.Input(shape=input_shape)
    patches = Patches(patch_size)(inputs)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    for _ in range(transformer_layers):
        x1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = tf.keras.layers.MultiHeadAttention(
            num_heads=num_heads, key_dim=projection_dim, dropout=0.1
        )(x1, x1)
        x2 = tf.keras.layers.Add()([attention_output, encoded_patches])
        x3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        encoded_patches = tf.keras.layers.Add()([x3, x2])

    representation = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    representation = tf.keras.layers.Flatten()(representation)
    representation = tf.keras.layers.Dropout(0.3)(representation)

    features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.3)
    logits = tf.keras.layers.Dense(num_classes)(features)
    outputs = tf.keras.layers.Activation('softmax')(logits)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

## 3. Training Setup

In [None]:
# Data generators
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest'
)

val_datagen = ImageDataGenerator(rescale=1./255)

# Model compilation
model = create_vit_classifier()
optimizer = tfa.optimizers.AdamW(
    learning_rate=3e-4,
    weight_decay=1e-4
)

model.compile(
    optimizer=optimizer,
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Callbacks
checkpoint = ModelCheckpoint(
    'best_model_vit.h5',
    monitor='val_accuracy',
    save_best_only=True
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=10,
    restore_best_weights=True
)

## 4. Attention Visualization

In [None]:
def plot_attention_maps(model, image, layer_name):
    attention_layer = model.get_layer(layer_name)
    attention_model = tf.keras.Model(
        inputs=model.input,
        outputs=attention_layer.output
    )
    
    attention_output = attention_model.predict(image)
    attention_weights = attention_output[1]  # Get attention weights
    
    plt.figure(figsize=(15, 8))
    plt.imshow(attention_weights[0], cmap='viridis')
    plt.colorbar()
    plt.title(f'Attention Map from {layer_name}')
    plt.show()

def visualize_predictions(model, images, true_labels, class_names):
    predictions = model.predict(images)
    plt.figure(figsize=(15, 5))
    for i in range(min(5, len(images))):
        plt.subplot(1, 5, i + 1)
        plt.imshow(images[i])
        plt.title(f'True: {class_names[true_labels[i]]}\nPred: {class_names[np.argmax(predictions[i])]}')
        plt.axis('off')
    plt.tight_layout()
    plt.show()

In [None]:
# Create model saving directory
model_save_dir = os.path.join('..', 'models', 'vit')
os.makedirs(model_save_dir, exist_ok=True)

# Update checkpoint callback
checkpoint = ModelCheckpoint(
    os.path.join(model_save_dir, 'best_model_vit.h5'),
    monitor='val_accuracy',
    save_best_only=True,
    mode='max'
)

# Add model saving after training
model.save(os.path.join(model_save_dir, 'final_model_vit.h5'))