In [None]:
import pandas as pd
import re
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import pytesseract
from PIL import Image
import requests
from io import BytesIO

# Define the entity-unit map for validation purposes
entity_unit_map = {
    "width": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
    "depth": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
    "height": {"centimetre", "foot", "millimetre", "metre", "inch", "yard"},
    "item_weight": {"milligram", "kilogram", "microgram", "gram", "ounce", "ton", "pound"},
    "maximum_weight_recommendation": {"milligram", "kilogram", "microgram", "gram", "ounce", "ton", "pound"},
    "voltage": {"millivolt", "kilovolt", "volt"},
    "wattage": {"kilowatt", "watt"},
    "item_volume": {"cubic foot", "microlitre", "cup", "fluid ounce", "centilitre", "imperial gallon", "pint",
                    "decilitre", "litre", "millilitre", "quart", "cubic inch", "gallon"}
}

# Check if GPU is available and set the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Preprocess text to extract numbers and units
def extract_number(text):
    """Extract the first number found in the text."""
    if not isinstance(text, str):
        return ""
    numbers = re.findall(r'\d+\.?\d*', text)
    return numbers[0] if numbers else ""

def extract_units(text):
    """Extract the unit from the text."""
    if not isinstance(text, str):
        return ""
    units = re.findall(r'[a-zA-Z]+', text)
    return units[-1] if units else ""

# Download image from URL
def download_image(image_url):
    """Download an image from a URL."""
    try:
        response = requests.get(image_url)
        image = Image.open(BytesIO(response.content))
        return image
    except Exception as e:
        print(f"Error downloading image: {e}")
        return None

# Extract text from image using OCR
def extract_text_from_image(image):
    """Extract text from an image using OCR."""
    try:
        text = pytesseract.image_to_string(image)
        return text.strip()
    except Exception as e:
        print(f"Error extracting text from image: {e}")
        return ""

# Modify process_image to download image and extract text
def process_image(row):
    """Process a single image, extract text."""
    image_url = row['image_link']
    entity_name = row['entity_name']
    index = row.name
    
    # Download the image
    image = download_image(image_url)
    if image:
        # Extract text from the image
        extracted_text = extract_text_from_image(image)
        return {'index': index, 'entity_name': entity_name, 'extracted_text': extracted_text}
    else:
        print(f"Failed to process image URL: {image_url}")
        return {'index': index, 'entity_name': entity_name, 'extracted_text': ''}

# Validate the predicted unit
def validate_unit(entity_name, unit):
    """Validate if the predicted unit is valid for the entity."""
    valid_units = entity_unit_map.get(entity_name, set())
    return unit if unit in valid_units else ""

# Load pre-trained BERT model for unit prediction
def load_bert_model(num_labels):
    """Load a pre-trained BERT model for unit classification."""
    print("Loading BERT model and tokenizer...")
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=num_labels)
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

    # Move model to the appropriate device
    model.to(device)

    return model, tokenizer

# Train the model
def train_model(model, dataloader, epochs=10):
    """Train the BERT model."""
    print("Training the BERT model...")
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for batch in tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}"):
            optimizer.zero_grad()
            input_ids, attention_mask, labels = [b.to(device) for b in batch]
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1} complete. Average Loss: {avg_loss:.4f}")

# Predict units for extracted text
def predict_units(extracted_df, bert_model, tokenizer, id_to_unit):
    """Predict units for the extracted texts."""
    predictions = []
    for idx, row in tqdm(extracted_df.iterrows(), total=len(extracted_df), desc="Predicting units"):
        index = row['index']
        entity_name = row['entity_name']
        extracted_text = row['extracted_text']

        if extracted_text.strip() == '':
            predictions.append({'index': index, 'prediction': ''})
            continue

        # Extract numerical value
        extracted_num = extract_number(extracted_text)

        # Tokenize and encode
        encoded_input = tokenizer(extracted_text, return_tensors='pt').to(device)
        with torch.no_grad():
            outputs = bert_model(**encoded_input)
            predicted_unit_id = torch.argmax(outputs.logits, dim=1).item()

            # Map prediction to unit and validate
            predicted_unit = id_to_unit[predicted_unit_id]
            validated_unit = validate_unit(entity_name, predicted_unit)

            # Format result for item_value
            item_value = f"{extracted_num} {validated_unit}" if validated_unit else extracted_num
            if validated_unit:
                item_value = f"{extracted_num} {validated_unit}"
            else:
                item_value = ""

        predictions.append({'index': index, 'prediction': item_value})

    return pd.DataFrame(predictions)

