In [1]:
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from mx import mx_mapping, finalize_mx_specs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
# Data preprocessing and loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

Files already downloaded and verified


In [4]:
# Training function
def train(model, train_loader, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    best_loss = float('inf')
    best_model_path = 'best_model.pth'
    
    for epoch in range(15):
        epoch_loss = 0
        for data, target in train_loader:
            data = data.to(device) 
            target = target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch+1} done. Loss: {epoch_loss}")

        # Save the best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), best_model_path)
    
    return best_model_path

In [5]:
# Evaluation function
def evaluate_average(model, test_loader, device, num_evaluations=10):
    total_accuracy = 0
    total_time = 0
    for _ in range(num_evaluations):
        start_time = time.time()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data = data.to(device)
                target = target.to(device)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        accuracy = 100 * correct / total
        inference_time = time.time() - start_time

        total_accuracy += accuracy
        total_time += inference_time

    average_accuracy = total_accuracy / num_evaluations
    average_time = total_time / num_evaluations
    return average_accuracy, average_time

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Train the model in FP32
model_fp32 = SimpleCNN().to(device)
best_model_path = train(model_fp32, train_loader, device)

Epoch 1 done. Loss: 1.3398408833367135
Epoch 2 done. Loss: 0.9525014681127065
Epoch 3 done. Loss: 0.7797165930728474
Epoch 4 done. Loss: 0.6401776306693207
Epoch 5 done. Loss: 0.5138734489907999
Epoch 6 done. Loss: 0.3950684686832111
Epoch 7 done. Loss: 0.29112268844262107
Epoch 8 done. Loss: 0.20943060401074418
Epoch 9 done. Loss: 0.14275533614723046
Epoch 10 done. Loss: 0.11132074313004837
Epoch 11 done. Loss: 0.08878128969079584
Epoch 12 done. Loss: 0.07690738045367533
Epoch 13 done. Loss: 0.07928652126911333
Epoch 14 done. Loss: 0.060265426246194485
Epoch 15 done. Loss: 0.052613999150024106


In [8]:
model_fp32.load_state_dict(torch.load(best_model_path))

<All keys matched successfully>

In [9]:
# Evaluate in FP32
average_accuracy_fp32, average_time_fp32 = evaluate_average(model_fp32, test_loader, device)
print(f"FP32 - Average Accuracy: {average_accuracy_fp32}%, Average Inference Time: {average_time_fp32} seconds")

FP32 - Average Accuracy: 71.56999999999998%, Average Inference Time: 1.7250673532485963 seconds


In [10]:
# Setup MXFP4
mx_specs = {
    'w_elem_format': 'fp4_e2m1',
    'a_elem_format': 'fp4_e2m1',
    'w_elem_format_bp':'fp4_e2m1',
    'a_elem_format_bp_ex':'fp4_e2m1',
    'a_elem_format_bp_os': 'int8',
    'scale_bits': 8,
    'block_size': 32,
    'custom_cuda': True,
    'bfloat': 16,
    'quantize_backprop': True
}

mx_specs = finalize_mx_specs(mx_specs)
mx_mapping.inject_pyt_ops(mx_specs)

In [11]:
# Load the best model and cast to FP4
model_mxfp4 = SimpleCNN().to(device)
# model_mxfp4.load_state_dict(torch.load(best_model_path))
best_model_mxfp4_model_path = train(model_mxfp4, train_loader, device)

Epoch 1 done. Loss: 1.3566702778077186
Epoch 2 done. Loss: 1.0043991924551747
Epoch 3 done. Loss: 0.84799560622486
Epoch 4 done. Loss: 0.7216999481248734
Epoch 5 done. Loss: 0.6086551892711684
Epoch 6 done. Loss: 0.49725582181950057
Epoch 7 done. Loss: 0.39372597097435874
Epoch 8 done. Loss: 0.3075755459763815
Epoch 9 done. Loss: 0.24632705811916106
Epoch 10 done. Loss: 0.19897673609654618
Epoch 11 done. Loss: 0.1632521321015704
Epoch 12 done. Loss: 0.14085606753807087
Epoch 13 done. Loss: 0.12328241868456706
Epoch 14 done. Loss: 0.12319542178074303
Epoch 15 done. Loss: 0.11006366661862921


In [12]:
model_mxfp4.load_state_dict(torch.load(best_model_mxfp4_model_path))

<All keys matched successfully>

In [13]:
# Evaluate in FP4
average_accuracy_mxfp4, average_time_mxfp4 = evaluate_average(model_mxfp4, test_loader, device)
print(f"MXFP4 - Average Accuracy: {average_accuracy_mxfp4}%, Average Inference Time: {average_time_mxfp4} seconds")

MXFP4 - Average Accuracy: 68.98%, Average Inference Time: 2.143258571624756 seconds
