In [1]:
#-------------------------------------------------------------------------------------JUPYTER NOTEBOOK SETTINGS-------------------------------------------------------------------------------------
from IPython.core.display import display, HTML                                    
display(HTML("<style>.container { width:100% !important; }</style>"))  

  from IPython.core.display import display, HTML


In [2]:
import os
import gc
import re
import librosa
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from joblib import dump, load

from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder

import tensorflow as tf
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.layers import Input, Conv1D, MaxPooling1D, Dropout, Flatten, Dense, BatchNormalization
from tensorflow.keras.regularizers import l2
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau, ModelCheckpoint, EarlyStopping 

import seaborn as sns
import matplotlib.pyplot as plt
import plotly.express as px

In [3]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
tf.config.list_physical_devices('GPU')

Num GPUs Available:  1


[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

### Data Loading and Processing

In [None]:
def calculate_frames(audio_path, window_size_ms=20, hop_size_ms=10, sr=16000):
    """
    Calculate the number of frames in an audio file given the window size, hop size, and sample rate.

    Parameters:
    - audio_path: str, path to the audio file.
    - window_size_ms: float, window size in milliseconds.
    - hop_size_ms: float, hop size in milliseconds.
    - sr: int, sample rate in Hz.

    Returns:
    - int, number of frames.
    """
    # Load audio file
    audio, _ = librosa.load(audio_path, sr=sr)
    
    # Convert window and hop size from milliseconds to samples
    window_size_samples = int(sr * window_size_ms / 1000)
    hop_size_samples = int(sr * hop_size_ms / 1000)
    
    # Calculate the number of frames
    num_frames = (len(audio) - window_size_samples) // hop_size_samples + 1
    return num_frames

# Example usage
audio_path = '/Users/ciprian/Desktop/Projects/Smart Plant Pot/Audio/Voice Recognition/Prototype 3/silence/silence_sample_600.wav'
num_frames = calculate_frames(audio_path)
print(f"Number of frames per sample: {num_frames}")

In [None]:
def time_masking(mfccs, width=10, num_masks=1):
    """ Apply time masking to a series of MFCCs with a given width and number of times. """
    masked_mfccs = np.copy(mfccs)
    for _ in range(num_masks):
        t = np.random.uniform(low=0, high=masked_mfccs.shape[1] - width)
        t = int(t)
        masked_mfccs[:, t:t+width] = 0
    return masked_mfccs

def frequency_masking(mfccs, width=5, num_masks=1):
    """ Apply frequency masking to a series of MFCCs with a given width and number of times. """
    masked_mfccs = np.copy(mfccs)
    for _ in range(num_masks):
        f = np.random.uniform(low=0, high=masked_mfccs.shape[0] - width)
        f = int(f)
        masked_mfccs[f:f+width, :] = 0
    return masked_mfccs

def time_warping(signal, sr, warping_factor=0.5):
    """ Warp time axis of a signal with a given warping factor safely. """
    # Calculate the number of steps to warp
    n_steps = int(signal.size * (1 - warping_factor))
    
    # Ensure n_steps is less than signal.size to avoid empty array
    if n_steps >= signal.size or n_steps < 1:
        return signal  # Return original signal if warping is not possible
    
    # Perform interpolation
    return np.interp(np.arange(signal.size), np.linspace(0, signal.size, num=signal.size - n_steps), signal[:signal.size - n_steps])

def apply_augmentation(signal, sr, mfccs, intensity):
    """Apply random augmentations conditionally based on the specified intensity."""
    if intensity == 'high':
        signal = time_warping(signal, sr, warping_factor=np.random.uniform(0.7, 1.3))
        mfccs = time_masking(mfccs, width=20, num_masks=2)
        mfccs = frequency_masking(mfccs, width=15, num_masks=2)
    elif intensity == 'medium':
        if np.random.rand() < 0.5:
            signal = time_warping(signal, sr, warping_factor=np.random.uniform(0.85, 1.15))
        mfccs = time_masking(mfccs, width=10, num_masks=1)
        mfccs = frequency_masking(mfccs, width=10, num_masks=1)
    elif intensity == 'low':
        choice = np.random.choice(['time', 'freq'])
        if choice == 'time':
            mfccs = time_masking(mfccs, width=5, num_masks=1)
        else:
            mfccs = frequency_masking(mfccs, width=5, num_masks=1)
    return signal, mfccs

def load_and_augment_data(directory, augment=True):
    max_length = 332  
    labels = [label for label in os.listdir(directory) if os.path.isdir(os.path.join(directory, label))]
    x, y = [], []
    
    # Finding the maximum number of frames among all samples with a progress bar
    #     for label in tqdm(labels, desc="Finding Max Length"):
    #         label_path = os.path.join(directory, label)
    #         wav_files = [os.path.join(label_path, file) for file in os.listdir(label_path) if file.endswith('.wav')]

            # Process each file in the label directory
    #         for wav_file in tqdm(wav_files, desc=f"Processing {label}", leave=False):
    #             signal, sr = librosa.load(wav_file, sr=16000)
    #             mfccs = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13, n_fft=256, hop_length=160, n_mels=32, fmin=0, fmax=8000)
    #             if mfccs.shape[1] > max_length:
    #                 max_length = mfccs.shape[1]

    for label in tqdm(labels, desc="Loading and Padding Data"):
        label_path = os.path.join(directory, label)
        wav_files = [os.path.join(label_path, file) for file in os.listdir(label_path) if file.endswith('.wav')]
        
        for wav_file in tqdm(wav_files, desc=f"Padding {label}", leave=False):
            signal, sr = librosa.load(wav_file, sr=16000)
            mfccs = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13, n_fft=256, hop_length=160, n_mels=32, fmin=0, fmax=8000)
            
            if augment:
                # Determine the level of augmentation based on a random choice
                intensity = np.random.choice(['none', 'low', 'medium', 'high'], p=[0.25, 0.25, 0.25, 0.25])
                if intensity != 'none':
                    signal, mfccs = apply_augmentation(signal, sr, mfccs, intensity)
            
            pad_width = max_length - mfccs.shape[1]
            mfccs_padded = np.pad(mfccs, pad_width=((0, 0), (0, pad_width)), mode='constant')
            x.append(mfccs_padded)
            y.append(label)

    return np.array(x, dtype=object), np.array(y)

