In [32]:
import torch
import torch.nn as nn
import time
import numpy as np

def test(device = None):
    # Select device: use XPU if available, else fallback to CPU
    device = torch.device(device)

    # Define the neural network architecture
    class NeuralNetwork(nn.Module):
        def __init__(self):
            super(NeuralNetwork, self).__init__()
            self.layer1 = nn.Linear(40, 1024)
            self.layer2 = nn.Linear(1024, 1024)
            self.layer3 = nn.Linear(1024, 1024)
            self.layer4 = nn.Linear(1024, 20)
            self.relu = nn.ReLU()

        def forward(self, x):
            x = self.relu(self.layer1(x))
            x = self.relu(self.layer2(x))
            x = self.relu(self.layer3(x))
            x = self.layer4(x)
            return x

    # Instantiate the model and move to device
    model = NeuralNetwork().to(device)
    model.eval()  # Set the model to evaluation mode

    # Create a dummy input tensor and move to device
    dummy_input = torch.randn(1, 40, device=device)

    # Number of times to run the inference for averaging
    num_runs = 100
    inference_times = []

    with torch.no_grad():  # Deactivate autograd for faster inference
        # Warm-up run (important to initialize everything)
        for _ in range(10):
            _ = model(dummy_input)

        # Measure inference time over multiple runs
        for _ in range(num_runs):
            start_time = time.perf_counter()
            output = model(dummy_input)
            end_time = time.perf_counter()
            inference_times.append(end_time - start_time)

    # Calculate the average inference time and standard deviation
    average_inference_time = np.mean(inference_times)
    std_dev_inference_time = np.std(inference_times)

    print(f"Number of runs: {num_runs}")
    print(f"Average inference time: {average_inference_time:.6f} seconds")
    print(f"Standard deviation: {std_dev_inference_time:.6f} seconds")

test('xpu')
test('cpu')

Number of runs: 100
Average inference time: 0.000275 seconds
Standard deviation: 0.000016 seconds
Number of runs: 100
Average inference time: 0.000427 seconds
Standard deviation: 0.000167 seconds
