## Training


In [2]:
!pip install -r requirements.txt


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m23.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


In [1]:
import PIL
import torch
import os
from torchvision import transforms
from diffusers import AutoencoderKL

from torch.utils.data import Dataset
from torchvision.transforms import functional
from PIL import Image
import pandas as pd
from tqdm import tqdm
import math
import glob

In [2]:
from src.model import UNet, CLIPVision
from src.pipeline import MyPipline
pretrained_name = 'lambdalabs/sd-image-variations-diffusers'


In [3]:
vae = AutoencoderKL.from_pretrained(pretrained_name, subfolder='vae', revision='v2.0')

In [4]:
unet2 = UNet(path=pretrained_name, subfolder='unet', revision='v2.0')

In [5]:
clip = CLIPVision.from_pretrained(pretrained_name, subfolder='image_encoder', revision="v2.0")

In [6]:
sd_pipe = MyPipline(unet2, clip)

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [7]:
from utils import prepare_inference_im
inp_images = prepare_inference_im('./test_im')
print(inp_images[0].shape)

torch.Size([1, 3, 224, 224])


In [None]:
param_need_grads = []
for name, param in unet2.named_parameters():
    if 'attentions' in name:
        param.requires_grad = True
        param_need_grads.append(name)
    else:
        param.requires_grad = False

print(len(param_need_grads) / 2)

416.0

In [15]:
from src.cofig import TrainingConfig
config = TrainingConfig()

In [16]:
relevant_poses = pd.read_pickle('./relevant_poses.pkl')

In [17]:
from src.dataset import FFHQDataset
dataset_ffhq = FFHQDataset('/content/ffhq', relevant_poses, size=config.image_size)
train_dataloader_ffhq = torch.utils.data.DataLoader(dataset_ffhq, batch_size=config.train_batch_size, shuffle=True)

In [18]:
from diffusers import DDPMScheduler
import torch.nn.functional as F

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)


In [19]:
optimizer = torch.optim.AdamW(unet2.parameters(), lr=config.learning_rate)

In [20]:
from diffusers.optimization import get_cosine_schedule_with_warmup

lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps,
    num_training_steps=(len(train_dataloader_ffhq) * config.num_epochs),
)

In [21]:
def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i%cols*w, i//cols*h))
    return grid

def evaluate(config, epoch, pipeline):

    images = pipeline(
        # batch_size = config.eval_batch_size, 
        generator=torch.manual_seed(config.seed),
    ) # return a tuple
    
    image_grid = make_grid(images, rows=config.train_batch_size // 2, cols=config.train_batch_size // 2)

    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")

def save_images(im, epoch):
    test_dir = os.path.join(config.output_dir, "samples")
    os.makedirs(test_dir, exist_ok=True)
    im[0].save(f"{test_dir}/{epoch:04d}.png")


In [9]:
pipe = MyPipline(unet2, clip)

`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.


In [28]:
from accelerate import Accelerator

def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler, vae, vit, inp_images):

    # accelerator
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        log_with="tensorboard", # change to wandb
        logging_dir=os.path.join(config.output_dir, "logs")
    )
    if accelerator.is_main_process:
        accelerator.init_trackers("train_example")

    # Have to unpack the
    # objects in the same order you gave them to the prepare method!!!
    model, optimizer, train_dataloader, lr_scheduler, vae, vit = accelerator.prepare(
        model, optimizer, train_dataloader, lr_scheduler, vae, vit
    )

    global_step = 0

    vae.eval()
    vit.train()

    model.train()

    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")

        for step, batch in enumerate(train_dataloader):
            poses = batch['diff_pose']
            latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
            latents = latents * 0.18215 # Why we mult latens on this constant?

            img_emb = vit(batch['pix_for_vit']).image_embeds.reshape(batch['pix_for_vit'].shape[0], 1, 768)


            # Sample noise to add to the images
            noise = torch.randn(latents.shape).to(latents.device)
            bs = latents.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=latents.device).long()

            # The forward diffusion process
            noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

            with accelerator.accumulate(model):

                # Predict the noise residual
                noise_pred = model(noisy_latents, timesteps, img_emb, poses)["sample"]
                loss = F.mse_loss(noise_pred, noise)
                accelerator.backward(loss)

                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            progress_bar.update(1)
            logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            metrics = {"train/train_loss": loss.detach().item(),
                       "lr": lr_scheduler.get_last_lr()[0],
                       "step": global_step}
            # wandb.log({**metrics})
            global_step += 1

        # Save sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            pipeline = MyPipline(unet=accelerator.unwrap_model(model), image_encoder=clip)

            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                try:
                    im = pipeline(inp_images[0], torch.Tensor([[0., -1.]]))
                    save_images(im, epoch)
                    torch.save(accelerator.unwrap_model(model).state_dict(), f'model_{epoch % 3}.pth')
                except KeyboardInterrupt:
                    assert KeyboardInterrupt
                except Exception as err:
                    print(err.args)


In [None]:
from accelerate import notebook_launcher
from train import train_loop

torch.cuda.empty_cache() # pain 

args = (config, unet2, noise_scheduler, optimizer, train_dataloader_ffhq, lr_scheduler, vae, clip, inp_images)

notebook_launcher(train_loop, args, num_processes=1)

## Eval

In [None]:
unet2.load_state_dict(torch.load('./model_new0.pth'))

<All keys matched successfully>

In [None]:
clip.load_state_dict(torch.load('./vit_10.pth'))

<All keys matched successfully>

In [None]:
clip.to('cuda')
vae.to('cuda')
unet2.to('cuda')

In [None]:
pipeline = MyPipline(unet2, clip)

In [None]:
import numpy as np
images = []
for s in np.linspace(-0.7, 0.7, 16):
    im = pipeline(inp_images[0],
              torch.Tensor([[0., s]]), 
              generator=torch.manual_seed(10), #10
              guidance_scale = 13, 
              num_inference_steps=100
              )
    images.append(im[0])

image_grid = make_grid(images, rows=4, cols=4)

# Save the images
image_grid.save(f"./rotation_test_8.png")

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]

  0%|          | 0/100 [00:00<?, ?it/s]