In [1]:
import google.generativeai as genai
import os
from PIL import Image
import pandas as pd
from tqdm.notebook import tqdm
import time
import numpy as np

genai.configure(api_key="AIzaSyAobOQBDYxOuTX9Xa67iAozyXwluvvsQhg")

In [None]:
df = pd.read_csv("../datasets/waterbird/metadata.csv")
df.head(5)

In [3]:
LANDBIRD = 0
WATERBIRD = 1
LAND = 0
WATER = 1

In [None]:
train_df = df[df['split'] == 0]
val_df = df[df['split'] == 1]
test_df = df[df['split'] == 2]

train_df.reset_index(inplace=True)
val_df.reset_index(inplace=True)
test_df.reset_index(inplace=True)

len(train_df), len(test_df), len(val_df)

In [5]:
start, end = 0, 400
test_df = test_df.iloc[start:end]

In [6]:
model = genai.GenerativeModel(model_name="gemini-1.5-pro",
                              system_instruction="Are those waterbirds or a landbirds? Just give birdtype for each image sperated by comma")

In [None]:
inference_data = []
batch = []

for idx, row in tqdm(test_df.iterrows(), total=len(test_df)):
    img_filename = row['img_filename']
    img_label = int(row['y'])
    place = int(row['place'])
    
    
    batch.append({
        'img_filename': img_filename,
        'img_label': img_label,
        'place': place
    })

    if (idx + 1) % 5 == 0:
        images = []
        for data in batch:
            img = Image.open(os.path.join("../datasets/waterbird", data['img_filename']))
            images.append(img)

        response = model.generate_content(images)
        response = response.text.strip()
        
        print(response)
        
        for prediction, entry in zip(response.split(","), batch):
            inference_data.append({
                'img_filename': entry['img_filename'],
                'img_label': entry['img_label'],
                'place': entry['place'],
                'predicted': prediction
            })

        batch = []
        # break
        
        delay = 10
        time.sleep(delay)

In [None]:
len(inference_data)

In [20]:
np.save(f'./inference_data_{start}_{end - 1}.npy', inference_data)

In [16]:
import torch
from transformers.generation.utils import GenerationMixin
from transformers.models.auto.modeling_auto import AutoModel
from transformers.models.paligemma.modeling_paligemma import PaliGemmaPreTrainedModel, PaliGemmaMultiModalProjector, PaliGemmaCausalLMOutputWithPast
from transformers import PaliGemmaConfig, SiglipVisionConfig, GemmaConfig
from transformers.cache_utils import Cache
from transformers.modeling_utils import PreTrainedModel
from typing import Optional, Union, List, Tuple

In [None]:
from transformers import PaliGemmaForConditionalGeneration

In [None]:
Pla

In [2]:
vision_config = SiglipVisionConfig()
text_config = GemmaConfig()
configuration = PaliGemmaConfig(vision_config, text_config)

In [14]:
class CustomPaliGemmaForConditionalGeneration(PaliGemmaPreTrainedModel, GenerationMixin):
    def __init__(self, config: PaliGemmaConfig):
        super().__init__(config)
        # self.vision_tower = model.vision_tower
        self.vision_tower = AutoModel.from_config(config=config.vision_config)
        self.multi_modal_projector = PaliGemmaMultiModalProjector(config)
        self.vocab_size = config.text_config.vocab_size
        self._attn_implementation = config._attn_implementation

        self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
        self.post_init()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        pixel_values: torch.FloatTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[List[torch.FloatTensor], Cache]] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        num_logits_to_keep: int = 0,
    ) -> Union[Tuple, PaliGemmaCausalLMOutputWithPast]:

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if pixel_values is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both pixel_values and inputs_embeds at the same time, and must specify either one"
            )

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        is_training = token_type_ids is not None and labels is not None

        if inputs_embeds is None:
            inputs_embeds = self.get_input_embeddings()(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length(
            ) if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(
                0) + 1  # Paligemma positions are 1-indexed

        # Merge text and images
        if pixel_values is not None:
            image_outputs = self.vision_tower(
                pixel_values.to(inputs_embeds.dtype))
            selected_image_feature = image_outputs.last_hidden_state
            image_features = self.multi_modal_projector(selected_image_feature)
            image_features = image_features / (self.config.hidden_size**0.5)

            special_image_mask = (
                input_ids == self.config.image_token_index).unsqueeze(-1)
            special_image_mask = special_image_mask.expand_as(
                inputs_embeds).to(inputs_embeds.device)
            if inputs_embeds[special_image_mask].numel() != image_features.numel():
                image_tokens_in_text = torch.sum(
                    input_ids == self.config.image_token_index)
                raise ValueError(
                    f"Number of images does not match number of special image tokens in the input text. "
                    f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} "
                    "tokens from image embeddings."
                )
            image_features = image_features.to(
                inputs_embeds.device, inputs_embeds.dtype)
            inputs_embeds = inputs_embeds.masked_scatter(
                special_image_mask, image_features)

In [15]:
custom_model = CustomPaliGemmaForConditionalGeneration(configuration)