In [1]:
import numpy as np
from copy import deepcopy
import tensorflow as tf
import tensorflow.keras as keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers,regularizers,metrics,optimizers
import random
import pandas as pd
from scipy.linalg import sqrtm
import pickle
from typing import Any, Callable, Dict, List, Optional, Union
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
import math
import scipy.stats as st
from scipy.special import comb
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from collections import defaultdict
import itertools
import json
from collections import deque

In [2]:
from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
config=tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.8
config.gpu_options.allow_growth=True
sess=tf.compat.v1.Session(config=config) 
import math

In [3]:
"""This algorithm is used to evaluate the structural redundancy of ResNet-18
and outputs the evaluation criteria of hidden layer redundancy as well as 
the entire redundancy evaluation criteria under each pruning parameter. 
Here, "Lam" refers to the pruning parameter set used in the evaluation 
algorithm, and "repeats" represents the number of times the pruning network 
is repeatedly fine-tuned."""
Lam=[1.0,0.9,0.8,0.7]
repeats=3

In [4]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train = x_train/255
x_test = x_test/255
y_train_onehot=tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test_onehot=tf.keras.utils.to_categorical(y_test,num_classes=10)

In [5]:
with open('data_dist_ResNet_18.pkl', 'rb') as f:
    [x_dist,y_dist]=pickle.load(f)
x_dist=x_dist.numpy()
y_dist=y_dist.numpy().reshape(len(y_dist),1)

In [6]:
initial_lr = 0.1
weight_decay = 1e-4
epochs = 200
warmup_epochs = 5
batch_size = 128
image_size = 32

In [7]:
class WarmUpCosine(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, base_lr, total_steps, warmup_steps, warmup_lr=0.0):
        super().__init__()
        self.base_lr = base_lr
        self.total_steps = total_steps
        self.warmup_steps = warmup_steps
        self.warmup_lr = warmup_lr
    def __call__(self, step):
        if step is None:
            step = tf.constant(0)
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        total_steps = tf.cast(self.total_steps, tf.float32)
        warmup_percent_done = step / warmup_steps
        learning_rate = tf.where(
            step < warmup_steps,
            self.warmup_lr + (self.base_lr - self.warmup_lr) * warmup_percent_done,
            self.base_lr * 0.5 * (1.0 + tf.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))
        )
        return learning_rate
    def get_config(self):
        return {
            "base_lr": self.base_lr,
            "total_steps": self.total_steps,
            "warmup_steps": self.warmup_steps,
            "warmup_lr": self.warmup_lr,
        }

In [8]:
class CustomWeightDecaySGD(tf.keras.optimizers.SGD):
    def __init__(self, weight_decay, **kwargs):
        super().__init__(**kwargs)
        self.weight_decay = weight_decay
    def apply_gradients(self, grads_and_vars, name=None, experimental_aggregate_gradients=True):
        super().apply_gradients(grads_and_vars, name, experimental_aggregate_gradients)
        for grad, var in grads_and_vars:
            if ('kernel' in var.name) and ('bn' not in var.name.lower()):
                var.assign_sub(self.weight_decay * var)
    def get_config(self):
        config = super().get_config()
        config.update({
            "weight_decay": float(self.weight_decay),  # 确保是float
        })
        return config

In [9]:
class LastNSaver(tf.keras.callbacks.Callback):
    def __init__(self, n=10):
        super().__init__()
        self.n = n
        self.history = deque(maxlen=n)  

    def on_epoch_end(self, epoch, logs=None):
        val_acc = logs.get("val_accuracy")
        if val_acc is not None:
            weights = self.model.get_weights()
            self.history.append((val_acc, weights))

    def on_train_end(self, logs=None):
        if not self.history:
            return
        best_acc, best_weights = max(self.history, key=lambda x: x[0])
        print(f" Using best val_acc={best_acc:.4f} from last {self.n} epochs")
        self.model.set_weights(best_weights) 

In [10]:
def load_Res():
    model = tf.keras.models.load_model('Res18_cifar10.h5',custom_objects={
        'CustomWeightDecaySGD': CustomWeightDecaySGD,
        'WarmUpCosine': WarmUpCosine
    })
    return model

In [11]:
def conv_bn_relu(x, filters, kernel_size, strides=1):
    x = tf.keras.layers.Conv2D(filters, kernel_size, strides=strides, padding='same',use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)

def residual_block(x, filter1, filter2, downsample=False):
    shortcut = x
    strides = 2 if downsample else 1
    x = conv_bn_relu(x, filter1, 3, strides)
    x = tf.keras.layers.Conv2D(filter2, 3, strides=1, padding='same',use_bias=False)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    if downsample:
        shortcut = tf.keras.layers.Conv2D(filter2, 1, strides=strides, padding='same',use_bias=False)(shortcut)
        shortcut = tf.keras.layers.BatchNormalization()(shortcut)
    x = tf.keras.layers.add([x, shortcut])
    return tf.keras.layers.ReLU()(x)

