In [None]:
!pip install --upgrade libraft-cu12==25.6.*
!pip install --upgrade pylibraft-cu12==25.6.*
!pip install --upgrade rmm-cu12==25.6.*
!pip install --upgrade cuml-cu12==25.6.*
!pip install --upgrade raft-dask-cu12==25.6.*
!pip install --upgrade cudf-cu12==25.6.*
!pip install --upgrade cuvs-cu12==25.6.*
!pip install --upgrade pylibcugraph-cu12==25.6.*

!pip install pyarrow>=21.0.0
!pip install "pydantic<2.12,>=2.0"
!pip install dask==2024.12.1
!pip install scikit-learn<1.6.0,>=1.0.0
!pip install dask==2025.5.0


In [None]:

!pip install sentence-transformers open-clip-torch transformers pandas pillow beautifulsoup4 tqdm


In [None]:
import os
import re
import torch
import pandas as pd
from PIL import Image
from tqdm import tqdm
from bs4 import BeautifulSoup
from sentence_transformers import SentenceTransformer
import open_clip

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(DEVICE)


In [None]:
# ---- CONFIG ----

IMAGE_FOLDER = r"/kaggle/input/amazon-images/images/images"
DATA_CSV_PATH = r"/kaggle/input/amazon-images/train.csv"
GEMMA_MODEL_NAME = "google/embeddinggemma-300m"
GEMMA_MAX_TOKENS = 2048
LAION_MODEL_ID = 'hf-hub:laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
LAION_TOKEN_LIMIT = 77
BATCH_SIZE = 32
OUTPUT_CSV_PATH = "all_embeddings.csv"

In [None]:
# --- PREPROCESSING FUNCTIONS ---

def is_html(text):
    return bool(re.search(r'<[^>]+>', str(text)))

