In [1]:
import os
import hydra
from omegaconf import DictConfig, OmegaConf

import sys
sys.path.append('../../')

from rlprompt.models import (ImagePromptModelConfig, SinglePromptModelConfig,
                             make_image_prompt_model, make_single_prompt_model)
from rlprompt.modules import SQLModuleConfig, make_sql_module
from rlprompt.trainers import TrainerConfig, make_trainer
from rlprompt.utils.utils import (colorful_print, compose_hydra_config_store,
                                  get_hydra_output_dir)

from ipg_helpers import (ImagePromptGenerationRewardConfig,
                         ImagePromptGenerationDatasetConfig,
                         make_image_prompt_generation_reward,
                         make_image_prompot_generation_dataset)


# Compose default config
config_list = [ImagePromptGenerationRewardConfig,
                ImagePromptGenerationDatasetConfig, 
                ImagePromptModelConfig,
                SinglePromptModelConfig, 
                SQLModuleConfig, 
                TrainerConfig]

# Combine all the configs and store the config as name 'base_ipg'
cs = compose_hydra_config_store('base_ipg', config_list)

policy_model = make_image_prompt_model(ImagePromptModelConfig)

prompt_model = make_single_prompt_model(policy_model, SinglePromptModelConfig)

dataset = make_image_prompot_generation_dataset(ImagePromptGenerationDatasetConfig)

from torch.utils.data import Dataset, DataLoader
def get_train_dataloader(dataset: Dataset, config: "DictConfig") -> DataLoader:
        return DataLoader(dataset,
                          shuffle=config.train_shuffle,
                          batch_size=8,
                          drop_last=config.train_drop_last)

train_dataloader = get_train_dataloader(dataset, TrainerConfig)


Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.7.self_attn.k_proj.weight', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.2.mlp.fc1.bias', 'text_model.encoder.layers.7.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm2.weight', 'text_model.encoder.layers.4.layer_norm1.bias', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.0.self_attn.v_proj.bias', 'text_model.encoder.layers.7.layer_norm2.bias', 'text_model.encoder.layers.5.self_attn.v_proj.bias', 'text_model.encoder.layers.10.mlp.fc2.weight', 'text_model.encoder.layers.11.self_attn.k_proj.weight', 'text_model.encoder.layers.11.layer_norm2.weight', 'text_model.encoder.layers.3.self_attn.v_proj.bias', 'text_model.encoder.layers.4.self_attn.out_proj.weight', 'text_model.encoder.layers.2.self_attn.out_proj.weight', 'text_model.encoder.layers.11.self_attn.out_proj.weight', 't

loading annotations into memory...
Done (t=0.02s)
creating index...
index created!


In [2]:
for i, batch in enumerate(train_dataloader):
    prompt_model.generate(**batch, do_sample=False, top_k=1, top_p=1, max_new_tokens=None, infer=False)
    break

greedy_search
step 0
attention_forward
attention_input torch.Size([8, 1, 768]) torch.Size([8, 50, 768])
attention_output: torch.Size([8, 1, 768])
logits: torch.Size([8, 1, 50257])
logits: torch.Size([8, 1, 50257])
action: torch.Size([8])
step 1
attention_forward
attention_input torch.Size([8, 1, 768]) torch.Size([8, 50, 768])
attention_output: torch.Size([8, 1, 768])
logits: torch.Size([8, 1, 50257])
logits: torch.Size([8, 1, 50257])
action: torch.Size([8])
step 2
attention_forward
attention_input torch.Size([8, 1, 768]) torch.Size([8, 50, 768])
attention_output: torch.Size([8, 1, 768])
logits: torch.Size([8, 1, 50257])
logits: torch.Size([8, 1, 50257])
action: torch.Size([8])
step 3
attention_forward
attention_input torch.Size([8, 1, 768]) torch.Size([8, 50, 768])
attention_output: torch.Size([8, 1, 768])
logits: torch.Size([8, 1, 50257])
logits: torch.Size([8, 1, 50257])
action: torch.Size([8])
step 4
attention_forward
attention_input torch.Size([8, 1, 768]) torch.Size([8, 50, 768])


In [3]:
from transformers import CLIPVisionModel
clip_visual_model = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32").to('cuda')

Some weights of the model checkpoint at openai/clip-vit-base-patch32 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.5.self_attn.out_proj.weight', 'text_model.encoder.layers.9.self_attn.q_proj.bias', 'text_model.encoder.layers.6.self_attn.v_proj.bias', 'text_model.encoder.layers.1.self_attn.v_proj.weight', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.3.self_attn.q_proj.bias', 'text_model.encoder.layers.0.mlp.fc1.weight', 'text_model.encoder.layers.4.mlp.fc1.weight', 'text_model.encoder.layers.9.self_attn.out_proj.bias', 'text_model.encoder.layers.1.mlp.fc2.bias', 'text_model.encoder.layers.2.mlp.fc2.bias', 'text_model.encoder.layers.0.layer_norm1.weight', 'text_model.encoder.layers.7.mlp.fc1.weight', 'text_model.embeddings.position_embedding.weight', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.1.layer_norm2.bias', 'text_model.encoder.layers.8.mlp.fc2.weight', 'text_model.encoder.layers.7.mlp.

In [None]:
clip_visual_model.

In [5]:
output = clip_visual_model(batch['image'].to('cuda'))

In [7]:
output['last_hidden_state'].shape

torch.Size([16, 50, 768])