In [None]:
!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U langchain langchain-community chromadb sentence-transformers

Restart kernel (not factory reset!) for applying update transformers library

In [1]:
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

import torch
from PIL import Image

from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
processor.tokenizer.padding_side = "left"
processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id

# model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", load_in_4bit=True, device_map="auto")
model = LlavaNextForConditionalGeneration.from_pretrained(
    "llava-hf/llava-v1.6-mistral-7b-hf",
    quantization_config=quantization_config,
    device_map="auto"
)


2024-03-31 12:07:03.965081: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-03-31 12:07:03.965221: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-03-31 12:07:04.119977: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


preprocessor_config.json:   0%|          | 0.00/754 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.85k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/493k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.80M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/41.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/552 [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/70.2k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/4 [00:00<?, ?it/s]

model-00001-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00002-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00003-of-00004.safetensors:   0%|          | 0.00/4.92G [00:00<?, ?B/s]

model-00004-of-00004.safetensors:   0%|          | 0.00/380M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

In [2]:
from datasets import load_dataset
import json
from pathlib import Path
from tqdm import tqdm

root_path = Path('/kaggle/input/qa-over-slides-searching-for-the-information')
train_path = root_path.joinpath('train/train')
test_path = root_path.joinpath('test/test')

deck_image_summaries = {}

with open(root_path.joinpath('qa_train.jsonl')) as json_file:
    qa_data = [json.loads(line) for line in json_file]


In [3]:
def image_summarize(image, prompt, max_new_tokens=128):
    """Make image summary"""
    inputs = processor(prompt, image, return_tensors="pt").to("cuda")

    output = model.generate(**inputs, max_new_tokens=max_new_tokens, pad_token_id=processor.tokenizer.eos_token_id)
    return processor.decode(output[0], skip_special_tokens=True).split("[/INST]")[1].strip()
    

def generate_img_summaries(images):
    """
    Generate summaries for images

    :param img_base64_list: Base64 encoded images
    :return: List of image summaries and processed images
    """

    # Store image summaries
    image_summaries = []

    # Prompt
    prompt = """[INST] <image>\nYou are an assistant tasked with summarizing images for retrieval.
These summaries will be embedded and used to retrieve the raw image.
Give a concise summary of the image that is well optimized for retrieval. [/INST]"""

    # Apply summarization to images
    for i, image in enumerate(images):
        try:
            image_summaries.append(image_summarize(image, prompt))
        except:
            print(f"BadRequestError with image {i+1}")

    return image_summaries

In [4]:
import uuid

from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.schema.document import Document
from langchain.prompts import PromptTemplate
from langchain.schema.output_parser import StrOutputParser
from langchain.schema.runnable import RunnableLambda, RunnablePassthrough
from langchain.storage import InMemoryStore

import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def create_multi_vector_retriever(image_summaries, images):
    """
    Create retriever that indexes summaries, but returns raw images or texts

    :param vectorstore: Vectorstore to store embedded image sumamries
    :param image_summaries: Image summaries
    :param images: Base64 encoded images
    :return: Retriever
    """

    # Initialize the storage layer
    store = InMemoryStore()
    id_key = "doc_id"
    
    img_ids = [str(uuid.uuid4()) for _ in image_summaries]
    summary_img = [
        Document(page_content=s, metadata={id_key: img_ids[i]})
        for i, s in enumerate(image_summaries)
    ]
    
    model_name = "colbert-ir/colbertv2.0"
    embedding_function = HuggingFaceEmbeddings(model_name=model_name)

    # The vectorstore to use to index the summaries
    Chroma().delete_collection()
    vectorstore = Chroma.from_documents(summary_img, embedding_function)

    # Create the multi-vector retriever
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=store,
        id_key=id_key,
    )
    
    retriever.vectorstore.add_documents(summary_img)
    retriever.docstore.mset(
        list(zip(img_ids, images))
    )  # Store the image summary as the raw document

    return retriever

In [5]:
def llava_generate(_dict, prompt=None):
    image = _dict['image'][0]
    prompt = f"""[INST] <image>
You are an analyst tasked with answering questions about visual content.
You will be given a image from a slide deck / presentation.
Use this information to answer the user question.
Give a short, precise answer.
User-provided question: {_dict['question']}
Answer: [/INST]"""
    return image_summarize(image, prompt, max_new_tokens=256)

In [6]:
def multi_modal_rag_chain(retriever):
    """
    Multi-modal RAG chain
    """

    # RAG pipeline
    chain = (
        {
            "image": retriever,
            "question": RunnablePassthrough(),
        }
        | RunnableLambda(llava_generate)
        | StrOutputParser()
    )

    return chain

In [14]:
def predict(sample, deck_image_summaries, data_dir_path=train_path):
    question = sample['question']
    deck_name = sample['deck_name']
    
    dataset = load_dataset("imagefolder", data_dir=data_dir_path.joinpath(deck_name))
    images = dataset['train']['image']
    
    if deck_name not in deck_image_summaries:
        deck_image_summaries[deck_name] = generate_img_summaries(images)
        
    image_summaries = deck_image_summaries[deck_name]

    # Create retriever
    retriever_multi_vector_img = create_multi_vector_retriever(
        image_summaries,
        images,
    )
    
    chain_multimodal_rag = multi_modal_rag_chain(retriever_multi_vector_img)
    return chain_multimodal_rag.invoke(question)

