In [1]:
import os
import json
import tqdm
import pandas as pd
from PIL import Image
from concurrent.futures import ThreadPoolExecutor

import torch
from torch.utils.data import Dataset, DataLoader
from peft import LoraConfig, get_peft_model
from transformers import AutoProcessor, BitsAndBytesConfig, Idefics2ForConditionalGeneration
from transformers import TrainingArguments, Trainer

In [2]:
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "./models/trained/train/checkpoint-1625"

In [3]:
# Load the model and processor
processor = AutoProcessor.from_pretrained(
    "HuggingFaceM4/idefics2-8b",
    do_image_splitting=False,
    use_fast=True
)

model = Idefics2ForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,
    device_map=DEVICE
)

model.enable_adapters()


Chat templates should be in a 'chat_template.json' file but found key='chat_template' in the processor's config. Make sure to move your template to its own file.
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


In [4]:
def load_image(image_path):
    try:
        return (image_path, Image.open(image_path).convert('RGB'))
    except Exception as e:
        print(f"Error loading image {image_path}: {e}")
        return None

def process_images_concurrently(folder_path, max_workers=8):
    image_paths = [os.path.join(folder_path, f) for f in os.listdir(folder_path) if f.endswith(('.png', '.jpg', '.jpeg'))]
    
    images = {}
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Wrap the executor.map with tqdm for progress bar
        for result in tqdm.tqdm(executor.map(load_image, image_paths), total=len(image_paths), desc="Loading images"):
            images[result[0]] = result[1]
    return images
    
images_dict = process_images_concurrently("./processed/test", max_workers=8)

Loading images: 100%|██████████| 90666/90666 [01:00<00:00, 1509.10it/s]


In [5]:
class InferenceDataset(Dataset):
    def __init__(self, df):
        self.df = df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        global images_dict
        row = self.df.iloc[idx]
        image_path = os.path.join("./processed/test", os.path.basename(row['image_link']))
        image = images_dict[image_path]
        return {
            'image': image,
            'entity_name': row['entity_name'],
            'id': idx
        }


In [6]:
entity_unit_map = {
    'width': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'depth': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'height': {'centimetre', 'foot', 'inch', 'metre', 'millimetre', 'yard'},
    'item_weight': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'maximum_weight_recommendation': {'gram',
        'kilogram',
        'microgram',
        'milligram',
        'ounce',
        'pound',
        'ton'},
    'voltage': {'kilovolt', 'millivolt', 'volt'},
    'wattage': {'kilowatt', 'watt'},
    'item_volume': {'centilitre',
        'cubic foot',
        'cubic inch',
        'cup',
        'decilitre',
        'fluid ounce',
        'gallon',
        'imperial gallon',
        'litre',
        'microlitre',
        'millilitre',
        'pint',
        'quart'}
}

In [7]:
def collate_fn(batch):
    images_ = [item['image'] for item in batch]
    entity_names = [item['entity_name'] for item in batch]
    ids = [item['id'] for item in batch]
    
    questions = [f'What is the {name.replace("_", " ")} of the product?' for name in entity_names]
    
    system_prompts = [f"""
    1. Report the value and unit exactly as they appear in the image.
    2. If the feature is not visible respond with an empty string.
    3. Provide your answer in the format: "value unit" (e.g., "500 gram" or "2.5 inch").
    4. Acceptable units: {entity_unit_map[name]}
    """ for name in entity_names]
    
    texts = []
    images = []
    for system_prompt, question, image in zip(system_prompts, questions, images_):
        messages = [
            {
                "role": "user",
                "content": [
                    {'type': 'text', 'text': system_prompt},
                    {'type': 'image'},
                    {'type': 'text', 'text': question}
                ]
            }     
        ]
        text = processor.apply_chat_template(messages, add_generation_prompt=True)
        texts.append(text.strip())
        images.append([image])

    inputs = processor(text=texts, images=images, return_tensors="pt", padding=True)
    
    return inputs, entity_names, ids


In [12]:
def generate_responses(model, dataloader):
    model.eval()
    results = []
    with torch.no_grad():
        for batch, entity_names, ids in tqdm.tqdm(dataloader, desc="Processing batches"):
            batch = {k: v.to(DEVICE) for k, v in batch.items()}
            
            generated_ids = model.generate(
                **batch,
                max_new_tokens=30,
                )
            
            generated_texts = processor.batch_decode(generated_ids[:, batch["input_ids"].size(1):], skip_special_tokens=True)
            op = [{'id': id_, 'entity_name': entity_name, 'generated_text': generated_text} for id_,entity_name, generated_text in zip(ids, entity_names, generated_texts)]
            results.extend(op)
    return results


In [13]:
df = pd.read_csv('./dataset/test.csv')
BATCH_SIZE = 65  
NUM_WORKERS = 32 

In [16]:
dataset = InferenceDataset(df)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, collate_fn=collate_fn)

In [17]:
results = generate_responses(model, dataloader)

Processing batches: 100%|██████████| 462/462 [1:55:28<00:00, 15.00s/it]


In [18]:
# Save results
results_df = pd.DataFrame(results)
results_df.to_csv('inference_results.csv', index=False)
print("Inference completed. Results saved to 'inference_results.csv'")


Inference completed. Results saved to 'inference_results.csv'
