### 1. Post Training Quantization

In [None]:
import torch
import torchvision.models as models
from torchvision.models import ResNet18_Weights
import torch.quantization

# Load FP32 model with updated weights
model_fp32 = models.resnet18(weights=ResNet18_Weights.DEFAULT)
model_fp32.eval()

# Fuse modules
model_fp32_fused = torch.quantization.fuse_modules(
    model_fp32, [["conv1", "bn1", "relu"]], inplace=False)

# Perform Post-Training Quantization (PTQ)
model_int8 = torch.quantization.quantize_dynamic(
    model_fp32_fused, {torch.nn.Linear}, dtype=torch.qint8)

# Save quantized model
torch.save(model_int8.state_dict(), "resnet18_ptq_int8.pth")


### 2. Quantization Aware Training

In [None]:
import torch
import torchvision
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.models import ResNet18_Weights

# Dataset and DataLoader
transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor()])
dataset = torchvision.datasets.FakeData(size=1000, image_size=(3, 224, 224), transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, num_workers=0)

# Model
model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
model.eval()
model = torch.quantization.fuse_modules(model, [["conv1", "bn1", "relu"]], inplace=False)
model.train()

# Quantization config
model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
torch.quantization.prepare_qat(model, inplace=True)

# Optimizer and Loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
criterion = torch.nn.CrossEntropyLoss()

# Training Loop
for epoch in range(2):  # Reduced for testing
    print(f"Epoch {epoch + 1}...")
    for batch_idx, (images, labels) in enumerate(dataloader):
        print(f"Batch {batch_idx + 1}...")
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        print(f"Loss: {loss.item()}")
        loss.backward()
        optimizer.step()

# Calibration
print("Calibrating model...")
model.eval()
with torch.no_grad():
    for images, _ in dataloader:
        model(images)
        break
print("Calibration complete!")

# Convert to quantized model
quantized_model = torch.quantization.convert(model, inplace=False)
print("Quantization complete!")

# Save model
torch.save(quantized_model.state_dict(), "resnet18_qat_int8.pth")
print("Model saved!")
