In [14]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
# keras.backend.set_image_data_format("channels_first")
keras.backend.set_image_data_format("channels_last")

In [15]:
dataset_base_dir = "/depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained"
tfrecords_base_dir = os.path.join(dataset_base_dir, "TFR_files", "2t")

dataset_train_dir = os.path.join(dataset_base_dir, "train")
dataset_test_dir = os.path.join(dataset_base_dir, "test")
tfrecords_dir_train = os.path.join(tfrecords_base_dir, "TFR_train")
tfrecords_dir_val   = os.path.join(tfrecords_base_dir, "TFR_val")

train_file_size = len(os.listdir(dataset_train_dir))
val_file_size = len(os.listdir(dataset_test_dir))
print(f"train_file_size: {train_file_size}\nval_file_size: {val_file_size}")

os.chdir("/home/das214/SmartPix/smrtpx_PQ/")
from DG.OptimizedDataGenerator_v2 import OptimizedDataGenerator
from losses.loss import custom_loss


train_file_size: 80
val_file_size: 20


In [16]:
val_loader = OptimizedDataGenerator(
    load_from_tfrecords_dir= tfrecords_dir_val,
    shuffle=True,
    seed=42,
    quantize=False,
)

train_loader = OptimizedDataGenerator(
    load_from_tfrecords_dir = tfrecords_dir_train,
    shuffle=True,
    seed=42,
    quantize=False,
)


Loading metadata from /depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained/TFR_files/2t/TFR_val/metadata.json




Loading metadata from /depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained/TFR_files/2t/TFR_train/metadata.json




In [27]:
try:
    os.chdir("/home/das214/PQuant/mdmm_dev/src")
except:
    pass

for f in os.listdir(os.getcwd()):
    print(f)

pquant


In [28]:
from keras.layers import Dense, SeparableConv2D, Conv2D, AveragePooling2D, Flatten, Input, Activation
from keras.models import Model
import tensorflow as tf

def var_network(var, hidden, output):
    var = Flatten()(var)
    var = Dense(hidden,
        kernel_regularizer=tf.keras.regularizers.L1L2(0.01),
        activity_regularizer=tf.keras.regularizers.L2(0.01)
    )(var)
    var = Activation("tanh")(var)
    var = Dense(hidden,
        kernel_regularizer=tf.keras.regularizers.L1L2(0.01),
        activity_regularizer=tf.keras.regularizers.L2(0.01)
    )(var)
    var = Activation("tanh")(var)
    return Dense(output,
        kernel_regularizer=tf.keras.regularizers.L1L2(0.01),
    )(var)

def conv_network(var, n_filters, kernel_size):
    var = SeparableConv2D(
        n_filters,kernel_size,
        depthwise_regularizer=tf.keras.regularizers.L1L2(0.01),
        pointwise_regularizer=tf.keras.regularizers.L1L2(0.01),
        activity_regularizer=tf.keras.regularizers.L2(0.01)
        )(var)
    var = Activation("tanh")(var)
    var = Conv2D(
        n_filters,1,
        kernel_regularizer=tf.keras.regularizers.L1L2(0.01),
        activity_regularizer=tf.keras.regularizers.L2(0.01)
        )(var)
    var = Activation("tanh")(var)    
    return var

def CreateModel(shape, n_filters, kernel_size, pool_size, hidden, output):
    x_base = x_in = Input(shape)
    stack = conv_network(x_base,  n_filters, kernel_size)
    stack = AveragePooling2D(
        pool_size=(pool_size, pool_size), 
        strides=None, 
        padding="valid", 
        data_format=None,        
    )(stack)
    stack = Activation("tanh")(stack)
    stack = var_network(stack, hidden=16, output=output)
    model = Model(inputs=x_in, outputs=stack)
    return model

In [29]:
model=CreateModel(shape = (16,16,2), 
                  n_filters=5, kernel_size=3,
                  pool_size=3, 
                  hidden=10,
                  output = 14)

model.summary()

In [30]:
model.compile(
    optimizer=tf.keras.optimizers.Nadam(learning_rate=1e-3, clipnorm=1.0),
    loss=custom_loss,
)
history = model.fit(
        x=train_loader,
        validation_data=val_loader,
        epochs=5,
        shuffle=False,
        verbose=1
    )

