In [1]:
import torch
from torch import nn

from matplotlib import pyplot as plt

from tqdm.notebook import tqdm

from datasets import get_dataset
from models import NCSN
from losses import NCSNLoss, GaussianPerturbation
from train import Trainer
from samplers import AnnealedLangevinDynamics

In [2]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(device)

mps


In [3]:
train_loader, test_loader = get_dataset("mnist", batch_size=128, num_workers=6, pin_memory=False)
num_cls = 10
shape = (28, 28, 1)

In [4]:
sigma0 = 0.01
r = 2
L = 12

sigmas = torch.as_tensor([sigma0 * r ** i for i in range(L)], device=device)
print(sigmas)

tensor([1.0000e-02, 2.0000e-02, 4.0000e-02, 8.0000e-02, 1.6000e-01, 3.2000e-01,
        6.4000e-01, 1.2800e+00, 2.5600e+00, 5.1200e+00, 1.0240e+01, 2.0480e+01],
       device='mps:0')


In [5]:
model = NCSN(shape, num_cls=num_cls, filters=16)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

In [6]:
perturbation = GaussianPerturbation()
loss_fn = NCSNLoss(perturbation, sigmas, coeff_func=lambda s: s**2, K=1)

In [7]:
trainer = Trainer(train_loader, model, device)

trainer.train(loss_fn, optimizer, verbose=True)

  0%|          | 0/468 [00:00<?, ?batch/s]

KeyboardInterrupt: 

In [None]:
sampler = AnnealedLangevinDynamics(model, sigmas, device)

In [None]:
img = sampler.sample(shape)
plt.imshow(img, cmap="gray")
plt.show()