In [None]:
# Install library
!pip install -U transformers
# !pip install accelrate
# Download Dataset

!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip /content/Flickr8k_Dataset.zip
!unzip /content/Flickr8k_text.zip

!echo "Downloaded Flickr8k dataset successfully."
!mkdir saved_model
!mkdir uploaded

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: __MACOSX/Flicker8k_Dataset/._3429581486_4556471d1a.jpg  
  inflating: Flicker8k_Dataset/3429641260_2f035c1813.jpg  
  inflating: __MACOSX/Flicker8k_Dataset/._3429641260_2f035c1813.jpg  
  inflating: Flicker8k_Dataset/3429956016_3c7e3096c2.jpg  
  inflating: __MACOSX/Flicker8k_Dataset/._3429956016_3c7e3096c2.jpg  
  inflating: Flicker8k_Dataset/3430100177_5864bf1e73.jpg  
  inflating: __MACOSX/Flicker8k_Dataset/._3430100177_5864bf1e73.jpg  
  inflating: Flicker8k_Dataset/3430287726_94a1825bbf.jpg  
  inflating: __MACOSX/Flicker8k_Dataset/._3430287726_94a1825bbf.jpg  
  inflating: Flicker8k_Dataset/3430526230_234b3550f6.jpg  
  inflating: __MACOSX/Flicker8k_Dataset/._3430526230_234b3550f6.jpg  
  inflating: Flicker8k_Dataset/3430607596_7e4f74e3ff.jpg  
  inflating: __MACOSX/Flicker8k_Dataset/._3430607596_7e4f74e3ff.jpg  
  inflating: Flicker8k_Dataset/343073813_df822aceac.jpg  
  inflating: __MACOSX/Flicker8k_D

In [None]:
import os
from transformers import AutoProcessor, CLIPVisionModel
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
from huggingface_hub import login
import urllib.request
from google.colab import userdata
HuggingFace_Token = userdata.get('HUGGINGFACE')
login(token=HuggingFace_Token)

# Building The Model Architecture

In [None]:
class MyAdaptor(nn.Module) :
  def __init__(self, vis_token_embedding_size, word_embedding_size) :
    super(MyAdaptor, self).__init__()
    self.vis_token_embedding_size = vis_token_embedding_size
    self.word_embedding_size = word_embedding_size

    self.adapter_linear = nn.Linear(self.vis_token_embedding_size, self.word_embedding_size)

  def forward(self, img_output) :
    self.adapter_linear.to(img_output.device)
    img_embed = self.adapter_linear(img_output)
    return img_embed

