In [3]:
import pandas as pd
import faiss
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import pickle
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
import random
import json
from accelerate import Accelerator

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
seed_everything(42)


### recall images based on the option

In [4]:
class ImageEmbedder:
    def __init__(self, model, preprocessor):
        """ model projects image to vector, processor load and prepare image to the model"""
        self.model = model
        self.processor = preprocessor

def BLIP_BASELINE():
    from torchvision import transforms
    from torchvision.transforms.functional import InterpolationMode

    import sys
    sys.path.insert(0, './BLIP')
    from BLIP.models.blip_itm import blip_itm
    # load model
    model = blip_itm(pretrained='./BLIP/chatir_weights.ckpt',  # Download from Google Drive, see README.md
                     med_config='BLIP/configs/med_config.json',
                     image_size=224,
                     vit='base'
                     )

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device).eval()

    # define Image Embedder (raw_image --> img_feature)
    transform_test = transforms.Compose([
        transforms.Resize((224, 224), interpolation=InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
    ])

    def blip_project_img(image):
        embeds = model.visual_encoder(image)
        projection = model.vision_proj(embeds[:, 0, :])
        return F.normalize(projection, dim=-1)

    def blip_prep_image(path):
        raw = Image.open(path).convert('RGB')
        return transform_test(raw)

    image_embedder = ImageEmbedder(blip_project_img, lambda path: blip_prep_image(path))

    # define dialog encoder (dialog --> img_feature)
    def dialog_encoder(dialog):
        text = model.tokenizer(dialog, padding='longest', truncation=True,
                               max_length=200,
                               return_tensors="pt"
                               ).to(device)

        text_output = model.text_encoder(text.input_ids, attention_mask=text.attention_mask,
                                         return_dict=True, mode='text')

        shift = model.text_proj(text_output.last_hidden_state[:, 0, :])
        return F.normalize(shift, dim=-1)

    return dialog_encoder, image_embedder

class TextDataset(torch.utils.data.Dataset):
    def __init__(self, texts: list):
        """
        Args:
            texts (list of str): List of text strings.
            processor (transformers processor): Processor to tokenize the text.
        """
        self.texts = texts
        # self.processor = processor

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]  
        return {'text': text}

def encode_text(dataset, model):
    """CLIP for encode text """
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=100,
                                              pin_memory=True,
                                              num_workers=4,
                                              prefetch_factor=2,
                                              shuffle=False,
                                              
                                              )
    all_features = []
    with torch.no_grad():
        for batch in data_loader:
            features = model(batch['text'])
            all_features.append(features.cpu())  
    return torch.cat(all_features)  

def retrieve_topk_images(query: list,
                         topk=10,
                         faiss_model=None,
                         text_encoder=None,
                         id2image=None,
                         ):
    text_dataset = TextDataset(query)
    query_vec = encode_text(text_dataset, text_encoder)
    query_vec = query_vec.numpy()
    query_vec /= np.linalg.norm(query_vec, axis=1, keepdims=True)

    distance, indices = faiss_model.search(query_vec, topk)
    indices = np.array(indices)
    image_paths = [[id2image.get(idx, 'path/not/found') for idx in row] for row in indices]
    return image_paths, indices


In [None]:
faiss_model = faiss.read_index('./checkpoints/blip_faiss.index')
with open('./checkpoints/id2image.pickle', 'rb') as f:
    id2image = pickle.load(f)
    
with open('./checkpoints/blip_image_embedding.pickle', 'rb') as f:
    image_vector = pickle.load(f)
dialog_encoder, image_embedder = BLIP_BASELINE()

### Generate Question

In [6]:

import matplotlib.pyplot as plt 
import numpy as np
def image_show(images):

    fig, axs = plt.subplots(4, 4, figsize=(15, 15))


    for ax, img in zip(axs.flatten(), images):
        ax.imshow(np.array(img))
        ax.axis('off')  

    plt.tight_layout()
    plt.show()

In [7]:
from peft import get_peft_model, LoraConfig, TaskType, PeftModel
from safetensors import safe_open
from pllava import PllavaProcessor, PllavaForConditionalGeneration, PllavaConfig
from accelerate import init_empty_weights, dispatch_model, infer_auto_device_map, load_checkpoint_in_model
from accelerate.utils import get_balanced_memory

