In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
import re
import os
from lq_layers import LQLinear, LQConv, LQActiv
from resnet import ResNet

In [17]:
def parse_quantization_bits_from_filename(filename):
    match_w = re.search(r'_wq(\d+)', filename)
    match_a = re.search(r'_aq(\d+)', filename)
    match_lambda = re.search(r'_lambda(\d+\.\d+)', filename)

    w_nbits = int(match_w.group(1)) if match_w else None
    a_nbits = int(match_a.group(1)) if match_a else None
    lambda_val = match_lambda.group(1) if match_lambda else None
    base_filename = os.path.splitext(os.path.basename(filename))[0]

    print('weight bits:', w_nbits, 'activation bits:', a_nbits, 'soft_thresholding:', lambda_val, 'base filename:', base_filename)

    return w_nbits, a_nbits, base_filename

In [3]:
def load_model(model_path, num_layers, num_classes, w_nbits, a_nbits, device):
    model = ResNet(w_nbits=w_nbits, a_nbits=a_nbits)
    model.to(device)

    model.load_state_dict(torch.load(model_path, map_location=device))
    model.eval()

    return model

In [4]:
def collect_activations_and_quantized_weights(model, dataloader, device):
    activations = {}
    quantized_weights = {}

    def get_quantized_weight(name, layer):
        def hook(model, input, output):
            q_weight = layer.lq.apply(layer.conv.weight, layer.basis, False)[0]
            quantized_weights[name] = q_weight.detach().cpu().numpy()
        return hook

    def get_activation(name, layer):
      def hook(model, input, output):
          activations[name] = output.detach().cpu().numpy()
      return hook

    for name, layer in model.named_modules():
        if isinstance(layer, LQConv):
            layer.register_forward_hook(get_quantized_weight(name, layer))
        if isinstance(layer, LQActiv):
            layer.register_forward_hook(get_activation(name, layer))

    with torch.no_grad():
        for inputs, _ in dataloader:
            inputs = inputs.to(device)
            _ = model(inputs)

    return activations, quantized_weights

In [5]:
def plot_distributions_side_by_side(activations, quantized_weights, base_filename):
    sorted_weight_keys = sorted(quantized_weights.keys())
    sorted_activation_keys = sorted(activations.keys())

    num_layers = max(len(sorted_weight_keys), len(sorted_activation_keys))
    fig, axes = plt.subplots(num_layers, 2, figsize=(12, 6 * num_layers))  

    output_directory = f"./output/{base_filename}"
    if not os.path.exists(output_directory):
        os.makedirs(output_directory)

    for i, (weight_key, activation_key) in enumerate(zip(sorted_weight_keys, sorted_activation_keys)):
        ax_w = axes[i][0] if num_layers > 1 else axes[0]
        ax_a = axes[i][1] if num_layers > 1 else axes[1]
        ax_w.hist(quantized_weights[weight_key].ravel(), bins=50, alpha=0.75)
        ax_w.set_title(f'Quantized Weight Distribution: {weight_key}')
        ax_w.set_xlabel('Weight Values')
        ax_w.set_ylabel('Frequency')

        ax_a.hist(activations[activation_key].ravel(), bins=50, alpha=0.75)
        ax_a.set_title(f'Activation Distribution: {activation_key}')
        ax_a.set_xlabel('Activation Values')
        ax_a.set_ylabel('Frequency')

        plot_filename = f"{output_directory}/{base_filename}_{weight_key}_and_{activation_key}.png"
        plt.savefig(plot_filename)
        print(f"Saved plot as {plot_filename}")

    plt.close(fig)

In [19]:
TEST_SET = datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.ToTensor())
def analyze_model(model_filename, device='cuda'):
    w_nbits, a_nbits, base_filename = parse_quantization_bits_from_filename(model_filename)
    num_layers = 20
    num_classes = 10

    model = load_model(model_filename, num_layers, num_classes, w_nbits, a_nbits, device)
    test_loader = DataLoader(TEST_SET, batch_size=64, shuffle=False)

    activations, weights = collect_activations_and_quantized_weights(model, test_loader, device)
    plot_distributions_side_by_side(activations, weights, base_filename)


Files already downloaded and verified


In [20]:
for filename in os.listdir('/home/users/xt37/LQ-Nets'):
    if filename.endswith('_.pt'):
        analyze_model(f'/home/users/xt37/LQ-Nets/{filename}')
        print(f'Analyzed {filename}')

weight bits: 3 activation bits: None soft_thresholding: 0.001 base filename: resnet20_cifar_wq3_lambda0.001epoch_179_
Saved plot as ./output/resnet20_cifar_wq3_lambda0.001epoch_179_/resnet20_cifar_wq3_lambda0.001epoch_179__layer1.0.conv1_and_layer1.0.activ1.1.png
Saved plot as ./output/resnet20_cifar_wq3_lambda0.001epoch_179_/resnet20_cifar_wq3_lambda0.001epoch_179__layer1.0.conv2_and_layer1.0.activ2.1.png
Saved plot as ./output/resnet20_cifar_wq3_lambda0.001epoch_179_/resnet20_cifar_wq3_lambda0.001epoch_179__layer1.1.conv1_and_layer1.1.activ1.1.png
Saved plot as ./output/resnet20_cifar_wq3_lambda0.001epoch_179_/resnet20_cifar_wq3_lambda0.001epoch_179__layer1.1.conv2_and_layer1.1.activ2.1.png
Saved plot as ./output/resnet20_cifar_wq3_lambda0.001epoch_179_/resnet20_cifar_wq3_lambda0.001epoch_179__layer1.2.conv1_and_layer1.2.activ1.1.png
Saved plot as ./output/resnet20_cifar_wq3_lambda0.001epoch_179_/resnet20_cifar_wq3_lambda0.001epoch_179__layer1.2.conv2_and_layer1.2.activ2.1.png
Saved 