In [12]:
import tensorflow as tf
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Conv2D, Input, Lambda, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import tensorflow_model_optimization as tfmot
import cv2
import numpy as np
import os
import shutil
import zipfile
from pycocotools.coco import COCO
import matplotlib.pyplot as plt
import requests
from tensorflow_model_optimization.sparsity import keras as sparsity
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer, prune_registry

In [2]:
# Define the COCO dataset URLs
coco_images_url = 'http://images.cocodataset.org/zips/train2017.zip'
coco_annotations_url = 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip'

# Define download paths
data_dir = './coco2017/'
images_zip_path = os.path.join(data_dir, 'train2017.zip')
annotations_zip_path = os.path.join(data_dir, 'annotations_trainval2017.zip')

# Create directory to store dataset
os.makedirs(data_dir, exist_ok=True)

# Download images
if not os.path.exists(os.path.join(data_dir, 'train2017')):
    print("Downloading COCO train images...")
    r = requests.get(coco_images_url, stream=True)
    with open(images_zip_path, 'wb') as f:
        shutil.copyfileobj(r.raw, f)
    
    # Unzip images
    with zipfile.ZipFile(images_zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)
    print("COCO train images downloaded and extracted.")

# Download annotations
if not os.path.exists(os.path.join(data_dir, 'annotations')):
    print("Downloading COCO annotations...")
    r = requests.get(coco_annotations_url, stream=True)
    with open(annotations_zip_path, 'wb') as f:
        shutil.copyfileobj(r.raw, f)
    
    # Unzip annotations
    with zipfile.ZipFile(annotations_zip_path, 'r') as zip_ref:
        zip_ref.extractall(data_dir)
    print("COCO annotations downloaded and extracted.")

In [3]:
# Initialize COCO API
annotation_file = os.path.join(data_dir, 'annotations/instances_train2017.json') 
coco = COCO(annotation_file)

# Get the category ID for "person"
person_category_id = coco.getCatIds(catNms=['person'])[0]

# Get all images containing people (positive samples)
person_image_ids = coco.getImgIds(catIds=[person_category_id])

# Get some images without people (negative samples)
all_image_ids = coco.getImgIds()
non_person_image_ids = list(set(all_image_ids) - set(person_image_ids))

# Load and preprocess the image, convert grayscale to 3-channel by replicating
def load_and_preprocess_image(coco, img_id, data_dir, img_size=(224, 224)):
    img_info = coco.loadImgs(img_id)[0]
    img_path = os.path.join(data_dir, 'train2017', img_info['file_name'])
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # Read as grayscale
    
    img = cv2.resize(img, img_size)  # Resize to match MobileNetV2 input size
    img = np.expand_dims(img, axis=-1)  # Add channel dimension for grayscale image
    return img


# Load person images
person_images = [load_and_preprocess_image(coco, img_id, data_dir) for img_id in person_image_ids[:2000]]  
non_person_images = [load_and_preprocess_image(coco, img_id, data_dir) for img_id in non_person_image_ids[:2000]]

# Create labels (1 for person, 0 for non-person)
labels = np.array([1] * len(person_images) + [0] * len(non_person_images))

# Combine images and shuffle the dataset
images = np.array(person_images + non_person_images)
indices = np.random.permutation(len(images))
images, labels = images[indices], labels[indices]

loading annotations into memory...
Done (t=9.94s)
creating index...
index created!


Prune first

In [19]:
# Pruning schedule
pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
        initial_sparsity=0.0, final_sparsity=0.5, begin_step=0, end_step=1000)
}

# Define the input layer (grayscale input)
input_layer = Input(shape=(224, 224, 1))  # Grayscale input

# Convert grayscale to RGB using Conv2D (no pruning needed here)
x = Conv2D(3, (3, 3), padding='same', activation='relu')(input_layer)

# Load pre-trained MobileNetV2
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the base model layers
base_model.trainable = False

# Pass the input through MobileNetV2
x = base_model(x)

# Add a Dropout layer (no pruning for Dropout)
x = Dropout(0.5)(x)

# Add Global Average Pooling layer
x = GlobalAveragePooling2D()(x)

# Final Dense layer for binary classification
output = Dense(1, activation='sigmoid')(x)

# Build the model
model = Model(inputs=input_layer, outputs=output)

# Apply pruning to the entire functional model
pruned_model = tfmot.sparsity.keras.prune_low_magnitude(model, **pruning_params)

# Compile the pruned model
pruned_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Summary of the pruned model to verify the structure
pruned_model.summary()

ValueError: `prune_low_magnitude` can only prune an object of the following types: keras.models.Sequential, keras functional model, keras.layers.Layer, list of keras.layers.Layer. You passed an object of type: Functional.

In [13]:
# Define a function to apply pruning to prunable layers
def apply_pruning_to_prunable_layers(layer):
    if isinstance(layer, prunable_layer.PrunableLayer) or hasattr(layer, 'get_prunable_weights') or prune_registry.PruneRegistry.supports(layer):
        return sparsity.prune_low_magnitude(layer)
    print("Not Prunable: ", layer)
    return layer

# Define the input layer (grayscale input)
input_layer = Input(shape=(224, 224, 1))  # Grayscale input

# Convert the grayscale input to RGB
x = Conv2D(3, (3, 3), padding='same', activation='relu')(input_layer)  # Convert grayscale to RGB (3 channels)

# Load pre-trained MobileNetV2
base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))

# Freeze the base model to prevent retraining
base_model.trainable = False

# Connect the base model to the converted input
base_model_output = base_model(x)

# Add custom layers
x = GlobalAveragePooling2D()(base_model_output)
x = Dropout(0.5)(x)  # Dropout for regularization (not pruned)

# Dense layer for added complexity
x = Dense(64, activation='relu')(x)

# Final Dense layer for binary classification (person vs non-person)
output = Dense(1, activation='sigmoid')(x)

# Create the full model
base_model = Model(inputs=input_layer, outputs=output)

# Clone the model and apply pruning only to prunable layers
model_for_pruning = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_pruning_to_prunable_layers
)

# Compile the pruned model
model_for_pruning.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Summary of the pruned model to verify everything
model_for_pruning.summary()

Not Prunable:  <Conv2D name=conv2d, built=True>
Not Prunable:  <Functional name=mobilenetv2_1.00_224, built=True>
Not Prunable:  <GlobalAveragePooling2D name=global_average_pooling2d_5, built=True>
Not Prunable:  <Dropout name=dropout_5, built=True>
Not Prunable:  <Dense name=dense_5, built=True>
Not Prunable:  <Dense name=dense_6, built=True>
