In [7]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [8]:
from runners.diffusion import *
import numpy as np
import math

In [13]:
def get_timestep_embedding(timesteps, embedding_dim):
    """
    This matches the implementation in Denoising Diffusion Probabilistic Models:
    From Fairseq.
    Build sinusoidal embeddings.
    This matches the implementation in tensor2tensor, but differs slightly
    from the description in Section 3.5 of "Attention Is All You Need".
    """
    assert len(timesteps.shape) == 1

    half_dim = embedding_dim // 2
    emb = math.log(10000) / (half_dim - 1)
    emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
    #emb = emb.to(device=timesteps.device)
    emb = timesteps.float()[:, None] * emb[None, :]
    emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
    if embedding_dim % 2 == 1:  # zero pad
        emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
    return emb

get_timestep_embedding(torch.tensor([1, 2, 3, 4, 5]), 4)

tensor([[ 8.4147e-01,  1.0000e-04,  5.4030e-01,  1.0000e+00],
        [ 9.0930e-01,  2.0000e-04, -4.1615e-01,  1.0000e+00],
        [ 1.4112e-01,  3.0000e-04, -9.8999e-01,  1.0000e+00],
        [-7.5680e-01,  4.0000e-04, -6.5364e-01,  1.0000e+00],
        [-9.5892e-01,  5.0000e-04,  2.8366e-01,  1.0000e+00]])

In [17]:
arr = np.array([[0,1],[2,3],[0,5]])
print(arr[:, None])
print(arr[None, :])

[[[0 1]]

 [[2 3]]

 [[0 5]]]
[[[0 1]
  [2 3]
  [0 5]]]


In [42]:
betas = get_beta_schedule(
    beta_schedule="linear",
    beta_start=0.02,
    beta_end=0.1,
    num_diffusion_timesteps=10,
    )
print(betas.shape)
betas.shape[0]

betas = torch.from_numpy(betas)
num_timesteps = betas.shape[0]

alphas = 1.0 - betas
alphas_cumprod = alphas.cumprod(dim=0)
alphas_cumprod_prev = torch.cat(
    [torch.ones(1), alphas_cumprod[:-1]], dim=0
)

posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )

print(alphas)
print(alphas_cumprod)
print(alphas_cumprod_prev)  
print("posterior_variance is ", posterior_variance)
print(betas.log())

(10,)
tensor([0.9800, 0.9711, 0.9622, 0.9533, 0.9444, 0.9356, 0.9267, 0.9178, 0.9089,
        0.9000], dtype=torch.float64)
tensor([0.9800, 0.9517, 0.9157, 0.8730, 0.8245, 0.7714, 0.7148, 0.6560, 0.5963,
        0.5366], dtype=torch.float64)
tensor([1.0000, 0.9800, 0.9517, 0.9157, 0.8730, 0.8245, 0.7714, 0.7148, 0.6560,
        0.5963], dtype=torch.float64)
posterior_variance is  tensor([0.0000, 0.0120, 0.0217, 0.0310, 0.0402, 0.0495, 0.0588, 0.0682, 0.0776,
        0.0871], dtype=torch.float64)
tensor([-3.9120, -3.5443, -3.2760, -3.0647, -2.8904, -2.7420, -2.6127, -2.4983,
        -2.3957, -2.3026], dtype=torch.float64)


In [53]:
def compute_alpha(beta, t):
    beta = torch.cat([torch.zeros(1), beta], dim=0)
    a = (1 - beta).cumprod(dim=0)
    print(a)
    a = a.index_select(0, t + 1)
    print(a)
    a = a.view(-1, 1, 1, 1)
    print(a)
    return a

t = (torch.ones(8) * 2)
print("t is", t)
print("the original betas is ", betas)
al = compute_alpha(betas, t.long())
print("the al is ", al)

t is tensor([2., 2., 2., 2., 2., 2., 2., 2.])
the original betas is  tensor([0.0200, 0.0289, 0.0378, 0.0467, 0.0556, 0.0644, 0.0733, 0.0822, 0.0911,
        0.1000], dtype=torch.float64)
tensor([1.0000, 0.9800, 0.9517, 0.9157, 0.8730, 0.8245, 0.7714, 0.7148, 0.6560,
        0.5963, 0.5366], dtype=torch.float64)
tensor([0.9157, 0.9157, 0.9157, 0.9157, 0.9157, 0.9157, 0.9157, 0.9157],
       dtype=torch.float64)
tensor([[[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]]], dtype=torch.float64)
the al is  tensor([[[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]],


        [[[0.9157]]]], dtype=torch.float64)
