In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data as data
import timm
import matplotlib.pyplot as plt

# Hyperparameters
BATCH_SIZE = 256
NUM_BITS = 3
INIT_A = 0.55  # Initial power function exponent
FIXED_T = 100.5  # Fixed temperature for soft rounding
LR = 0.01  # Corrected Learning rate
EPOCHS = 10  # Training epochs for optimizing `a`
NUM_ITERATIONS = 100  # Per-layer optimization iterations

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
test_loader = data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Load Pretrained ResNet18 from timm
import resnet18
import mobilenetv2
import densenet

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = resnet18.resnet18(pretrained=False, device=device)
#model = mobilenetv2.mobilenet_v2(pretrained=False, device=device)
#model = densenet.densenet121(pretrained=False, device=device)
#model.to(device)

state_dict = torch.load('/content/resnet18.pt', map_location=torch.device('cpu'))
#state_dict = torch.load('/content/mobilenet_v2.pt', map_location=torch.device('cpu'))
#state_dict = torch.load('/content/densenet121.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)

print("Accuracy BEFORE Quantization:")
def evaluate(model, test_loader):
    model.to(device)
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    print(f"Test Accuracy: {100 * correct / total:.2f}%")

#evaluate(model, test_loader)

Files already downloaded and verified
Files already downloaded and verified
Accuracy BEFORE Quantization:


  state_dict = torch.load('/content/resnet18.pt', map_location=torch.device('cpu'))


In [2]:
import torch
import torch.nn as nn

class DAQTanhQuantization(nn.Module):
    def __init__(self, num_levels=2**NUM_BITS, init_beta=1.8184, init_c=1.5, fixed_T=FIXED_T, hard_rounding=False):
        super().__init__()
        self.num_levels = num_levels
        self.beta = nn.Parameter(torch.tensor(init_beta, dtype=torch.float32))  # Trainable beta
        self.c = nn.Parameter(torch.tensor(init_c, dtype=torch.float32))  # Trainable scaling factor
        self.fixed_T = fixed_T
        self.hard_rounding = hard_rounding

    def forward(self, w):
        offset = w.mean()  # Centering the weights
        w_shifted = w - offset
        EPSILON = 1e-6

        # Apply piecewise tanh scaling
        tau = 0.5  # Threshold
        w_transformed = torch.where(
            torch.abs(w_shifted) < tau,
            self.beta * w_shifted,  # Linear scaling for small values
            self.c * torch.tanh(self.beta * w_shifted) + EPSILON  # Tanh scaling for large values
        )

        # Normalize weights to [0, 1]
        w_min, w_max = w_transformed.min(), w_transformed.max()
        w_normalized = (w_transformed - w_min) / (w_max - w_min + 1e-6)  # Avoid division by zero

        # Define quantization levels
        q_levels = torch.linspace(0, 1, self.num_levels, device=w.device)

        if self.hard_rounding:
            # HARD ROUNDING: Direct rounding to nearest quantization bin
            w_quantized = torch.round(w_normalized * (self.num_levels - 1)) / (self.num_levels - 1)
        else:
            # SOFT ROUNDING: Distance-aware quantization with softmax
            distances = -torch.abs(w_normalized.unsqueeze(-1) - q_levels)  # Negative for softmax
            soft_weights = torch.softmax(distances * self.fixed_T, dim=-1)  # Softmax with temperature
            w_quantized = (soft_weights * q_levels).sum(dim=-1)

        # De-normalize back to original scale
        w_dequantized = w_quantized * (w_max - w_min) + w_min

        # Clip before inverse transformation to avoid NaN in atanh
        w_dequantized = torch.clamp(w_dequantized, -0.9999, 0.9999)
        w_dequantized = torch.where(
            torch.abs(w_shifted) < tau,
            w_dequantized / self.beta,  # Linear inverse for small values
            torch.atanh(w_dequantized - EPSILON) / self.beta
        ) + offset

        return w_dequantized


