<a href="https://colab.research.google.com/github/SzymonNowakowski/diffusions/blob/master/rr_compute.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# rr function

### Author: Piotr Pokarowski (R code) and Szymon Nowakowski (Python port and tests)



# `rr` function

In [110]:
import torch
dtype=torch.float32

noise = torch.tensor([0.0000000], dtype=dtype, device="cpu")
torch.set_printoptions(precision=11, sci_mode=False)


def rr(num_steps: int, res_dtype: torch.dtype, device: torch.device):
    """
    Implementation of the R logic in PyTorch.
    Returns:
        rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML
    Each vector has length num_steps, sorted from largest to smallest,
    with 0.0 appended at the end.
    """
    num_steps = int(num_steps)
    dtype = torch.float64

    # s1 = seq(0.002^(1/7), 80^(1/7), len=TT)^7
    start = 80.0 ** (1.0 / 7.0)
    end = 0.002 ** (1.0 / 7.0)
    s1_root = torch.linspace(start, end, num_steps, dtype=dtype, device=device)
    s1 = s1_root ** 7.0

    # roO = 1/s1
    ro0 = 1.0 / s1

    # R:
    #   roO[-TT] = all except the last one  -> ro0[:-1]
    #   roO[-1]  = all except the first one -> ro0[1:]
    ro0_head = ro0[1:]   # roO[-TT]
    ro0_tail = ro0[:-1]  # roO[-1]

    # gaO = (roO[-TT]/roO[-1])^2
    ga0 = (ro0_head / ro0_tail) ** 2.0  # (ro_new / ro_old)^2

    # fpred = sqrt(1+4/roO[-1]^2)
    fpred = torch.sqrt(1.0 + 4.0 / (ro0_tail ** 2.0))

    # fpred = fpred*max(sqrt((gaO-1)/2)/fpred) +1e-10
    scale = torch.max(torch.sqrt((ga0 - 1.0) / 2.0) / fpred)
    fpred = fpred * scale + 1e-10

    # eta2 = (gaO-1)^2 / (fpred*sqrt(2*gaO) + sqrt(2*fpred^2+1-gaO))^2
    numer = (ga0 - 1.0) ** 2.0
    denom = fpred * torch.sqrt(2.0 * ga0) + torch.sqrt(2.0 * fpred ** 2.0 + 1.0 - ga0)
    eta2 = numer / (denom ** 2.0)  # eta^2

    # ga = 1/(1 - eta2)
    ga = 1.0 / (1.0 - eta2)  # (ro_new / ro_old)^2

    # rrFLOW = 1/sqrt(gaO)     # r_old / r_new
    rrFLOW = 1.0 / torch.sqrt(ga0)

    # rrMSE = 1/sqrt(ga*gaO)   # r_old / r_new
    rrMSE = 1.0 / torch.sqrt(ga * ga0)

    # rrML = 1/gaO             # r_old / r_new
    rrML = 1.0 / ga0

    # betaMSE = sqrt(eta2)/roO[-TT]
    betaMSE = torch.sqrt(eta2) / ro0_head

    # betaML = sqrt(1 - 1/gaO)/roO[-TT]
    betaML = torch.sqrt(1.0 - 1.0 / ga0) / ro0_head

    # FLOW: all zeros (same length as betaMSE/betaML before appending zero)
    betaFLOW = torch.zeros_like(betaMSE)

    # append 0.0 at the end so the vector has length num_steps
    zero = torch.zeros(1, dtype=dtype, device=device)

    rrFLOW   = torch.cat([rrFLOW,   zero], dim=0).to(dtype = res_dtype)
    rrMSE    = torch.cat([rrMSE,    zero], dim=0).to(dtype = res_dtype)
    rrML     = torch.cat([rrML,     zero], dim=0).to(dtype = res_dtype)
    betaFLOW = torch.cat([betaFLOW, zero], dim=0).to(dtype = res_dtype)
    betaMSE  = torch.cat([betaMSE,  zero], dim=0).to(dtype = res_dtype)
    betaML   = torch.cat([betaML,   zero], dim=0).to(dtype = res_dtype)

    # --- ML to FLOW validity check
    assert torch.allclose(rrFLOW ** 2, rrML, atol=1e-6), "FLOW^2 to ML rr mismatch"

    return rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML






# `rr_time_schedule` function

