In [1]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel, Transformer2DModel, AutoencoderKL
from matplotlib import pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

Using device: cuda


In [2]:
from pathlib import Path

In [None]:
'https://s3.amazonaws.com/fast-ai-imageclas/bedroom.tgz'

In [3]:
path_data = Path('test_data')
path_data.mkdir(exist_ok=True)
path = path_data/'bedroom'

In [4]:
from PIL import Image
import numpy as np
bs = 64
def to_img(f): 
    return np.array(Image.open(f))/255

In [5]:
from glob import glob
class ImagesDS:
    def __init__(self, spec):
        self.path = Path(path)
        self.files = glob(str(spec), recursive=True)
    def __len__(self): 
        return len(self.files)
    def __getitem__(self, i): 
        return torch.from_numpy(to_img(self.files[i]).transpose(2, 0, 1)[:, :256,:256]).float(), 0

In [6]:
dataset = ImagesDS(path / f'**/*.jpg')

In [7]:
len(dataset)

303125

In [8]:
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=False, num_workers=32)

In [9]:
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-ema").cuda().requires_grad_(False)

In [10]:
mmpath = Path('test_data/bedroom/data.npmm')

In [11]:
len(dataset)

303125

In [12]:
mmshape = (303125,4,32,32)


In [13]:
from tqdm.notebook import tqdm

In [14]:
if not mmpath.exists() and False:
    a = np.memmap(mmpath, np.float32, mode='w+', shape=mmshape)
    i = 0
    with torch.no_grad():
        for b, _ in tqdm(train_dataloader):
            n = len(b)
            a[i:i+n] = vae.encode(b.cuda()).latent_dist.mean.cpu().numpy()
            i += n
    a.flush()
    del(a)

In [15]:
from paintmind.stage1.diffusion_transfomers import DiT_S_8

In [16]:
model = DiT_S_8(learn_sigma=False)

In [17]:
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

In [18]:
from tqdm.notebook import tqdm

In [19]:
# Redefining the dataloader to set the batch size higher than the demo of 8
train_dataloader = DataLoader(dataset, batch_size=128, shuffle=True, num_workers=32)

# How many runs through the data should we do?
n_epochs = 10

# Our network 
net = model.to(device)

# Our loss function
loss_fn = nn.MSELoss()

# The optimizer
opt = torch.optim.AdamW(net.parameters(), lr=3e-3, eps=1e-5) 
lr_sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=3e-3, total_steps=n_epochs*len(train_dataloader))

# Keeping a record of the losses for later viewing
losses = []

# The training loop
for epoch in range(n_epochs):
    for x, y in tqdm(train_dataloader):
        
        # Get some data and prepare the corrupted version
        x = x.to(device) # Data on the GPU (mapped to (-1, 1))
        with torch.no_grad():
            x = vae.encode(x)
        x = x.latent_dist.mean
        y = y.to(device)
        noise = torch.randn_like(x)
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device)
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps)

        # Get the model prediction
        pred = net(noisy_x, timesteps, y) # Note that we pass in the labels y

        # Calculate the loss
        loss = loss_fn(pred, noise) # How close is the output to the noise

        # Backprop and update the params:
        opt.zero_grad()
        loss.backward()
        opt.step()
        lr_sched.step()

        # Store the loss for later
        losses.append(loss.item())

    # Print out the average of the last 100 loss values to get an idea of progress:
    avg_loss = sum(losses[-100:])/100
    print(f'Finished epoch {epoch}. Average of the last 100 loss values: {avg_loss:05f}')

# View the loss curve
plt.plot(losses)
plt.show()

  0%|          | 0/2369 [00:00<?, ?it/s]

In [None]:
# x = torch.randn(80, 3, 32, 32).to(device)
# y = torch.tensor([[i]*8 for i in range(10)]).flatten().to(device)

# # Sampling loop
# for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

#     # Get model pred
#     with torch.no_grad():
#         residual = net(x, torch.LongTensor([t]*80).cuda(), y)  # Again, note that we pass in our labels y

#     # Update sample with step
#     x = noise_scheduler.step(residual, t, x).prev_sample

# # Show the results
# fig, ax = plt.subplots(1, 1, figsize=(12, 12))
# ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=8)[0], cmap='Greys')
# plt.show()