In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from transformers import ViTForImageClassification, ViTImageProcessor, AutoFeatureExtractor
from tqdm import tqdm
import time
import torch.nn.functional as F

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Hyperparameters
learning_rate = 0.0002
train_batch_size = 16
eval_batch_size = 8
num_epochs = 4
seed = 42

# Set seed for reproducibility
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

# Load CIFAR100 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

In [3]:
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(trainset, batch_size=train_batch_size, shuffle=True, num_workers=2)
test_loader = DataLoader(testset, batch_size=eval_batch_size, shuffle=False, num_workers=2)

# Load pre-trained ViT model
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=100)
model.to(device)

Files already downloaded and verified
Files already downloaded and verified


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [4]:
# Define optimizer and loss function
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0

    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images).logits
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        train_total += labels.size(0)
        train_correct += predicted.eq(labels).sum().item()

    train_loss /= len(train_loader)
    train_accuracy = train_correct / train_total

    # Evaluation loop
    model.eval()
    eval_loss = 0.0
    eval_correct = 0
    eval_total = 0

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images).logits
            loss = criterion(outputs, labels)

            eval_loss += loss.item()
            _, predicted = outputs.max(1)
            eval_total += labels.size(0)
            eval_correct += predicted.eq(labels).sum().item()

    eval_loss /= len(test_loader)
    eval_accuracy = eval_correct / eval_total

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_accuracy:.4f}")
    print(f"Eval Loss: {eval_loss:.4f}, Eval Accuracy: {eval_accuracy:.4f}")
    print()

Epoch 1/4: 100%|██████████| 3125/3125 [29:53<00:00,  1.74it/s]
Evaluating: 100%|██████████| 1250/1250 [01:58<00:00, 10.58it/s]


Epoch 1/4
Train Loss: 1.1488, Train Accuracy: 0.7275
Eval Loss: 0.7587, Eval Accuracy: 0.7886



Epoch 2/4: 100%|██████████| 3125/3125 [29:54<00:00,  1.74it/s]
Evaluating: 100%|██████████| 1250/1250 [01:59<00:00, 10.48it/s]


Epoch 2/4
Train Loss: 0.5075, Train Accuracy: 0.8518
Eval Loss: 0.6755, Eval Accuracy: 0.8113



Epoch 3/4: 100%|██████████| 3125/3125 [30:04<00:00,  1.73it/s]
Evaluating: 100%|██████████| 1250/1250 [01:59<00:00, 10.50it/s]


Epoch 3/4
Train Loss: 0.3714, Train Accuracy: 0.8891
Eval Loss: 0.6988, Eval Accuracy: 0.8056



Epoch 4/4: 100%|██████████| 3125/3125 [30:01<00:00,  1.73it/s]
Evaluating: 100%|██████████| 1250/1250 [01:59<00:00, 10.48it/s]

Epoch 4/4
Train Loss: 0.2958, Train Accuracy: 0.9101
Eval Loss: 0.8210, Eval Accuracy: 0.7794






In [None]:
# Save the fine-tuned model
torch.save(model.state_dict(), "vit_cifar100_finetuned.pth")

In [None]:
# state_dict = torch.load('vit_cifar100_finetuned.pth', weights_only=True)
state_dict = torch.load('vit_cifar100_finetuned.pth', weights_only=True, map_location=torch.device('cpu'))

# Load the state dict into your model
model.load_state_dict(state_dict)

<All keys matched successfully>

In [5]:
def eval_op(model):
    model.eval()
    eval_loss = 0.0
    eval_correct_top1 = 0
    eval_correct_top3 = 0
    eval_total = 0

    # Start timing
    if device == 'cuda':
        torch.cuda.synchronize()
    start_time = time.time()

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images).logits
            loss = criterion(outputs, labels)

            eval_loss += loss.item()

            # Get the top 1 and top 3 predictions
            _, top1_predicted = outputs.max(1)
            _, top3_predicted = outputs.topk(3, 1, largest=True, sorted=True)

            # Calculate top-1 accuracy
            eval_correct_top1 += top1_predicted.eq(labels).sum().item()

            # Expand labels to match the shape of top3_predicted
            labels_expanded = labels.view(-1, 1).expand_as(top3_predicted)

            # Check if the correct label is in the top 3 predictions
            correct_top3 = top3_predicted.eq(labels_expanded).any(dim=1)

            eval_total += labels.size(0)
            eval_correct_top3 += correct_top3.sum().item()

    # End timing
    if device == 'cuda':
        torch.cuda.synchronize()
    end_time = time.time()

    # Calculate evaluation time
    eval_time = end_time - start_time

    eval_loss /= len(test_loader)
    eval_accuracy_top1 = eval_correct_top1 / eval_total
    eval_accuracy_top3 = eval_correct_top3 / eval_total

    print(f"Eval Loss: {eval_loss:.4f}")
    print(f"Eval Top-1 Accuracy: {eval_accuracy_top1:.4f}")
    print(f"Eval Top-3 Accuracy: {eval_accuracy_top3:.4f}")
    print(f"Evaluation Time: {eval_time:.2f} seconds")

