In [38]:
import torch
from PIL import Image
from transformers import AutoModelForCausalLM, AutoProcessor

# Define models and processors from Hugging Face
model_id = "microsoft/Phi-3.5-vision-instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    device_map="auto"
)
processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)

# Ensure the tokenizer padding side is set to 'left'
processor.tokenizer.padding_side = 'left'

# Define prompts and suffix
user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"


from PIL import ImageOps

@torch.inference_mode()
def run_example(image_path, text_input=None, target_aspect_ratio=1.2):
    # Open and preprocess the image
    image = Image.open(image_path).convert("RGB")
    
    # Resize and add padding to ensure the width is near the target aspect ratio
    def pad_to_near_square(img, target_aspect_ratio=1.2):
        width, height = img.size
        aspect_ratio = width / height

        if aspect_ratio < target_aspect_ratio:
            # If the width is less than the target aspect ratio, add padding to the width
            new_width = int(target_aspect_ratio * height)
            padding = (new_width - width) // 2
            img = ImageOps.expand(img, (padding, 0, padding, 0))  # Add padding to width only
        elif aspect_ratio > target_aspect_ratio:
            # If the height is less than the target aspect ratio, add padding to the height
            new_height = int(width / target_aspect_ratio)
            padding = (new_height - height) // 2
            img = ImageOps.expand(img, (0, padding, 0, padding))  # Add padding to height only
        
        return img

    # Adjust the image to be near the target aspect ratio
    image = pad_to_near_square(image, target_aspect_ratio=target_aspect_ratio)
    
    # Prepare the prompt with the image and text
    prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
    
    # Process the inputs and move to the same device as the model
    inputs = processor(prompt, image, return_tensors="pt").to(model.device)

    # Generate the output from the model
    with torch.no_grad():
        generate_ids = model.generate(
            **inputs, 
            max_new_tokens=1000,
            eos_token_id=processor.tokenizer.eos_token_id,
        )

    # Trim the prompt tokens and decode the output
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
    response = processor.batch_decode(
        generate_ids, 
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=False
    )[0]
    
    return response



Loading checkpoint shards: 100%|██████████| 2/2 [00:03<00:00,  1.55s/it]


In [39]:
# Test example for inference
image_path = '/home/images/train/21jpeKcd2pL.jpg'
text_input = "Describe the Voltage, No hallucinations"
output_text = run_example(image_path, text_input)
print("Model Inference Output: ", output_text)

Model Inference Output:  Sorry, I cannot answer this question. The image shows a single AA battery with a yellow casing, lying on a plain surface. There is text printed on the battery, but it is not clear enough to read without potential hallucinogens affecting perception. However, the OCR text provided reads "VOLTAGE 1.5V", which gives us the voltage information without the need for hallucinogens.


In [40]:
from tqdm import tqdm
import pandas as pd
import os

# Path to the annotations file
annotations_file = "/home/student_resource3/dataset/train.csv"
df = pd.read_csv(annotations_file)

# Initialize a list to store results
results = []

# Use only a portion of the DataFrame (first 32 rows in this case)
data = df.iloc[300:320]

# Iterate over the sliced dataset
for idx, row in tqdm(data.iterrows(), total=len(data), desc="Inference Progress"):
    try:
        image_id = os.path.basename(row['image_link'])
        entity_name = row.get('entity_name', None)  # Handle missing entity_name
        entity_value = row.get('entity_value', None)

        if entity_name is None:
            print(f"Missing entity_name at row {idx}, skipping...")
            continue

        # Construct the image path
        image_path = os.path.join("/home/images/train", image_id)

        # Construct the text input for the model
        # text_input = f"Find the text of {entity_name} only where its unit  is beside it in the image and return the value in the format '<|value|> <|unit|>' with a space between the number and the unit"
        # text_input = f"Find the {entity_name}. Return only the decimal expression part in the format of '<|number|> <|unit|>' with a space between the number and the unit. Convert units like '3.5g' to '3.5 gram' or '750ml' to '750 millilitre'. Use correct units based on the provided entity type from the allowed units list: ['milligram', 'kilogram', 'microgram', 'gram', 'ounce', 'ton', 'pound']. Do not hallucinate, add extra information, or change the meaning. Also expand all the terms like if mm comes then it should be millimeter no acroynom"
        # text_input = f"Extract the value of {entity_name} from the image. Ensure the extracted value is followed by its corresponding unit, if on unit give the next best one with a space separating them. Format the result as '<|value|> <|unit|>'."
        text_input = f"Extract the value of {entity_name} from the image. Ensure the extracted value is followed by its corresponding unit, if no unit is present unit give the next best one with unit. Format the result as '<|value|> <|unit|>'."

        # Run inference for the current image
        output_text = run_example(image_path, text_input)

        # Store result
        results.append({
            'image_id': image_id,
            'entity_name': entity_name,
            'entity_value': entity_value,
            'result': output_text
        })

    except Exception as e:
        print(f"Error at row {idx}: {e}")

# Convert results to DataFrame and save to CSV
results_df = pd.DataFrame(results)
results_df.to_csv('/home/student_resource3/dataset/inference_results_partial.csv', index=False)
print("Inference results for partial dataset saved to CSV.")


Inference Progress: 100%|██████████| 20/20 [00:12<00:00,  1.56it/s]

Inference results for partial dataset saved to CSV.



