In [None]:
import time
import torch
import torchvision
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from torch.utils.data import Subset
import torchvision.transforms as transforms

Kept the following code block for reference. We don't use `torch.quantization` - no quantization in general. With quantization, we could go very low in precision (e.g., int8), see [DoReFa](https://arxiv.org/pdf/1606.06160), for example.

We also don't use [Automatic Mixed Precision](https://pytorch.org/docs/stable/amp.html).

For now, we just go from float64 down to float16 as a first step.


In [None]:
print("Supported engines:", torch.backends.quantized.supported_engines)
print("Currently active quantized engine:", torch.backends.quantized.engine)
torch.backends.quantized.engine = "qnnpack"
print("Now active quantized engine:", torch.backends.quantized.engine)

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
EPOCHS = 3
K_FOLDS = 5
MOMENTUM = 0.9
BATCH_SIZE = 64
LEARNING_REATE = 0.01
DTYPES = [torch.float64, torch.float32, torch.float16]

In [None]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        # Linear layers & ReLU are separate for fusing when quantizing (if someone wants to implement that)
        # MNIST => 28x28=784
        # Linear = fully connected, y=xA^T+b
        # ReLU = activation for non-linearity
        # Output is 10 logits ("logit" = raw output of a linear layer = pre-activation) used later with CrossEntropyLoss
        self.fc1 = nn.Linear(784, 128)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        # 28x28 MNIST images are flattened into 784-dim vectors
        x = x.view(x.size(0), -1)  # [batchSize,1,28,28] to [batchSize,784]
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        return x


def trainEpoch(model, loader, criterion, optimizer, device="cpu", dtype=torch.float32):
    model.train()
    running_loss = 0.0
    for images, labels in loader:

        images = images.to(device, dtype=dtype)  # Adapt input data to precision
        labels = labels.to(device)  # Labels stay as integer type (LongTensor)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    return running_loss / len(loader)


def evaluate(model, loader, device="cpu", dtype=torch.float32):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device, dtype=dtype)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100.0 * correct / total

In [None]:
# https://stackoverflow.com/questions/63746182/correct-way-of-normalizing-and-scaling-the-mnist-dataset
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)
full_dataset = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)
train_set = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform,
)
val_set = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=transform,
)
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=BATCH_SIZE,
    shuffle=True,
)
val_loader = torch.utils.data.DataLoader(
    val_set,
    batch_size=BATCH_SIZE,
    shuffle=False,
)

## Training from scratch with decreasing precision

This should answer the question: _How does changing the precision during training affect the final model’s accuracy and performance?_


In [None]:
fold_size = len(full_dataset) // K_FOLDS
indices = list(range(len(full_dataset)))
np.random.shuffle(indices)

results = {}
timings = {}

for current_dtype in DTYPES:
    print("\n==============================================")
    print(f"Training with dtype = {current_dtype}")
    fold_accuracies = []
    train_times = []
    eval_times = []

    for fold in range(K_FOLDS):
        print(f"\n--- Fold {fold+1}/{K_FOLDS} ---")

        val_indices = indices[fold * fold_size : (fold + 1) * fold_size]
        train_indices = indices[: fold * fold_size] + indices[(fold + 1) * fold_size :]

        train_subset = Subset(full_dataset, train_indices)
        val_subset = Subset(full_dataset, val_indices)

        train_loader = torch.utils.data.DataLoader(
            train_subset, batch_size=BATCH_SIZE, shuffle=True
        )
        val_loader = torch.utils.data.DataLoader(
            val_subset, batch_size=BATCH_SIZE, shuffle=False
        )

        model = MLP().to(device, dtype=current_dtype)
        # Same dtype necessary here for CrossEntropyLoss? (I mean it is a subclasses of nn.Module so...)
        criterion = nn.CrossEntropyLoss().to(device, dtype=current_dtype)
        optimizer = optim.SGD(model.parameters(), lr=LEARNING_REATE, momentum=MOMENTUM)

        train_start = time.time()
        for epoch in range(EPOCHS):
            loss_val = trainEpoch(
                model,
                train_loader,
                criterion,
                optimizer,
                device=device,
                dtype=current_dtype,
            )
            acc_val = evaluate(model, val_loader, device=device, dtype=current_dtype)
            print(
                f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss_val:.4f}, Val Acc: {acc_val:.2f}%"
            )
        train_end = time.time()
        train_duration = train_end - train_start
        train_times.append(train_duration)

        eval_start = time.time()
        final_acc = evaluate(model, val_loader, device=device, dtype=current_dtype)
        eval_end = time.time()
        eval_duration = eval_end - eval_start
        eval_times.append(eval_duration)

        fold_accuracies.append(final_acc)
        print(f"Fold {fold+1} accuracy: {final_acc:.2f}%")
        print(
            f"Training time: {train_duration:.2f}s | Inference time: {eval_duration:.2f}s"
        )

    avg_acc = sum(fold_accuracies) / K_FOLDS
    avg_train_time = sum(train_times) / K_FOLDS
    avg_eval_time = sum(eval_times) / K_FOLDS

    results[str(current_dtype)] = avg_acc
    timings[str(current_dtype)] = {
        "avg_train_time": avg_train_time,
        "avg_eval_time": avg_eval_time,
    }

print("\nAveraged results over 5 folds:")
for dtype_str, acc_val in results.items():
    print(f"{dtype_str}: {acc_val:.2f}%")

