In [1]:
import time
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
from PIL import Image
import torch
import json

# Set device to GPU if available, otherwise CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

# Load LLaVA-v1.6 model and processor
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-vicuna-13b-hf", torch_dtype=torch.float32, low_cpu_mem_usage=True,load_in_4bit=True
                                                          #,use_flash_attention_2=True
                                                         )

# Function to load the shuffled dataset
def load_dataset(dataset_file):
    data = []
    with open(dataset_file, 'r') as file:
        for line in file:
            try:
                data.append(eval(line.strip()))  # Parse each line as a Python dictionary
            except SyntaxError as e:
                print(f"Error parsing line: {line}, Error: {e}")
    return data

# Function to process each entry in the dataset and generate text output using LLaVA
def generate_output_from_dataset(dataset):
    results = {}
    prompt = "Based on the image and caption, analyze if they reflect symptoms of anxiety or depression. "
    count=0
    start =time.time()
    for entry in dataset[:10000]:
        
        count+=1
        if(count%10==0):
            
            print(count+10000, time.time()-start)
            start =time.time()
        
        image_path = entry['image_path']
        caption = entry['caption']
        
        try:
            # Load image
            image = Image.open(image_path).convert("RGB")

            # Create a conversation input for LLaVA (image + caption)
            conversation = [
                {"role": "user", 
                 "content": [{"type": "text", "text": prompt + f" Caption: {caption}"}, {"type": "image"}]}
            ]
            prompt_text = processor.apply_chat_template(conversation, add_generation_prompt=False)

            # Preprocess the image and text
            inputs = processor(images=image, text=prompt_text, return_tensors="pt").to(device)

            # Generate the output from the model
            output = model.generate(input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=300)
            
            # Decode the output and store the result
            result = processor.decode(output[0], skip_special_tokens=True)
            results[image_path] = result
            #print(f"Generated output for {image_path}: {result}")

        except Exception as e:
            print(f"Error processing {image_path}: {e}")

    return results

# Load the shuffled dataset
dataset_file = "SHUFFLED_DATASET.txt"  # Replace with the path to your shuffled dataset
dataset = load_dataset(dataset_file)

# Generate text output for the dataset using LLaVA
generated_outputs = generate_output_from_dataset(dataset)

# Save the generated outputs to a file
with open('LLMgenerated_outputs.json', 'w') as f:
    json.dump(generated_outputs, f)

# Print a sample output
for image_path, output in generated_outputs.items():
    print(f"Image: {image_path}, Generated Output: {output}")
    break  # Just print the first example


cuda


The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.


10010 86.46623754501343


KeyboardInterrupt: 