In [1]:
import os
import logging
import numpy as np

from scipy.spatial.transform import Rotation as R
from torchvision.transforms import ToTensor
from torchvision import transforms
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import train_test_split

from PIL import Image
from pathlib import Path
from tqdm import tqdm
import json

import sys
import os

In [2]:
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [3]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('runs/diffusion')

In [4]:
sys.path.append('../')
from data.renders_data import load_metadatas, load_render, BackgroundColor, NormTorchToPil, RenderDataset, makeDataLoader, RandomHue, LatentDataset
from diffusion.ddim import evaluate, train_loop

In [5]:
def get_preprocessor(size):
    return transforms.Compose([
        transforms.Resize(size),
        BackgroundColor((0.5, 0.5, 0.5)),
        transforms.Normalize([0.5], [0.5])
    ])

In [6]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 32  # the generated image resolution
    train_batch_size = 16
    eval_batch_size = 16  # how many images to sample during evaluation
    num_epochs = 20
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    save_image_epochs = 1
    save_image_batches = 50
    save_model_epochs = 1
    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'ddim-cars-angle_30-res_128-2'  # the model namy locally and on the HF Hub

    push_to_hub = False  # whether to upload the saved model to the HF Hub
    hub_private_repo = False  
    overwrite_output_dir = True  # overwrite the old model when re-running the notebook
    seed = 0

config = TrainingConfig()

In [7]:
renders_path = Path('/mnt/ML/novel_view_synthesis/render/renders_20_views_all_classes_256')
latents_path = Path('/mnt/ML/novel_view_synthesis/diffusion/latents')
data_transforms = get_preprocessor((config.image_size, config.image_size))
data_augmentations = None #RandomHue()
data = load_metadatas(renders_path)

In [8]:
for metadatas in data:
    for metadata in metadatas:
        metadata['latent_path'] = str(latents_path / metadata['asset_id'] / f"latent_{metadata['id']:05d}.pt")

In [9]:
renders_path = Path('/mnt/ML/Datasets/shapenet renders/renders_old')
cat_id = '02958343'
cat_path = renders_path / cat_id
asset_ids = [str(d.relative_to(renders_path)) for d in cat_path.iterdir()]
data = [[{ 'path': renders_path / asset_id, 'rgba_path': 'color_angle30_res256x256.png' }] for asset_id in asset_ids]

In [10]:
from diffusers import AutoencoderKL

vae = AutoencoderKL.from_pretrained('vae-cars-360_res_64/checkpoints/checkpoint_0004.ckp')
vae.enable_slicing()
vae.requires_grad_(False)
vae.cuda()
_ = vae.eval()

In [11]:
def save_latents(vae, metadatas_list, batch_size=16):
    metadatas = [metadata for metadatas in metadatas_list for metadata in metadatas]
    dataloader, size = makeDataLoader(metadatas, config, batch_size, RenderDataset, data_transforms, shuffle=False)
    
    for step, images in tqdm(enumerate(dataloader), total=len(dataloader)):
        images = images.to(vae.device)
        with torch.no_grad():
            posterior = vae.encode(images).latent_dist
            latents = posterior.mode()
        for i, latent in enumerate(latents):
            metadata = metadatas[step*batch_size+i]
            Path(metadata['latent_path']).parent.mkdir(parents=True, exist_ok=True)
            torch.save(latent, metadata['latent_path'])
            
def compute_std(metadatas_list, batch_size=16):
    metadatas = [metadata for metadatas in metadatas_list for metadata in metadatas]
    dataloader, size = makeDataLoader(metadatas, config, batch_size, LatentDataset, shuffle=True)
    latents = next(iter(dataloader))
    std = latents.std(unbiased=False)
    return std
    
#save_latents(vae, data, batch_size=16)

In [12]:
# std = compute_std(data, batch_size=64)
# latent = torch.load(data[7][0]['latent_path'])
# print(latent.shape)
# dec = vae.decode(latent.unsqueeze(0)).sample
# print(latent.mean())
# display(NormTorchToPil(latent[:3] / std / 2))
# display(NormTorchToPil(dec[0]))

