In [1]:
import torch
from tqdm import tqdm
import numpy as np
from torch import nn
from torch.nn import functional as F

In [2]:
ar = torch.arange(2)[None, :, None].expand(
            3, -1, 1)
ar

tensor([[[0],
         [1]],

        [[0],
         [1]],

        [[0],
         [1]]])

In [3]:
def extract(v, t, x_shape):
    """
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    out = torch.gather(v, dim=0, index=t).float()
    return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))

class ResBlock(nn.Module):
    def __init__(self, attn=False):
        super().__init__()
        self.block1 = nn.Sequential(
            # nn.LazyBatchNorm1d(),
            nn.Linear(400, 400),
            nn.SiLU(),
        )
        self.block2 = nn.Sequential(
            # nn.LazyBatchNorm1d(),
            nn.Linear(400, 400),
            nn.SiLU(),
        )
        if attn:
            # self.attn = AttnBlock(out_ch)
            pass
        else:
            self.attn = nn.Identity()

    def forward(self, x):
        h = self.block1(x)
        h = self.block2(h)
        h = h + x
        h = self.attn(h)
        return h
class HyperbolicDiffusion(nn.Module):

    def __init__(self ,T = 1000,beta_1=1e-4, beta_T=0.02):
        super(HyperbolicDiffusion, self).__init__()

        self.denoise_net = nn.Sequential(
            nn.Linear(21, 400),
            ResBlock(),
            ResBlock(),
            ResBlock(),
            ResBlock(),
            nn.Linear(400, 400),
            nn.SiLU(),
            nn.Linear(400, 400),
            nn.SiLU(),
            nn.Linear(400, 20),
        )
        self.T = T

        self.register_buffer(
            'betas', torch.linspace(beta_1, beta_T, T).double())

        alphas = 1. - self.betas
        self.register_buffer(
            'sqrt_alphas', torch.sqrt(alphas))
        alphas_bar = torch.cumprod(alphas, dim=0)
        self.register_buffer(
            'sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer(
            'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))


    def forward(self,x):

        loss = self.compute_loss(x)

        return loss
    def sample(self,h):
        z_t = torch.randn_like(h)
        for t in reversed(range(self.T)):
            Time = torch.ones((h.size(0),),dtype=torch.int64, device=h.device) * t
            noise = torch.randn_like(h)
            pred_noise = self.denoise_net(torch.concat([Time[...,None,None].repeat(1,h.size(1),1),z_t],dim=2))
            sqrt_one_minus_alphas_bar = extract(self.sqrt_one_minus_alphas_bar, Time, h.shape)
            sqrt_alphas = extract(self.sqrt_alphas, Time, h.shape)
            betas = extract(self.betas, Time, h.shape)
            z_t = z_t-betas/sqrt_one_minus_alphas_bar * pred_noise
            z_t = z_t/sqrt_alphas+betas*noise
            print('t:',t,' z_t:',z_t[0])


    def compute_loss(self, h):
        t = torch.randint(self.T,size=(h.shape[0],), device=h.device)
        noise = torch.randn_like(h)
        x_t = (
                extract(self.sqrt_alphas_bar, t, h.shape) * h +
                extract(self.sqrt_one_minus_alphas_bar, t, h.shape) * noise)
        # if(t[0]>950):
        #     print(x_t[0])
        t = t[...,None,None].repeat(1,h.size(1),1)

        pred_noise = self.denoise_net(torch.concat([t,x_t],dim=2))
        loss = F.mse_loss(pred_noise, noise, reduction='mean')

        return loss

In [4]:
model = HyperbolicDiffusion()

optimizer = torch.optim.Adam(model.parameters(),0.0005)
lr_scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=2000,
    gamma=float(0.9)
)
x = torch.ones((200,1,20),dtype=torch.float32) *3
# x = torch.tensor([[1.5794, 2.5078, 0.0000, 0.0000, 1.7158, 1.0340, 0.0000, 1.3385, 0.0000,
#         0.8585, 0.9584, 0.1076, 0.4893, 0.0000, 0.0000, 0.6972, 1.2082, 2.7626,
#         0.0000, 0.0000]],device='cuda').repeat(200,1)

for i in tqdm(range(10000)):
    optimizer.zero_grad()
    loss = model(x)
    if i%1000 ==0:
        print(loss,lr_scheduler.get_last_lr())
    loss.backward()
    optimizer.step()
    lr_scheduler.step()

  0%|          | 7/10000 [00:00<05:08, 32.40it/s]

tensor(42.3208, grad_fn=<MseLossBackward0>) [0.0005]


  0%|          | 17/10000 [00:00<05:34, 29.85it/s]


KeyboardInterrupt: 

In [None]:
model.sample(x)

In [None]:
# def extract(v, t, x_shape):
#     """
#     Extract some coefficients at specified timesteps, then reshape to
#     [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
#     """
#     out = torch.gather(v, dim=0, index=t).float()
#     return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
#
#
# class HyperbolicDiffusion(nn.Module):
#
#     def __init__(self, T=1000, beta_1=1e-4, beta_T=0.02):
#         super(HyperbolicDiffusion, self).__init__()
#
#         self.denoise_net = nn.Sequential(
#             nn.Linear(31, 300),
#             nn.ReLU(),
#             nn.Linear(300, 300),
#             nn.ReLU(),
#             nn.ReLU(),
#             nn.Linear(300, 300),
#             nn.ReLU(),
#             nn.Linear(300, 30),
#         )
#         self.T = T
#
#         self.register_buffer(
#             'betas', torch.linspace(beta_1, beta_T, T).double())
#
#         alphas = 1. - self.betas
#         self.register_buffer(
#             'sqrt_alphas', torch.sqrt(alphas))
#         alphas_bar = torch.cumprod(alphas, dim=0)
#         self.register_buffer(
#             'sqrt_alphas_bar', torch.sqrt(alphas_bar))
#         self.register_buffer(
#             'sqrt_one_minus_alphas_bar', torch.sqrt(1. - alphas_bar))
#
#     def forward(self):
#         loss = self.compute_loss(torch.ones((20, 30), dtype=torch.float32))
#
#         return loss
#
#     def sample(self, h):
#         z_t = torch.randn_like(h)
#         for t in reversed(range(self.T)):
#             Time = torch.ones((h.size(0),), dtype=torch.int64) * t
#             noise = torch.randn_like(h)
#             pred_noise = self.denoise_net(torch.concat([Time[..., None], z_t], dim=1))
#             sqrt_one_minus_alphas_bar = extract(self.sqrt_one_minus_alphas_bar, Time, h.shape)
#             sqrt_alphas = extract(self.sqrt_alphas, Time, h.shape)
#             betas = extract(self.betas, Time, h.shape)
#             z_t = z_t - betas / sqrt_one_minus_alphas_bar * pred_noise
#             z_t = z_t / sqrt_alphas + betas * noise
#             print('t:', t, ' z_t:', z_t[0])
#
#     def compute_loss(self, h):
#         t = torch.randint(self.T, size=(h.shape[0],), device=h.device)
#         noise = torch.randn_like(h)
#         x_t = (
#                 extract(self.sqrt_alphas_bar, t, h.shape) * h +
#                 extract(self.sqrt_one_minus_alphas_bar, t, h.shape) * noise)
#         t = t[..., None]
#         pred_noise = self.denoise_net(torch.concat([t, x_t], dim=1))
#         loss = F.mse_loss(pred_noise, noise, reduction='mean')
#
#         return loss
#
#
# model = HyperbolicDiffusion()
#
# optimizer = torch.optim.Adam(model.parameters(), 0.001)
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
#     optimizer,
#     step_size=2000,
#     gamma=float(0.9)
# )
# for i in range(10000):
#     optimizer.zero_grad()
#     loss = model().sum()
#     print(loss, lr_scheduler.get_last_lr())
#     loss.backward()
#     optimizer.step()
#     lr_scheduler.step()