In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

In [2]:
img_size = (133, 133)
batch_size = 32

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "Solution Dataset/train",
    image_size=img_size,
    batch_size=batch_size,
    label_mode='int'
)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "Solution Dataset/val",
    image_size=img_size,
    batch_size=batch_size,
    label_mode='int'
)

# Load the test dataset
test_ds = tf.keras.preprocessing.image_dataset_from_directory(
    "Solution Dataset/test/",
    image_size=img_size,
    batch_size=batch_size,
    label_mode='int'
)

Found 284900 files belonging to 2 classes.
Found 60157 files belonging to 2 classes.
Found 60157 files belonging to 2 classes.


In [3]:
def vit_preprocess_fn(image, label):
    image = tf.cast(image, tf.float32) / 127.5 - 1.0
    return image, label

In [4]:
AUTOTUNE = tf.data.AUTOTUNE # Prefetch the datasets for performance. This allows the data loading of the next batch to be done in parallel with model training from the current batch.

# Preprocess the datasets for ViT. Caching in RAM.
ViT_train_ds = train_ds.map(vit_preprocess_fn, num_parallel_calls=AUTOTUNE) \
                       .cache() \
                       .shuffle(1000) \
                       .prefetch(AUTOTUNE)

ViT_val_ds = val_ds.map(vit_preprocess_fn, num_parallel_calls=AUTOTUNE) \
                       .cache() \
                       .prefetch(AUTOTUNE)

ViT_test_ds = test_ds.map(vit_preprocess_fn, num_parallel_calls=AUTOTUNE) \
                       .cache() \
                       .prefetch(AUTOTUNE)

In [5]:
#ViT Setup
from tensorflow.keras import layers, models
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint

In [6]:
# Positional Encoding Layer
class PositionalEncoding(tf.keras.layers.Layer):
    def __init__(self, max_steps, max_dims, **kwargs):
        super().__init__(**kwargs)
        if max_dims % 2 == 1:
            max_dims += 1
        self.max_steps = max_steps
        self.max_dims = max_dims

        p = np.arange(max_steps)[:, np.newaxis]
        i = np.arange(max_dims // 2)[np.newaxis, :]
        angle_rates = 1 / np.power(10000, (2 * i) / np.float32(max_dims))

        pos_emb = np.zeros((max_steps, max_dims))
        pos_emb[:, 0::2] = np.sin(p * angle_rates)
        pos_emb[:, 1::2] = np.cos(p * angle_rates)

        self.positional_embedding = tf.constant(pos_emb.astype(np.float32))

    def call(self, inputs):
        shape = tf.shape(inputs)
        return inputs + self.positional_embedding[:shape[-2], :shape[-1]]

    def get_config(self):
        config = super().get_config()
        config.update({
            "max_steps": self.max_steps,
            "max_dims": self.max_dims
        })
        return config

In [7]:
# Transformer Encoder
def transformer_encoder(x, num_heads, ff_dim):
    attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=ff_dim)(x, x)
    x = layers.LayerNormalization(epsilon=1e-6)(x + attention)
    ffn = tf.keras.Sequential([
        layers.Dense(ff_dim, activation='gelu'),
        layers.Dense(x.shape[-1])
    ])
    x = layers.LayerNormalization(epsilon=1e-6)(x + ffn(x))
    return x


