In [1]:
import os
import torch
import torchvision
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, transforms
from torchvision.utils import save_image
from tqdm import tqdm

from model import UNet
from diffusion_discrete import DiscreteDiffusion, generate_betas

In [2]:
gpu = torch.cuda.is_available()
device = torch.device("cuda:0" if gpu else "cpu")
print("device:", device)

device: cpu


In [3]:
torch.manual_seed(50)

# training parameters
num_epochs = 10
batch_size = 128
lr = 2e-4
model = UNet(image_channels=1)
optimizer = optim.Adam(model.parameters(), lr=lr)

In [4]:
# change image size from 28 to 32 so that it is power of 2
img_size = 32

transform = transforms.Compose([
    transforms.Resize(img_size),
    transforms.ToTensor(),
])

trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)

In [5]:
betas = generate_betas(type='cosine', start=0.02, stop=1., num_steps=1000)

In [6]:
diffusion = DiscreteDiffusion(betas=betas, transition_mat_type='uniform',
                              num_bits=8, transition_bands=None, model_prediction='x_start',
                              model_output='logistic_pars', loss_type='hybrid',
                              hybrid_coeff=0.001)

  self.betas = betas = torch.tensor(betas, dtype=torch.float64)


In [7]:
model.train()
for e in range(1, num_epochs+1):
    train_loss = 0
    train_loss_vals = []
    for batch_idx, (x, _) in tqdm(enumerate(train_loader), total=len(train_loader)):
        x = x.int().to(device)
        optimizer.zero_grad()
        l = diffusion.training_losses(model, x_start=x, rng=25)
        l.backward()
        train_loss += l.item()
        train_loss_vals.append(train_loss)
        optimizer.step()
    print("t\Epoch,", e, "complete!", "\tLoss: ", train_loss / batch_idx)
        

  0%|          | 0/469 [00:32<?, ?it/s]


TypeError: linspace() received an invalid combination of arguments - got (device=torch.device, steps=int, stop=float, start=float, ), but expected one of:
 * (Tensor start, Tensor end, int steps, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Number start, Tensor end, int steps, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Tensor start, Number end, int steps, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (Number start, Number end, int steps, *, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
