In [3]:
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
class ModelTrainingFramework:
    def __init__(self, model, train_data, val_data, batch_size=32, epochs=20):
        self.model = model
        self.train_data = train_data
        self.val_data = val_data
        self.batch_size = batch_size
        self.epochs = epochs 
        
    def prepare_dataset(self):
        train_dataset = tf.data.Dataset.from_tensor_slices(self.train_data)
        train_dataset = train_dataset.batch(self.batch_size).shuffle(buffer_size=1000)
        
        val_dataset = tf.data.Dataset.from_tensor_slices(self.val_data)
        val_dataset = val_dataset.batch(self.batch_size)
        
        return train_dataset, val_dataset
    
    def train(self):
        train_dataset, val_dataset = self.prepare_dataset()
        
        self.model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy', 'f1_score'])
        history = self.model.fit(train_dataset, epochs=self.epochs, validation_data=val_dataset)
        
        return history
    
    def plot_history(self, history):
        plt.figure(figsize=(12,4))
        
        plt.subplot(1, 2, 1)
        plt.plot(history.history['accuracy'], label='Training Accuracy')
        plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Loss Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.show()
        
    def evaluate_model(self):
        val_dataset = tf.data.Dataset.from_tensor_slices(self.val_data).batch(self.batch_size)
        loss, accuracy = self.model.evaluate(val_dataset)
        print(f"Validation Loss: {loss}")
        print(f"Validation Accuracy: {accuracy}")

In [None]:
ssd_model = #use create_model

framework = ModelTrainingFramework(model=, train_data=, val_data=, batch_size=, epochs=)

history = framework.train()

framework.plot_history(history)

framework.evaluate_model()