<a href="https://colab.research.google.com/github/OneFineStarstuff/OneFineStarstuff/blob/main/Model_Quantization_for_Reduced_Energy_Usage.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define a simple model
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Initialize and train your model as before
model_fp32 = MyModel()
optimizer = optim.Adam(model_fp32.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# Dummy data creation
X = torch.randn(1000, 784)  # Example data
y = torch.randint(0, 10, (1000,))  # Example labels

# Split data into DataLoader
train_dataset = TensorDataset(X, y)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

# Train the model
model_fp32.train()
for epoch in range(10):  # 10 epochs
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model_fp32(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# Set the model to evaluation mode
model_fp32.eval()

# Fuse modules (if applicable)
# This step is essential for certain models like those with Conv2d, BatchNorm2d, ReLU
# Not needed for this simple model
# model_fp32 = torch.quantization.fuse_modules(model_fp32, [['conv', 'bn', 'relu']])

# Prepare the model for quantization
model_fp32.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model_fp32, inplace=True)

# Calibration step (using some sample data)
with torch.no_grad():
    for inputs, _ in train_loader:
        model_fp32(inputs)

# Convert the model to quantized version
model_int8 = torch.quantization.convert(model_fp32, inplace=True)

# Check the quantized model
print(model_int8)