# 5s Transfer Learning on ESP32 Dataset (2s -> 5s random placement)

This notebook keeps the pipeline simple:
1. Load your trained 5s base model.
2. Load ESP32 2s clips (`cough` and `non_cough`).
3. Create synthetic 5s training samples by:
   - building a noisy 5s background,
   - placing 2s cough randomly inside the 5s window (for cough class).
4. Extract MFCC with the same settings as the base model.
5. Transfer learn and export model.

No zero-padding-only approach is used for cough clips.


In [1]:
import json
import random
from pathlib import Path

import numpy as np
import pandas as pd
import librosa
from tqdm import tqdm

import tensorflow as tf
import keras
from keras.utils import to_categorical

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

print('TensorFlow:', tf.__version__)


TensorFlow: 2.10.0


In [2]:
# =====================
# Configuration
# =====================
SEED = 42
np.random.seed(SEED)
random.seed(SEED)
tf.random.set_seed(SEED)

BASE_MODEL_CANDIDATES = [
    Path('./cough_cnn_5s_base.h5'),
    Path('model/cough_cnn_5s_base.h5'),
]
ESP32_ROOT_CANDIDATES = [
    Path('./esp32_dataset'),
    Path('model/esp32_dataset'),
]
PUBLIC_DATASET_CANDIDATES = [
    Path('../public_dataset'),
    Path('public_dataset'),
]


def first_existing(paths):
    for p in paths:
        if p.exists():
            return p
    return paths[0]


BASE_MODEL_PATH = first_existing(BASE_MODEL_CANDIDATES)
ESP32_ROOT = first_existing(ESP32_ROOT_CANDIDATES)
PUBLIC_DATASET_DIR = first_existing(PUBLIC_DATASET_CANDIDATES)

COUGH_DIR = ESP32_ROOT / 'cough'
NON_COUGH_DIR = ESP32_ROOT / 'non_cough'

OUTPUT_PREFIX = 'cough_cnn_5s_transfer_esp32'

# Audio setup
SR = 16000
SRC_SECONDS = 2.0
TARGET_SECONDS = 5.0
SRC_SAMPLES = int(SR * SRC_SECONDS)
TARGET_SAMPLES = int(SR * TARGET_SECONDS)

# MFCC (must match base model preprocessing)
N_MFCC = 40
N_MELS = 128
N_FFT = 1024
HOP_LENGTH = 512
EXPECTED_FRAMES = 1 + int(np.floor((TARGET_SAMPLES - N_FFT) / float(HOP_LENGTH)))

# Training setup
TEST_SIZE = 0.15
VAL_SIZE_FROM_TRAIN = 0.1765  # ~15% val of total
BATCH_SIZE = 32
HEAD_EPOCHS = 8
FINE_TUNE_EPOCHS = 12
HEAD_LR = 5e-4
FINE_TUNE_LR = 1e-4

# Data synthesis controls
TRAIN_VERSIONS_PER_SAMPLE = 1  # avoid over-augmenting positives with noise
VAL_VERSIONS_PER_SAMPLE = 1
TEST_VERSIONS_PER_SAMPLE = 1
USE_PUBLIC_NOISE_IN_TRANSFER = False  # keep transfer domain close to ESP32 by default
COUGH_SNR_DB_RANGE = (8.0, 20.0)      # keep cough clearly audible in positive samples
NON_COUGH_EVENT_SNR_DB_RANGE = (0.0, 18.0)

print('Base model path:', BASE_MODEL_PATH)
print('ESP32 root:', ESP32_ROOT)
print('Public dataset:', PUBLIC_DATASET_DIR)
print('Target shape (frames, mfcc):', (EXPECTED_FRAMES, N_MFCC))
print('USE_PUBLIC_NOISE_IN_TRANSFER:', USE_PUBLIC_NOISE_IN_TRANSFER)


Base model path: cough_cnn_5s_base.h5
ESP32 root: esp32_dataset
Public dataset: ..\public_dataset
Target shape (frames, mfcc): (155, 40)
USE_PUBLIC_NOISE_IN_TRANSFER: False


