In [1]:
import torch
from tqdm import tqdm
import numpy as np
from torch import nn
from torch.nn import functional as F

In [2]:
def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    out = torch.gather(v, dim=0, index=t).float()
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

class ResBlock(nn.Module):
    def __init__(self, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            # nn.LazyBatchNorm1d(),
            nn.Linear(400, 400),
            nn.SiLU(),
        )
        self.block2 = nn.Sequential(
            # nn.LazyBatchNorm1d(),
            nn.Linear(400, 400),
            nn.SiLU(),
        )
        if attn:
            # self.attn = AttnBlock(out_ch)
            pass
        else:
            self.attn = nn.Identity()

    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        h = h + x
        h = self.attn(h)
        return h
class HyperbolicDiffusion(nn.Module):

    def __init__(self ,T = 1000,beta_1=1e-4, beta_T=0.02):
        super(HyperbolicDiffusion, self).__init__()

        self.denoise_net = nn.Sequential(
            nn.Linear(21, 400),
            ResBlock(),
            ResBlock(),
            ResBlock(),
            ResBlock(),
            nn.Linear(400, 400),
            nn.SiLU(),
            nn.Linear(400, 400),
            nn.SiLU(),
            nn.Linear(400, 20),
        )
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())

        alphas = 1. - self.betas
        self.register_buffer(
            'sqrt_alphas', torch.sqrt(alphas))
        alphas_bar = torch.cumprod(alphas, dim=0)
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))


    def forward(self,x):

        loss = self.compute_loss(x)

        return loss
    def sample(self,h):
        z_t = torch.randn_like(h)
        for t in reversed(range(self.T)):
            Time = torch.ones((h.size(0),),dtype=torch.int64, device=h.device) * t
            noise = torch.randn_like(h)
            pred_noise = self.denoise_net(torch.concat([Time[...,None,None].repeat(1,h.size(1),1),z_t],dim=2))
            sqrt_one_minus_alphas_bar = extract(self.sqrt_one_minus_alphas_bar, Time, h.shape)
            sqrt_alphas = extract(self.sqrt_alphas, Time, h.shape)
            betas = extract(self.betas, Time, h.shape)
            z_t = z_t-betas/sqrt_one_minus_alphas_bar * pred_noise
            z_t = z_t/sqrt_alphas+betas*noise
            print('t:',t,' z_t:',z_t[0])


    def compute_loss(self, h):
        t = torch.randint(self.T,size=(h.shape[0],), device=h.device)
        noise = torch.randn_like(h)
        x_t = (
                extract(self.sqrt_alphas_bar, t, h.shape) * h +
                extract(self.sqrt_one_minus_alphas_bar, t, h.shape) * noise)
        # if(t[0]>950):
        #     print(x_t[0])
        t = t[...,None,None].repeat(1,h.size(1),1)

        pred_noise = self.denoise_net(torch.concat([t,x_t],dim=2))
        loss = F.mse_loss(pred_noise, noise, reduction='mean')

        return loss

In [3]:
model = HyperbolicDiffusion()

optimizer = torch.optim.Adam(model.parameters(),0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=2000,
    gamma=float(0.9)
)
x = torch.ones((200,1,20),dtype=torch.float32) *3
# x = torch.tensor([[1.5794, 2.5078, 0.0000, 0.0000, 1.7158, 1.0340, 0.0000, 1.3385, 0.0000,
#         0.8585, 0.9584, 0.1076, 0.4893, 0.0000, 0.0000, 0.6972, 1.2082, 2.7626,
#         0.0000, 0.0000]],device='cuda').repeat(200,1)

for i in tqdm(range(10000)):
    optimizer.zero_grad()
    loss = model(x)
    if i%1000 ==0:
        print(loss,lr_scheduler.get_last_lr())
    loss.backward()
    optimizer.step()
    lr_scheduler.step()

  0%|          | 4/10000 [00:00<04:44, 35.08it/s]

tensor(29.1013, grad_fn=<MseLossBackward0>) [0.0005]


 10%|█         | 1005/10000 [00:34<04:44, 31.62it/s]

tensor(0.4409, grad_fn=<MseLossBackward0>) [0.0005]


 20%|██        | 2004/10000 [01:05<04:16, 31.19it/s]

tensor(0.1344, grad_fn=<MseLossBackward0>) [0.00045000000000000004]


 30%|███       | 3004/10000 [01:36<03:36, 32.38it/s]

tensor(0.0458, grad_fn=<MseLossBackward0>) [0.00045000000000000004]


 40%|████      | 4004/10000 [02:07<03:13, 30.99it/s]

tensor(0.0325, grad_fn=<MseLossBackward0>) [0.00040500000000000003]


 50%|█████     | 5004/10000 [02:38<02:33, 32.45it/s]

tensor(0.0195, grad_fn=<MseLossBackward0>) [0.00040500000000000003]


 60%|██████    | 6004/10000 [03:08<02:04, 32.21it/s]

tensor(0.0245, grad_fn=<MseLossBackward0>) [0.0003645]


 70%|███████   | 7006/10000 [03:39<01:27, 34.06it/s]

tensor(0.0211, grad_fn=<MseLossBackward0>) [0.0003645]


 80%|████████  | 8006/10000 [04:09<00:57, 34.69it/s]

