<a href="https://colab.research.google.com/github/amir-asari/Qwen-VL-Basic/blob/main/Qwen_Zeroshot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os

### --- 1. Environment Setup and Installation ---
This block installs/upgrades libraries and forces a runtime restart.

This ensures bitsandbytes loads correctly for 4-bit quantization.


In [None]:
try:
    # Check if a specific environment variable is set after the first run
    if not os.environ.get('QWEN_VL_RESTARTED'):
        print("Installing necessary packages...")
        # 1. Install transformers from source for Qwen2.5-VL compatibility.
        # 2. Install/upgrade bitsandbytes for the latest 4-bit quantization features.
        # 3. Install necessary utility libraries, including 'datasets' for CIFAR-100.
        !pip install -q git+https://github.com/huggingface/transformers.git accelerate
        !pip install -q -U bitsandbytes
        !pip install -q qwen-vl-utils pillow requests datasets # Added datasets

        # Set an environment variable to prevent restart on the second run
        os.environ['QWEN_VL_RESTARTED'] = 'true'

        print("Installation complete. ***Please restart the runtime to load new packages!***")
        # --- IMPORTANT FIX: Restart the runtime after installing/upgrading packages ---
        os._exit(00)

except Exception as e:
    print(f"Initial installation failed: {e}")

### --- 2. Imports and Model Loading (with 4-bit Quantization) ---

In [None]:
import torch
import warnings
import random # Added for sampling the dataset
from PIL import Image
from datasets import load_dataset # Added for loading CIFAR-100
from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor, BitsAndBytesConfig

# Suppress minor warnings for a clean output
warnings.filterwarnings('ignore')

# Switched to the smaller 3B model for compatibility and speed on T4 GPUs.
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"

# Check for GPU availability
if torch.cuda.is_available():
    device = "cuda"
    dtype = torch.float16
else:
    device = "cpu"
    dtype = torch.float32

print(f"--- Environment Setup ---")
print(f"Device: {device}")
print(f"Loading Model: {MODEL_ID} (with 4-bit Quantization)")
print("-" * 50)

model = None
try:
    # 1. Define 4-bit quantization configuration
    # This is the key to preventing CUDA Out of Memory errors
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16
    )

    # 2. Load the Model and Processor
    tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
    processor = AutoProcessor.from_pretrained(MODEL_ID)

    # Load model using the 4-bit configuration
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        MODEL_ID,
        quantization_config=bnb_config,
        torch_dtype=dtype,
        device_map="auto",
    ).eval()
    print("Model loaded successfully in 4-bit precision.")

except Exception as e:
    print(f"Error loading model or libraries: {e}")
    print("Please ensure you have a T4 GPU enabled and the runtime has been restarted after installation.")
    model = None

### --- 3. Zero-Shot Classification Function ---

In [None]:
def zero_shot_classify(image: Image.Image, candidate_labels: list):
    """
    Performs zero-shot classification by generating a prompt that includes the
    list of candidate labels (VQA format) for the provided PIL Image object.
    """
    if model is None:
        return "Model failed to load. Cannot run inference."

    labels_str = ", ".join([f"'{label}'" for label in candidate_labels])

    # The classification prompt is framed as a VQA task
    question = (
        f"What is the most accurate classification of the object shown in this image? "
        f"Choose only one answer from the following candidates: {labels_str}. State only the chosen label."
    )

    print(f"\n--- Running Classification ---")
    print(f"Image object passed (Size: {image.size})") # Now printing image size
    print(f"Candidate Labels: {len(candidate_labels)} total")

    # 1. Construct the chat template, passing the PIL Image object
    conversation = [
        {
            "role": "user",
            "content": [
                {"type": "image", "image": image},
                {"type": "text", "text": question}
            ],
        }
    ]

    # 2. Process the input and generate response
    try:
        inputs = processor.apply_chat_template(
            conversation,
            tokenize=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt"
        ).to(device)

        # Generate response
        output_ids = model.generate(
            **inputs,
            do_sample=False,
            max_new_tokens=20, # Only need a short output (the chosen label)
        )

        response_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        # Extract only the assistant's response part
        assistant_tag = "<|im_start|>assistant\n"
        if assistant_tag in response_text:
            response_text = response_text.split(assistant_tag)[-1].strip()

        return response_text

    except Exception as e:
        return f"Error during model inference: {e}"

### --- 4. CIFAR-100 Dataset Loading and Execution ---

In [None]:
print("\n--- Loading CIFAR-100 Dataset (Test Split) ---")
# Load the test split of CIFAR-100 (5000 images, 100 classes)
dataset = load_dataset('cifar100', split='test')

# Extract the 100 class names to use as candidate labels
candidate_labels = dataset.features['fine_label'].names

print(f"Total CIFAR-100 Test Images: {len(dataset)}")
print(f"Number of Candidate Labels: {len(candidate_labels)}")
print(f"Example Labels: {candidate_labels[:5]}...")

# Select a small, fixed set of random samples for demonstration purposes
random.seed(42) # Ensure the same samples are picked every time
sample_indices = random.sample(range(len(dataset)), 5)

print("\n--- Starting Zero-Shot Classification on Sample Images ---")

for i, index in enumerate(sample_indices):
    item = dataset[index]

    # Extract the PIL Image object and the ground truth label name
    image_pil = item['img']
    expected_label = candidate_labels[item['fine_label']]

    # Perform classification using the PIL Image object
    predicted_label = zero_shot_classify(image_pil, candidate_labels)

    print(f"\n[CIFAR-100 Test {i+1} / {len(sample_indices)}]")
    print(f"Expected Label:  {expected_label}")
    print(f"Predicted Label: {predicted_label}")
    print("-" * 50)