In [3]:
# =====================
# Step 1: Build ESP32 metadata
# =====================

def collect_labeled_files(folder, label):
    if not folder.exists():
        return []
    rows = []
    for wav in sorted(folder.glob('*.wav')):
        rows.append({'wav_path': str(wav.resolve()), 'label': int(label)})
    return rows

rows = []
rows += collect_labeled_files(COUGH_DIR, 1)
rows += collect_labeled_files(NON_COUGH_DIR, 0)

esp32_df = pd.DataFrame(rows)

if len(esp32_df) == 0:
    raise RuntimeError('No ESP32 wav files found. Check esp32_dataset/cough and esp32_dataset/non_cough')

print('ESP32 files total:', len(esp32_df))
print('Class counts:', esp32_df['label'].value_counts().to_dict())
print(esp32_df.head(3))


ESP32 files total: 512
Class counts: {0: 258, 1: 254}
                                            wav_path  label
0  E:\minor-project\model\esp32_dataset\cough\cou...      1
1  E:\minor-project\model\esp32_dataset\cough\cou...      1
2  E:\minor-project\model\esp32_dataset\cough\cou...      1


In [None]:
# =====================
# Step 2: Train/Val/Test split
# =====================

train_df, test_df = train_test_split(
    esp32_df,
    test_size=TEST_SIZE,
    random_state=SEED,
    stratify=esp32_df['label']
)

train_df, val_df = train_test_split(
    train_df,
    test_size=VAL_SIZE_FROM_TRAIN,
    random_state=SEED,
    stratify=train_df['label']
)

train_df = train_df.reset_index(drop=True)
val_df = val_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

print('Train:', train_df.shape, train_df['label'].value_counts().to_dict())
print('Val:  ', val_df.shape, val_df['label'].value_counts().to_dict())
print('Test: ', test_df.shape, test_df['label'].value_counts().to_dict())


Train: (358, 2) {0: 180, 1: 178}
Val:   (77, 2) {0: 39, 1: 38}
Test:  (77, 2) {0: 39, 1: 38}


In [5]:
# =====================
# Step 3: Noise pools
# =====================

# Main background pool from ESP32 non-cough clips.
ESP32_NOISE_PATHS = train_df.loc[train_df['label'] == 0, 'wav_path'].tolist()

# Optional extra noise pool from public dataset non-cough samples (cough_detected <= 0.2).
PUBLIC_NOISE_PATHS = []
if USE_PUBLIC_NOISE_IN_TRANSFER and PUBLIC_DATASET_DIR.exists():
    for wav in sorted(PUBLIC_DATASET_DIR.glob('*.wav')):
        js = wav.with_suffix('.json')
        if not js.exists():
            continue
        try:
            score = float(json.loads(js.read_text(encoding='utf-8')).get('cough_detected'))
            if score <= 0.20:
                PUBLIC_NOISE_PATHS.append(str(wav.resolve()))
        except Exception:
            pass

print('ESP32 background pool:', len(ESP32_NOISE_PATHS))
print('Public background pool:', len(PUBLIC_NOISE_PATHS))


ESP32 background pool: 180
Public background pool: 0


In [6]:
# =====================
# Step 4: Waveform helpers
# =====================


def pad_or_trim(y, target_len):
    if len(y) < target_len:
        y = np.pad(y, (0, target_len - len(y)))
    elif len(y) > target_len:
        y = y[:target_len]
    return y.astype(np.float32)


def load_2s_clip(path):
    y, _ = librosa.load(path, sr=SR, mono=True, offset=0.0, duration=SRC_SECONDS)
    y = pad_or_trim(y, SRC_SAMPLES)
    return y.astype(np.float32)


def load_any_clip(path):
    y, _ = librosa.load(path, sr=SR, mono=True)
    return y.astype(np.float32)


def rms(x):
    return float(np.sqrt(np.mean(np.square(x), dtype=np.float64) + 1e-10))


