In [3]:
import requests
import numpy as np
import pandas as pd
from PIL import Image

import torch
from torch import nn
from torchvision import transforms
from eg3d_dataset import EG3DDataset, EG3DImageProcessor
from gen_samples import vision_evaluate



In [4]:
from PIL import Image
import requests
from transformers import BlipProcessor, BlipForConditionalGeneration, BlipTextModel

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

inputs = processor(images=image, return_tensors="pt")
# print(inputs.input_ids.shape)

outputs = model.generate(**inputs)
print(processor.decode(outputs[0], skip_special_tokens=True))

FileNotFoundError: [Errno 2] No such file or directory: '/Users/alexanderkorte/.cache/huggingface/hub/models--Salesforce--blip-image-captioning-base/refs/main'

In [24]:
outputs.keys()

odict_keys(['loss', 'decoder_logits', 'image_embeds', 'last_hidden_state'])

In [27]:
# top_k_top_p_filtering(outputs.decoder_logits[0])
outputs.decoder_logits[0].shape

torch.Size([16, 30524])

In [26]:
classes = [word.argmax().item() for word in outputs.decoder_logits[0]]
processor.decode(classes)

'twoss two twosdinindss _erer [SEP] [SEP] [SEP]'

In [55]:
print(outputs.decoder_logits.shape)
print(outputs.image_embeds.shape)

torch.Size([1, 2, 30524])
torch.Size([1, 577, 768])


In [19]:
processor(text=["testing testin 123 dsafjdksalf", "testing 2 but different"], return_tensors="pt", padding=True).keys()

dict_keys(['input_ids', 'attention_mask'])

In [2]:
from dataclasses import dataclass

@dataclass
class TrainingConfig:
    image_size = 512  # the generated image resolution
    train_batch_size = 128
    eval_batch_size = 128  # how many images to sample during evaluation
    num_dataloader_workers = 12  # how many subprocesses to use for data loading
    num_epochs = 30
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    scheduler_train_timesteps = 1000
    eval_inference_steps = 1000
    save_image_epochs = 10
    save_model_epochs = 30
    mixed_precision = 'no'  # `no` for float32, `fp16` for automatic mixed precision
    output_dir = 'vision-eg3d-latent-interpreter'
    
    data_dir = 'data/'
    df_file = 'dataset.df'

    overwrite_output_dir = True
    seed = 0

config = TrainingConfig()

In [3]:
preprocess = EG3DImageProcessor()

dataset = EG3DDataset(df_file=config.df_file, data_dir=config.data_dir, image_size=128, transform=preprocess, encode=False)

train_size = int(len(dataset) * 0.95)
eval_size = len(dataset) - train_size
train_dataset, eval_dataset = torch.utils.data.random_split(dataset, [train_size, eval_size])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=config.train_batch_size, shuffle=True, num_workers=config.num_dataloader_workers)
eval_dataloader = torch.utils.data.DataLoader(eval_dataset, batch_size=config.eval_batch_size, shuffle=True, num_workers=config.num_dataloader_workers)

In [4]:
vision_config = CLIPVisionConfig(hidden_size=512, num_hidden_layers=16, num_attention_heads=16)
model = CLIPVisionModel(vision_config)
processor = CLIPImageProcessor(size=512)

In [5]:
# url = "http://images.cocodataset.org/val2017/000000039769.jpg"
# image = Image.open(requests.get(url, stream=True).raw)

# inputs = processor(images=[dataset[0]['images'], dataset[1]['images']], return_tensors="pt")

# outputs = model(**inputs)
# last_hidden_state = outputs.last_hidden_state
# pooled_output = outputs.pooler_output  # pooled CLS states

In [6]:
# pooled_output.shape

In [7]:
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)

In [8]:
from accelerate import Accelerator

from tqdm.auto import tqdm
import os