In [13]:
valid_data, train_data = data[:20], data[20:]
train_data = [metadata for metadatas in train_data for metadata in metadatas]
valid_data = [metadata for metadatas in valid_data for metadata in metadatas]
# train_loader, train_size = makeDataLoader(train_data, config, config.train_batch_size, LatentDataset, None)
# valid_loader, val_size = makeDataLoader(valid_data, config, config.train_batch_size, LatentDataset, None, shuffle=False)
train_loader, train_size = makeDataLoader(train_data, config, config.train_batch_size, RenderDataset, data_transforms)
valid_loader, val_size = makeDataLoader(valid_data, config, config.train_batch_size, RenderDataset, data_transforms, shuffle=False)

In [14]:
from diffusers import UNet2DModel

def load_model(path):
    model = UNet2DModel.from_pretrained(path).cuda()
    return model

model = UNet2DModel(
    sample_size=config.image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, ),  # the number of output channes for each UNet block
    attention_head_dim=8,
    down_block_types=( 
        "DownBlock2D", 
        "AttnDownBlock2D",
        "AttnDownBlock2D",
    ), 
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",
        "UpBlock2D", 
      ),
)

#model = load_model('ddim-cars-angle_60_res_128/unet')

model = model.cuda()

In [15]:
from diffusers import DDIMScheduler

noise_scheduler = DDIMScheduler(num_train_timesteps=1000, clip_sample=False, set_alpha_to_one=False, steps_offset=1)

In [16]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

In [17]:
from diffusers.optimization import get_cosine_schedule_with_warmup

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(train_size * config.num_epochs),
)

In [18]:
from accelerate import notebook_launcher
from ldm import train_loop

args = (config, model, noise_scheduler, optimizer, train_loader, lr_scheduler)#, vae, std)

notebook_launcher(train_loop, args, num_processes=0)

Launching training on one GPU.


Training epoch 0:: 100%|████████████████████████████████████████████████████████████████████████████████████████| 217/217 [01:15<00:00,  2.86it/s]
Training epoch 1:: 100%|████████████████████████████████████████████████████████████████████████████████████████| 217/217 [01:12<00:00,  2.99it/s]
Training epoch 2:: 100%|████████████████████████████████████████████████████████████████████████████████████████| 217/217 [01:12<00:00,  3.01it/s]
Training epoch 3:: 100%|████████████████████████████████████████████████████████████████████████████████████████| 217/217 [01:13<00:00,  2.97it/s]
Training epoch 4:: 100%|████████████████████████████████████████████████████████████████████████████████████████| 217/217 [01:12<00:00,  2.98it/s]
Training epoch 5:: 100%|████████████████████████████████████████████████████████████████████████████████████████| 217/217 [01:13<00:00,  2.96it/s]
Training epoch 6:: 100%|████████████████████████████████████████████████████████████████████████████████████████| 217/

In [22]:
from ldm import evaluate
noise_scheduler = DDIMScheduler(num_train_timesteps=1000)
evaluate(config, model, noise_scheduler, 100, 0, num_inference_steps=50)

In [20]:
from ldm import process_image, repaint, sample_images

image = process_image(Path(data[0]['path'] / data[0]['rgba_path']), data_transforms)
r = repaint(model, noise_scheduler, image, 1.0, 1000)
NormTorchToPil(r)

TypeError: list indices must be integers or slices, not str

In [None]:
image = valid_loader.dataset[0].cuda()
image = image.unsqueeze(0)
with torch.no_grad():
    z = vae.encode(image).latent_dist.mode()
    rec = vae(image).sample
display(NormTorchToPil(image[0][:1].repeat(3, 1, 1)))
display(NormTorchToPil(-0.2*z[0][3:4].repeat(3, 1, 1)))

In [None]:
from ldm import sample_images

images = sample_images(config, model, noise_scheduler, vae=vae, num_inference_steps=100)
images = [NormTorchToPil(img) for img in images]
for image in images:
    display(image)