## In this tutorial we create a CNN and dataloaders, and train / prune the model.

In [None]:
import os
os.environ["KERAS_BACKEND"] = "torch" # Needs to be set, some pruning layers as well as the quantizers are Keras
import keras
keras.config.set_backend("torch")
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
keras.backend.set_image_data_format("channels_first")

In [None]:
try:
    os.chdir("/home/das214/PQuant/mdmm_dev/src")
except:
    pass

for f in os.listdir(os.getcwd()):
    print(f)

In [None]:
model = torchvision.models.resnet18()
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
model = model.to(device)

model

## Add pruning and quantization
Begin prunning with MDMM pruning with Unstructured Sparsity metric function

In [None]:
from pquant import get_default_config
from IPython.display import JSON

pruning_method = "mdmm"
config = get_default_config(pruning_method)
JSON(config)

In [None]:
# Replace layers with compressed layers
from pquant import add_compression_layers
input_shape = (256,3,32,32)
model = add_compression_layers(model, config, input_shape)
model

In [None]:
import torchvision.transforms as transforms
from pquant import get_layer_keep_ratio, get_model_losses
from quantizers.fixed_point.fixed_point_ops import get_fixed_quantizer
from tqdm import tqdm


def get_cifar10_data(batch_size):
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), 
                                          transforms.ToTensor(), normalize])
    test_transform = transforms.Compose([transforms.ToTensor(), normalize])  
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=train_transform)
    valset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                       download=True, transform=test_transform)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                          shuffle=True, num_workers=4, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(valset, batch_size=batch_size,
                                         shuffle=False, num_workers=4, pin_memory=True)

    return train_loader, val_loader

# Set up input quantizer
quantizer = get_fixed_quantizer(overflow_mode="SAT")

def train_resnet(model, trainloader, device, loss_func,
                 epoch, optimizer, scheduler, *args, **kwargs):
    """
    One epoch of training with a live ETA/throughput bar.
    """
    model.train()

    with tqdm(trainloader,
              desc=f"Train ‖ Epoch {epoch}",
              total=len(trainloader),
              unit="batch",
              dynamic_ncols=True) as pbar:

        for inputs, labels in pbar:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            inputs = quantizer(inputs, k=torch.tensor(1.), i=torch.tensor(0.), f=torch.tensor(7.))

            optimizer.zero_grad(set_to_none=True)              # cleaner gradient reset
            outputs = model(inputs)
            loss = loss_func(outputs, labels)
            losses = get_model_losses(model, torch.tensor(0.).to(device))
            loss += losses
            loss.backward()
            optimizer.step()

            if scheduler is not None:
                scheduler.step()

            pbar.set_postfix(loss=f"{loss.item():.4f} ")
        
    # ----- Diagnostics on Last mini-batch -----
    print(f"Loss={loss_func(outputs, labels).item():.4f} | Reg={loss.item() - loss_func(outputs, labels).item():.4f}")

def validate_resnet(model, testloader, device, loss_func, epoch, *args, **kwargs):
    """
    Validation with progress bar and accuracy summary.
    """
    model.eval()
    correct = total = 0


    with torch.no_grad():
        with tqdm(testloader,
                  desc=f"Val   ‖ Epoch {epoch}",
                  total=len(testloader),
                  unit="batch",
                  dynamic_ncols=True) as pbar:

            for inputs, labels in pbar:
                inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
                inputs = quantizer(inputs, k=torch.tensor(1.), i=torch.tensor(0.), f=torch.tensor(7.))
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += labels.size(0)
                correct += predicted.eq(labels).sum().item()

                running_acc = 100. * correct / total
                pbar.set_postfix(acc=f"{running_acc:.2f}%")

    ratio = get_layer_keep_ratio(model)
    print(f"Accuracy: {correct/total*100:.2f}% | Remaining weights: {ratio*100:.2f}% \n")



BATCH_SIZE = 256
train_loader, val_loader = get_cifar10_data(BATCH_SIZE)

In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, 200)
loss_function = nn.CrossEntropyLoss()

In [None]:
# from pquant import iterative_train
# """
# Inputs to train_resnet we defined previously are:
#           model, trainloader, device, loss_func, epoch, optimizer, scheduler, **kwargs
# """

