In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

# Conditional Latent Diffusion Training

In this notebook, we will train a simple `LatentDiffusionConditional` model in low resolution (64 by 64).

The training should take about 20 hours for reasonable results.

---

Maps dataset from the pix2pix paper:
```bash
wget http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/maps.tar.gz
tar -xvf maps.tar.gz
```

In [42]:
!pip install pillow

Found existing installation: Pillow 9.4.0
Uninstalling Pillow-9.4.0:
  Would remove:
    /usr/local/lib/python3.9/site-packages/PIL/*
    /usr/local/lib/python3.9/site-packages/Pillow-9.4.0.dist-info/*
Proceed (Y/n)? ^C
[31mERROR: Operation cancelled by user[0m[31m
[0m

In [43]:
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as T
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset
import pytorch_lightning as pl

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import imageio
from skimage import io
import os

from src import *

mpl.rcParams['figure.figsize'] = (8, 8)

In [44]:
import kornia
from kornia.utils import image_to_tensor
import kornia.augmentation as KA

class SimpleImageDataset(Dataset):
    """Dataset returning images in a folder."""

    def __init__(self,
                 target_dir,
                 condition_dir=None,
                 transforms=None,
                 paired=True,
                 return_pair=False):
        self.target_dir = target_dir
        self.condition_dir = condition_dir
        self.transforms = transforms
        self.paired=paired
        self.return_pair=return_pair
        
        # set up transforms
        if self.transforms is not None:
            if self.paired:
                data_keys=2*['input']
            else:
                data_keys=['input']

            self.input_T=KA.container.AugmentationSequential(
                *self.transforms,
                data_keys=data_keys,
                same_on_batch=False
            )   
        
        # check files
        supported_formats=['webp','jpg','png']        
        self.files=[el for el in os.listdir(self.target_dir) if el.split('.')[-1] in supported_formats]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()            

        target_name = os.path.join(self.target_dir, self.files[idx])
        condition_name = os.path.join(self.condition_dir, self.files[idx])
        
        target_image = image_to_tensor(io.imread(target_name))/255
        condition_image = image_to_tensor(io.imread(condition_name))/255


        if self.paired:
            if self.transforms is not None:
                out = self.input_T(target_image,condition_image)
                image=out[0][0]
                image2=out[1][0]
        elif self.transforms is not None:
            target_image = self.input_T(target_image)[0]

        if self.return_pair:
            return condition_image, target_image
        else:
            return target_image

In [46]:
!pip install pillow

[33mDEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621[0m[33m
[33mDEPRECATION: Configuring installation scheme with distutils config files is deprecated and will no longer work in the near future. If you are using a Homebrew or Linuxbrew Python, please see discussion at https://github.com/Homebrew/homebrew-core/issues/76621[0m[33m
[0m

In [47]:
import torchvision.transforms as T
import PIL

'''CROP_SIZE=256

T=[
    KA.RandomCrop((2*CROP_SIZE,2*CROP_SIZE)),
    KA.Resize((CROP_SIZE,CROP_SIZE),antialias=True),
    KA.RandomVerticalFlip()
  ]'''

train_ds=SimpleImageDataset(target_dir='datasets/randomMIDI/PianoViolin11025/jpeg_amp_only/train/ins3',
                            condition_dir='datasets/randomMIDI/PianoViolin11025/jpeg_amp_only/train/mix',
                            return_pair=True
                     )


test_ds=SimpleImageDataset(target_dir='datasets/randomMIDI/PianoViolin11025/jpeg_amp_only/test/ins3',
                           condition_dir='datasets/randomMIDI/PianoViolin11025/jpeg_amp_only/test/mix',
                           return_pair=True
                     )

img1,img2=train_ds[0]

plt.subplot(1,2,1)
plt.imshow(img1.permute(1,2,0))
plt.subplot(1,2,2)
plt.imshow(img2.permute(1,2,0))

8.2.0


AttributeError: module 'PIL.TiffTags' has no attribute 'LONG8'

### Model Training

In [None]:
model=LatentDiffusionConditional(train_ds,
                                 lr=1e-4,
                                 batch_size=8)

...but first, let's check if the used `AutoEncoder` (`model.ae`) can reconstruct our samples well.

**You should always test your autoencoder in this way when using latent diffusion models on a new dataset.**

In [None]:
plt.subplot(1,2,1)
plt.imshow(img.permute(1,2,0))
plt.title('Input')
plt.subplot(1,2,2)
plt.imshow(model.ae(img.unsqueeze(0))[0].detach().cpu().permute(1,2,0))
plt.title('AutoEncoder Reconstruction')

In [None]:
trainer = pl.Trainer(
    max_steps=2e5,
    callbacks=[EMA(0.9999)],
    gpus = [0]
)

In [None]:
trainer.fit(model)

In [None]:
input,output=test_ds[0]
batch_input=torch.stack(4*[input],0)

model.cuda()
out=model(batch_input, verbose=True)

In [None]:
plt.subplot(1,2+len(out),1)
plt.imshow(input.permute(1,2,0))
plt.title('Input')
plt.axis('off')
for idx in range(out.shape[0]):
    plt.subplot(1,2+len(out),idx+2)
    plt.imshow(out[idx].detach().cpu().permute(1,2,0))
    plt.axis('off')
plt.subplot(1,2+len(out),2+len(out))
plt.imshow(output.permute(1,2,0))
plt.title('Ground Truth')
plt.axis('off')

By default, the `DDPM` sampler contained in the model is used, as above.

However, you can use a `DDIM` sampler just as well to reduce the number of inference steps:

In [None]:
input,output=test_ds[0]
batch_input=torch.stack(4*[input],0)
STEPS=200 # ddim steps

ddim_sampler=DDIM_Sampler(STEPS,model.model.num_timesteps)

model.cuda()
out=model(batch_input,sampler=ddim_sampler,verbose=True)

In [None]:
plt.subplot(1,2+len(out),1)
plt.imshow(input.permute(1,2,0))
plt.title('Input')
plt.axis('off')
for idx in range(out.shape[0]):
    plt.subplot(1,2+len(out),idx+2)
    plt.imshow(out[idx].detach().cpu().permute(1,2,0))
    plt.axis('off')
plt.subplot(1,2+len(out),2+len(out))
plt.imshow(output.permute(1,2,0))
plt.title('Ground Truth')
plt.axis('off')