# Health Classification Model Training

This notebook demonstrates training a health classification model using sensor data.


In [None]:
import sys
from pathlib import Path

# Add src to path
project_root = Path().absolute().parent
sys.path.insert(0, str(project_root))

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import classification_report, confusion_matrix

from src.config.settings import MODELS_DIR
from src.train_model import generate_synthetic_data, create_model
from src.edge_ml.model_converter import convert_keras_model_to_tflite

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)


## 1. Load or Generate Data


In [None]:
# Generate synthetic data (or load from database/file)
df = generate_synthetic_data(n_samples=10000)

print(f"Dataset shape: {df.shape}")
print(f"\nClass distribution:")
print(df['label'].value_counts())

df.head()


## 2. Exploratory Data Analysis


In [None]:
# Visualize data distribution
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, col in enumerate(['heart_rate', 'spo2', 'acceleration_magnitude']):
    df.boxplot(column=col, by='label', ax=axes[idx])
    axes[idx].set_title(f'{col} by Health Status')
    axes[idx].set_xlabel('Health Status (0=normal, 1=warning, 2=critical)')

plt.tight_layout()
plt.show()


In [None]:
# Correlation matrix
sns.heatmap(df[['heart_rate', 'spo2', 'acceleration_magnitude', 'label']].corr(), 
            annot=True, cmap='coolwarm', center=0)
plt.title('Feature Correlation Matrix')
plt.show()


## 3. Prepare Data for Training


In [None]:
# Prepare features and labels
feature_cols = ['heart_rate', 'spo2', 'acceleration_magnitude']
X = df[feature_cols].values
y = df['label'].values

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# Scale features
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

print(f"Training set shape: {X_train_scaled.shape}")
print(f"Test set shape: {X_test_scaled.shape}")


## 4. Create and Train Model


In [None]:
# Create model
model = create_model(input_dim=X_train_scaled.shape[1], num_classes=3)
model.summary()


In [None]:
# Train model
history = model.fit(
    X_train_scaled,
    y_train,
    epochs=50,
    batch_size=32,
    validation_split=0.2,
    verbose=1,
)


## 5. Evaluate Model


In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history.history['loss'], label='Training Loss')
axes[0].plot(history.history['val_loss'], label='Validation Loss')
axes[0].set_title('Model Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()

axes[1].plot(history.history['accuracy'], label='Training Accuracy')
axes[1].plot(history.history['val_accuracy'], label='Validation Accuracy')
axes[1].set_title('Model Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()

plt.tight_layout()
plt.show()


In [None]:
# Evaluate on test set
test_loss, test_accuracy = model.evaluate(X_test_scaled, y_test, verbose=0)
print(f"Test accuracy: {test_accuracy:.4f}")

# Predictions
y_pred = model.predict(X_test_scaled)
y_pred_classes = np.argmax(y_pred, axis=1)

# Classification report
print("\nClassification Report:")
print(classification_report(y_test, y_pred_classes, 
                          target_names=['normal', 'warning', 'critical']))


## 6. Save Model and Convert to TFLite


In [None]:
# Save model
model_path = MODELS_DIR / 'health_classifier.h5'
model.save(model_path)
print(f"Model saved to {model_path}")

# Save scaler
import pickle
scaler_path = MODELS_DIR / 'health_classifier_scaler.pkl'
with open(scaler_path, 'wb') as f:
    pickle.dump(scaler, f)
print(f"Scaler saved to {scaler_path}")


In [None]:
# Convert to TFLite
tflite_path = MODELS_DIR / 'health_classifier.tflite'
convert_keras_model_to_tflite(
    model,
    tflite_path,
    quantize=False,  # Set to True for quantization
)
print(f"TFLite model saved to {tflite_path}")
