# Fashion Attribute Signal Extraction

This notebook implements Step 5 from the project checklist: "Build structured attribute signals from text (materials + design cues)"

The goal is to extract structured fashion attributes from product text descriptions to enable better similarity matching for garment design and materials.

In [51]:
# Import required libraries
import os
import re
import random
from typing import List, Dict, Set
from collections import Counter, defaultdict

import pandas as pd
import numpy as np
from datasets import Dataset, DatasetDict, load_from_disk
from tqdm import tqdm

# Set random seed for reproducibility
random.seed(42)
np.random.seed(42)

## Step 1: Load Dataset and Define Attribute Schema

Load the cleaned dataset and define the attribute extraction schema based on the project requirements.

In [None]:
# Load the cleaned dataset
print("Loading cleaned dataset...")
ds = load_from_disk("../data/processed/hf")
train_ds = ds["train"]
print(f"Dataset loaded with {len(train_ds)} examples")
print(f"Features: {train_ds.features}")

# Define attribute schema based on project requirements
ATTRIBUTE_SCHEMA = {
    "material": {
        "keywords": [
            "cotton", "denim", "knit", "leather", "polyester", "silk", "wool", 
            "linen", "nylon", "spandex", "rayon", "acrylic", "velvet", "satin",
            "chiffon", "lace", "mesh", "jersey", "fleece", "tulle", "organza"
        ],
        "description": "Fabric type and material composition"
    },
    "pattern": {
    "keywords": [
        # Original
        "solid", "stripe", "plaid", "check", "polka dot", "floral", "graphic",
        "animal print", "camouflage", "color-block", "ombre", "tie-dye",
        "paisley", "houndstooth", "gingham", "argyle", "chevron",
        # Missing terms
        "pure color", "plain", "solid color", "single color", "monochrome",
        "geometric", "abstract", "motif", "print", "design", "patterned"
    ]
},
    "neckline": {
    "keywords": [
        # Basic styles
        "crew", "v-neck", "scoop", "boat", "square", "halter", "sweetheart",
        "off-the-shoulder", "turtleneck", "mock neck", "collar", "peter pan",
        "cowl", "asymmetric", "keyhole", "round", "stand", "lapel", "v-shape",
        # Compound terms  
        "crew neck", "v neck", "scoop neck", "boat neck", "square neck", 
        "sweetheart neck", "halter neck", "turtle neck", "mockneck", 
        "mandarin collar", "notched collar", "shirt collar", "lapel collar",
        "stand collar", "round neck", "v-shape neck"
    ]
},
    "sleeve": {
        "keywords": [
            "sleeveless", "short", "long", "3/4", "cap", "bell", "puff",
            "raglan", "bishop", "flutter", "batwing", "dolman", "cold shoulder"
        ],
        "description": "Sleeve length and style"
    },
    "closure": {
        "keywords": [
            "button", "zipper", "snap", "hook", "tie", "velcro", "magnetic",
            "elastic", "drawstring", "frog", "buckle", "toggle"
        ],
        "description": "Fastening and closure mechanisms"
    }
}

print("Attribute schema defined with categories:", list(ATTRIBUTE_SCHEMA.keys()))

Loading cleaned dataset...
Dataset loaded with 38283 examples
Features: {'image': Image(mode=None, decode=True), 'category1': Value('string'), 'category2': Value('string'), 'category3': Value('float64'), 'text': Value('string'), 'item_ID': Value('string')}
Attribute schema defined with categories: ['material', 'pattern', 'neckline', 'sleeve', 'closure']


## Step 2: Attribute Extraction Functions

Implement rule-based attribute extraction functions that search for keywords in the text descriptions.

In [None]:
def extract_with_target_anchoring(text: str, target_category: str, schema: Dict) -> Dict[str, List[str]]:
    """
    Extract attributes by anchoring to the target garment throughout the text.
    """
    target_scope = get_garment_scope(target_category)
    text_lower = text.lower()
    
    # Define anchor patterns for each garment type
    anchor_patterns = {
        'upper': [
            r'\b(?:the\s+)?(?:shirt|blouse|top|jacket|coat|sweater|cardigan|vest|t-shirt)\b',
            r'\b(?:her|his|their|the)\s+(?:shirt|blouse|top|jacket|coat|sweater)\b'
        ],
        'lower': [
            r'\b(?:the\s+)?(?:pants|jeans|shorts|skirt|leggings|trousers)\b', 
            r'\b(?:her|his|their|the)\s+(?:pants|jeans|shorts|skirt|leggings)\b'
        ],
        'full': [
            r'\b(?:the\s+)?(?:dress|jumpsuit|romper|overall|gown)\b',
            r'\b(?:her|his|their|the)\s+(?:dress|jumpsuit|romper|overall)\b'
        ]
    }
    
    attributes = {}
    for attr_name, attr_config in schema.items():
        found_keywords = []
        
        for keyword in attr_config["keywords"]:
            keyword_lower = keyword.lower()
            
            # Check if keyword appears near target garment anchors
            if is_keyword_anchored_to_target(keyword_lower, text_lower, anchor_patterns[target_scope]):
                if keyword not in found_keywords:  # Avoid duplicates
                    found_keywords.append(keyword)
        
        attributes[attr_name] = found_keywords
    
    return attributes

