In [None]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import keras
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
# keras.backend.set_image_data_format("channels_first")
keras.backend.set_image_data_format("channels_last")

In [None]:
try:
    os.chdir("/home/das214/PQuant/mdmm_dev/src")
except:
    pass

for f in os.listdir(os.getcwd()):
    print(f)

pquant


In [None]:
from keras.layers import Dense, SeparableConv2D, Conv2D, AveragePooling2D, Flatten, Input, Activation
from keras.models import Model

def var_network(var, hidden, output):
    var = Flatten()(var)
    var = Dense(hidden)(var)
    var = Activation("tanh")(var)
    var = Dense(hidden)(var)
    var = Activation("tanh")(var)
    return Dense(output)(var)

def conv_network(var, n_filters, kernel_size):
    var = SeparableConv2D(n_filters,kernel_size)(var)
    var = Activation("tanh")(var)
    var = Conv2D(n_filters,1)(var)
    var = Activation("tanh")(var)    
    return var

def CreateModel(shape, n_filters, kernel_size, pool_size, hidden, output):
    x_base = x_in = Input(shape)
    stack = conv_network(x_base,  n_filters, kernel_size)
    stack = AveragePooling2D(
        pool_size=(pool_size, pool_size), 
        strides=None, 
        padding="valid", 
        data_format=None,        
    )(stack)
    stack = Activation("tanh")(stack)
    stack = var_network(stack, hidden=16, output=output)
    model = Model(inputs=x_in, outputs=stack)
    return model

In [None]:
model=CreateModel(shape = (16,16,2), 
                  n_filters=5, kernel_size=3,
                  pool_size=3, 
                  hidden=10,
                  output = 14)

model.summary()

## Add pruning and quantization
To add pruning and quantization, we need a config file that defines how to do that. Let's load a config file from `pquant/configs/`. <br/>
The training function we use later will add the pruning layers and quantized activations automatically using this config

In [None]:
from pquant import get_default_config
from IPython.display import JSON

# pruning_methods: "autosparse, cl, cs, dst, pdp, wanda, mdmm"
pruning_method = "mdmm"
config = get_default_config(pruning_method)
JSON(config)

<IPython.core.display.JSON object>

In [None]:
# Replace layers with compressed layers
from pquant import add_compression_layers
input_shape = (1, 16, 16, 2)

model = add_compression_layers(model, config, input_shape)
model.summary()

## Create data set
#### Let's create the data loader and the training and validation loops

In [None]:
dataset_base_dir = "/depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained"
tfrecords_base_dir = os.path.join(dataset_base_dir, "TFR_files", "2t")

dataset_train_dir = os.path.join(dataset_base_dir, "train")
dataset_test_dir = os.path.join(dataset_base_dir, "test")
tfrecords_dir_train = os.path.join(tfrecords_base_dir, "TFR_train")
tfrecords_dir_val   = os.path.join(tfrecords_base_dir, "TFR_val")

train_file_size = len(os.listdir(dataset_train_dir))
val_file_size = len(os.listdir(dataset_test_dir))
print(f"train_file_size: {train_file_size}\nval_file_size: {val_file_size}")

train_file_size: 80
val_file_size: 20


In [None]:
os.chdir("/home/das214/SmartPix/smrtpx_PQ/")
from DG.OptimizedDataGenerator_v2 import OptimizedDataGenerator
from losses.loss import custom_loss

In [None]:
os.chdir("/home/das214/PQuant/mdmm_dev/src")

In [None]:
val_loader = OptimizedDataGenerator(
    load_from_tfrecords_dir= tfrecords_dir_val,
    shuffle=True,
    seed=42,
    quantize=False,
)

train_loader = OptimizedDataGenerator(
    load_from_tfrecords_dir = tfrecords_dir_train,
    shuffle=True,
    seed=42,
    quantize=False,
)


Loading metadata from /depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained/TFR_files/2t/TFR_val/metadata.json