def mix_at_snr(signal, noise, snr_db):
    s = max(rms(signal), 1e-4)
    n = max(rms(noise), 1e-6)
    target_n = s / (10.0 ** (snr_db / 20.0))
    return signal + noise * (target_n / n)


def sample_noise_5s(rng):
    pool = []
    if len(ESP32_NOISE_PATHS) > 0:
        pool.extend(ESP32_NOISE_PATHS)
    if len(PUBLIC_NOISE_PATHS) > 0:
        pool.extend(PUBLIC_NOISE_PATHS)

    if len(pool) == 0:
        return np.zeros(TARGET_SAMPLES, dtype=np.float32)

    p = pool[int(rng.integers(0, len(pool)))]
    y = load_any_clip(p)

    if len(y) >= TARGET_SAMPLES:
        start = int(rng.integers(0, len(y) - TARGET_SAMPLES + 1))
        return y[start:start + TARGET_SAMPLES].astype(np.float32)

    reps = int(np.ceil(TARGET_SAMPLES / len(y)))
    return np.tile(y, reps)[:TARGET_SAMPLES].astype(np.float32)


def wind_noise(n, rng):
    brown = np.cumsum(rng.normal(0.0, 1.0, n)).astype(np.float32)
    brown = brown / (np.max(np.abs(brown)) + 1e-8)
    k = max(16, int(0.03 * SR))
    kernel = np.ones(k, dtype=np.float32) / k
    w = np.convolve(brown, kernel, mode='same')
    w = w / (np.max(np.abs(w)) + 1e-8)
    return w.astype(np.float32)


def pink_noise(n, rng):
    white = rng.normal(0.0, 1.0, n)
    spec = np.fft.rfft(white)
    freqs = np.fft.rfftfreq(n, d=1.0 / SR)
    scale = np.zeros_like(freqs)
    nz = freqs > 0
    scale[nz] = 1.0 / np.sqrt(freqs[nz])
    pn = np.fft.irfft(spec * scale, n=n)
    pn = pn / (np.max(np.abs(pn)) + 1e-8)
    return pn.astype(np.float32)



In [7]:
# =====================
# Step 5: Build synthetic 5s sample from 2s clip
# =====================

def synthesize_5s_from_2s(src_2s, label, rng):
    """
    label=1 (cough): place cough at random position over noisy 5s background.
    label=0 (non-cough): place non-cough event + disturbances, so loud noise is learned as non-cough.
    """
    # Start from real background noise
    y5 = sample_noise_5s(rng)

    # Baseline ambient variation
    y5 *= rng.uniform(0.70, 1.20)

    if int(label) == 1:
        # Keep positive background relatively clean; cough should stay dominant.
        if rng.random() < 0.25:
            y5 += rng.uniform(0.0005, 0.006) * rng.normal(0.0, 1.0, len(y5)).astype(np.float32)

        start = int(rng.integers(0, TARGET_SAMPLES - SRC_SAMPLES + 1))
        cough = src_2s.copy() * rng.uniform(0.9, 1.3)

        local_bg = y5[start:start + SRC_SAMPLES]
        c_rms = max(rms(cough), 1e-6)
        b_rms = max(rms(local_bg), 1e-6)
        snr_db = float(rng.uniform(COUGH_SNR_DB_RANGE[0], COUGH_SNR_DB_RANGE[1]))
        target_cough_rms = b_rms * (10.0 ** (snr_db / 20.0))
        cough = cough * (target_cough_rms / c_rms)

        y5[start:start + SRC_SAMPLES] += cough

    else:
        # Non-cough hard negatives: include loud non-cough events to break loudness shortcut.
        start = int(rng.integers(0, TARGET_SAMPLES - SRC_SAMPLES + 1))
        non_cough_event = src_2s.copy() * rng.uniform(0.8, 1.6)

        local_bg = y5[start:start + SRC_SAMPLES]
        e_rms = max(rms(non_cough_event), 1e-6)
        b_rms = max(rms(local_bg), 1e-6)
        snr_db = float(rng.uniform(NON_COUGH_EVENT_SNR_DB_RANGE[0], NON_COUGH_EVENT_SNR_DB_RANGE[1]))
        target_event_rms = b_rms * (10.0 ** (snr_db / 20.0))
        non_cough_event = non_cough_event * (target_event_rms / e_rms)

        y5[start:start + SRC_SAMPLES] += non_cough_event

        # Extra disturbances only for non-cough class
        if rng.random() < 0.60:
            y5 += rng.uniform(0.001, 0.020) * wind_noise(len(y5), rng)
        if rng.random() < 0.55:
            y5 += rng.uniform(0.001, 0.015) * pink_noise(len(y5), rng)
        if rng.random() < 0.50:
            y5 += rng.uniform(0.0005, 0.012) * rng.normal(0.0, 1.0, len(y5)).astype(np.float32)
        if rng.random() < 0.40:
            burst_len = int(rng.integers(int(0.005 * SR), int(0.03 * SR)))
            bstart = int(rng.integers(0, max(1, len(y5) - burst_len)))
            y5[bstart:bstart + burst_len] += rng.uniform(-0.35, 0.35)

    # Gentle compression effect (both classes)
    if rng.random() < 0.20:
        c = rng.uniform(0.30, 0.90)
        y5 = np.tanh(y5 / c) * c

    return np.clip(y5, -1.0, 1.0).astype(np.float32)