def is_keyword_anchored_to_target(keyword: str, text: str, anchor_patterns: List[str]) -> bool:
    """
    Check if a keyword is syntactically linked to the target garment.
    """
    import re
    
    # Find all anchor positions
    anchor_positions = []
    for pattern in anchor_patterns:
        for match in re.finditer(pattern, text):
            anchor_positions.append(match.start())
    
    if not anchor_positions:
        return False
    
    # Find keyword positions
    keyword_positions = []
    start = 0
    while True:
        pos = text.find(keyword, start)
        if pos == -1:
            break
        keyword_positions.append(pos)
        start = pos + 1
    
    # Check if any keyword is within reasonable distance of any anchor
    max_distance = 100  # characters
    
    for kw_pos in keyword_positions:
        for anchor_pos in anchor_positions:
            if abs(kw_pos - anchor_pos) <= max_distance:
                # Additional check: no competing garment mentions in between
                between_text = text[min(kw_pos, anchor_pos):max(kw_pos, anchor_pos)]
                competing_garments = ['pants', 'jeans', 'shirt', 'blouse', 'skirt', 'dress']
                has_competition = any(garment in between_text for garment in competing_garments 
                                    if garment not in anchor_patterns[0])  # Don't count target garment as competition
                
                if not has_competition:
                    return True
    
    return False

def extract_garment_specific_attributes(text: str, target_category: str, schema: Dict) -> Dict[str, List[str]]:
    """
    Extract attributes specific to the target garment type.
    """
    # Use the existing get_garment_scope function
    target_scope = get_garment_scope(target_category)
    
    # Split text into garment-specific segments
    segments = segment_text_by_garments(text, target_scope)
    
    # Extract only from relevant segments
    attributes = {}
    for attr_name, attr_config in schema.items():
        found_keywords = []
        for segment in segments:
            for keyword in attr_config["keywords"]:
                if is_attribute_linked_to_garment(keyword, segment, target_scope):
                    found_keywords.append(keyword)
        
        attributes[attr_name] = list(set(found_keywords))  # Remove duplicates
    
    return attributes

def segment_text_by_garments(text: str, target_scope: str) -> List[str]:
    """
    Improved segmentation that keeps more context while avoiding contamination.
    """
    sentences = [s.strip() for s in text.split('.') if s.strip()]
    
    if target_scope == 'full':
        return sentences
    
    relevant_segments = []
    garment_indicators = {
        'upper': ['shirt', 'blouse', 'top', 'jacket', 'coat', 'sweater', 't-shirt', 'cardigan', 'vest'],
        'lower': ['pants', 'jeans', 'shorts', 'skirt', 'leggings', 'trousers', 'legging']
    }
    
    for sentence in sentences:
        sentence_lower = sentence.lower()
        
        # Include sentences that mention target garment
        if any(indicator in sentence_lower for indicator in garment_indicators[target_scope]):
            relevant_segments.append(sentence)
        # For upper garments, include general description sentences that don't mention lower garments
        elif target_scope == 'upper' and not any(lower_ind in sentence_lower for lower_ind in garment_indicators['lower']):
            # Additional check: look for upper garment attributes
            upper_attrs = ['sleeve', 'neck', 'collar', 'button', 'zipper']
            if any(attr in sentence_lower for attr in upper_attrs):
                relevant_segments.append(sentence)
    
    return relevant_segments if relevant_segments else sentences


def is_attribute_linked_to_garment(keyword: str, segment: str, target_scope: str) -> bool:
    """
    Check if an attribute keyword is syntactically linked to the target garment.
    """
    segment_lower = segment.lower()
    keyword_lower = keyword.lower()
    
    # Find garment mentions in the segment
    garment_indicators = {
    'upper': ['shirt', 'blouse', 'top', 'jacket', 'coat', 'sweater', 't-shirt', 'cardigan'],
    'lower': ['pants', 'jeans', 'shorts', 'skirt', 'leggings', 'trousers'],
    'full': ['dress', 'jumpsuit', 'romper', 'overall']
    }
    
    # Look for explicit attribute-garment relationships
    # Pattern: "garment is/with/has [attribute]"
    for indicator in garment_indicators[target_scope]:
        if indicator in segment_lower:
            # Check for connecting words
            patterns = [
                f"{indicator}.*?{keyword_lower}",
                f"{keyword_lower}.*?{indicator}",
                f"{indicator}.*?(is|with|has).*?{keyword_lower}",
                f"{keyword_lower}.*?(is|with|has).*?{indicator}"
            ]
            for pattern in patterns:
                if re.search(pattern, segment_lower, re.IGNORECASE):
                    return True
    
    return False