# Load data and split
directory = "/Users/ciprian/Desktop/Projects/Smart Plant Pot/Audio/Voice Recognition/Prototype 3"
x, y = load_and_augment_data(directory)
x_train, x_temp, y_train, y_temp = train_test_split(x, y, test_size=0.3, random_state=42, stratify=y)
x_val, x_test, y_val, y_test = train_test_split(x_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp)

In [None]:
# Assuming x_train is properly shaped and contains MFCC features
if len(x_train) > 1:
    print(f"The number of features extracted from one of the samples is: {len(x_train[1])}")
    print(f"Number of frames (windows) in the first sample: {x_train[0].shape[1]}\n")
else:
    print("Not enough samples in x_train to display features.\n")

print(f"The number of samples for training is {len(x_train)}, with the number of labels {len(y_train)}")
print(f"The number of samples for validation is {len(x_val)}, with the number of labels {len(y_val)}")
print(f"The number of samples for testing is {len(x_test)}, with the number of labels {len(y_test)}")

In [None]:
# Flatten the MFCC features for the first 10 samples for display purposes
x_flattened_for_display = [x[i].flatten()[:25] for i in range(25)]  # Display only the first 10 MFCC coefficients of each sample

# Create a DataFrame with the flattened MFCC features and labels
df = pd.DataFrame(x_flattened_for_display)
df['Label'] = y[:25]  # Add the labels as the last column

# Setting column names for clarity in display
feature_columns = [f'Feature_{i+1}' for i in range(df.shape[1] - 1)]  # Feature column names
df.columns = feature_columns + ['Label']  # Rename the columns for better understanding

# Display the DataFrame in Jupyter Notebook
df

In [None]:
dump((x_train, y_train), 'saved_data/train_data.joblib')
dump((x_val, y_val), 'saved_data/val_data.joblib')
dump((x_test, y_test), 'saved_data/test_data.joblib')
print("All extraded features from the samples have been saved properly!")

### Data Visualisation

In [None]:
def flatten_features(x):        # Flatten the MFCC data with a tqdm progress bar
    return np.array([sample.reshape(-1) for sample in tqdm(x, desc="Flattening MFCCs")])

def apply_tsne(features, perplexity=30, n_components=3, n_iter=1000):
    print("Running t-SNE, this may take a while...")
    tsne = TSNE(n_components=n_components, perplexity=perplexity, n_iter=n_iter, random_state=42, verbose=1)
    return tsne.fit_transform(features)

# Assuming x and y have been loaded and prepared
x_flat = flatten_features(x)  # Flatten with progress update
x_tsne = apply_tsne(x_flat)  # This will print updates due to verbose=1

dump(x_tsne, 'saved_data/tsne_results.joblib')