def clean_text_for_laion(text):
    '''Cleaner for LAION OpenCLIP text embedding.'''
    text = str(text).encode('utf-8', 'ignore').decode(errors='ignore')
    if is_html(text):
        text = BeautifulSoup(text, 'lxml').get_text()
    # Remove emojis & corrupted unicode
    text = re.sub(r'[^\x00-\x7F]+', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text

def clean_text_for_gemma(text):
    '''Cleaner for Gemma embedding input.'''
    text = str(text).lower()
    if is_html(text):
        text = BeautifulSoup(text, 'lxml').get_text()
    text = re.sub(r'[^\x00-\x7F]+', ' ', text)  # Remove non-ascii for simplicity
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


In [None]:
# --- Gemma SentenceTransformer Chunking ---

def chunk_text_for_sentence_transformer(text, tokenizer, max_tokens):
    tokens = tokenizer.tokenize(text)
    chunks, i = [], 0
    while i < len(tokens):
        chunk = tokens[i:i+max_tokens]
        chunk_text = tokenizer.convert_tokens_to_string(chunk)
        chunks.append(chunk_text)
        i += max_tokens
    return chunks

def get_gemma_embedding(text, gemma_model):
    tokenizer = gemma_model.tokenizer
    chunks = chunk_text_for_sentence_transformer(text, tokenizer, GEMMA_MAX_TOKENS)
    chunk_embeds = gemma_model.encode(chunks, device=DEVICE)
    # Average the chunk embeddings
    avg_embedding = torch.tensor(chunk_embeds).float().mean(dim=0).numpy()
    return avg_embedding

In [None]:
#  ----------------------

def get_laion_text_embedding_batch(texts, tokenizer, laion_model):
    tokenized = []
    for text in texts:
        tokens = list(tokenizer.encode(text))
        tokens = tokens[:LAION_TOKEN_LIMIT]

        pad_len = LAION_TOKEN_LIMIT - len(tokens)
        if pad_len > 0:
            pad_token = tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else 0
            tokens.extend([pad_token] * pad_len)

        print(f"Original length: {len(tokens) - pad_len}, Padded length: {len(tokens)}")
        tokenized.append(tokens)

    assert all(len(t) == LAION_TOKEN_LIMIT for t in tokenized), "Token length mismatch."

        
    text_tensor = torch.tensor(tokenized, dtype=torch.long).to(DEVICE)
    with torch.no_grad():
        emb = laion_model.encode_text(text_tensor)
        emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb.cpu().numpy()


def get_laion_image_embedding_batch(image_paths, laion_model, image_preprocess):
    images = []
    for path in image_paths:
        if os.path.exists(path):
            try:
                image = Image.open(path).convert("RGB")
                images.append(image_preprocess(image))
            except Exception as e:
                print(f"Error loading image {path}: {e}")
                images.append(torch.zeros(3, 224, 224))
        else:
            images.append(torch.zeros(3, 224, 224))

    image_tensor = torch.stack(images).to(DEVICE)
    with torch.no_grad():
        emb = laion_model.encode_image(image_tensor)
        emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb.cpu().numpy()

In [None]:
#-----------------------------

def main():
    # Load models
    gemma_model = SentenceTransformer(GEMMA_MODEL_NAME, device=DEVICE)
    laion_model, image_preprocess_train, _ = open_clip.create_model_and_transforms(LAION_MODEL_ID, device=DEVICE)
    laion_tokenizer = open_clip.get_tokenizer(LAION_MODEL_ID)

    # Load CSV data
    df = pd.read_csv(DATA_CSV_PATH)

    # Load already processed sample_ids if output CSV exists
    processed_ids = set()
    if os.path.exists(OUTPUT_CSV_PATH):
        processed_df = pd.read_csv(OUTPUT_CSV_PATH, usecols=["sample_id"])
        processed_ids = set(processed_df["sample_id"].astype(str).tolist())
        print(f"Resuming from {len(processed_ids)} already processed samples.")

    output_rows = []
    batch_sample_ids = []
    batch_prices = []
    batch_texts_gemma = []
    batch_texts_laion = []
    batch_image_paths = []



    for idx, row in tqdm(df.iterrows(), total=len(df)):

        sample_id_str = str(row["sample_id"])
        if sample_id_str in processed_ids:
            continue
        
        batch_sample_ids.append(row["sample_id"])
        batch_prices.append(row.get("price", None))
        batch_texts_gemma.append(clean_text_for_gemma(row["catalog_content"]))
        batch_texts_laion.append(clean_text_for_laion(row["catalog_content"]))
        image_filename = str(row["sample_id"]) + ".jpg"
        batch_image_paths.append(os.path.join(IMAGE_FOLDER, image_filename))


        # When batch full or last row
        if len(batch_sample_ids) == BATCH_SIZE or idx == len(df) - 1:
            # Gemma embeddings batch
            gemma_emb_batch = []
            for text in batch_texts_gemma:
                gemma_emb_batch.append(get_gemma_embedding(text, gemma_model))
            gemma_emb_batch = torch.tensor(gemma_emb_batch).numpy()

            # LAION text embedding batch
            laion_text_emb_batch = get_laion_text_embedding_batch(batch_texts_laion, laion_tokenizer, laion_model)

            # LAION image embedding batch
            laion_image_emb_batch = get_laion_image_embedding_batch(batch_image_paths, laion_model, image_preprocess_train)

            # Compose output
            for i in range(len(batch_sample_ids)):
                out_row = {
                    "sample_id": batch_sample_ids[i],
                    "price": batch_prices[i],
                    **{f"gemma_{j}": gemma_emb_batch[i, j] for j in range(gemma_emb_batch.shape[1])},
                    **{f"laion_text_{j}": laion_text_emb_batch[i, j] for j in range(laion_text_emb_batch.shape[1])},
                    **{f"laion_image_{j}": laion_image_emb_batch[i, j] for j in range(laion_image_emb_batch.shape[1])}
                }
                output_rows.append(out_row)


            # Append batch results to CSV (create if not exists)
            batch_df = pd.DataFrame(output_rows)
            if not os.path.exists(OUTPUT_CSV_PATH):
                batch_df.to_csv(OUTPUT_CSV_PATH, index=False)
            else:
                batch_df.to_csv(OUTPUT_CSV_PATH, mode='a', index=False, header=False)
            print(f"Saved {len(output_rows)} rows to {OUTPUT_CSV_PATH}")

            # Clear batches
            
            # Clear batches
            output_rows = []
            batch_sample_ids = []
            batch_prices = []
            batch_texts_gemma = []
            batch_texts_laion = []
            batch_image_paths = []



if __name__ == "__main__":
    main()