### Small example to check how autocast operates in torch

In [None]:
# small mixed precision test
import torch
import torch.nn as nn
import torch.optim as optim
from torch.cuda.amp import GradScaler, autocast
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import time
# Define a simple model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv = nn.Conv2d(1, 32, 3, 1)
        self.fc = nn.Linear(32 * 26 * 26, 10)

    def forward(self, x, use_autocast=False):
        if use_autocast:
            with autocast(dtype=torch.float16):
                x = self.conv(x)
                x = torch.relu(x)
                x = torch.flatten(x, 1)
                x = self.fc(x)
        else:
            x = self.conv(x)
            x = torch.relu(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
        return x

# Training function with mixed precision
def train(model, device, train_loader, optimizer, criterion, scaler, epoch, use_autocast=False):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device).half(), target.to(device)

        optimizer.zero_grad()

        # Enables autocasting for the forward pass
        if use_autocast:
            with autocast(dtype=torch.float16):
                output = model(data)
                loss = criterion(output, target)
            # Scales the loss, and calls backward()
            scaler.scale(loss).backward()

            # Unscales gradients and calls optimizer.step()
            scaler.step(optimizer)

            # Updates the scale for next iteration
            scaler.update()
        else:
            output = model(data)
            loss = criterion(output.float(), target)
            loss.backward()
            optimizer.step()

        

        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/'
                f'{len(train_loader.dataset)}] Loss: {loss.item():.6f}')
            # check precision used with the auto cast
            for name, param in model.named_parameters():
                print(f'Parameter: {name}, dtype: {param.dtype}')


def main():
    # Check device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Hyperparameters
    batch_size = 64
    epochs = 5
    learning_rate = 1e-3

    # Data loaders
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

    # Initialize model, loss, optimizer
    model = SimpleCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    # Initialize GradScaler
    scaler = GradScaler()

    epoch = 1

    # Training loop
    # benchmark with and without autocast
    # print("Benchmarking without autocast...")
    # start_time = time.time()
    # for epoch in range(1):
    #     train(model, device, train_loader, optimizer, criterion, scaler, epoch, use_autocast=False)
    # time_without_autocast = time.time() - start_time
    # print(f"Average time per batch without autocast: {time_without_autocast:.6f} seconds")

    print("\nBenchmarking with autocast...")
    start_time = time.time()
    for epoch in range(1):
        train(model, device, train_loader, optimizer, criterion, scaler, epoch, use_autocast=False)
    time_with_autocast = time.time() - start_time
    print(f"Average time per batch with autocast: {time_with_autocast:.6f} seconds")

    # print("benchmarking with float16")
    # start_time = time.time()
    # for epoch in range(1):
    #     train(model.half(), device, train_loader, optimizer, criterion, scaler, epoch, use_autocast=False)
    # time_with_autocast = time.time() - start_time
    # print(f"Average time per batch with autocast: {time_with_autocast:.6f} seconds")

    # Compare results
    #for epoch in range(1, epochs + 1):
    #    train(model, device, train_loader, optimizer, criterion, scaler, epoch, use_autocast=True)

main()