tensor(0.0142, grad_fn=<MseLossBackward0>) [0.00032805000000000003]


 90%|█████████ | 9006/10000 [04:39<00:29, 34.23it/s]

tensor(0.0116, grad_fn=<MseLossBackward0>) [0.00032805000000000003]


100%|██████████| 10000/10000 [05:09<00:00, 32.30it/s]


In [4]:
model.sample(x)

t: 999  z_t: tensor([[-0.6249,  0.7678, -2.3885, -0.9662,  0.4536, -2.5512, -1.2631,  1.4147,
          0.5552,  0.9829,  0.2241, -0.8008, -0.6211,  0.4476, -1.5509, -0.5470,
         -1.0335,  0.0652, -0.1625,  0.4889]], grad_fn=<SelectBackward0>)
t: 998  z_t: tensor([[-0.5814,  0.7482, -2.3631, -0.9541,  0.4319, -2.5093, -1.2416,  1.3771,
          0.5566,  0.9797,  0.2426, -0.8164, -0.6293,  0.4373, -1.5410, -0.5960,
         -1.0115,  0.0732, -0.1630,  0.4835]], grad_fn=<SelectBackward0>)
t: 997  z_t: tensor([[-0.5558,  0.7089, -2.3532, -0.9626,  0.4372, -2.5053, -1.2375,  1.3546,
          0.5375,  0.9572,  0.2554, -0.7941, -0.6261,  0.4299, -1.4999, -0.5931,
         -0.9962,  0.0843, -0.1732,  0.4591]], grad_fn=<SelectBackward0>)
t: 996  z_t: tensor([[-0.5398,  0.6711, -2.3351, -0.9499,  0.4235, -2.4639, -1.2199,  1.3578,
          0.5228,  0.9585,  0.2477, -0.7827, -0.6622,  0.4086, -1.5021, -0.5688,
         -0.9928,  0.1075, -0.1758,  0.4810]], grad_fn=<SelectBackward0>)
t: 9

In [5]:
# def extract(v, t, x_shape):
#     """
#     Extract some coefficients at specified timesteps, then reshape to
#     [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
#     """
#     out = torch.gather(v, dim=0, index=t).float()
#     return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
#
#
# class HyperbolicDiffusion(nn.Module):
#
#     def __init__(self, T=1000, beta_1=1e-4, beta_T=0.02):
#         super(HyperbolicDiffusion, self).__init__()
#
#         self.denoise_net = nn.Sequential(
#             nn.Linear(31, 300),
#             nn.ReLU(),
#             nn.Linear(300, 300),
#             nn.ReLU(),
#             nn.ReLU(),
#             nn.Linear(300, 300),
#             nn.ReLU(),
#             nn.Linear(300, 30),
#         )
#         self.T = T
#
#         self.register_buffer(
#             'betas', torch.linspace(beta_1, beta_T, T).double())
#
#         alphas = 1. - self.betas
#         self.register_buffer(
#             'sqrt_alphas', torch.sqrt(alphas))
#         alphas_bar = torch.cumprod(alphas, dim=0)
#         self.register_buffer(
#             'sqrt_alphas_bar', torch.sqrt(alphas_bar))
#         self.register_buffer(
#             'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
#
#     def forward(self):
#         loss = self.compute_loss(torch.ones((20, 30), dtype=torch.float32))
#
#         return loss
#
#     def sample(self, h):
#         z_t = torch.randn_like(h)
#         for t in reversed(range(self.T)):
#             Time = torch.ones((h.size(0),), dtype=torch.int64) * t
#             noise = torch.randn_like(h)
#             pred_noise = self.denoise_net(torch.concat([Time[..., None], z_t], dim=1))
#             sqrt_one_minus_alphas_bar = extract(self.sqrt_one_minus_alphas_bar, Time, h.shape)
#             sqrt_alphas = extract(self.sqrt_alphas, Time, h.shape)
#             betas = extract(self.betas, Time, h.shape)
#             z_t = z_t - betas / sqrt_one_minus_alphas_bar * pred_noise
#             z_t = z_t / sqrt_alphas + betas * noise
#             print('t:', t, ' z_t:', z_t[0])
#
#     def compute_loss(self, h):
#         t = torch.randint(self.T, size=(h.shape[0],), device=h.device)
#         noise = torch.randn_like(h)
#         x_t = (
#                 extract(self.sqrt_alphas_bar, t, h.shape) * h +
#                 extract(self.sqrt_one_minus_alphas_bar, t, h.shape) * noise)
#         t = t[..., None]
#         pred_noise = self.denoise_net(torch.concat([t, x_t], dim=1))
#         loss = F.mse_loss(pred_noise, noise, reduction='mean')
#
#         return loss
#
#
# model = HyperbolicDiffusion()
#
# optimizer = torch.optim.Adam(model.parameters(), 0.001)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
#     optimizer,
#     step_size=2000,
#     gamma=float(0.9)
# )
# for i in range(10000):
#     optimizer.zero_grad()
#     loss = model().sum()
#     print(loss, lr_scheduler.get_last_lr())
#     loss.backward()
#     optimizer.step()
#     lr_scheduler.step()