In [111]:
def rr_time_schedule(num_steps: int, time_schedule, res_dtype: torch.dtype, device: torch.device):
    """
    Implementation of the R logic in PyTorch.
    Returns:
        rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML, sigmas
    Each vector (but the last) has length num_steps, sorted from largest to smallest,
    with 0.0 appended at the end.

    sigmas are length num_step+1

    input time runs from 0 to 1.0
    """

    num_steps = int(num_steps)
    dtype = torch.float64

    # check length of time_schedule and num_steps
    if time_schedule is not None and len(time_schedule) != num_steps:
        raise ValueError(
            f"Expected time_schedule of length {num_steps}, "
            f"got {len(time_schedule)}."
        )

    time_inferred = False

    if time_schedule is None:
          # s1 = seq(0.002^(1/7), 80^(1/7), len=TT)^7
      start = 80.0 ** (1.0 / 7.0)
      end = 0.002 ** (1.0 / 7.0)
      sigma_root = torch.linspace(start, end, num_steps, dtype=dtype, device=device)
      sigma = sigma_root ** 7.0

      # tt = pnorm(-log(sigma), 0.4, 1)   (R)
      normal_dist = torch.distributions.Normal(loc=torch.tensor(0.4, dtype=dtype, device=device),
                                                 scale=torch.tensor(1.0, dtype=dtype, device=device))
      time_schedule = normal_dist.cdf(-torch.log(sigma))
      print("Inferred time:",time_schedule)
      time_inferred = True

    time_schedule = time_schedule.to(dtype=dtype, device=device)  #conversion to common type and device for all paths

    normal_dist_back = torch.distributions.Normal(loc=torch.tensor(0.4, dtype=dtype, device=device),
                                                  scale=torch.tensor(1.0, dtype=dtype, device=device))
    s1 = torch.exp(-normal_dist_back.icdf(time_schedule))  # sigma from time

    if time_inferred:
      assert torch.allclose(sigma, s1, atol=1e-6), "sigma -> time -> sigma mismatch"


    # roO = 1/s1
    ro0 = 1.0 / s1

    # R:
    #   roO[-TT] = all except the last one  -> ro0[:-1]
    #   roO[-1]  = all except the first one -> ro0[1:]
    ro0_head = ro0[1:]   # roO[-TT]
    ro0_tail = ro0[:-1]  # roO[-1]

    # gaO = (roO[-TT]/roO[-1])^2
    ga0 = (ro0_head / ro0_tail) ** 2.0  # (ro_new / ro_old)^2

    # fpred = sqrt(1+4/roO[-1]^2)
    fpred = torch.sqrt(1.0 + 4.0 / (ro0_tail ** 2.0))

    # fpred = fpred*max(sqrt((gaO-1)/2)/fpred) +1e-10
    scale = torch.max(torch.sqrt((ga0 - 1.0) / 2.0) / fpred)
    fpred = fpred * scale + 1e-10

    # eta2 = (gaO-1)^2 / (fpred*sqrt(2*gaO) + sqrt(2*fpred^2+1-gaO))^2
    numer = (ga0 - 1.0) ** 2.0
    denom = fpred * torch.sqrt(2.0 * ga0) + torch.sqrt(2.0 * fpred ** 2.0 + 1.0 - ga0)
    eta2 = numer / (denom ** 2.0)  # eta^2

    # ga = 1/(1 - eta2)
    ga = 1.0 / (1.0 - eta2)  # (ro_new / ro_old)^2

    # rrFLOW = 1/sqrt(gaO)     # r_old / r_new
    rrFLOW = 1.0 / torch.sqrt(ga0)

    # rrMSE = 1/sqrt(ga*gaO)   # r_old / r_new
    rrMSE = 1.0 / torch.sqrt(ga * ga0)

    # rrML = 1/gaO             # r_old / r_new
    rrML = 1.0 / ga0

    # betaMSE = sqrt(eta2)/roO[-TT]
    betaMSE = torch.sqrt(eta2) / ro0_head

    # betaML = sqrt(1 - 1/gaO)/roO[-TT]
    betaML = torch.sqrt(1.0 - 1.0 / ga0) / ro0_head

    # FLOW: all zeros (same length as betaMSE/betaML before appending zero)
    betaFLOW = torch.zeros_like(betaMSE)

    # append 0.0 at the end so the vector has length num_steps
    zero = torch.zeros(1, dtype=dtype, device=device)

    s1       = torch.cat([s1,   zero], dim=0).to(dtype = res_dtype)
    rrFLOW   = torch.cat([rrFLOW,   zero], dim=0).to(dtype = res_dtype)
    rrMSE    = torch.cat([rrMSE,    zero], dim=0).to(dtype = res_dtype)
    rrML     = torch.cat([rrML,     zero], dim=0).to(dtype = res_dtype)
    betaFLOW = torch.cat([betaFLOW, zero], dim=0).to(dtype = res_dtype)
    betaMSE  = torch.cat([betaMSE,  zero], dim=0).to(dtype = res_dtype)
    betaML   = torch.cat([betaML,   zero], dim=0).to(dtype = res_dtype)

    # --- ML to FLOW validity check
    assert torch.allclose(rrFLOW ** 2, rrML, atol=1e-6), "FLOW^2 to ML rr mismatch"

    return rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML, s1




# ASSERTS FOR KARRAS SIGMA SCHEDULE COMPUTE

