In [None]:
import torch
from torchvision import datasets, transforms
from torch import nn, optim

from dmsq_quantizer import DMSQQuantizer

# Data loader for MNIST
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# Load the full training dataset
full_train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)

# Subset the dataset to use only 100 samples
subset_indices = torch.arange(100)  # Use the first 100 samples
train_subset = torch.utils.data.Subset(full_train_dataset, subset_indices)

# Create DataLoader for the subset
train_loader = torch.utils.data.DataLoader(train_subset, batch_size=64, shuffle=True)

# Test DataLoader remains the same
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transform),
    batch_size=64, shuffle=False
)

# Define the model, loss, and optimizer
device = torch.device('mps' if torch.cuda.is_available() else 'cpu')

class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.relu = nn.ReLU()
        self.fc = nn.Linear(26 * 26 * 32, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = x.view(-1, 26 * 26 * 32)
        x = self.fc(x)
        return x

model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
print("Training the model...")
for epoch in range(1, 100):  
    model.train()
    total_loss = 0
    for data, target in train_loader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch}, Loss: {total_loss / len(train_loader):.4f}")

# Evaluate the model
def test_model(model, loader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
    return correct / len(loader.dataset)

accuracy = test_model(model, test_loader)
print(f"Initial Test Accuracy: {accuracy:.4f}")

# Define precision map
precision_map = {
    'conv1': 8,  # Higher precision for sensitive layers
    'fc': 8,     # High precision for dense layers
    'default': 8  # Default precision
}

# Initialize the DMSQ quantizer
quantizer = DMSQQuantizer(model, precision_map, device=device)

# Sensitivity analysis
dummy_input = torch.randn(1, 1, 28, 28).to(device)
print("Performing sensitivity analysis...")
quantizer.layer_sensitivity_analysis(dummy_input, criterion)

# Workload optimization (set arbitrary memory and latency targets for the example)
print("Optimizing workload...")
quantizer.workload_optimization(target_memory=1e9, target_latency=50, inputs=dummy_input)

# Apply quantization
print("Applying quantization...")
quantizer.quantize_model()

# Test the quantized model
print("Testing the quantized model...")
quantized_accuracy = test_model(model, test_loader)
print(f"Quantized Test Accuracy: {quantized_accuracy:.4f}")
