In [1]:
!pip install --upgrade transformers
!pip install --upgrade bert_score

[0m

In [None]:
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import numpy as np
from transformers import LlavaProcessor, LlavaForConditionalGeneration
import torch
import os
import json
import gc
from tqdm import tqdm
from bert_score import score as bert_score
from datasets import load_dataset
import re  # Import the regular expressions module

# Check if CUDA is available
print(torch.cuda.is_available())
print(torch.version.cuda)

# Create data and result directories
os.makedirs("../data", exist_ok=True)
os.makedirs("./result", exist_ok=True)

# Load the dataset
data_id = "Trelis/chess_pieces"
dataset = load_dataset(data_id, cache_dir="../data")

# Load the model and processor
model_id = "llava-hf/llava-1.5-7b-hf"

torch.backends.cuda.matmul.allow_tf32 = True
print(f"TF32 Enabled: {torch.backends.cuda.matmul.allow_tf32}")

# Use the processor and model classes that match the model
processor = LlavaProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    attn_implementation="sdpa"
)

# Get and set parameters from the model configuration
patch_size = model.config.vision_config.patch_size
vision_feature_select_strategy = getattr(model.config, 'vision_feature_select_strategy', 'default')
processor.patch_size = patch_size
processor.vision_feature_select_strategy = vision_feature_select_strategy

# Initialize variables for evaluation results
results = []
bleu_scores = []
bert_scores_p = []
bert_scores_r = []
bert_scores_f1 = []
correct_predictions = 0
total_predictions = 0

# Define the smoothing function
smoothie = SmoothingFunction().method4

# Evaluate samples in the dataset
for i, example in enumerate(tqdm(dataset["test"], desc="Processing examples")):
    try:
        print(f"Example {i + 1}/{len(dataset['test'])}")

        image = example.get("image")
        if image is None:
            print(f"Skipping example {i + 1}: No image found")
            continue

        conversation = [
            {
                "role": "user",
                "content": [
                    {"type": "text", "text": "Describe the chess pieces in the image, including their types and colors."},
                    {"type": "image"},
                ],
            },
        ]

        prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

        # Print the prompt
        print(f"Prompt for Example {i + 1}:\n{prompt}\n")

        inputs = processor(images=image, text=prompt, return_tensors="pt")
        inputs = {key: value.to(model.device) for key, value in inputs.items()}

        expected_output = example.get("caption")
        if expected_output is None:
            print(f"Skipping example {i + 1}: No caption found")
            continue

        # Generate model output
        output = model.generate(
            **inputs,
            max_new_tokens=128,
            no_repeat_ngram_size=2  # Prevent repetition
            # early_stopping=True  # Remove this parameter
        )
        full_output = processor.decode(output[0], skip_special_tokens=True)

        # Print the full model output
        print(f"Full Model Output for Example {i + 1}:\n{full_output}\n")

        # Extract the assistant's reply using regular expressions
        match = re.search(r'ASSISTANT:\s*(.*)', full_output, re.DOTALL)
        if match:
            model_output = match.group(1).strip()
        else:
            model_output = full_output.strip()

        # Print the processed model output
        print(f"Processed Model Output for Example {i + 1}:\n{model_output}\n")

        # Check if the prediction is correct using BERT Score's F1 value as the criterion
        P, R, F1 = bert_score([model_output], [expected_output], lang="en", rescale_with_baseline=True)

        if F1.item() > 0.5:  # Adjust the threshold as needed
            is_correct = True
            correct_predictions += 1
        else:
            is_correct = False

        # Compute BLEU score using the smoothing function
        reference = [expected_output.split()]
        hypothesis = model_output.split()
        bleu = sentence_bleu(reference, hypothesis, smoothing_function=smoothie)
        bleu_scores.append(bleu)

        # Store BERT Scores
        bert_scores_p.append(P.item())
        bert_scores_r.append(R.item())
        bert_scores_f1.append(F1.item())

        # Store the result for this sample
        results.append({
            "Example": i + 1,
            "Expected Output": expected_output,
            "Model Output": model_output,
            "BLEU Score": bleu,
            "BERT Score": {
                "P": P.item(),
                "R": R.item(),
                "F1": F1.item()
            },
            "Correct Prediction": is_correct
        })

        total_predictions += 1

        # Release resources
        gc.collect()
        torch.cuda.empty_cache()

    except Exception as e:
        print(f"Error processing example {i + 1}: {e}")

# Compute overall results
if total_predictions > 0:
    accuracy = correct_predictions / total_predictions
    average_bleu = np.mean(bleu_scores)
    average_bert_score = {
        "P": np.mean(bert_scores_p),
        "R": np.mean(bert_scores_r),
        "F1": np.mean(bert_scores_f1)
    }

    # Define output_data
    output_data = {
        "results": results,
        "summary": {
            "Accuracy": accuracy,
            "Average BLEU Score": average_bleu,
            "Average BERT Score": average_bert_score
        }
    }

    # Save results to a JSON file
    safe_data_id = data_id.replace("/", "_")
    output_file = f"./result/llava-1.5-7b-hf_{safe_data_id}_results.json"

    with open(output_file, "w") as f:
        json.dump(output_data, f, indent=4)

    print(f"Accuracy: {accuracy * 100:.2f}%")
    print(f"Average BLEU Score: {average_bleu:.4f}")
    print(f"Average BERT Score: P={average_bert_score['P']:.4f}, R={average_bert_score['R']:.4f}, F1={average_bert_score['F1']:.4f}")
else:
    print("No valid examples processed. Please check the dataset or the processing steps.")

True
12.1
TF32 Enabled: True


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Processing examples:   0%|          | 0/3 [00:00<?, ?it/s]

Example 1/3
Prompt for Example 1:
USER: <image>
Describe the chess pieces in the image, including their types and colors. ASSISTANT:

Full Model Output for Example 1:
USER:  
Describe the chess pieces in the image, including their types and colors. ASSISTANT: The chest pieces on the wooden table include a knight, a rook, and a pawn. The knights are white, the rooks are black, while the paawns are brown.

Processed Model Output for Example 1:
The chest pieces on the wooden table include a knight, a rook, and a pawn. The knights are white, the rooks are black, while the paawns are brown.



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Processing examples:  33%|███▎      | 1/3 [00:02<00:05,  2.66s/it]

Example 2/3
Prompt for Example 2:
USER: <image>
Describe the chess pieces in the image, including their types and colors. ASSISTANT:

Full Model Output for Example 2:
USER:  
Describe the chess pieces in the image, including their types and colors. ASSISTANT: The chest pieces on the wooden table include a knight, a rook, and a pawn. The knights are brown, the rooks are black, while the paons are white.

Processed Model Output for Example 2:
The chest pieces on the wooden table include a knight, a rook, and a pawn. The knights are brown, the rooks are black, while the paons are white.



Some weights of RobertaModel were not initialized from the model checkpoint at roberta-large and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
