In [1]:
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
import torch.optim as optim
import torch.quantization
from torch.cuda.amp import autocast, GradScaler
from collections import Counter
import numpy as np
from scipy.stats import entropy

# Define a simple model for demonstration
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        return self.fc3(x)

# Initialize the model, optimizer, and scaler
model = SimpleModel()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scaler = GradScaler()  # For AMP training stability

# Example input and target
inputs = torch.randn(128, 784)  # 128 samples, 784 features each
targets = torch.randint(0, 10, (128,))  # 128 targets for classification

# --- Step 1: Pruning using Taylor Method ---
def taylor_pruning(model, inputs, targets, pruning_amount=0.2):
    model.eval()
    loss_fn = nn.CrossEntropyLoss()
    outputs = model(inputs)
    loss = loss_fn(outputs, targets)
    loss.backward()  # Compute gradients for Taylor-based importance

    # Prune weights based on gradient magnitude (approximation of Taylor method)
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.l1_unstructured(module, name="weight", amount=pruning_amount)
    print("Pruning applied!")

taylor_pruning(model, inputs, targets, pruning_amount=0.2)

# --- Step 2: Quantization-Aware Training (QAT) ---
def apply_quantization(model):
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    torch.quantization.prepare_qat(model, inplace=True)
    print("Quantization-aware training setup complete.")
    return model

model = apply_quantization(model)

# --- Step 3: Huffman Coding ---
def huffman_encoding(tensor):
    # Flatten weights and compute frequencies
    flat_tensor = tensor.detach().cpu().numpy().flatten()
    frequencies = Counter(flat_tensor)
    prob_dist = np.array(list(frequencies.values())) / len(flat_tensor)
    print(f"Compression Entropy: {entropy(prob_dist, base=2)} bits")

    # Placeholder for Huffman encoding logic (to be implemented for actual compression)
    return frequencies

# Example: Apply Huffman encoding to quantized weights
def apply_huffman_to_model(model):
    for name, param in model.named_parameters():
        if "weight" in name:
            print(f"Applying Huffman coding to {name}...")
            huffman_encoding(param)

# Huffman coding applied post-QAT conversion
apply_huffman_to_model(model)

# --- Step 4: Training with Automatic Mixed Precision (AMP) ---
def train_step_amp(model, optimizer, inputs, targets, scaler):
    model.train()
    optimizer.zero_grad()

    with autocast():  # Use mixed precision
        outputs = model(inputs)
        loss = nn.CrossEntropyLoss()(outputs, targets)

    # Scale gradients, backward pass, and optimizer step
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

    return loss.item()

# Train the model with QAT + AMP + Huffman-ready weights
batch_size = 32
for epoch in range(3):  # Example 3 epochs
    for i in range(0, inputs.size(0), batch_size):
        batch_inputs = inputs[i:i+batch_size]
        batch_targets = targets[i:i+batch_size]
        loss = train_step_amp(model, optimizer, batch_inputs, batch_targets, scaler)
        print(f"Epoch {epoch}, Batch Loss: {loss}")

# Convert QAT model to fully quantized
model = torch.quantization.convert(model.eval(), inplace=True)
print("Model quantized and ready for deployment.")


  scaler = GradScaler()  # For AMP training stability


Pruning applied!


AssertionError: prepare_qat only works on models in training mode