In [1]:
import requests
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch import nn
from torchvision import transforms
from eg3d_dataset import EG3DDataset, EG3DImageProcessor
from gen_samples import vision_evaluate

from transformers import CLIPImageProcessor, CLIPVisionModel, CLIPVisionConfig

In [2]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 512  # the generated image resolution
    train_batch_size = 128
    eval_batch_size = 128  # how many images to sample during evaluation
    num_dataloader_workers = 12  # how many subprocesses to use for data loading
    num_epochs = 60
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    scheduler_train_timesteps = 1000
    eval_inference_steps = 1000
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = 'no'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'vision-eg3d-latent-interpreter'
    
    data_dir = 'data_color/'
    df_file = 'dataset.df'

    overwrite_output_dir = True
    seed = 0

config = TrainingConfig()

In [3]:
preprocess = EG3DImageProcessor()

dataset = EG3DDataset(df_file=config.df_file, data_dir=config.data_dir, image_size=128, transform=preprocess, encode=False)

train_size = int(len(dataset) * 0.95)
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_dataloader_workers)
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=True, num_workers=config.num_dataloader_workers)

In [7]:
vision_config = CLIPVisionConfig(hidden_size=512, num_hidden_layers=32, num_attention_heads=32)
model = CLIPVisionModel(vision_config)
processor = CLIPImageProcessor(size=512)

In [8]:
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)

# inputs = processor(images=[dataset[0]['images'], dataset[1]['images']], return_tensors="pt")

# outputs = model(**inputs)
# last_hidden_state = outputs.last_hidden_state
# pooled_output = outputs.pooler_output  # pooled CLS states

In [9]:
# pooled_output.shape

In [10]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
loss_function = nn.SmoothL1Loss(reduction='mean')

In [11]:
from accelerate import Accelerator

from tqdm.auto import tqdm
import os

In [12]:
def train_loop(config, model, processor, optimizer, train_dataloader, eval_dataloader):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps, 
        log_with="tensorboard",
        logging_dir=os.path.join(config.output_dir, "logs")
    )
    if accelerator.is_main_process:
        accelerator.init_trackers("clip_latent_interpreter")

    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )

    global_step = 0
    
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")
        model.train()
        
        for step, batch in enumerate(train_dataloader):
            images = batch['images']
            latent_vectors = batch['latent_vectors']
            
            with accelerator.accumulate(model):
                latent_vectors_pred = model(pixel_values=images).pooler_output
                
                loss = loss_function(latent_vectors_pred, latent_vectors)
                accelerator.backward(loss)

                optimizer.step()
                # lr_scheduler.step()
                optimizer.zero_grad()
                
            progress_bar.update(1)
            logs = {"train_loss": loss.detach().item()}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1
        
        model.eval()
        avg_eval_loss = []
        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                latent_vectors_pred = model(pixel_values=images).pooler_output
                
                loss = loss_function(latent_vectors_pred, latent_vectors)
                avg_eval_loss.append(loss.detach().item())
        avg_eval_loss = sum(avg_eval_loss) / len(avg_eval_loss)
        logs = {"eval_loss": avg_eval_loss}
        accelerator.log(logs, step=global_step)

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                render = epoch == config.num_epochs - 1
                vision_evaluate(config, epoch, processor, model, eval_dataloader, render=render)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                }, os.path.join(config.output_dir, 'model.pth'))

In [14]:
from accelerate import notebook_launcher
args = (config, model, processor, optimizer, train_dataloader, eval_dataloader)

notebook_launcher(train_loop, args, num_processes=1)

Launching training on one GPU.


  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

  0%|          | 0/61 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [15]:
import glob
from IPython.display import display, HTML

sample_dfs = sorted(glob.glob(f"{config.output_dir}/samples/*.df"))
df = pd.read_pickle(sample_dfs[-1])

display(df)

Unnamed: 0,latent_vectors,latent_vectors_pred
0,"[-0.46562985, 1.334632, 1.3463336, -0.21872601...","[0.34967417, 0.33072296, -0.028596513, 0.03216..."
1,"[0.29773626, 0.31208906, 1.0516183, -1.5730159...","[0.5148991, 0.08528648, 0.065244794, -0.153878..."
2,"[0.6659024, 0.89474744, 2.2324452, -1.1894681,...","[0.32946944, 0.102209136, -0.16398558, -0.0494..."
3,"[0.3788743, 2.4546614, 0.36312538, -2.1257122,...","[0.276742, 0.17762648, 0.034863807, -0.0208411..."
4,"[0.31123757, 0.17900072, 1.3598925, -1.8150185...","[0.17728442, 0.33686158, -0.044201493, 0.02935..."
...,...,...
123,"[1.2418977, -0.9253846, -0.13435465, -0.433270...","[0.07940714, 0.016563637, -0.10800985, -0.0164..."
124,"[-0.057110276, -0.43176275, -0.21717091, 0.219...","[0.2501511, 0.008996903, -0.17652346, -0.08992..."
125,"[1.8535191, 0.035906322, -0.27818182, -1.50999...","[0.21476893, 0.0716762, -0.104455546, -0.10830..."
126,"[-0.006629809, -0.6767008, -0.6187405, -0.5521...","[0.16533367, -0.07044939, -0.19103174, -0.1112..."
