<a href="https://colab.research.google.com/github/ayush-vatsal/Caption-Studio/blob/main/AI_Image_Captions_for_Social_Media.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Run all and ignore!

### Necessary Imports

In [1]:
! pip install -q -U evaluate
! pip install -q -U jiwer
! pip install -q -U --upgrade huggingface_hub
!pip install -q -U gradio

!pip install -q -U bitsandbytes
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U git+https://github.com/huggingface/peft.git
!pip install -q -U git+https://github.com/huggingface/accelerate.git
!pip install -q -U einops
!pip install -q -U safetensors
!pip install -q -U torch
!pip install -q -U xformers
!pip install -q -U datasets

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.4/81.4 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m486.2/486.2 kB[0m [31m28.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m110.5/110.5 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.5/212.5 kB[0m [31m17.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.3/134.3 kB[0m [31m15.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.8/268.8 kB[0m [31m23.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.2/2.2 MB[0m [31m39.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m19.8/19.8 MB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━

### [GIT: A Generative Image-to-text Transformer for Vision and Language](https://arxiv.org/abs/2205.14100)

Model for generating descriptions from image

[Model overview](https://huggingface.co/docs/transformers/model_doc/git)

In [2]:
# access_token = "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"

# from huggingface_hub import login
# login()

In [3]:
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
import torch
from PIL import Image
import requests
import traceback

In [4]:
class Image2Text:
    def __init__(self):
        # Load the GIT coco model
        preprocessor_git_large_coco = AutoProcessor.from_pretrained("microsoft/git-large-coco")
        model_git_large_coco = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.preprocessor = preprocessor_git_large_coco
        self.model = model_git_large_coco
        self.model.to(self.device)


    def image_description(
        self,
        image_url,
        max_length=50,
        temperature=0.1,
        use_sample_image=False,
    ):
        """
        Generate captions for the given image.

        -----
        Parameters
        image_url: Image URL
            The image to generate captions for.
        max_length: int
            The max length of the generated descriptions.

        -----
        Returns
        str
            The generated image description.
        """
        caption_git_large_coco = ""

        if use_sample_image:
            image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"

        image = Image.open(requests.get(image_url, stream=True).raw)

        # Generate captions for the image using the GIT coco model
        try:
            caption_git_large_coco = self._generate_description(image, max_length, False).strip()
            return caption_git_large_coco

        except Exception as e:
            print(e)
            traceback.print_exc()


    def _generate_description(
        self,
        image,
        max_length=50,
        use_float_16=False,
    ):
        """
        Generate captions for the given image.

        -----
        Parameters
        image: PIL.Image
            The image to generate captions for.
        max_length: int
            The max length of the generated descriptions.
        use_float_16: bool
            Whether to use float16 precision. This can speed up inference, but may lead to worse results.

        -----
        Returns
        str
            The generated caption.
        """
        # inputs = preprocessor(image, return_tensors="pt").to(device)
        pixel_values = self.preprocessor(images=image, return_tensors="pt").pixel_values.to(self.device)
        generated_ids = self.model.generate(
            pixel_values=pixel_values,
            max_length=max_length,
        )
        generated_caption = self.preprocessor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        return generated_caption

### Finetuned Falcon-7B Model

A decoder-only LLM finetuned for generating social media worthy captions. The base model is finetuned using [QLoRA](https://arxiv.org/abs/2305.14314) and the hugging-face [`peft`](https://github.com/huggingface/peft) library.

[Model card](https://huggingface.co/tiiuae/falcon-7b)

In [5]:
import json
import os
from pprint import pprint

import bitsandbytes as bnb
import pandas as pd

import torch
import torch.nn as nn
import transformers
from datasets import load_dataset
from huggingface_hub import notebook_login
from peft import (
  LoraConfig ,
  PeftConfig ,
  PeftModel ,
  get_peft_model ,
  prepare_model_for_kbit_training,
)
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, AutoConfig
from peft import LoraConfig, get_peft_model


os.environ["CUDA_VISIBLE_DEVICES"] = "0"


Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 118
CUDA SETUP: Loading binary /usr/local/lib/python3.10/dist-packages/bitsandbytes/libbitsandbytes_cuda118.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)


In [6]:
class Social_Media_Captioner:
    def __init__(self, use_finetuned: bool=True, temp=0.1):
        self.use_finetuned = use_finetuned
        self.MODEL_NAME = "vilsonrodrigues/falcon-7b-instruct-sharded"
        self.peft_model_name = "ayush-vatsal/caption_qlora_finetune"
        self.model_loaded = False
        self.device = "cuda:0"

        self._load_model()

        self.generation_config = self.model.generation_config
        self.generation_config.max_new_tokens = 50
        self.generation_config.temperature = temp
        self.generation_config.top_p = 0.7
        self.generation_config.num_return_sequences = 1
        self.generation_config.pad_token_id = self.tokenizer.eos_token_id
        self.generation_config.eos_token_id = self.tokenizer.eos_token_id

        self.cache: list[dict] = [] # [{"image_decription": "A man", "caption": ["A man"]}]


    def _load_model(self):
        try:
            self.bnb_config = BitsAndBytesConfig(
                load_in_4bit = True,
                bnb_4bit_use_double_quant = True,
                bnb_4bit_quant_type= "nf4",
                bnb_4bit_compute_dtype=torch.bfloat16,
                )
            self.model = AutoModelForCausalLM.from_pretrained(
                self.MODEL_NAME,
                device_map = "auto",
                trust_remote_code = True,
                quantization_config = self.bnb_config
                )

            # Defining the tokenizers
            self.tokenizer = AutoTokenizer.from_pretrained(self.MODEL_NAME)
            self.tokenizer.pad_token = self.tokenizer.eos_token

            if self.use_finetuned:
                # LORA Config Model
                self.lora_config = LoraConfig(
                    r=16,
                    lora_alpha=32,
                    target_modules=["query_key_value"],
                    lora_dropout=0.05,
                    bias="none",
                    task_type="CAUSAL_LM"
                )
                self.model = get_peft_model(self.model, self.lora_config)

                # Fitting the adapters
                self.peft_config = PeftConfig.from_pretrained(self.peft_model_name)
                self.model = AutoModelForCausalLM.from_pretrained(
                    self.peft_config.base_model_name_or_path,
                    return_dict = True,
                    quantization_config = self.bnb_config,
                    device_map= "auto",
                    trust_remote_code = True
                    )
                self.model = PeftModel.from_pretrained(self.model, self.peft_model_name)

                # Defining the tokenizers
                self.tokenizer = AutoTokenizer.from_pretrained(self.peft_config.base_model_name_or_path)
                self.tokenizer.pad_token = self.tokenizer.eos_token

            self.model_loaded = True
            print("Model Loaded successfully")

        except Exception as e:
            print(e)
            self.model_loaded = False


    def inference(self, input_text: str, use_cached=True, cache_generation=True) -> str | None:
        if not self.model_loaded:
            raise Exception("Model not loaded")

        try:
            prompt = Social_Media_Captioner._prompt(input_text)
            if use_cached:
                for item in self.cache:
                    if item['image_description'] == input_text:
                        return item['caption']

            encoding = self.tokenizer(prompt, return_tensors = "pt").to(self.device)
            with torch.inference_mode():
                outputs = self.model.generate(
                    input_ids = encoding.input_ids,
                    attention_mask = encoding.attention_mask,
                    generation_config = self.generation_config
                )
                generated_caption = (self.tokenizer.decode(outputs[0], skip_special_tokens=True).split('Caption: "')[-1]).split('"')[0]

                if cache_generation:
                    for item in self.cache:
                        if item['image_description'] == input_text:
                            item['caption'].append(generated_caption)
                            break
                    else:
                        self.cache.append({
                            'image_description': input_text,
                            'caption': [generated_caption]
                        })

                return generated_caption
        except Exception as e:
            print(e)
            return None


    def _prompt(input_text="A man walking alone in the road"):
        if input_text is None:
            raise Exception("Enter a valid input text to generate a valid prompt")

        return f"""
            Convert the given image description to social media worthy caption
            Description: {input_text}
            Caption:
            """.strip()

    @staticmethod
    def get_trainable_parameters(model):
        trainable_params = 0
        all_param = 0
        for _, param in model.named_parameters():
            all_param += param.numel()
            if param.requires_grad:
                trainable_params += param.numel()
        return f"trainable_params: {trainable_params} || all_params: {all_param} || Percentage of trainable params: {100*trainable_params / all_param}"


    def __repr__(self):
        return f"""
        Base Model Name: {self.MODEL_NAME}
        PEFT Model Name: {self.peft_model_name}
        Using PEFT Finetuned Model: {self.use_finetuned}
        Model: {self.model}

        ------------------------------------------------------------

        {Social_Media_Captioner.get_trainable_parameters(self.model)}
        """

## Assembly Line

In this section, we assemble the **image-to-text** (image description) and **description-to-caption** parts.

In [7]:
class Captions:
    def __init__(self, use_finetuned_LLM: bool=True, temp_LLM=0.1):
        self.image_to_text = Image2Text()
        self.LLM = Social_Media_Captioner(use_finetuned_LLM, temp_LLM)

    def generate_captions(
        self,
        image,
        image_url=None,
        max_length_GIT=50,
        temperature_GIT=0.1,
        use_sample_image_GIT=False,
        use_cached_LLM=True,
        cache_generation_LLM=True
    ):
        if image_url:
            image_description = self.image_to_text.image_description(image_url, max_length=max_length_GIT, temperature=temperature_GIT, use_sample_image=use_sample_image_GIT)
        else:
            image_description = self.image_to_text._generate_description(image, max_length=max_length_GIT)
        captions = self.LLM.inference(image_description, use_cached=use_cached_LLM, cache_generation=cache_generation_LLM)
        return captions

## Get your captions

In [8]:
caption_generator = Captions()

Downloading (…)rocessor_config.json:   0%|          | 0.00/503 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/453 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/2.82k [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/1.58G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/141 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/773 [00:00<?, ?B/s]

Downloading (…)/configuration_RW.py:   0%|          | 0.00/2.61k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/tiiuae/falcon-7b-instruct:
- configuration_RW.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)main/modelling_RW.py:   0%|          | 0.00/47.5k [00:00<?, ?B/s]

A new version of the following files was downloaded from https://huggingface.co/tiiuae/falcon-7b-instruct:
- modelling_RW.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.


Downloading (…)fetensors.index.json:   0%|          | 0.00/16.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/15 [00:00<?, ?it/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.68G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.82G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.82G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.82G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.82G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/1.99G [00:00<?, ?B/s]

Downloading (…)of-00015.safetensors:   0%|          | 0.00/828M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

Downloading (…)neration_config.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/180 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.73M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/281 [00:00<?, ?B/s]

Downloading (…)/adapter_config.json:   0%|          | 0.00/436 [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

Downloading adapter_model.bin:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

Model Loaded successfully


In [9]:
# def setup(im_shape=((400, 600))):
#     image_url = input("Enter the image url: ")
#     captions = caption_generator.generate_captions(image_url=image_url)

#     raw_image = Image.open(requests.get(image_url, stream=True).raw).convert('RGB')
#     display(raw_image.resize(im_shape))
#     print(f"\nCaption: ## {captions} ##\n")

In [10]:
import gradio as gr

def setup(image):
    return caption_generator.generate_captions(image = image)

iface = gr.Interface(
    fn=setup,
    inputs=gr.inputs.Image(type="pil", label="Upload Image"),
    outputs=gr.outputs.Textbox(label="Caption")
)

iface.launch()

  inputs=gr.inputs.Image(type="pil", label="Upload Image"),
  inputs=gr.inputs.Image(type="pil", label="Upload Image"),
  outputs=gr.outputs.Textbox(label="Caption")


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
Note: opening Chrome Inspector may crash demo inside Colab notebooks.

To create a public link, set `share=True` in `launch()`.


<IPython.core.display.Javascript object>

