In [13]:
import os
import torch
import torch.nn as nn
import torch.ao.quantization as quant
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat, convert
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.observer import MinMaxObserver

from transformers import ViTForImageClassification
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.transforms import InterpolationMode

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm



In [2]:
# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
# Load a pretrained ViT model from Hugging Face
model_name = "google/vit-base-patch16-224"
model = ViTForImageClassification.from_pretrained(
    model_name,
    ignore_mismatched_sizes=True,
)

# Modify the classifier to match the Food101 dataset (101 classes)
model.classifier = nn.Linear(model.config.hidden_size, 101)
model = model.to(device)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/69.7k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [4]:
# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)


In [7]:
# Optimized Data Transform Pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])


In [10]:
# Load Food101 Dataset
train_dataset = datasets.Food101(root='./data', split='train', transform=transform, download=True)
val_dataset = datasets.Food101(root='./data', split='test', transform=transform, download=True)
test_dataset = datasets.Food101(root='./data', split='test', transform=transform, download=True)

train_loader = DataLoader(
    train_dataset,
    batch_size=32,
    shuffle=True,
    num_workers=torch.multiprocessing.cpu_count(),
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=torch.multiprocessing.cpu_count(),
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=torch.multiprocessing.cpu_count(),
    pin_memory=True
)

Downloading https://data.vision.ee.ethz.ch/cvl/food-101.tar.gz to ./data/food-101.tar.gz


100%|██████████| 5.00G/5.00G [02:34<00:00, 32.3MB/s]


Extracting ./data/food-101.tar.gz to ./data


In [11]:
model.train()  # ENsuring Model is in training mode
model = prepare_qat(model, inplace=True)



In [12]:
model.train()

qconfig = get_default_qat_qconfig("fbgemm")
model.qconfig = qconfig
model = prepare_qat(model, inplace=True)


In [15]:
def train_qat_model(model, dataloader, optimizer, criterion, device, epochs=1, save_dir="qat_checkpoints"):
    model.train()
    os.makedirs(save_dir, exist_ok=True)

    for epoch in range(epochs):
        total_loss, correct, total = 0.0, 0, 0
        with tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}") as tepoch:
            for images, labels in tepoch:
                images = images.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()
                outputs = model(images).logits
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                total_loss += loss.item()
                preds = torch.argmax(outputs, dim=1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

                running_loss = total_loss / total
                running_accuracy = 100.0 * correct / total

                tepoch.set_postfix(loss=running_loss, accuracy=running_accuracy)

        avg_loss = total_loss / len(dataloader.dataset)
        accuracy = 100.0 * correct / total
        print(f"Epoch {epoch+1}/{epochs} - Loss: {avg_loss:.4f}, Accuracy: {accuracy:.2f}%")

        checkpoint_path = os.path.join(save_dir, f"checkpoint_epoch_{epoch+1}.pth")
        torch.save(model.state_dict(), checkpoint_path)
        print(f"Checkpoint saved at {checkpoint_path}")

In [16]:
train_qat_model(model, train_loader, optimizer, criterion, epochs=1, device=device)



Epoch 1/1:   0%|          | 9/2368 [00:11<51:18,  1.30s/it, accuracy=3.82, loss=0.148]


KeyboardInterrupt: 

In [59]:
from torch.ao.quantization import convert
quantized_model = convert(model)

In [56]:

def evaluate_model(model, dataloader, device):
    correct, total = 0, 0
    all_preds, all_labels = [], []

    model.eval()
    with torch.no_grad():
        for images, labels in tqdm(dataloader, desc="Evaluating"):
            images, labels = images.to(device, non_blocking=True), labels.to(device, non_blocking=True)

            # Forward pass
            outputs = model(images).logits
            preds = torch.argmax(outputs, dim=1)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    # Calculate Metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)

    # Print Results
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")

    return accuracy, precision, recall, f1


In [60]:
# Print all layers and check their data type
for name, module in quantized_model.named_modules():
    # For layers with parameters
    if hasattr(module, 'weight'):
        print(f"Layer: {name}")
        print(f"  Type: {type(module)}")
        print(f"  Weight dtype: {module.weight.dtype}")
        print(f"  Device: {module.weight.device}")
        print("-" * 40)

Layer: vit.embeddings.patch_embeddings.projection
  Type: <class 'torch.nn.modules.conv.Conv2d'>
  Weight dtype: torch.float32
  Device: cuda:0
----------------------------------------
Layer: vit.encoder.layer.0.attention.attention.query
  Type: <class 'torch.nn.modules.linear.Linear'>
  Weight dtype: torch.float32
  Device: cuda:0
----------------------------------------
Layer: vit.encoder.layer.0.attention.attention.key
  Type: <class 'torch.nn.modules.linear.Linear'>
  Weight dtype: torch.float32
  Device: cuda:0
----------------------------------------
Layer: vit.encoder.layer.0.attention.attention.value
  Type: <class 'torch.nn.modules.linear.Linear'>
  Weight dtype: torch.float32
  Device: cuda:0
----------------------------------------
Layer: vit.encoder.layer.0.attention.output.dense
  Type: <class 'torch.nn.modules.linear.Linear'>
  Weight dtype: torch.float32
  Device: cuda:0
----------------------------------------
Layer: vit.encoder.layer.0.intermediate.dense
  Type: <class