In [1]:
import os
import numpy as np
import pandas as pd
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from eg3d_dataset import EG3DDataset
from diffuser_utils.evaluate import vision_evaluate
from torchvision.models import convnext_base, convnext_small

from accelerate import Accelerator
from diffusers.models.vae import Encoder
from diffusers import UNet1DModel
from diffusers import DPMSolverMultistepScheduler

from eg3d_loss import EG3DLoss
from eg3d import EG3D

In [3]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    rgb = True
    image_size = 512  # the generated image resolution
    train_batch_size = 8
    eval_batch_size = 8  # how many images to sample during evaluation
    num_dataloader_workers = 12  # how many subprocesses to use for data loading
    num_epochs = 200
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    scheduler_train_timesteps = 30
    eval_inference_steps = 30
    save_image_epochs = 10
    save_model_epochs = 10
    mixed_precision = 'fp16'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'eg3d-latent-diffusion'
    
    eg3d_model_path = 'eg3d/eg3d_model/ffhqrebalanced512-128.pkl'
    eg3d_latent_vector_size = 512
    
    data_dir = 'data_color/'
    df_file = 'dataset.df'

    overwrite_output_dir = True
    seed = 0

config = TrainingConfig()

In [None]:
preprocess = transforms.Compose(
        [
            transforms.Resize(config.image_size),
            transforms.ToTensor(),
        ]
    )

dataset = EG3DDataset(df_file=config.df_file, data_dir=config.data_dir, transform=preprocess, encode=False)

train_size = int(len(dataset) * 0.95)
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_dataloader_workers)
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=True, num_workers=config.num_dataloader_workers)

In [4]:
in_channels = 3 if config.rgb else 1
encoder = Encoder(in_channels=1, out_channels=1)

unetModel = UNet1DModel(
    sample_size=config.eg3d_latent_vector_size,  # the target image resolution
    in_channels=in_channels,
    out_channels=1,
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(128, 128, 256, 256, 512, 512),  # the number of output channes for each UNet block
    down_block_types=( 
        "DownBlock1D",  # a regular ResNet downsampling block
        "DownBlock1D", 
        "DownBlock1D", 
        "DownBlock1D", 
        "AttnDownBlock1D",  # a ResNet downsampling block with spatial self-attention
        "DownBlock1D",
    ), 
    up_block_types=(
        "UpBlock1D",  # a regular ResNet upsampling block
        "AttnUpBlock1D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock1D", 
        "UpBlock1D", 
        "UpBlock1D", 
        "UpBlock1D"  
      ),
)

In [None]:
noise_scheduler = DPMSolverMultistepScheduler(num_train_timesteps=config.scheduler_train_timesteps)

In [5]:
test_inp = torch.rand((1, 1, 512, 512))
encoder = encoder
outputs = encoder(test_inp)
print(outputs)

: 

: 

In [None]:
model = nn.Sequential(
    encoder,
    unetModel,
)

In [None]:
def train_loop(config, model, noise_scheduler, optimizer, loss_function, train_dataloader, lr_scheduler):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps, 
        log_with="tensorboard",
        logging_dir=os.path.join(config.output_dir, "logs")
    )
    if accelerator.is_main_process:
        accelerator.init_trackers("train_eg3d_diffuser")

    # Prepare everything
    # There is no specific order to remember, you just need to unpack the 
    # objects in the same order you gave them to the prepare method.
    model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler
    )
    
    global_step = 0
    # Now you train the model
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            clean_images = batch['images']
            noise = torch.randn(clean_images.shape).to(clean_images.device)
            bs = clean_images.shape[0]
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            
            with accelerator.accumulate(model):
                # Predict the noise residual
                noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
            
            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            pipeline = EG3DPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                evaluate(config, epoch, pipeline)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                pipeline.save_pretrained(config.output_dir) 