In [1]:
import torch
from dalle_pytorch import DiscreteVAE, DALLE
from torchsummary import summary
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets
from dataloader import CustomImageDataLoader
import matplotlib.pyplot as plt

In [2]:
vae = DiscreteVAE(
    image_size = 512,
    num_layers = 3,           # number of downsamples - ex. 256 / (2 ** 3) = (32 x 32 feature map)
    num_tokens = 8192,        # number of visual tokens. in the paper, they used 8192, but could be smaller for downsized projects
    codebook_dim = 512,       # codebook dimension
    hidden_dim = 64,          # hidden dimension
    num_resnet_blocks = 2,    # number of resnet blocks
    temperature = 0.9,        # gumbel softmax temperature, the lower this is, the harder the discretization
    straight_through = False, # straight-through for gumbel softmax. unclear if it is better one way or the other
).to('cuda')

enc_chans_io, dec_chans_io: [(3, 64), (64, 64), (64, 64)] [(64, 64), (64, 64), (64, 64)]
dec_layers: 512 64
normalization: ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))


In [3]:
img = torch.randn(1, 3, 512, 512).to('cuda')

f = vae.forward(img)

torch.Size([1, 8192, 64, 64])
temp: 0.9
sampled torch.Size([1, 512, 64, 64])


In [3]:
summary(vae, (3, 512, 512))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 64, 256, 256]           3,136
              ReLU-2         [-1, 64, 256, 256]               0
            Conv2d-3         [-1, 64, 128, 128]          65,600
              ReLU-4         [-1, 64, 128, 128]               0
            Conv2d-5           [-1, 64, 64, 64]          65,600
              ReLU-6           [-1, 64, 64, 64]               0
            Conv2d-7           [-1, 64, 64, 64]          36,928
              ReLU-8           [-1, 64, 64, 64]               0
            Conv2d-9           [-1, 64, 64, 64]          36,928
             ReLU-10           [-1, 64, 64, 64]               0
           Conv2d-11           [-1, 64, 64, 64]           4,160
         ResBlock-12           [-1, 64, 64, 64]               0
           Conv2d-13           [-1, 64, 64, 64]          36,928
             ReLU-14           [-1, 64,

In [4]:
path_images = 'F:/DATASETS/original/images'

In [9]:
dataset = CustomImageDataLoader(path_dataset=path_images)

dataloader = DataLoader(
    dataset,
    batch_size=4,
    shuffle=False,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
 )

In [6]:
def padd_image(image):

    batch_size = image.shape[0]

    image = image.resize(batch_size, 3, 240, 360)

    source = (240, 360)
    objective = (512, 512)

    num_blank_rows = objective[1] - source[1]
    num_pad_col = objective[0] - source[0]

    top = torch.zeros(batch_size, 3, objective[0], num_blank_rows//2)
    bot = torch.zeros(batch_size, 3, objective[0], num_blank_rows//2)

    left = torch.zeros(batch_size, 3, num_pad_col//2, source[1])
    right = torch.zeros(batch_size, 3, num_pad_col//2, source[1])

    # print(left.shape, image.shape)

    concated = torch.concat((left, image), dim=2)
    concated = torch.concat((concated, right), dim=2)

    concated = torch.concat((top, concated), dim=3)
    concated = torch.concat((concated, bot), dim=3)


    return concated

In [7]:
%%time
print('test')

# %time

test
CPU times: total: 0 ns
Wall time: 0 ns


In [11]:
%%time

i = 0

for batch in dataloader:
    padded = (padd_image(batch['image'].resize(4, 3, 360, 240))).to('cuda')

    loss = vae(padded, return_loss=True)
    print(loss)
    loss.backward()

    i += 1

    if i == 100:
        break


RuntimeError: CUDA out of memory. Tried to allocate 512.00 MiB (GPU 0; 8.00 GiB total capacity; 6.70 GiB already allocated; 0 bytes free; 6.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF