In [2]:
import torch
import matplotlib.pyplot as plt
from classeg.extensions.unstable_diffusion.forward_diffusers.diffusers import LinearDiffuser

In [3]:
xt_im = torch.randn( (4, 3, 128, 128))
xt_seg = torch.randn( (4, 1, 128, 128))

mini = xt_im.reshape(xt_im.shape[0], -1 ).min(dim=-1)
print(mini.values.shape)
print(xt_im.size())

torch.Size([4])
torch.Size([4, 3, 128, 128])


In [39]:
import numpy as np
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
    def sigmoid(x):
        return 1 / (np.exp(-x) + 1)

    if beta_schedule == "quad":
        betas = (
            np.linspace(
                beta_start ** 0.5,
                beta_end ** 0.5,
                num_diffusion_timesteps,
                dtype=np.float64,
            )
            ** 2
        )
    elif beta_schedule == "linear":
        betas = np.linspace(
            beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "const":
        betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
    elif beta_schedule == "jsd":  # 1/T, 1/(T-1), 1/(T-2), ..., 1
        betas = 1.0 / np.linspace(
            num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
        )
    elif beta_schedule == "sigmoid":
        betas = np.linspace(-6, 6, num_diffusion_timesteps)
        betas = sigmoid(betas) * (beta_end - beta_start) + beta_start
    else:
        raise NotImplementedError(beta_schedule)
    assert betas.shape == (num_diffusion_timesteps,)
    return betas

def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1).to(beta.device), beta], dim=0)
    a = (1 - beta).cumprod(dim=0).index_select(0, t + 1).view(-1, 1, 1, 1)
    return a

In [40]:
betas  = get_beta_schedule("linear", beta_start=0.0015, beta_end=0.0195, num_diffusion_timesteps=100)
print(betas.shape)
betas = torch.from_numpy(betas)
print(betas)


our = LinearDiffuser(100, 0.0015, 0.0195)
print(our._betas)

(100,)
tensor([0.0015, 0.0017, 0.0019, 0.0020, 0.0022, 0.0024, 0.0026, 0.0028, 0.0030,
        0.0031, 0.0033, 0.0035, 0.0037, 0.0039, 0.0040, 0.0042, 0.0044, 0.0046,
        0.0048, 0.0050, 0.0051, 0.0053, 0.0055, 0.0057, 0.0059, 0.0060, 0.0062,
        0.0064, 0.0066, 0.0068, 0.0070, 0.0071, 0.0073, 0.0075, 0.0077, 0.0079,
        0.0080, 0.0082, 0.0084, 0.0086, 0.0088, 0.0090, 0.0091, 0.0093, 0.0095,
        0.0097, 0.0099, 0.0100, 0.0102, 0.0104, 0.0106, 0.0108, 0.0110, 0.0111,
        0.0113, 0.0115, 0.0117, 0.0119, 0.0120, 0.0122, 0.0124, 0.0126, 0.0128,
        0.0130, 0.0131, 0.0133, 0.0135, 0.0137, 0.0139, 0.0140, 0.0142, 0.0144,
        0.0146, 0.0148, 0.0150, 0.0151, 0.0153, 0.0155, 0.0157, 0.0159, 0.0160,
        0.0162, 0.0164, 0.0166, 0.0168, 0.0170, 0.0171, 0.0173, 0.0175, 0.0177,
        0.0179, 0.0180, 0.0182, 0.0184, 0.0186, 0.0188, 0.0190, 0.0191, 0.0193,
        0.0195], dtype=torch.float64)
tensor([0.0015, 0.0017, 0.0019, 0.0020, 0.0022, 0.0024, 0.0026, 0.0028, 0.0

In [41]:
timesteps = 100
num_timesteps = 20
skip = timesteps // num_timesteps
seq = range(0, timesteps, skip)
seq_next = [-1] + list(seq[:-1])
print(seq)
print(seq_next)

range(0, 100, 5)
[-1, 0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70, 75, 80, 85, 90]


In [55]:
for i, j in zip(reversed(seq), reversed(seq_next)):
  print(i, j)
  t = (torch.ones(1) * i)
  next_t = (torch.ones(1) * j)

  at =  compute_alpha(betas, t.long())
  at_next = compute_alpha(betas, next_t.long())
  print(our._alpha_bars.shape)

  alpha_bars = torch.cat([our._alpha_bars, torch.tensor([1.0])], dim=0)
  print(alpha_bars.shape)
  att = alpha_bars[t.long()]
  att_next = alpha_bars[next_t.long()]
  print(at, att)
  print(at_next, att_next)

  
  beta_t = 1 - att/ att_next



95 90
torch.Size([100])
torch.Size([101])
tensor([[[[0.3756]]]], dtype=torch.float64) tensor([0.3756])
tensor([[[[0.4121]]]], dtype=torch.float64) tensor([0.4121])
90 85
torch.Size([100])
torch.Size([101])
tensor([[[[0.4121]]]], dtype=torch.float64) tensor([0.4121])
tensor([[[[0.4502]]]], dtype=torch.float64) tensor([0.4502])
85 80
torch.Size([100])
torch.Size([101])
tensor([[[[0.4502]]]], dtype=torch.float64) tensor([0.4502])
tensor([[[[0.4895]]]], dtype=torch.float64) tensor([0.4895])
80 75
torch.Size([100])
torch.Size([101])
tensor([[[[0.4895]]]], dtype=torch.float64) tensor([0.4895])
tensor([[[[0.5297]]]], dtype=torch.float64) tensor([0.5297])
75 70
torch.Size([100])
torch.Size([101])
tensor([[[[0.5297]]]], dtype=torch.float64) tensor([0.5297])
tensor([[[[0.5706]]]], dtype=torch.float64) tensor([0.5706])
70 65
torch.Size([100])
torch.Size([101])
tensor([[[[0.5706]]]], dtype=torch.float64) tensor([0.5706])
tensor([[[[0.6119]]]], dtype=torch.float64) tensor([0.6119])
65 60
torch.Size