In [12]:
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim.adam
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from mx import mx_mapping, finalize_mx_specs


In [13]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [14]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

Files already downloaded and verified


In [15]:
def train(model, train_loader, device):
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    for epoch in range(10):
        for data, target in train_loader:
            data = data.to(device) 
            target = target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1} done.")

In [16]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [17]:
# Setup MXFP4
mx_specs = {
    'w_elem_format': 'fp4_e2m1',
    'a_elem_format': 'fp4_e2m1',
    'scale_bits': 8,
    'block_size': 32,
    'custom_cuda': True,
    'bfloat': 16,
    'quantize_backprop': True
}

mx_specs = finalize_mx_specs(mx_specs)
mx_mapping.inject_pyt_ops(mx_specs)

In [18]:
model_mxfp4 = SimpleCNN().to(device)
train(model_mxfp4, train_loader, device)

Epoch 1 done.
Epoch 2 done.
Epoch 3 done.
Epoch 4 done.
Epoch 5 done.
Epoch 6 done.
Epoch 7 done.
Epoch 8 done.
Epoch 9 done.
Epoch 10 done.


In [19]:
def evaluate_average(model, test_loader, device, num_evaluations=10):
    total_accuracy = 0
    total_time = 0
    for _ in range(num_evaluations):
        start_time = time.time()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data = data.to(device)
                target = target.to(device)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        accuracy = 100 * correct / total
        inference_time = time.time() - start_time

        total_accuracy += accuracy
        total_time += inference_time

    average_accuracy = total_accuracy / num_evaluations
    average_time = total_time / num_evaluations
    return average_accuracy, average_time

In [20]:
average_accuracy_mxfp4, average_time_mxfp4 = evaluate_average(model_mxfp4, test_loader, device)
print(f"MXFP4 - Average Accuracy: {average_accuracy_mxfp4}%, Average Inference Time: {average_time_mxfp4} seconds")

MXFP4 - Average Accuracy: 57.67%, Average Inference Time: 2.3505164623260497 seconds
