### Spatially-Aware Just-in-Time Autoregressive Diffusion

Convert an image into patches

Apply noise jointly across spatial and temporal domains onto patches

Predict the noise 

Lock in cell in the center and save that as the image

133 seq_len

32 step diffusion

100   1        99

past  current  future

300 max seq_len for images

sample rand from -32-368

In [1]:
from PIL import Image
import numpy as np
import scipy.io
import gc

data_path = "../data/jpg/image_00001.jpg"
label_path = "../data/jpg/imagelabels.mat"

image = Image.open(data_path)
#image.show()

i = np.array(image)
print(i.shape)
mat = scipy.io.loadmat(label_path)
mat = mat['labels'][0]
print(mat)

mat = set(mat)
print(mat)



(500, 591, 3)
[77 77 77 ... 62 62 62]
{np.uint8(1), np.uint8(2), np.uint8(3), np.uint8(4), np.uint8(5), np.uint8(6), np.uint8(7), np.uint8(8), np.uint8(9), np.uint8(10), np.uint8(11), np.uint8(12), np.uint8(13), np.uint8(14), np.uint8(15), np.uint8(16), np.uint8(17), np.uint8(18), np.uint8(19), np.uint8(20), np.uint8(21), np.uint8(22), np.uint8(23), np.uint8(24), np.uint8(25), np.uint8(26), np.uint8(27), np.uint8(28), np.uint8(29), np.uint8(30), np.uint8(31), np.uint8(32), np.uint8(33), np.uint8(34), np.uint8(35), np.uint8(36), np.uint8(37), np.uint8(38), np.uint8(39), np.uint8(40), np.uint8(41), np.uint8(42), np.uint8(43), np.uint8(44), np.uint8(45), np.uint8(46), np.uint8(47), np.uint8(48), np.uint8(49), np.uint8(50), np.uint8(51), np.uint8(52), np.uint8(53), np.uint8(54), np.uint8(55), np.uint8(56), np.uint8(57), np.uint8(58), np.uint8(59), np.uint8(60), np.uint8(61), np.uint8(62), np.uint8(63), np.uint8(64), np.uint8(65), np.uint8(66), np.uint8(67), np.uint8(68), np.uint8(69), np.u

In [2]:
import numpy as np
np.resize(np.zeros((1,1)), (3,2))

array([[0., 0.],
       [0., 0.],
       [0., 0.]])

In [3]:
from einops import rearrange

MAX_LEN = 682

def img_norm(img):
    return img / 255

def img_crop(img, patch_size):
    height= (img.shape[0]//patch_size)*patch_size
    width = (img.shape[1]//patch_size)*patch_size
    return img[:height, :width, :]

def get_dataset(root, label_path, patch_size, future_len):
    labels = scipy.io.loadmat(label_path)
    labels = labels['labels'][0]
    dataset = []
    l = []
    for i, idx in enumerate(labels):
        fp = root +"image_"+str(i+1).rjust(5,'0')+".jpg"
        image = Image.open(fp)
        image = img_norm(img_crop(np.array(image), patch_size))

        patches = rearrange(image, '(h p1) (w p2) c ->  (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)
        plen = len(patches)
        l.append(plen)
        mask = [0] * (MAX_LEN+future_len)
        mask[1:plen+1] = [i+3 for i in range(plen)]
        mask[0] = 1
        mask[plen] = 2
        mask = np.pad(mask, (future_len,0), mode="constant", constant_values=0)

        dataset.append(
            {
                'patches':patches,
                'label':idx,
                'mask':mask
            }
        )
    print(patches.shape)
    print(len(mask))
    print(mask)
    print(len(dataset))
    print(max(l))
    print(min(l))
    print(np.average(l))
    return dataset

In [4]:
from torch.utils.data import Dataset, DataLoader

class FlowerDataset(Dataset):
    def __init__(self,
                 patch_size = 32,
                 future_len = 99,
                 label_path = "../data/jpg/imagelabels.mat", 
                 root = "../data/jpg/"):
        self.dataset = get_dataset(root, label_path, patch_size, future_len)
        self.future_len = future_len

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        
        patches = self.dataset[idx]['patches']
        mask = self.dataset[idx]['mask']
        label = self.dataset[idx]['label']
        offset = np.random.randint(0,  patches.shape[0] + 2*self.future_len)
        return {
            'patches': patches,
            'mask': mask,
            'label': label,
            'pos':offset
        }
    
trainset = FlowerDataset()
# I don't want to pad / resize shit - so batch size 1 for now...
trainloader = DataLoader(trainset, batch_size=1, shuffle=True) 

(300, 3072)
880
[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
   0   0   0   0   0   0   0   0   0   1   3   4   5   6   7   8   9  10
  11  12  13  14  15  16  17  18  19  20  21  22  23  24  25  26  27  28
  29  30  31  32  33  34  35  36  37  38  39  40  41  42  43  44  45  46
  47  48  49  50  51  52  53  54  55  56  57  58  59  60  61  62  63  64
  65  66  67  68  69  70  71  72  73  74  75  76  77  78  79  80  81  82
  83  84  85  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100
 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
 137 138 139 140 141 142 143 144 14

In [16]:
import torch
a = torch.tensor([[[1,2,3]]])  # special value embeddings
b = torch.zeros((1, 10, 3))    # image patches

c = torch.concat([a,b], dim=1) # Unified dictionary

c[0][torch.tensor([0,2,3])]    # Unified grabbing of information

tensor([[1., 2., 3.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [9]:
import torch
from vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

In [8]:
from diffusion import ImageAutoregressiveDiffusion
import torch

model = ImageAutoregressiveDiffusion(
    model = dict(
        dim = 1024,
        depth = 12,
        heads = 12,
    ),
    image_size = 64,
    patch_size = 8
)

images = torch.randn(3, 3, 64, 64)

loss = model(images)
loss.backward()

sampled = model.sample(batch_size = 3)

assert sampled.shape == images.shape

sampling time step: 100%|██████████| 32/32 [00:00<00:00, 532.84it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 541.82it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 534.10it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 563.23it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 533.07it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 579.72it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 573.33it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 580.69it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 539.44it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 556.54it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 538.63it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 532.27it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 558.25it/s]
sampling time step: 100%|██████████| 32/32 [00:00<00:00, 563.41it/s]
sampling time step: 100%|█████████