def load_pllava(repo_id, num_frames, use_lora=False, weight_dir=None, lora_alpha=32, use_multi_gpus=False, pooling_shape=(16,12,12)):
    kwargs = {
        'num_frames': num_frames,
    }

    if num_frames == 0:
        kwargs.update(pooling_shape=(0,12,12)) 
    config = PllavaConfig.from_pretrained(
        repo_id if not use_lora else weight_dir,
        pooling_shape=pooling_shape,
        **kwargs,
    )

    
    with torch.no_grad():
        model = PllavaForConditionalGeneration.from_pretrained(repo_id, config=config, torch_dtype=torch.bfloat16)
        
    try:
        processor = PllavaProcessor.from_pretrained(repo_id)
    except Exception as e:
        processor = PllavaProcessor.from_pretrained('llava-hf/llava-1.5-7b-hf')
    
    processor.padding_side='left'
    processor.tokenizer.padding_side='left'
    

    if use_lora and weight_dir is not None:
        print("Use lora")
        peft_config = LoraConfig(
            task_type=TaskType.CAUSAL_LM, inference_mode=False,  target_modules=["q_proj", "v_proj"],
            r=128, lora_alpha=lora_alpha, lora_dropout=0.
        )
        print("Lora Scaling:", lora_alpha/128)
        model.language_model = get_peft_model(model.language_model, peft_config)
        assert weight_dir is not None, "pass a folder to your lora weight"
        print("Finish use lora")
    
    if weight_dir is not None:
        state_dict = {}
        save_fnames = os.listdir(weight_dir)
        if "model.safetensors" in save_fnames:
            use_full = False
            for fn in save_fnames:
                if fn.startswith('model-0'):
                    use_full=True        
                    break
        else:
            use_full= True

        if not use_full:
            print("Loading weight from", weight_dir, "model.safetensors")
            with safe_open(f"{weight_dir}/model.safetensors", framework="pt", device="cpu") as f:
                for k in f.keys():
                    state_dict[k] = f.get_tensor(k)
        else:
            print("Loading weight from", weight_dir)
            for fn in save_fnames:
                if fn.startswith('model-0'):
                    with safe_open(f"{weight_dir}/{fn}", framework="pt", device="cpu") as f:
                        for k in f.keys():
                            state_dict[k] = f.get_tensor(k)
            
        if 'model' in state_dict.keys():
            msg = model.load_state_dict(state_dict['model'], strict=False)
        else:
            msg = model.load_state_dict(state_dict, strict=False)
        print(msg)

    if use_multi_gpus:
        max_memory = get_balanced_memory(
            model,
            max_memory=None,
            no_split_module_classes=["LlamaDecoderLayer"],
            dtype='bfloat16',
            low_zero=False,
        )

        device_map = infer_auto_device_map(
            model,
            max_memory=max_memory,
            no_split_module_classes=["LlamaDecoderLayer"],
            dtype='bfloat16'
        )

        dispatch_model(model, device_map=device_map)
        print(model.hf_device_map)

    model = model.eval()

    return model, processor



In [None]:
model, processor = load_pllava(repo_id='MODELS/pllava-7b', 
            num_frames=5, # num_images = 5
            use_lora=True, 
            weight_dir='MODELS/pllava-7b', 
            lora_alpha=4, 
            use_multi_gpus=False, 
            pooling_shape=(5,12,12),
            )


In [9]:
def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

In [10]:
from collections import defaultdict
def downsample_retrieved_images(images_path:list, strategy='clustering', image_vector=None, retrieve_image_indices:list=None):
    if strategy == 'clustering':
        topk_embedding = image_vector[retrieve_image_indices] 
        
        kmeans = faiss.Kmeans(d=topk_embedding.shape[1], k=5, niter=10, verbose=False) 

        kmeans.train(topk_embedding)
        distance, clustering_label = kmeans.index.search(topk_embedding, 1)
        
        clustering_label = clustering_label.flatten()
        label_to_paths = defaultdict(list)
        for path, label in zip(images_path, clustering_label):
            label_to_paths[label].append(path)
        
        sampled_paths_list = [random.choice(paths) for paths in label_to_paths.values()]
    elif strategy == 'interval':
        sampled_paths_list = []
        for i in range(0, 50, 10):
            sampled_path = random.choice(images_path[i:i+10])
            sampled_paths_list.append(sampled_path)
    elif strategy == 'topk_cos_similiarity':
        sampled_paths_list = images_path
    return sampled_paths_list