Epoch 1/5
[1m84/84[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 302ms/step - loss: 64958.6094 - val_loss: 9073.3701
Epoch 2/5
[1m84/84[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 98ms/step - loss: 6669.2617 - val_loss: 615.1271
Epoch 3/5
[1m84/84[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 97ms/step - loss: -515.6271 - val_loss: -3601.6577
Epoch 4/5
[1m84/84[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 98ms/step - loss: -4654.4536 - val_loss: -7608.0542
Epoch 5/5
[1m84/84[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 99ms/step - loss: -8298.7861 - val_loss: -9543.2158


## Add pruning and quantization
To add pruning and quantization, we need a config file that defines how to do that. Let's load a config file from `pquant/configs/`. <br/>
The training function we use later will add the pruning layers and quantized activations automatically using this config

In [21]:
from pquant import get_default_config
from IPython.display import JSON

# pruning_methods: "autosparse, cl, cs, dst, pdp, wanda, mdmm"
pruning_method = "mdmm"
config = get_default_config(pruning_method)
JSON(config)

<IPython.core.display.JSON object>

In [22]:
# Replace layers with compressed layers
from pquant import add_compression_layers
input_shape = (1, 16, 16, 2)

model = add_compression_layers(model, config, input_shape)
model.summary()

## Create data set
#### Let's create the data loader and the training and validation loops

In [23]:
from quantizers.fixed_point.fixed_point_ops import get_fixed_quantizer
# Set up input quantizer
quantizer = get_fixed_quantizer(overflow_mode="SAT")

import tensorflow as tf
from tqdm import tqdm
from pquant import get_layer_keep_ratio, get_model_losses


@tf.function
def train_step(model, inputs, logits, optimizer: keras.optimizers.Optimizer):
    with tf.GradientTape() as tape:
        outputs = model(inputs, training=True)
        loss = custom_loss(logits, outputs)
        # loss += get_model_losses(model, losses=keras.ops.convert_to_tensor(0.))
        if model.losses:
            loss += tf.add_n(model.losses)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss

@tf.function
def valid_step(model, inputs, logits):
    outputs = model(inputs, training=False)
    loss = custom_loss(logits, outputs)
    # loss += get_model_losses(model, losses=keras.ops.convert_to_tensor(0.))
    if model.losses:
            loss += tf.add_n(model.losses)
    return loss

def train_smart_pixels_tf(model, epoch, trainloader, optimizer: keras.optimizers.Optimizer, *args, **kwargs):
    for batch_idx in tqdm(range(trainloader.__len__())):
        inputs, logits = trainloader[batch_idx]
        loss = train_step(model, inputs, logits, optimizer)
        
def validate_smart_pixels_tf(model, epoch, testloader, *args, **kwargs):
    total_loss = 0.0
    for batch_idx in tqdm(range(testloader.__len__())):
        inputs, logits = testloader[batch_idx]
        loss = valid_step(model, inputs, logits)
        total_loss += loss
        
    ratio = get_layer_keep_ratio(model)
    print(f'NLL+mdmm_loss: {total_loss/testloader.__len__():.2f}, remaining_weights: {ratio * 100:.2f}%')


## Create loss function, scheduler and optimizer

In [24]:
from keras.optimizers import Adam

optimizer = Adam(learning_rate=config["lr"])

## Train model
Training time. We use the train_compressed_model function from pquant to train. We need to provide some parameters such as training and validation functions, their input parameters, the model and the config file. The function automatically adds pruning layers and replaces activations with a quantized variant, trains the model, and removes the pruning layers after training is done

In [25]:
for layer in model.layers:
        if "CompressedLayer" in layer.__class__.__name__ and hasattr(layer, 'pruning_layer'):
            pruning_layer = layer.pruning_layer # This is your MDMM instance
            
            # Check if the MDMM instance has the constraint_layer
            if hasattr(pruning_layer, 'constraint_layer'):
                constraint = pruning_layer.constraint_layer
                
                # Finally, check for 'lmbda' on the constraint layer
                if hasattr(constraint, 'lmbda'):
                    print(f"  Layer: {layer.name}, Lambda Value: {constraint.lmbda.numpy():.4f}")


  Layer: compressed_layer_conv2d_keras_3, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_3, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_4, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_5, Lambda Value: 0.0000


In [26]:
from pquant import iterative_train
"""
Inputs to train_resnet we defined previously are:
          model, trainloader, device, loss_func, epoch, optimizer, scheduler, **kwargs
"""

trained_model = iterative_train(model = model, 
                                config = config, 
                                train_func = train_smart_pixels_tf, 
                                valid_func = validate_smart_pixels_tf, 
                                trainloader = train_loader, 
                                testloader = val_loader, 
                                device = None, 
                                loss_func = custom_loss,
                                optimizer = optimizer, 
                                )

100%|██████████| 84/84 [00:10<00:00,  7.80it/s]
100%|██████████| 21/21 [00:02<00:00,  8.99it/s]


NLL+mdmm_loss: 102784.91, remaining_weights: 99.78%


100%|██████████| 84/84 [00:05<00:00, 14.54it/s]
100%|██████████| 21/21 [00:01<00:00, 14.98it/s]


NLL+mdmm_loss: 102784.09, remaining_weights: 99.50%


100%|██████████| 84/84 [00:05<00:00, 14.38it/s]
100%|██████████| 21/21 [00:01<00:00, 15.11it/s]


NLL+mdmm_loss: 102783.71, remaining_weights: 99.34%


100%|██████████| 84/84 [00:05<00:00, 14.31it/s]
100%|██████████| 21/21 [00:01<00:00, 14.95it/s]


NLL+mdmm_loss: 102783.54, remaining_weights: 98.95%


100%|██████████| 84/84 [00:05<00:00, 14.57it/s]
100%|██████████| 21/21 [00:01<00:00, 15.32it/s]


NLL+mdmm_loss: 102783.48, remaining_weights: 98.79%


100%|██████████| 84/84 [00:05<00:00, 14.33it/s]
100%|██████████| 21/21 [00:01<00:00, 14.91it/s]


NLL+mdmm_loss: 102783.48, remaining_weights: 98.62%


100%|██████████| 84/84 [00:05<00:00, 14.18it/s]
100%|██████████| 21/21 [00:01<00:00, 14.76it/s]


NLL+mdmm_loss: 102783.48, remaining_weights: 98.62%


100%|██████████| 84/84 [00:05<00:00, 14.06it/s]
100%|██████████| 21/21 [00:01<00:00, 14.82it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.62%


100%|██████████| 84/84 [00:05<00:00, 14.39it/s]
100%|██████████| 21/21 [00:01<00:00, 14.70it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.51%


100%|██████████| 84/84 [00:06<00:00, 13.96it/s]
100%|██████████| 21/21 [00:01<00:00, 14.94it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.42it/s]
100%|██████████| 21/21 [00:01<00:00, 15.10it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.49it/s]
100%|██████████| 21/21 [00:01<00:00, 14.75it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.40it/s]
100%|██████████| 21/21 [00:01<00:00, 15.25it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.44it/s]
100%|██████████| 21/21 [00:01<00:00, 14.97it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.33it/s]
100%|██████████| 21/21 [00:01<00:00, 15.17it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.57it/s]
100%|██████████| 21/21 [00:01<00:00, 15.26it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.50it/s]
100%|██████████| 21/21 [00:01<00:00, 15.17it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.04it/s]
100%|██████████| 21/21 [00:01<00:00, 15.07it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.32it/s]
100%|██████████| 21/21 [00:01<00:00, 15.06it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.56it/s]
100%|██████████| 21/21 [00:01<00:00, 14.75it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.47it/s]
100%|██████████| 21/21 [00:01<00:00, 14.77it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.07it/s]
100%|██████████| 21/21 [00:01<00:00, 10.85it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.44it/s]
100%|██████████| 21/21 [00:01<00:00, 14.86it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.36it/s]
100%|██████████| 21/21 [00:01<00:00, 11.78it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.45it/s]
100%|██████████| 21/21 [00:01<00:00, 15.04it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.33it/s]
100%|██████████| 21/21 [00:01<00:00, 15.24it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.49it/s]
100%|██████████| 21/21 [00:01<00:00, 15.34it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:06<00:00, 13.88it/s]
100%|██████████| 21/21 [00:01<00:00, 14.97it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


100%|██████████| 84/84 [00:05<00:00, 14.18it/s]
100%|██████████| 21/21 [00:01<00:00, 15.13it/s]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


 94%|█████████▍| 79/84 [00:05<00:00, 14.37it/s]


KeyboardInterrupt: 