#### Loading Minimal LLaVA
* Language Model decoder (Llama3.1)
* Vision Encoder (CLIP)
* Vision projector (MLP with Randomized weight)


<div align="center">
  <img src="data/mini-llava.png" width="800" alt="Mini-LLaVA">
  <p><em>Mini-LLaVA handles text, image and video inputs</em></p>
</div>

In [None]:
from mini_llava import LlavaLlamaForCausalLM # Register the llava models into 'transformers'
from transformers import AutoConfig, AutoTokenizer
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
tokenizer.add_special_tokens({'pad_token': '[PAD]'})
config = AutoConfig.for_model("llava_llama", trust_remote_code=True)
llava_model = LlavaLlamaForCausalLM.from_pretrained_lm(config).to(device) # Initalize with Load Llama3.1 Weights & CLIP encoder

#### Let's check if Llama3.1 is working fine here

In [2]:
from mini_llava import generate_text, generation_config

# Example usage
prompt = "Once upon a time in a galaxy far, far away,"
text = generate_text(prompt, llava_model, tokenizer, device, generation_config)
print(f"Prompt: {prompt}")
print(f"Generated text: {text}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Prompt: Once upon a time in a galaxy far, far away,
Generated text:  there was a young Jedi named Kael. Kael was a skilled warrior and a master of the Force, but he was also a bit of a rebel. He had a tendency to disobey orders and follow his own path, which often put him at odds with his fellow Jedi.
One day, Kael received a message from an unknown sender, claiming to have information about a powerful Sith Lord who was hiding in a distant corner of the galaxy. The message read: "Meet me on the planet of Zorvath at midnight. Come alone."
Kael was intrigued by the message and decided to investigate. He packed his bags, said goodbye to his fellow Jedi, and set off for Zorvath.
As he arrived on the planet, Kael noticed that the air was thick with an eerie, pulsating energy. He could feel the presence of the dark side all around him, and he knew that he was getting close to his target.
At midnight, Kael made his way to the designated meeting point, a large, ancient temple in the heart of t

<div align="center">
  <img src="data/cat.jpg" width="500" alt="Cat image">
  <p><em>We want Mini-LLava to recognize cat in this image</em></p>
</div>

### Before Training, model can't see

In [18]:
from mini_llava import data_args, LazyProcessor

proc = LazyProcessor(tokenizer=tokenizer, data_args=data_args, image_processor=llava_model.get_model().vision_tower.image_processor)

img_path = "data/cat.jpg"
query = "What is in the image?"
proc.query(question = query, 
           # media_paths = []) # text-only
           media_paths = [img_path]) # interleaved text & image chat

generated_texts = proc.get_response(llava_model, tokenizer, generation_config)
print("Mini-LLaVA Response: \n\n  ", generated_texts[0]) # Model has no idea what's in the image yet

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Mini-LLaVA Response: 

   The image of the image.


<div align="center">
  <p><em>What is in the image?</em></p> 
  <img src="data/cat.jpg" width="400" alt="Cat image">
  <p><em>Mini-LLaVA before training: The image of the image.</em></p>
</div>

### Pre-Train vision projector on a visual question-answer dataset (~8K)
* A projector learns how to 'translate' image to embeddings, which LLM understands.

In [None]:
from mini_llava import prepare_docci_data, DataCollatorForSupervisedDataset, LazySupervisedDataset
from mini_llava import train_mini_llava as train
from torch.utils.data import DataLoader 

data_args = prepare_docci_data("data/docci_converted.json", "data/docci")
dataset = LazySupervisedDataset(data_args=data_args, tokenizer=tokenizer, image_processor=llava_model.get_model().vision_tower.image_processor)
collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
dataloader = DataLoader(dataset, collate_fn=collator, batch_size=4, num_workers=1)

train(llava_model, tokenizer, dataloader, use_lora=True) # This is easily the better choice (so many lower-level optimization happens here ...)

### After training, mini-Llava already recognizes the cat in the image

In [16]:
from mini_llava import LazyProcessor

proc = LazyProcessor(tokenizer=tokenizer, data_args=data_args, image_processor=llava_model.get_model().vision_tower.image_processor)

img_path = "data/cat.jpg"
query = "What is in the image?"
proc.query(question = query, 
           # media_paths = []) # for testing text-only chat
           media_paths = [img_path]) # for interleaved text & image chat

generated_texts = proc.get_response(llava_model, tokenizer)
print("Mini-LLaVA Response: \n\n  ", generated_texts[0]) # Model sees the cat now (!) Note that we've only trained the adaptor here.

Mini-LLaVA Response: 

   A close up view of a small, white, ceramic cat figurine sitting on a small, white, ceramic plate. The cat is facing left and has its front paws on the plate. The cat has a small, red nose and a small, black mouth. The cat has a small, white collar around its neck. The plate has a small, white rim around the edge. The plate is sitting on a small, white tablecloth. The tablecloth has a small, white border around the edge. The tablecloth is sitting on a small, white table. The table has a small, white leg on the right side. The table is sitting on a small, white floor. The floor is made of small, white tiles. The tiles are arranged in a grid pattern. The tiles are the same size as the table. The tiles are the same size as the tablecloth. The tiles are the same size as the plate. The tiles are the same size as the cat. The tiles are the same size as the table leg. The tiles are the same size as the table. The tiles are the same size as the tablecloth. The tiles ar

<div align="center">
  <p><em>What is in the image?</em></p> 
  <img src="data/cat.jpg" width="400" alt="Cat image">
  <p style="margin-left: 40px; margin-right: 40px;"><em>Mini-LLaVA after pre-training: Cat! </em></p>
</div>