In [112]:
r_vals_FLOW = torch.tensor([  #rrFLOW by Pokar - sigma ratios
        0.8366359, 0.8327429, 0.8286604, 0.8243743, 0.8198691,
        0.8151275, 0.8101306, 0.8048572, 0.7992839, 0.7933844,
        0.7871296, 0.7804864, 0.7734178, 0.7658819, 0.7578312,
        0.7492115, 0.7399610, 0.7300084, 0.7192719, 0.7076561,
        0.6950504, 0.6813246, 0.6663258, 0.6498723, 0.6317477,
        0.6116921, 0.5893917, 0.5644651, 0.5364472, 0.5047683,
        0.4687320, 0.0000000
    ], dtype=dtype, device=noise.device)

betas_diffusion_FLOW = torch.tensor([
  0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000,
  0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000,
  0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000,
  0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000,
  0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000,
  0.0000000, 0.0000000, 0.0000000, 0.0000000, 0.0000000,
  0.0000000, 0.0000000
], dtype=dtype, device=noise.device)

r_vals_MSE = torch.tensor([  #rrMSE by Pokar
  0.8366357, 0.8327426, 0.8286600, 0.8243737, 0.8198681,
  0.8151260, 0.8101281, 0.8048533, 0.7992774, 0.7933737,
  0.7871113, 0.7804548, 0.7733622, 0.7657820, 0.7576482,
  0.7488696, 0.7393110, 0.7287571, 0.7168593, 0.7030898,
  0.6868220, 0.6676873, 0.6459559, 0.6222139, 0.5965456,
  0.5682467, 0.5359994, 0.4978241, 0.4503007, 0.3852130,
  0.2197134, 0.0000000
], dtype=dtype, device=noise.device)

betas_diffusion_MSE = torch.tensor([
  0.043344183, 0.044376738, 0.045459380, 0.046595754, 0.047789821,
  0.049045871, 0.050368504, 0.051762576, 0.053233054, 0.054784720,
  0.056421549, 0.058145500, 0.059954137, 0.061835954, 0.063761203,
  0.065663974, 0.067407924, 0.068724701, 0.069120044, 0.067790560,
  0.063734535, 0.056347248, 0.046276620, 0.035376323, 0.025486142,
  0.017531299, 0.011610107, 0.007427481, 0.004594182, 0.002757327,
  0.001766672, 0.000000000
], dtype=dtype, device=noise.device)

r_vals_MSE_63 = torch.tensor([  # rrMSE by Pokar
  0.915197, 0.9141584, 0.9130933, 0.9120015, 0.9108819,
  0.9097335, 0.9085551, 0.9073455, 0.9061035, 0.9048278,

  0.9035169, 0.9021693, 0.9007836, 0.8993580, 0.8978907,
  0.8963800, 0.8948237, 0.8932197, 0.8915658, 0.8898594,

  0.8880979, 0.8862783, 0.8843976, 0.8824521, 0.8804380,
  0.8783509, 0.8761857, 0.8739365, 0.8715964, 0.8691570,

  0.8666081, 0.8639367, 0.8611263, 0.8581559, 0.8549980,
  0.8516176, 0.8479704, 0.8440028, 0.8396537, 0.8348607,

  0.8295707, 0.8237538, 0.8174154, 0.8105968, 0.8033628,
  0.7957759, 0.7878728, 0.7796501, 0.7710623, 0.7620285,

  0.7524396, 0.7421621, 0.7310375, 0.7188746, 0.7054347,
  0.6904060, 0.6733564, 0.6536437, 0.6302156, 0.6010721,

  0.5611755, 0.4592289, 0.0000000
], dtype=dtype, device=noise.device)

betas_diffusion_MSE_63 = torch.tensor([
  0.039073243, 0.039551953, 0.040042504, 0.040545336, 0.041060907,
  0.041589700, 0.042132218, 0.042688988, 0.043260559, 0.043847507,

  0.044450430, 0.045069949, 0.045706711, 0.046361379, 0.047034633,
  0.047727167, 0.048439675, 0.049172840, 0.049927320, 0.050703719,

  0.051502552, 0.052324197, 0.053168821, 0.054036279, 0.054925978,
  0.055836674, 0.056766194, 0.057711046, 0.058665866, 0.059622650,

  0.060569693, 0.061490129, 0.062359987, 0.063145664, 0.063800839,
  0.064263066, 0.064450726, 0.064261832, 0.063577149, 0.062270922,

  0.060231804, 0.057392975, 0.053763558, 0.049447122, 0.044634229,
  0.039567833, 0.034495387, 0.029627320, 0.025114529, 0.021045563,

  0.017456520, 0.014345626, 0.011687116, 0.009442081, 0.007565980,
  0.006013398, 0.004740838, 0.003708284, 0.002880117, 0.002226155,

  0.001725404, 0.001470735, 0.000000000
], dtype=dtype, device=noise.device)

