**DeiT without SAVE**

Import Library

In [1]:
import os
import time
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import Flatten, Dense, Layer, Lambda
from tensorflow.keras.models import Model
from tensorflow.keras.callbacks import ModelCheckpoint
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from vit_keras import vit

Dataset Configuration and Class Labels

In [None]:
source_dir = "Blended malware image"
image_size = (224, 224)
batch_size = 16

classes = sorted(
    d for d in os.listdir(source_dir)
    if os.path.isdir(os.path.join(source_dir, d))
)

Gather Files Path and Integer Labels

In [3]:
file_paths, labels = [], []
for label, cls in enumerate(classes):
    cls_dir = os.path.join(source_dir, cls)
    for fname in os.listdir(cls_dir):
        fpath = os.path.join(cls_dir, fname)
        if os.path.isfile(fpath):
            file_paths.append(fpath)
            labels.append(label)

Split Dataset

In [4]:
X_train, X_temp, y_train, y_temp = train_test_split(file_paths, labels, test_size=0.30, stratify=labels, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.50, stratify=y_temp, random_state=42)

TF-Data Pipeline

In [5]:
def process_path(file_path, label):
    img = tf.io.decode_jpeg(tf.io.read_file(file_path), channels=3)
    img = tf.image.resize(img, image_size)
    img = tf.keras.applications.imagenet_utils.preprocess_input(img)
    label = tf.one_hot(label, depth=len(classes))
    return img, label

def make_dataset(X, y, shuffle=True):
    ds = tf.data.Dataset.from_tensor_slices((X, y))
    if shuffle:
        ds = ds.shuffle(buffer_size=len(X))
    return (
        ds
        .map(process_path, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )

train_ds = make_dataset(X_train, y_train, shuffle=True)
val_ds = make_dataset(X_val, y_val, shuffle=False)
test_ds = make_dataset(X_test, y_test, shuffle=False)

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'


DeiT Model

In [6]:
def build_deit_model(num_classes, image_size=image_size, patch_size=16):
    # Load pretrained ViT backbone
    backbone = vit.vit_b16(
        image_size=image_size,
        pretrained=True,
        include_top=False,
        pretrained_top=False
    )

    # Extract sequence output before classification token
    seq_output = backbone.get_layer("Transformer/encoder_norm").output
    seq_model  = Model(inputs=backbone.input, outputs=seq_output)

    # Fold patch tokens into feature map
    def fold_patches(x):
        patch_tokens = x[:, 1:, :]
        return tf.reshape(
            patch_tokens,
           (-1,
             image_size[0] // patch_size,
             image_size[1] // patch_size,
             x.shape[-1])
        )

    x = Lambda(fold_patches, name="fold_patches")(seq_model.output)

    # Classification head
    x = Flatten()(x)
    x = Dense(1000, activation="relu")(x)
    outputs = Dense(num_classes, activation="softmax")(x)

    model = Model(inputs=seq_model.input, outputs=outputs)
    model.compile(
        optimizer='adam',
        loss='categorical_crossentropy',
        metrics=[
            'accuracy',
            tf.keras.metrics.AUC(name='auc'),
            tf.keras.metrics.FalsePositives(name='false_positives'),
            tf.keras.metrics.Precision(name='precision'),
            tf.keras.metrics.Recall(name='recall'),
        ]
    )
    return model

model = build_deit_model(num_classes=len(classes))
model.summary()



Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 embedding (Conv2D)          (None, 14, 14, 768)       590592    
                                                                 
 reshape (Reshape)           (None, 196, 768)          0         
                                                                 
 class_token (ClassToken)    (None, 197, 768)          768       
                                                                 
 Transformer/posembed_input   (None, 197, 768)         151296    
 (AddPositionEmbs)                                               
                                                                 
 Transformer/encoderblock_0   ((None, 197, 768),       7087872   
 (TransformerBlock)           (None, 12, None, None))      

Model Training

In [None]:
os.makedirs('checkpoints', exist_ok=True)
os.makedirs('models', exist_ok=True)

checkpoint_cb = ModelCheckpoint(
    filepath='checkpoints/vit_malware_best.h5',
    monitor='val_accuracy',
    save_best_only=True,
    verbose=1
)

train_ds_inf = make_dataset(X_train, y_train, shuffle=True) \
                   .repeat()  

history = model.fit(
    train_ds_inf,
    epochs=20,
    steps_per_epoch=300,      
    validation_data=val_ds,
    callbacks=[checkpoint_cb]
)

model.save('models/vit_malware_final.h5')

Epoch 1/20
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: 'arguments' object has no attribute 'posonlyargs'

Epoch 1: val_accuracy improved from -inf to 0.85911, saving model to checkpoints\vit_malware_best.h5
Epoch 2/20
Epoch 2:

Plot Training History

In [None]:
def plot_history(hist):
    metrics = ['loss', 'accuracy', 'auc', 'false_positives', 'precision', 'recall']
    fig, axes = plt.subplots(3, 2, figsize=(12, 12))
    for idx, m in enumerate(metrics):
        ax = axes[idx // 2, idx % 2]
        ax.plot(hist.history[m], label='train')
        ax.plot(hist.history[f'val_{m}'], label='val')
        ax.set_title(m)
        ax.set_xlabel('epoch')
        ax.legend()
    plt.tight_layout()
    plt.show()

plot_history(history)

Model Validation

In [None]:
val_metrics = model.evaluate(val_ds, verbose=1)
print("\nValidation set metrics:")
for name, value in zip(model.metrics_names, val_metrics):
    print(f" - {name}: {value:.4f}")

Model Testing

In [None]:
start = time.time()
preds = model.predict(test_ds)
print(f"\nInference time on test set: {time.time() - start:.2f}s")

y_true = np.argmax(
    np.vstack([y for _, y in test_ds.unbatch()]), axis=1
)
y_pred = np.argmax(preds, axis=1)

print("\nClassification Report:\n")
print(classification_report(y_true, y_pred, target_names=classes))

Model Evaluation

In [None]:
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
plt.title('Confusion Matrix Heatmap')
plt.imshow(cm, interpolation='nearest', aspect='auto')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.colorbar(fraction=0.046, pad=0.04)

tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45, ha='right')
plt.yticks(tick_marks, classes)

# annotate each cell
thresh = cm.max() / 2
for i, j in np.ndindex(cm.shape):
    plt.text(
        j, i, f"{cm[i, j]:d}",
        ha='center', va='center',
        color='white' if cm[i, j] > thresh else 'black'
    )

plt.tight_layout()
plt.show()