In [None]:
import subprocess
import re
import numpy as np

import matplotlib.pyplot as plt

# Define the range of sizes
sizes = np.arange(224, 8193, 32)
gflops_values = []

# Execute the gemm program for each size and extract the GFLOP/s value
for size in sizes:
    try:
        # Execute the gemm program
        result = subprocess.run(f"./gemm {size}", shell=True, capture_output=True, text=True, check=True)
        output = result.stdout.strip()
        
        # Extract the GFLOP/s value using regex
        match = re.search(r"GFLOP/s:\s+(\d+\.\d+)", output)
        if match:
            gflops = float(match.group(1))
            gflops_values.append(gflops)
            print(f"Size {size}: {gflops} GFLOP/s")
        else:
            print(f"Could not parse output for size {size}: {output}")
            gflops_values.append(None)
    except subprocess.CalledProcessError as e:
        print(f"Error executing gemm for size {size}: {e}")
        gflops_values.append(None)

In [None]:
# Create the line plot
plt.figure(figsize=(12, 6))
plt.plot(sizes, gflops_values)
plt.axhline(y=80, color='black', linestyle='--', label='Theoretical Maximum (80 GFLOP/s)')
plt.title('GEMM Performance')
plt.xlabel('Size')
plt.ylabel('GFLOP/s')
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
%env JAX_PLATFORMS=cpu
%env NPROC=1
import jax
import time


# For matrix multiplication, the number of operations is 2*n^3
n = 4096
flops = 2 * n**3  # Total number of floating point operations

# Benchmark using a loop instead of timeit
num_runs = 10
total_time = 0
for i in range(num_runs):
    x = jax.random.uniform(jax.random.key(i), (n, n))
    start_time = time.monotonic()
    (x @ x).block_until_ready()  # Ensure computation is complete
    total_time += time.monotonic() - start_time

avg_time = total_time / num_runs
gflops_per_second = (flops / 1e9) / avg_time

print(f"For a {n}Ã—{n} matrix multiplication:")
print(f"Number of operations: {flops:,} (2*n^3)")
print(f"Average execution time: {avg_time*1000:.3f} ms")
print(f"Performance: {gflops_per_second:.2f} GFLOP/s")