In [1]:
import torch
from score_models import TimeConditionalScoreNet
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def forward_pdf_std(t, sigma):
    # forward SDE: dx_t = \sigma^t*dw, 0<=t<=1
    # forward pdf: p(x_t|x_0) = N(x_t| x_0, (\sigma^{2t}-1) / (2*\log(\sigma)) * I )
    # t: (B, )
    # sigma: (,)

    return torch.sqrt((sigma**(2*t)-1.0) / (2.0*torch.log(sigma)))  # (B, )

def diffusion_coeff(t, sigma):
    # t: (B, )
    # sigma: (,)

    # forward SDE: dx_t = \sigma^t dw, 0<=t<=1  
    return sigma**t  # (B, )

In [3]:
n_epochs = 200
batch_size = 128
lr = 1e-4
sigma = 25.0
sigma = torch.tensor(sigma, dtype=torch.float, device=device)

dataset = datasets.MNIST("./data", train=True, transform=transforms.ToTensor(), download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [4]:
score_model = TimeConditionalScoreNet(1, forward_pdf_std, sigma)
score_model = score_model.to(device)

optimizer = torch.optim.Adam(score_model.parameters(), lr=lr)

In [5]:
for epoch in range(n_epochs):
    avg_loss = 0.0
    num_items = 0.0

    for x, _ in dataloader:
        x = x.to(device)
        loss = score_model.compute_loss(x)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss += loss.item() * x.shape[0]
    num_items += x.shape[0]
    print('Epoch: {}, Loss: {:5f}'.format(epoch+1, avg_loss / num_items))
    torch.save(score_model.state_dict(), 'ckpt/ckpt.pt')


Epoch: 1, Loss: 314.374939
Epoch: 2, Loss: 158.344040
Epoch: 3, Loss: 109.345642
Epoch: 4, Loss: 82.262329
Epoch: 5, Loss: 71.648376
Epoch: 6, Loss: 59.525208
Epoch: 7, Loss: 47.975246
Epoch: 8, Loss: 41.609844
Epoch: 9, Loss: 34.429382
Epoch: 10, Loss: 43.747734
Epoch: 11, Loss: 28.584015
Epoch: 12, Loss: 33.076504
Epoch: 13, Loss: 30.180569
Epoch: 14, Loss: 22.704376
Epoch: 15, Loss: 34.027290
Epoch: 16, Loss: 27.341469
Epoch: 17, Loss: 34.070740
Epoch: 18, Loss: 31.403898
Epoch: 19, Loss: 33.557079
Epoch: 20, Loss: 30.518291
Epoch: 21, Loss: 17.882357
Epoch: 22, Loss: 23.403126
Epoch: 23, Loss: 21.626297
Epoch: 24, Loss: 22.596424
Epoch: 25, Loss: 22.478292
Epoch: 26, Loss: 22.190765
Epoch: 27, Loss: 23.334789
Epoch: 28, Loss: 22.203123
Epoch: 29, Loss: 21.526402
Epoch: 30, Loss: 21.350180
Epoch: 31, Loss: 25.736891
Epoch: 32, Loss: 25.540096
Epoch: 33, Loss: 19.078289
Epoch: 34, Loss: 19.630146
Epoch: 35, Loss: 20.734842
Epoch: 36, Loss: 17.619423
Epoch: 37, Loss: 18.307384
Epoch: 