In [6]:
# Function to count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [7]:
# Load pre-trained ViT model
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=100)
model.to(device)

state_dict = torch.load('vit_cifar100_finetuned.pth', weights_only=True, map_location=torch.device('cpu'))

# Load the state dict into your model
model.load_state_dict(state_dict)



student_model = ViTForImageClassification.from_pretrained("facebook/deit-tiny-patch16-224")
student_model.classifier = torch.nn.Linear(student_model.classifier.in_features, 100)
student_model.to(device)

state_dict = torch.load('distilled_deit_tiny.pth', weights_only=True, map_location=torch.device('cpu'))

# Load the state dict into your model
student_model.load_state_dict(state_dict)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


<All keys matched successfully>

In [26]:
print(count_parameters(model))
eval_op(model)

85875556


Evaluating: 100%|██████████| 1250/1250 [04:07<00:00,  5.06it/s]

Eval Loss: 0.8210
Eval Top-1 Accuracy: 0.7794
Eval Top-3 Accuracy: 0.9228
Evaluation Time: 247.13 seconds





In [27]:
import torch.nn.functional as F
import torch_pruning as tp
from transformers.models.vit.modeling_vit import ViTSelfAttention
import math

# Define the new forward function for ViTSelfAttention
def new_forward(self, hidden_states, head_mask=None, output_attentions=False):
    batch_size, seq_length, _ = hidden_states.shape

    mixed_query_layer = self.query(hidden_states)
    mixed_key_layer = self.key(hidden_states)
    mixed_value_layer = self.value(hidden_states)

    query_layer = self.transpose_for_scores(mixed_query_layer)
    key_layer = self.transpose_for_scores(mixed_key_layer)
    value_layer = self.transpose_for_scores(mixed_value_layer)

    # Take the dot product between "query" and "key" to get the raw attention scores.
    attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
    attention_scores = attention_scores / math.sqrt(self.attention_head_size)

    # Normalize the attention scores to probabilities.
    attention_probs = F.softmax(attention_scores, dim=-1)

    # This is actually dropping out entire tokens to attend to, which might
    # seem a bit unusual, but is taken from the original Transformer paper.
    attention_probs = self.dropout(attention_probs)

    # Mask heads if we want to
    if head_mask is not None:
        attention_probs = attention_probs * head_mask

    context_layer = torch.matmul(attention_probs, value_layer)
    context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
    new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
    context_layer = context_layer.view(*new_context_layer_shape)

    outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
    return outputs


# model.eval().to("cuda")

# Prepare example inputs
example_inputs = torch.randn(1, 3, 224, 224)#.to("cuda")

# Set up pruning configuration
num_heads = {}
ignored_layers = [model.classifier]

# Replace the forward function and set up num_heads
for m in model.modules():
    if isinstance(m, ViTSelfAttention):
        m.forward = new_forward.__get__(m, ViTSelfAttention)
        num_heads[m.query] = m.num_attention_heads

imp = tp.importance.GroupNormImportance(2)
pruner = tp.pruner.MetaPruner(
    model,
    example_inputs,
    iterative_steps=5,
    global_pruning=False,
    importance=imp,
    ignored_layers=ignored_layers,
    num_heads=num_heads,
    prune_head_dims=False,
    prune_num_heads=True,
    head_pruning_ratio=(1.0/10.0),
    round_to=2,
)

# Perform pruning
for i, g in enumerate(pruner.step(interactive=True)):
    g.prune()

