In [1]:
# %% [code]
!pip install git+https://github.com/huggingface/transformers accelerate

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-bncwbh58
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-bncwbh58
  Resolved https://github.com/huggingface/transformers to commit 3927ffed31e3c0d2929bf98bd05b7c61fcc48b62
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting huggingface-hub==1.0.0.rc5 (from transformers==5.0.0.dev0)
  Downloading huggingface_hub-1.0.0rc5-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<=0.23.0,>=0.22.0 (from transformers==5.0.0.dev0)
  Downloading tokenizers-0.22.1-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=2.0.0->accelerate)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.1

In [2]:
# %% [code]
import torch
from transformers import AutoModel, AutoProcessor
from PIL import Image
import pandas as pd
from tqdm import tqdm
import requests
from io import BytesIO
import numpy as np
import re
import torch.nn.functional as F

# %% [code]
def preprocess_product_text(text, func=str.lower):
    """
    Processes product text to extract and format specific fields in the order:
    Units, Value, Item Name.
    """
    if not isinstance(text, str):
        return "" # Return empty string for non-string inputs
        
    # Remove emojis
    emoji_pattern = re.compile(
        "[" 
        "\U0001F600-\U0001F64F"  # emoticons
        "\U0001F300-\U0001F5FF"  # symbols & pictographs
        "\U0001F680-\U0001F6FF"  # transport & map symbols
        "\U0001F1E0-\U0001F1FF"  # flags (iOS)
        "\U00002700-\U000027BF"  # dingbats
        "\U0001F900-\U0001F9FF"  # Supplemental Symbols and Pictographs
        "\U00002600-\U000026FF"  # miscellaneous symbols
        "\U00002B00-\U00002BFF"  # miscellaneous symbols and arrows
        "]+", flags=re.UNICODE)
    text = emoji_pattern.sub(r'', text)
    
    # Extract required fields
    item_name_match = re.search(r'Item Name:\s*(.*)', text, re.IGNORECASE)
    item_name = func(item_name_match.group(1).strip()) if item_name_match else ''
    
    value_match = re.search(r'Value:\s*(.*)', text, re.IGNORECASE)
    value = func(value_match.group(1).strip()) if value_match else ''
    
    units_match = re.search(r'Units:\s*(.*)', text, re.IGNORECASE)
    units = func(units_match.group(1).strip()) if units_match else ''
    
    # Construct the output string in the desired order
    output_lines = [
        f"Units: {units}",
        f"Value: {value}",
        f"Item Name: {item_name}"
    ]
    
    return '\n'.join(output_lines)

# %% [code]
MODEL_ID = "google/siglip2-giant-opt-patch16-384"
SAVE_PATH = "./siglip2-base-patch16-256"
DATA_PATH = "/kaggle/input/amlc2025/student_resource/dataset/train.csv"
TEXT_COLUMN = 'catalog_content'
BATCH_SIZE = 128 # Increased batch size for better efficiency
DEBUG = False
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SHARD_NUMBER = 8 # Change this for each Kaggle account
TOTAL_SHARDS = 20 # Set this to the total number of accounts/shards

# %% [code]
model = AutoModel.from_pretrained(MODEL_ID, device_map="auto").eval()
processor = AutoProcessor.from_pretrained(MODEL_ID)

# %% [code]
total_df = pd.read_csv(DATA_PATH)
total_rows = len(total_df)
rows_per_shard = total_rows // TOTAL_SHARDS
start_idx = (SHARD_NUMBER - 1) * rows_per_shard
end_idx = start_idx + rows_per_shard if SHARD_NUMBER < TOTAL_SHARDS else total_rows
df = total_df.iloc[start_idx:end_idx].reset_index(drop=True)
print(f"Processing shard {SHARD_NUMBER}/{TOTAL_SHARDS}: {len(df)} rows")

# %% [code]
lim = len(df)
if DEBUG:
    lim = 200

all_image_embeddings = []
all_text_embeddings = []
all_ids = []

for start_idx in tqdm(range(0, lim, BATCH_SIZE)):
    batch_df = df.iloc[start_idx : min(start_idx + BATCH_SIZE, lim)]
    
    images_to_process = []
    batch_texts = []
    
    for _, row in batch_df.iterrows():
        # Process text
        text_content = row[TEXT_COLUMN]
        processed_text = preprocess_product_text(text_content)
        batch_texts.append(processed_text)
        
        # Process image
        try:
            image_url = row["image_link"]
            if not isinstance(image_url, str) or not (image_url.startswith("http://") or image_url.startswith("https://")):
                raise ValueError("Invalid image URL")

            image_response = requests.get(image_url, stream=True)
            image_response.raise_for_status()
            image = Image.open(image_response.raw).convert("RGB")
            images_to_process.append(image)

        except Exception as e:
            print(f"Failed to process image {row.get('image_link', 'N/A')}. Using a black dummy image instead. Error: {e}")
            image = Image.new('RGB', (224, 224), color='black')
            images_to_process.append(image)

    # Use the processor for both text and images
    inputs = processor(
        text=batch_texts, 
        images=images_to_process, 
        return_tensors="pt", 
        padding="max_length", # Recommended for SigLIP 2
        truncation=True,
        max_length=64       # Recommended for SigLIP 2
    ).to(model.device)

    with torch.no_grad():
        # Get both image and text embeddings from a single model call
        outputs = model(**inputs)
        image_embeddings = outputs.image_embeds
        text_embeddings = outputs.text_embeds
        
        # Normalize the embeddings
        image_embeddings = F.normalize(image_embeddings, p=2, dim=-1)
        text_embeddings = F.normalize(text_embeddings, p=2, dim=-1)


    for i, (_, row) in enumerate(batch_df.iterrows()):
        all_image_embeddings.append(image_embeddings[i].cpu().numpy())
        all_text_embeddings.append(text_embeddings[i].cpu().numpy())
        all_ids.append(row["sample_id"])
            
    del inputs, outputs, image_embeddings, text_embeddings, images_to_process

# %% [code]
# Convert lists of embeddings to 2D numpy arrays
all_image_embeddings = np.stack(all_image_embeddings)
all_text_embeddings = np.stack(all_text_embeddings)
all_ids = np.array(all_ids)

# Save image embeddings, text embeddings, and IDs to separate files
np.save(f"image_embeddings_{SHARD_NUMBER}.npy", all_image_embeddings)
np.save(f"text_embeddings_{SHARD_NUMBER}.npy", all_text_embeddings)
np.save(f"sample_ids_{SHARD_NUMBER}.npy", all_ids)

print("Saved image embeddings shape:", all_image_embeddings.shape)
print("Saved text embeddings shape:", all_text_embeddings.shape)
print("Saved sample IDs shape:", all_ids.shape)

# %% [code]
# Optionally save the model and processor
# model.save_pretrained(SAVE_PATH)
# processor.save_pretrained(SAVE_PATH)

config.json:   0%|          | 0.00/537 [00:00<?, ?B/s]

model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.49G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.99G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

preprocessor_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.24M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/34.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/636 [00:00<?, ?B/s]

Processing shard 8/20: 3750 rows


100%|██████████| 30/30 [25:59<00:00, 51.97s/it]

Saved image embeddings shape: (3750, 1536)
Saved text embeddings shape: (3750, 1536)
Saved sample IDs shape: (3750,)