r_vals_ML = torch.tensor([  #rrML by Pokar - sigma ratios squared
  0.6999597, 0.6934608, 0.6866781, 0.6795930, 0.6721853,
  0.6644328, 0.6563115, 0.6477951, 0.6388547, 0.6294589,
  0.6195730, 0.6091590, 0.5981751, 0.5865751, 0.5743081,
  0.5613179, 0.5475422, 0.5329123, 0.5173520, 0.5007772,
  0.4830950, 0.4642033, 0.4439901, 0.4223341, 0.3991052,
  0.3741673, 0.3473825, 0.3186209, 0.2877756, 0.2547910,
  0.2197097, 0.0000000
], dtype=dtype, device=noise.device)

betas_diffusion_ML = torch.tensor([
  36.662013635, 30.858902915, 25.852908864, 21.552093680, 17.872989770,
  14.740026867, 12.084983057, 9.846459379, 7.969377645, 6.404501137,
  5.107977807, 4.040905633, 3.168919747, 2.461800968, 1.893105349,
  1.439814351, 1.082005249, 0.802541356, 0.586781653, 0.422309410,
  0.298679353, 0.207182933, 0.140631255, 0.093155182, 0.060022151,
  0.037469196, 0.022551663, 0.013007109, 0.007133814, 0.003683362,
  0.001766681, 0.000000000
], dtype=dtype, device=noise.device)

r_vals_ML_63 = torch.tensor([  # rrML by Pokar
  0.8375871, 0.8356859, 0.8337398, 0.8317472, 0.8297064,
  0.8276157, 0.8254732, 0.8232769, 0.8210249, 0.8187149,

  0.8163448, 0.8139120, 0.8114143, 0.8088489, 0.8062130,
  0.8035037, 0.8007179, 0.7978524, 0.7949037, 0.7918681,

  0.7887418, 0.7855207, 0.7822003, 0.7787762, 0.7752433,
  0.7715965, 0.7678302, 0.7639385, 0.7599150, 0.7557531,

  0.7514454, 0.7469843, 0.7423615, 0.7375682, 0.7325948,
  0.7274312, 0.7220663, 0.7164883, 0.7106844, 0.7046408,

  0.6983426, 0.6917737, 0.6849164, 0.6777517, 0.6702590,
  0.6624157, 0.6541970, 0.6455762, 0.6365238, 0.6270075,

  0.6169919, 0.6064379, 0.5953028, 0.5835392, 0.5710949,
  0.5579123, 0.5439274, 0.5290698, 0.5132614, 0.4964158,

  0.4784378, 0.4592230, 0.0000000
], dtype=dtype, device=noise.device)

betas_diffusion_ML_63 = torch.tensor([
  29.506336512, 27.130884030, 24.919304383, 22.862229095, 20.950711687,
  19.176213924, 17.530592329, 16.006084966, 14.595298502, 13.291195523,

  12.087082127, 10.976595773, 9.953693395, 9.012639773, 8.147996166,
  7.354609196, 6.627599987, 5.962353561, 5.354508476, 4.799946715,

  4.294783827, 3.835359299, 3.418227182, 3.040146953, 2.698074608,
  2.389154001, 2.110708411, 1.860232339, 1.635383539, 1.433975272,

  1.253968783, 1.093466009, 0.950702499, 0.824040553, 0.711962584,
  0.613064683, 0.526050405, 0.449724754, 0.382988384, 0.324831999,

  0.274330953, 0.230640051, 0.192988549, 0.160675344, 0.133064358,
  0.109580108, 0.089703466, 0.072967598, 0.058954091, 0.047289246,

  0.037640559, 0.029713366, 0.023247663, 0.018015084, 0.013816053,
  0.010477088, 0.007848268, 0.005800850, 0.004225037, 0.003027900,

  0.002131431, 0.001470751, 0.000000000
], dtype=dtype, device=noise.device)

# --- execute rr function ---
rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML = rr(32, dtype, noise.device)

# --- comparison FLOW ---
assert torch.allclose(rrFLOW, r_vals_FLOW, atol=1e-6), "FLOW rr mismatch"
assert torch.allclose(betaFLOW, betas_diffusion_FLOW, atol=1e-6), "FLOW beta mismatch"

# --- comparison MSE ---
assert torch.allclose(rrMSE, r_vals_MSE, atol=1e-6), "MSE rr mismatch"
assert torch.allclose(betaMSE, betas_diffusion_MSE, atol=1e-6), "MSE beta mismatch"

# --- comparison ML ---
assert torch.allclose(rrML, r_vals_ML, atol=1e-6), "ML rr mismatch"
assert torch.allclose(betaML, betas_diffusion_ML, atol=1e-6), "ML beta mismatch"