In [8]:
# =====================
# Step 6: MFCC extraction (same settings as base)
# =====================

def extract_mfcc_2d(y,
                    sr=SR,
                    n_mfcc=N_MFCC,
                    n_mels=N_MELS,
                    n_fft=N_FFT,
                    hop_length=HOP_LENGTH,
                    max_frames=EXPECTED_FRAMES,
                    normalise=True):
    mfcc = librosa.feature.mfcc(
        y=y,
        sr=sr,
        n_mfcc=n_mfcc,
        n_mels=n_mels,
        n_fft=n_fft,
        hop_length=hop_length,
        htk=False,
    )

    mfcc = mfcc.T.astype(np.float32)

    if mfcc.shape[0] < max_frames:
        pad = np.zeros((max_frames - mfcc.shape[0], n_mfcc), dtype=np.float32)
        mfcc = np.vstack([mfcc, pad])
    elif mfcc.shape[0] > max_frames:
        mfcc = mfcc[:max_frames, :]

    if normalise:
        mean = np.mean(mfcc, axis=0, keepdims=True)
        std = np.std(mfcc, axis=0, keepdims=True) + 1e-6
        mfcc = (mfcc - mean) / std

    return mfcc.astype(np.float32)


In [9]:
# =====================
# Step 7: Build feature tensors
# =====================

def build_feature_set(split_df, versions_per_sample=1, seed=42):
    rng = np.random.default_rng(seed)

    total = len(split_df) * int(versions_per_sample)
    X = np.zeros((total, EXPECTED_FRAMES, N_MFCC), dtype=np.float32)
    y = np.zeros((total,), dtype=np.int32)

    idx = 0
    for row in tqdm(split_df.itertuples(index=False), total=len(split_df)):
        src_2s = load_2s_clip(row.wav_path)
        label = int(row.label)

        for _ in range(int(versions_per_sample)):
            y5 = synthesize_5s_from_2s(src_2s, label, rng)
            y5 = y5.astype(np.float32)
            X[idx] = extract_mfcc_2d(y5)
            y[idx] = label
            idx += 1

    return X[:idx], y[:idx]


X_train, y_train = build_feature_set(train_df, versions_per_sample=TRAIN_VERSIONS_PER_SAMPLE, seed=SEED)
X_val, y_val = build_feature_set(val_df, versions_per_sample=VAL_VERSIONS_PER_SAMPLE, seed=SEED + 1)
X_test, y_test = build_feature_set(test_df, versions_per_sample=TEST_VERSIONS_PER_SAMPLE, seed=SEED + 2)