def get_primary_attribute_smart(attr_list: List[str], text: str, attr_type: str) -> str:
    """
    Select primary attribute with text-position priority, not semantic interpretation.
    """
    if not attr_list:
        return ""
    
    if len(attr_list) == 1:
        return attr_list[0]
    
    text_lower = text.lower()
    
    # Priority: First occurrence in text, not semantic similarity
    for attr in attr_list:
        if attr.lower() in text_lower:
            # Find the earliest position of this attribute in text
            attr_pos = text_lower.find(attr.lower())
            if attr_pos != -1:
                return attr
    
    # Fallback to first in list
    return attr_list[0]

def apply_negative_constraints(attributes: Dict[str, List[str]], target_scope: str) -> Dict[str, List[str]]:
    """
    Apply negative constraints to prevent impossible attribute combinations.
    """
    filtered = attributes.copy()
    
    # Sleeve constraints for upper garments
    if target_scope == 'upper':
        sleeve_attrs = filtered.get('sleeve', [])
        # Prevent impossible combinations like sleeveless + long
        if 'sleeveless' in sleeve_attrs and any(long in sleeve_attrs for long in ['long', '3/4']):
            # Keep sleeveless, remove conflicting attributes
            filtered['sleeve'] = ['sleeveless']
    
    # Add more constraints as needed...
    
    return filtered



def get_primary_attribute(attr_list: List[str], priority_order: List[str] = None) -> str:
    """
    Select primary attribute from list (first match or based on priority).
    
    Args:
        attr_list: List of found attributes
        priority_order: Optional priority ordering
    
    Returns:
        Primary attribute string or empty string
    """
    if not attr_list:
        return ""
    
    if priority_order:
        for priority in priority_order:
            if priority in attr_list:
                return priority
    
    return attr_list[0]

def get_garment_scope(target_category: str) -> str:
    """
    Determine the garment scope based on category2 value.
    
    Args:
        target_category: The category2 value from the dataset (e.g., 'blouses')
    
    Returns:
        Garment scope: 'upper', 'lower', or 'full'
    """
    # Map category2 values to garment scopes
    GARMENT_TYPES = {
        'upper': [
            'blouses', 'shirts', 't-shirts', 'tops', 'jackets', 'coats', 
            'sweaters', 'cardigans', 'blazers', 'vests', 'hoodies'
        ],
        'lower': [
            'pants', 'jeans', 'shorts', 'skirts', 'leggings', 'trousers',
            'capris', 'culottes'
        ],
        'full': [
            'dresses', 'jumpsuits', 'rompers', 'overalls', 'gowns'
        ]
    }
    
    target_lower = target_category.lower()
    for scope, categories in GARMENT_TYPES.items():
        if target_lower in categories:
            return scope
    
    # Default fallback - most fashion items are upper body focused
    return 'upper'

def process_example_improved(example: Dict) -> Dict:
    """
    Process a single example to extract garment-specific attributes.
    
    Args:
        example: Dataset example with 'text' and 'category2' fields
    
    Returns:
        Example with added attribute fields
    """
    text = example.get("text", "")
    target_category = example.get("category2", "")
    
    # Use the new garment-specific extraction
    extracted = extract_garment_specific_attributes(text, target_category, ATTRIBUTE_SCHEMA)
    
    # Apply negative constraints
    extracted = apply_negative_constraints(extracted, get_garment_scope(target_category))
    
    # Add to example
    for attr_name in ATTRIBUTE_SCHEMA.keys():
        attr_key = f"attr_{attr_name}"
        primary_key = f"attr_{attr_name}_primary"
        
        example[attr_key] = extracted[attr_name]
        example[primary_key] = get_primary_attribute_smart(extracted[attr_name], text, attr_name)
    
    return example


# Test the improved extraction on a sample
sample_example = train_ds[0]
print("Sample text:", sample_example["text"])
print("Target category:", sample_example["category2"])
processed_sample = process_example_improved(sample_example.copy())
print("Extracted attributes:")
for attr_name in ATTRIBUTE_SCHEMA.keys():
    attr_key = f"attr_{attr_name}"
    primary_key = f"attr_{attr_name}_primary"
    print(f"  {attr_name}: {processed_sample[attr_key]} (primary: {processed_sample[primary_key]})")

