In [None]:
!echo $PWD
!which python

!echo $CUDA_HOME
!echo $CUDNN_LIB_DIR
!echo $CUDNN_INCLUDE_DIR
!echo $LD_LIBRARY_PATH

In [None]:
# Cell 1: Imports and Setup
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, Subset
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

import torch.nn.functional as F
import torch.quantization
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

import torch.nn as nn
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch.nn.utils.prune as prune
from torch.cuda.amp import autocast
import time
import pandas as pd

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [None]:
print(torch.__version__)
print(torch.version.cuda)
print(torch.cuda.is_available())  # Should return True if CUDA is available
print(torch.cuda.current_device())  # Should return the device index
print(torch.cuda.get_device_name(0))  # Should return the name of the GPU

In [None]:
# Cell 2: Configuration
model_name = "google/vit-huge-patch14-224-in21k"
PRUNE_PERCENTILE = 30
BATCH_SIZE = 64


In [None]:
# Cell 3: Load Models and Feature Extractor
original_fp32_model = ViTForImageClassification.from_pretrained(model_name, num_labels=10)
pruned_fp32_model = ViTForImageClassification.from_pretrained(model_name, num_labels=10)
pruned_awq_model = ViTForImageClassification.from_pretrained(model_name, num_labels=10)
pruned_fp16_model = ViTForImageClassification.from_pretrained(model_name, num_labels=10)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)


In [None]:
# Cell 4: Load and Prepare Calibration Data
# --- Load CIFAR-10 Datasets ---

# Load CIFAR-10 without transform because feature extractor handles it
cifar10_dataset = CIFAR10(root='./data', train=False, download=True, transform=None)

In [None]:
# Full test set
full_test_images, full_test_labels = zip(*[cifar10_dataset[i] for i in range(len(cifar10_dataset))])
full_test_labels = torch.tensor(full_test_labels)

# Display a few sample images
plt.figure(figsize=(10, 4))
for i in range(6):
    plt.subplot(2, 3, i + 1)
    plt.imshow(full_test_images[i])
    plt.title(f"Label: {full_test_labels[i]}")
    plt.axis('off')
plt.suptitle("full Sample Images")
plt.tight_layout()
plt.show()

In [None]:
# Calibration subset (32 samples)
calibration_indices = list(range(32))
calibration_subset = Subset(cifar10_dataset, calibration_indices)
calibration_images, calibration_labels = zip(*[calibration_subset[i] for i in range(len(calibration_subset))])
calibration_labels = torch.tensor(calibration_labels)

# Display a few sample images
plt.figure(figsize=(10, 4))
for i in range(6):
    plt.subplot(2, 3, i + 1)
    plt.imshow(calibration_images[i])
    plt.title(f"Label: {calibration_labels[i]}")
    plt.axis('off')
plt.suptitle("Calibration Sample Images")
plt.tight_layout()
plt.show()


In [None]:
# Extract inputs using feature extractor
calib_inputs = feature_extractor(images=calibration_images, return_tensors="pt")
test_inputs = feature_extractor(images=full_test_images, return_tensors="pt")

In [None]:
# --- Register Hooks on Linear Layers for Activation Collection ---

# Use original_fp32_model for profiling

activation_stats = {}
memory_stats = {}

def activation_and_memory_hook(module, input, output):
    with torch.no_grad():
        output = output[0]
        if output.dim() == 2:
            avg_activation = output.abs().mean(dim=0)
        elif output.dim() == 3:
            avg_activation = output.abs().mean(dim=[0, 1])
        else:
            raise ValueError(f"Unexpected output dimensions: {output.dim()}")
        mem_bytes = output.numel() * output.element_size()
    
    module_key = f"{str(module)}_{id(module)}"
    activation_stats[module_key] = avg_activation.cpu()
    memory_stats[module_key] = mem_bytes

hooks = []
for i, encoder_layer in enumerate(original_fp32_model.vit.encoder.layer):
    for sub_name, sub_module in encoder_layer.named_modules():
        if isinstance(sub_module, torch.nn.Linear) and (
            "intermediate.dense" in sub_name or "output.dense" in sub_name
        ):
            hook = sub_module.register_forward_hook(activation_and_memory_hook)
            hooks.append(hook)
            print(f"Registered hook on: {sub_name}, id: {id(sub_module)}")

# Run forward passes to collect activations
original_fp32_model.to(device)
batch_size = 4
num_samples = calib_inputs['pixel_values'].size(0)

for i in range(0, num_samples, batch_size):
    batch_inputs = {k: v[i:i+batch_size].to(device) for k, v in calib_inputs.items()}
    with torch.no_grad():
        _ = original_fp32_model(**batch_inputs)
    torch.cuda.empty_cache()

# --- STEP 5: Clean Up Hooks ---
for hook in hooks:
    hook.remove()