In [9]:
def train_loop(config, model, processor, optimizer, train_dataloader, eval_dataloader):
    # Initialize accelerator and tensorboard logging
    accelerator = Accelerator(
        mixed_precision=config.mixed_precision,
        gradient_accumulation_steps=config.gradient_accumulation_steps, 
        log_with="tensorboard",
        logging_dir=os.path.join(config.output_dir, "logs")
    )
    if accelerator.is_main_process:
        accelerator.init_trackers("blip_latent_interpreter")

    model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader
    )

    global_step = 0
    
    for epoch in range(config.num_epochs):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)
        progress_bar.set_description(f"Epoch {epoch}")
        model.train()
        
        for step, batch in enumerate(train_dataloader):
            images = batch['images']
            latent_vectors = batch['latent_vectors']
            
            with accelerator.accumulate(model):
                latent_vectors_pred = model(pixel_values=images).pooler_output
                
                loss = latent_vectors_pred.loss
                accelerator.backward(loss)

                optimizer.step()
                # lr_scheduler.step()
                optimizer.zero_grad()
                
            progress_bar.update(1)
            logs = {"train_loss": loss.detach().item()}
            progress_bar.set_postfix(**logs)
            accelerator.log(logs, step=global_step)
            global_step += 1
        
        model.eval()
        avg_eval_loss = []
        for step, batch in enumerate(eval_dataloader):
            with torch.no_grad():
                latent_vectors_pred = model(pixel_values=images).pooler_output
                
                loss = loss_function(latent_vectors_pred, latent_vectors)
                avg_eval_loss.append(loss.detach().item())
        avg_eval_loss = sum(avg_eval_loss) / len(avg_eval_loss)
        logs = {"eval_loss": avg_eval_loss}
        accelerator.log(logs, step=global_step)

        # After each epoch you optionally sample some demo images with evaluate() and save the model
        if accelerator.is_main_process:
            if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
                render = epoch == config.num_epochs - 1
                vision_evaluate(config, epoch, processor, model, eval_dataloader, render=render)

            if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': loss,
                }, os.path.join(config.output_dir, 'model.pth'))

In [10]:
from accelerate import notebook_launcher
args = (config, model, processor, optimizer, train_dataloader, eval_dataloader)

notebook_launcher(train_loop, args, num_processes=1)

Launching training on one GPU.


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

In [11]:
import glob
from IPython.display import display, HTML

sample_dfs = sorted(glob.glob(f"{config.output_dir}/samples/*.df"))
df = pd.read_pickle(sample_dfs[-1])

display(df)

Unnamed: 0,latent_vectors,latent_vectors_pred
0,"[0.43638438, -0.35027096, 0.65112334, -0.14421...","[-0.7743806, 1.1886537, -0.03107796, -0.628278..."
1,"[-0.9565933, 0.9568855, 2.2118726, 0.9116337, ...","[-1.1879879, 0.63656104, -0.44836017, 0.114511..."
2,"[-0.41382802, 0.45954362, 0.4524543, -0.719719...","[-0.7957045, 0.5445316, -0.21407354, 0.2051533..."
3,"[-0.07948193, -0.773863, 0.26276428, -0.192911...","[-0.84935576, 0.829156, 0.021630026, 0.7263233..."
4,"[0.7090109, -0.97619194, -0.50082564, -1.24826...","[-1.273112, 0.8009871, -1.1427329, 0.37143478,..."
...,...,...
98,"[0.8456882, -1.0896301, 0.110036835, 1.014014,...","[-0.6623965, 1.1237195, -0.42268917, -0.576029..."
99,"[0.0018373622, -1.0448452, 0.47130919, -0.1664...","[-0.8078432, -0.28316444, -0.0044523203, 0.279..."
100,"[0.46234643, -1.7402277, -0.09604135, -1.29059...","[-1.4029849, -0.3401298, 0.043911677, 0.465329..."
101,"[2.3822598, -1.3872718, 0.60999435, -0.8259504...","[-0.5495116, 0.44137225, -0.62052745, 0.230265..."