# --- execute rr function ---
rrFLOW_63, rrMSE_63, rrML_63, betaFLOW_63, betaMSE_63, betaML_63 = rr(63, dtype, noise.device)

# --- MSE ---
assert torch.allclose(rrMSE_63, r_vals_MSE_63, atol=1e-6), "MSE-63 rr mismatch"
assert torch.allclose(betaMSE_63, betas_diffusion_MSE_63, atol=1e-6), "MSE-63 beta mismatch"

# --- ML ---
assert torch.allclose(rrML_63, r_vals_ML_63, atol=1e-6), "ML-63 rr mismatch"
assert torch.allclose(betaML_63, betas_diffusion_ML_63, atol=1e-6), "ML-63 beta mismatch"




# ASSERTS FOR TIME_SCHEDULE COMPUTE

In [113]:
# --- execute rr function ---
rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML, sigmas = rr_time_schedule(32, None, dtype, noise.device)

# --- comparison FLOW ---
assert torch.allclose(rrFLOW, r_vals_FLOW, atol=1e-6), "FLOW rr mismatch"
assert torch.allclose(betaFLOW, betas_diffusion_FLOW, atol=1e-6), "FLOW beta mismatch"

# --- comparison MSE ---
assert torch.allclose(rrMSE, r_vals_MSE, atol=1e-6), "MSE rr mismatch"
assert torch.allclose(betaMSE, betas_diffusion_MSE, atol=1e-6), "MSE beta mismatch"

# --- comparison ML ---
assert torch.allclose(rrML, r_vals_ML, atol=1e-6), "ML rr mismatch"
assert torch.allclose(betaML, betas_diffusion_ML, atol=1e-6), "ML beta mismatch"

print(sigmas)

# --- execute rr function ---
rrFLOW_63, rrMSE_63, rrML_63, betaFLOW_63, betaMSE_63, betaML_63, sigmas = rr_time_schedule(63, None, dtype, noise.device)

# --- MSE ---
assert torch.allclose(rrMSE_63, r_vals_MSE_63, atol=1e-6), "MSE-63 rr mismatch"
assert torch.allclose(betaMSE_63, betas_diffusion_MSE_63, atol=1e-6), "MSE-63 beta mismatch"

# --- ML ---
assert torch.allclose(rrML_63, r_vals_ML_63, atol=1e-6), "ML-63 rr mismatch"
assert torch.allclose(betaML_63, betas_diffusion_ML_63, atol=1e-6), "ML-63 beta mismatch"




Inferred time: tensor([0.00000086768, 0.00000207565, 0.00000492068, 0.00001154588,
        0.00002677640, 0.00006128103, 0.00013816598, 0.00030630017,
        0.00066625860, 0.00141861567, 0.00294902136, 0.00596794564,
        0.01171969095, 0.02225477496, 0.04070762205, 0.07142796389,
        0.11969670848, 0.19069271787, 0.28751902428, 0.40860637840,
        0.54564491526, 0.68383024002, 0.80569831529, 0.89762413826,
        0.95535065549, 0.98455061250, 0.99597422059, 0.99925996828,
        0.99991173810, 0.99999387887, 0.99999978692, 0.99999999696],
       dtype=torch.float64)