class MyModel(nn.Module) :
  def __init__(self) :
    super(MyModel, self).__init__()
    self.model_language = AutoModelForCausalLM.from_pretrained("google/gemma-2-2b-it", torch_dtype=torch.bfloat16)
    self.tokenizer_language = AutoTokenizer.from_pretrained("google/gemma-2-2b-it", padding_side= 'right')
    self.image_processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32").image_processor
    self.model_image = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch32")

    self.word_embedding_size = 2304
    self.num_vocab = 256000

    self.trigger_str_img = "<start_image>"
    self.num_vis_token_summary = 50
    self.vis_token_embedding_size = 768
    self.adaptor = MyAdaptor(self.vis_token_embedding_size,self.word_embedding_size )
    self.dummy_img_token = (" ".join(["the"]*self.num_vis_token_summary)).strip()

  def search_trigger_idx(self, text_token, trigger_str) :
    all_token = text_token
    all_string_now = ""
    all_token_now = []
    dummy_start_token = None
    for token_idx in range(len(all_token)) :
      token_now = int(all_token[token_idx].detach().cpu().numpy())
      all_token_now.append(token_now)
      token_as_string = self.tokenizer_language.batch_decode([all_token_now],skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

      if trigger_str in token_as_string :
        dummy_start_token = token_idx + 1
        break
    return dummy_start_token

  def get_image_embed(self, image_input) :
    img_output = self.model_image(image_input)['last_hidden_state']
    img_embed = self.adaptor(img_output)

    return img_embed


  def replace_embedding_hook(self, image_input) :
    image_feature = self.get_image_embed(image_input)
    assert len(image_feature) == 1

    def now_hook(model, input, output) :
      real_input = input[0]
      batch_size, token_len = real_input.shape
      if(token_len > 1) :
        assert batch_size == 1
        dummy_start_token = self.search_trigger_idx(real_input[0], self.trigger_str_img )

        temp = image_feature[0]
        output[:,dummy_start_token:dummy_start_token+self.num_vis_token_summary] = temp
      return output
    return now_hook



  def split_and_replace(self, now_input_tokens, replacement_embed, start_loc) :
    num_token = len(replacement_embed)

    start_embed = now_input_tokens[0:start_loc]
    end_embed = now_input_tokens[start_loc+num_token:]
    replaced_embed = torch.cat((start_embed, replacement_embed.to(now_input_tokens.dtype), end_embed),0)

    return replaced_embed

  def forward_loss(self, image_input_raw, caption_output_raw) :
    instruction_now =  "<start_of_turn>user\n"
    instruction_now += f"<start_image> {self.dummy_img_token}\n<end_image>\n"
    instruction_now += f"Create a simple description of the image!\n<end_of_turn>\n<start_of_turn>model\n"

    image_input = self.image_processor(image_input_raw, return_tensors="pt")['pixel_values']
    image_input = image_input.to(device)

    caption_output = self.tokenizer_language(caption_output_raw,padding=True,return_tensors="pt")
    caption_output['input_ids'] = caption_output['input_ids'].to(device)
    caption_output['attention_mask'] = caption_output['attention_mask'].to(device)

    img_output = self.model_image(image_input)['last_hidden_state']
    img_embed = self.adaptor(img_output)

    all_text_with_prompt = [instruction_now + temp_text for temp_text in self.tokenizer_language.batch_decode(caption_output['input_ids'], skip_special_tokens=True)]
    all_tokens_with_prompt = self.tokenizer_language(all_text_with_prompt, padding=True, return_tensors="pt")
    all_tokens_with_prompt['input_ids'] = all_tokens_with_prompt['input_ids'].to(device).detach()
    all_tokens_with_prompt['attention_mask'] = all_tokens_with_prompt['attention_mask'].to(device).detach()

    all_token_prompt_embed = self.model_language.model.embed_tokens(all_tokens_with_prompt['input_ids'])
    prompt_len = len(self.tokenizer_language([instruction_now])['input_ids'][0])
    caption_label_now = all_tokens_with_prompt['input_ids'][:,prompt_len:]
    caption_label_now = F.one_hot(caption_label_now,self.num_vocab)
    attn_mask_now = all_tokens_with_prompt['attention_mask'][:,prompt_len:]

    all_replaced_feature = []
    for temp_idx in range(len(all_tokens_with_prompt['input_ids'])) :
      tokens_text_now = all_tokens_with_prompt['input_ids'][temp_idx].detach().cpu()
      dummy_location_caption = self.search_trigger_idx(tokens_text_now, self.trigger_str_img )
      image_replaced_prompt = self.split_and_replace(all_token_prompt_embed[temp_idx], img_embed[temp_idx], dummy_location_caption)

      all_replaced_feature.append(image_replaced_prompt)
    all_replaced_feature = torch.stack(all_replaced_feature)


    logits_now = self.model_language(inputs_embeds =all_replaced_feature, attention_mask=all_tokens_with_prompt['attention_mask'])

    logits_now = logits_now['logits']
    caption_prediction_now = logits_now[:,prompt_len-1:-1]
    caption_prediction_now = torch.softmax(caption_prediction_now,-1)
    caption_prediction_now = torch.maximum(caption_prediction_now,torch.as_tensor(1e-10).to(caption_prediction_now.dtype))
    caption_prediction_now = torch.minimum(caption_prediction_now,torch.as_tensor(1 - 1e-10).to(caption_prediction_now.dtype))


    loss_lm = -torch.sum(caption_label_now*torch.log(caption_prediction_now),-1)
    loss_lm = torch.sum(loss_lm*attn_mask_now,-1)/torch.sum(attn_mask_now,-1)
    loss_lm = torch.mean(loss_lm)

    return loss_lm

  def generate_aswer_image(self, input_string, pil_image, max_new_tokens = 32, do_sample=True, top_k=50, top_p=0.95, temperature =1 ) :

    input_with_dummy_prompt = self.tokenizer_language.apply_chat_template(input_string, tokenize=False, add_generation_prompt=True)
    input_with_dummy_prompt = input_with_dummy_prompt.replace("<image>", "<start_image> "+self.dummy_img_token+"\n<end_image>")
    dummy_input = self.tokenizer_language(input_with_dummy_prompt,padding=True,return_tensors="pt")
    dummy_input['input_ids'] = dummy_input['input_ids'].to(device)
    dummy_input['attention_mask'] = dummy_input['attention_mask'].to(device)
    assert len(dummy_input['input_ids']) == 1

    handler_image = None

    contains_image = False
    if self.trigger_str_img in input_with_dummy_prompt :
      image_input = self.image_processor([pil_image], return_tensors="pt")['pixel_values'].to(device)
      hook_now_image = self.replace_embedding_hook(image_input)
      contains_image = True
      handler_image = self.model_language.model.embed_tokens.register_forward_hook(hook_now_image)



    output_now = self.model_language.generate(**dummy_input,
                                              max_new_tokens = max_new_tokens,
                                              do_sample=do_sample,
                                              temperature=temperature,
                                              top_k=top_k,
                                              top_p=top_p,
                                              )
    output_string = self.tokenizer_language.batch_decode(output_now, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    if contains_image :
      handler_image.remove()

    return output_string.split("model\n")[-1]

# Hyperparameter settings

In [None]:
CAPTION_PATH = "/content/Flickr8k.token.txt"
IMAGES_FILE_PATH = "/content/Flicker8k_Dataset"
SAVED_PATH = "/content/saved_model/adaptor_caption.pt"
#SAVED_PATH = "/content/drive/MyDrive/Model/gemma/adaptor_caption2.pt"

BATCH_SIZE = 8
NUM_ITERATION = 2000
SAVE_EVERY = 200
LEARNING_RATE = 1e-4
TRAIN_DATA_NUM = 7500

device = 'cpu'
if torch.cuda.is_available() :
  device = 'cuda'

# Utility Function
This function will help the project training and inferences run smoothly.

In [None]:
def check_model_nan(model):
  num_nan = 0
  for param in model.parameters():
    num_nan += torch.sum(torch.isnan(param))
  return num_nan > 0


def getLabelDictionary(file_path):
  file_now = open(file_path)
  all_string = file_now.read()
  all_string = all_string.split("\n")
  label_dictionary = {}
  for line_now in all_string:
    splitted_line = line_now.split("\t")
    if len(splitted_line) > 1:
      file_name_now = splitted_line[0].split('#')[0]
      number_now = splitted_line[0].split('#')[1]
      label_now = splitted_line[1]

      if file_name_now in label_dictionary.keys() :
        label_dictionary[file_name_now].append((label_now))
      else :
       label_dictionary.update({file_name_now : [label_now]})
  return label_dictionary

def count__model_param(model_now):
  counter = 0
  for param in model_now.parameters():
    counter += torch.sum(torch.ones_like(param))
  return counter

def sample_data_caption(file_list_now, caption_dict_now, n ):
  # base_path = "/content/Flicker8k_Dataset"
  base_path = IMAGES_FILE_PATH
  rand_idx = np.random.randint(0,len(file_list_now),n)

  all_image = []
  all_text = []
  for idx_now in rand_idx :
    file_now = base_path + "/" + file_list_now[idx_now]
    image_now = Image.open(file_now)
    all_image.append(image_now)

    text_list_now = caption_dict_now[file_list_now[idx_now]]
    selected_text_now_idx = np.random.randint(0,len(text_list_now))
    all_text.append(text_list_now[selected_text_now_idx])

  return all_image, all_text

# Model Initialization

In [13]:
label_dictionary = getLabelDictionary(CAPTION_PATH)
all_file = os.listdir(IMAGES_FILE_PATH)

model = MyModel()
model = model.to(device)
model = model.to(torch.bfloat16)

for param in model.parameters() :
  param.requires_grad = True
for param in model.model_language.parameters() :
  param.requires_grad = False
for param in model.model_image.parameters() :
  param.requires_grad = False

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

# Freeze the large language parameter and image encoder parameter

In [14]:
optim = torch.optim.AdamW(model.parameters(),LEARNING_RATE)

In [15]:
model.train()
for itr in range(NUM_ITERATION) :
  rand_image, rand_targets = sample_data_caption(all_file[0:TRAIN_DATA_NUM], label_dictionary, BATCH_SIZE)
  loss = model.forward_loss(rand_image, rand_targets)
  optim.zero_grad()
  loss.backward()
  optim.step()

  print((itr+1),"/", NUM_ITERATION,":", loss)

  if itr % SAVE_EVERY == (SAVE_EVERY - 1) :
    print("MODEL SAVED!")
    torch.save(model.adaptor.state_dict(),SAVED_PATH)

1 / 2000 : tensor(6.1875, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
2 / 2000 : tensor(5.3750, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
3 / 2000 : tensor(4.6562, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
4 / 2000 : tensor(4.4688, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
5 / 2000 : tensor(4.1875, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
6 / 2000 : tensor(4., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
7 / 2000 : tensor(3.9531, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
8 / 2000 : tensor(3.0156, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
9 / 2000 : tensor(4., device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
10 / 2000 : tensor(3.2969, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
11 / 2000 : tensor(3.5625, device='cuda:0', dtype=torch.bfloat16, grad_fn=<MeanBackward0>)
12 / 2000 : tens

In [17]:
{"role":"user",
"content":"Create a short poem out of the <image>!"}

{'role': 'user', 'content': 'Create a short poem out of the <image>!'}

In [18]:
image_url = "https://i.ytimg.com/vi/FHytoCvj90w/maxresdefault.jpg"
saved_name = "./temp_img." + image_url.split(".")[-1]
urllib.request.urlretrieve(image_url,saved_name)
pil_image_now = Image.open(saved_name)
pil_image_now
output_string = model.generate_aswer_image([{
    "role":"user",
    "content":"<image>Create a simple poem of the image!"
}], pil_image_now, max_new_tokens=256)
print(output_string)

The 'batch_size' attribute of HybridCache is deprecated and will be removed in v4.49. Use the more precisely named 'self.max_batch_size' attribute instead.


Three furry friends , red and black ,
they jump and run , so black .
one brown one white with paws ,
are playing tug of war , on paws .
the dog pulls hard , the dog so strong ,
the third one barks , he wants to be gone .
he runs and jumps and plays all day ,
then runs back home to sleep , in a sunny ray .



The brown , red and white dog play tug of war ,
The little dog , so small , wants to be gone .



The furry friends , they run and play .
with smiles and barks , they go all day .
on sunshiny days , they get in a heap ,
the big dog barks , and makes friends , and takes a leap . 



I hope that this is a poem of the picture .



The dogs are playing together in a nice warm place .
The white dog , he jumps and tries to get away .
It's a beautiful day , sunny and light ,
the dogs play all day .



You know , I love dog dogs , they're my friends so true .
The dogs play all day and they have so much to do .



The white dog is having fun , he loves to play ,


In [45]:
from huggingface_hub import HfApi, create_repo, login
import torch
import os
from huggingface_hub import push_to_hub_keras

# Assuming "Obrempong77" is your Hugging Face username and "GemmaVission" is your model name
model_id = "Obrempong77/GemmaVission"
SAVED_PATH = "/content/saved_model/adaptor_caption.pt"

# Load and save the model
model.adaptor.load_state_dict(torch.load(SAVED_PATH, map_location=torch.device('cpu')))

# Create the directory to save the model
model_dir = f"/content/{model_id.split('/')[-1]}"
if os.path.exists(model_dir):
    os.system(f"rm -rf {model_dir}")
os.makedirs(model_dir, exist_ok=True)

# Save the PyTorch model
torch.save(model.adaptor.state_dict(), f"{model_dir}/pytorch_model.bin")

# Create a README.md file
with open(f"{model_dir}/README.md", "w") as f:
    f.write("# GemmaVission Adaptor Model\n\nThis is the fine-tuned adaptor model for GemmaVission.\n")

# Login to Hugging Face
from google.colab import userdata
HuggingFace_Token = userdata.get('HUGGINGFACE')
login(HuggingFace_Token)

# Create the repository on Hugging Face (if it doesn't exist)
api = HfApi()
api.create_repo(repo_id=model_id, exist_ok=True, token=HuggingFace_Token)

# Push the model to the Hugging Face Hub
from huggingface_hub import upload_folder

upload_folder(
    repo_id=model_id,
    folder_path=model_dir,
    token=HuggingFace_Token,
    commit_message="Upload fine-tuned adaptor model",
)

print("Model successfully pushed to Hugging Face Hub!")


  model.adaptor.load_state_dict(torch.load(SAVED_PATH, map_location=torch.device('cpu')))
- empty or missing yaml metadata in repo card


pytorch_model.bin:   0%|          | 0.00/7.09M [00:00<?, ?B/s]

Model successfully pushed to Hugging Face Hub!
