In [185]:
import torch
from dalle_pytorch import DiscreteVAE, DALLE
from torchsummary import summary
import torch
import torch.nn as nn

In [169]:
vae = DiscreteVAE(
    image_size = 256,
    num_layers = 3,           # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,        # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,       # codebook dimension
    hidden_dim = 64,          # hidden dimension
    num_resnet_blocks = 2,    # number of resnet blocks
    temperature = 0.9,        # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
).to('cuda')

In [145]:
class ImageCompressor(nn.Module):
    """
    Image compressor in order to retain local information for the GAN

    Input: Image (240x360)
    Output: Vector (1x512)
    """

    def __init__(self):
        super(ImageCompressor, self).__init__()

        num_blocks = 5
        num_tokens = 4096

        layers = [
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=(1, 1)),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
            nn.ReLU(),
        ]

        for _ in range(num_blocks):
            layers.append(nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)))
            layers.append(nn.ReLU())

        layers.append(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=(1, 1)))
        layers.append(nn.Conv2d(128, num_tokens, kernel_size=1))


        self.layers = nn.Sequential(*layers)

    def forward(self, image):
        # image = torchvision.transforms.functional.to_tensor(image.numpy())
        return self.layers(image)

c = ImageCompressor().to('cuda')
summary(c, (3, 360, 240))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 180, 120]           3,136
              ReLU-2         [-1, 64, 180, 120]               0
            Conv2d-3           [-1, 64, 90, 60]          65,600
              ReLU-4           [-1, 64, 90, 60]               0
            Conv2d-5           [-1, 64, 45, 30]          16,448
              ReLU-6           [-1, 64, 45, 30]               0
            Conv2d-7           [-1, 64, 45, 30]          36,928
              ReLU-8           [-1, 64, 45, 30]               0
            Conv2d-9           [-1, 64, 45, 30]          36,928
             ReLU-10           [-1, 64, 45, 30]               0
           Conv2d-11           [-1, 64, 45, 30]          36,928
             ReLU-12           [-1, 64, 45, 30]               0
           Conv2d-13           [-1, 64, 45, 30]          36,928
             ReLU-14           [-1, 64,

In [170]:
summary(vae.encoder, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           3,136
              ReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3           [-1, 64, 64, 64]          65,600
              ReLU-4           [-1, 64, 64, 64]               0
            Conv2d-5           [-1, 64, 32, 32]          65,600
              ReLU-6           [-1, 64, 32, 32]               0
            Conv2d-7           [-1, 64, 32, 32]          36,928
              ReLU-8           [-1, 64, 32, 32]               0
            Conv2d-9           [-1, 64, 32, 32]          36,928
             ReLU-10           [-1, 64, 32, 32]               0
           Conv2d-11           [-1, 64, 32, 32]           4,160
         ResBlock-12           [-1, 64, 32, 32]               0
           Conv2d-13           [-1, 64, 32, 32]          36,928
             ReLU-14           [-1, 64,

In [171]:
summary(vae, (3, 256, 256))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 128, 128]           3,136
              ReLU-2         [-1, 64, 128, 128]               0
            Conv2d-3           [-1, 64, 64, 64]          65,600
              ReLU-4           [-1, 64, 64, 64]               0
            Conv2d-5           [-1, 64, 32, 32]          65,600
              ReLU-6           [-1, 64, 32, 32]               0
            Conv2d-7           [-1, 64, 32, 32]          36,928
              ReLU-8           [-1, 64, 32, 32]               0
            Conv2d-9           [-1, 64, 32, 32]          36,928
             ReLU-10           [-1, 64, 32, 32]               0
           Conv2d-11           [-1, 64, 32, 32]           4,160
         ResBlock-12           [-1, 64, 32, 32]               0
           Conv2d-13           [-1, 64, 32, 32]          36,928
             ReLU-14           [-1, 64,

In [184]:
images = torch.randn(1, 3, 256, 256).to('cuda')
encoded = vae.forward(images)
# vae.decode(encoded).shape
encoded.shape

torch.Size([1, 3, 256, 256])

In [186]:
dalle = DALLE(
    dim = 1024,
    vae = vae,                  # automatically infer (1) image sequence length and (2) number of image tokens
    num_text_tokens = 10000,    # vocab size for text
    text_seq_len = 256,         # text sequence length
    depth = 12,                 # should aim to be 64
    heads = 16,                 # attention heads
    dim_head = 64,              # attention head dimension
    attn_dropout = 0.1,         # attention dropout
    ff_dropout = 0.1            # feedforward dropout
)