# Modify the attention head size and all head size after pruning
head_id = 0
for m in model.modules():
    if isinstance(m, ViTSelfAttention):
        print(f"Head #{head_id}")
        print(f"[Before Pruning] Num Heads: {m.num_attention_heads}, Head Dim: {m.attention_head_size} =>")
        m.num_attention_heads = pruner.num_heads[m.query]
        m.attention_head_size = m.query.out_features // m.num_attention_heads
        m.all_head_size = m.num_attention_heads * m.attention_head_size
        print(f"[After Pruning] Num Heads: {m.num_attention_heads}, Head Dim: {m.attention_head_size}")
        print()
        head_id += 1

# Print the modified model structure
print(model)

# Save the pruned model
pruned_model_path = "pruned_model.pth"
model.save_pretrained(pruned_model_path)

Head #0
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #1
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #2
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #3
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #4
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #5
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #6
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #7
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #8
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head Dim: 64

Head #9
[Before Pruning] Num Heads: 12, Head Dim: 64 =>
[After Pruning] Num Heads: 12, Head

In [28]:
# step 5
print(count_parameters(model))
eval_op(model)

72056206


Evaluating: 100%|██████████| 1250/1250 [03:52<00:00,  5.39it/s]

Eval Loss: 1.1247
Eval Top-1 Accuracy: 0.7015
Eval Top-3 Accuracy: 0.8695
Evaluation Time: 232.08 seconds





In [18]:
#step 4
print(count_parameters(model))
eval_op(model)

68951428


Evaluating: 100%|██████████| 1250/1250 [04:43<00:00,  4.42it/s]

Eval Loss: 1.2572
Eval Top-1 Accuracy: 0.6726
Eval Top-3 Accuracy: 0.8476
Evaluation Time: 283.12 seconds





In [21]:
#step 3
print(count_parameters(model))
eval_op(model)

63703268


Evaluating: 100%|██████████| 1250/1250 [03:59<00:00,  5.23it/s]

Eval Loss: 1.7575
Eval Top-1 Accuracy: 0.5660
Eval Top-3 Accuracy: 0.7551
Evaluation Time: 239.13 seconds





In [24]:
#step 2
print(count_parameters(model))
eval_op(model)

53796772


Evaluating: 100%|██████████| 1250/1250 [03:28<00:00,  5.99it/s]

Eval Loss: 3.4750
Eval Top-1 Accuracy: 0.2342
Eval Top-3 Accuracy: 0.3854
Evaluation Time: 208.84 seconds





In [30]:
teacher_model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=100)
teacher_model.to(device)

state_dict = torch.load('vit_cifar100_finetuned.pth', weights_only=True, map_location=torch.device('cpu'))

# Load the state dict into your model
teacher_model.load_state_dict(state_dict)
student_model = ViTForImageClassification.from_pretrained("facebook/deit-tiny-patch16-224")
student_model.classifier = torch.nn.Linear(student_model.classifier.in_features, 100)

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
for param in teacher_model.parameters():
    param.requires_grad = False

# Set models to evaluation mode
teacher_model.eval()
student_model.train()

# Load feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

In [None]:
# Define optimizer
optimizer = torch.optim.AdamW(student_model.parameters(), lr=1e-4)

# Distillation parameters
temperature = 2.0
alpha = 0.5

In [None]:
def distillation_loss(student_logits, teacher_logits, labels, temperature=2.0, alpha=0.5):
    hard_loss = F.cross_entropy(student_logits, labels)
    soft_loss = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=1),
        F.softmax(teacher_logits / temperature, dim=1),
        reduction='batchmean'
    ) * (temperature ** 2)
    return alpha * hard_loss + (1 - alpha) * soft_loss

# Evaluation function
def evaluate(model, data_loader, device):
    model.eval()
    eval_total = 0
    eval_correct_top1 = 0
    eval_correct_top3 = 0

    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            labels = labels.to(device)

            outputs = model(images)

            # Top-1 Accuracy
            _, predicted = torch.max(outputs.logits, 1)
            eval_correct_top1 += (predicted == labels).sum().item()

            # Top-3 Accuracy
            _, top3_predicted = outputs.logits.topk(3, 1, largest=True, sorted=True)
            labels_expanded = labels.view(-1, 1).expand_as(top3_predicted)
            correct = top3_predicted.eq(labels_expanded).any(dim=1)
            eval_correct_top3 += correct.sum().item()

            eval_total += labels.size(0)

    accuracy_top1 = eval_correct_top1 / eval_total
    accuracy_top3 = eval_correct_top3 / eval_total

    return accuracy_top1, accuracy_top3