In [8]:
def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))

def compute_f1(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec)

In [9]:
# n_samples = 10

# pred_answers = []
# f1_scores = []

# for sample in tqdm(qa_data[0:n_samples]):
#     pred_answer = predict(sample)
#     true_answer = sample['answer']
#     f1_score = compute_f1(pred_answer, true_answer)
    
#     pred_answers.append(pred_answer)
#     f1_scores.append(f1_score)

In [10]:
# for i, sample in enumerate(qa_data[0:n_samples]):
#     print("\tSample", i)
#     print("QUESTION:\n\t", sample['question'])
#     print("PREDICT ANSWER:\n\t", pred_answers[i])
#     print("TRUE ANSWER:\n\t", sample['answer'])
#     print("F1 SCORE:", f1_scores[i])
#     print("---------------------------")

In [11]:
import pandas as pd

test = pd.read_csv(root_path.joinpath('qa_test.csv'))
print("Test shape:", test.shape)

test_sorted = test.sort_values(by='deck_name')

Test shape: (1014, 3)


In [12]:
import csv

n_samples_test = 10

test_deck_image_summaries = {}
test_pred_answers = []

output_path = Path('/kaggle/working/')

In [20]:
for i in tqdm(range(n_samples_test)):
    sample = test_sorted.iloc[i]
    pred_answer = predict(sample, test_deck_image_summaries, data_dir_path=test_path)
    test_pred_answers.append(pred_answer)
    
    with open(output_path.joinpath('qa_test_deck_summaries.csv'), 'a') as file:
        writer = csv.DictWriter(file, fieldnames=['deck_name', 'slide_summaries'])
        writer.writerow({'deck_name': sample['deck_name'], 
                         'slide_summaries': test_deck_image_summaries[sample['deck_name']]})
    
    with open(output_path.joinpath('qa_sample_submission.csv'), 'a') as file:
        writer = csv.DictWriter(file, fieldnames=['ID', 'pred_answer'])
        writer.writerow({'ID': sample['ID'], 'pred_answer': pred_answer})

  0%|          | 0/10 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 10%|█         | 1/10 [00:10<01:30, 10.08s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 20%|██        | 2/10 [00:14<00:55,  6.94s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Downloading and preparing dataset image_folder/default to /root/.cache/huggingface/datasets/image_folder/default-a809ff933bc29276/0.0.0/ee92df8e96c6907f3c851a987be3fd03d4b93b247e727b69a8e23ac94392a091...
                

Downloading data files #5:   0%|          | 0/1 [00:00<?, ?obj/s]

 30%|███       | 3/10 [03:50<11:55, 102.20s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 40%|████      | 4/10 [03:57<06:28, 64.70s/it] 

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 50%|█████     | 5/10 [04:05<03:40, 44.10s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 60%|██████    | 6/10 [04:09<02:02, 30.72s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 70%|███████   | 7/10 [04:14<01:06, 22.24s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 80%|████████  | 8/10 [04:19<00:33, 16.64s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

 90%|█████████ | 9/10 [04:27<00:14, 14.14s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 10/10 [04:32<00:00, 27.26s/it]


In [None]:
n_start, n_end = 10, 100
for i in range(n_start, n_end):
    sample = test_sorted.iloc[i]
    pred_answer = predict(sample, test_deck_image_summaries, data_dir_path=test_path)
    test_pred_answers.append(pred_answer)
    
    with open(output_path.joinpath('qa_test_deck_summaries.csv'), 'a') as file:
        writer = csv.DictWriter(file, fieldnames=['deck_name', 'slide_summaries'])
        writer.writerow({'deck_name': sample['deck_name'], 
                         'slide_summaries': test_deck_image_summaries[sample['deck_name']]})
    
    with open(output_path.joinpath('qa_sample_submission.csv'), 'a') as file:
        writer = csv.DictWriter(file, fieldnames=['ID', 'pred_answer'])
        writer.writerow({'ID': sample['ID'], 'pred_answer': pred_answer})

  0%|          | 0/90 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/1 [00:00<?, ?it/s]

  1%|          | 1/90 [00:04<07:22,  4.98s/it]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Downloading and preparing dataset image_folder/default to /root/.cache/huggingface/datasets/image_folder/default-7535e2526ce358a4/0.0.0/ee92df8e96c6907f3c851a987be3fd03d4b93b247e727b69a8e23ac94392a091...
                

In [None]:
submission = pd.read_csv(output_path.joinpath('qa_sample_submission.csv'), names=['ID', 'pred_answer'], header=None)

In [27]:
submission

Unnamed: 0,ID,pred_answer
0,299,"In the image, the ""No Delegation Tokens"" diagr..."
1,672,The two components of the Container request ar...
2,579,The transparent process of land acquisitions i...
3,580,The image shows a diagram of the institutional...
4,590,The image shows the following levels of instit...
5,485,Old Act Section 4 is in New Act Section 11.
6,839,The 2013 LARR Act made no provision about soci...
7,766,"SIA stands for ""Society of Industrial and Appl..."
8,765,The content of Chapter IV includes determinati...
9,982,The 1894 Act did not specify compensation for ...
