In [1]:
import torch
import torch.nn as nn
from huggingface_hub import login
from transformers import AutoProcessor, CLIPVisionModel
from transformers import AutoTokenizer, AutoModelForCausalLM
HUGGINGFACE_TOKEN = "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"
login(HUGGINGFACE_TOKEN)

  from .autonotebook import tqdm as notebook_tqdm


The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.
Token is valid (permission: read).
Your token has been saved to C:\Users\Lsemi\.cache\huggingface\token
Login successful


In [10]:
#CAPTION_PATH = "/content/Flickr8k.token.txt"
CAPTION_PATH = "./Flicker/Flickr8k_text/Flickr8k.token.txt"
#IMAGES_FILE_PATH = "/content/Flicker8k_Dataset"
IMAGES_FILE_PATH = "./Flicker/Flickr8k_Dataset/Flicker8k_Dataset"
# SAVED_PATH = "/content/saved_model/adaptor_caption.pt"
SAVED_PATH = "./gemma2-9b/adaptor_caption.pt"

BATCH_SIZE = 8
NUM_ITERATION = 2000
SAVE_EVERY = 200
LEARNING_RATE = 1e-4
TRAIN_DATA_NUM = 7500

device = 'cpu'
if torch.cuda.is_available() :
  device = 'cuda'

In [3]:
print(device)

cuda


In [4]:
class MyAdaptor(nn.Module) :
  def __init__(self, vis_token_embedding_size, word_embedding_size) :
    super(MyAdaptor, self).__init__()
    self.vis_token_embedding_size = vis_token_embedding_size
    self.word_embedding_size = word_embedding_size

    self.adapter_linear = nn.Linear(self.vis_token_embedding_size, self.word_embedding_size)

  def forward(self, img_output) :
    self.adapter_linear.to(img_output.device)
    img_embed = self.adapter_linear(img_output)
    return img_embed