Sample text: This woman wears a short-sleeve T-shirt with color block patterns and a long pants. The T-shirt is with cotton fabric and its neckline is crew. The pants are with cotton fabric and solid color patterns. This female is wearing a ring on her finger. There is an accessory on her wrist.
Target category: blouses
Extracted attributes:
  material: ['cotton'] (primary: cotton)
  pattern: [] (primary: )
  neckline: ['crew'] (primary: crew)
  sleeve: ['long', 'short'] (primary: long)
  closure: [] (primary: )


## Step 3: Process Sample Dataset

Process a small sample of the dataset to test and validate the extraction pipeline.

In [80]:
# Process a sample of the dataset for testing
sample_size = min(1000, len(train_ds))
print(f"Processing sample of {sample_size} examples...")

# Take a sample for testing
sample_indices = random.sample(range(len(train_ds)), sample_size)
sample_examples = [train_ds[i] for i in sample_indices]

# Process sample
processed_sample = []
for example in tqdm(sample_examples, desc="Processing sample"):
    processed_sample.append(process_example_improved(example.copy()))  # FIXED

# Convert to Dataset for analysis
sample_ds = Dataset.from_list(processed_sample)
print(f"Sample processing complete. Shape: {sample_ds.shape}")

# Basic statistics
print("\nAttribute extraction statistics on sample:")
for attr_name in ATTRIBUTE_SCHEMA.keys():
    attr_key = f"attr_{attr_name}"
    primary_key = f"attr_{attr_name}_primary"
    
    # Count examples with any attributes
    has_attr = sum(1 for ex in processed_sample if ex[attr_key])
    
    # Count examples with primary attribute
    has_primary = sum(1 for ex in processed_sample if ex[primary_key])
    
    print(f"{attr_name}: {has_attr}/{sample_size} examples with attributes ({has_primary} with primary)")

Processing sample of 1000 examples...


Processing sample: 100%|██████████| 1000/1000 [00:39<00:00, 25.61it/s]


Sample processing complete. Shape: (1000, 16)

Attribute extraction statistics on sample:
material: 673/1000 examples with attributes (673 with primary)
pattern: 572/1000 examples with attributes (572 with primary)
neckline: 147/1000 examples with attributes (147 with primary)
sleeve: 468/1000 examples with attributes (468 with primary)
closure: 0/1000 examples with attributes (0 with primary)


## Step 4: Analyze Extraction Quality

Analyze the quality of attribute extraction and identify common patterns or issues.

In [81]:
# Analyze extraction quality
def analyze_extraction_quality(processed_examples: List[Dict]) -> Dict:
    """
    Analyze the quality of attribute extraction.
    
    Args:
        processed_examples: List of processed examples with attributes
    
    Returns:
        Dictionary with quality metrics
    """
    total_examples = len(processed_examples)
    
    # Initialize counters
    attr_counts = defaultdict(Counter)
    primary_counts = defaultdict(Counter)
    coverage_stats = {}
    
    for attr_name in ATTRIBUTE_SCHEMA.keys():
        attr_key = f"attr_{attr_name}"
        primary_key = f"attr_{attr_name}_primary"
        
        # Count all extracted attributes
        all_attrs = []
        primary_attrs = []
        
        for ex in processed_examples:
            all_attrs.extend(ex[attr_key])
            if ex[primary_key]:
                primary_attrs.append(ex[primary_key])
        
        attr_counts[attr_name] = Counter(all_attrs)
        primary_counts[attr_name] = Counter(primary_attrs)
        
        # Coverage statistics
        examples_with_attr = sum(1 for ex in processed_examples if ex[attr_key])
        examples_with_primary = sum(1 for ex in processed_examples if ex[primary_key])
        
        coverage_stats[attr_name] = {
            "examples_with_attributes": examples_with_attr,
            "examples_with_primary": examples_with_primary,
            "coverage_percentage": examples_with_attr / total_examples * 100,
            "primary_coverage_percentage": examples_with_primary / total_examples * 100,
            "unique_attributes_found": len(attr_counts[attr_name]),
            "total_attribute_mentions": sum(attr_counts[attr_name].values())
        }
    
    return {
        "coverage_stats": coverage_stats,
        "attr_counts": dict(attr_counts),
        "primary_counts": dict(primary_counts),
        "total_examples": total_examples
    }

# Analyze the sample
quality_analysis = analyze_extraction_quality(processed_sample)

