In [6]:
import torch

from diffusion import get_named_beta_schedule, Diffusion
from unet import UNetModel
from trainer import TrainLoop
from datasets import load_data

In [10]:
x = torch.randn((16, 1, 32, 32))
eps = torch.randn_like(x)
t = torch.randint(0, 1000, (x.shape[0],))
alpha = torch.rand((1000, ))

In [12]:
out = alpha[t, None, None, None] * x

In [3]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

batch_size = 16
lr = 2e-4
weight_decay = 0.0
lr_anneal_steps = 10000
ema_rate = "0.9999"

image_size = 64
num_channels = 192
num_res_blocks = 3
channel_mult= (1, 2, 3, 4)
learn_sigma=False
class_cond=False
use_checkpoint=False
attention_resolutions="32,16,8"
num_heads=1
num_head_channels=64
num_heads_upsample=-1
use_scale_shift_norm=True
dropout=0.1
resblock_updown=True
use_fp16=False
use_new_attention_order=True

attention_ds = []
for res in attention_resolutions.split(","):
    attention_ds.append(image_size // int(res))


In [4]:
betas = get_named_beta_schedule("linear", num_diffusion_timesteps=1000)

In [5]:
diffusion = Diffusion(
    betas=betas
)

In [6]:
model = UNetModel(
    image_size=image_size,
    in_channels=1,
    model_channels=num_channels,
    out_channels=(1 if not learn_sigma else 6),
    num_res_blocks=num_res_blocks,
    attention_resolutions=tuple(attention_ds),
    dropout=dropout,
    channel_mult=channel_mult,
    num_classes=None,
    use_checkpoint=use_checkpoint,
    use_fp16=use_fp16,
    num_heads=num_heads,
    num_head_channels=num_head_channels,
    num_heads_upsample=num_heads_upsample,
    use_scale_shift_norm=use_scale_shift_norm,
    resblock_updown=resblock_updown,
    use_new_attention_order=use_new_attention_order,
)
model.to(device)

UNetModel(
  (time_embed): Sequential(
    (0): Linear(in_features=192, out_features=768, bias=True)
    (1): SiLU()
    (2): Linear(in_features=768, out_features=768, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(1, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1-3): 3 x TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 192, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (h_upd): Identity()
        (x_upd): Identity()
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=768, out_features=384, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32, 192, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Dropout(p=0.1, inplace=False)
          (3): Conv2d(192, 192, kernel_size=(3, 3)

In [8]:
data = load_data(
    image_size=image_size,
    batch_size=batch_size,
)

In [None]:
TrainLoop(
    model=model,
    diffusion=diffusion,
    data=data,
    batch_size=batch_size,
    lr=lr,
    ema_rate=ema_rate,
    weight_decay=weight_decay,
    lr_anneal_steps=lr_anneal_steps
).run_loop()


In [9]:
batch, _ = next(data)

In [10]:
batch.shape

torch.Size([16, 1, 64, 64])

In [15]:
import matplotlib.pyplot as plt

In [17]:
data.

TypeError: object of type 'generator' has no len()

In [19]:
data

<generator object load_data at 0x7f7c0169a3c0>

In [22]:
import torchvision
import torchvision.transforms as transforms

In [28]:
import torch.utils


dataset = torchvision.datasets.MNIST(
            root='../data',
            train=True,
            download=True,
            transform=transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.Normalize(mean=[0.50], std=[0.5])
            ])
        )
dataset = torch.utils.data.Subset(dataset, np.random.permutation(len(dataset))[:1000])

In [1]:
len(dataset)

NameError: name 'dataset' is not defined