In [None]:
import os
import pandas as pd
import torch
import numpy as np
from PIL import Image
import requests
from io import BytesIO
from torch.utils.data import Dataset, DataLoader
from transformers import VisionEncoderDecoderConfig, VisionEncoderDecoderModel, ViTFeatureExtractor, AutoTokenizer
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import re

# Define unit symbols and their full names
unit_symbols = {
    'height': {'cm': 'centimeter', 'mm': 'millimeter', 'in': 'inch', 'm': 'meter', 'ft': 'foot'},
    'width': {'cm': 'centimeter', 'mm': 'millimeter', 'in': 'inch', 'm': 'meter', 'ft': 'foot'},
    'depth': {'cm': 'centimeter', 'mm': 'millimeter', 'in': 'inch', 'm': 'meter', 'ft': 'foot'},
    'item_volume': {'cup': 'cup', 'gal': 'gallon', 'oz': 'ounce', 'ml': 'milliliter', 'ft³': 'cubic foot', 'ft3': 'cubic foot',
                    'fl oz': 'fluid ounce', 'dl': 'deciliter', 'in³': 'cubic inch', 'in3': 'cubic inch', 'l': 'liter',
                    'qt': 'quart', 'pt': 'pint', 'cl': 'centiliter'},
    'wattage': {'W': 'watt', 'w':'watt', 'hp': 'horsepower', 'kW': 'kilowatt', 'kWh': 'kilowatt hour', 'mAh': 'milliampere hour'},
    'voltage': {'V': 'volt'},
    'item_weight': {'g': 'gram', 'mg': 'milligram', 'kg': 'kilogram', 'oz': 'ounce', 'lb': 'pound',
                    't': 'ton', 'µg': 'microgram', 'ml': 'milliliter', 'GB': 'gigabyte', 'gb': 'gigabyte' , 'ct': 'carat',
                    'L': 'liter', 'l': 'liter', 'nit': 'nit', 'in': 'inch', 'qt': 'quart', 'W': 'watt',
                    'mm': 'millimeter', 'cm': 'centimeter', 'in³': 'cubic inch', 'in3': 'cubic inch'},
    'maximum_weight_recommendation': {'g': 'gram', 'mg': 'milligram', 'kg': 'kilogram', 'oz': 'ounce',
                                       'lb': 'pound', 't': 'ton', 'µg': 'microgram', 'ml': 'milliliter',
                                       'GB': 'gigabyte', 'ct': 'carat', 'L': 'liter', 'nit': 'nit',
                                       'in': 'inch', 'qt': 'quart', 'W': 'watt', 'mm': 'millimeter',
                                       'cm': 'centimeter', 'in³': 'cubic inch', 'in3': 'cubic inch'}
}

class ImageTextDataset(Dataset):
    def __init__(self, dataframe, feature_extractor, tokenizer, max_target_length=128):
        self.dataframe = dataframe
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.max_target_length = max_target_length

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        item = self.dataframe.iloc[idx]
        image_url = item['image_link']
        text = item['text']  

        # Download and process image
        response = requests.get(image_url)
        image = Image.open(BytesIO(response.content)).convert('RGB')
        pixel_values = self.feature_extractor(images=image, return_tensors="pt").pixel_values

        # Tokenize text
        labels = self.tokenizer(text, 
                                padding="max_length", 
                                max_length=self.max_target_length,
                                truncation=True).input_ids

        return {
            "pixel_values": pixel_values.squeeze(),
            "labels": torch.tensor(labels)
        }

def fine_tune_model(train_df, val_df, model_name="microsoft/trocr-base-handwritten", output_dir="./fine_tuned_model"):
    # Load pre-trained model and tokenizer
    model = VisionEncoderDecoderModel.from_pretrained(model_name)
    feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # Prepare datasets
    train_dataset = ImageTextDataset(train_df, feature_extractor, tokenizer)
    val_dataset = ImageTextDataset(val_df, feature_extractor, tokenizer)

    # Define training arguments
    training_args = Seq2SeqTrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        learning_rate=5e-5,
        per_device_train_batch_size=4,
        per_device_eval_batch_size=4,
        num_train_epochs=3,
        weight_decay=0.01,
        save_total_limit=3,
    )

    # Initialize trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        tokenizer=tokenizer,
    )

    # Fine-tune the model
    trainer.train()

    # Save the fine-tuned model
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)
    feature_extractor.save_pretrained(output_dir)

    return model, feature_extractor, tokenizer

def extract_text_from_image(image_url, model, feature_extractor, tokenizer):
    response = requests.get(image_url)
    image = Image.open(BytesIO(response.content)).convert("RGB")

    pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
    generated_ids = model.generate(pixel_values)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    return generated_text

def find_value_with_symbol(text, entity_name):
    if entity_name == 'item_weight':
        net_weight_pattern = r'(?:net\s*(?:wt\.?|weight)|(?:total\s*)?weight|wt\.?|w\.)\s*[:()]?\s*(\d+(?:[.,]\d+)?)\s*([a-zA-Z]+)'
        match = re.search(net_weight_pattern, text, re.IGNORECASE)
        if match:
            value, unit = match.groups()
            value = value.replace(',', '.')
            formatted_value = f"{float(value):.1f}" if '.' not in value else value
            
            if unit.lower() == 'c':
                unit = 'g'
            
            full_unit_name = next((full for abbr, full in unit_symbols['item_weight'].items() 
                                   if abbr.lower() == unit.lower()), unit)
            
            return f"{formatted_value} {full_unit_name}"

    symbols = unit_symbols.get(entity_name, {})
    for symbol, unit_name in symbols.items():
        pattern = rf'(\d+(?:[.,]\d+)?)\s*{re.escape(symbol)}'
        match = re.search(pattern, text)
        if match:
            value = match.group(1).replace(',', '.')
            formatted_value = f"{float(value):.1f}" if '.' not in value else value
            return f"{formatted_value} {unit_name}"
    
    return ''

def process_dataframe(df, model, feature_extractor, tokenizer):
    def process_row(row):
        extracted_text = extract_text_from_image(row['image_link'], model, feature_extractor, tokenizer)
        return find_value_with_symbol(extracted_text, row['entity_name'])

    df['extracted_value'] = df.apply(process_row, axis=1)
    return df

if __name__ == "__main__":
    # Load your training and validation dataframes
    train_df = pd.read_csv('train_data.csv')
    val_df = pd.read_csv('val_data.csv')

    # Fine-tune the model
    model, feature_extractor, tokenizer = fine_tune_model(train_df, val_df)

    # Load your test dataframe
    test_df = pd.read_csv('test_data.csv')

    # Process the test dataframe
    result_df = process_dataframe(test_df, model, feature_extractor, tokenizer)

    # Save the results
    result_df.to_csv('output_results.csv', index=False)
    print("Processing complete. Results saved to 'output_results.csv'")