In [8]:
# Patch Embedding Layer
class PatchEmbedding(tf.keras.layers.Layer):
    def __init__(self, patch_size, num_patches, projection_dim, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size
        self.num_patches = num_patches
        self.projection_dim = projection_dim

        self.projection = tf.keras.layers.Dense(projection_dim)
        self.positions = tf.range(start=0, limit=num_patches, delta=1)
        self.position_embedding = tf.keras.layers.Embedding(input_dim=num_patches, output_dim=projection_dim)

    def build(self, input_shape):
        # No manual weight creation needed, but this tells Keras we're built
        super().build(input_shape)

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding='VALID',
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        embeddings = self.projection(patches) + self.position_embedding(self.positions)
        return embeddings

    def get_config(self):
        config = super().get_config()
        config.update({
            "patch_size": self.patch_size,
            "num_patches": self.num_patches,
            "projection_dim": self.projection_dim
        })
        return config

In [101]:
def build_ViT(image_size=133, patch_size=16, num_classes=1, projection_dim=64, transformer_layers=4):
    num_patches = (image_size // patch_size) ** 2
    inputs = layers.Input(shape=(image_size, image_size, 3))

    x = PatchEmbedding(patch_size, num_patches, projection_dim)(inputs)
    x = PositionalEncoding(max_steps=num_patches, max_dims=projection_dim)(x)

    for _ in range(transformer_layers):
        x = transformer_encoder(x, num_heads=4, ff_dim=projection_dim * 2)

    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.GlobalAveragePooling1D()(x)
    x = layers.Dense(64, activation='gelu')(x)
    outputs = layers.Dense(num_classes, activation='sigmoid')(x) 

    return models.Model(inputs, outputs)

In [111]:
from tensorflow.keras import backend as K

def focal_loss(gamma=2.0, alpha=0.75):
    def loss(y_true, y_pred):
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
        focal_weight = alpha_factor * K.pow((1 - p_t), gamma)
        return -K.mean(focal_weight * K.log(p_t))
    return loss

In [112]:
# Build & Compile Model
ViT_model = build_ViT()
ViT_model.compile(optimizer='adam',
                  loss=focal_loss(gamma=2, alpha=0.75),
                  metrics=[
                    'accuracy',
                    tf.keras.metrics.Precision(name='precision', thresholds=0.5),
                    tf.keras.metrics.Recall(name='recall', thresholds=0.5)
                ])

In [117]:
class_weights = {0: 0.5, 1: 1000.0}  # Tune as needed

In [118]:
# Define callbacks
early_stopping = EarlyStopping(monitor='val_loss', patience=8, restore_best_weights=True)
ViT_model_checkpoint = ModelCheckpoint('Vision Transformer/Model Epoch Checkpoints/vit_model_epoch_{epoch:02d}.keras', save_weights_only=False,                  # Save the full model
    save_freq='epoch',                        # Save after each epoch
    save_best_only=False,                     # Save every epoch, not just the best
    verbose=1)

In [119]:
ViT_model.summary()

In [None]:
# Train block
ViT_history = ViT_model.fit(
    ViT_train_ds,
    class_weight=class_weights,
    epochs=3,
    validation_data=ViT_val_ds,
    verbose=1,
    callbacks=[early_stopping, ViT_model_checkpoint]
)

# Save the model
ViT_model.save('Vision Transformer/Model Epoch Checkpoints/vit_final_model.keras')  # Save the final model

Epoch 1/3
[1m  44/8904[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m11:01[0m 75ms/step - accuracy: 0.9762 - loss: 0.7546 - precision: 0.0000e+00 - recall: 0.0000e+00

In [133]:
# Train block
ViT_history = ViT_model.fit(
    ViT_train_ds,
    class_weight=class_weights,
    epochs=5,
    initial_epoch=3,
    validation_data=ViT_val_ds,
    verbose=1,
    callbacks=[early_stopping, ViT_model_checkpoint]
)

Epoch 4/5
[1m8904/8904[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 75ms/step - accuracy: 0.9802 - loss: 0.4602 - precision: 0.0165 - recall: 0.0044
Epoch 4: saving model to Vision Transformer/Model Epoch Checkpoints/vit_model_epoch_04.keras
[1m8904/8904[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m725s[0m 81ms/step - accuracy: 0.9802 - loss: 0.4602 - precision: 0.0165 - recall: 0.0044 - val_accuracy: 0.9990 - val_loss: 0.0130 - val_precision: 0.0000e+00 - val_recall: 0.0000e+00
Epoch 5/5
[1m8904/8904[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 75ms/step - accuracy: 0.9842 - loss: 0.4428 - precision: 0.0000e+00 - recall: 0.0000e+00
Epoch 5: saving model to Vision Transformer/Model Epoch Checkpoints/vit_model_epoch_05.keras
[1m8904/8904[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m719s[0m 81ms/step - accuracy: 0.9842 - loss: 0.4428 - precision: 0.0000e+00 - recall: 0.0000e+00 - val_accuracy: 0.9990 - val_loss: 0.0090 - val_precision: 0.0000e+00 - val_recall:

In [123]:
from tensorflow.keras import backend as K
from tensorflow.keras.utils import register_keras_serializable

@register_keras_serializable()
def focal_loss(gamma=2.0, alpha=0.75):
    def loss(y_true, y_pred):
        epsilon = K.epsilon()
        y_pred = K.clip(y_pred, epsilon, 1. - epsilon)
        p_t = y_true * y_pred + (1 - y_true) * (1 - y_pred)
        alpha_factor = y_true * alpha + (1 - y_true) * (1 - alpha)
        focal_weight = alpha_factor * K.pow((1 - p_t), gamma)
        return -K.mean(focal_weight * K.log(p_t))
    return loss

In [124]:
from tensorflow.keras.models import load_model
ViTmodel = load_model('Vision Transformer/Model Epoch Checkpoints/vit_model_epoch_03.keras', custom_objects={'PatchEmbedding': PatchEmbedding, 'PositionalEncoding': PositionalEncoding}, compile=False)

In [125]:
ViT_model.compile(optimizer='adam',
                  loss=focal_loss(gamma=2.0, alpha=0.75),
                  metrics=[
                    'accuracy',
                    tf.keras.metrics.Precision(name='precision', thresholds=0.5),
                    tf.keras.metrics.Recall(name='recall', thresholds=0.5)
                ])

In [134]:
loss, accuracy, precision, recall = ViT_model.evaluate(ViT_test_ds)
print(f"Loss: {loss:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

[1m1880/1880[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m51s[0m 27ms/step - accuracy: 0.9991 - loss: 0.0090 - precision: 0.0000e+00 - recall: 0.0000e+00
Loss: 0.0090
Accuracy: 0.9990
Precision: 0.0000
Recall: 0.0000


In [135]:
# Fine-tune the classification threshold on the validation dataset

from sklearn.metrics import recall_score
import numpy as np

# Predictions and true labels
predictions = ViT_model.predict(ViT_val_ds).flatten()
true_classes = np.concatenate([y.numpy() for _, y in ViT_val_ds], axis=0)

# Thresholds to test
thresholds = np.arange(0.0, 1.001, 0.001)

recalls_class_0 = []
recalls_class_1 = []

# Evaluate each threshold
for thresh in thresholds:
    preds = (predictions >= thresh).astype(int)
    recall_per_class = recall_score(true_classes, preds, average=None, zero_division=0)
    recalls_class_0.append(recall_per_class[0])
    recalls_class_1.append(recall_per_class[1])


# Convert to arrays
recalls_class_0 = np.array(recalls_class_0)
recalls_class_1 = np.array(recalls_class_1)

# Find threshold where the difference in recall is minimized
recall_diff = np.abs(recalls_class_0 - recalls_class_1)
crossing_idx = np.argmin(recall_diff)
crossing_threshold = thresholds[crossing_idx]
recall0 = recalls_class_0[crossing_idx]
recall1 = recalls_class_1[crossing_idx]

# Print results
print(f"Crossing threshold: {crossing_threshold:.3f}")
print(f"Recall for class 0 (benign): {recall0:.4f}")
print(f"Recall for class 1 (malignant): {recall1:.4f}")


[1m1880/1880[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 26ms/step
Crossing threshold: 0.307
Recall for class 0 (benign): 0.9882
Recall for class 1 (malignant): 0.0345


In [136]:
# Make predictions on the test dataset
import numpy as np
class_names = test_ds.class_names
predictions = ViT_model.predict(ViT_test_ds)
predicted_classes = (predictions > crossing_threshold).astype("int32")
true_classes = np.concatenate([y.numpy() for _, y in ViT_test_ds], axis=0)
from sklearn.metrics import classification_report, confusion_matrix

report = classification_report(true_classes, predicted_classes, target_names=class_names)

# Print classification report
print(report)

# Save to file
#with open("Vision Transformer/vit_classification_report.txt", "w") as f:
#    f.write(report)

[1m1880/1880[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m50s[0m 26ms/step
              precision    recall  f1-score   support

      benign       1.00      0.99      0.99     60099
   malignant       0.00      0.00      0.00        58

    accuracy                           0.99     60157
   macro avg       0.50      0.49      0.50     60157
weighted avg       1.00      0.99      0.99     60157



In [129]:
conf_matrix = confusion_matrix(true_classes, predicted_classes)
print(conf_matrix)
#np.save("Vision Transformer/vit_confusion_matrix.npy", conf_matrix)

[[35867 24232]
 [   21    37]]


In [132]:
ViT_model.save('Vision Transformer/Model Epoch Checkpoints/vit_final_model.h5')  # Save the final model

