In [23]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
keras.backend.set_image_data_format("channels_last")
import tensorflow as tf

In [24]:
try:
    os.chdir("/home/das214/PQuant/mdmm_dev/src")
except:
    pass

for f in os.listdir(os.getcwd()):
    print(f)

pquant
smartpixels


In [25]:
dataset_base_dir = "/depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained"
tfrecords_base_dir = os.path.join(dataset_base_dir, "TFR_files", "2t")

dataset_train_dir = os.path.join(dataset_base_dir, "train")
dataset_test_dir = os.path.join(dataset_base_dir, "test")
tfrecords_dir_train = os.path.join(tfrecords_base_dir, "TFR_train")
tfrecords_dir_val   = os.path.join(tfrecords_base_dir, "TFR_val")

train_file_size = len(os.listdir(dataset_train_dir))
val_file_size = len(os.listdir(dataset_test_dir))
print(f"train_file_size: {train_file_size}\nval_file_size: {val_file_size}")

from smartpixels.DG.OptimizedDataGenerator_v2 import OptimizedDataGenerator
from smartpixels.losses.maxNLL import custom_loss
from smartpixels.models.conv2D import CreateModel


train_file_size: 80
val_file_size: 20


In [26]:
val_loader = OptimizedDataGenerator(
    load_from_tfrecords_dir= tfrecords_dir_val,
    shuffle=True,
    seed=42,
    quantize=False,
)

train_loader = OptimizedDataGenerator(
    load_from_tfrecords_dir = tfrecords_dir_train,
    shuffle=True,
    seed=42,
    quantize=False,
)


Loading metadata from /depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained/TFR_files/2t/TFR_val/metadata.json




Loading metadata from /depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained/TFR_files/2t/TFR_train/metadata.json




In [27]:
model=CreateModel(shape = (16,16,2), 
                  n_filters=5, kernel_size=3,
                  pool_size=3, 
                  hidden=10,
                  output = 14)

model.summary()

In [28]:
# model.compile(
#     optimizer=tf.keras.optimizers.Nadam(learning_rate=1e-3, clipnorm=1.0),
#     loss=custom_loss,
# )
# history = model.fit(
#         x=train_loader,
#         validation_data=val_loader,
#         epochs=5,
#         shuffle=False,
#         verbose=1
#     )

## Add pruning and quantization
To add pruning and quantization, we need a config file that defines how to do that. Let's load a config file from `pquant/configs/`. <br/>
The training function we use later will add the pruning layers and quantized activations automatically using this config

In [29]:
from pquant import get_default_config
from IPython.display import JSON

# pruning_methods: "autosparse, cl, cs, dst, pdp, wanda, mdmm"
pruning_method = "mdmm"
config = get_default_config(pruning_method)
JSON(config)

<IPython.core.display.JSON object>

In [30]:
# Replace layers with compressed layers
from pquant import add_compression_layers
input_shape = (1, 16, 16, 2)

model = add_compression_layers(model, config, input_shape)
model.summary()

## Create data set
#### Let's create the data loader and the training and validation loops

In [31]:
from quantizers.fixed_point.fixed_point_ops import get_fixed_quantizer
# Set up input quantizer
quantizer = get_fixed_quantizer(overflow_mode="SAT")

import tensorflow as tf
from tqdm import tqdm
from pquant import get_layer_keep_ratio, get_model_losses


@tf.function
def train_step(model, inputs, logits, optimizer: keras.optimizers.Optimizer):
    with tf.GradientTape() as tape:
        outputs = model(inputs, training=True)
        loss = custom_loss(logits, outputs)
        # loss += get_model_losses(model, losses=keras.ops.convert_to_tensor(0.))
        if model.losses:
            loss += tf.add_n(model.losses)
    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    return loss

@tf.function
def valid_step(model, inputs, logits):
    outputs = model(inputs, training=False)
    loss = custom_loss(logits, outputs)
    # loss += get_model_losses(model, losses=keras.ops.convert_to_tensor(0.))
    if model.losses:
            loss += tf.add_n(model.losses)
    return loss

