In [7]:
# execute on Google Colab with T4 GPU

import numpy as np
import torch
import time
from ncon_torch import ncon
from tqdm import tqdm
import matplotlib.pyplot as plt
import scienceplots

In [8]:
def benchmark(n, device, num_instances=10):
  num_qubits = n
  local_dim = 2

  np_times = []
  torch_times = []
  for _ in range(num_instances): 
    np_sv = np.random.rand(local_dim**num_qubits) + 1j * np.random.rand(local_dim**num_qubits)
    np_sv /= np.linalg.norm(np_sv)
    np_sv = np_sv.reshape([local_dim] * num_qubits)
    torch_sv = torch.tensor(np_sv, dtype=torch.complex64, device=device)

    gate_np = np.random.rand(4, 4) + 1j * np.random.rand(4, 4)
    gate_np /= np.linalg.norm(gate_np)
    gate_np = gate_np.reshape(2, 2, 2, 2)  # shape: [i', j', i, j] (bra, ket)
    gate_torch = torch.tensor(gate_np, dtype=torch.complex64, device=device)

    gate_inds = [-1, -2, 1, 2]
    state_inds = [1,2] + [-i for i in range(3, num_qubits+1)]

    start_np = time.time()
    result_np = ncon([gate_np, np_sv], [gate_inds, state_inds])
    end_np = time.time()

    start_torch = time.time()
    result_torch = ncon([gate_torch, torch_sv], [gate_inds, state_inds])
    torch.cuda.synchronize()
    end_torch = time.time()

    np_times.append(end_np-start_np)
    torch_times.append(end_torch-start_torch)

    assert np.allclose(result_np, result_torch.cpu().numpy()), "Mismatch between NumPy and PyTorch results!"
  return np.mean(np_times), np.mean(torch_times)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

n_range = np.arange(4,28,2)
np_times = []
torch_times = []
for n in tqdm(n_range):
    np_time, torch_time = benchmark(n, device)
    np_times.append(np_time)
    torch_times.append(torch_time)

cpu


  0%|          | 0/12 [00:00<?, ?it/s]


AssertionError: Torch not compiled with CUDA enabled

In [None]:
def set_style():
    plt.style.use(['science', 'grid'])
    plt.rcParams['text.usetex'] = True
    plt.rcParams['image.cmap'] = 'cividis'
    
    plt.rcParams['axes.labelsize'] = 8
    plt.rcParams['xtick.labelsize'] = 8  
    plt.rcParams['ytick.labelsize'] = 8 
    plt.rcParams['legend.fontsize'] = 8
    plt.rcParams['axes.titlesize'] = 10
    
    plt.rcParams['lines.linewidth'] = 1.5
    plt.rcParams['savefig.dpi'] = 300
    plt.rcParams['figure.dpi'] = 300 
    
    plt.rcParams['axes.grid'] = True  
    plt.rcParams['axes.spines.top'] = True 
    plt.rcParams['axes.spines.right'] = True
    plt.rcParams['xtick.top'] = False  
    plt.rcParams['ytick.right'] = False
    plt.rcParams['axes.spines.left'] = True
    plt.rcParams['ytick.left'] = False

set_style()
colors = plt.cm.plasma(np.linspace(0.3,0.7,2))
fig, ax = plt.subplots(figsize=(3.2,2))
plt.plot(n_range, np_times, label='NumPy', marker='o', color=colors[0])
plt.plot(n_range, torch_times, label='PyTorch', marker='o', color=colors[1])
plt.xlabel('number of qubits')
plt.ylabel('time (s)')
plt.legend()
plt.yscale('log')
plt.savefig('benchmark.pdf', bbox_inches='tight')