In [None]:
!pip install huggingface-hub

from huggingface_hub import login
# token="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"

login(token=token)

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [None]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from pyspark.sql.functions import collect_list
from google.cloud import storage

storage_client = storage.Client()
bucket = storage_client.get_bucket('gcs-dsci-fryou-fy-dev-prd')

In [None]:
version = 1
path = "gs://gcs-dsci-fryou-fy-dev-prd/catalog_vqvae/vqvae_catalog_embeddings_2025-03-01"

In [None]:
df = spark.read.parquet(path)

In [None]:
output_path = "gs://gcs-dsci-fryou-fy-dev-prd/catalog_vqvae/kmeans_50000"

In [None]:
taxonomy = spark.sql("""select distinct catalog_id, sscat_id, scat_id, sscat from gold.product_info""")

In [None]:
search_catalog_50 = spark.read.parquet(output_path)
result_50 = search_catalog_50.join(df, on="search_catalog_embedding", how="inner")
intersected_result_50 = result_50.join(
    taxonomy, 
    on="catalog_id", 
    how="inner"
).select("cluster", "catalog_id", "sscat_id", "sscat")

In [None]:
catalog_summary = spark.sql("select distinct catalog_id, summary from ds_silver.product_image_captions")
catalog_summary = catalog_summary.join(intersected_result_50, on="catalog_id", how="inner")

In [None]:
df_grouped = catalog_summary.groupBy("cluster").agg(collect_list("summary").alias("summary_list"), collect_list("sscat").alias("sscat"))

In [None]:
output_path = "gs://gcs-dsci-fryou-fy-dev-prd/catalog_vqvae/kmeans_50000_wordmap"
df_cluster_wordmap = spark.read.parquet(f"{output_path}/")

In [None]:
df_llm = df_grouped.join(df_cluster_wordmap.dropDuplicates(["cluster"]), on="cluster", how="inner")

In [None]:
df_llm_1 = df_llm.limit(10)

In [None]:
from huggingface_hub import HfApi, HfFolder
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import json
import re
from pyspark.sql import Row, SparkSession
from pyspark.sql.functions import col

class SimpleCategoryGenerator:
    def __init__(self, model_name="meta-llama/Meta-Llama-3-8B-Instruct", token=None):
        print(f"Loading model: {model_name}")
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            device_map="auto",
            token=token
        )
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
    
    def generate_categories(self, product_description, category_hints=None, features=None):
        category_hints = category_hints or []
        features = features or []
            
        category_str = ", ".join(category_hints[:5]) if category_hints else "Unknown"
        features_str = ", ".join(features[:5]) if features else "Unknown"
        
        # prompt
        prompt = f"""Create a JSON object for this e-commerce product with proper categorization:

            Product: {product_description[:250]}
            Categories: {category_str}
            Features: {features_str}

            Return JSON with these hierarchical fields:
            - "Category": Top-level department or main product group (e.g., Electronics, Clothing, Home & Garden, Beauty, Sports & Outdoors)
            - "Sub-Category": Specific type within that category (must be more specific than Category)
            - "Cluster Name": Descriptive name (5-7 words) highlighting key selling points

            IMPORTANT: Category and Sub-Category must be different. If product is "Men's Running Shoes", 
            Category might be "Footwear" or "Sports Equipment" and Sub-Category would be "Athletic Shoes" or "Running Shoes".

            Examples of good Category → Sub-Category pairs:
            • Clothing → Men's T-Shirts
            • Kitchen → Coffee Makers
            • Electronics → Bluetooth Speakers
            • Beauty → Hair Styling Tools
            • Furniture → Office Chairs

            JSON:
            {{
            "Category": """
        
        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
        
        with torch.no_grad():
            outputs = self.model.generate(
                inputs["input_ids"],
                max_new_tokens=100,
                do_sample=True,
                temperature=0.5,
                top_p=0.9,
                pad_token_id=self.tokenizer.pad_token_id
            )
        
        prompt_length = inputs["input_ids"].shape[1]
        new_tokens = outputs[0][prompt_length:]
        generated_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
        
        full_json = '{\n  "Category": ' + generated_text
        
        if not self._is_complete_json(full_json):
            full_json = self._complete_json(full_json, category_hints, features)
        
        try:
            return json.loads(full_json)
        except json.JSONDecodeError:
            return self._extract_json_fields(full_json)
    
    def _is_complete_json(self, json_str):
        """Check if the JSON contains all required fields and is properly closed"""
        return ('"Category"' in json_str and 
                '"Sub-Category"' in json_str and 
                '"Cluster Name"' in json_str and 
                json_str.strip().endswith("}"))
    
    def _complete_json(self, partial_json, category_hints, features):
        """Complete partial JSON with missing fields"""
        if '"Sub-Category"' not in partial_json:
            partial_json += ',\n  "Sub-Category": '
            if len(category_hints) > 1:
                partial_json += f'"{category_hints[1]}"'
            elif len(category_hints) > 0:
                partial_json += f'"{category_hints[0]} Products"'  # Added 'Products' to differentiate
            else:
                partial_json += '"Various"'
        
        if '"Cluster Name"' not in partial_json:
            partial_json += ',\n  "Cluster Name": '
            if features and len(features) >= 2:
                partial_json += f'"{features[0]} {features[1]} Collection"'  # Added 'Collection'
            elif features and len(features) == 1:
                partial_json += f'"Premium {features[0]} Products"'  # Added 'Premium' and 'Products'
            else:
                partial_json += '"Quality Product Collection"'
        
        if not partial_json.strip().endswith("}"):
            partial_json += '\n}'
            
        return partial_json
    
    def _extract_json_fields(self, text):
        """Extract fields using regex when JSON parsing fails"""
        result = {}
        
        for field in ["Category", "Sub-Category", "Cluster Name"]:
            match = re.search(f'"{field}":\\s*"([^"]+)"', text)
            if match:
                result[field] = match.group(1)
        
        if "Category" not in result and "Sub-Category" not in result:
            result["Category"] = "General Merchandise"
            result["Sub-Category"] = "Various Products"
        elif "Category" in result and "Sub-Category" not in result:
            result["Sub-Category"] = f"{result['Category']} Products"
        elif "Category" not in result and "Sub-Category" in result:
            words = result["Sub-Category"].split()
            if words:
                result["Category"] = words[0]
            else:
                result["Category"] = "General Merchandise"
                
        if "Cluster Name" not in result:
            if "Sub-Category" in result:
                result["Cluster Name"] = f"Premium {result['Sub-Category']} Collection"
            else:
                result["Cluster Name"] = "Quality Product Collection"
                
        return result

