In [16]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
keras.backend.set_image_data_format("channels_last")
import tensorflow as tf

import numpy as np
import random

SEED = 42
os.environ['PYTHONHASHSEED'] = str(SEED)
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

In [17]:
try:
    os.chdir("/home/das214/PQuant/mdmm_dev/src")
except:
    pass

for f in os.listdir(os.getcwd()):
    print(f)

pquant
smartpixels


In [18]:
dataset_base_dir = "/depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained"
tfrecords_base_dir = os.path.join(dataset_base_dir, "TFR_files", "2t")

dataset_train_dir = os.path.join(dataset_base_dir, "train")
dataset_test_dir = os.path.join(dataset_base_dir, "test")
tfrecords_dir_train = os.path.join(tfrecords_base_dir, "TFR_train")
tfrecords_dir_val   = os.path.join(tfrecords_base_dir, "TFR_val")

train_file_size = len(os.listdir(dataset_train_dir))
val_file_size = len(os.listdir(dataset_test_dir))
print(f"train_file_size: {train_file_size}\nval_file_size: {val_file_size}")

from smartpixels.DG.OptimizedDataGenerator_v2 import OptimizedDataGenerator
from smartpixels.losses.maxNLL import custom_loss
from smartpixels.models.conv2D import CreateModel


train_file_size: 80
val_file_size: 20


In [19]:
val_loader = OptimizedDataGenerator(
    load_from_tfrecords_dir= tfrecords_dir_val,
    shuffle=True,
    seed=SEED,
    quantize=False,
)

train_loader = OptimizedDataGenerator(
    load_from_tfrecords_dir = tfrecords_dir_train,
    shuffle=True,
    seed=SEED,
    quantize=False,
)


Loading metadata from /depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained/TFR_files/2t/TFR_val/metadata.json




Loading metadata from /depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained/TFR_files/2t/TFR_train/metadata.json




In [20]:
model=CreateModel(shape = (16,16,2), 
                  n_filters=5, SepConv2D_kernel_size=3, 
                  Conv2D_kernel_size=1,
                  pool_size=3, 
                  hidden=16,
                  output = 14)

model.summary()

In [21]:
# model.compile(
#     optimizer=tf.keras.optimizers.Nadam(learning_rate=1e-3),
#     loss=custom_loss,
# )
# history = model.fit(
#         x=train_loader,
#         validation_data=val_loader,
#         epochs=5,
#         shuffle=False,
#         verbose=1
#     )

## Add pruning and quantization
To add pruning and quantization, we need a config file that defines how to do that. Let's load a config file from `pquant/configs/`. <br/>
The training function we use later will add the pruning layers and quantized activations automatically using this config

In [22]:
from pquant import get_default_config
from IPython.display import JSON

# pruning_methods: "autosparse, cl, cs, dst, pdp, wanda, mdmm"
pruning_method = "mdmm"
config = get_default_config(pruning_method)
JSON(config)

<IPython.core.display.JSON object>

In [23]:
# Replace layers with compressed layers
from pquant import add_compression_layers
input_shape = (1, 16, 16, 2)

model = add_compression_layers(model, config, input_shape)
model.summary()

## Create data set
#### Let's create the data loader and the training and validation loops

In [24]:
from quantizers.fixed_point.fixed_point_ops import get_fixed_quantizer
# Set up input quantizer
quantizer = get_fixed_quantizer(overflow_mode="SAT")

from tqdm import tqdm
from pquant import get_layer_keep_ratio, get_model_losses
import re

def get_lambda_values(model):
    """Extracts lambda values with clean, descriptive keys for the progress bar."""
    lambda_info = {}
    # Use counters for each layer type to ensure unique, short keys
    counters = {'conv2d': 1, 'dense': 1, 'other': 1}

    for layer in model.layers:
        if "CompressedLayer" in layer.__class__.__name__ and hasattr(layer, 'pruning_layer'):
            pruning_layer = layer.pruning_layer
            if hasattr(pruning_layer, 'constraint_layer') and hasattr(pruning_layer.constraint_layer, 'lmbda'):
                layer_name = layer.name.lower()
                
                # Assign a clean, short key based on the layer type
                if 'conv2d' in layer_name:
                    key = f"λc2D{counters['conv2d']}"
                    counters['conv2d'] += 1
                elif 'dense' in layer_name:
                    key = f"λd{counters['dense']}"
                    counters['dense'] += 1
                else:
                    key = f"λo{counters['other']}"
                    counters['other'] += 1
                
                # Format the lambda value
                lambda_info[key] = f"{pruning_layer.constraint_layer.lmbda.numpy():.2f}"
                
    return lambda_info