In [None]:
import re 
import json
def extract_json(text):
    pattern = r'{.*}'
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group()
    else:
        return 'parse incorrectly'

def extract_llm_ouptut(text):
    ext_json = extract_json(text)
    try:
        question = json.loads(ext_json)['Question to differentiate the pictures']
    except:
        question = ''
        pass
    return question
text = """
"What is the common object that appears in all",
"What is he distinguishing feature that can help differentiate the picture":"Color",
"Question to differentiate the pict":"What is the color of the tennis ball in each picture?"
"""
extract_llm_ouptut(text)
text = """{
                    "What is the common object that appears in all five pictures": "a large cake",
                    "What is the distinguishing feature that can help differentiate the picture": "the color of the cake",
                    "Question to differentiate the pictures": "What color is the cake in each of the pictures?"
                   }"""
extract_llm_ouptut(text)

### batch inference

In [10]:

class MultiModalDataset(torch.utils.data.Dataset):
    def __init__(self, option, images_path, processor):
        """
        Args:
            data (list of dict): List of data dictionaries, each containing images and conversations.
            processor: A processor for the model.
        """
        self.option = option
        self.images_path = images_path
        self.processor = processor


    def __len__(self):
        return len(self.option)

    def __getitem__(self, idx):
        image_paths = self.images_path[idx] 
        option = self.option[idx]

        image_tensor = [Image.open(image_path).convert("RGB") for image_path in image_paths]

    
        prompt = f"""You are Pllava, a large vision-language assistant. 
    # You are able to understand the video content that the user provides, and assist the user with a variety of tasks using natural language.
    # Follow the instructions carefully and explain your answers in detail based on the provided video.
    #  USER:<image>  USER: You need to find a common object that appears in all 5 pictures but has distinguishing features. Based on this object, ask a question to differentiate the pictures.

                Remember, you must ensure the question is specific, not abstract, and the answer should be directly obtainable by looking at the images.
                
                For example:
                Example 1: All 5 pictures have people, but the number of people differs. You can ask about the number of people.
                Example 2: All 5 pictures have cats, but the colors are different. You can ask about the color.
                Example 3: All 5 pictures have traffic lights, but their positions differ. You can ask about the position of the traffic lights.
                
                Ask a specific question based on the object that will help distinguish the pictures. The question is not overlapped with the description: {option}.
                Don't ask 2 questions each time. such as what is the attribute of a or b

                Output as the following format
                {{
                "What is the common object that appears in all five pictures":"",
                "What is he distinguishing feature that can help differentiate the picture":"",
                "Question to differentiate the pictures":""
                }}
                ""
                ASSISTANT:
                """

        # print('------The prompt example is that -------\n', prompt)
  
        encode = self.processor(prompt, image_tensor, return_tensors="pt")
        
        return {'input_ids': encode['input_ids'].squeeze(0), # shape: [seq_len]
                'attention_mask': encode['attention_mask'].squeeze(0), # shape: [seq_len]
                'pixel_values': encode['pixel_values'], # shape: [num_images, 3, 224, 224]
                }

def collate_fn(batch):

    # print(batch.keys())
    # input_ids, attention_mask, pixel_values = zip(*batch)
    input_ids = torch.nn.utils.rnn.pad_sequence([item['input_ids'] for item in batch],
                                                batch_first=True,
                                                padding_value=processor.tokenizer.pad_token_id)
    attention_mask = torch.nn.utils.rnn.pad_sequence([item['attention_mask'] for item in batch], batch_first=True, padding_value=0)
    pixel_values = torch.cat([item['pixel_values'] for item in batch], dim=0)  # flatten imagesi

    return {'input_ids': input_ids, # shape: [batch_size, seq_len]
            'attention_mask': attention_mask, # shape: [batch_size, seq_len]
            'pixel_values': pixel_values, # shape: [batch_size * num_images, 3, 224, 224]
            }