2025-04-10 12:07:59.770088: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-10 12:07:59.834834: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [None]:
# Authentication with Hugging Face
token = "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"
api = HfApi()
user = api.whoami(token=token)
print(f"Authenticated as: {user['name']}")

def process_batch_with_llm(iterator):
    """Process a partition of data with the LLM."""
    import torch
    import json

    generator = SimpleCategoryGenerator(token=token)
    
    results = []
    batch_size = 5  
    current_batch = []
    
    for row in iterator:
        current_batch.append(row)
        
        if len(current_batch) >= batch_size:
            processed_batch = process_mini_batch(current_batch, generator)
            results.extend(processed_batch)
            current_batch = []
    
    if current_batch:
        processed_batch = process_mini_batch(current_batch, generator)
        results.extend(processed_batch)
    
    return results

def process_mini_batch(batch, generator):
    """Process a mini-batch of rows"""
    results = []
    for row in batch:
        try:
            summaries = "\n".join(row.summary_list[:5]) if hasattr(row, 'summary_list') and row.summary_list else ""
            
            sscat_list = []
            if hasattr(row, 'sscat'):
                if isinstance(row.sscat, list):
                    sscat_list = row.sscat
                elif row.sscat:
                    sscat_list = [row.sscat]
            
            tokens_list = []
            if hasattr(row, 'tokens'):
                if isinstance(row.tokens, list):
                    tokens_list = row.tokens
                elif row.tokens:
                    tokens_list = [row.tokens]
            
            result = generator.generate_categories(
                product_description=summaries,
                category_hints=sscat_list,
                features=tokens_list
            )
            
            new_row = row.asDict()
            new_row["category"] = result.get("Category", "")
            new_row["sub_category"] = result.get("Sub-Category", "")
            new_row["cluster_name"] = result.get("Cluster Name", "")
            
            for field in ["summary_list", "tokens", "sscat"]:
                if field in new_row:
                    del new_row[field]
                    
            results.append(Row(**new_row))
            
        except Exception as e:
            print(f"Error processing row: {e}")
            new_row = row.asDict()
            for field in ["summary_list", "tokens", "sscat"]:
                if field in new_row:
                    del new_row[field]
            new_row["category"] = ""
            new_row["sub_category"] = ""
            new_row["cluster_name"] = ""
            results.append(Row(**new_row))
    
    return results

Authenticated as: Arshimeesho


In [None]:
def process_dataframe(df, chunk_size=10):
    total_count = df.count()
    print(f"Total rows to process: {total_count}")
    
    if total_count <= chunk_size:
        print(f"Processing all {total_count} rows at once")
        return df.rdd.mapPartitions(process_batch_with_llm).toDF()
    
    print(f"Processing in chunks of {chunk_size} rows")
    results = []
    
    num_chunks = (total_count + chunk_size - 1) // chunk_size
    
    for i in range(num_chunks):
        print(f"Processing chunk {i+1} of {num_chunks}")
        chunk_df = df.limit(chunk_size).offset(i * chunk_size)
        
        chunk_df = chunk_df.repartition(4) 
        
        result_df = chunk_df.rdd.mapPartitions(process_batch_with_llm).toDF()
        results.append(result_df)
        print(f"Chunk {i+1} completed")
    
    final_df = spark.createDataFrame([], schema=results[0].schema)
    for df in results:
        final_df = final_df.union(df)
    
    return final_df


df_llm_1 = df_llm.limit(10)
df_final = process_dataframe(df_llm_1, chunk_size=5)
df_final.cache()
print("Processing complete. Displaying results...")
display(df_final)