def Res_model(NN,input_shape=(32,32,3), num_classes=10):
    inputs = tf.keras.Input(shape=input_shape)
    x = conv_bn_relu(inputs, NN[2], 3)
    #x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
    x = residual_block(x, NN[0], NN[2])
    x = residual_block(x, NN[1], NN[2])
    x = residual_block(x, NN[3], NN[5], downsample=True)
    x = residual_block(x, NN[4], NN[5])
    x = residual_block(x, NN[6], NN[8], downsample=True)
    x = residual_block(x, NN[7], NN[8])
    x = residual_block(x, NN[9], NN[11], downsample=True)
    x = residual_block(x, NN[10],NN[11])
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    outputs = tf.keras.layers.Dense(num_classes,activation='softmax')(x)
    return tf.keras.Model(inputs, outputs)

In [12]:
model=load_Res()

In [13]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 32, 32, 3)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 32, 32, 64)   1728        ['input_1[0][0]']                
                                                                                                  
 batch_normalization (BatchNorm  (None, 32, 32, 64)  256         ['conv2d[0][0]']                 
 alization)                                                                                       
                                                                                                  
 re_lu (ReLU)                   (None, 32, 32, 64)   0           ['batch_normalization[0][0]']

In [14]:
def JW(m, M):
    """
    Compute the binary MD-LP J_w_value.

    Args:
        m: 1-D tensor of shape [d], the mean of the Minkowski difference of a 
           binary classification dataset.
        M: 2-D tensor of shape [d, N], binary classification dataset Minkowski 
           difference set.
           
    Key idea:
        - Calculate the approximate solution m_weighted for the optimal weights 
          in the MD-LP.
        - Calculate the MD-LP based on the approximately optimal weights, and 
          perform a left truncation at 0.5. 
    Returns:
        Binary MD-LP value.
    """
    row_norm_sq = tf.reduce_sum(tf.square(M), axis=1)  
    reciprocal_norm = tf.where(row_norm_sq != 0,
                               tf.math.reciprocal(row_norm_sq),
                               tf.zeros_like(row_norm_sq))  
    m_weighted = m * reciprocal_norm  
    m_weighted = tf.reshape(m_weighted, [1, -1])  
    mM = tf.matmul(m_weighted, M)
    L1 = tf.reduce_sum(mM)
    L_1 = tf.reduce_sum(tf.abs(mM))
    J_w_value = tf.abs(L1) / (L_1 + 1e-8)
    J_w_value = tf.maximum(J_w_value, 0.5)
    return J_w_value
def W(X, Y, k, n_c=10):
    """
    This function is used to calculate the top k largest binary classification 
    problems MD-LP used in the multi-classification problem calculation. Here, 
    the binary classification problems are obtained by combining each pair of 
    categories of the multi-classification problem.
    Args:
        X: Tensor/array of shape [b, l, w]. Channel output.
        Y: Tensor/array of labels of shape [b]. Data labels.
        k: Number of the largest binary MD-LP to keep.
        n_c: Number of classes.
    Returns:
        JK_list: Tensor of shape [k], the top-k MD-LP.
    """
    b, l, w = X.shape
    X = tf.reshape(X, [b, l*w])   # flatten
    J_list = []
    for i, j in itertools.combinations(range(n_c), 2):
        mask_1 = tf.reshape(tf.equal(Y, i), [-1])
        mask_2 = tf.reshape(tf.equal(Y, j), [-1])
        X1 = tf.boolean_mask(X, mask_1)
        X2 = tf.boolean_mask(X, mask_2)
        n1 = tf.shape(X1)[0]
        n2 = tf.shape(X2)[0]
        m_i = tf.reduce_sum(X1, axis=0) * tf.cast(n2, tf.float32) - tf.reduce_sum(X2, axis=0) * tf.cast(n1, tf.float32)
        m_i = m_i / tf.linalg.norm(m_i + 1e-8)
        M_i = tf.reshape(X1[:, None, :] - X2[None, :, :], [-1, l*w])
        M_i = tf.transpose(M_i)
        J = JW(m_i, M_i)
        J_list.append(J)
    J_list = tf.stack(J_list)
    JK_list , JK_inde = tf.math.top_k(J_list,k)
    return JK_list

In [15]:
def L1_channel(x_L,y, nnn, alpha=2.5):
    """
    This function computes TCR measure of each channel in convolutional layer.
    
    Given the output of a convolutional layer, this function will execute:
    - Treating each channel independently and computing a multi-class MD-LP 
      via function W;
    - By applying nonlinear transformation, a TCR measure is constructed 
      to enhance the separability of MD-LP.
    
    Key Args:
    x_L (Tensor):
        Output of a convolutional hidden layer, with shape 
        [batch_size, height, width, channels].
    y (Tensor):
        Ground-truth labels corresponding to the input samples.
    alpha (float, optional):
        LP transformation parameter. Used to enhance the separability 
        of the MD-LP close to 1.
    
    Returns:
    jw:
        TCR measure of each channel.
    """
    a, b, d, c = x_L.shape
    jw = tf.zeros([c], dtype=tf.float32)
    alpha = tf.cast(alpha, tf.float32)
    for j in tf.range(c):
        N_tf = W(x_L[:,:,:,j], y, nnn)
        jw_j = tf.norm(N_tf) / tf.sqrt(float(nnn))
        jw_j = (tf.exp(alpha * (2*jw_j-1)) - 1.0) / (tf.exp(alpha) - 1.0)
        jw = tf.tensor_scatter_nd_update(jw, [[j]], [jw_j])
    return jw