In [None]:
num_epochs = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher_model.to(device)
student_model.to(device)

for epoch in range(num_epochs):
    student_model.train()
    total_loss = 0
    for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images = images.to(device)
        labels = labels.to(device)

        # Get teacher predictions
        with torch.no_grad():
            teacher_outputs = teacher_model(images)
            teacher_logits = teacher_outputs.logits

        # Get student predictions
        student_outputs = student_model(images)
        student_logits = student_outputs.logits

        # Calculate loss
        loss = distillation_loss(student_logits, teacher_logits, labels, temperature, alpha)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Calculate test accuracy
    test_accuracy_top1, test_accuracy_top3 = evaluate(student_model, test_loader, device)

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(train_loader):.4f}, "
          f"Test Accuracy Top-1: {test_accuracy_top1:.4f}, Top-3: {test_accuracy_top3:.4f}")


Epoch 1/4: 100%|██████████| 3125/3125 [12:56<00:00,  4.02it/s]


Epoch 1/4, Loss: 1.9217, Test Accuracy Top-1: 0.7226, Top-3: 0.9002


Epoch 2/4: 100%|██████████| 3125/3125 [12:56<00:00,  4.03it/s]


Epoch 2/4, Loss: 0.9378, Test Accuracy Top-1: 0.7553, Top-3: 0.9142


Epoch 3/4: 100%|██████████| 3125/3125 [12:55<00:00,  4.03it/s]


Epoch 3/4, Loss: 0.6997, Test Accuracy Top-1: 0.7682, Top-3: 0.9167


Epoch 4/4: 100%|██████████| 3125/3125 [12:57<00:00,  4.02it/s]


Epoch 4/4, Loss: 0.5574, Test Accuracy Top-1: 0.7785, Top-3: 0.9256


In [None]:
torch.save(student_model.state_dict(), 'distilled_deit_tiny.pth')

In [33]:
print(count_parameters(student_model))
eval_op(student_model)

5543716


Evaluating: 100%|██████████| 1250/1250 [01:18<00:00, 15.89it/s]

Eval Loss: 0.7914
Eval Top-1 Accuracy: 0.7785
Eval Top-3 Accuracy: 0.9256
Evaluation Time: 78.65 seconds





