<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Quantization_with_PyTorch_for_Efficient_Inference.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.quantization as tq
import torch.quantization.observer as observer
from torch.utils.data import DataLoader, TensorDataset

# Define a simple model
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(10, 2)

    def forward(self, x):
        return self.fc(x)

# Create dummy data
data = torch.randn(100, 10)
labels = torch.randint(0, 2, (100,))
dataset = TensorDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=16, shuffle=True)

# Initialize the model, loss function, and optimizer
model = SimpleModel()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model (for demonstration purposes)
model.train()
for epoch in range(5):
    for inputs, labels in data_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Quantization configuration
model.qconfig = tq.QConfig(
    activation=observer.default_observer.with_args(dtype=torch.quint8, quant_min=0, quant_max=255),
    weight=observer.default_per_channel_weight_observer
)

# Prepare model for quantization
tq.prepare(model, inplace=True)

# Calibrate the model with a few batches of data
model.eval()
with torch.no_grad():
    for inputs, _ in data_loader:
        model(inputs)

# Convert to quantized version
tq.convert(model, inplace=True)

# Verify the quantized model
print(model)