def train_smart_pixels_tf(model, epoch, trainloader, optimizer: keras.optimizers.Optimizer, *args, **kwargs):
    pbar = tqdm(range(len(trainloader)), desc=f"Epoch {epoch} [Training]")
    for batch_idx in pbar:
        inputs, logits = trainloader[batch_idx]
        loss = train_step(model, inputs, logits, optimizer)
        pbar.set_postfix(loss=f'{loss.numpy():.4f}')

        
def validate_smart_pixels_tf(model, epoch, testloader, *args, **kwargs):
    total_loss = 0.0
    pbar = tqdm(range(len(testloader)), desc=f"Epoch {epoch} [Validation]")
    
    for batch_idx in pbar:
        inputs, logits = testloader[batch_idx]
        loss = valid_step(model, inputs, logits)
        total_loss += loss
        pbar.set_postfix(loss=f'{loss.numpy():.4f}')
        
    ratio = get_layer_keep_ratio(model)
    print(f'NLL+mdmm_loss: {total_loss/len(testloader):.2f}, remaining_weights: {ratio * 100:.2f}%')


## Create loss function, scheduler and optimizer

In [32]:
from keras.optimizers import Adam

optimizer = Adam(learning_rate=config["lr"])

## Train model
Training time. We use the train_compressed_model function from pquant to train. We need to provide some parameters such as training and validation functions, their input parameters, the model and the config file. The function automatically adds pruning layers and replaces activations with a quantized variant, trains the model, and removes the pruning layers after training is done

In [None]:
for layer in model.layers:
        if "CompressedLayer" in layer.__class__.__name__ and hasattr(layer, 'pruning_layer'):
            pruning_layer = layer.pruning_layer # This is your MDMM instance
            
            # Check if the MDMM instance has the constraint_layer
            if hasattr(pruning_layer, 'constraint_layer'):
                constraint = pruning_layer.constraint_layer
                
                # Finally, check for 'lmbda' on the constraint layer
                if hasattr(constraint, 'lmbda'):
                    print(f"  Layer: {layer.name}, Lambda Value: {constraint.lmbda.numpy():.4f}")


  Layer: compressed_layer_conv2d_keras_3, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_4, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_5, Lambda Value: 0.0000
  Layer: compressed_layer_dense_keras_6, Lambda Value: 0.0000


In [None]:
from pquant import iterative_train
"""
Inputs to train_resnet we defined previously are:
          model, trainloader, device, loss_func, epoch, optimizer, scheduler, **kwargs
"""

trained_model = iterative_train(model = model, 
                                config = config, 
                                train_func = train_smart_pixels_tf, 
                                valid_func = validate_smart_pixels_tf, 
                                trainloader = train_loader, 
                                testloader = val_loader, 
                                device = None, 
                                loss_func = custom_loss,
                                optimizer = optimizer, 
                                )

Epoch 0 [Training]: 100%|██████████| 84/84 [00:11<00:00,  7.36it/s, loss=103616.3281]
Epoch 0 [Validation]: 100%|██████████| 21/21 [00:02<00:00,  8.84it/s, loss=103616.3281]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


Epoch 1 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.52it/s, loss=103616.3281]
Epoch 1 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.53it/s, loss=103616.3281]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


Epoch 2 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.58it/s, loss=103616.3281]
Epoch 2 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 15.14it/s, loss=103616.3281]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


Epoch 3 [Training]: 100%|██████████| 84/84 [00:06<00:00, 13.95it/s, loss=103616.3281]
Epoch 3 [Validation]: 100%|██████████| 21/21 [00:01<00:00, 14.75it/s, loss=103616.3281]


NLL+mdmm_loss: 102783.46, remaining_weights: 98.46%


Epoch 4 [Training]:  93%|█████████▎| 78/84 [00:05<00:00, 13.70it/s, loss=103616.3281]


KeyboardInterrupt: 