# trained_model = iterative_train(model = model, 
#                                 config = config, 
#                                 train_func = train_resnet, 
#                                 valid_func = validate_resnet, 
#                                 trainloader = train_loader, 
#                                 testloader = val_loader, 
#                                 device = device, 
#                                 loss_func = loss_function,
#                                 optimizer = optimizer, 
#                                 scheduler = scheduler
#                                 )

In [None]:
SAVE_PATH = 'resnet_mdmm_unstr_pruned.pth'
# torch.save(model_copy.state_dict(), SAVE_PATH)

model = torchvision.models.resnet18()
model.load_state_dict(torch.load(SAVE_PATH))

In [None]:
# Plot remaining weights
import numpy as np
import matplotlib.pyplot as plt

names = []
remaining = []
total_w = []
nonzeros = []
for n, m in model.named_modules():
    if isinstance(m, (torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Linear)):
        names.append(n)
        nonzero = np.count_nonzero(m.weight.detach().cpu())
        remaining_pct = nonzero / m.weight.numel()
        remaining.append(remaining_pct)
        total_w.append(m.weight.numel())
        nonzeros.append(nonzero)
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].bar(range(len(names)), remaining)
ax[0].set_xticks(range(len(names)))
ax[0].set_xticklabels(names)
ax[0].tick_params(axis='x', labelrotation=270)
new_ytick = []
for i in ax[0].get_yticklabels():
    ytick = f"{float(i.get_text()) * 100:.2f}%"
    new_ytick.append(ytick)
ax[0].set_yticklabels(new_ytick)
ax[0].title.set_text("Remaining weights per layer")

ax[1].bar(range(len(nonzeros)), total_w, color="lightcoral", label="total weights")
ax[1].bar(range(len(nonzeros)), nonzeros, color="steelblue", label="nonzero weights")
ax[1].set_xticks(range(len(names)))
ax[1].set_xticklabels(names)
ax[1].tick_params(axis='x', labelrotation=270)
ax[1].title.set_text("Weights per layer")
ax[1].legend()
ax[1].set_yscale("log")

plt.tight_layout()
plt.show()

## Add PACA prunning
#### After pruning we will have multiple patterns, so we force all of them to have a lower num,ber of dominant patterns

In [None]:
import yaml 

with open("pquant/configs/config_mdmm_paca.yaml", 'r') as f:
    config = yaml.safe_load(f)
JSON(config)

In [None]:
from pquant import add_compression_layers
input_shape = (256,3,32,32)
model_paca = model.to(device)
model_paca = add_compression_layers(model_paca, config, input_shape)
model_paca

In [None]:
def train_resnet(model, trainloader, device, loss_func,
                 epoch, optimizer, scheduler, *args, **kwargs):
    """
    One epoch of training with a live ETA/throughput bar.
    """
    print()
    model.train()

    with tqdm(trainloader,
              desc=f"Train ‖\tEpoch {epoch}",
              total=len(trainloader),
              unit="batch",
              dynamic_ncols=True) as pbar:

        for inputs, labels in pbar:
            inputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)
            inputs = quantizer(inputs, k=torch.tensor(1.), i=torch.tensor(0.), f=torch.tensor(7.))

            optimizer.zero_grad(set_to_none=True)              # cleaner gradient reset
            outputs = model(inputs)
            loss = loss_func(outputs, labels)
            losses = get_model_losses(model, torch.tensor(0.).to(device))
            loss += losses
            loss.backward()
            optimizer.step()

            if scheduler is not None:
                scheduler.step()

            pbar.set_postfix(loss=f"{loss.item():.4f} ")
        
    # ----- Diagnostics on Last mini-batch -----
    print(f"Loss={loss_func(outputs, labels).item():.4f} | Reg={loss.item() - loss_func(outputs, labels).item():.4f}")


In [None]:
import torch
from keras import ops
from tqdm import tqdm
from pquant.pruning_methods.utils import patterns

