### Spatially-Aware Just-in-Time Autoregressive Diffusion

Convert an image into patches

Apply noise jointly across spatial and temporal domains onto patches

Predict the noise 

Lock in cell in the center and save that as the image

150 step diffusion

150   1        150

past  current  future

Technically this is doing two step attending - one attending only for past - creating cond sequence

cond sequence then fed into future attending to generate diffusion results

Autoregressive Diffusion inference speed sped up from O(nm) to O(n+m)!!!

# TODO

Implement EMA
Double check parts deviating from Lucid's implementation
- difformer line 293 to 310
Implement position embedding / absolute position embedding
- Sinusiodal embeddings need to be updated to work with multi-timestep logic



In [1]:
from PIL import Image
import numpy as np
import scipy.io
import gc
from tqdm import tqdm
import numpy as np

PATCH_SIZE = 8
SAMPLE_STEPS = 256
WINDOW_SIZE = SAMPLE_STEPS
SAMPLE_SIZE = 256
data_path = "../data/jpg/image_00001.jpg"
label_path = "../data/jpg/imagelabels.mat"
device = "cuda"


In [2]:
# import os, sys

# path = "../data/jpg/"
# dirs = os.listdir( path )

# def resize():
#     for item in dirs:
#         if ".jpg" in item and "resized" not in item:
#             if os.path.isfile(path+item):
#                 im = Image.open(path+item)
#                 f, e = os.path.splitext(path+item)
#                 imResize = im.resize((200,200), Image.LANCZOS)
#                 imResize.save(f + '_resized.jpg', 'JPEG', quality=90)

# resize()

In [3]:
from torch.utils.data import Dataset, DataLoader
from einops import rearrange

def img_norm(img):
    return img / 255

def img_crop(img, patch_size):
    height= (img.shape[0]//patch_size)*patch_size
    width = (img.shape[1]//patch_size)*patch_size
    plen = (img.shape[0]//patch_size) * (img.shape[1]//patch_size)
    return img[:height, :width, :], plen

def get_dataset(root, label_path, patch_size, sample_steps):
    labels = scipy.io.loadmat(label_path)
    labels = labels['labels'][0]
    dataset = []
    l = []
    for i, idx in enumerate(tqdm(labels)):
        fp = root +"image_"+str(i+1).rjust(5,'0')+"_resized.jpg"
        f = open(fp, 'rb')
        image = Image.open(f)
        image, plen = img_crop(np.array(image), patch_size)
        l.append(plen)
        mask = [0] * (plen+sample_steps)
        mask[1:plen+1] = [i+3 for i in range(plen)]
        mask[0] = 1
        mask[plen] = 2
        mask = np.pad(mask, (sample_steps,0), mode="constant", constant_values=0)

        dataset.append(
            {
                'img':image,
                'label':idx,
                'mask':mask,
                'plen':plen
            }
        )
        del mask
        del image

    print(max(l))
    return dataset

# Oxford flowers dataset 
class FlowerDataset(Dataset):
    def __init__(self,
                 patch_size = 8,
                 sample_steps = 99,
                 label_path = "../data/jpg/imagelabels.mat", 
                 root = "../data/jpg/"):
        self.dataset = get_dataset(root, label_path, patch_size, sample_steps)
        self.sample_steps = sample_steps

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        plen = self.dataset[idx]['plen']
        offset = np.random.randint(self.sample_steps+1,  self.sample_steps + plen-1)
        return {
            'img': self.dataset[idx]['img'],
            'mask': self.dataset[idx]['mask'],
            'label': self.dataset[idx]['label'] - 1,
            'offset':offset
        }
    
trainset = FlowerDataset(patch_size = PATCH_SIZE , sample_steps = SAMPLE_STEPS)
trainloader = DataLoader(trainset, batch_size=48) 

100%|██████████| 8189/8189 [00:03<00:00, 2440.24it/s]

625





In [4]:
from difformer import ArSpImageDiffusion
import torch

model = ArSpImageDiffusion(
    model = dict(
        dim = 1024,
    ),
    patch_size = PATCH_SIZE,
    num_classes = 102,
    window_size = WINDOW_SIZE,
    sample_steps = SAMPLE_STEPS,
    sample_size = SAMPLE_SIZE
)
model.to(device)


ArSpImageDiffusion(
  (model): ArSpDiffusion(
    (label_embedding): Embedding(102, 192)
    (proj_in): Linear(in_features=192, out_features=1024, bias=True)
    (transformer): Decoder(
      (layers): ModuleList(
        (0): ModuleList(
          (0): ModuleList(
            (0): LayerNorm(
              (ln): LayerNorm((1024,), eps=1e-05, elementwise_affine=False)
            )
            (1-2): 2 x None
          )
          (1): Attention(
            (to_q): Linear(in_features=1024, out_features=512, bias=False)
            (to_k): Linear(in_features=1024, out_features=512, bias=False)
            (to_v): Linear(in_features=1024, out_features=512, bias=False)
            (attend): Attend(
              (attn_dropout): Dropout(p=0.0, inplace=False)
            )
            (to_out): Linear(in_features=512, out_features=1024, bias=False)
          )
          (2): Residual()
        )
        (1): ModuleList(
          (0): ModuleList(
            (0): LayerNorm(
              (l

In [5]:
from einops import rearrange, repeat, reduce, pack, unpack

sigmas = model.model.diffusion.sample_schedule()
spatial = model.model.diffusion.sample_spatial(80)

sigma = repeat(sigmas[spatial], "d -> b d 1", b = 1)

print(sigma.shape)


# gammas = torch.where(
#     (sigmas >= self.S_tmin) & (sigmas <= self.S_tmax),
#     min(self.S_churn / sample_steps, sqrt(2) - 1),
#     0.
# )


# sigma = repeat(sigmas[spatial], "d -> b d 1", b = shape[0])
# gamma = repeat(gammas[spatial], "d -> b d 1", b = shape[0])
# sigma_next = sigma - 1

torch.Size([1, 256, 1])


In [6]:
def train(model, dataloader, optimizer):
    model.train()
    running_loss = 0
    total_steps = 0
    for i, b in enumerate(tqdm(dataloader)):
        img = b['img'].to(device).float()
        mask = b['mask'].to(device).int()
        label = b['label'].to(device).int()
        offset = b['offset'].int()

        optimizer.zero_grad()
        loss = model(img, mask, offset, label)
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        total_steps += 1
    return running_loss/total_steps

In [7]:
def inference(model, num_images):
    model.eval()
    for l in range(101):
        for j in range(num_images):
            sampled = model.sample(batch_size = 1, label=torch.tensor(l).to(device))
            img = Image.fromarray(sampled.squeeze().cpu().numpy(), 'RGB')
            img.save("./results/"+str(l)+"_"+str(j)+".jpg")


In [8]:
import torch.optim as optim

epochs = 500
optimizer = optim.Adam(model.parameters(), lr=0.001)
for e in range(epochs):
    loss = train(model, trainloader, optimizer)
    print(e, " avg loss:{:.3f}".format(loss))

    if e%10 == 0 and e>0:
        inference(model, 1)

 19%|█▉        | 33/171 [00:25<01:47,  1.28it/s]