In [None]:
# Load back the computed t-SNE
x_tsne = load('saved_data/tsne_results.joblib')

In [None]:
# Create the 3D scatter plot
fig = px.scatter_3d(
    x=x_tsne[:, 0], y=x_tsne[:, 1], z=x_tsne[:, 2],
    color=y_encoded,
    color_continuous_scale=px.colors.sequential.Viridis,  # Using a continuous color scale
    labels={'color': 'Label'},
    title="3D Scatter Plot of Voice Samples via t-SNE"
)

# Update plot layout to increase height
fig.update_layout(
    width=1500,  # Adjust width as necessary
    height=1500,  # Increased height
    autosize=False
)

# Update marker size if needed
fig.update_traces(marker=dict(size=2))

# Display the plot
fig.show()

### CNN Setup and Training

In [4]:
x_train, y_train = load('saved_data/train_data.joblib')
x_val, y_val = load('saved_data/val_data.joblib')
x_test, y_test = load('saved_data/test_data.joblib')

In [None]:
# x = np.concatenate((x_train, x_val, x_test))
# y = np.concatenate((y_train, y_val, y_test))

In [5]:
# Ensure all input data is float32
x_train = np.array(x_train, dtype=np.float32)
x_val = np.array(x_val, dtype=np.float32)
x_test = np.array(x_test, dtype=np.float32)

In [6]:
print("Train data:", x_train.shape, x_train.dtype)
print("Validation data:", x_val.shape, x_val.dtype)
print("Test data:", x_test.shape, x_test.dtype)

# Check for any NaN or inf values in your dataset
print("NaNs in train:", np.isnan(x_train).any())
print("NaNs in validation:", np.isnan(x_val).any())
print("NaNs in test:", np.isnan(x_test).any())

print("Infs in train:", np.isinf(x_train).any())
print("Infs in validation:", np.isinf(x_val).any())
print("Infs in test:", np.isinf(x_test).any())

Train data: (63000, 13, 332) float32
Validation data: (13500, 13, 332) float32
Test data: (13500, 13, 332) float32
NaNs in train: False
NaNs in validation: False
NaNs in test: False
Infs in train: False
Infs in validation: False
Infs in test: False


In [7]:
# ONEHOT ENCODING THE LABELS
label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train)
y_val_encoded = label_encoder.transform(y_val)
y_test_encoded = label_encoder.transform(y_test)

# Convert labels to one-hot encoding
y_train_onehot = to_categorical(y_train_encoded)
y_val_onehot = to_categorical(y_val_encoded)
y_test_onehot = to_categorical(y_test_encoded)

In [9]:
# CONVOLUTIONAL NEURAL NETWORK SETUP AND TRAINING
# Model architecture
model = Sequential([
    Input(shape=(13, 332)),  # Input shape is specified here
    Conv1D(32, kernel_size=3, activation='relu', padding='same'),
    BatchNormalization(),
    Conv1D(32, kernel_size=3, activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling1D(pool_size=2),
    Dropout(0.25),
    Conv1D(64, kernel_size=2, activation='relu', padding='same'),
    BatchNormalization(),
    Conv1D(64, kernel_size=2, activation='relu', padding='same'),
    BatchNormalization(),
    MaxPooling1D(pool_size=2),
    Dropout(0.25),
    Flatten(),
    Dense(64, activation='relu', kernel_regularizer=l2(0.01)),
    Dropout(0.5),
    Dense(y_train_onehot.shape[1], activation='softmax')  # Adjust the number of units to match the number of classes
])

def load_latest_weights(weights_dir, file_pattern):
    """Load the latest weights based on the file modification time."""
    # List all files in the directory that match the pattern
    all_weights = [os.path.join(weights_dir, f) for f in os.listdir(weights_dir) if file_pattern in f]
    # Find the most recent file by sorting based on modification time
    latest_weights = max(all_weights, key=os.path.getmtime, default=None)
    if latest_weights:
        print(f"Loading weights from {latest_weights}")
        return latest_weights
    else:
        print("No weights file found.")
        return None

class SaveWeightsCallback(Callback):
    def __init__(self, save_freq, filepath):
        super(SaveWeightsCallback, self).__init__()
        self.save_freq = save_freq
        self.filepath = filepath
    
    def on_epoch_end(self, epoch, logs=None):
        # Check if the current epoch number is a multiple of the save frequency
        if (epoch + 1) % self.save_freq == 0:
            self.model.save_weights(self.filepath.format(epoch=epoch + 1))
            
# Create an Adam optimizer with a specified initial learning rate
optimizer = Adam(learning_rate=0.001)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

# Define paths and initial settings
directory = 'saved_data/'
if not os.path.exists(directory):
    os.makedirs(directory)

# Setup callbacks
early_stopping_monitor = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True, verbose=1)
reduce_lr_on_plateau = ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=0.0000001, verbose=1)
weights_saver = SaveWeightsCallback(save_freq=50, filepath='saved_data/custom-cnn_weights_epoch_{epoch}.weights.h5')

