In [76]:
import numpy as np
import torch
import torch.nn as nn
import subprocess

# Write raw file 
def write_raw_file(filename, data):
    with open(filename, 'wb') as f:
        # Convert the tensor to a numpy array and write it as raw bytes
        np_data = data.numpy().astype(np.float16)
        f.write(np_data.tobytes())
        
        
def read_raw_file(filename, shape):
    with open(filename, 'rb') as f:
        # Read the raw bytes and convert them to a numpy array
        data = np.fromfile(f, dtype=np.float16)
        return torch.tensor(data, dtype=torch.float32).reshape(shape)


def execute_inference(model: str):
    # Run the inference command
    command = ["../runtime/simple", model, "input.raw", "output.raw"]
    subprocess.run(command)
    
def perform_complex_fft(input_tensor): 
    # For the batch dim
    batch_size = input_tensor.shape[0]
    
    # output tensor
    output_tensor = torch.empty_like(input_tensor)
    
    for i in range(batch_size):
        # The input tensor is interleaved, so split it into real and imaginary parts
        real_part = input_tensor[i, 0::2]
        imag_part = input_tensor[i, 1::2]
        
        c = torch.complex(real_part, imag_part)
        
        fft = torch.fft.fft(c, dim=-1)
        
        # Interleave the real and imaginary parts back into a single tensor
        output_tensor[i, 0::2] = fft.real
        output_tensor[i, 1::2] = fft.imag

    return output_tensor

In [193]:
# Generate a random input tensor
BATCH = 512
POINTS = 256

input_tensor = torch.rand(BATCH, POINTS*2, dtype=torch.float32) * 100
# input_tensor[0, 2] = 1.0  # Set first element to 1.0


write_raw_file("input.raw", input_tensor)
execute_inference(f"/var/scratch/dsl2511/nvdla-parts/dft_model-c256-b{BATCH}.nvdla")
# execute_inference(f"/var/scratch/dsl2511/fft_model-c256-b{BATCH}.nvdla")

results = read_raw_file("output.raw", (BATCH, POINTS*2))

First FP16 value in input data: 19.6562
numInputTensors = 1
numOutputTensors = 1
Tensor desc: 
Name: input'
Size: 524288
N: 512 C: 512 H: 1 W: 1
Data format: 3
Data type: 2
Data category: 2
Pixel format: 36
Pixel mapping: 0
Stride: 2 32 32 0 1024 0 0 0 

Tensor desc: 
Name: output'
Size: 524288
N: 512 C: 512 H: 1 W: 1
Data format: 3
Data type: 2
Data category: 2
Pixel format: 36
Pixel mapping: 0
Stride: 2 32 32 0 1024 0 0 0 

12608.000000000012808.0000000000-666.5000000000-750.0000000000734.5000000000-448.750000000093.1250000000390.0000000000202.8750000000340.0000000000322.2500000000-427.7500000000-85.875000000064.8750000000207.2500000000465.5000000000662.5000000000468.5000000000479.5000000000131.8750000000-298.5000000000-34.9687500000444.5000000000-166.6250000000403.0000000000369.2500000000350.5000000000344.7500000000-330.000000000048.4687500000-558.0000000000-3.7226562500-87.6875000000-525.5000000000-353.5000000000259.0000000000-227.1250000000-135.5000000000349.7500000000-98.62500000

In [194]:
# print("Input Tensor:")
# print(input_tensor)
# print("Output Tensor:")
# tIn = torch.tensor(input_tensor[0], dtype=torch.float32)]
base = perform_complex_fft(input_tensor)

print("Base Output Tensor:")
print(base[0, 0:10])  # Print first 10 elements for brevity
print("Results Tensor:")
print(results[0, 0:10])  # Print first 10 elements for brevity

# Calculate error
max_abs_error = torch.max(torch.abs(results - base))
percent_error = (max_abs_error / torch.max(torch.abs(base))) * 100
print(f"Max absolute error: {max_abs_error.item()}")
print(f"Percent error: {percent_error.item()}%")

# Mean absolute error
mae = torch.mean(torch.abs(results - base))
print(f"Mean Absolute Error: {mae.item()}")

mse = torch.mean((results - base)**2)
rel_mse = mse / torch.mean(base**2)
print(f"Mean Squared Error: {mse.item()}")
print(f"Relative Mean Squared Error: {rel_mse.item()}")

rel_l2_error = torch.norm(results - base) / torch.norm(base)
print(f"Relative L2 Error: {rel_l2_error.item()}")

Base Output Tensor:
tensor([12607.1758, 12808.2764,  -666.1183,  -749.3750,   733.9633,  -448.9785,
           93.0284,   389.9610,   202.6745,   340.2129])
Results Tensor:
tensor([12608.0000, 12808.0000,  -666.5000,  -750.0000,   734.5000,  -448.7500,
           93.1250,   390.0000,   202.8750,   340.0000])
Max absolute error: 4.201171875
Percent error: 0.029812637716531754%
Mean Absolute Error: 0.1869499832391739
Mean Squared Error: 0.07134275138378143
Relative Mean Squared Error: 8.360179748478913e-08
Relative L2 Error: 0.00028914198628626764