# Main function
def main():
    # Load and preprocess data
    train_data = pd.read_csv('/kaggle/input/amazon-ml-cleaned/train_clean.csv')
    test_data = pd.read_csv('/kaggle/input/amazon-ml/test.csv')
#     train_data = train_data.iloc[:50000]
#     test_data = test_data.iloc[:100]
    
    # Prepare features and labels
    train_data['numerical_value'] = train_data['entity_value'].apply(extract_number)
    train_data['unit'] = train_data['entity_value'].apply(extract_units)

    # Map units to numerical values
    unit_to_id = {unit: i for i, unit in enumerate(set(train_data['unit']))}
    id_to_unit = {i: unit for unit, i in unit_to_id.items()}

    # Convert units to numerical IDs
    train_data['unit_id'] = train_data['unit'].map(unit_to_id)

    # Split data
    X_train, X_val, y_train, y_val = train_test_split(
        train_data[['entity_value']],
        train_data['unit_id'],
        test_size=0.2,
        random_state=42
    )

    # Load BERT model
    bert_model, tokenizer = load_bert_model(num_labels=len(unit_to_id))

    # Tokenize training data
    def tokenize_data(texts, tokenizer):
        return tokenizer(texts, truncation=True, padding=True, return_tensors='pt')

    train_encodings = tokenize_data(X_train['entity_value'].tolist(), tokenizer)
    val_encodings = tokenize_data(X_val['entity_value'].tolist(), tokenizer)

    # Create dataloaders
    def create_dataloader(encodings, labels, batch_size=32):
        dataset = torch.utils.data.TensorDataset(
            encodings['input_ids'], encodings['attention_mask'], torch.tensor(labels.values)
        )
        return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    train_loader = create_dataloader(train_encodings, y_train)
    val_loader = create_dataloader(val_encodings, y_val)

    # Train the model
    train_model(bert_model, train_loader, epochs=5)  # Train for 5 epochs (increase if necessary)

    # Process test images and convert results to DataFrame
    extracted_df = pd.DataFrame(test_data.apply(process_image, axis=1).tolist())

    # Predict units
    output_df = predict_units(extracted_df, bert_model, tokenizer, id_to_unit)

    # Save output
    output_df.to_csv('test_out.csv', index=False)
    print("Results saved to 'test_out.csv'")
    print(output_df)  # Print first few rows of the output CSV

if __name__ == "__main__":
    main()

Using device: cuda
Loading BERT model and tokenizer...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training the BERT model...


Epoch 1/5: 100%|██████████| 6442/6442 [09:04<00:00, 11.84it/s]


Epoch 1 complete. Average Loss: 0.0191


Epoch 2/5: 100%|██████████| 6442/6442 [09:04<00:00, 11.82it/s]


Epoch 2 complete. Average Loss: 0.0149


Epoch 3/5: 100%|██████████| 6442/6442 [09:04<00:00, 11.84it/s]


Epoch 3 complete. Average Loss: 0.0055


Epoch 4/5: 100%|██████████| 6442/6442 [09:04<00:00, 11.84it/s]


Epoch 4 complete. Average Loss: 0.0012


Epoch 5/5: 100%|██████████| 6442/6442 [09:04<00:00, 11.83it/s]


Epoch 5 complete. Average Loss: 0.0083
