In [1]:
%load_ext autoreload
%autoreload 1

In [2]:
%aimport zizi_pipeline
%aimport zizi_vae_pipeline

In [3]:
from diffusers import AutoencoderKL, DDPMScheduler, VQModel
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset, Dataset

import os

from PIL import Image

from accelerate import Accelerator
from tqdm.auto import tqdm
from pathlib import Path

%matplotlib inline
import matplotlib.pyplot as plt

from zizi_pipeline import get_dataloader, TrainingConfig

In [4]:
torch.manual_seed(15926)

<torch._C.Generator at 0x11f16df30>

In [5]:
config = TrainingConfig("data/pink-me/", "output/scratch-vae")

In [6]:
vae = AutoencoderKL(
    sample_size=config.image_size,
    block_out_channels=(128,256,512,512),
    down_block_types=(
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D",
        "DownEncoderBlock2D"
    ),
    up_block_types=(
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D",
        "UpDecoderBlock2D"
    ),
    latent_channels=4,
    layers_per_block=2
).to("mps")
vae.requires_grad_(True)

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_blocks): ModuleList(
      (0): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0-1): 2 x ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (nonlinearity): SiLU()
          )
        )
        (downsamplers): ModuleList(
          (0): Downsample2D(
            (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
          )
        )
      )
      (1): DownEncoderBlock2D(
        (resnets): ModuleList(
          (0): ResnetBlock2D(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (c

In [7]:
dataloader = get_dataloader(config)

In [8]:
from torch.utils.data import Subset

In [9]:
batch_set = Subset(dataloader.dataset, range(0,8))

In [10]:
batch_loader = DataLoader(batch_set, batch_size=1)

In [11]:
for step, img in enumerate(batch_loader):
    print(img['images'].shape)
    predicted_img = vae(img['images'].to("mps"))
    print(f"prediction: {predicted_img}")

torch.Size([1, 3, 512, 512])
prediction: DecoderOutput(sample=tensor([[[[ 0.0399,  0.0371,  0.0517,  ...,  0.0811,  0.0757,  0.0525],
          [ 0.0208,  0.0102,  0.0298,  ...,  0.0829,  0.0656,  0.0437],
          [ 0.0070,  0.0108,  0.0389,  ...,  0.0790,  0.0637,  0.0221],
          ...,
          [ 0.0262,  0.0301,  0.0805,  ...,  0.0514,  0.0254,  0.0079],
          [ 0.0096,  0.0248,  0.0644,  ...,  0.0190, -0.0038, -0.0048],
          [-0.0447, -0.0321, -0.0210,  ..., -0.0733, -0.0904, -0.0604]],

         [[ 0.0313,  0.0277,  0.0303,  ...,  0.0290,  0.0308,  0.0248],
          [ 0.0446,  0.0490,  0.0453,  ...,  0.0242,  0.0291,  0.0185],
          [ 0.0487,  0.0460,  0.0277,  ...,  0.0131,  0.0308,  0.0213],
          ...,
          [ 0.0298,  0.0239,  0.0304,  ...,  0.0884,  0.0608,  0.0243],
          [ 0.0304,  0.0285,  0.0339,  ...,  0.0941,  0.0746,  0.0371],
          [ 0.0566,  0.0455,  0.0596,  ...,  0.0699,  0.0544,  0.0230]],

         [[ 0.0088,  0.0079,  0.0091,  .

KeyboardInterrupt: 