In [1]:
import torch

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
from time_embedding import get_timestep_embedding

### Timestep Embedding

In [4]:
t = torch.zeros(100).long()

In [5]:
emb = get_timestep_embedding(t, 64)
emb.shape

torch.Size([100, 64])

In [6]:
emb

tensor([[0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        ...,
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.],
        [0., 0., 0.,  ..., 1., 1., 1.]])

### Downsampling

In [7]:
from down_sampling import Downsample

In [8]:
down_sampler = Downsample(3)

In [9]:
reduced_sample = down_sampler.forward(torch.ones((1, 3, 10, 10)))
reduced_sample.shape

torch.Size([1, 3, 5, 5])

### Upsampling

In [10]:
from up_sampling import Upsample

In [11]:
up_sampler = Upsample(3)
increased_sample = up_sampler.forward(torch.ones((1, 3, 10, 10)))
increased_sample.shape

torch.Size([1, 3, 20, 20])

## ResNet

In [19]:
from resnet import Nin
from resnet import ResNetBlock

In [20]:
img = torch.randn((10, 64, 16, 16))

t = (torch.rand(10) * 10).long()
temb = get_timestep_embedding(t, 512)

nin = Nin(in_dim=64, out_dim=128)

upsample = Upsample(64)
downsample = Downsample(64)

In [21]:
print(img.shape)
h = downsample(img)
print(h.shape)

img = upsample(h)
print(img.shape)

torch.Size([10, 64, 16, 16])
torch.Size([10, 64, 8, 8])
torch.Size([10, 64, 16, 16])


In [22]:
img = nin.forward(img)
print(img.shape)

torch.Size([10, 128, 16, 16])


In [23]:
resn = ResNetBlock(128, 128, 0.1)
img = resn(img, temb)
print(img.shape)

resn = ResNetBlock(128, 64, 0.1)
img = resn(img, temb)
print(img.shape)

torch.Size([10, 128, 16, 16])
torch.Size([10, 64, 16, 16])


In [28]:
from resnet import AttentionBlock

In [29]:
att = AttentionBlock(64)
img = att(img)
print(img.shape)

torch.Size([10, 64, 16, 16])


# U-Net

In [30]:
from unet import UNet

In [31]:
img = torch.randn((10, 1, 32, 32))
model = UNet()
img = model(img, t)

In [32]:
img.shape

torch.Size([10, 1, 32, 32])

In [33]:
nbr_params = sum([p.numel() for p in model.parameters() ])
print(f"Number of parameters: {nbr_params}")

Number of parameters: 35713281
