# Install & Import Libraries


In [68]:
#!pip install tensorflow numpy matplotlib scikit-learn scipy paho-mqtt
import os
import numpy as np
from collections import deque
import time
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Dense, Dropout, BatchNormalization, GlobalAveragePooling1D
from tensorflow.keras.optimizers import Adam
from sklearn.preprocessing import StandardScaler
from scipy.signal import butter, lfilter
import paho.mqtt.client as mqtt
import json
from paho.mqtt.client import CallbackAPIVersion

# Define Classes 

In [70]:
sampling_rate = 20
window_size = 3 * sampling_rate  # 3s
num_channels = 6
wand_classes = ["Wave", "Circle", "Square", "Triangle", "Infinity", "Zigzag", "None"]
CONFIDENCE_THRESHOLD = 0.7 # Adjust as needed.

# Processing Functions


In [72]:
def butter_lowpass(cutoff, fs, order=4):
    nyq = 0.5 * fs
    normal_cutoff = cutoff / nyq
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    return b, a

def lowpass_filter(data, cutoff=5, fs=50, order=4):
    b, a = butter_lowpass(cutoff, fs, order=order)
    return lfilter(b, a, data)

def normalize_window(window):
    if window.size == 0:
        return window
    return (window - np.mean(window, axis=0)) / (np.std(window, axis=0) + 1e-8)

# CNN Model


In [74]:
model = Sequential([
    tf.keras.Input(shape=(window_size, num_channels)),
    Conv1D(64, 5, activation='relu'),
    BatchNormalization(),
    MaxPooling1D(2),

    Conv1D(128, 5, activation='relu'),
    BatchNormalization(),
    MaxPooling1D(2),

    Conv1D(256, 3, activation='relu'),
    BatchNormalization(),
    MaxPooling1D(2),

    Dropout(0.4),
    GlobalAveragePooling1D(),   
    Dense(128, activation='relu'),
    Dropout(0.3),
    Dense(len(wand_classes), activation='softmax')
])

model.compile(optimizer=Adam(learning_rate=0.001),
              loss='categorical_crossentropy',
              metrics=['accuracy'])

# Load fine-tuned weights
model.load_weights("cnn_finetuned_wand.keras")

# Classification Function

In [76]:
buffer = deque(maxlen=window_size)

def process_realtime_data(new_sample):
    buffer.append(new_sample)
    if len(buffer) == window_size:
        window = np.array(buffer)
        return classify_window(window)
    return None, None

def classify_window(window):
    # Low-pass filter per channel
    for ch in range(num_channels):
        window[:, ch] = lowpass_filter(window[:, ch], cutoff=5, fs=sampling_rate)
    window = normalize_window(window)
    window = np.expand_dims(window, axis=0)
    
    probs = model.predict(window, verbose=0)
    predicted_idx = np.argmax(probs, axis=1)[0]
    confidence = np.max(probs)
    
    if confidence < CONFIDENCE_THRESHOLD:
        return "None", confidence
    return wand_classes[predicted_idx], confidence

# Generate test data

In [78]:
# Parameters
sampling_rate = 20
window_size = 3 * sampling_rate  # 60 timesteps
num_channels = 6
wand_classes = ["Wave", "Circle", "Square", "Triangle", "Infinity", "Zigzag", "None"]
num_classes = len(wand_classes)
num_samples = 5   # how many test vectors to export

# Example: generate random test windows (replace with real dataset)
x_test = np.random.randn(num_samples, window_size, num_channels).astype(np.float32)
y_pred = model.predict(x_test, verbose=0)   # shape (num_samples, num_classes)

# Quantize to Q8.8
def quantize_q8_8(x):
    return np.round(x * 256).astype(np.int16)

x_q = quantize_q8_8(x_test)       # (N, 60, 6)
y_q = quantize_q8_8(y_pred)       # (N, 7)

# Write to header
with open("test_data.h", "w") as f:
    f.write("/* Auto-generated test dataset from Keras CNN */\n")
    f.write("#ifndef TEST_DATA_H\n#define TEST_DATA_H\n\n")
    f.write("#include <cstdint>\n\n")

    f.write(f"#define NUM_SAMPLES {num_samples}\n")
    f.write(f"#define INPUT_LEN {window_size}\n")
    f.write(f"#define INPUT_CH {num_channels}\n")
    f.write(f"#define NUM_CLASSES {num_classes}\n\n")

    # Input array
    f.write(f"const int16_t test_inputs[NUM_SAMPLES][INPUT_LEN][INPUT_CH] = {{\n")
    for n in range(num_samples):
        f.write("  {\n")
        for t in range(window_size):
            row = ", ".join(map(str, x_q[n, t]))
            f.write(f"    {{{row}}},\n")
        f.write("  },\n")
    f.write("};\n\n")

    # Output array
    f.write(f"const int16_t test_outputs[NUM_SAMPLES][NUM_CLASSES] = {{\n")
    for n in range(num_samples):
        row = ", ".join(map(str, y_q[n]))
        f.write(f"  {{{row}}},\n")
    f.write("};\n\n")

    f.write("#endif // TEST_DATA_H\n")


# Extract Weights, Biases, Activations

In [80]:
# Quantization helper (Q8.8 fixed-point)
def quantize_q8_8(x):
    return np.round(x * 256).astype(np.int16)  # signed 16-bit

with open("cnn_weights.h", "w") as f:
    f.write("// Auto-generated CNN weights (Q8.8 fixed point)\n\n")
    for layer in model.layers:
        weights = layer.get_weights()
        if not weights:
            continue  # skip non-param layers (e.g., Pooling)
        
        for i, w in enumerate(weights):
            q = quantize_q8_8(w)
            flat = q.flatten()
            name = f"{layer.name}_param{i}"
            f.write(f"const short {name}[{len(flat)}] = {{\n")
            f.write(", ".join(map(str, flat)))
            f.write("\n};\n\n")


# MQTT

In [None]:
buffer = []

def on_connect(client, userdata, flags, rc):
    if rc == 0:
        print("Connected to broker")
        client.subscribe("wand/mpu")
    else:
        print(f"Failed to connect, rc={rc}")


def on_message(client, userdata, msg):
    try:
        payload = json.loads(msg.payload.decode())
        # extract your 6 channels
        sample = [
            payload["accelx"],
            payload["accely"],
            payload["accelz"],
            payload.get("gyrox", 0.0),  # adjust keys to your ESP32 payload
            payload.get("gyroy", 0.0),
            payload.get("gyroz", 0.0)
        ]
        if len(sample) != num_channels:
            print("⚠️ Incorrect sample length:", sample)
            return

        buffer.append(sample)

        if len(buffer) == window_size:
            window = np.array(buffer)  # now shape (window_size, 6)
            symbol, confidence = classify_window(window)
            print(f"Detected: {symbol}, Confidence: {confidence:.2f}")
            buffer.clear()  # start next window

    except Exception as e:
        print("⚠️ Failed to process message:", e)

# Real-time Inference

In [None]:
# run mosquitto first
client = mqtt.Client(client_id="LaptopInference")
client.on_connect = on_connect
client.on_message = on_message

broker_address = "192.168.68.56" 
client.connect(broker_address, 1883, 60)

print("📡 Waiting for IMU data...")
client.loop_start()  # run in background

# Keep notebook alive while processing MQTT messages
try:
    while True:
        time.sleep(1)
except KeyboardInterrupt:
    print("Stopping...")
    client.loop_stop()
    client.disconnect()