@tf.function
def train_step(model, inputs, logits, optimizer: keras.optimizers.Optimizer):
    with tf.GradientTape() as tape:
        outputs = model(inputs, training=True)
        loss = custom_loss(logits, outputs)
        loss += get_model_losses(model, losses=keras.ops.convert_to_tensor(0.))
        if model.losses:
            loss += tf.add_n(model.losses)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss

@tf.function
def valid_step(model, inputs, logits):
    outputs = model(inputs, training=False)
    loss = custom_loss(logits, outputs)
    pquant_loss = get_model_losses(model, losses=keras.ops.convert_to_tensor(0.))
    if model.losses:
            loss += tf.add_n(model.losses)
    return loss + pquant_loss, pquant_loss

def train_smart_pixels_tf(model, epoch, trainloader, optimizer: keras.optimizers.Optimizer, *args, **kwargs):
    total_loss = 0.0
    pbar = tqdm(range(len(trainloader)), desc=f"Epoch {epoch} [Training]")
    
    for batch_idx in pbar:
        inputs, logits = trainloader[batch_idx]
        loss = train_step(model, inputs, logits, optimizer)
        total_loss += loss
        metrics = {
            'loss': f'{loss.numpy():.2f}',
        }
        metrics.update(get_lambda_values(model))
        pbar.set_postfix(**metrics)
        
    train_loader.on_epoch_end()
    return (total_loss/ len(trainloader)).numpy()

        

        
def validate_smart_pixels_tf(model, epoch, testloader, *args, **kwargs):
    total_loss = 0.0
    pbar = tqdm(range(len(testloader)), desc=f"Epoch {epoch} [Validation]")
    
    for batch_idx in pbar:
        inputs, logits = testloader[batch_idx]
        loss, pquant_loss = valid_step(model, inputs, logits)
        total_loss += loss
        pbar.set_postfix(loss=f'{loss.numpy():.2f}')

    testloader.on_epoch_end()
        
    avg_total_loss = total_loss / len(testloader)
    ratio = get_layer_keep_ratio(model)

    summary_string = (
        f"\tTotal Loss (NLL + MDMM): {avg_total_loss:.2f}\n"
        f"\tMDMM Loss:               {pquant_loss:.2f}\n"
        f"\tRemaining Weights:       {ratio * 100:.2f} %"
    )
    print(summary_string)
    return (total_loss/ len(testloader)).numpy()


## Create loss function, scheduler and optimizer

## Train model
Training time. We use the train_compressed_model function from pquant to train. We need to provide some parameters such as training and validation functions, their input parameters, the model and the config file. The function automatically adds pruning layers and replaces activations with a quantized variant, trains the model, and removes the pruning layers after training is done

In [25]:
for layer in model.layers:
        if "CompressedLayer" in layer.__class__.__name__ and hasattr(layer, 'pruning_layer'):
            pruning_layer = layer.pruning_layer # This is your MDMM instance
            
            # Check if the MDMM instance has the constraint_layer
            if hasattr(pruning_layer, 'constraint_layer'):
                constraint = pruning_layer.constraint_layer
                
                # Finally, check for 'lmbda' on the constraint layer
                if hasattr(constraint, 'lmbda'):
                    print(f"  Layer: {layer.name}, Lambda Value: {constraint.lmbda.numpy():.4f}")


  Layer: compressed_layer_conv2d_keras_3, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_3, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_4, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_5, Lambda Value: 0.0000


In [26]:
from pathlib import Path
from secrets import token_hex

BASE_DIR = Path("/home/das214/PQuant/mdmm_dev/trainings").resolve()
RUN_DIR  = BASE_DIR / "smart_pixels"
FP       = token_hex(4)          

BASE_DIR.mkdir(parents=True, exist_ok=True)
RUN_DIR.mkdir(parents=True, exist_ok=True)

RUN_OUT = RUN_DIR / FP
RUN_OUT.mkdir(exist_ok=True)

print(f"Training artifacts will be stored in: {RUN_OUT}")

Training artifacts will be stored in: /home/das214/PQuant/mdmm_dev/trainings/smart_pixels/e6697199


In [27]:
from pquant.core.tf_impl.callbacks import CSVLogger, EarlyStopping 
from pquant.core.tf_impl.callbacks import ModelCheckpoint, EpochCheckpoint

cbs = [
    CSVLogger(RUN_OUT / "history.csv"),                   
    EpochCheckpoint(RUN_OUT / "checkpoints"),              
    EarlyStopping(patience=50, min_delta=1e-3),            
]

In [None]:
from pquant import iterative_train
"""
Inputs to train_resnet we defined previously are:
          model, trainloader, device, loss_func, epoch, optimizer, scheduler, **kwargs
"""

optimizer =  keras.optimizers.Nadam(learning_rate=config["lr"])

