In [107]:
import math

def respire_error_estimate(
    sigma: float,
    d: int,
    ds: int,
    dr: int,
    p: int,
    q1: int,
    q2: int,
    q: int,
    nu1: int,
    nu2: int,
    nu3: int,
    t_gsw: int,
    t_auto_expand_regev: int,
    t_auto_expand_gsw: int,
    t_auto_proj: int,
    t_regev_to_gsw: int,
    t_scal_to_vec: int,
    t_compress: int,
    gadget_correction: bool,
    secret_norm_corection: bool,
  ):

  sigma_sq: float = sigma ** 2

  def eta(z: int) -> float:
    s: float = 0
    z_min = -math.floor((z - 1) / 2)
    z_max = math.floor(z / 2)
    assert(len(range(z_min, z_max + 1)) == z)
    for i in range(z_min, z_max + 1):
      s += i ** 2
    return s / z

  def z_ratio(z: int) -> float:
    if gadget_correction:
      return eta(z)
    else:
      return math.floor(z / 2) ** 2

  def t_to_z(t: int, which_q: int) -> int:
    return math.ceil(which_q ** (1 / t))

  def expand(e_sq: float, t_auto: int, z_auto: int) -> float:
    return 2 * e_sq + t_auto * d * z_ratio(z_auto) * sigma_sq

  def select(e_gsw_sq: float, e_reg_sq: float) -> float:
    return e_reg_sq + 2 * t_gsw * d * z_ratio(z_gsw) * e_gsw_sq

  z_gsw: int = t_to_z(t_gsw, q)
  z_auto_expand_regev: int = t_to_z(t_auto_expand_regev, q)
  z_auto_expand_gsw: int = t_to_z(t_auto_expand_gsw, q)
  z_auto_proj: int = t_to_z(t_auto_proj, q)
  z_regev_to_gsw: int = t_to_z(t_regev_to_gsw, q)
  z_scal_to_vec: int = t_to_z(t_scal_to_vec, q)
  z_compress: int = t_to_z(t_compress, q2)

  #
  # Query expansion
  #

  assert(d == 2048)
  log_d: int = 11

  q_reg: float = sigma_sq
  reg_iters: int = nu1
  for i in range(log_d - reg_iters):
    q_reg = expand(q_reg, t_auto_expand_gsw, z_auto_expand_gsw)
  for i in range(reg_iters):
    q_reg = expand(q_reg, t_auto_expand_regev, z_auto_expand_regev)

  c_reg: float = q_reg
  print("c_reg bits", math.log2(c_reg) / 2)

  q_gsw: float = sigma_sq
  n_gsw: int = t_gsw * (nu2 + nu3)
  gsw_iters: int = math.ceil(math.log2(n_gsw))
  for i in range(log_d - gsw_iters):
    q_gsw = expand(q_gsw, t_auto_expand_gsw, z_auto_expand_gsw)
  for i in range(gsw_iters):
    q_gsw = expand(q_gsw, t_auto_expand_gsw, z_auto_expand_gsw)

  c_reg_prime: float = q_gsw
  print("c_reg_prime bits", math.log2(c_reg_prime) / 2)

  secret_norm = 1 if secret_norm_corection else 8
  print("c_gsw bits (c_reg_prime component)", math.log2(d * c_reg_prime * secret_norm * sigma_sq) / 2)
  print("c_gsw bits (gadget component)", math.log2(2 * t_regev_to_gsw * d * z_ratio(z_regev_to_gsw) * sigma_sq) / 2)
  c_gsw = d * c_reg_prime * secret_norm * sigma_sq + 2 * t_regev_to_gsw * d * z_ratio(z_regev_to_gsw) * sigma_sq

  print("c_gsw bits", math.log2(c_gsw) / 2)

  #
  # First dimension
  #

  c_i = (2 ** nu1) * d * (p / 2) * c_reg

  print("c_i bits (first dim)", math.log2(c_i) / 2)

  #
  # Folding
  #

  for b in range(nu2):
    c_i = select(c_gsw, c_i)
    print(b, math.log2(c_i) / 2)

  print("c_i bits (fold)", math.log2(c_i) / 2)

  #
  # Projecting
  #

  for c in range(nu3):
    c_i = select(c_gsw, c_i)
    c_i = expand(c_i, t_auto_proj, z_auto_proj)

  print("c_i bits (proj)", math.log2(c_i) / 2)

  #
  # Scal to Vec
  #
  c = c_i + t_scal_to_vec * d * z_ratio(z_scal_to_vec) * sigma_sq

  print("pre compression", math.log2(c) / 2)

  # e2_prime = c * (q1 ** 2) / (q ** 2) + (q1 ** 2) / (4 * (q2 ** 2)) * (d * sigma_sq + ds * sigma_sq + t_compress * d * (z_compress ** 2) * sigma_sq)
  assert(z_compress == 2)

  # https://www.wolframalpha.com/input?i=Sum%5BBinomial%5B2048%2Ci%5D%2F2%5E2048%2C+%7Bi%2C0%2C1185%7D%5D
  # With probability <= 2^(-41.088), the zero one term will have <= 1185 ones. 1185 / 2048 = 0.579
  ZERO_ONE_FACTOR = 0.579
  e2_prime = c * (q1 ** 2) / (q ** 2) + (q1 ** 2) / (4 * (q2 ** 2)) * (d * sigma_sq + ds * sigma_sq + t_compress * d * (z_compress ** 2) * sigma_sq * ZERO_ONE_FACTOR)

  print("post compression width (gadget term)", e2_prime)

  e1_prime = 1
  error_rate = 2 * d * math.exp(-math.pi * ((1/2) * math.floor(q1 / p) - e1_prime) ** 2 / e2_prime)
  print(f"correctness = 2^({math.log2(error_rate)})")