print('X_train:', X_train.shape, 'y_train:', y_train.shape)
print('X_val:  ', X_val.shape, 'y_val:  ', y_val.shape)
print('X_test: ', X_test.shape, 'y_test: ', y_test.shape)
print('Train class counts:', {0:int(np.sum(y_train==0)), 1:int(np.sum(y_train==1))})


100%|██████████| 358/358 [00:05<00:00, 61.81it/s] 
100%|██████████| 77/77 [00:00<00:00, 88.31it/s]
100%|██████████| 77/77 [00:00<00:00, 82.33it/s]

X_train: (358, 155, 40) y_train: (358,)
X_val:   (77, 155, 40) y_val:   (77,)
X_test:  (77, 155, 40) y_test:  (77,)
Train class counts: {0: 180, 1: 178}





In [10]:
# Labels + class weights
y_train_onehot = to_categorical(y_train, num_classes=2)
y_val_onehot = to_categorical(y_val, num_classes=2)
y_test_onehot = to_categorical(y_test, num_classes=2)

class_weights = compute_class_weight(
    class_weight='balanced',
    classes=np.array([0, 1]),
    y=y_train
)
class_weight = {0: float(class_weights[0]), 1: float(class_weights[1])}

print('Class weights:', class_weight)


Class weights: {0: 0.9944444444444445, 1: 1.0056179775280898}


In [11]:
# =====================
# Step 8: Transfer learning
# =====================
if not BASE_MODEL_PATH.exists():
    raise FileNotFoundError(f'Base model not found: {BASE_MODEL_PATH}')

model = keras.models.load_model(str(BASE_MODEL_PATH))
print('Loaded base model from:', BASE_MODEL_PATH)

# Phase 1: freeze most layers, train only top layers
for layer in model.layers:
    layer.trainable = False

for layer in model.layers[-4:]:
    layer.trainable = True

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=HEAD_LR),
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.03),
    metrics=['accuracy']
)

callbacks = [
    keras.callbacks.ModelCheckpoint(
        filepath=f'{OUTPUT_PREFIX}_head.h5',
        monitor='val_loss',
        mode='min',
        save_best_only=True,
        verbose=1,
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        mode='min',
        factor=0.5,
        patience=3,
        min_lr=1e-5,
        verbose=1,
    ),
]

history_head = model.fit(
    X_train, y_train_onehot,
    validation_data=(X_val, y_val_onehot),
    epochs=HEAD_EPOCHS,
    batch_size=BATCH_SIZE,
    class_weight=class_weight,
    callbacks=callbacks,
    verbose=1
)

# Phase 2: unfreeze all, fine-tune with lower LR
for layer in model.layers:
    layer.trainable = True

model.compile(
    optimizer=keras.optimizers.Adam(learning_rate=FINE_TUNE_LR),
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=0.02),
    metrics=['accuracy']
)

callbacks_ft = [
    keras.callbacks.ModelCheckpoint(
        filepath=f'{OUTPUT_PREFIX}.h5',
        monitor='val_loss',
        mode='min',
        save_best_only=True,
        verbose=1,
    ),
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',
        mode='min',
        factor=0.5,
        patience=3,
        min_lr=1e-6,
        verbose=1,
    ),
    keras.callbacks.EarlyStopping(
        monitor='val_loss',
        mode='min',
        patience=6,
        restore_best_weights=True,
        verbose=1,
    )
]

history_ft = model.fit(
    X_train, y_train_onehot,
    validation_data=(X_val, y_val_onehot),
    epochs=FINE_TUNE_EPOCHS,
    batch_size=BATCH_SIZE,
    class_weight=class_weight,
    callbacks=callbacks_ft,
    verbose=1
)

model = keras.models.load_model(f'{OUTPUT_PREFIX}.h5')
print('Loaded best transfer model:', f'{OUTPUT_PREFIX}.h5')