print("\nExtraction Quality Analysis:")
print("=" * 50)
for attr_name, stats in quality_analysis["coverage_stats"].items():
    print(f"\n{attr_name.upper()}:")
    print(f"  Coverage: {stats['examples_with_attributes']}/{quality_analysis['total_examples']} examples ({stats['coverage_percentage']:.1f}%)")
    print(f"  Primary: {stats['examples_with_primary']}/{quality_analysis['total_examples']} examples ({stats['primary_coverage_percentage']:.1f}%)")
    print(f"  Unique attributes found: {stats['unique_attributes_found']}")
    print(f"  Total mentions: {stats['total_attribute_mentions']}")
    
    # Show top attributes
    top_attrs = quality_analysis["primary_counts"][attr_name].most_common(5)
    if top_attrs:
        print(f"  Top primary attributes: {', '.join([f'{attr}({count})' for attr, count in top_attrs])}")


Extraction Quality Analysis:

MATERIAL:
  Coverage: 673/1000 examples (67.3%)
  Primary: 673/1000 examples (67.3%)
  Unique attributes found: 4
  Total mentions: 673
  Top primary attributes: cotton(553), chiffon(75), knit(34), denim(11)

PATTERN:
  Coverage: 572/1000 examples (57.2%)
  Primary: 572/1000 examples (57.2%)
  Unique attributes found: 7
  Total mentions: 758
  Top primary attributes: solid color(180), graphic(179), pure color(155), stripe(34), floral(22)

NECKLINE:
  Coverage: 147/1000 examples (14.7%)
  Primary: 147/1000 examples (14.7%)
  Unique attributes found: 4
  Total mentions: 168
  Top primary attributes: crew(63), round(55), suspenders(29)

SLEEVE:
  Coverage: 468/1000 examples (46.8%)
  Primary: 468/1000 examples (46.8%)
  Unique attributes found: 3
  Total mentions: 490
  Top primary attributes: long(259), short(157), sleeveless(52)

CLOSURE:
  Coverage: 0/1000 examples (0.0%)
  Primary: 0/1000 examples (0.0%)
  Unique attributes found: 0
  Total mentions: 0


## Step 5: Iterative Refinement

Review extraction results and refine the attribute schema or extraction logic as needed.

In [82]:
# Manual review of extraction results
def review_extractions(processed_examples: List[Dict], num_samples: int = 10) -> None:
    """
    Manually review extraction results for quality assessment.
    
    Args:
        processed_examples: List of processed examples
        num_samples: Number of samples to review
    """
    print(f"\nManual Review of {num_samples} Extraction Results:")
    print("=" * 60)
    
    # Random sample for review
    review_indices = random.sample(range(len(processed_examples)), min(num_samples, len(processed_examples)))
    
    for i, idx in enumerate(review_indices):
        ex = processed_examples[idx]
        print(f"\nExample {i+1}:")
        print(f"Text: {ex['text'][:300]}...")
        print("Extracted attributes:")
        
        for attr_name in ATTRIBUTE_SCHEMA.keys():
            attr_key = f"attr_{attr_name}"
            primary_key = f"attr_{attr_name}_primary"
            attrs = ex[attr_key]
            primary = ex[primary_key]
            if attrs or primary:
                print(f"  {attr_name}: {attrs} (primary: {primary})")
        
        print("-" * 60)

# Review sample extractions
review_extractions(processed_sample, num_samples=5)


Manual Review of 5 Extraction Results:

Example 1:
Text: Her sweater has long sleeves, cotton fabric and pure color patterns. The neckline of it is stand. This lady wears a three-point pants. The pants are with cotton fabric and floral patterns. The person is wearing leggings....
Extracted attributes:
  material: ['cotton'] (primary: cotton)
  pattern: ['floral'] (primary: floral)
------------------------------------------------------------

Example 2:
Text: The T-shirt this gentleman wears has short sleeves and it is with cotton fabric and striped patterns. The neckline of the T-shirt is lapel....
Extracted attributes:
  material: ['cotton'] (primary: cotton)
  pattern: ['stripe'] (primary: stripe)
  sleeve: ['short'] (primary: short)
------------------------------------------------------------

Example 3:
Text: The person wears a sleeveless tank shirt with pure color patterns. The tank shirt is with cotton fabric. The pants the person wears is of three-point length. The pants are wi

## Step 6: Optimize Extraction Logic

Based on the analysis, optimize the extraction logic and attribute schema.