In [16]:
def prune_channel(x_LG, y, prune_rate, nnn=45, esp=1e-8):
    """
    This function computes the structural redundancy evaluation criterion R_L 
    and determines the set of retained channel indices `channel_i_label` used 
    by the pruning algorithm for a layer group.
    
    Given the output of a convolutional layer, this function will execute:
    - By analyzing the propensity calculation of TCR measure, an evaluation 
      criterion for evaluating the redundancy of convolutional layers is derived.
    - Based on the TCR measure, the pruning threshold is calculated and the 
      channels that remain after pruning are selected.
    
    Key Args:
    x_LG (a list of Tensors):
        Output of a group of convolutional hidden layers.
    y (Tensor):
        Ground-truth labels corresponding to the input samples.
    prune_rate (float):
        Pruning parameter. Used to control the strictness of pruning.
    
    Returns:
    channel_i_label (ndarray):
        Indices of channels retained after pruning.
    R_L (float):
        Structural redundancy evaluation criterion of the layer group,
    """
    a, b, d, c = x_LG[-1].shape
    L1_list = []
    for i in range(len(x_LG)):
        L1_i = L1_channel(x_LG[i], y, nnn)
        L1_list.append(L1_i)
    L1 = tf.stack(L1_list, axis=0)  
    jw = tf.reduce_mean(L1, axis=0) 
    jw_min = tf.maximum(tf.reduce_min(jw) - esp, 0.0)
    jw_max = tf.reduce_max(jw)
    me = tf.sqrt(tf.reduce_mean(tf.square(jw - jw_min)))
    jd = jw_min + prune_rate * me
    mean = tf.maximum(tf.reduce_mean(jw) - esp, 0.0)
    R_L = tf.reduce_mean(tf.sign(jw - mean))
    channel_i_label = tf.where(jw >= jd)[:,0]
    return channel_i_label.numpy(), R_L.numpy()

In [17]:
def Group_1(x_L,w,S):
    """This function is used to collect the key hidden layer outputs of the first block 
    in ResNet. Here, x_L1 is used to update the list of output from the ResNet 
    convolutional layers, and x_L2 is used to provide the hidden layer outputs required 
    for layer group pruning."""
    wb=w[0]
    bn=w[1]
    x_L1=[]
    x_L2=[]
    x_1=layer_xL("conv2d",x_L,w=wb[0])
    x_1=layer_xL("batch_normalization",x_1,w=bn[0])
    x_1=layer_xL("activation",x_1)
    x_L2.append(deepcopy(x_1))
    #x_1=layer_xL("maxpooling",x_1)
    x_L1.append(deepcopy(x_1))

    x_2=layer_xL("conv2d",x_1,w=wb[1])
    x_2=layer_xL("batch_normalization",x_2,w=bn[1])
    x_2=layer_xL("activation",x_2)
    x_L1.append(deepcopy(x_2))
    
    x_2=layer_xL("conv2d",x_2,w=wb[2])
    x_2=layer_xL("batch_normalization",x_2,w=bn[2])
    x_1=layer_xL("add",[x_2,x_1])
    x_1=layer_xL("activation",x_1)
    x_L1.append(deepcopy(x_1))
    x_L2.append(deepcopy(x_1))

    x_2=layer_xL("conv2d",x_1,w=wb[3])
    x_2=layer_xL("batch_normalization",x_2,w=bn[3])
    x_2=layer_xL("activation",x_2)
    x_L1.append(deepcopy(x_2))
    
    x_2=layer_xL("conv2d",x_2,w=wb[4])
    x_2=layer_xL("batch_normalization",x_2,w=bn[4])
    x_1=layer_xL("add",[x_2,x_1])
    x_1=layer_xL("activation",x_1)
    x_L1.append(deepcopy(x_1))
    x_L2.append(deepcopy(x_1))
    return x_L1,x_L2

In [18]:
def Group_L(x_L,w,S):
    """This function is used to collect the key hidden layer outputs of the other block 
    in ResNet. Here, x_L1 is used to update the list of output from the ResNet 
    convolutional layers, and x_L2 is used to provide the hidden layer outputs required 
    for layer group pruning."""
    wb=w[0]
    bn=w[1]
    x_L1=[]
    x_L2=[]
    x_1=layer_xL("conv2d",x_L,w=wb[0],S=2)
    x_1=layer_xL("batch_normalization",x_1,w=bn[0])
    x_1=layer_xL("activation",x_1)
    x_L1.append(deepcopy(x_1))
            
    x_1=layer_xL("conv2d",x_1,w=wb[1])
    x_2=layer_xL("conv2d",x_L,w=wb[2],S=2)
    x_1=layer_xL("batch_normalization",x_1,w=bn[1])
    x_2=layer_xL("batch_normalization",x_2,w=bn[2])
    x_1=layer_xL("add",[x_1,x_2])
    x_1=layer_xL("activation",x_1)
    x_L1.append(deepcopy(x_1))
    x_L2.append(deepcopy(x_1))
            
    x_2=layer_xL("conv2d",x_1,w=wb[3])
    x_2=layer_xL("batch_normalization",x_2,w=bn[3])
    x_2=layer_xL("activation",x_2)
    x_L1.append(deepcopy(x_2))
    
    x_2=layer_xL("conv2d",x_2,w=wb[4])
    x_2=layer_xL("batch_normalization",x_2,w=bn[4])
    x_1=layer_xL("add",[x_2,x_1])
    x_1=layer_xL("activation",x_1)
    x_L1.append(deepcopy(x_1))
    x_L2.append(deepcopy(x_1))
    return x_L1,x_L2