def apply_projection_mask(model, src, epsilon, distance_metric, alpha=16, beta=0.85):
    """
    Applies a projection mask to the Conv2d layers of a model, ensuring device consistency.
    This version includes extra checks to guarantee the weight remains on its original device.
    """
    module_iterator = tqdm(model.named_modules(), desc="Applying Projection Mask", leave=False)
    
    with torch.no_grad():
        for name, module in module_iterator:
            if isinstance(module, torch.nn.Conv2d):
                weight = module.weight
                original_device = weight.device
                
                module_iterator.set_postfix_str(f"Processing {name} on {original_device}")
                _, all_patterns, _ = patterns._get_kernels_and_patterns(weight)
                unique_patterns, counts = patterns._get_unique_patterns_with_counts(all_patterns)
                
                if ops.shape(unique_patterns)[0] == 0:
                    continue
                    
                dominant_patterns = patterns._select_dominant_patterns(all_patterns, 
                                                              unique_patterns, counts,
                                                              alpha=alpha, beta=beta, dtype=weight.dtype)

                projection_mask_keras = patterns._get_projection_mask(weight, dominant_patterns, 
                                                                      src=src, epsilon=epsilon, 
                                                                      distance_metric=distance_metric)
                

                projection_mask_torch = ops.convert_to_tensor(projection_mask_keras, dtype=weight.dtype)
                projection_mask_final = projection_mask_torch.to(original_device)
                module.weight.data.mul_(projection_mask_final)

In [None]:
import torchvision.transforms as transforms
from tqdm import tqdm

def evaluate_model(model, testloader, device, desc):
    """
    Evaluates the model's accuracy on the provided data loader.
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        with tqdm(testloader,
                  desc=desc,
                  total=len(testloader),
                  unit="batch",
                  dynamic_ncols=True) as pbar:

            for inputs, labels in pbar:
                inputs, labels = inputs.to(device), labels.to(device)
                inputs = quantizer(inputs, k=torch.tensor(1.), i=torch.tensor(0.), f=torch.tensor(7.))
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
                
                postfix_stats = {
                    'acc': f"{(100 * correct / total):.2f}%"
                }
                pbar.set_postfix(**postfix_stats)
    
    accuracy = 100. * correct / total
    # print(f"Final Accuracy: {final_accuracy:.2f}%")
    return accuracy

In [None]:
import random
fingerprint = '%08x' % random.randrange(16**8)
SAVE_DIR = f'trained_models/resnet/mdmm/{fingerprint}/unstr_paca/'
print(fingerprint)
print(SAVE_DIR)

In [None]:
from pquant import remove_pruning_from_model

os.makedirs(SAVE_DIR, exist_ok=True)

def validate_resnet(model, testloader, device, loss_func, epoch, **kwargs):
    model.eval()
    
    accuracy_bp = evaluate_model(model, testloader, device, desc=f" - Val (BP)\t‖\tEpoch {epoch}")
    print(f"Accuracy (Before Projection): {accuracy_bp:.2f}%")
    
    ratio_bp = get_layer_keep_ratio(model)
    print(f'Remaining Weights: {ratio_bp * 100:.2f}%')
    
    print("Saving Before-Projection model...")
    model_to_save_bp = torchvision.models.resnet18()
    model_to_save_bp.load_state_dict(model.state_dict(), strict=False)
    
    clean_model_bp = remove_pruning_from_model(model_to_save_bp, config)
    
    model_fname_bp = f'BP_e{epoch}_a{accuracy_bp:.2f}_r{ratio_bp:.2f}.pth'
    SAVE_PATH_BP = os.path.join(SAVE_DIR, model_fname_bp)
    torch.save(clean_model_bp.state_dict(), SAVE_PATH_BP)
    print(f"Saved to {SAVE_PATH_BP}")
    
    model_to_save_ap = torchvision.models.resnet18()
    model_to_save_ap.load_state_dict(model.state_dict(), strict=False)
    model_to_save_ap.to(device)
    
    apply_projection_mask(model_to_save_ap, 
                          src=config['pruning_parameters']['src'],
                          epsilon=config['pruning_parameters']['epsilon'],
                          distance_metric=config['pruning_parameters']['distance_metric'],
                          alpha=config['pruning_parameters']['num_patterns_to_keep'], 
                          beta=config['pruning_parameters']['beta'])
        
    accuracy_ap = evaluate_model(model_to_save_ap, testloader, device, desc=f" - Val (AP)\t‖\tEpoch {epoch}")
    print(f"Accuracy (After Projection): {accuracy_ap:.2f}%\n")
    
    ratio_ap = get_layer_keep_ratio(model_to_save_ap)
    print(f'Remaining Weights: {ratio_ap * 100:.2f}%')
    
    print("Saving After-Projection model...")
    clean_model_ap = remove_pruning_from_model(model_to_save_ap, config)
    
    model_fname_ap = f'AP_e{epoch}_a{accuracy_ap:.2f}_r{ratio_ap:.2f}.pth'
    SAVE_PATH_AP = os.path.join(SAVE_DIR, model_fname_ap)
    torch.save(clean_model_ap.state_dict(), SAVE_PATH_AP)
    print(f"Saved to {SAVE_PATH_AP}")
    
    
BATCH_SIZE = 256
train_loader, val_loader = get_cifar10_data(BATCH_SIZE)

## Train model


In [None]:
from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0001, momentum=0.9)
scheduler = CosineAnnealingLR(optimizer, 200)
loss_function = nn.CrossEntropyLoss()

In [19]:
from pquant import iterative_train
"""
Inputs to train_resnet we defined previously are:
          model, trainloader, device, loss_func, epoch, optimizer, scheduler, **kwargs