torch.cuda.empty_cache()
print("\nHooks removed and activations stored.")

In [None]:
# Cell 6: Pruning
def prune_model(model, activation_stats, percentile=10):
    for encoder_layer in model.vit.encoder.layer:
        for name, module in encoder_layer.named_modules():
            if isinstance(module, torch.nn.Linear) and ("intermediate.dense" in name or "output.dense" in name):
                module_key = f"{str(module)}_{id(module)}"
                if module_key in activation_stats:
                    act_stat = activation_stats[module_key].numpy()
                    threshold = np.percentile(act_stat, percentile)
                    channel_mask = (act_stat >= threshold).astype(np.float32)
                    channel_mask_tensor = torch.tensor(channel_mask, dtype=torch.float32, device=module.weight.device)
                    if module.weight.dim() >= 2:
                        mask = channel_mask_tensor.unsqueeze(-1).expand_as(module.weight)
                    else:
                        mask = channel_mask_tensor
                    prune.custom_from_mask(module, name="weight", mask=mask)

prune_model(pruned_fp32_model, activation_stats, percentile=PRUNE_PERCENTILE)
prune_model(pruned_awq_model, activation_stats, percentile=PRUNE_PERCENTILE)
prune_model(pruned_fp16_model, activation_stats, percentile=PRUNE_PERCENTILE)

In [None]:
# Cell 7: Quantization (AWQ)
def awq_quantize_model(model, activation_stats, bitwidth=8):
    for name, module in model.named_modules():
        module_key = f"{str(module)}_{id(module)}"
        if isinstance(module, torch.nn.Linear) and module_key in activation_stats:
            act_stat = activation_stats[module_key].numpy()
            weight = module.weight.data.cpu().numpy()
            out_channels = weight.shape[0]
            scales = np.zeros(out_channels)
            for i in range(out_channels):
                max_w = np.max(np.abs(weight[i]))
                max_a = act_stat[i]
                scale = max(max_w * max_a, 1e-5)
                scales[i] = scale / (2 ** (bitwidth - 1) - 1)
            weight_q = np.round(weight / scales[:, None]).clip(
                -(2 ** (bitwidth - 1)),
                2 ** (bitwidth - 1) - 1
            ).astype(np.int8)
            # Optional: Dequantize if needed
            # weight_deq = (weight_q * scales[:, None]).astype(np.float32)
            # module.weight.data = torch.tensor(weight_deq, device=module.weight.device)

awq_quantize_model(pruned_awq_model, activation_stats, bitwidth=8)


In [None]:
# Cell 8: Convert to FP16
pruned_fp16_model = pruned_fp16_model.half()


In [None]:
# Cell 9: Evaluation Function
def evaluate_model(model, variant_name="", use_amp=False, inputs=test_inputs, labels=full_test_labels):
    correct = 0
    total = 0
    latencies = []
    torch.cuda.reset_peak_memory_stats(device)
    model.to(device)
    if use_amp:
        inputs = {k: v.half().to(device) for k, v in inputs.items()}
    start_total = time.time()
    with torch.no_grad():
        for i in range(0, labels.size(0), BATCH_SIZE):
            batch_inputs = {k: v[i:i+BATCH_SIZE].to(device) for k, v in inputs.items()}
            batch_labels = labels[i:i+BATCH_SIZE].to(device)
            torch.cuda.synchronize()
            start_batch = time.time()
            if use_amp:
                with autocast():
                    outputs = model(**batch_inputs)
            else:
                outputs = model(**batch_inputs)
            torch.cuda.synchronize()
            latencies.append(time.time() - start_batch)
            preds = torch.argmax(outputs.logits, dim=1)
            correct += (preds == batch_labels).sum().item()
            total += batch_labels.size(0)
    total_time = time.time() - start_total
    accuracy = 100.0 * correct / total
    avg_latency = sum(latencies) / len(latencies)
    peak_memory = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
    return {
        "Variant": variant_name,
        "Accuracy (%)": accuracy,
        "Avg Batch Latency (s)": avg_latency,
        "Peak Memory (MB)": peak_memory,
        "Total Inference Time (s)": total_time
    }


In [None]:
# Cell 10: Evaluate Variants
variants = {
    "Original FP32": {"model": original_fp32_model, "use_amp": False},
    "Pruned FP32": {"model": pruned_fp32_model, "use_amp": False},
    "Pruned + Quantized": {"model": pruned_awq_model, "use_amp": False},
    "Pruned + FP16": {"model": pruned_fp16_model, "use_amp": True},
}

benchmark_results = []
for name, config in variants.items():
    print(f"Evaluating {name}...")
    result = evaluate_model(config["model"], variant_name=name, use_amp=config["use_amp"])
    benchmark_results.append(result)


In [None]:
# Cell 11: Display Results
df = pd.DataFrame(benchmark_results)
print("\nBenchmark Results:")
print(df)
