In [4]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

from src.models.resnet_ssd_v0 import create_model

In [None]:
class ModelTrainingFramework:
    def __init__(self, model, train_data, val_data, batch_size=32, epochs=20):
        self.model = model
        self.train_data = train_data
        self.val_data = val_data
        self.batch_size = batch_size
        self.epochs = epochs 
        
    def prepare_dataset(self):
        train_dataset = tf.data.Dataset.from_tensor_slices(self.train_data)
        train_dataset = train_dataset.batch(self.batch_size).shuffle(buffer_size=1000)
        
        val_dataset = tf.data.Dataset.from_tensor_slices(self.val_data)
        val_dataset = val_dataset.batch(self.batch_size)
        
        return train_dataset, val_dataset
    
    def train(self):
        train_dataset, val_dataset = self.prepare_dataset()
        
        self.model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy', 'f1_score'])
        history = self.model.fit(train_dataset, epochs=self.epochs, validation_data=val_dataset)
        
        return history
    
    def plot_history(self, history):
        plt.figure(figsize=(12,4))
        
        plt.subplot(1, 2, 1)
        plt.plot(history.history['accuracy'], label='Training Accuracy')
        plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Accuracy Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Accuracy')
        plt.legend()
        
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'], label='Training Loss')
        plt.plot(history.history['val_loss'], label='Validation Loss')
        plt.title('Loss Over Epochs')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        
        plt.show()
        
    def evaluate_model(self):
        val_dataset = tf.data.Dataset.from_tensor_slices(self.val_data).batch(self.batch_size)
        loss, accuracy = self.model.evaluate(val_dataset)
        print(f"Validation Loss: {loss}")
        print(f"Validation Accuracy: {accuracy}")

In [None]:
ssd_model = create_model() #takes output shape, input shape

framework = ModelTrainingFramework(model=ssd_model, train_data=, val_data=, batch_size=, epochs=)

history = framework.train()

framework.plot_history(history)

framework.evaluate_model()

In [None]:
#merge overlapping bounding boxes using NMS (keeps most confident)

def iou(box1, box2):
    x1, y1, w1, h1 = box1
    x2, y2, w2, h2 = box2
    
    xi1 = max(x1, x2)
    yi1 = max(y1, y2)
    xi2 = min(x1+w1, x2+w2)
    yi2 = min(y1+h1, y2+h2)
    
    inter_area = max(0, xi2 - xi1) * max(0, yi2 - yi1)
    box1_area = w1 * h1
    box2_area = w2 * h2
    iou = inter_area / (box1_area + box2_area - inter_area)
    
    return iou

def NMS(boxes, scores, iou_threshold):
    indices = np.argsort(scores)[::-1]
    keep = []
    
    while len(indices) > 0:
        idx = indices[0]
        keep.append(idx)
        indices = indices[1:]
        
        filtered = []
        for i in indices:
            if iou(boxes[idx], boxes[i]) < iou_threshold:
                filtered.append(i)
                
        indices = np.array(filtered)
        
    return keep