"""

trained_model_paca = iterative_train(model = model_paca, 
                                config = config, 
                                train_func = train_resnet, 
                                valid_func = validate_resnet, 
                                trainloader = train_loader, 
                                testloader = val_loader, 
                                device = device, 
                                loss_func = loss_function,
                                optimizer = optimizer, 
                                scheduler = scheduler
                                )

Train ‖	Epoch 0: 100%|██████████| 196/196 [00:10<00:00, 19.09batch/s, loss=2.7553]


Loss=0.5762 | Reg=2.1791


 - Val (BP)	‖	Epoch 0: 100%|██████████| 196/196 [00:06<00:00, 32.62batch/s, acc=86.97%]


Accuracy (Before Projection): 86.97%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e0_a86.97_r0.46.pth


 - Val (AP)	‖	Epoch 0: 100%|██████████| 196/196 [00:02<00:00, 78.31batch/s, acc=86.97%]       


Accuracy (After Projection): 86.97%

Remaining Weights: 44.50%
Saving After-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/AP_e0_a86.97_r0.44.pth



Train ‖	Epoch 1: 100%|██████████| 196/196 [00:10<00:00, 19.29batch/s, loss=3.9187]


Loss=0.5352 | Reg=3.3835


 - Val (BP)	‖	Epoch 1: 100%|██████████| 196/196 [00:06<00:00, 31.69batch/s, acc=77.27%]


Accuracy (Before Projection): 77.27%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e1_a77.27_r0.46.pth


 - Val (AP)	‖	Epoch 1: 100%|██████████| 196/196 [00:02<00:00, 78.12batch/s, acc=77.27%]       


Accuracy (After Projection): 77.27%

Remaining Weights: 44.50%
Saving After-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/AP_e1_a77.27_r0.44.pth



Train ‖	Epoch 2: 100%|██████████| 196/196 [00:10<00:00, 18.90batch/s, loss=4.9096]


Loss=0.3528 | Reg=4.5568


 - Val (BP)	‖	Epoch 2: 100%|██████████| 196/196 [00:06<00:00, 32.40batch/s, acc=87.95%]


Accuracy (Before Projection): 87.95%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e2_a87.95_r0.46.pth


 - Val (AP)	‖	Epoch 2: 100%|██████████| 196/196 [00:02<00:00, 77.83batch/s, acc=87.95%]       


Accuracy (After Projection): 87.95%

Remaining Weights: 44.50%
Saving After-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/AP_e2_a87.95_r0.44.pth



Train ‖	Epoch 3: 100%|██████████| 196/196 [00:10<00:00, 19.46batch/s, loss=6.4855]


Loss=0.7320 | Reg=5.7535


 - Val (BP)	‖	Epoch 3: 100%|██████████| 196/196 [00:06<00:00, 32.49batch/s, acc=81.43%]


Accuracy (Before Projection): 81.43%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e3_a81.43_r0.46.pth


 - Val (AP)	‖	Epoch 3: 100%|██████████| 196/196 [00:02<00:00, 79.57batch/s, acc=81.43%]       


Accuracy (After Projection): 81.43%

Remaining Weights: 44.49%
Saving After-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/AP_e3_a81.43_r0.44.pth



Train ‖	Epoch 4: 100%|██████████| 196/196 [00:10<00:00, 19.21batch/s, loss=7.4137]


Loss=0.4904 | Reg=6.9233


 - Val (BP)	‖	Epoch 4: 100%|██████████| 196/196 [00:06<00:00, 32.57batch/s, acc=88.78%]


Accuracy (Before Projection): 88.78%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e4_a88.78_r0.46.pth


 - Val (AP)	‖	Epoch 4: 100%|██████████| 196/196 [00:02<00:00, 78.11batch/s, acc=88.78%]       


Accuracy (After Projection): 88.78%

Remaining Weights: 44.49%
Saving After-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/AP_e4_a88.78_r0.44.pth



Train ‖	Epoch 5: 100%|██████████| 196/196 [00:10<00:00, 18.99batch/s, loss=8.5892]


Loss=0.4771 | Reg=8.1121


 - Val (BP)	‖	Epoch 5: 100%|██████████| 196/196 [00:06<00:00, 31.65batch/s, acc=82.75%]


Accuracy (Before Projection): 82.75%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e5_a82.75_r0.46.pth


 - Val (AP)	‖	Epoch 5: 100%|██████████| 196/196 [00:02<00:00, 77.16batch/s, acc=82.75%]       


Accuracy (After Projection): 82.75%

Remaining Weights: 44.49%
Saving After-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/AP_e5_a82.75_r0.44.pth



Train ‖	Epoch 6: 100%|██████████| 196/196 [00:10<00:00, 19.23batch/s, loss=9.7694]


Loss=0.4861 | Reg=9.2834


 - Val (BP)	‖	Epoch 6: 100%|██████████| 196/196 [00:06<00:00, 32.16batch/s, acc=89.23%]


Accuracy (Before Projection): 89.23%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e6_a89.23_r0.46.pth


 - Val (AP)	‖	Epoch 6: 100%|██████████| 196/196 [00:02<00:00, 78.08batch/s, acc=89.23%]       


Accuracy (After Projection): 89.23%

Remaining Weights: 44.49%
Saving After-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/AP_e6_a89.23_r0.44.pth



Train ‖	Epoch 7: 100%|██████████| 196/196 [00:10<00:00, 19.30batch/s, loss=11.0626]


Loss=0.5922 | Reg=10.4703


 - Val (BP)	‖	Epoch 7: 100%|██████████| 196/196 [00:06<00:00, 31.91batch/s, acc=83.99%]


Accuracy (Before Projection): 83.99%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e7_a83.99_r0.46.pth


 - Val (AP)	‖	Epoch 7: 100%|██████████| 196/196 [00:02<00:00, 77.40batch/s, acc=83.99%]       


Accuracy (After Projection): 83.99%

Remaining Weights: 44.49%
Saving After-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/AP_e7_a83.99_r0.44.pth



Train ‖	Epoch 8: 100%|██████████| 196/196 [00:10<00:00, 19.29batch/s, loss=12.3056]


Loss=0.6655 | Reg=11.6401


 - Val (BP)	‖	Epoch 8: 100%|██████████| 196/196 [00:06<00:00, 32.08batch/s, acc=89.75%]


Accuracy (Before Projection): 89.75%
Remaining Weights: 45.78%
Saving Before-Projection model...
Saved to trained_models/resnet/mdmm/e92c8d83/unstr_paca/BP_e8_a89.75_r0.46.pth


 - Val (AP)	‖	Epoch 8: 100%|██████████| 196/196 [00:02<00:00, 77.28batch/s, acc=89.75%]       


Accuracy (After Projection): 89.75%

Remaining Weights: 44.49%
Saving After-Projection model...


RuntimeError: [enforce fail at inline_container.cc:659] . unexpected pos 64 vs 0

In [None]:
# from pquant import remove_pruning_from_model

# model_paca = remove_pruning_from_model(trained_model_paca, config)
# SAVE_PATH = 'resnet_mdmm_unstr_paca_pruned.pth'
# torch.save(model_paca.state_dict(), SAVE_PATH)

# model_paca = torchvision.models.resnet18()
# model_paca.load_state_dict(torch.load(SAVE_PATH))

In [None]:
1+1