tensor([80.00000000000, 66.93087005615, 55.73620986938, 46.18638992310,
        38.07487487793, 31.21641349792, 25.44535636902, 20.61406135559,
        16.59137535095, 13.26121807098, 10.52124404907,  8.28158187866,
         6.46366214752,  4.99911117554,  3.82872891426,  2.90153026581,
         2.17385983467,  1.60857141018,  1.17427062988,  0.84461987019,
         0.59770041704,  0.41543191671,  0.283044010

# ASSERTS FOR THE NEW VECTOR IN NEW TIME FROM POKAR

In [114]:
# --- R reference vectors (with trailing 0 added) ---

r_vals_FLOW = torch.tensor([
    0.7345387, 0.7348875, 0.7349970, 0.7349791, 0.7349975, 0.7349982,
    0.7349997, 0.7349998, 0.7349999, 0.7350000, 0.7350000, 0.7350000,
    0.7350000, 0.7350000, 0.7350000, 0.7350000, 0.7350000, 0.7350000,
    0.7350000, 0.7350000, 0.7350000, 0.7350000, 0.7350000, 0.7349999,
    0.7349998, 0.7350000, 0.7350000, 0.7350009, 0.7350127, 0.7349762,
    0.7348165, 0.0
], dtype=dtype, device=noise.device)

r_vals_MSE = torch.tensor([
    0.7345372, 0.7348848, 0.7349920, 0.7349699, 0.7349803, 0.7349664,
    0.7349409, 0.7348910, 0.7347987, 0.7346281, 0.7343136, 0.7337365,
    0.7326850, 0.7307932, 0.7274657, 0.7218307, 0.7128428, 0.6996880,
    0.6824207, 0.6623045, 0.6413822, 0.6216069, 0.6042523, 0.5898395,
    0.5783554, 0.5695043, 0.5628810, 0.5580995, 0.5550202, 0.5518533,
    0.5399612, 0.0
], dtype=dtype, device=noise.device)

r_vals_ML = torch.tensor([
    0.5395471, 0.5400596, 0.5402206, 0.5401943, 0.5402213, 0.5402223,
    0.5402245, 0.5402247, 0.5402249, 0.5402250, 0.5402249, 0.5402249,
    0.5402250, 0.5402250, 0.5402250, 0.5402250, 0.5402250, 0.5402250,
    0.5402250, 0.5402250, 0.5402250, 0.5402250, 0.5402250, 0.5402248,
    0.5402247, 0.5402250, 0.5402249, 0.5402263, 0.5402436, 0.5401900,
    0.5399553, 0.0
], dtype=dtype, device=noise.device)

betas_diffusion_MSE = torch.tensor([
    0.143822161, 0.143632291, 0.143571233, 0.143577768, 0.143562001, 0.143550840,
    0.143530086, 0.143493101, 0.143424819, 0.143298828, 0.143066769, 0.142640844,
    0.141864697, 0.140468443, 0.138012435, 0.133852989, 0.127217927, 0.117505125,
    0.104751697, 0.089885799, 0.074409376, 0.059756709, 0.046859975, 0.036094978,
    0.027439548, 0.020660068, 0.015446663, 0.011488729, 0.008507573, 0.006299481,
    0.004753566, 0.0
], dtype=dtype, device=noise.device)

betas_diffusion_ML = torch.tensor([
    48.842260047, 35.873584716, 26.362361470, 19.376339474, 14.241143273, 10.467202884,
    7.693372483, 5.654626059, 4.156149044, 3.054769256, 2.245255411, 1.650262638,
    1.212942967, 0.891513082, 0.655262111, 0.481617644, 0.353988964, 0.260181889,
    0.191233685, 0.140556758, 0.103309216, 0.075932271, 0.055810220, 0.041020514,
    0.030150073, 0.022160298, 0.016287819, 0.011971544, 0.008799071, 0.006467485,
    0.004753627, 0.0
], dtype=dtype, device=noise.device)

betas_diffusion_FLOW = torch.zeros_like(betas_diffusion_MSE)  # FLOW always zero


const_FLOW = torch.tensor([0.00000031, 0.00000146, 0.00000626, 0.00002449, 0.00008753,
      0.00028591, 0.00085415, 0.00233565, 0.00585138, 0.01344545,
      0.02837619, 0.05509628, 0.09862199, 0.16315867, 0.25025935,
      0.35726125, 0.47691245, 0.59869879, 0.71153182, 0.80668639,
      0.87972935, 0.93076593, 0.96322515, 0.98201601, 0.99191762,
      0.99666674, 0.99874006, 0.99956394, 0.99986193, 0.99996003,
      0.99998943, 0.99999745], dtype=dtype, device=noise.device)

# --- execute rr function ---
rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML, sigmas = rr_time_schedule(
    32, const_FLOW,
    dtype,
    noise.device
)

print("SIGMAS =", sigmas)
print("rrFLOW =", rrFLOW)
print("rrMSE  =", rrMSE)
print("rrML   =", rrML)

print("betaFLOW =", betaFLOW)
print("betaMSE  =", betaMSE)
print("betaML   =", betaML)

print("\n\nUWAGA! rrMSE - R-reference =", rrMSE - r_vals_MSE)
# --- comparison FLOW ---
assert torch.allclose(rrFLOW, r_vals_FLOW, atol=0.0011), "FLOW rr mismatch"
assert torch.allclose(betaFLOW, betas_diffusion_FLOW, atol=1e-6), "FLOW beta mismatch"

# --- comparison MSE ---
assert torch.allclose(rrMSE, r_vals_MSE, atol=0.03), "MSE rr mismatch"
assert torch.allclose(betaMSE, betas_diffusion_MSE, atol=1e-3), "MSE beta mismatch"

# --- comparison ML ---
assert torch.allclose(rrML, r_vals_ML, atol=0.0017), "ML rr mismatch"
assert torch.allclose(betaML, betas_diffusion_ML, atol=1e-5), "ML beta mismatch"


SIGMAS = tensor([97.99151611328, 71.97856140137, 52.89614105225, 38.87850570679,
        28.57489013672, 21.00247383118, 15.43677997589, 11.34602832794,
         8.33932876587,  6.12940597534,  4.50511360168,  3.31125831604,
         2.43377470970,  1.78882431984,  1.31478595734,  0.96636766195,
         0.71028023958,  0.52205592394,  0.38371112943,  0.28202766180,
         0.20729035139,  0.15235839784,  0.11198344827,  0.08230778575,
         0.06049625948,  0.04446476698,  0.03268143535,  0.02402106859,
         0.01765457727,  0.01297888160,  0.00953371730,  0.00701598777,
         0.00000000000])
rrFLOW = tensor([0.73453867435, 0.73488748074, 0.73499703407, 0.73497915268,
        0.73499751091, 0.73499822617, 0.73499965668, 0.73499983549,
        0.73499995470, 0.73500001431, 0.73499995470, 0.73499995470,
        0.73500001431, 0.73500001431, 0.73499995470, 0.73500001431,
        0.73500001431, 0.73500001431, 0.73499995470, 0.73500007391,
        0.73499995470, 0.73500019312, 0.7

# BINARY SEARCH

In [115]:

from typing import Callable, Optional, List

device="cpu"
dtype = torch.float64

def binary_search_next_t(
    t_i: float,
    t_max: float,
    q: float,
    rr_time_schedule: Callable,
    tol_q: float = 1e-8,
    tol_t: float = 1e-8,
    max_iter: int = 64,
) -> Optional[float]:
    """
    Znajduje kolejny krok czasowy t_{i+1} w (t_i, t_max], tak aby
    rrMSE(t_i, t_{i+1}) ~= q przy użyciu wyszukiwania binarnego.

    Zwraca:
        t_{i+1} (float) albo None, jeśli nie da się sensownie
        zwiększyć czasu (np. już przy t_max wartość rrMSE jest zbyt mała).
    """



    # Krótka pomocnicza funkcja do policzenia rrMSE dla [t_i, t_candidate]
    def rrMSE_for_interval(t_candidate: float) -> float:
        time_schedule = torch.tensor(
            [t_i, t_candidate],
            dtype=dtype,
            device=device,
        )

        rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML, t_steps = rr_time_schedule(
            time_schedule.numel(),
            time_schedule, dtype, device
        )

        return float(rrMSE[0].item())

    # Jeśli już jesteśmy w t_max albo niemal, nie ma sensu liczyć dalej
    if t_i >= t_max:
        return None

    # Sprawdź wartość w skrajnym punkcie t_max
    val_at_tmax = rrMSE_for_interval(t_max)
    print("Checked val at t_max vs t_i= ", val_at_tmax, t_max, t_i)

    # Jeżeli nawet największy dozwolony krok przekracza q,
    # możesz:
    # (a) zakończyć i zwrócić None,
    # (b) albo przyjąć t_max jako ostatni krok (komentarz niżej).
    if val_at_tmax > q + tol_q:
        # Wersja konserwatywna: i tak przekraczamy q
        return None

        # Wersja „domykająca” przedział: bierzemy t_max jako ostatni krok
        return t_max

    # W tym miejscu zakładamy, że funkcja rrMSE(t_i, t) jest
    # (w przybliżeniu) malejąca w t, więc możemy użyć binary search.
    left = t_i
    right = t_max
    best_t = None
    best_val = None

    for _ in range(max_iter):
        mid = 0.5 * (left + right)

        # Zabezpieczenie przed nieskończoną pętlą,
        # jeśli t się już praktycznie nie zmienia
        if mid <= t_i + tol_t:
            break

        val_mid = rrMSE_for_interval(mid)

        # Zapamiętujemy najlepsze przybliżenie do q
        if best_val is None or abs(val_mid - q) < abs(best_val - q):
            best_val = val_mid
            best_t = mid

        # Jeśli trafiliśmy z dokładnością do tol_q — koniec
        if abs(val_mid - q) <= tol_q:
            best_t = mid
            break

        # Klasyczne przeszukiwanie binarne:
        # zakładamy rrMSE rośnie z t, więc:
        if val_mid > q:
            left = mid
        else:
            right = mid

    # Jeśli nic sensownego nie znaleźliśmy lub t się nie ruszyło — koniec
    if best_t is None or best_t <= t_i + tol_t:
        return None

    return best_t


def build_rr_schedule(
    t_min: float,
    t_max: float,
    q: float,
    rr_time_schedule: Callable,
    tol_q: float = 1e-8,
    tol_t: float = 1e-8,
    max_iter_per_step: int = 64,
) -> torch.Tensor:
    """
    Buduje harmonogram czasów t_0, t_1, ..., t_K:
    - t_0 = t_min,
    - t_{i+1} wybierane przez wyszukiwanie binarne tak, aby rrMSE ~= q,
    - zatrzymuje się, gdy nie da się dodać kolejnego kroku bez przekroczenia t_max.

    Zwraca:
        1D tensor z punktami czasowymi.
    """

    times: List[float] = [float(t_min)]

    while True:
        t_i = times[-1]

        next_t = binary_search_next_t(
            t_i=t_i,
            t_max=t_max,
            q=q,
            rr_time_schedule=rr_time_schedule,
            tol_q=tol_q,
            tol_t=tol_t,
            max_iter=max_iter_per_step,
        )

        # Nie da się dodać kolejnego kroku
        if next_t is None:
            break

        times.append(next_t)

        # Jeśli praktycznie doszliśmy do t_max — koniec
        if next_t >= t_max - tol_t:
            break

    return torch.tensor(times, dtype=dtype, device=device)


In [116]:
t_min = 0.00000031
t_max = 1-1e-8
q = 0.53  # Twój docelowy rrMSE

time_schedule = build_rr_schedule(
    t_min=t_min,
    t_max=t_max,
    q=q,
    rr_time_schedule=rr_time_schedule,
)


print(time_schedule)
print(time_schedule.shape)


Checked val at t_max vs t_i=  6.267818419097071e-10 0.99999999 3.1e-07
Checked val at t_max vs t_i=  1.182080007272628e-09 0.99999999 1.5249074087049667e-06
Checked val at t_max vs t_i=  2.229423835397254e-09 0.99999999 6.807143120187713e-06
Checked val at t_max vs t_i=  4.204973262106784e-09 0.99999999 2.758782397844881e-05
Checked val at t_max vs t_i=  7.931444828492391e-09 0.99999999 0.00010156088732334243
Checked val at t_max vs t_i=  1.496138265417226e-08 0.99999999 0.0003398275990529135
Checked val at t_max vs t_i=  2.8222940961556806e-08 0.99999999 0.0010342647985298581
Checked val at t_max vs t_i=  5.324138356090966e-08 0.99999999 0.002865676560418642
Checked val at t_max vs t_i=  1.004401884444652e-07 0.99999999 0.0072361228361146526
Checked val at t_max vs t_i=  1.894855373477007e-07 0.99999999 0.01667358784844426
Checked val at t_max vs t_i=  3.574818087146861e-07 0.99999999 0.03511428844868349
Checked val at t_max vs t_i=  6.744335294570665e-07 0.99999999 0.0677200549946635

In [120]:


const_MSE=    torch.tensor([
        0.00000031000, 0.00000152491, 0.00000680714, 0.00002758782,
        0.00010156089, 0.00033982760, 0.00103426480, 0.00286567656,
        0.00723612284, 0.01667358785, 0.03511428845, 0.06772005499,
        0.11988855043, 0.19541879341, 0.29437281786, 0.41168606231,
        0.53753868016, 0.65971283439, 0.76703717366, 0.85235079529,
        0.91371837718, 0.95366305078, 0.97719058493, 0.98973035991,
        0.99577814813, 0.99841748213, 0.99945975191, 0.99983218712,
        0.99995260825, 0.99998783983, 0.99999716669, 0.99999940081,
    ], dtype=dtype, device=noise.device)
# --- execute rr function ---
rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML, sigmas = rr_time_schedule(
    32, const_MSE,
    dtype,
    noise.device
)

print("SIGMAS =", sigmas)
print("rrFLOW =", rrFLOW)
print("rrMSE  =", rrMSE)
print("rrML   =", rrML)

print("betaFLOW =", betaFLOW)
print("betaMSE  =", betaMSE)
print("betaML   =", betaML)


const_MSE=    torch.tensor([
        0.00000031000, 0.00000152491,
    ], dtype=dtype, device=noise.device)
# --- execute rr function ---
rrFLOW, rrMSE, rrML, betaFLOW, betaMSE, betaML, sigmas = rr_time_schedule(
    2, const_MSE,
    dtype,
    noise.device
)

print("SIGMAS =", sigmas)
print("rrFLOW =", rrFLOW)
print("rrMSE  =", rrMSE)
print("rrML   =", rrML)

print("betaFLOW =", betaFLOW)
print("betaMSE  =", betaMSE)
print("betaML   =", betaML)


SIGMAS = tensor([97.99151314225, 71.33847477844, 51.93492614926, 37.80898319283,
        27.52520012213, 20.03853615917, 14.58819279989, 10.62030522791,
         7.73165563642,  5.62869878597,  4.09773168407,  2.98317697969,
         2.17177348858,  1.58106612982,  1.15102708270,  0.83795567729,
         0.61003752817,  0.44411154120,  0.32331627173,  0.23537648286,
         0.17135570680,  0.12474813919,  0.09081750874,  0.06611577452,
         0.04813274132,  0.03504096842,  0.02551006710,  0.01857150543,
         0.01352018424,  0.00984278823,  0.00716561583,  0.00521662183,
         0.00000000000], dtype=torch.float64)
rrFLOW = tensor([0.72800666599, 0.72800724028, 0.72800687314, 0.72800688613,
        0.72800691985, 0.72800691049, 0.72800691447, 0.72800691416,
        0.72800691736, 0.72800692307, 0.72800690960, 0.72800692127,
        0.72800692067, 0.72800691950, 0.72800691651, 0.72800691576,
        0.72800691875, 0.72800691208, 0.72800691905, 0.72800691351,
        0.7280069133