In [19]:
def layer_xL(layer_name,x_L,w=None,F=False,S=1):
    """This function is used to obtain the outputs of various hidden layers or layer groups in the 
    network, which is utilized for updating the outputs of hidden layers or for performing layer 
    group pruning calculations."""
    if "conv2d" in layer_name:
        weight=w[0]
        strides=[1,S,S,1]
        x_L1=tf.nn.conv2d(x_L,weight,strides=strides,padding="SAME")
        #x_L1=tf.nn.bias_add(x_L1,bias)
        return x_L1
    if "first" in layer_name:
        x_L1,x_L2=Group_1(x_L,w,S)
        if F==True:
            return x_L2
        else:
            return x_L1
    if "group" in layer_name:
        x_L1,x_L2=Group_L(x_L,w,S)
        if F==True:
            return x_L2
        else:
            return x_L1
    if "batch_normalization" in layer_name:
        gamma,beta,mean,var=w
        x_L1=tf.nn.batch_normalization(x_L,mean=mean,
                                          variance=var,
                                          offset=beta,
                                          scale=gamma,variance_epsilon=1e-5)
        return x_L1
    if "activation" in layer_name:
        x_L1=tf.nn.relu(x_L)
        return x_L1
    if "add" in layer_name:
        x1,x2=x_L
        x_L1=tf.math.add(x1,x2)
        return x_L1

In [20]:
def x_block(a,b,G,weight_list,x_LG,First=False,Group=False,R=False,F=False):
    """This function is used to update the list of hidden layer outputs 
    of ResNet after function pruning. For single-hidden-layer pruning, 
    it only needs to update the output of the convolutional layer within 
    the block. However, for layer group pruning, since it involves multiple 
    blocks, the output of all convolutional layers in these blocks needs 
    to be updated."""
    if Group==False:
        if R==True:
            x_1=x_LG[b]
            layer_l=G[a][0]
            wc=weight_list[layer_l]
            wb=weight_list[layer_l+1]
            x_2=layer_xL("conv2d",x_1,wc,S=2)
            x_2=layer_xL("batch_normalization",x_2,wb)
            x_2=layer_xL("activation",x_2)
            x_LG[b+1]=x_2
            layer_l=G[a+2][0]
            wc=weight_list[layer_l]
            wb=weight_list[layer_l+2]
            x_2=layer_xL("conv2d",x_2,wc)
            x_2=layer_xL("batch_normalization",x_2,wb)
            layer_l=G[a+2][1]
            wc=weight_list[layer_l]
            wb=weight_list[layer_l+2]
            x_1=layer_xL("conv2d",x_1,wc,S=2)
            x_1=layer_xL("batch_normalization",x_1,wb)
            x_1=layer_xL("add",[x_2,x_1])
            x_1=layer_xL("activation",x_1)
            x_LG[b+2]=x_1
            return x_LG
        if R==False:
            x_1=x_LG[b]
            layer_l=G[a][0]
            wc=weight_list[layer_l]
            wb=weight_list[layer_l+1]
            x_2=layer_xL("conv2d",x_1,wc)
            x_2=layer_xL("batch_normalization",x_2,wb)
            x_2=layer_xL("activation",x_2)
            x_LG[b+1]=x_2
            if First==True:
                layer_l=G[a+2][1]
            else:
                layer_l=G[a+1][2]
            wc=weight_list[layer_l]
            wb=weight_list[layer_l+1]
            x_2=layer_xL("conv2d",x_2,wc)
            x_2=layer_xL("batch_normalization",x_2,wb)
            x_1=layer_xL("add",[x_2,x_1])
            x_1=layer_xL("activation",x_1)
            x_LG[b+2]=x_1
            return x_LG
    if Group==True:
        if First==True:
            label=[G[a][0],G[a-2][0],G[a][1],G[a-1][0],G[a][2]]
            w1=[]
            w2=[]
            x_1=x_LG[b]
            for i in range(5):
                w1.append(weight_list[label[i]])
                w2.append(weight_list[label[i]+1])
            w=[w1,w2]
            x_1=layer_xL("first",x_1,w,F=False,S=1)
            for i in range(5):
                x_LG[b+i+1]=x_1[i]
            return x_LG
        else:
            label=[G[a-2][0],G[a][0],G[a][1],G[a-1][0],G[a][2]]
            w1=[]
            w2=[]
            x_1=x_LG[b]
            l=[1,2,2,1,1]
            for i in range(5):
                w1.append(weight_list[label[i]])
                w2.append(weight_list[label[i]+l[i]])
            w=[w1,w2]
            x_1=layer_xL("group",x_1,w)
            for i in range(4):
                x_LG[b+i+1]=x_1[i]
            return x_LG