# Segment-based training setup
num_epochs_per_stage = 50
total_epochs = 500
current_epoch = 0
all_history = []

while current_epoch < total_epochs:
    try:
        # Load the latest model weights if available
        try:
            latest_weights_file = load_latest_weights('saved_data', '.weights.h5')
            if latest_weights_file:
                model.load_weights(latest_weights_file)
        except Exception as e:
            print("Loading weights failed:", e)
        
        # Train the model for a stage
        history = model.fit(
            x_train,
            y_train_onehot, 
            epochs=current_epoch + num_epochs_per_stage,
            batch_size=2048,
            validation_data=(x_val, y_val_onehot),
            callbacks=[weights_saver, early_stopping_monitor, reduce_lr_on_plateau],
            initial_epoch=current_epoch,
            verbose=1  
        )
        
        # Append segment history to the total history
        all_history.append(history.history)
        
        # Update the current epoch count
        current_epoch += len(history.history['loss'])
        
        # Optionally perform garbage collection
        gc.collect()

        # Check if early stopping was triggered
        if early_stopping_monitor.stopped_epoch > 0:
            print(f"Early stopping triggered at epoch {current_epoch}")
            break

    except Exception as e:
        print("An error occurred during training:", e)
        break

# Save final model
model.save('saved_data/custom-cnn_final_model.keras')

# Concatenate all history segments into one dictionary if there are any segments
if all_history:
    final_history = {key: np.concatenate([seg[key] for seg in all_history]) for key in all_history[0]}
    # Save the final training history
    dump(final_history, 'saved_data/custom-cnn_training_history.joblib')
else:
    print("No training history was recorded.")

No weights file found.
Epoch 1/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 118ms/step - accuracy: 0.1267 - loss: 3.9921 - val_accuracy: 0.1548 - val_loss: 2.9878 - learning_rate: 0.0010
Epoch 2/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 60ms/step - accuracy: 0.2021 - loss: 2.8622 - val_accuracy: 0.2070 - val_loss: 2.6728 - learning_rate: 0.0010
Epoch 3/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 62ms/step - accuracy: 0.2645 - loss: 2.4975 - val_accuracy: 0.2869 - val_loss: 2.4051 - learning_rate: 0.0010
Epoch 4/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 64ms/step - accuracy: 0.3177 - loss: 2.2038 - val_accuracy: 0.3534 - val_loss: 2.1677 - learning_rate: 0.0010
Epoch 5/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 58ms/step - accuracy: 0.3741 - loss: 1.9707 - val_accuracy: 0.4439 - val_loss: 1.9437 - learning_rate: 0.0010
Epoch 6/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━

Epoch 45/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 58ms/step - accuracy: 0.6641 - loss: 0.9583 - val_accuracy: 0.6773 - val_loss: 0.9079 - learning_rate: 0.0010
Epoch 46/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 56ms/step - accuracy: 0.6642 - loss: 0.9516 - val_accuracy: 0.6812 - val_loss: 0.8929 - learning_rate: 0.0010
Epoch 47/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 56ms/step - accuracy: 0.6620 - loss: 0.9529 - val_accuracy: 0.6812 - val_loss: 0.8884 - learning_rate: 0.0010
Epoch 48/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 58ms/step - accuracy: 0.6662 - loss: 0.9429 - val_accuracy: 0.6762 - val_loss: 0.9027 - learning_rate: 0.0010
Epoch 49/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 56ms/step - accuracy: 0.6650 - loss: 0.9521 - val_accuracy: 0.6793 - val_loss: 0.8923 - learning_rate: 0.0010
Epoch 50/50
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2

In [None]:
# Define paths and initial settings
directory = 'saved_data/'
if not os.path.exists(directory):
    os.makedirs(directory)

# Save the aggregated history
dump(history, f'{directory}/model_training_history.joblib')

# Save the model
model.save(f'{directory}/model_full.keras')

### Model Testing and Statistics Plotting

In [None]:
history = load('saved_data/model_training_history.joblib')
model = load_model('saved_data/model_full.keras')

