<a href="https://colab.research.google.com/github/Abhinav1004/Colab-Notebooks/blob/main/test_llava.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U bitsandbytes



In [None]:
import torch
import requests
from PIL import Image
from transformers import AutoProcessor, LlavaForConditionalGeneration, BitsAndBytesConfig
import io
import numpy as np # Needed for np.exp()

# --- Configuration ---
# The dedicated class for LLaVA (LlavaForConditionalGeneration) is often preferred
# over the generic AutoModelForCausalLM/AutoModelForVision2Seq for stability.
MODEL_ID = "llava-hf/llava-1.5-7b-hf"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if DEVICE == "cuda" else torch.bfloat16

# Recommended: Configuration for 4-bit quantization to save GPU VRAM
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=DTYPE,
    # bnb_4bit_quant_type="nf4", # Optional: can specify quant type
    # bnb_4bit_use_double_quant=True, # Optional: can use double quant
)

# --- Helper Function for Image Loading ---
def load_image_from_url(url: str) -> Image.Image:
    """Loads an image from a URL."""
    try:
        response = requests.get(url, stream=True, timeout=10)
        response.raise_for_status() # Raise an exception for bad status codes
        image = Image.open(io.BytesIO(response.content)).convert("RGB")
        return image
    except requests.exceptions.RequestException as e:
        print(f"Error loading image from URL: {e}")
        # Return a dummy black image if loading fails
        return Image.new('RGB', (336, 336), color = 'black')

# --- Step 1: Initialize the Processor and Model (Direct Loading) ---
print(f"Loading LLaVA model: {MODEL_ID} on device: {DEVICE}...")

# Use AutoProcessor for correct image and text preprocessing
processor = AutoProcessor.from_pretrained(MODEL_ID)

# CRITICAL FIX: Use LlavaForConditionalGeneration instead of AutoModelForCausalLM
# This resolves the "Unrecognized configuration class" ValueError.
model = LlavaForConditionalGeneration.from_pretrained(
    MODEL_ID,
    quantization_config=quantization_config,
    device_map="auto",
    torch_dtype=DTYPE # Set torch_dtype for the model itself
)
model.eval() # Set model to evaluation mode

# --- Step 2: Define the Prediction Function (MODIFIED TO RETURN CONFIDENCE) ---

def classify_tweet_relevance(image: Image.Image, tweet_text: str):
    """
    Uses LLaVA in a zero-shot manner for binary classification, returning label and confidence.
    """
    # CRITICAL: Prompt Engineering for Classification
    # LLaVA requires the <image> token and a specific USER/ASSISTANT format.
    classification_prompt = (
        "USER: <image>\n"
        "Analyze the visual content of the image and the following tweet text.\n"
        "The classification task is: Determine if this post is a **complaint regarding garbage, waste, or overflowing bins**.\n"
        f"Tweet: \"{tweet_text}\"\n"
        "Answer with only the single, capitalized word: 'RELEVANT' or 'NOT RELEVANT'.\n"
        "ASSISTANT:"
    )

    # 1. Prepare inputs using the processor
    inputs = processor(
        text=classification_prompt,
        images=image,
        return_tensors="pt"
    ).to(DEVICE, DTYPE)

    # 2. Generate the prediction, requesting scores for confidence calculation
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=10,
            temperature=0.1,
            do_sample=False,
            pad_token_id=processor.tokenizer.eos_token_id,
            # --- LINES ADDED FOR CONFIDENCE ---
            output_scores=True,
            return_dict_in_generate=True
        )

    # 3. Process the output and confidence
    output_ids = outputs.sequences
    scores = outputs.scores

    # Decode the generated text
    input_len = inputs['input_ids'].shape[1]
    generated_text = processor.decode(
        output_ids[0, input_len:],
        skip_special_tokens=True
    ).strip().upper()

    # Calculate Confidence (Log-likelihood of the first generated token)
    first_token_logits = scores[0].float()[0]

    # Convert logits to log-probabilities
    log_probs = torch.log_softmax(first_token_logits, dim=-1)

    # Get the ID of the first *generated* token
    predicted_token_id = output_ids[0, input_len].item()

    # Get the log-probability of the generated token
    log_confidence = log_probs[predicted_token_id].item()

    # Convert log-probability to a percentage confidence score
    confidence_score = float(np.exp(log_confidence) * 100)

    # 4. Classify based on the generated text
    if "RELEVANT" in generated_text:
        label = "RELEVANT (Garbage Complaint)"
    elif "NOT RELEVANT" in generated_text:
        label = "NOT RELEVANT (Other Topic)"
    else:
        label = f"UNCLEAR: {generated_text}"

    # Return both the label and the confidence score
    return label, confidence_score

# --- Step 3: Example Usage (MODIFIED TO HANDLE TWO RETURN VALUES) ---

print("\n--- Example 1: Garbage Complaint (Prediction based on prompt + image content) ---")
# Using a famous scenic view image, so the prediction should lean NOT RELEVANT
image_url_1 = "https://llava-vl.github.io/static/images/view.jpg"
image_1 = load_image_from_url(image_url_1)
tweet_text_1 = "The street looks awful today. This mess has been sitting here for three days! @LocalCityHall"

prediction_1, confidence_1 = classify_tweet_relevance(image_1, tweet_text_1)
print(f"Tweet Text: {tweet_text_1}")
print(f"Prediction: **{prediction_1}**")
print(f"Confidence: **{confidence_1:.2f}%**")
print("-" * 40)

print("\n--- Example 2: Non-Complaint Tweet (Prediction based on prompt + image content) ---")
# Corrected image URL for a second example
image_url_2 = "https://upload.wikimedia.org/wikipedia/commons/f/f1/Vuilnis_bij_Essent_Milieu.jpg"
image_2 = load_image_from_url(image_url_2)
tweet_text_2 = "So thankful for this beautiful hike this morning. Nature is the best therapy!"

prediction_2, confidence_2 = classify_tweet_relevance(image_2, tweet_text_2)
print(f"Tweet Text: {tweet_text_2}")
print(f"Prediction: **{prediction_2}**")
print(f"Confidence: **{confidence_2:.2f}%**")
print("-" * 40)

Loading LLaVA model: llava-hf/llava-1.5-7b-hf on device: cuda...


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]


--- Example 1: Garbage Complaint (Prediction based on prompt + image content) ---


The following generation flags are not valid and may be ignored: ['temperature']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Tweet Text: The street looks awful today. This mess has been sitting here for three days! @LocalCityHall
Prediction: **RELEVANT (Garbage Complaint)**
Confidence: **43.58%**
----------------------------------------

--- Example 2: Non-Complaint Tweet (Prediction based on prompt + image content) ---
Error loading image from URL: 403 Client Error: Forbidden for url: https://upload.wikimedia.org/wikipedia/commons/f/f1/Vuilnis_bij_Essent_Milieu.jpg
Tweet Text: So thankful for this beautiful hike this morning. Nature is the best therapy!
Prediction: **RELEVANT (Garbage Complaint)**
Confidence: **57.69%**
----------------------------------------