In [116]:
respire512 = {
  "sigma": 6.4,
  "d": 2048,
  "ds": 512,
  "dr": 512,
  "p": 17,
  "q1": 6*17,
  "q2": 163841,
  "q": 268369921 * 249561089,
  "nu1": 9,
  "nu2": 9,
  "nu3": 2,
  "t_gsw": 8,
  "t_auto_expand_regev": 4,
  "t_auto_expand_gsw": 16,
  "t_auto_proj": 16,
  "t_regev_to_gsw": 4,
  "t_scal_to_vec": 8,
  "t_compress": 18,
}

In [115]:
print("**** Respire512")

print("** Base estimate")
respire_error_estimate(
  **respire512,
  gadget_correction=False,
  secret_norm_corection=False,
)
print()

print("** With gadget correction")
respire_error_estimate(
  **respire512,
  gadget_correction=True,
  secret_norm_corection=False,
)
print()

**** Respire512
** Base estimate
c_reg bits 26.650363837275084
c_reg_prime bits 18.26268271113784
c_gsw bits (c_reg_prime component) 27.940754616250473
c_gsw bits (gadget component) 22.65176927141799
c_gsw bits 27.941226370751476
c_i bits (first dim) 38.19409525790025
0 41.426717021828175
1 41.9226233402875
2 42.21373485117618
3 42.42056775448843
4 42.581119980913094
5 42.712362505808876
6 42.823362453834214
7 42.9195377607724
8 43.004385726910435
c_i bits (fold) 43.004385726910435
c_i bits (proj) 44.11545101864022
pre compression 44.11545101864022
post compression width (gadget term) 0.34984358769348123
correctness = 2^(-39.821560277369024)

** With gadget correction
c_reg bits 25.8578826595184
c_reg_prime bits 17.480152437937132
c_gsw bits (c_reg_prime component) 27.158224343049767
c_gsw bits (gadget component) 21.859288026631457
c_gsw bits 27.158689636474936
c_i bits (first dim) 37.40161408014357
0 39.87850801363926
1 40.366775178848854
2 40.655302689506826
3 40.860836413000655
4 41

In [110]:
respire1024 = {
  "sigma": 6.4,
  "d": 2048,
  "ds": 1024,
  "dr": 256,
  "p": 257,
  "q1": 1028,
  "q2": 4169729,
  "q": 268369921 * 249561089,
  "nu1": 9,
  "nu2": 8,
  "nu3": 3,
  "t_gsw": 8,
  "t_auto_expand_regev": 4,
  "t_auto_expand_gsw": 16,
  "t_auto_proj": 16,
  "t_regev_to_gsw": 4,
  "t_scal_to_vec": 8,
  "t_compress": 22,
}

In [112]:
print("**** Respire1024")

print("** Base estimate")
respire_error_estimate(
  **respire1024,
  gadget_correction=False,
  secret_norm_corection=False,
)
print()

print("** With gadget correction")
respire_error_estimate(
  **respire1024,
  gadget_correction=True,
  secret_norm_corection=False,
)
print()

# print("** ... and secret norm correction")
# respire_error_estimate(
#   **respire1024,
#   gadget_correction=True,
#   secret_norm_corection=True,
# )

**** Respire1024
** Base estimate
c_reg bits 26.650363837275084
c_reg_prime bits 18.26268271113784
c_gsw bits (c_reg_prime component) 27.940754616250473
c_gsw bits (gadget component) 22.65176927141799
c_gsw bits 27.941226370751476
c_i bits (first dim) 40.15317611187202
0 41.533644359895604
1 41.97837011651391
2 42.2514435702735
3 42.44905903119774
4 42.60401518008632
5 42.73149913704694
6 42.839800588246845
7 42.93394440861364
c_i bits (fold) 42.93394440861364
c_i bits (proj) 44.5738987927689
pre compression 44.5738987927689
post compression width (gadget term) 0.2284164640296712
correctness = 2^(-7.842528256800447)

** With gadget correction
c_reg bits 25.8578826595184
c_reg_prime bits 17.480152437937132
c_gsw bits (c_reg_prime component) 27.158224343049767
c_gsw bits (gadget component) 21.859288026631457
c_gsw bits 27.158689636474936
c_i bits (first dim) 39.36069493411534
0 40.149283730274526
1 40.51698550967265
2 40.75936394235473
3 40.94046281430651
4 41.08509762750265
5 41.2055196