class MyModel(nn.Module) :
  def __init__(self) :
    super(MyModel, self).__init__()
    self.model_language = AutoModelForCausalLM.from_pretrained("google/gemma-2-9b-it", torch_dtype=torch.bfloat16)
    self.tokenizer_language = AutoTokenizer.from_pretrained("google/gemma-2-9b-it", padding_side= 'right')
    self.image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32").image_processor
    self.model_image = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")

    #self.word_embedding_size = 2304
    self.word_embedding_size = 3584
    self.num_vocab = 256000

    self.trigger_str_img = "<start_image>"
    self.num_vis_token_summary = 50
    self.vis_token_embedding_size = 768
    self.adaptor = MyAdaptor(self.vis_token_embedding_size,self.word_embedding_size )
    self.dummy_img_token = (" ".join(["the"]*self.num_vis_token_summary)).strip()

  def search_trigger_idx(self, text_token, trigger_str) :
    all_token = text_token
    all_string_now = ""
    all_token_now = []
    dummy_start_token = None
    for token_idx in range(len(all_token)) :
      token_now = int(all_token[token_idx].detach().cpu().numpy())
      all_token_now.append(token_now)
      token_as_string = self.tokenizer_language.batch_decode([all_token_now],skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

      if trigger_str in token_as_string :
        dummy_start_token = token_idx + 1
        break
    return dummy_start_token

  def get_image_embed(self, image_input) :
    img_output = self.model_image(image_input)['last_hidden_state']
    img_embed = self.adaptor(img_output)

    return img_embed


  def replace_embedding_hook(self, image_input) :
    image_feature = self.get_image_embed(image_input)
    assert len(image_feature) == 1

    def now_hook(model, input, output) :
      real_input = input[0]
      batch_size, token_len = real_input.shape
      if(token_len > 1) :
        assert batch_size == 1
        dummy_start_token = self.search_trigger_idx(real_input[0], self.trigger_str_img )

        temp = image_feature[0]
        output[:,dummy_start_token:dummy_start_token+self.num_vis_token_summary] = temp
      return output
    return now_hook



  def split_and_replace(self, now_input_tokens, replacement_embed, start_loc) :
    num_token = len(replacement_embed)

    start_embed = now_input_tokens[0:start_loc]
    end_embed = now_input_tokens[start_loc+num_token:]
    # デバッグ用ログ
    #print(f"start_embed: {start_embed.shape}, replacement_embed: {replacement_embed.shape}, end_embed: {end_embed.shape}")

    # 必要ならサイズを調整
    if start_embed.size(-1) != replacement_embed.size(-1) or end_embed.size(-1) != replacement_embed.size(-1):
        print(f"Size mismatch: start={start_embed.size(-1)}, replacement={replacement_embed.size(-1)}, end={end_embed.size(-1)}")
    replaced_embed = torch.cat((start_embed, replacement_embed.to(now_input_tokens.dtype), end_embed),0)

    return replaced_embed

  def forward_loss(self, image_input_raw, caption_output_raw) :
    instruction_now =  "<start_of_turn>user\n"
    instruction_now += f"<start_image> {self.dummy_img_token}\n<end_image>\n"
    instruction_now += f"Create a simple description of the image!\n<end_of_turn>\n<start_of_turn>model\n"

    image_input = self.image_processor(image_input_raw, return_tensors="pt")['pixel_values']
    image_input = image_input.to(device)

    caption_output = self.tokenizer_language(caption_output_raw,padding=True,return_tensors="pt")
    caption_output['input_ids'] = caption_output['input_ids'].to(device)
    caption_output['attention_mask'] = caption_output['attention_mask'].to(device)

    img_output = self.model_image(image_input)['last_hidden_state']
    img_embed = self.adaptor(img_output)

    all_text_with_prompt = [instruction_now + temp_text for temp_text in self.tokenizer_language.batch_decode(caption_output['input_ids'], skip_special_tokens=True)]
    all_tokens_with_prompt = self.tokenizer_language(all_text_with_prompt, padding=True, return_tensors="pt")
    all_tokens_with_prompt['input_ids'] = all_tokens_with_prompt['input_ids'].to(device).detach()
    all_tokens_with_prompt['attention_mask'] = all_tokens_with_prompt['attention_mask'].to(device).detach()

    all_token_prompt_embed = self.model_language.model.embed_tokens(all_tokens_with_prompt['input_ids'])
    prompt_len = len(self.tokenizer_language([instruction_now])['input_ids'][0])
    caption_label_now = all_tokens_with_prompt['input_ids'][:,prompt_len:]
    caption_label_now = F.one_hot(caption_label_now,self.num_vocab)
    attn_mask_now = all_tokens_with_prompt['attention_mask'][:,prompt_len:]

    all_replaced_feature = []
    for temp_idx in range(len(all_tokens_with_prompt['input_ids'])) :
      tokens_text_now = all_tokens_with_prompt['input_ids'][temp_idx].detach().cpu()
      dummy_location_caption = self.search_trigger_idx(tokens_text_now, self.trigger_str_img )
      image_replaced_prompt = self.split_and_replace(all_token_prompt_embed[temp_idx], img_embed[temp_idx], dummy_location_caption)

      all_replaced_feature.append(image_replaced_prompt)
    all_replaced_feature = torch.stack(all_replaced_feature)


    logits_now = self.model_language(inputs_embeds =all_replaced_feature, attention_mask=all_tokens_with_prompt['attention_mask'])

    logits_now = logits_now['logits']
    caption_prediction_now = logits_now[:,prompt_len-1:-1]
    caption_prediction_now = torch.softmax(caption_prediction_now,-1)
    caption_prediction_now = torch.maximum(caption_prediction_now,torch.as_tensor(1e-10).to(caption_prediction_now.dtype))
    caption_prediction_now = torch.minimum(caption_prediction_now,torch.as_tensor(1 - 1e-10).to(caption_prediction_now.dtype))


    loss_lm = -torch.sum(caption_label_now*torch.log(caption_prediction_now),-1)
    loss_lm = torch.sum(loss_lm*attn_mask_now,-1)/torch.sum(attn_mask_now,-1)
    loss_lm = torch.mean(loss_lm)

    return loss_lm

  def generate_aswer_image(self, input_string, pil_image, max_new_tokens = 32, do_sample=True, top_k=50, top_p=0.95, temperature =1 ) :

    input_with_dummy_prompt = self.tokenizer_language.apply_chat_template(input_string, tokenize=False, add_generation_prompt=True)
    input_with_dummy_prompt = input_with_dummy_prompt.replace("<image>", "<start_image> "+self.dummy_img_token+"\n<end_image>")
    dummy_input = self.tokenizer_language(input_with_dummy_prompt,padding=True,return_tensors="pt")
    dummy_input['input_ids'] = dummy_input['input_ids'].to(device)
    dummy_input['attention_mask'] = dummy_input['attention_mask'].to(device)
    assert len(dummy_input['input_ids']) == 1

    handler_image = None

    contains_image = False
    if self.trigger_str_img in input_with_dummy_prompt :
      image_input = self.image_processor([pil_image], return_tensors="pt")['pixel_values'].to(device)
      hook_now_image = self.replace_embedding_hook(image_input)
      contains_image = True
      handler_image = self.model_language.model.embed_tokens.register_forward_hook(hook_now_image)



    output_now = self.model_language.generate(**dummy_input,
                                              max_new_tokens = max_new_tokens,
                                              do_sample=do_sample,
                                              temperature=temperature,
                                              top_k=top_k,
                                              top_p=top_p,
                                              )
    output_string = self.tokenizer_language.batch_decode(output_now, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    if contains_image :
      handler_image.remove()

    return output_string.split("model\n")[-1]

In [5]:
model = MyModel()

Loading checkpoint shards: 100%|██████████| 4/4 [00:07<00:00,  1.83s/it]


In [13]:
model = model.to(torch.bfloat16)
model.adaptor.load_state_dict(torch.load(SAVED_PATH), strict=False)
model.to(device)

  model.adaptor.load_state_dict(torch.load(SAVED_PATH), strict=False)


KeyboardInterrupt: 