trained_model = iterative_train(model = model, 
                                config = config, 
                                train_func = train_smart_pixels_tf, 
                                valid_func = validate_smart_pixels_tf, 
                                trainloader = train_loader, 
                                testloader = val_loader, 
                                device = None, 
                                loss_func = custom_loss,
                                optimizer = optimizer, 
                                callbacks     = cbs, 
                                )

Training...


Epoch 0 [Training]:   0%|          | 0/84 [00:00<?, ?it/s]



Epoch 0 [Training]: 100%|██████████| 84/84 [00:14<00:00,  5.78it/s, loss=16211.51, λc2D1=0.23, λd1=0.08, λd2=0.15, λd3=0.15]
Epoch 0 [Validation]: 100%|██████████| 21/21 [00:02<00:00,  7.45it/s, loss=16170.89]


	Total Loss (NLL + MDMM): 16036.61
	MDMM Loss:               4.96
	Remaining Weights:       99.61 %


Epoch 1 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.51it/s, loss=8348.47, λc2D1=0.52, λd1=0.19, λd2=0.34, λd3=0.35] 
Epoch 1 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 15.17it/s, loss=8168.31]


	Total Loss (NLL + MDMM): 8091.49
	MDMM Loss:               8.90
	Remaining Weights:       99.61 %


Epoch 2 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.37it/s, loss=3677.15, λc2D1=0.81, λd1=0.30, λd2=0.53, λd3=0.55]
Epoch 2 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.95it/s, loss=3482.80]


	Total Loss (NLL + MDMM): 3460.40
	MDMM Loss:               12.60
	Remaining Weights:       99.72 %


Epoch 3 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.44it/s, loss=1710.83, λc2D1=1.10, λd1=0.41, λd2=0.73, λd3=0.74]
Epoch 3 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.47it/s, loss=1515.59]


	Total Loss (NLL + MDMM): 1489.65
	MDMM Loss:               16.41
	Remaining Weights:       99.67 %


Epoch 4 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.18it/s, loss=191.16, λc2D1=1.39, λd1=0.52, λd2=0.92, λd3=0.94] 
Epoch 4 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.99it/s, loss=43.57]  


	Total Loss (NLL + MDMM): 19.99
	MDMM Loss:               20.55
	Remaining Weights:       99.67 %


Epoch 5 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.56it/s, loss=-922.08, λc2D1=1.69, λd1=0.63, λd2=1.11, λd3=1.14] 
Epoch 5 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.82it/s, loss=-1132.03]


	Total Loss (NLL + MDMM): -1070.18
	MDMM Loss:               24.84
	Remaining Weights:       99.67 %


Epoch 6 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.53it/s, loss=-1647.81, λc2D1=1.98, λd1=0.74, λd2=1.31, λd3=1.34]
Epoch 6 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 15.14it/s, loss=-1862.30]


	Total Loss (NLL + MDMM): -1776.08
	MDMM Loss:               29.23
	Remaining Weights:       99.78 %


Epoch 7 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.39it/s, loss=-2271.63, λc2D1=2.27, λd1=0.85, λd2=1.50, λd3=1.54]
Epoch 7 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 15.01it/s, loss=-2474.14]


	Total Loss (NLL + MDMM): -2400.69
	MDMM Loss:               33.82
	Remaining Weights:       99.94 %


Epoch 8 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.33it/s, loss=-2914.03, λc2D1=2.56, λd1=0.96, λd2=1.70, λd3=1.73]
Epoch 8 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.52it/s, loss=-3109.90]


	Total Loss (NLL + MDMM): -3002.42
	MDMM Loss:               38.46
	Remaining Weights:       99.72 %


Epoch 9 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.30it/s, loss=-3590.65, λc2D1=2.86, λd1=1.07, λd2=1.90, λd3=1.93]
Epoch 9 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.88it/s, loss=-3846.69]


	Total Loss (NLL + MDMM): -3698.96
	MDMM Loss:               43.10
	Remaining Weights:       99.72 %


Epoch 10 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.33it/s, loss=-4272.73, λc2D1=3.15, λd1=1.19, λd2=2.09, λd3=2.13]
Epoch 10 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.47it/s, loss=-4549.61]


	Total Loss (NLL + MDMM): -4381.13
	MDMM Loss:               47.89
	Remaining Weights:       99.67 %


Epoch 11 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.37it/s, loss=-5013.15, λc2D1=3.44, λd1=1.30, λd2=2.29, λd3=2.33]
Epoch 11 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.84it/s, loss=-5317.29]


	Total Loss (NLL + MDMM): -5132.06
	MDMM Loss:               52.85
	Remaining Weights:       99.50 %


Epoch 12 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.26it/s, loss=-5848.98, λc2D1=3.74, λd1=1.41, λd2=2.49, λd3=2.53]
Epoch 12 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.92it/s, loss=-6186.92]


	Total Loss (NLL + MDMM): -5992.00
	MDMM Loss:               57.99
	Remaining Weights:       99.67 %