Loaded base model from: cough_cnn_5s_base.h5
Epoch 1/8
Epoch 1: val_loss improved from inf to 1.00509, saving model to cough_cnn_5s_transfer_esp32_head.h5
Epoch 2/8
 1/12 [=>............................] - ETA: 0s - loss: 0.8832 - accuracy: 0.5938
Epoch 2: val_loss improved from 1.00509 to 0.87555, saving model to cough_cnn_5s_transfer_esp32_head.h5
Epoch 3/8
 1/12 [=>............................] - ETA: 0s - loss: 0.9720 - accuracy: 0.5000
Epoch 3: val_loss improved from 0.87555 to 0.79889, saving model to cough_cnn_5s_transfer_esp32_head.h5
Epoch 4/8
Epoch 4: val_loss improved from 0.79889 to 0.75927, saving model to cough_cnn_5s_transfer_esp32_head.h5
Epoch 5/8
Epoch 5: val_loss improved from 0.75927 to 0.73632, saving model to cough_cnn_5s_transfer_esp32_head.h5
Epoch 6/8
 1/12 [=>............................] - ETA: 0s - loss: 0.6829 - accuracy: 0.5312
Epoch 6: val_loss improved from 0.73632 to 0.72119, saving model to cough_cnn_5s_transfer_esp32_head.h5
Epoch 7/8
Epoch 7: val_los

In [12]:
# =====================
# Step 9: Raw probabilities + test metrics (argmax decision)
# =====================
val_pred_prob = model.predict(X_val, batch_size=BATCH_SIZE, verbose=1)
test_pred_prob = model.predict(X_test, batch_size=BATCH_SIZE, verbose=1)

# Raw class probabilities
val_prob_non_cough = val_pred_prob[:, 0]
val_prob_cough = val_pred_prob[:, 1]

test_prob_non_cough = test_pred_prob[:, 0]
test_prob_cough = test_pred_prob[:, 1]

# No threshold tuning: use class argmax directly
test_pred = np.argmax(test_pred_prob, axis=1).astype(np.int32)

print('Val probability summary:')
print(pd.DataFrame({
    'p_cough': val_prob_cough,
    'p_non_cough': val_prob_non_cough
}).describe().T)

print('\nTest probability summary:')
print(pd.DataFrame({
    'p_cough': test_prob_cough,
    'p_non_cough': test_prob_non_cough
}).describe().T)

preview_df = pd.DataFrame({
    'y_true': y_test,
    'pred': test_pred,
    'p_cough': test_prob_cough,
    'p_non_cough': test_prob_non_cough,
})

print('\nRaw probability preview (first 20 test samples):')
print(preview_df.head(20))

print('\nConfusion matrix (test, argmax):')
print(confusion_matrix(y_test, test_pred))

print('\nClassification report (test, argmax):')
print(classification_report(y_test, test_pred, digits=4))

print('Test AUC (using p_cough):', roc_auc_score(y_test, test_prob_cough))


Val probability summary:
             count      mean      std       min       25%       50%       75%  \
p_cough       77.0  0.450966  0.14087  0.081205  0.381805  0.443077  0.546257   
p_non_cough   77.0  0.549034  0.14087  0.240031  0.453743  0.556923  0.618195   

                  max  
p_cough      0.759969  
p_non_cough  0.918795  

Test probability summary:
             count      mean       std       min       25%       50%  \
p_cough       77.0  0.451555  0.148593  0.097577  0.361660  0.433655   
p_non_cough   77.0  0.548445  0.148593  0.174669  0.440806  0.566345   

                  75%       max  
p_cough      0.559194  0.825330  
p_non_cough  0.638340  0.902423  

Raw probability preview (first 20 test samples):
    y_true  pred   p_cough  p_non_cough
0        1     1  0.599284     0.400716
1        1     1  0.511056     0.488944
2        0     0  0.489676     0.510324
3        0     0  0.102863     0.897137
4        1     0  0.285394     0.714606
5        0     0  0.379

