In [1]:
import numpy as np
from numpy.testing import assert_almost_equal, assert_equal
import scipy.fftpack as fft
import matplotlib.pyplot as plt
import math

from scipy.stats import norm, chi
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import chi
from scipy.optimize import curve_fit

np.random.seed(123)

# Scalling

```
FFHQ_128
2024-10-16 10:29:04 root         INFO     [LBM_Base_Corruptor.py:137] Preprocessing (lbm) 0
2024-10-16 10:30:17 root         INFO     [LBM_Base_Corruptor.py:137] Preprocessing (lbm) 1000
On A100 GPU it takes ~ 1 min to process 1000 images with 128x128 resolution (steps=uniform_rand(0,100) thus the dataset with 70 000 images is processed in ~1h 10min.

python train_corrupted.py --config configs/ffhq/res_256/ffhq_256_lbm_ns_config_lin_visc.py
2024-10-20 21:09:15 root         INFO     [LBM_Base_Corruptor.py:116] Preprocessing (lbm) 0
2024-10-20 21:11:11 root         INFO     [LBM_Base_Corruptor.py:116] Preprocessing (lbm) 1000

On A100 GPU it takes ~ 2 min to process 1000 images with 256x256 resolution (steps=uniform_rand(0,100) thus the dataset with 70 000 images is processed in ~2h 20min.


python train_corrupted.py --config configs/ffhq/res_512/ffhq_512_lbm_ns_config_lin_visc.py
2024-10-20 21:07:30 root         INFO     [LBM_Base_Corruptor.py:116] Preprocessing (lbm) 0
2024-10-20 21:10:09 root         INFO     [LBM_Base_Corruptor.py:116] Preprocessing (lbm) 1000


On A100 GPU it takes ~ 2.5 min to process 1000 images with 512x512 resolution (steps=uniform_rand(0,100) thus the dataset with 70 000 images is processed in ~3h.


python train_corrupted.py --config configs/ffhq/res_1024/ffhq_1024_lbm_ns_config_lin_visc.py
2024-10-20 21:05:58 root         INFO     [LBM_Base_Corruptor.py:116] Preprocessing (lbm) 0
2024-10-20 21:11:39 root         INFO     [LBM_Base_Corruptor.py:116] Preprocessing (lbm) 1000

On A100 GPU it takes ~ 6 min to process 1000 images with 1024x1024 resolution (steps=uniform_rand(0,100) thus the dataset with 70 000 images is processed in ~7h.
```

In [3]:
img_resolution = np.array([128,256,512,1024])
time = np.array([85, 135, 185, 393])/60

fig = plt.figure(figsize=(6, 6))
ax = fig.add_axes([0.2, 0.2, 0.7, 0.7])  # [left, bottom, width, height] as fractions of the figure size
plt.plot(img_resolution, time, 'rx', label=f'Time to process 70k dataset')
ax.grid(True, which="both", ls="--")
ax.set_xlabel(r"img size")
ax.set_ylabel(r"time [h]")
plt.xticks(img_resolution)

plt.legend()
plt.show()

# Schedulers

In [4]:

def exp_schedule(min_value, max_value, n, dtype=float):
    return np.exp(np.linspace(np.log(min_value), np.log(max_value), n)).astype(dtype)

def lin_schedule(min_value, max_value, n, dtype=float):
    return np.linspace(min_value, max_value, n).astype(dtype)