In [2]:
class DAQPowerQuantization(nn.Module):
    def __init__(self, num_levels=2**NUM_BITS, init_a=INIT_A, fixed_T=FIXED_T, hard_rounding=False):
        super().__init__()
        self.num_levels = num_levels
        self.a = nn.Parameter(torch.tensor(init_a, dtype=torch.float32))  # Trainable power exponent
        self.fixed_T = fixed_T  # Fixed temperature for soft rounding
        self.hard_rounding = hard_rounding  # Toggle for hard rounding

    def forward(self, w):
        offset = w.mean()  # Centering the weights
        w_shifted = w - offset
        EPSILON = 1e-6
        w_transformed = torch.sign(w_shifted) * (torch.abs(w_shifted) ** self.a) + EPSILON

        # Normalize weights to [0, 1]
        w_min, w_max = w_transformed.min(), w_transformed.max()
        w_normalized = (w_transformed - w_min) / (w_max - w_min + 1e-6)  # Avoid division by zero

        # Define quantization levels
        q_levels = torch.linspace(0, 1, self.num_levels, device=w.device)

        if self.hard_rounding:
            # HARD ROUNDING: Direct rounding to nearest quantization bin
            w_quantized = torch.round(w_normalized * (self.num_levels - 1)) / (self.num_levels - 1)
        else:
            # SOFT ROUNDING: Distance-aware quantization with softmax
            distances = -torch.abs(w_normalized.unsqueeze(-1) - q_levels)  # Negative for softmax
            soft_weights = torch.softmax(distances * self.fixed_T, dim=-1)  # Softmax with temperature
            w_quantized = (soft_weights * q_levels).sum(dim=-1)

        # De-normalize back to original scale
        w_dequantized = w_quantized * (w_max - w_min) + w_min
        w_dequantized = (torch.abs(w_dequantized) ** (1/self.a)) * torch.sign(w_dequantized) + offset  # Descale

        return w_dequantized