In [13]:
# =====================
# Step 10: Save report + TFLite
# =====================
report = {
    'output_prefix': OUTPUT_PREFIX,
    'base_model_path': str(BASE_MODEL_PATH),
    'seed': SEED,
    'audio': {
        'sr': SR,
        'src_seconds': SRC_SECONDS,
        'target_seconds': TARGET_SECONDS,
        'src_samples': SRC_SAMPLES,
        'target_samples': TARGET_SAMPLES,
    },
    'mfcc': {
        'n_mfcc': N_MFCC,
        'n_mels': N_MELS,
        'n_fft': N_FFT,
        'hop_length': HOP_LENGTH,
        'expected_frames': EXPECTED_FRAMES,
    },
    'counts': {
        'esp32_files_total': int(len(esp32_df)),
        'train_files': int(len(train_df)),
        'val_files': int(len(val_df)),
        'test_files': int(len(test_df)),
        'train_samples_after_synthesis': int(len(y_train)),
        'val_samples_after_synthesis': int(len(y_val)),
        'test_samples_after_synthesis': int(len(y_test)),
        'esp32_noise_pool': int(len(ESP32_NOISE_PATHS)),
        'public_noise_pool': int(len(PUBLIC_NOISE_PATHS)),
    },
    'synthesis': {
        'train_versions_per_sample': TRAIN_VERSIONS_PER_SAMPLE,
        'val_versions_per_sample': VAL_VERSIONS_PER_SAMPLE,
        'test_versions_per_sample': TEST_VERSIONS_PER_SAMPLE,
        'cough_insert_random_position': True,
        'zero_padding_only': False,
    },
    'decision_rule': 'argmax(class_probabilities)',
    'test_auc': float(roc_auc_score(y_test, test_prob_cough)),
    'classification_report': classification_report(y_test, test_pred, output_dict=True, digits=6),
    'confusion_matrix': confusion_matrix(y_test, test_pred).tolist(),
    # Keep a tiny preview so report stays small but still shows raw scores.
    'test_probability_preview_first_20': [
        {
            'y_true': int(y_test[i]),
            'pred': int(test_pred[i]),
            'p_cough': float(test_prob_cough[i]),
            'p_non_cough': float(test_prob_non_cough[i]),
        }
        for i in range(min(20, len(test_pred)))
    ],
}

report_path = Path(f'{OUTPUT_PREFIX}_report.json')
report_path.write_text(json.dumps(report, indent=2), encoding='utf-8')
print('Saved report:', report_path.resolve())


def representative_data_gen():
    n = min(300, len(X_train))
    idx = np.random.choice(len(X_train), size=n, replace=False)
    for i in idx:
        yield [X_train[i:i+1].astype(np.float32)]

converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_data_gen
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8
converter.inference_output_type = tf.int8

quant_tflite = converter.convert()
tflite_path = Path(f'{OUTPUT_PREFIX}_int8.tflite')
tflite_path.write_bytes(quant_tflite)
print('Saved tflite:', tflite_path.resolve())

interpreter = tf.lite.Interpreter(model_path=str(tflite_path))
interpreter.allocate_tensors()

in_d = interpreter.get_input_details()[0]
out_d = interpreter.get_output_details()[0]

print('Input quantization:', in_d['quantization'], 'dtype:', in_d['dtype'], 'shape:', in_d['shape'])
print('Output quantization:', out_d['quantization'], 'dtype:', out_d['dtype'], 'shape:', out_d['shape'])


Saved report: E:\minor-project\model\cough_cnn_5s_transfer_esp32_report.json




INFO:tensorflow:Assets written to: C:\Users\Aman\AppData\Local\Temp\tmpsfli8_rb\assets


INFO:tensorflow:Assets written to: C:\Users\Aman\AppData\Local\Temp\tmpsfli8_rb\assets


Saved tflite: E:\minor-project\model\cough_cnn_5s_transfer_esp32_int8.tflite
Input quantization: (0.05734632909297943, -4) dtype: <class 'numpy.int8'> shape: [  1 155  40]
Output quantization: (0.00390625, -128) dtype: <class 'numpy.int8'> shape: [1 2]