In [21]:
def get_x(a,b,G,x_LG,weight_list,First=False):
    """This function is used to provide the required output list 
    for 'prune_channel' function when performing layer group pruning."""
    x_l=x_LG[b]
    if First==True:
        label=[G[a][0],G[a-2][0],G[a][1],G[a-1][0],G[a][2]]
    else:
        label=[G[a-2][0],G[a][0],G[a][1],G[a-1][0],G[a][2]]
    w1=[]
    w2=[]
    for i in range(5):
        w1.append(weight_list[label[i]])
    l=[1,2,2,1,1]
    if First==True:
        for i in range(5):
            w2.append(weight_list[label[i]+1])
        w=[w1,w2]
        x_L=layer_xL('first',x_l,w,F=True)
    else:
        for i in range(5):
            w2.append(weight_list[label[i]+l[i]])
        w=[w1,w2]
        x_L=layer_xL('group',x_l,w,F=True)
    return x_L

In [22]:
def prune_model(model,G,P,x,y,prune_rate,q=67):
    """
    Structured Channel Group Pruning Function Based on MD-LP (Channel-wise Pruning) 
    
    This function is the main function for pruning in ResNet-18. It achieves 
    the pruning of layer groups by using the provided convolution layer groups G.
    The main process of this function is as follows:
    - Pruning Preparation: Based on the hidden layer positions provided by P, 
      construct the input/output lists for the pruning-required convolutional layers, 
      as well as the list of convolution kernel parameters and BN layer parameters.
    - Layer Group Pruning: In accordance with the sequence in G, the pruning_channel 
      function is used to perform pruning successively, resulting in the channels 
      that are retained after pruning, which are labeled as channel_new_label.
    - Output/Parameter Update: Based on the channel_new_label, the parameters of the 
      convolutional layers included in the group, as well as the parameters of the 
      convolutional layers whose outputs are used as inputs, are updated. And the 
      input/output lists of the convolutional layers are updated according to the 
      updated parameters.
    
    Input:
    model : Original Keras ResNet-18
    x : Network input samples (used for forward propagation and channel evaluation)
    y : Sample labels (used for metric calculation in prune_channel)
    prune_rate : Pruning parameter
    G : The grouped list of network convolution layers, based on the ReNet network 
    structure, divides the hidden layers of ResNet. It is used to ensure that the 
    output results of the network hidden layers within the same group after pruning 
    can still maintain the same size and can be added together.
    
    Output:
    weight_list: List of weights for each convolutional / BN layer after pruning.
    channel_label: Record of the number of retained channels for each layer group.   
    """
    layer_outputs =[layer.output for layer in model.layers] 
    weight_list=[]
    # =========================
    # Collect the input/output of the convolutional layers, and the parameters
    # of the convolutional layers and BN layers in the ResNet network.
    # =========================
    for i in range(q+1):
        layer=model.layers[i]
        if "conv" in layer.name:
            w=layer.get_weights()
            weight_list.append(w)
            #print(i)
        elif "dense" in layer.name:
            w,b=layer.get_weights()
            weight_list.append([w,b])
        elif "batch_normalization" in layer.name:
            g,b,m,v=layer.get_weights()
            weight_list.append([g,b,m,v])
        else:
            weight_list.append(None)
    x_LG=[]
    for i in range(len(P)):
        activation_model = tf.keras.models.Model(inputs=model.input,outputs=layer_outputs[P[i]])
        layer_x=activation_model.predict(x)
        x_LG.append(layer_x)
    channel_label=[]
    b=1
    for i in range(len(G)):
        if len(G[i])==1:
            # =========================
            # According to the prune_channel function, the single convolution layer in G is pruned, 
            # resulting in the retained channels after pruning. Based on the pruning results, the 
            # network parameters and network input/output are updated.
            # =========================
            if (b-1)%4==0:
                a=G[i][0]
                channel_new_label,r_L=prune_channel([x_LG[b+1]],y,prune_rate)
                print(len(channel_new_label),0)
                weight_list[a][0]=weight_list[a][0][:,:,:,channel_new_label]
                for j in range(4):
                    weight_list[a+1][j]=weight_list[a+1][j][channel_new_label]
                if i==0:
                    a1=G[i+2][1]
                    weight_list[a1][0]=weight_list[a1][0][:,:,channel_new_label,:]
                    x_LG=x_block(i,b,G,weight_list,x_LG,First=True)
                else:
                    a1=G[i+2][0]
                    weight_list[a1][0]=weight_list[a1][0][:,:,channel_new_label,:]
                    x_LG=x_block(i,b,G,weight_list,x_LG,First=False,R=True)
                channel_label.append(len(channel_new_label))
                b+=2
                continue
            if (b-1)%4==2:
                a=G[i][0]
                a1=G[i+1][2]
                channel_new_label,r_L=prune_channel([x_LG[b+1]],y,prune_rate)
                print(len(channel_new_label),1)
                weight_list[a][0]=weight_list[a][0][:,:,:,channel_new_label]
                for j in range(4):
                    weight_list[a+1][j]=weight_list[a+1][j][channel_new_label]
                weight_list[a1][0]=weight_list[a1][0][:,:,channel_new_label,:]
                x_LG=x_block(i,b,G,weight_list,x_LG)
                channel_label.append(len(channel_new_label))
                if i==1:
                    b-=3
                else:
                    b-=2
                continue
        if len(G[i])>1:
            # =========================
            # According to the prune_channel function, the layer group in G with multiple convolutional 
            # layers is pruned to obtain a unified set of pruned channels. Based on the pruning results, 
            # the network parameters and network input/output are updated.
            # =========================
            if i==2:
                x_LP=get_x(i,b,G,x_LG,weight_list,First=True)
            else:
                x_LP=get_x(i,b,G,x_LG,weight_list)
            channel_new_label,r_L=prune_channel(x_LP,y,prune_rate)
            print(len(channel_new_label),2)
            for g in G[i]:
                weight_list[g][0]=weight_list[g][0][:,:,:,channel_new_label]
            weight_list[G[i-1][0]][0]=weight_list[G[i-1][0]][0][:,:,channel_new_label,:]
            if i==2:
                for g in G[i]:
                    for j in range(4):
                        weight_list[g+1][j]=weight_list[g+1][j][channel_new_label]
                weight_list[G[i-2][0]][0]=weight_list[G[i-2][0]][0][:,:,channel_new_label,:]
            else:
                l=[2,2,1]
                for g in range(3):
                    for j in range(4):
                        weight_list[G[i][g]+l[g]][j]=weight_list[G[i][g]+l[g]][j][channel_new_label]
            if i!=len(G)-1:
                weight_list[G[i+1][0]][0]=weight_list[G[i+1][0]][0][:,:,channel_new_label,:]
                weight_list[G[i+3][1]][0]=weight_list[G[i+3][1]][0][:,:,channel_new_label,:]
            if i==2:
                x_LG=x_block(i,b,G,weight_list,x_LG,First=True,R=False,Group=True)
                b+=5
            else:
                x_LG=x_block(i,b,G,weight_list,x_LG,First=False,R=False,Group=True)
                b+=4
            channel_label.append(len(channel_new_label))
            continue
    weight_list[q][0]=weight_list[q][0][channel_new_label]
    return weight_list,channel_label