In [3]:
def optimize_per_layer(model, test_loader, num_iterations=NUM_ITERATIONS, lr=LR):
    model.to(device)
    model.eval()
    updated_state_dict = model.state_dict()
    quantization_layers = {}

    print("Starting per-layer quantization optimization...")

    # Get a batch of test images and labels before optimization
    data_iterator = iter(test_loader)
    images, labels = next(data_iterator)
    images, labels = images.to(device), labels.to(device)

    # Compute classification loss before optimization
    with torch.no_grad():
        outputs = model(images)
        pre_optimization_loss = classification_loss_fn(outputs, labels).item()

    print(f"Initial Classification Loss Before Optimization: {pre_optimization_loss:.6f}")

    for name, param in model.named_parameters():
        if "conv" in name and "weight" in name:
            print(f"Optimizing {name}...")
            layer_name = name.replace(".weight", "")

            if layer_name not in temp_activations:
                continue  # Skip if activation was not stored

            # Initialize differentiable quantization module for this layer
            #quantization_layers[layer_name] = DAQPowerQuantization().to(device)
            quantization_layers[layer_name] = DAQTanhQuantization().to(device)
            quant_layer = quantization_layers[layer_name]

            optimizer = optim.Adam(quant_layer.parameters(), lr=lr)
            loss_fn = nn.MSELoss()

            # Retrieve stored activations
            x = temp_activations[layer_name]
            original_weight = param.clone().detach()

            # Track `a` values
            a_tracking = []
            prev_class_loss = 100
            # Per-layer optimization loop
            for iter_idx in range(num_iterations):
                optimizer.zero_grad()
                quantized_weight = quant_layer(original_weight)  # Apply DAQ quantization

                # # Compute output difference
                # quantized_output = nn.functional.conv2d(x, quantized_weight, stride=param.shape[2], padding=param.shape[3])
                # original_output = nn.functional.conv2d(x, original_weight, stride=param.shape[2], padding=param.shape[3])
                # Detect Depthwise Convolution
                groups = param.shape[0] if param.shape[1] == 1 else 1  # Depthwise conv fix

                # Compute output difference
                quantized_output = nn.functional.conv2d(
                    x, quantized_weight, stride=param.shape[2], padding=param.shape[3], groups=groups
                )
                original_output = nn.functional.conv2d(
                    x, original_weight, stride=param.shape[2], padding=param.shape[3], groups=groups
                )

                # Reconstruction loss
                reconstruction_loss = loss_fn(quantized_output, original_output)

                if reconstruction_loss < 2e-7:
                    break

                # Get a batch of test images and labels
                for images, labels in test_loader:
                    images, labels = images.to(device), labels.to(device)
                    break  # Use only one batch

                # Temporarily replace model weight with quantized weight (without modifying in-place)
                with torch.no_grad():
                    temp_weight = quantized_weight.detach().clone()
                    param_backup = param.data.clone()
                    param.data.copy_(temp_weight)

                # Compute classification loss with quantized model
                classification_loss = classification_loss_fn(model(images), labels)
                if prev_class_loss < classification_loss:
                    break
                prev_class_loss = classification_loss
                # Restore original weight after loss computation
                with torch.no_grad():
                    param.data.copy_(param_backup)

                # Compute total loss
                final_loss = 0.1* reconstruction_loss + 0.9 * classification_loss
                final_loss.backward()
                optimizer.step()

                a_tracking.append(quant_layer.beta.item())

                if iter_idx % 10== 0:
                    #print(classification_loss.item(), reconstruction_loss.item())
                    print(f"Iter {iter_idx}: classification_loss = {classification_loss.item():.8f}, reconstruction_loss = {reconstruction_loss.item():.4f}")
                    print(f"Iter {iter_idx}: Loss = {final_loss.item():.8f}, b = {quant_layer.beta.item():.4f}")

            updated_state_dict[name] = quant_layer(original_weight).detach()

    model.load_state_dict(updated_state_dict)
    print("Per-layer optimization complete.")


In [4]:
# Apply Hard Rounding for Final Evaluation
def apply_hard_rounding(model):
    print("Applying Hard Rounding for Final Evaluation...")
    for name, param in model.named_parameters():
        if "conv" in name and "weight" in name:
            layer = DAQPowerQuantization(hard_rounding=True).to(device)
            param.data = layer(param).detach()  # Apply hard rounding


In [4]:
# Define Cross-Entropy Loss for Classification
classification_loss_fn = nn.CrossEntropyLoss()

# ====== Store Layer Activations for Per-Layer Optimization ======
temp_activations = {}

def activation_hook(layer_name):
    def hook(module, input, output):
        temp_activations[layer_name] = input[0].detach()
    return hook

# Register hooks for convolutional layers
for name, layer in model.named_modules():
    if isinstance(layer, nn.Conv2d):
        layer.register_forward_hook(activation_hook(name))
# First, optimize using soft rounding
optimize_per_layer(model, test_loader)

# Then, apply hard rounding before final evaluation
#apply_hard_rounding(model)

# Evaluate Model After Hard Rounding
evaluate(model, test_loader)


Starting per-layer quantization optimization...
Initial Classification Loss Before Optimization: 0.526322
Optimizing conv1.weight...
Iter 0: classification_loss = 0.79476923, reconstruction_loss = 0.0046
Iter 0: Loss = 0.71575695, b = 1.8133
Optimizing layer1.0.conv1.weight...
Iter 0: classification_loss = 1.93208408, reconstruction_loss = 0.0001
Iter 0: Loss = 1.73888171, b = 1.8168
Optimizing layer1.0.conv2.weight...
Iter 0: classification_loss = 2.12270570, reconstruction_loss = 0.0000
Iter 0: Loss = 1.91043568, b = 1.8185
Optimizing layer1.1.conv1.weight...
Iter 0: classification_loss = 2.06956673, reconstruction_loss = 0.0000
Iter 0: Loss = 1.86261165, b = 1.8191
Optimizing layer1.1.conv2.weight...
Iter 0: classification_loss = 2.11616588, reconstruction_loss = 0.0000
Iter 0: Loss = 1.90454936, b = 1.8184
Optimizing layer2.0.conv1.weight...
Iter 0: classification_loss = 2.13264680, reconstruction_loss = 0.0000
Iter 0: Loss = 1.91938293, b = 1.8183
Optimizing layer2.0.conv2.weight.