In [58]:
# Optimized extraction with improved patterns
def extract_attributes_optimized(text: str, schema: Dict) -> Dict[str, List[str]]:
    """
    Optimized attribute extraction with better pattern matching.
    
    Args:
        text: Product description text
        schema: Attribute schema with keywords
    
    Returns:
        Dictionary mapping attribute names to lists of found keywords
    """
    text_lower = text.lower()
    attributes = {}
    
    for attr_name, attr_config in schema.items():
        found_keywords = []
        
        for keyword in attr_config["keywords"]:
            # More flexible matching - handle compound words and variations
            keyword_lower = keyword.lower()
            
            # Exact word match
            pattern = r'\b' + re.escape(keyword_lower) + r'\b'
            if re.search(pattern, text_lower):
                found_keywords.append(keyword)
                continue
            
            # Handle compound keywords with spaces
            if ' ' in keyword_lower:
                # Allow some flexibility in spacing
                flexible_pattern = r'\b' + re.escape(keyword_lower).replace(' ', r'\s+') + r'\b'
                if re.search(flexible_pattern, text_lower):
                    found_keywords.append(keyword)
                    continue
            
            # Handle common variations (e.g., "t-shirt" for "tee")
            if attr_name == "material" and keyword_lower == "cotton":
                if any(term in text_lower for term in ["cotton", "cotton fabric"]):
                    found_keywords.append(keyword)
            
        # Remove duplicates while preserving order
        seen = set()
        unique_keywords = []
        for keyword in found_keywords:
            if keyword not in seen:
                unique_keywords.append(keyword)
                seen.add(keyword)
        
        attributes[attr_name] = unique_keywords
    
    return attributes

# Update the process function to use optimized extraction
def process_example_optimized(example: Dict) -> Dict:
    """
    Process example with optimized attribute extraction.
    """
    text = example.get("text", "")
    
    # Extract attributes with optimized function
    extracted = extract_attributes_optimized(text, ATTRIBUTE_SCHEMA)
    
    # Add to example
    for attr_name in ATTRIBUTE_SCHEMA.keys():
        attr_key = f"attr_{attr_name}"
        primary_key = f"attr_{attr_name}_primary"
        
        example[attr_key] = extracted[attr_name]
        example[primary_key] = get_primary_attribute(extracted[attr_name])
    
    return example

# Test optimized extraction on same sample
print("Testing optimized extraction...")
optimized_sample = [process_example_optimized(ex.copy()) for ex in sample_examples]

# Compare results
optimized_quality = analyze_extraction_quality(optimized_sample)
print("\nOptimized Extraction Results:")
print("=" * 40)
for attr_name in ATTRIBUTE_SCHEMA.keys():
    old_stats = quality_analysis["coverage_stats"][attr_name]
    new_stats = optimized_quality["coverage_stats"][attr_name]
    
    improvement = new_stats["coverage_percentage"] - old_stats["coverage_percentage"]
    print(f"{attr_name}: {new_stats['coverage_percentage']:.1f}% (was {old_stats['coverage_percentage']:.1f}%, +{improvement:.1f}%)")

Testing optimized extraction...

Optimized Extraction Results:
material: 94.9% (was 94.9%, +0.0%)
pattern: 75.5% (was 75.5%, +0.0%)
neckline: 21.1% (was 21.1%, +0.0%)
sleeve: 71.7% (was 71.7%, +0.0%)
closure: 0.0% (was 0.0%, +0.0%)


## Step 7: Save Enhanced Dataset

What this block does:
Provides code to process the full dataset (commented out for safety)
Saves the enhanced dataset with extracted attributes
Shows final summary and example extraction
Key things to notice:
Processes full dataset only when ready (commented by default)
Saves in HuggingFace dataset format for easy loading
Provides final verification of extraction quality

In [None]:
# Step 7: Save Enhanced Dataset

def save_enhanced_dataset():
    """
    Process the full dataset and save with extracted attributes.
    COMMENTED OUT FOR SAFETY - Uncomment only when ready to process full dataset.
    """
    print("WARNING: This will process the entire dataset and may take significant time and resources.")
    print("Only run this when you're satisfied with the extraction quality.")
    
    # Uncomment the following lines when ready:
    
    # print("Processing full dataset...")
    # full_processed = []
    # for example in tqdm(train_ds, desc="Processing full dataset"):
    #     processed_example = process_example_optimized(example.copy())
    #     full_processed.append(processed_example)
    
    # # Convert to Dataset
    # enhanced_ds = Dataset.from_list(full_processed)
    
    # # Save to disk
    # output_path = "../data/processed/hf_with_attributes"
    # os.makedirs(output_path, exist_ok=True)
    # enhanced_ds.save_to_disk(output_path)
    # print(f"Enhanced dataset saved to: {output_path}")
    
    # return enhanced_ds
    
    return None

# For demonstration, save the sample dataset instead
def save_sample_enhanced_dataset():
    """
    Save the processed sample dataset for testing.
    """
    print("Saving enhanced sample dataset...")
    
    # Create sample dataset
    enhanced_sample_ds = Dataset.from_list(optimized_sample)
    
    # Save to disk
    output_path = "../data/processed/hf_with_attributes_sample"
    os.makedirs(output_path, exist_ok=True)
    enhanced_sample_ds.save_to_disk(output_path)
    print(f"Enhanced sample dataset saved to: {output_path}")
    
    return enhanced_sample_ds