In [23]:
def model_pr(model,weight_list,channel_label):
    """This function is used to construct a pruned network by using 
    the given pruned network structure and parameters."""
    model_p=Res_model(channel_label)
    for i in range(len(weight_list)):
        if weight_list[i]!=None:
            w = [ww for ww in weight_list[i]]
            model_p.layers[i].set_weights(w)
    return model_p

In [24]:
G=[[4],[11],[1,7,14],[18],[27],[21,22,30],[34],[43],[37,38,46],[50],[59],[53,54,62]]

In [25]:
P=[0,3,6,10,13,17,20,26,29,33,36,42,45,49,52,58,61,65]

In [26]:
datagen = ImageDataGenerator(
            featurewise_center=False,  # set input mean to 0 over the dataset
            samplewise_center=False,  # set each sample mean to 0
            featurewise_std_normalization=False,  # divide inputs by std of the dataset
            samplewise_std_normalization=False,  # divide each input by its std
            zca_whitening=False,  # apply ZCA whitening
            rotation_range=15,  # randomly rotate images in the range (degrees, 0 to 180)
            width_shift_range=0.1,  # randomly shift images horizontally (fraction of total width)
            height_shift_range=0.1,  # randomly shift images vertically (fraction of total height)
            horizontal_flip=True,  # randomly flip images
            vertical_flip=False)  # randomly flip images
        # (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)

In [27]:
def retrain(model,x_train,y_train,x_test,y_test):
    """This function is used to fine-tune the pruned network 
    using the same method as the original network training."""
    total_steps = epochs * (x_train.shape[0] // batch_size)
    warmup_steps = warmup_epochs * (x_train.shape[0] // batch_size)
    lr_schedule = WarmUpCosine(initial_lr, total_steps, warmup_steps)
    optimizer = CustomWeightDecaySGD(weight_decay=weight_decay,learning_rate=lr_schedule,momentum=0.9,nesterov=True)
    loss_fn=tf.keras.losses.CategoricalCrossentropy()
    model.compile(optimizer=optimizer,loss=loss_fn,metrics=['accuracy'])
    saver = LastNSaver(n=20)
    model.fit(datagen.flow(x_train, y_train_onehot,batch_size=batch_size),
                            steps_per_epoch=x_train.shape[0] // batch_size,
                            epochs=epochs,
                            validation_data=(x_test, y_test_onehot),verbose=2,callbacks=[saver])

In [28]:
"""These functions are used to calculate the FLOPs 
and the number of parameters of the network."""
def conv_flops_params(layer, input_shape):
    h_in, w_in, cin = input_shape[1:]
    h_out, w_out, cout = layer.output_shape[1:]
    k_h, k_w = layer.kernel_size
    flops = h_out * w_out * cin * cout * k_h * k_w
    params = cin * cout * k_h * k_w
    if layer.use_bias:
        params += cout
    return flops, params, (h_out, w_out, cout)
def dense_flops_params(layer, input_shape):
    cin = input_shape[-1]
    cout = layer.units
    flops = cin * cout
    params = cin * cout
    if layer.use_bias:
        params += cout
    return flops, params, (cout,)
def compute_flops_params(model, input_shape=(32, 32, 3)):
    total_flops = 0
    total_params = 0
    dummy_input = tf.zeros((1, *input_shape))
    _ = model(dummy_input)
    current_shape = (1, *input_shape)
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Conv2D):
            flops, params, out_shape = conv_flops_params(layer, current_shape)
            total_flops += flops
            total_params += params
            current_shape = (1, *out_shape)
        elif isinstance(layer, tf.keras.layers.Dense):
            flops, params, out_shape = dense_flops_params(layer, current_shape)
            total_flops += flops
            total_params += params
            current_shape = (1, *out_shape)
    return total_flops, total_params

In [29]:
def R_layers(model,x,y,R=P[1:]):
    """This function is used to obtain the structural redundancy 
    criterion of each block in the ResNet-18 network."""
    layer_outputs =[layer.output for layer in model.layers] 
    R_L=[]
    channel_label=[]
    for i in range(len(R)):
        print('start')
        activation_model = tf.keras.models.Model(inputs=model.input,outputs=layer_outputs[R[i]])
        x_L=activation_model.predict(x)
        channel_new_label,r_L=prune_channel([x_L],y,0,nnn=15)
        print('finish')
        r_L=float(r_L)
        R_L.append(r_L)
        print(r_L)
    R_L=np.array(R_L)
    R=np.mean(R_L)
    LLL=[1,2,2,2,2,2,2,2,2]
    RR_L=[]
    iii=0
    for k in range(len(LLL)):
        print(R_L)
        if LLL[k]==1:
            RR_L.append(R_L[0])
        if LLL[k]==2:
            RR_L.append(R_L[iii:iii+2].sum()/2)
        iii+=LLL[k]
    return R,RR_L

In [30]:
def channel_G(model):
    """This function is used to obtain the number of channels of the layer group."""
    C=[]
    for i in range(len(G)):
        CG=0
        for g in G[i]:
            a,b,d,c=model.layers[g].output.shape
            CG+=c
        C.append(CG)
    return C

In [31]:
P_list=[]
E_list=[]
F_list=[]
#RP_list=[]
#RRP_list=[]
C_list=[]
flops,par=compute_flops_params(model)
loss, acc = model.evaluate(x_test, y_test_onehot)
C_0=channel_G(model)
print(flops)

561714176


In [32]:
SAVE_FILE = "training_ResNet18_log.json"
def load_progress():
    if os.path.exists(SAVE_FILE):
        with open(SAVE_FILE, "r") as f:
            return json.load(f)
    return {"results": [], 
            "RR_L": [],
            "P_list": [],
            "E_list": [],
            "F_list": [],
            "C_list": [],
            "last_lam_idx": 0,
            "last_repeat": 0,
            "RL_exist": 0,
            "Cri_exist": 0}
def save_progress(progress):
    with open(SAVE_FILE, "w") as f:
        json.dump(progress, f)

In [33]:
progress = load_progress()
start_lr_idx = progress["last_lam_idx"]
start_repeat = progress["last_repeat"]
If_RL = progress["RL_exist"]

In [34]:
if If_RL == 0:
    model=load_Res()
    R_L,RR_L=R_layers(model,x_dist,y_dist)
    progress["RR_L"].append(RR_L)
    progress["RL_exist"] = 1
    save_progress(progress)

In [35]:
print(progress["RR_L"])

[[0.03125, -0.03125, 0.296875, 0.2421875, 0.1953125, 0.15625, 0.1015625, 0.146484375, 0.591796875]]


In [36]:
repeats=3

In [37]:
for lam_idx in range(start_lr_idx, len(Lam)):
    lam = Lam[lam_idx]
    for rep in range(start_repeat, repeats):
        print(f"\n lambda: Lam={lam}, Repeat={rep+1}/{repeats}")
        if progress["Cri_exist"] == 0:
            model=load_Res()
            weight_list,channel_label=prune_model(model,G,P,x_dist,y_dist,lam)
            model_p=model_pr(model,weight_list,channel_label)
            flops_p,par_p=compute_flops_params(model_p)
            P_=par_p/par
            F=flops_p/flops
            C_P=channel_G(model_p)
            print(flops_p,flops)
            progress["P_list"].append(P_)
            progress["F_list"].append([flops_p,F])
            progress["C_list"].append([C_P])
            progress["Cri_exist"] = 1
            save_progress(progress)
            model_p.save("Res_18_pruned.h5")
        else:
            model_p=tf.keras.models.load_model('Res_18_pruned.h5',custom_objects={
                'CustomWeightDecaySGD': CustomWeightDecaySGD,
                'WarmUpCosine': WarmUpCosine})
            flops_p,par_p=compute_flops_params(model_p)
            F=flops_p/flops
            print(flops_p,flops)
        retrain(model_p,x_train,y_train_onehot,x_test,y_test_onehot)
        loss_p, acc_p = model_p.evaluate(x_test, y_test_onehot)
        print(f" Finished: Lam={lam}, Repeat={rep+1}, Acc={acc_p:.4f}")
        progress["results"].append(acc_p)
        progress["last_lam_idx"] = lam_idx
        progress["last_repeat"] = rep+1
        save_progress(progress)
    progress["E_list"].append(sum(progress["results"])/(repeats*acc))
    progress["results"]=[]
    progress["Cri_exist"] = 0
    progress["last_repeat"] = 0
    progress["last_lam_idx"] = lam_idx + 1
    start_repeat=0
    save_progress(progress)


 lambda: Lam=1.0, Repeat=2/3
132506468 561714176
Epoch 1/200
390/390 - 25s - loss: 1.0519 - accuracy: 0.6380 - val_loss: 3.0781 - val_accuracy: 0.2232 - 25s/epoch - 64ms/step
Epoch 2/200
390/390 - 22s - loss: 0.7249 - accuracy: 0.7511 - val_loss: 2.2710 - val_accuracy: 0.4803 - 22s/epoch - 55ms/step
Epoch 3/200
390/390 - 21s - loss: 0.6209 - accuracy: 0.7882 - val_loss: 1.1878 - val_accuracy: 0.6627 - 21s/epoch - 54ms/step
Epoch 4/200
390/390 - 21s - loss: 0.5594 - accuracy: 0.8075 - val_loss: 0.9046 - val_accuracy: 0.7076 - 21s/epoch - 54ms/step
Epoch 5/200
390/390 - 20s - loss: 0.5123 - accuracy: 0.8228 - val_loss: 1.4879 - val_accuracy: 0.5927 - 20s/epoch - 53ms/step
Epoch 6/200
390/390 - 22s - loss: 0.4674 - accuracy: 0.8401 - val_loss: 1.2169 - val_accuracy: 0.6743 - 22s/epoch - 57ms/step
Epoch 7/200
390/390 - 22s - loss: 0.4149 - accuracy: 0.8581 - val_loss: 0.8488 - val_accuracy: 0.7497 - 22s/epoch - 56ms/step
Epoch 8/200
390/390 - 22s - loss: 0.3818 - accuracy: 0.8686 - val_lo

  layer_config = serialize_layer_fn(layer)


Epoch 1/200
390/390 - 24s - loss: 0.8601 - accuracy: 0.7157 - val_loss: 2.7571 - val_accuracy: 0.3810 - 24s/epoch - 62ms/step
Epoch 2/200
390/390 - 21s - loss: 0.6597 - accuracy: 0.7769 - val_loss: 2.2355 - val_accuracy: 0.4841 - 21s/epoch - 54ms/step
Epoch 3/200
390/390 - 21s - loss: 0.5617 - accuracy: 0.8091 - val_loss: 0.7927 - val_accuracy: 0.7427 - 21s/epoch - 55ms/step
Epoch 4/200
390/390 - 22s - loss: 0.4984 - accuracy: 0.8314 - val_loss: 1.3957 - val_accuracy: 0.6403 - 22s/epoch - 55ms/step
Epoch 5/200
390/390 - 20s - loss: 0.4536 - accuracy: 0.8445 - val_loss: 1.0387 - val_accuracy: 0.6838 - 20s/epoch - 52ms/step
Epoch 6/200
390/390 - 21s - loss: 0.4102 - accuracy: 0.8588 - val_loss: 0.7868 - val_accuracy: 0.7746 - 21s/epoch - 55ms/step
Epoch 7/200
390/390 - 22s - loss: 0.3582 - accuracy: 0.8759 - val_loss: 0.5158 - val_accuracy: 0.8327 - 22s/epoch - 55ms/step
Epoch 8/200
390/390 - 20s - loss: 0.3263 - accuracy: 0.8883 - val_loss: 0.5305 - val_accuracy: 0.8361 - 20s/epoch - 51

In [38]:
[0.9424,0.9406,0.9391]

[0.9424]

In [47]:
[0.9421,0.9449,0.9433]

[0.9421, 0.9449, 0.9433]

In [48]:
[0.9451,0.9466,0.9461]

[0.9451, 0.9466, 0.9461]

In [None]:
[0.9457,0.9452,0.9475]

In [43]:
progress["P_list"]

[0.21097767067728945,
 0.3448167909478093,
 0.5083822068993015,
 0.570894501373188]

In [44]:
progress["F_list"]

[[132506468, 0.23589660660442366],
 [199343756, 0.35488468070992746],
 [269001286, 0.4788935325000592],
 [321009422, 0.5714817886312344]]

In [45]:
progress["E_list"]

[0.9966098160812437,
 0.9995055994940815,
 1.0021542037030278,
 1.0023660836201163]

In [46]:
progress["C_list"]

[[[22, 39, 102, 69, 66, 180, 126, 118, 378, 224, 281, 606]],
 [[25, 44, 120, 78, 83, 231, 156, 152, 462, 269, 367, 810]],
 [[26, 51, 126, 89, 91, 288, 191, 175, 528, 336, 420, 1053]],
 [[34, 53, 141, 94, 101, 321, 210, 189, 588, 382, 463, 1017]]]