In [13]:
data = pd.read_csv('recall_res.csv')
option = data['option'].tolist()
# Batch Inference 
retrieved_images_path, retrieved_indices = retrieve_topk_images(option,
                                             topk=100,
                                             faiss_model=faiss_model,
                                             text_encoder=dialog_encoder,
                                             id2image=id2image,
                             )

# Downsample retrieved image
downsampled_images_paths = []
for img_path, idx in zip(retrieved_images_path, retrieved_indices):
    downsampled_images_path = downsample_retrieved_images(img_path, 
                                                       strategy='interval', 
                                                       image_vector=image_vector, 
                                                       retrieve_image_indices=idx
                                                      )
    downsampled_images_paths.append(downsampled_images_path)
    

In [14]:
# downsampled_images_paths
data['interval_sample_images'] = downsampled_images_paths

In [None]:
data_repeat_5times = data.loc[data.index.repeat(5)].reset_index(drop=True)
data_repeat_5times

In [20]:
data_repeat_5times.to_csv('./experiment_res/data_repeat_5times.csv',index=False)

In [17]:
option = data_repeat_5times['option'].tolist()
images_path = data_repeat_5times['interval_sample_images'].tolist()

In [None]:
accelerator = Accelerator()

In [None]:
dataset = MultiModalDataset(option, images_path, processor)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         batch_size=1,
                                         shuffle=False,
                                         pin_memory=True,
                                         num_workers=4,
                                         collate_fn=collate_fn,
                                              prefetch_factor=2
                                        )
model, dataloader = accelerator.prepare(model, dataloader)


output_res = []
for batch in tqdm(dataloader):
    batch = {k:v.to(accelerator.device) for k,v in batch.items()}
    with torch.no_grad():
        output_token = model.generate(**batch, media_type='video',
                                do_sample=True,
                                max_new_tokens=1000, 
                                  num_beams=1, 
                                  min_length=1, 
                                 top_p=0.9, 
                                  repetition_penalty=1, 
                                  length_penalty=1, 
                                  temperature=0.9,
                                )
    output_text = processor.batch_decode(output_token, skip_special_tokens=True, clean_up_tokenization_spaces=False)

    for out in output_text:
        llm_out = out.split('ASSISTANT:')[-1].strip()
        print(llm_out)
        processed_text = extract_llm_ouptut(str(llm_out))
        output_res.append(processed_text)
        print(processed_text)
        
        save_dict = {
            'original_output': llm_out,
            'questions': processed_text
        }
        with open('./experiment_res/one_option_five_images_five_response_13b.jsonl','a') as f:
            json_str = json.dumps(save_dict)
            f.write(json_str + '\n')

In [76]:

import pandas as pd
pd.set_option('display.max_columns',None)
pd.set_option('display.max_colwidth',None)

import json
def load_jsonl(filename):
    with open(filename, "r") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]

import re
def extract_option(text):
    pat_res = re.findall('The question is not overlapped with the description: ([^\n]+)', text)
    if len(pat_res)>0:
        return pat_res[0].strip()[:-1]
    return ''

In [77]:
final_res = load_jsonl('./experiment_res/one_option_five_images_five_response_13b.jsonl')
# final_res[0]

In [78]:
options = []
questions = []
original_outputs = []
for i in range(len(final_res)):
    options.append(extract_option(final_res[i]['option']))
    questions.append(final_res[i]['questions'])
    original_outputs.append(final_res[i]['original_output'])
response = pd.DataFrame(list(zip(options, questions, original_outputs)), columns=['option', 'one_tune_questions', 'original_output'])


In [None]:
response[response['original_output']=='']

In [80]:
repeat_data = pd.read_csv('./experiment_res/data_repeat_5times.csv',)
repeat_data.drop_duplicates(subset=['option'],inplace=True)

In [81]:
repeat_data = repeat_data[['option','target_image']]


In [98]:
merged_data= pd.merge(response, repeat_data, how='left', on=['option'])



In [105]:
merged_data = merged_data[merged_data['one_tune_questions']!='']
merged_data = merged_data.drop_duplicates(subset=['option','one_tune_questions'])
merged_data.reset_index(drop=True, inplace=True)
merged_data.drop(columns=['index'],inplace=True)

In [None]:
merged_data

In [84]:
merged_data.to_csv('./rank_res/one_option_five_images_five_response_13b.csv',index=False)