Loading metadata from /depot/cms/users/das214/datasets/dataset_3sr/dataset_3sr_16x16_50x12P5_parquets/contained/TFR_files/2t/TFR_train/metadata.json




In [None]:
from quantizers.fixed_point.fixed_point_ops import get_fixed_quantizer
# Set up input quantizer
quantizer = get_fixed_quantizer(overflow_mode="SAT")

import tensorflow as tf
from tqdm.auto import tqdm
from pquant import get_layer_keep_ratio, get_model_losses


@tf.function
def train_step(model, inputs, logits, optimizer: keras.optimizers.Optimizer):
    with tf.GradientTape() as tape:
        outputs = model(inputs, training=True)
        loss = custom_loss(logits, outputs)
        loss += get_model_losses(model, losses=keras.ops.convert_to_tensor(0.))

    grads = tape.gradient(loss, model.trainable_weights)
    optimizer.apply(grads, model.trainable_weights)
    return loss

def train_smart_pixels_tf(model, epoch, trainloader, optimizer: keras.optimizers.Optimizer, *args, **kwargs):
    for batch_idx in tqdm(range(train_loader.__len__())):
        inputs, logits = train_loader[batch_idx]
        loss = train_step(model, inputs, logits, optimizer)
        
@tf.function     
def valid_step(model, inputs, logits):
    outputs = model(inputs, training=False)
    loss = custom_loss(logits, outputs)
    loss += get_model_losses(model, losses=keras.ops.convert_to_tensor(0.))
    return loss


def validate_smart_pixels_tf(model, epoch, testloader, *args, **kwargs):
    for batch_idx in tqdm(range(val_loader.__len__())):
        inputs, logits = val_loader[batch_idx]
        loss = valid_step(model, inputs, logits)
        
    ratio = get_layer_keep_ratio(model)
    print(f'NLL+mdmm_loss: {loss:.2f}, remaining_weights: {ratio * 100:.2f}%')


## Create loss function, scheduler and optimizer

In [None]:
from keras.optimizers import Adam

optimizer = Adam(learning_rate=config["lr"])

## Train model
Training time. We use the train_compressed_model function from pquant to train. We need to provide some parameters such as training and validation functions, their input parameters, the model and the config file. The function automatically adds pruning layers and replaces activations with a quantized variant, trains the model, and removes the pruning layers after training is done

In [None]:
from pquant import iterative_train
"""
Inputs to train_resnet we defined previously are:
          model, trainloader, device, loss_func, epoch, optimizer, scheduler, **kwargs
"""

trained_model = iterative_train(model = model, 
                                config = config, 
                                train_func = train_smart_pixels_tf, 
                                valid_func = validate_smart_pixels_tf, 
                                trainloader = train_loader, 
                                testloader = val_loader, 
                                device = None, 
                                loss_func = custom_loss,
                                optimizer = optimizer, 
                                )

  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.38%, remaining_weights: 50.41%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.40%, remaining_weights: 59.62%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.40%, remaining_weights: 45.39%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.41%, remaining_weights: 51.30%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.42%, remaining_weights: 52.68%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.39%, remaining_weights: 63.76%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.39%, remaining_weights: 50.14%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.38%, remaining_weights: 47.99%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.39%, remaining_weights: 50.97%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.41%, remaining_weights: 55.49%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.38%, remaining_weights: 35.74%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.39%, remaining_weights: 49.64%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.40%, remaining_weights: 44.35%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.41%, remaining_weights: 50.69%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.42%, remaining_weights: 55.05%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.40%, remaining_weights: 43.19%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.38%, remaining_weights: 53.78%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.41%, remaining_weights: 46.55%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.42%, remaining_weights: 50.30%


  0%|          | 0/84 [00:00<?, ?it/s]

  0%|          | 0/21 [00:00<?, ?it/s]

NLL+mdmm_loss: 103616.41%, remaining_weights: 32.93%


  0%|          | 0/84 [00:00<?, ?it/s]

KeyboardInterrupt: 