# Predictions with the model
y_pred = np.argmax(model.predict(x_test), axis=1)

# Decode the integer predictions back to string labels
y_pred_labels = label_encoder.inverse_transform(y_pred)

# Decode y_test if it's not already in integer form
y_test_encoded = label_encoder.transform(y_test)  # Only if y_test is in string format

# Compute confusion matrix with integer labels for calculation
conf_matrix = confusion_matrix(y_test_encoded, y_pred)

# Classification Report with string labels
print("Number of epochs trained:", len(history.history['loss']))
print(classification_report(y_test, y_pred_labels))

# Plot confusion matrix with string labels for readability
fig, ax = plt.subplots(figsize=(8, 8))
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Greens', ax=ax, 
            xticklabels=label_encoder.classes_, 
            yticklabels=label_encoder.classes_)
ax.set_xlabel('Predicted Labels')
ax.set_ylabel('True Labels')
ax.set_title('Confusion Matrix')
plt.show()

# Loss and Accuracy Per Epoch plots
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], label='Training Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Loss Per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], label='Training Accuracy')
plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
plt.title('Accuracy Per Epoch')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# Custom Samples
def process_wav_file(wav_file_path, max_length=332):
    # Load the WAV file
    signal, sr = librosa.load(wav_file_path, sr=16000)  # Ensure sample rate is 16000 Hz
    # Compute MFCC features
    mfccs = librosa.feature.mfcc(y=signal, sr=sr, n_mfcc=13, n_fft=256, hop_length=160, n_mels=32, fmin=0, fmax=8000)
    # Calculate padding width
    pad_width = max_length - mfccs.shape[1]
    if pad_width > 0:  # Apply padding if needed
        mfccs = np.pad(mfccs, pad_width=((0, 0), (0, pad_width)), mode='constant')
    return mfccs

def predict_wav_file(wav_file_path, model, label_encoder):
    # Process the WAV file to get padded MFCCs
    mfccs_padded = process_wav_file(wav_file_path)

    # Reshape the input to fit the model by adding a batch dimension
    mfccs_padded = mfccs_padded[np.newaxis, ...]  # Add batch dimension, reshaping (13, 332) to (1, 13, 332)

    # Perform prediction using the model to get softmax outputs
    softmax_output = model.predict(mfccs_padded)[0]  # [0] to get the first (and only) batch item
    
    # Create a DataFrame to hold the probabilities associated with each label
    labels = label_encoder.classes_  # Assuming label_encoder has all labels
    probabilities_df = pd.DataFrame(softmax_output, index=labels, columns=['Probability'])

    return probabilities_df

model = load_model('saved_data/model_full.keras')  # Load pre-trained model

all_labels = ['battery', 'description', 'environment', 'greeting', 'health', 'noise', 'noise', 'nutrition', 'silence', 'sun', 'water']  
label_encoder = LabelEncoder()
label_encoder.fit(all_labels)

# Path to the WAV file 
wav_file_path = '/Users/ciprian/Desktop/Projects/Smart Plant Pot/Audio/Voice Recognition/Prototype 3/_testing_samples_bianca/what can you do.mp3'  

# Get the prediction DataFrame
probabilities_df = predict_wav_file(wav_file_path, model, label_encoder)

# Find the highest probability
max_label = probabilities_df['Probability'].idxmax()
highest_probability = probabilities_df['Probability'].max()
print(f"The highest probability is {highest_probability}, associated with label '{max_label}'")

probabilities_df

### Conversion of model to TensorFlow Lite

In [None]:
# quantization of the model and conversion to tensorflow lite
import numpy as np

def representative_dataset():
    for _ in range(100):
        # Yielding a batch of one image with the expected input shape
        yield [np.random.rand(1, 13, 148, 1).astype(np.float32)]

        
# Load the model
model = tf.keras.models.load_model('my_model.h5')

# Convert the model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
converter.inference_input_type = tf.int8  # or tf.uint8
converter.inference_output_type = tf.int8  # or tf.uint8
tflite_model = converter.convert()

# Save the model
with open('my_model_quantized.tflite', 'wb') as f:
    f.write(tflite_model)

In [None]:
# verify the model
interpreter = tf.lite.Interpreter(model_path='my_model_quantized.tflite')
interpreter.allocate_tensors()

# Get input and output tensors
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

# Test the model on random input data
input_shape = input_details[0]['shape']
input_data = np.array(np.random.random_sample(input_shape), dtype=np.int8)
interpreter.set_tensor(input_details[0]['index'], input_data)

interpreter.invoke()

# Get the results
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)