# Save sample dataset
enhanced_sample = save_sample_enhanced_dataset()

# Final summary and example extraction
print("\nFinal Summary:")
print("=" * 50)
print(f"Sample dataset processed: {len(enhanced_sample)} examples")
print("\nAttribute extraction coverage:")

for attr_name, stats in optimized_quality["coverage_stats"].items():
    print(f"  {attr_name}: {stats['coverage_percentage']:.1f}% ({stats['examples_with_attributes']}/{optimized_quality['total_examples']} examples)")

# Show example extraction
print("\nExample Extraction:")
print("-" * 30)
example_idx = 0
example = enhanced_sample[example_idx]
print(f"Text: {example['text'][:150]}...")
print("\nExtracted attributes:")
for attr_name in ATTRIBUTE_SCHEMA.keys():
    attr_key = f"attr_{attr_name}"
    primary_key = f"attr_{attr_name}_primary"
    attrs = example[attr_key]
    primary = example[primary_key]
    if attrs:
        print(f"  {attr_name}: {attrs} (primary: {primary})")
    else:
        print(f"  {attr_name}: None")

print("\n✅ Enhanced dataset ready for use!")
print("Next steps:")
print("1. Review extraction quality in validation step")
print("2. Uncomment and run full dataset processing when ready")
print("3. Use enhanced dataset for training similarity models")

## Step 8: Validation and Error Analysis

What this block does:
Provides a framework for manual validation of extraction accuracy
Randomly samples examples for human review
Tracks validation metrics and issues
Key things to notice:
Uses reproducible random sampling
Provides clear review guidelines for manual validation
Tracks different types of errors (missing, incorrect, hallucinated)

In [None]:
# Step 8: Validation and Error Analysis

