In [1]:
import os
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from mx import mx_mapping, finalize_mx_specs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [3]:
# Data preprocessing and loading
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, transform=transform)
train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

Files already downloaded and verified


In [18]:
# Training function
def train(model, train_loader, device):
    model.train()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    best_loss = float('inf')
    best_model_path = 'best_model.pth'
    
    for epoch in range(15):
        epoch_loss = 0
        for data, target in train_loader:
            data = data.to(device) 
            target = target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        epoch_loss /= len(train_loader)
        print(f"Epoch {epoch+1} done. Loss: {epoch_loss}")

        # Save the best model
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            torch.save(model.state_dict(), best_model_path)
    
    return best_model_path

In [5]:
# Evaluation function
def evaluate_average(model, test_loader, device, num_evaluations=10):
    total_accuracy = 0
    total_time = 0
    for _ in range(num_evaluations):
        start_time = time.time()
        correct = 0
        total = 0
        with torch.no_grad():
            for data, target in test_loader:
                data = data.to(device)
                target = target.to(device)
                outputs = model(data)
                _, predicted = torch.max(outputs.data, 1)
                total += target.size(0)
                correct += (predicted == target).sum().item()
        accuracy = 100 * correct / total
        inference_time = time.time() - start_time

        total_accuracy += accuracy
        total_time += inference_time

    average_accuracy = total_accuracy / num_evaluations
    average_time = total_time / num_evaluations
    return average_accuracy, average_time

In [6]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
# Train the model in FP32
model_fp32 = SimpleCNN().to(device)
best_model_path = train(model_fp32, train_loader, device)

Epoch 1 done. Loss: 1.3176605977365732
Epoch 2 done. Loss: 0.9313664588019671
Epoch 3 done. Loss: 0.7648372553346102
Epoch 4 done. Loss: 0.6300224576078718
Epoch 5 done. Loss: 0.5067216605900804
Epoch 6 done. Loss: 0.3874464797051361
Epoch 7 done. Loss: 0.2889585506904613
Epoch 8 done. Loss: 0.20265076383757774
Epoch 9 done. Loss: 0.14291796909021143
Epoch 10 done. Loss: 0.11030141576586286


In [8]:
model_fp32.load_state_dict(torch.load(best_model_path))

<All keys matched successfully>

In [9]:
# Evaluate in FP32
average_accuracy_fp32, average_time_fp32 = evaluate_average(model_fp32, test_loader, device)
print(f"FP32 - Average Accuracy: {average_accuracy_fp32}%, Average Inference Time: {average_time_fp32} seconds")

FP32 - Average Accuracy: 71.86%, Average Inference Time: 2.2540157556533815 seconds


In [39]:
# Setup MXFP4
mx_specs = {
    'w_elem_format': 'fp4_e2m1',
    'a_elem_format': 'fp4_e2m1',
    'w_elem_format_bp':'fp4_e2m1',
    'a_elem_format_bp_ex':'fp4_e2m1',
    'a_elem_format_bp_os': 'int8',
    'scale_bits': 8,
    'block_size': 32,
    'custom_cuda': True,
    'bfloat': 16,
    'quantize_backprop': True
}

mx_specs = finalize_mx_specs(mx_specs)
mx_mapping.inject_pyt_ops(mx_specs)

In [40]:
# Load the best model and cast to FP4
model_mxfp4 = SimpleCNN().to(device)
# model_mxfp4.load_state_dict(torch.load(best_model_path))
best_model_mxfp4_model_path = train(model_mxfp4, train_loader, device)

Epoch 1 done. Loss: 1.3548200424674832
Epoch 2 done. Loss: 0.9949734219352303
Epoch 3 done. Loss: 0.8371523920913486
Epoch 4 done. Loss: 0.7198510145592263
Epoch 5 done. Loss: 0.607438024390689
Epoch 6 done. Loss: 0.5043549986408494
Epoch 7 done. Loss: 0.4094330814221631
Epoch 8 done. Loss: 0.31546328812265945
Epoch 9 done. Loss: 0.2577301764198581
Epoch 10 done. Loss: 0.20743465050578574
Epoch 11 done. Loss: 0.17977961066805417
Epoch 12 done. Loss: 0.14907906397872264
Epoch 13 done. Loss: 0.13879902921545575
Epoch 14 done. Loss: 0.1274116559351897
Epoch 15 done. Loss: 0.11853514945306018


In [41]:
model_mxfp4.load_state_dict(torch.load(best_model_mxfp4_model_path))

<All keys matched successfully>

In [42]:
# Evaluate in FP4
average_accuracy_mxfp4, average_time_mxfp4 = evaluate_average(model_mxfp4, test_loader, device)
print(f"MXFP4 - Average Accuracy: {average_accuracy_mxfp4}%, Average Inference Time: {average_time_mxfp4} seconds")

MXFP4 - Average Accuracy: 68.33000000000001%, Average Inference Time: 2.59999737739563 seconds
