In [None]:
import tensorflow as tf

# Custom Model Class
class CustomModel(tf.keras.Model):
    def __init__(self, n_classes):
        super().__init__()
        self.hidden1 = tf.keras.layers.Dense(128, activation='relu')
        self.hidden2 = tf.keras.layers.Dense(64, activation='relu')
        self.output_layer = tf.keras.layers.Dense(n_classes, activation='softmax')

    def call(self, inputs):
        x = self.hidden1(inputs)
        x = self.hidden2(x)
        return self.output_layer(x)

# Custom Training Loop
def train_model(model, X_train, y_train, X_val, y_val, learning_rate=0.001, epochs=10, batch_size=32):
    optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy()
    train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
    val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

    train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(buffer_size=1024).batch(batch_size)
    val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val)).batch(batch_size)

    for epoch in range(epochs):
        print(f"\nStart of epoch {epoch}")

        # Training loop
        for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                logits = model(x_batch_train, training=True)
                loss_value = loss_fn(y_batch_train, logits)

            grads = tape.gradient(loss_value, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_weights))

            train_acc_metric.update_state(y_batch_train, logits)

        # Display metrics
        train_acc = train_acc_metric.result()
        print(f"Training acc over epoch: {float(train_acc):.4f}")

        train_acc_metric.reset_states()

        # Validation loop
        for x_batch_val, y_batch_val in val_dataset:
            val_logits = model(x_batch_val, training=False)
            val_acc_metric.update_state(y_batch_val, val_logits)

        val_acc = val_acc_metric.result()
        print(f"Validation acc: {float(val_acc):.4f}")
        val_acc_metric.reset_states()

# Usage Example Function
def run_custom_training(X, y, input_shape=(10,), n_classes=10, epochs=10):
    X_train, X_test, y_train, y_test = prepare_data(X, y)

    custom_model = CustomModel(n_classes=n_classes)
    train_model(custom_model, X_train, y_train, X_test, y_test, epochs=epochs)

# Sample usage with custom dataset
# run_custom_training(X, y, input_shape=(X_train.shape[1],), n_classes=10, epochs=20)