Epoch 13 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.38it/s, loss=-6763.70, λc2D1=4.03, λd1=1.52, λd2=2.69, λd3=2.72]
Epoch 13 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.35it/s, loss=-7136.27]


	Total Loss (NLL + MDMM): -6930.22
	MDMM Loss:               63.49
	Remaining Weights:       99.83 %


Epoch 14 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.13it/s, loss=-7731.63, λc2D1=4.33, λd1=1.63, λd2=2.88, λd3=2.92]
Epoch 14 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.79it/s, loss=-8079.72]


	Total Loss (NLL + MDMM): -7890.54
	MDMM Loss:               69.18
	Remaining Weights:       99.89 %


Epoch 15 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.35it/s, loss=-8701.97, λc2D1=4.62, λd1=1.75, λd2=3.08, λd3=3.12]
Epoch 15 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.60it/s, loss=-9022.88]


	Total Loss (NLL + MDMM): -8861.05
	MDMM Loss:               74.54
	Remaining Weights:       99.61 %


Epoch 16 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.04it/s, loss=-9612.43, λc2D1=4.92, λd1=1.86, λd2=3.28, λd3=3.31]
Epoch 16 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 13.79it/s, loss=-9956.47] 


	Total Loss (NLL + MDMM): -9791.81
	MDMM Loss:               81.11
	Remaining Weights:       99.56 %


Epoch 17 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.23it/s, loss=-10456.52, λc2D1=5.21, λd1=1.97, λd2=3.48, λd3=3.51]
Epoch 17 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.65it/s, loss=-10802.04]


	Total Loss (NLL + MDMM): -10616.11
	MDMM Loss:               87.15
	Remaining Weights:       99.61 %


Epoch 18 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.08it/s, loss=-11271.01, λc2D1=5.51, λd1=2.09, λd2=3.67, λd3=3.71]
Epoch 18 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.27it/s, loss=-11642.80]


	Total Loss (NLL + MDMM): -11416.81
	MDMM Loss:               93.56
	Remaining Weights:       99.67 %


Epoch 19 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.26it/s, loss=-12081.03, λc2D1=5.80, λd1=2.20, λd2=3.87, λd3=3.90]
Epoch 19 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.56it/s, loss=-12423.58]


	Total Loss (NLL + MDMM): -12185.89
	MDMM Loss:               100.03
	Remaining Weights:       99.72 %


Epoch 20 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.02it/s, loss=-12793.66, λc2D1=6.10, λd1=2.31, λd2=4.07, λd3=4.10]
Epoch 20 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.63it/s, loss=-13076.62]


	Total Loss (NLL + MDMM): -12855.65
	MDMM Loss:               106.37
	Remaining Weights:       99.78 %


Epoch 21 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.30it/s, loss=-13373.74, λc2D1=6.40, λd1=2.43, λd2=4.26, λd3=4.29]
Epoch 21 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.01it/s, loss=-13632.99]


	Total Loss (NLL + MDMM): -13416.96
	MDMM Loss:               112.55
	Remaining Weights:       99.28 %


Epoch 22 [Training]: 100%|██████████| 84/84 [00:06<00:00, 12.93it/s, loss=-13875.42, λc2D1=6.69, λd1=2.54, λd2=4.46, λd3=4.49]
Epoch 22 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.44it/s, loss=-14123.55]


	Total Loss (NLL + MDMM): -13905.73
	MDMM Loss:               118.57
	Remaining Weights:       99.45 %


Epoch 23 [Training]: 100%|██████████| 84/84 [00:06<00:00, 12.98it/s, loss=-14342.71, λc2D1=6.99, λd1=2.66, λd2=4.65, λd3=4.68]
Epoch 23 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.95it/s, loss=-14593.58]


	Total Loss (NLL + MDMM): -14352.20
	MDMM Loss:               124.83
	Remaining Weights:       99.56 %


Epoch 24 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.31it/s, loss=-14754.52, λc2D1=7.29, λd1=2.77, λd2=4.85, λd3=4.88]
Epoch 24 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.09it/s, loss=-15001.77]


	Total Loss (NLL + MDMM): -14760.04
	MDMM Loss:               130.97
	Remaining Weights:       99.72 %


Epoch 25 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.23it/s, loss=-15129.18, λc2D1=7.59, λd1=2.89, λd2=5.04, λd3=5.07]
Epoch 25 [Validation]:  48%|████▊     | 10/21 [00:00<00:00, 14.68it/s, loss=-15177.51]

In [None]:
# TODO: 
#     - Increase number of epochs
#     - Implement early stopping
#     - Saving model weights (checkPoints)
#     - HLS4ml tutorial