In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
from datetime import datetime
from vit_keras import vit

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Paths in Google Drive
base_path = '/content/drive/My Drive/vit_dnn/'
train_data_path = base_path + 'Processed_Training_Data_Single/preprocessed_batch_final.npy'
train_labels_path = base_path + 'Processed_Training_Data_Single/labels_batch_final.npy'
val_data_path = base_path + 'Processed_Validation_Data_Single/validation_preprocessed_batch_final.npy'
val_labels_path = base_path + 'Processed_Validation_Data_Single/validation_labels_batch_final.npy'

In [None]:
class PersonalityAnalysisModel(tf.keras.Model):
    def __init__(self):
        super(PersonalityAnalysisModel, self).__init__()
        # Define the Vision Transformer model
        self.vit_model = vit.vit_b16(
            image_size=224,
            activation='sigmoid',
            pretrained=True,
            include_top=False,
            pretrained_top=False
        )
        # Freeze the pre-trained layers
        for layer in self.vit_model.layers:
            layer.trainable = False

        self.dense1 = tf.keras.layers.Dense(128, activation='relu')
        self.dropout = tf.keras.layers.Dropout(0.5)
        self.output_layer = tf.keras.layers.Dense(5, activation='sigmoid')

    def call(self, inputs):
        x = self.vit_model(inputs)
        x = self.dense1(x)
        x = self.dropout(x)
        return self.output_layer(x)

# Instantiate the model
model = PersonalityAnalysisModel()

# Compile the model as before
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss='mean_squared_error',
    metrics=[tf.keras.metrics.MeanAbsoluteError()]
)

In [None]:
# Set up TensorBoard for monitoring (logs will be stored in Google Drive)
log_dir = base_path + "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [None]:
# Additional Callbacks
checkpoint_path = base_path + 'best_model_subclass'  # Updated path without .h5 extension
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor='val_loss',
    save_best_only=True,
    save_weights_only=False,
    save_format="tf"  # Save in TensorFlow SavedModel format
)
early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor='val_loss',
    patience=3  # Reduced patience for early stopping
)

In [None]:
# Data generator function
def data_generator(data_path, labels_path, batch_size):
    while True:
        X = np.load(data_path, mmap_mode='r')
        y_all_annotations = np.load(labels_path, mmap_mode='r')
        y = y_all_annotations  

        indices = np.arange(len(X))
        np.random.shuffle(indices)

        for start_idx in range(0, len(X), batch_size):
            end_idx = min(start_idx + batch_size, len(X))
            batch_indices = indices[start_idx:end_idx]

            X_batch = X[batch_indices]
            y_batch = y[batch_indices]

            yield X_batch, y_batch

# Use the data generator for training and validation
batch_size = 32  # Increased batch size
train_gen = data_generator(train_data_path, train_labels_path, batch_size)
val_gen = data_generator(val_data_path, val_labels_path, batch_size)

In [None]:
# Training the model
try:
  history = model.fit(
      train_gen,
      steps_per_epoch=len(np.load(train_data_path, mmap_mode='r')) // batch_size,
      epochs=10,
      validation_data=val_gen,
      validation_steps=len(np.load(val_data_path, mmap_mode='r')) // batch_size,
      verbose=1,
      callbacks=[tensorboard_callback, checkpoint_callback, early_stopping_callback]
  )
except Exception as e:
    print(f"Error during training: {str(e)}")
    raise