class AttributeValidator:
    """
    Framework for validating attribute extraction accuracy.
    """
    
    def __init__(self, dataset, sample_size: int = 100, random_seed: int = 42):
        """
        Initialize validator.
        
        Args:
            dataset: Dataset with extracted attributes
            sample_size: Number of examples to sample for validation
            random_seed: Random seed for reproducible sampling
        """
        self.dataset = dataset
        self.sample_size = min(sample_size, len(dataset))
        self.random_seed = random_seed
        self.validation_results = []
        
        # Sample examples for validation
        random.seed(random_seed)
        self.sample_indices = random.sample(range(len(dataset)), self.sample_size)
        
    def get_validation_sample(self) -> List[Dict]:
        """
        Get the validation sample for manual review.
        
        Returns:
            List of examples for validation
        """
        return [self.dataset[i] for i in self.sample_indices]
    
    def display_validation_guidelines(self):
        """
        Display guidelines for manual validation.
        """
        print("\nVALIDATION GUIDELINES")
        print("=" * 50)
        print("For each example, evaluate the extracted attributes:")
        print("\n1. MATERIAL:")
        print("   - Check if fabric types (cotton, denim, etc.) are correctly identified")
        print("   - Mark as MISSING if material mentioned but not extracted")
        print("   - Mark as INCORRECT if wrong material extracted")
        print("   - Mark as HALLUCINATED if material extracted but not in text")
        print("\n2. PATTERN:")
        print("   - Verify patterns (stripe, floral, solid, etc.) are accurate")
        print("   - Pay attention to color-block vs solid color distinctions")
        print("\n3. NECKLINE/SLEEVE:")
        print("   - Check if garment style attributes are correctly identified")
        print("   - Look for crew neck, v-neck, short/long sleeves, etc.")
        print("\n4. CLOSURE:")
        print("   - Verify buttons, zippers, snaps are correctly detected")
        print("\nRATING SCALE:")
        print("   CORRECT: All extracted attributes are accurate")
        print("   MOSTLY_CORRECT: >75% of attributes are accurate")
        print("   PARTIALLY_CORRECT: 25-75% of attributes are accurate")
        print("   POOR: <25% of attributes are accurate")
        print("\nRecord your assessment for each example.")
    
    def validate_sample_manually(self) -> Dict:
        """
        Interactive manual validation of the sample.
        
        Returns:
            Validation results dictionary
        """
        sample = self.get_validation_sample()
        results = []
        
        print(f"\nStarting manual validation of {len(sample)} examples...")
        print("(This is a framework - in practice, you'd have human reviewers)")
        
        # For demonstration, we'll do automated validation
        # In real scenario, this would be manual
        for i, example in enumerate(sample):
            print(f"\nExample {i+1}/{len(sample)}:")
            print(f"Text: {example['text'][:200]}...")
            
            # Display extracted attributes
            extracted_attrs = {}
            for attr_name in ATTRIBUTE_SCHEMA.keys():
                attr_key = f"attr_{attr_name}"
                primary_key = f"attr_{attr_name}_primary"
                attrs = example[attr_key]
                primary = example[primary_key]
                if attrs:
                    extracted_attrs[attr_name] = attrs
                    print(f"  {attr_name}: {attrs} (primary: {primary})")
            
            # Automated validation (in real scenario, this would be manual)
            validation_result = self._automated_validation(example, extracted_attrs)
            results.append(validation_result)
            
            # Show validation result
            print(f"  Validation: {validation_result['overall_rating']}")
            if validation_result['issues']:
                print(f"  Issues: {', '.join(validation_result['issues'])}")
            
            # Stop after a few examples for demonstration
            if i >= 4:
                print("\n... (showing first 5 examples for demonstration)")
                break
        
        return self._analyze_validation_results(results)
    
    def _automated_validation(self, example: Dict, extracted_attrs: Dict) -> Dict:
        """
        Automated validation for demonstration (in practice, this would be manual).
        """
        text = example['text'].lower()
        issues = []
        correct_count = 0
        total_attrs = len(extracted_attrs)
        
        for attr_name, attrs in extracted_attrs.items():
            if not attrs:
                continue
                
            # Check if extracted attributes appear in text
            for attr in attrs:
                if attr.lower() not in text:
                    issues.append(f"hallucinated_{attr_name}_{attr}")
                else:
                    correct_count += 1
        
        # Determine overall rating
        if total_attrs == 0:
            rating = "NO_ATTRIBUTES"
        elif issues:
            rating = "ISSUES_FOUND"
        else:
            rating = "APPEARS_CORRECT"
        
        return {
            "example_id": example.get("item_ID", "unknown"),
            "overall_rating": rating,
            "issues": issues,
            "correct_attributes": correct_count,
            "total_attributes": total_attrs
        }
    
    def _analyze_validation_results(self, results: List[Dict]) -> Dict:
        """
        Analyze validation results and compute metrics.
        
        Args:
            results: List of validation result dictionaries
        
        Returns:
            Analysis summary
        """
        if not results:
            return {}
        
        total_examples = len(results)
        ratings = Counter(r['overall_rating'] for r in results)
        
        # Collect all issues
        all_issues = []
        for r in results:
            all_issues.extend(r['issues'])
        
        issue_counts = Counter(all_issues)
        
        # Calculate accuracy metrics
        total_attributes = sum(r['total_attributes'] for r in results)
        correct_attributes = sum(r['correct_attributes'] for r in results)
        
        accuracy = correct_attributes / total_attributes if total_attributes > 0 else 0
        
        return {
            "total_examples_validated": total_examples,
            "rating_distribution": dict(ratings),
            "common_issues": dict(issue_counts.most_common(10)),
            "attribute_accuracy": accuracy,
            "total_attributes_checked": total_attributes,
            "correct_attributes": correct_attributes
        }

# Initialize validator
validator = AttributeValidator(enhanced_sample, sample_size=100)

# Display guidelines
validator.display_validation_guidelines()

# Run validation
validation_results = validator.validate_sample_manually()

# Display validation summary
print("\nVALIDATION SUMMARY")
print("=" * 30)
print(f"Examples validated: {validation_results['total_examples_validated']}")
print(f"Attribute accuracy: {validation_results['attribute_accuracy']:.1%}")
print(f"Total attributes checked: {validation_results['total_attributes_checked']}")
print(f"Correct attributes: {validation_results['correct_attributes']}")

print("\nRating distribution:")
for rating, count in validation_results['rating_distribution'].items():
    pct = count / validation_results['total_examples_validated'] * 100
    print(f"  {rating}: {count} ({pct:.1f}%)")

print("\nMost common issues:")
for issue, count in list(validation_results['common_issues'].items())[:5]:
    print(f"  {issue}: {count}")

# Recommendations based on validation
accuracy = validation_results['attribute_accuracy']
if accuracy > 0.8:
    print("\n✅ High accuracy! Ready for production use.")
elif accuracy > 0.6:
    print("\n⚠️ Moderate accuracy. Consider refining extraction rules.")
else:
    print("\n❌ Low accuracy. Significant improvements needed.")
    print("   Consider:")
    print("   - Adding more keywords to schema")
    print("   - Implementing more sophisticated matching")
    print("   - Using machine learning for extraction")

print("\nNext steps:")
print("1. Address identified issues")
print("2. Re-run validation after improvements")
print("3. Proceed to model training when accuracy meets requirements")