print("\nAverage timing per dtype:")
for dtype_str, t in timings.items():
    print(
        f"{dtype_str}: Train = {t['avg_train_time']:.2f}s, Eval = {t['avg_eval_time']:.2f}s"
    )

In [None]:
results = {
    "float64": 96.82,
    "float32": 96.76,
    "float16": 96.74,
}
df = pd.DataFrame(list(results.items()), columns=["Precision", "Accuracy"])
plt.figure(figsize=(8, 5), dpi=300)
plt.bar(df["Precision"], df["Accuracy"], color="#b9a4c6")
plt.title("Model Accuracies (Training and Inference)")
plt.xlabel("Precision Type")
plt.ylabel("Accuracy (%)")
plt.ylim(96.5, 97)
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()

In [None]:
timings = {
    "torch.float64": {"Train": 17.92, "Eval": 1.13},
    "torch.float32": {"Train": 17.20, "Eval": 1.05},
    "torch.float16": {"Train": 17.88, "Eval": 1.14},
}
timing_df = pd.DataFrame(timings).T.reset_index().rename(columns={"index": "Precision"})
timing_df = timing_df[["Precision", "Train", "Eval"]]
x = range(len(timing_df))
width = 0.35
plt.figure(figsize=(9, 5), dpi=300)
plt.bar(
    [i - width / 2 for i in x],
    timing_df["Train"],
    width=width,
    label="Train Time",
    color="#b9a4c6",
)
plt.bar(
    [i + width / 2 for i in x],
    timing_df["Eval"],
    width=width,
    label="Eval Time",
    color="#d2bfd8",
)
plt.xticks(x, timing_df["Precision"])
plt.title("Average Training and Inference Times")
plt.xlabel("Precision Type")
plt.ylabel("Time (sec)")
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.legend()
plt.tight_layout()
plt.show()

## Train on single (high) precision and test lower-precision inference effects (= inference-only testing)

This should show the effect of precision during inference after training a single model.

So the model is trained with float64 and then a **post-training precision reduction** is done.


In [None]:
results = {
    "float64": [],
    "float32_eval_only": [],
    "float16_eval_only": [],
}

for fold in range(K_FOLDS):
    print(f"\n==============================================")
    print(f"Training Fold {fold+1}/{K_FOLDS} with dtype = float64")

    val_indices = indices[fold * fold_size : (fold + 1) * fold_size]
    train_indices = indices[: fold * fold_size] + indices[(fold + 1) * fold_size :]

    train_subset = Subset(full_dataset, train_indices)
    val_subset = Subset(full_dataset, val_indices)

    train_loader = torch.utils.data.DataLoader(
        train_subset, batch_size=BATCH_SIZE, shuffle=True
    )
    val_loader = torch.utils.data.DataLoader(
        val_subset, batch_size=BATCH_SIZE, shuffle=False
    )

    model = MLP().to(device, dtype=torch.float64)  # float64 only
    criterion = nn.CrossEntropyLoss().to(device, dtype=torch.float64)
    optimizer = optim.SGD(model.parameters(), lr=LEARNING_REATE, momentum=MOMENTUM)

    for epoch in range(EPOCHS):
        loss_val = trainEpoch(
            model,
            train_loader,
            criterion,
            optimizer,
            device=device,
            dtype=torch.float64,
        )
        acc_val = evaluate(model, val_loader, device=device, dtype=torch.float64)
        print(
            f"Epoch {epoch+1}/{EPOCHS}, Loss: {loss_val:.4f}, Val Acc: {acc_val:.2f}%"
        )

    # float64 eval
    acc_float64 = evaluate(model, val_loader, device=device, dtype=torch.float64)
    results["float64"].append(acc_float64)
    print(f"Fold {fold+1} accuracy (float64): {acc_float64:.2f}%")

    ################################################################################

    # Eval-only with float32
    model_float32 = model.to(dtype=torch.float32)
    acc_float32 = evaluate(
        model_float32, val_loader, device=device, dtype=torch.float32
    )
    results["float32_eval_only"].append(acc_float32)
    print(f"Fold {fold+1} accuracy (eval only, float32): {acc_float32:.2f}%")

    ################################################################################

    # Eval-only with float16
    model_float16 = model.to(dtype=torch.float16)
    acc_float16 = evaluate(
        model_float16, val_loader, device=device, dtype=torch.float16
    )
    results["float16_eval_only"].append(acc_float16)
    print(f"Fold {fold+1} accuracy (eval only, float16): {acc_float16:.2f}%")


print("\nAveraged results over 5 folds:")
for key in results:
    avg = sum(results[key]) / K_FOLDS
    print(f"{key}: {avg:.2f}%")

In [None]:
results = {
    "float64 (trained + eval)": 96.77,
    "float32 (eval only)": 96.77,
    "float16 (eval only)": 96.77,
}
df = pd.DataFrame(list(results.items()), columns=["Precision", "Accuracy"])
plt.figure(figsize=(8, 5), dpi=300)
plt.bar(df["Precision"], df["Accuracy"], color="#b9a4c6")
plt.title("Model Accuracies (Inference Only)")
plt.xlabel("Precision Type")
plt.ylabel("Accuracy (%)")
plt.ylim(96.5, 97)
plt.grid(axis="y", linestyle="--", alpha=0.7)
plt.tight_layout()
plt.show()