In [6]:
#model = resnet18.resnet18(pretrained=False, device=device)
#model = mobilenetv2.mobilenet_v2(pretrained=False, device=device)
model = densenet.densenet121(pretrained=False, device=device)
model.to(device)

#state_dict = torch.load('/content/resnet18.pt', map_location=torch.device('cpu'))
#state_dict = torch.load('/content/mobilenet_v2.pt', map_location=torch.device('cpu'))
state_dict = torch.load('/content/densenet121.pt', map_location=torch.device('cpu'))
model.load_state_dict(state_dict, strict=False)
model.eval()
model.to(device)

def uniform_quantization(tensor, num_bits=NUM_BITS):
    min_val, max_val = tensor.min(), tensor.max()
    scale = (max_val - min_val) / (2 ** num_bits - 1)
    quantized_tensor = torch.round((tensor - min_val) / scale) * scale + min_val
    return quantized_tensor

# Apply Uniform Quantization to Conv Layers Only
for name, param in model.named_parameters():
    if "conv" in name and "weight" in name:
        param.data = uniform_quantization(param.data)

evaluate(model, test_loader)


  state_dict = torch.load('/content/densenet121.pt', map_location=torch.device('cpu'))


Test Accuracy: 10.07%


In [2]:
import torch
import torch.nn as nn

class DAQLogQuantization(nn.Module):
    def __init__(self, num_levels=2**NUM_BITS, init_base=1.44, fixed_T=FIXED_T, hard_rounding=False):
        super().__init__()
        self.num_levels = num_levels
        self.base = nn.Parameter(torch.tensor(init_base, dtype=torch.float32))  # Trainable log base
        self.fixed_T = fixed_T  # Fixed temperature for soft rounding
        self.hard_rounding = hard_rounding  # Toggle for hard rounding

    def forward(self, w):
        offset = w.mean()  # Centering the weights
        w_shifted = w - offset
        EPSILON = 1e-6
        w_transformed = torch.sign(w_shifted) * torch.log1p(torch.abs(w_shifted)) / torch.log(self.base) + EPSILON

        # Normalize weights to [0, 1]
        w_min, w_max = w_transformed.min(), w_transformed.max()
        w_normalized = (w_transformed - w_min) / (w_max - w_min + 1e-6)  # Avoid division by zero

        # Define quantization levels
        q_levels = torch.linspace(0, 1, self.num_levels, device=w.device)

        if self.hard_rounding:
            # HARD ROUNDING: Direct rounding to nearest quantization bin
            w_quantized = torch.round(w_normalized * (self.num_levels - 1)) / (self.num_levels - 1)
        else:
            # SOFT ROUNDING: Distance-aware quantization with softmax
            distances = -torch.abs(w_normalized.unsqueeze(-1) - q_levels)  # Negative for softmax
            soft_weights = torch.softmax(distances * self.fixed_T, dim=-1)  # Softmax with temperature
            w_quantized = (soft_weights * q_levels).sum(dim=-1)

        # De-normalize back to original scale
        w_dequantized = w_quantized * (w_max - w_min) + w_min
        w_dequantized = torch.exp(w_dequantized * torch.log(self.base)) - 1  # Descale using inverse log
        w_dequantized = w_dequantized * torch.sign(w_shifted) + offset

        return w_dequantized