def cosine_beta_schedule(min_value, max_value, n, s=0.008, dtype=float):
    """
    Rescaled cosine schedule as proposed in https://arxiv.org/abs/2102.09672
    """
    x = np.linspace(0, n, n)
    alphas_cumprod = np.cos(((x / n) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
    betas = 1-(alphas_cumprod[1:] / alphas_cumprod[:-1])

    # Rescale betas
    betas_scaled = betas * (max_value - min_value) + min_value

    # Rescale 1-alphas_cumprod
    alphas_scaled =  alphas_cumprod * (max_value - min_value) + min_value

    return betas_scaled.astype(dtype), alphas_scaled.astype(dtype)

def inv_cosine_aplha_schedule(min_value, max_value, n, s=0.008, dtype=float):
    """
    Insipredd by schedule proposed in https://arxiv.org/abs/2102.09672
    """
    x = np.linspace(0, n, n)
    alphas_cumprod = np.cos(((x / n) + s) / (1 + s) * np.pi * 0.5) ** 2
    alphas_cumprod = alphas_cumprod / alphas_cumprod[0]

    # Rescale 1-alphas_cumprod
    alphas_inv_scaled =  (alphas_cumprod) * (max_value - min_value) + min_value
    return np.flip(alphas_inv_scaled).astype(dtype)

def tanh_schedule(min_value, max_value, n, steepness = 0.005, dtype=float):
    x = np.linspace(-500, 500, n)
    result = (np.tanh(steepness*x) + 1) / 2
    result_scaled = result * (max_value - min_value) + min_value
    return result_scaled.astype(dtype)

def log_schedule(min_value, max_value, n, log_base=2.0, dtype=int):
    starting_lbm_steps_pow = np.emath.logn(log_base, min_value)
    final_lbm_steps_pow = np.emath.logn(log_base, max_value)

    # python 3.10 --> math.pow
    # python 3.12 --> np.pow
    if math.pow(log_base, final_lbm_steps_pow) != max_value:
        final_lbm_steps_pow += 2 * np.finfo(float).eps

    schedule = np.logspace(
        starting_lbm_steps_pow,final_lbm_steps_pow,
        n, base=log_base)

    return schedule.astype(dtype)

### setup

In [5]:
min_solver_steps = 1
max_solver_steps = 250
n_elements = 100 # aka denoising steps


y_lin_sched = lin_schedule(min_solver_steps, max_solver_steps, n_elements, int)
y_exp_sched = exp_schedule(min_solver_steps, max_solver_steps, n_elements, int)
y_log_sched = log_schedule(min_solver_steps, max_solver_steps, n_elements, log_base=10.0,  dtype=int)

y_tanh_sched = tanh_schedule(min_solver_steps, max_solver_steps, n_elements, dtype=int)
y_inv_cos_alpha_shed = inv_cosine_aplha_schedule(min_solver_steps, max_solver_steps, n_elements, dtype=int)

In [6]:
fig = plt.figure(figsize=(14, 8))

# Create Axes with space for the title and labels
ax = fig.add_axes([0.2, 0.2, 0.7, 0.7])  # [left, bottom, width, height] as fractions of the figure size

# plt.plot(y_lin_sched, '<',label="y_lin_sched", color='green')
# plt.plot(np.unique(y_lin_sched), '<', label="unique y_lin_sched", color='blue')

plt.plot(y_exp_sched, 'x',label="y_exp_sched", color='green')
plt.plot(np.unique(y_exp_sched), 'x', label="unique y_exp_sched", color='blue')

plt.plot(y_log_sched, 'P',label="y_log_sched", color='green')
plt.plot(np.unique(y_log_sched), 'P', label="unique y_log_sched", color='blue')

plt.plot(y_tanh_sched, 'X',label="y_tanh_sched", color='green')
plt.plot(np.unique(y_tanh_sched), 'X', label="unique y_tanh_sched", color='blue')

plt.plot(y_inv_cos_alpha_shed, '>',label="y_inv_cos_alpha_shed", color='green')
plt.plot(np.unique(y_inv_cos_alpha_shed), '>', label="unique y_inv_cos_alpha_shed", color='blue')

# Add grid and labels
ax.grid(True, which="both", ls="--")
ax.set_xlabel(r"input time")
ax.set_ylabel(r"scheduler")

plt.legend()


In [7]:
len(np.unique(y_exp_sched))
len(np.unique(y_log_sched))

## some other ideas

which are **not** supercool

In [8]:
from functools import reduce


fibonacci_gen = lambda n: reduce(lambda x, _: x + [x[-1] + x[-2]], range(n - 2), [0, 1])
fibonacci_numbers = fibonacci_gen(18)

triangular_numbers_gen = lambda n: n*(n+1)/2
triangular_numbers = np.array([triangular_numbers_gen(k) for k in range(50)]).astype(int)

print(fibonacci_numbers)
print(triangular_numbers)

fig = plt.figure(figsize=(8, 8))

# Create Axes with space for the title and labels
ax = fig.add_axes([0.2, 0.2, 0.7, 0.7])  # [left, bottom, width, height] as fractions of the figure size

# Plot the current energy spectrum
plt.plot(fibonacci_numbers , 'gx', label=f'fibonacci_numbers')
plt.plot(triangular_numbers , 'bx', label=f'triangular_numbers')

# Add grid and labels
ax.grid(True, which="both", ls="--")
ax.set_xlabel(r"input time")
ax.set_ylabel(r"tweaked time")

plt.legend()

In [9]:
K = 100
blur_sigma_max = 200
blur_sigma_min = 1

x = np.linspace(np.log(blur_sigma_min), np.log(blur_sigma_max), K)
tweaked_time0 = np.exp(x)

K_range = np.arange(0, K)

In [10]:
def integerize_array(N):
  arr = np.arange(N)
  arr = np.exp(arr)
  arr = np.ceil(arr)  # Step 1: Ceil
  arr = np.sort(arr)  # Step 2: Sort
  arr = np.unique(arr)  # Step 3: Remove duplicates

  #Adjust N to get 25 elements in the final array
  while len(arr) != 25:
    if len(arr) < 25:
      N +=1
    else:
      N -=1
    arr = np.arange(N)
    arr = np.exp(arr)
    arr = np.ceil(arr)
    arr = int(arr)
    arr = np.sort(arr)
    arr = np.unique(arr)

  return arr


def integerize_array2(arr):
  arr = np.ceil(arr)  # Step 1: Ceil
  arr = np.sort(arr)  # Step 2: Sort
  arr = np.unique(arr)  # Step 3: Remove duplicates,

  # arr = arr.astype(int)
  return arr


tweaked_time1 = integerize_array(25) #start with N=25
tweaked_time2 = integerize_array2(tweaked_time0) #start with N=25
print(tweaked_time0)
print(tweaked_time1.astype(int))
print(tweaked_time2)
len(tweaked_time2)

In [11]:
!python --version

In [12]:
fig = plt.figure(figsize=(8, 8))

# Create Axes with space for the title and labels
ax = fig.add_axes([0.2, 0.2, 0.7, 0.7])  # [left, bottom, width, height] as fractions of the figure size

# Plot the current energy spectrum
plt.plot(tweaked_time0, 'rx', label=f'reference scheduler')
plt.plot(tweaked_time2, 'gx', label=f'scheduler v2')

# Add grid and labels
ax.grid(True, which="both", ls="--")
ax.set_xlabel(r"input time")
ax.set_ylabel(r"tweaked time")

plt.legend()

# Time scheduler


In [13]:
# #ffhq128 - estimate
# blur_sigma_max = 128 # blur_sigma_max=128 for 256x256 ffhq in default ihd config --> Fo = 0.001953125
# L = 256

blur_sigma_max = 16 # blur_sigma_max=32 for 128x128 ffhq config --> same Fo = 0.001953125
L = 128

final_lbm_step = 500
max_fwd_steps = 200 # max_fwd_steps = 200 in default ihd config for 256x256 ffhq

def exp_schedule(min_value, max_value, n, dtype=np.float32):
    return np.exp(np.linspace(np.log(min_value), np.log(max_value), n)).astype(dtype)

ihd_blur_schedule = exp_schedule(0.5, blur_sigma_max, max_fwd_steps )

corrupt_sched = exp_schedule(1, final_lbm_step, max_fwd_steps, dtype=int)
niu_sched = np.linspace(1/6, 1/6, max_solver_steps).astype(np.float32)

niu0 = niu_sched[0]

In [14]:
def calc_Fo(sigma, L):
    Fo = sigma / (L*L)
    return Fo

Fo = calc_Fo(blur_sigma_max, L)
# print(f"Fo = {Fo}")

def get_timesteps_from_sigma(diffusivity, sigma):
    # sigma = np.sqrt(2 * diffusivity * tc)
    tc = sigma*sigma/(2*diffusivity)
    return int(tc)

lbm_iter = get_timesteps_from_sigma(niu0, blur_sigma_max)
# print(f"lbm_iter = {lbm_iter}")

def get_sigma_from_Fo(Fo, L):
    sigma = Fo * L*L
    return sigma

assert_almost_equal(
    get_sigma_from_Fo(Fo, L),
    blur_sigma_max)

print(f"Fo={Fo}\n sigma={get_sigma_from_Fo(Fo, L)} \n"
      f" lbm_iter={lbm_iter}\n L ={L}")
############################################################

def get_timesteps_from_Fo_niu_L(Fo, diffusivity, L):
    # sigma = np.sqrt(2 * diffusivity * tc)
    sigma = Fo * L*L
    tc = sigma*sigma/(2*diffusivity)
    return int(tc)

assert_equal(get_timesteps_from_Fo_niu_L(Fo, niu0, L),
    get_timesteps_from_sigma(niu0, blur_sigma_max))

def recalculate_blur_schedule(blur_schedule, niu_sched, L):
    """Recalculates the blur schedule from sigmas to timesteps.

    Args:
        blur_schedule: A list of sigmas.
        L: The size of the image.
        niu: The diffusivity.

    Returns:
        A list of timesteps.
    """

    timesteps_list = []
    Fo_list = []
    Fo_list.append(0.)
    iter = 0
    sum = 0

    print(f"iter \t\t Fo \t sigma \t\t timesteps")
    for sigma in blur_schedule:
        Fo = calc_Fo(sigma, L)
        # TODO: this is messy
        # niu has a schedule
        # dFo = Fo_n - Fo_n_1
        # dt_lbm = get_timesteps_from_Fo_niu_L(dFo, niu_sched[iter], L)

        timesteps = get_timesteps_from_Fo_niu_L(Fo, niu_sched[iter], L)
        print(f"{iter} \t\t {Fo:.2e} \t {sigma:.2f} \t\t {timesteps}")
        Fo_list.append(Fo)
        timesteps_list.append(timesteps)
        iter += 1
        # print(iter)\
    return np.array(timesteps_list), np.array(Fo_list)


def recalculate_blur_schedule_v2_WIP(blur_schedule, niu_sched, L):
    """Recalculates the blur schedule from sigmas to timesteps.

    Args:
        blur_schedule: A list of sigmas.
        L: The size of the image.
        niu: The diffusivity.

    Returns:
        A list of timesteps.
    """

    timesteps_list = []
    Fo_list = []
    Fo_list.append(0.)
    iter = 0
    sum = 0
    print(f"iter \t\t Fo \t sigma \t\t timesteps")
    for sigma in blur_schedule:
        Fo = calc_Fo(sigma, L)
        # TODO: this is messy
        # niu has a schedule
        # dFo = Fo_n - Fo_n_1
        # dt_lbm = get_timesteps_from_Fo_niu_L(dFo, niu_sched[iter], L)

        timesteps = get_timesteps_from_Fo_niu_L(Fo, niu_sched[iter], L)
        print(f"{iter} \t\t {Fo:.2e} \t {sigma:.2f} \t\t {timesteps}")
        Fo_list.append(Fo)
        timesteps_list.append(timesteps)
        iter += 1
        # print(iter)\
    return np.array(timesteps_list), np.array(Fo_list)


def get_Fo_from_tc_niu_L(tc, diffusivity, L):
    sigma = np.sqrt(2. * diffusivity * tc)
    Fo = sigma / (L*L)
    return Fo

def calc_Fo_schedule(dtc_sched, niu_sched, L):
    Fo_list = []
    Fo_list.append(0.)
    iter = 0

    print(f"iter \t\t Fo \t dtc \t\t niu")
    for dtc, niu in zip(dtc_sched, niu_sched):
        dFo = get_Fo_from_tc_niu_L(dtc, niu, L)
        # print(f"{Fo_list}")
        # print(f"{dFo}")

        Fo = Fo_list[-1] + dFo
        Fo_list.append(Fo)
        iter += 1
        print(f"{iter} \t\t {Fo:.2e} \t {dtc:.2f} \t\t {niu:.2e}")
    return np.array(Fo_list)



In [15]:
lbm_ihd_timesteps_schedule, ihd_Fo_schedule = recalculate_blur_schedule(ihd_blur_schedule, niu_sched, L)

fig = plt.figure(figsize=(8, 8))

# Create Axes with space for the title and labels
ax = fig.add_axes([0.2, 0.2, 0.7, 0.7])  # [left, bottom, width, height] as fractions of the figure size

# Plot the current energy spectrum
plt.plot(lbm_ihd_timesteps_schedule, 'rx', label=f'lbm_ihd_timesteps_schedule')
plt.plot(np.unique(lbm_ihd_timesteps_schedule), 'gx', label=f'unique lbm_ihd_timesteps_schedule')

# Add grid and labels
ax.grid(True, which="both", ls="--")
ax.set_xlabel(r"input time")
ax.set_ylabel(r"lbm steps")

plt.legend()
# plt.close()


###########################################
fig = plt.figure(figsize=(8, 8))

# Create Axes with space for the title and labels
ax = fig.add_axes([0.2, 0.2, 0.7, 0.7])  # [left, bottom, width, height] as fractions of the figure size

# Plot the current energy spectrum
plt.plot(ihd_Fo_schedule, 'r<', label=f'ihd_Fo_schedule')
plt.plot(np.unique(ihd_Fo_schedule), 'gx', label=f'unique ihd_Fo_schedule')

# Add grid and labels
ax.grid(True, which="both", ls="--")
ax.set_xlabel(r"input time")
ax.set_ylabel(r"Fo")

plt.legend()
# plt.close()

In [16]:
# corrupt_sched = exp_schedule(1, final_lbm_step, max_fwd_steps, dtype=int)
dcorrupt_sched = np.linspace(1, 10, max_solver_steps, dtype=int)
corrupt_sched =  np.array([0.] + list(np.cumsum(dcorrupt_sched))) # scan add
# niu_sched = np.linspace(1E-3*1/6, 1/6, max_solver_steps).astype(np.float32)
niu_sched = exp_schedule(1E-4*1/6, 1/6, max_solver_steps).astype(np.float32)

Fo_schedule = calc_Fo_schedule(dcorrupt_sched, niu_sched, L)
# Fo_schedule_unique = calc_Fo_schedule(np.unique(corrupt_sched), niu_sched, L)

fig = plt.figure(figsize=(8, 8))

# Create Axes with space for the title and labels
ax = fig.add_axes([0.2, 0.2, 0.7, 0.7])  # [left, bottom, width, height] as fractions of the figure size

# Plot the current energy spectrum
plt.plot(corrupt_sched, Fo_schedule,  'rx', label=f'Fo_schedule')
# plt.plot(np.unique(Fo_schedule_unique), 'gx', label=f'Fo_schedule_unique')

# Add grid and labels
ax.grid(True, which="both", ls="--")
ax.set_xlabel(r"lbm steps")
ax.set_ylabel(r"Fo")

plt.legend()



In [17]:
import numpy as np

def calculate_t_niu_array(Fo, niu_min, niu_max, L):
  """Calculates `t` and `niu` for each element in `Fo,
  knowing that Fo=np.sqrt(2*t*niu)/(L*L).

  Assumptions:
    `t` and `niu` are both monotonically increasing.
    `t` is a positive integer.
    `niu` is a float within range `niu_min` and `niu_max`.
    `L` is a positive float.
    `Fo` is a NumPy array.

  Args:
    Fo: The NumPy array of `Fo` values.
    niu_min: The minimum value of `niu`.
    niu_max: The maximum value of `niu`.

  Returns:
    Two NumPy arrays containing the values of `t` and `niu`
    corresponding to each element in `Fo`.
  """

  dt_values = []
  niu_values = []

  realizable_dFo = []
  niu_realizable_values= []

  dt = 1.
  for i in range(1, len(Fo)):
    dFo = Fo[i] - Fo[i-1]

    while True:
      niu = ((L*L*dFo) **2 )/ (2 * dt)
      if niu <= niu_max:
        dt_values.append(dt)
        niu_values.append(niu)

        if niu < niu_min:
          dFo_step_realizable = np.sqrt(2*dt*niu_min)/(L*L)
          niu_realizable_values.append(niu_min)
        else:
          dFo_step_realizable = get_Fo_from_tc_niu_L(dt, niu, L)
          niu_realizable_values.append(niu)

        realizable_dFo.append(dFo_step_realizable)

        break
      else:
        print(f"niu out of range at i= {i} \t dt={dt} niu={niu:.2e} \t dFo={dFo:.2e}, Fo = {Fo[i]:.2e}")
      dt += 1

      if dt > 100:
        break

  return np.array(dt_values), np.array(niu_values), np.array(niu_realizable_values), np.array(realizable_dFo)