In [16]:
model

ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
            (dense): Linear(in_fe

In [17]:
import torch
from transformers import ViTForImageClassification

# Load the fine-tuned model
model = ViTForImageClassification.from_pretrained("google/vit-base-patch16-224-in21k", num_labels=100)
state_dict = torch.load('vit_cifar100_finetuned.pth', weights_only=True, map_location=torch.device('cpu'))
model.load_state_dict(state_dict)


# Perform dynamic quantization
quantized_model = torch.quantization.quantize_dynamic(
    model,
        {torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm},  # Quantize only Linear layers
    dtype=torch.qint8
)

# Save the quantized model
torch.save(quantized_model.state_dict(), 'vit_cifar100_quantized.pth')

# To use the quantized model
quantized_model.to(device)
quantized_model.eval()

Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTSdpaAttention(
            (attention): ViTSdpaSelfAttention(
              (query): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (key): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (value): DynamicQuantizedLinear(in_features=768, out_features=768, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): DynamicQuantizedLinear(in_features=768,

In [18]:
print(count_parameters(quantized_model))
eval_op(quantized_model)

781056


Evaluating: 100%|██████████| 1250/1250 [03:48<00:00,  5.46it/s]

Eval Loss: 0.8766
Eval Top-1 Accuracy: 0.7673
Eval Top-3 Accuracy: 0.9150
Evaluation Time: 228.86 seconds





In [4]:
!pip install onnxruntime

Defaulting to user installation because normal site-packages is not writeable
Collecting onnxruntime
  Using cached onnxruntime-1.20.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Using cached onnxruntime-1.20.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (13.3 MB)
Installing collected packages: onnxruntime
[0mSuccessfully installed onnxruntime-1.20.1


In [15]:
from transformers import ViTForImageClassification, ViTFeatureExtractor
import onnx
from onnxruntime.quantization import quantize_dynamic, QuantType
import onnxruntime as ort
from PIL import Image
import numpy as np

# 1. Export the PyTorch Model to ONNX
def export_to_onnx():
    model.eval()  # Set to evaluation mode
    model.to("cpu")  # Export on CPU

    # Create a dummy input tensor
    dummy_input = torch.randn(1, 3, 224, 224)  # Batch size 1, RGB, 224x224

    # Export the model
    onnx_path = "vit_model.onnx"
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=16,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={"input": {0: "batch_size"}}
    )
    print(f"ONNX model exported to {onnx_path}")
    return onnx_path

# 2. Quantize the ONNX Model
def quantize_model(onnx_path):
    quantized_onnx_path = "vit_model_quantized.onnx"
    quantize_dynamic(
        model_input=onnx_path,
        model_output=quantized_onnx_path,
        weight_type=QuantType.QInt8,
        nodes_to_exclude=["/vit/embeddings/patch_embeddings/projection/Conv"]
    )
    print(f"Quantized ONNX model saved to {quantized_onnx_path}")
    return quantized_onnx_path


In [16]:
import numpy as np
import torch
import onnxruntime as ort
from tqdm import tqdm

def eval_onnx_model(ort_session, test_loader, criterion, device="cuda"):
    eval_loss = 0.0
    eval_correct_top1 = 0
    eval_correct_top3 = 0
    eval_total = 0

    if device == "cuda":
        torch.cuda.synchronize()
    start_time = time.time()

    with torch.no_grad():
        for images, labels in tqdm(test_loader, desc="Evaluating"):
            # Move data to GPU
            images, labels = images.to(device), labels.to(device)

            # Run ONNX model inference
            ort_inputs = {ort_session.get_inputs()[0].name: images.cpu().numpy()}
            outputs = ort_session.run(None, ort_inputs)

            logits = torch.tensor(outputs[0], device=device)  # Ensure logits are on the GPU

            loss = criterion(logits, labels)
            eval_loss += loss.item()

            # Get top-1 and top-3 predictions
            top1_predicted = logits.argmax(dim=1)
            top3_predicted = logits.topk(3, 1, largest=True, sorted=True).indices

            eval_correct_top1 += (top1_predicted == labels).sum().item()

            # Check if the correct label is in the top-3 predictions
            correct_top3 = top3_predicted.eq(labels.view(-1, 1).expand_as(top3_predicted)).any(dim=1)
            eval_correct_top3 += correct_top3.sum().item()

            eval_total += labels.size(0)

    if device == "cuda":
        torch.cuda.synchronize()
    end_time = time.time()

    eval_time = end_time - start_time

    eval_loss /= len(test_loader)
    eval_accuracy_top1 = eval_correct_top1 / eval_total
    eval_accuracy_top3 = eval_correct_top3 / eval_total

    print(f"Eval Loss: {eval_loss:.4f}")
    print(f"Eval Top-1 Accuracy: {eval_accuracy_top1:.4f}")
    print(f"Eval Top-3 Accuracy: {eval_accuracy_top3:.4f}")
    print(f"Evaluation Time: {eval_time:.2f} seconds")


In [17]:
# Step 1: Export PyTorch model to ONNX
onnx_path = export_to_onnx()


ONNX model exported to vit_model.onnx


In [18]:
# Step 2: Quantize the ONNX model
quantized_onnx_path = quantize_model(onnx_path)





Quantized ONNX model saved to vit_model_quantized.onnx


In [14]:
providers = ["CPUExecutionProvider"]
ort_session = ort.InferenceSession(quantized_onnx_path, providers=providers)
eval_onnx_model(ort_session, test_loader, criterion, device="cuda")

NotImplemented: [ONNXRuntimeError] : 9 : NOT_IMPLEMENTED : Could not find an implementation for ConvInteger(10) node with name '/vit/embeddings/patch_embeddings/projection/Conv_quant'

In [None]:
!git clone https://github.com/VainF/Torch-Pruning.git
!cd Torch-Pruning && pip install -e .

Cloning into 'Torch-Pruning'...
remote: Enumerating objects: 6909, done.[K
remote: Total 6909 (delta 0), reused 0 (delta 0), pack-reused 6909 (from 1)[K
Receiving objects: 100% (6909/6909), 10.08 MiB | 13.50 MiB/s, done.
Resolving deltas: 100% (4633/4633), done.
Obtaining file:///content/Torch-Pruning
  Preparing metadata (setup.py) ... [?25l[?25hdone
Installing collected packages: torch-pruning
  Running setup.py develop for torch-pruning
Successfully installed torch-pruning-1.5.0


In [None]:
!cp -r Torch-Pruning/torch_pruning/ torch_pruning/
!rm -rf Torch-Pruning/