In [None]:
import os
import json
import pandas as pd
import numpy as np
from sentence_transformers import SentenceTransformer
import torch
import torch.nn as nn
import re
import h5py
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm

class FigmaHTMLFeatureExtractor:
    
    def __init__(
        self,
        semantic_model_name: str = 'all-MiniLM-L6-v2',
        node_type_embedding_dim: int = 50,
        output_format: str = 'hdf5'  # 'hdf5' or 'csv'
    ):
        # Models and embeddings configuration
        self.semantic_model = SentenceTransformer(semantic_model_name)
        self.text_embedding_dim = 384  # Dimension of 'all-MiniLM-L6-v2'
        self.node_name_embedding_dim = 384  # Using the same model for node names
        
        # Initialize node type embedding
        self.node_types = self._get_node_types()
        self.node_type_to_idx = {node_type: idx for idx, node_type in enumerate(self.node_types)}
        self.node_type_embedding_dim = node_type_embedding_dim
        self.node_type_embedding_layer = nn.Embedding(len(self.node_types), self.node_type_embedding_dim)
        
        # Tag mapping and cleaning configuration
        self.tag_mapping = self._get_tag_mapping()
        self.custom_tag_removal_pattern = self._get_custom_tag_removal_pattern()
        self.default_tag = "DIV"
        self.icon_like_node_types = {"VECTOR", "INSTANCE", "COMPONENT", "SHAPE", "SVG_ICON"}
        
        # Output configuration
        self.output_format = output_format
        
        # Counters for statistics
        self.stats = {
            "files_processed": 0,
            "nodes_processed": 0,
            "json_errors": 0,
            "tag_mappings": {},
            "unique_node_types": set()
        }

    def _get_node_types(self) -> List[str]:
        """Define the node types for embedding."""
        return [
            "TEXT", "RECTANGLE", "GROUP", "ELLIPSE", "FRAME", "VECTOR", "STAR", "LINE", 
            "POLYGON", "BOOLEAN_OPERATION", "SLICE", "COMPONENT", "INSTANCE", "COMPONENT_SET", 
            "DOCUMENT", "CANVAS", "SECTION", "SHAPE_WITH_TEXT", "STICKY", "TABLE", "WASHI_TAPE", 
            "CONNECTOR", "HIGHLIGHT", "WIDGET", "EMBED", "LINK", "LINK_UNFURL", "MEDIA", "CODE_BLOCK", 
            "STAMP", "COMMENT", "FREEFORM", "TIMELINE", "STICKER", "SHAPE", "ARROW", "CALL_OUT", 
            "FLOW", "TEXT_AREA", "TEXT_FIELD", "BUTTON", "CHECKBOX", "RADIO", "TOGGLE", "SLIDER", 
            "DROPDOWN", "COMBOBOX", "LIST", "TABLE_CELL", "TABLE_ROW", "TABLE_COLUMN", "TABLE_SECTION", 
            "TABLE_HEADER", "TABLE_FOOTER", "TABLE_BODY", "TABLE_CAPTION", "TABLE_COLGROUP", "TABLE_COL", 
            "TABLE_THEAD", "TABLE_TBODY", "TABLE_TFOOT", "TABLE_TR", "TABLE_TH", "TABLE_TD", 
            "UNKNOWN_TYPE"
        ]

    def _get_tag_mapping(self) -> Dict[str, str]:
        """Define mapping for tag consolidation."""
        return {
            "ARTICLE": "DIV", "DIV": "DIV", "FIGURE": "DIV", "FOOTER": "DIV", "HEADER": "DIV", "NAV": "DIV", "MAIN": "DIV", "IFRAME": "DIV",
            "BODY" : "DIV", "FORM" : "DIV", "TABLE": "DIV", "THEAD":"DIV" , "TBODY": "DIV", "SECTION": "DIV","ASIDE":"DIV", 
            
            "UL" : "LIST", "OL" : "LIST", "DL": "LIST",
            
            "H1": "P", "H2": "P", "H3": "P", "H4": "P", "H5": "P", "H6": "P","SUP": "P","SUB": "P", "BIG": "P",
            "P": "P", "CAPTION": "P", "FIGCAPTION": "P", "B": "P", "EM": "P", "I": "P", "TD": "P", "TH": "P", "TR": "P","PRE":"P",
            "U": "P", "TIME": "P", "TXT": "P", "ABBR": "P","SMALL": "P","STRONG": "P","SUMMARY": "P","SPAN": "P", "LABEL": "P","LI":"P","DD":"P",
            "A":"P","BLOCKQUOTE":"P","CODE":"P",
            
            "PICTURE": "IMG" , "VIDEO": "IMG",
            "SELECT": "INPUT","TEXTAREA": "INPUT",
            "VECTOR": "SVG","ICON":"SVG",
            
            "UNK": "CONTAINER"
        }

    def _get_custom_tag_removal_pattern(self) -> str:
        """Define regex pattern for removing problematic tags."""
        return r'[-:]|\b(DETAILS|CANVAS|FIELDSET|COLGROUP|COL|CNX|ADDRESS|CITE|S|DEL|LEGEND|BDI|LOGO|OBJECT|OPTGROUP|CENTER|FRONT|Q|SEARCH|SLOT|AD|ADSLOT|BLINK|BOLD|COMMENTS|DATA|DIALOG|EMBED|EMPHASIS|FONT|H7|HGROUP|INS|INTERACTION|ITALIC|ITEMTEMPLATE|MATH|MENU|MI|MN|MO|MROW|MSUP|NOBR|OFFER|PATH|PROGRESS|STRIKE|SWAL|TEXT|TITLE|TT|VAR|VEV|W|WBR|COUNTRY|ESI:INCLUDE|HTTPS:|LOGIN|NOCSRIPT|PERSONAL|STONG|CONTENT|DELIVERY|LEFT|MSUBSUP|KBD|ROOT|PARAGRAPH|BE|AI2SVELTEWRAP|BANNER|PHOTO1)\b'

    def clean_and_map_tag(self, raw_tag: str) -> str:
        """Clean and map a raw HTML tag to a canonical form."""
        if not raw_tag:
            return self.default_tag
        raw_tag = raw_tag.upper()
        cleaned_tag = self.tag_mapping.get(raw_tag, raw_tag)
        if re.search(self.custom_tag_removal_pattern, cleaned_tag, re.IGNORECASE):
            cleaned_tag = self.default_tag
        final_tag = self.tag_mapping.get(cleaned_tag, cleaned_tag)
        if final_tag != raw_tag:
            self.stats["tag_mappings"][raw_tag] = self.stats["tag_mappings"].get(raw_tag, 0) + 1
        return final_tag

    def determine_bioes_label(self, base_tag: str) -> Tuple[str, Optional[str]]:
        bioes_label = ""
        
        if base_tag == "CONTAINER":
            bioes_label = "B_CONTAINER"
        else:
            bioes_label = base_tag
        return bioes_label

    def extract_features(self, 
                        node_data_item: Dict, 
                        current_body_width: float,
                        sequence_id: str,
                        parent_node_height: Optional[float] = None,
                        parent_base_tag: Optional[str] = None,
                        depth: int = 0,
                        position_in_siblings: int = 0,
                        total_siblings: int = 1) -> List[Dict]:
        """
        Extract features from a node and its children recursively.
        Appends 'E_CONTAINER' after processing children of 'B_CONTAINER' nodes.
        """
        features_and_labels_list = []
        node_dict = node_data_item.get("node", {})
        raw_tag = node_data_item.get("tag", "UNK").upper()

        # Determine base tag
        has_children = bool(node_data_item.get("children"))
        base_tag = self.clean_and_map_tag(raw_tag)
        
        # Determine label
        bioes_label = self.determine_bioes_label(base_tag)
        
        # Node type embedding
        node_type_str = node_dict.get("type", "UNKNOWN_TYPE")
        self.stats["unique_node_types"].add(node_type_str)
        node_type_idx = self.node_type_to_idx.get(node_type_str, self.node_type_to_idx.get("UNKNOWN_TYPE", 0))
        node_type_emb = self.node_type_embedding_layer(torch.tensor(node_type_idx)).detach().numpy()
        
        # Text embedding
        text_content = node_dict.get("characters", "").strip()
        text_emb = self.semantic_model.encode(text_content) if node_type_str == "TEXT" and text_content else np.zeros(self.text_embedding_dim)
        
        # Node name embedding
        node_name = node_data_item.get("name", "").strip()
        node_name_emb = self.semantic_model.encode(node_name) if node_name and (node_type_str in self.icon_like_node_types or "icon" in node_name.lower()) else np.zeros(self.node_name_embedding_dim)
        
        # Numerical & structural features
        eps = 1e-6
        node_width = float(node_dict.get("width", 0))
        node_height = float(node_dict.get("height", 0))
        aspect_ratio = node_width / (node_height + eps) if node_height > 0 else 0
        normalized_width = node_width / (current_body_width + eps) if current_body_width > 0 else 0
        normalized_height = node_height / (parent_node_height + eps) if parent_node_height and parent_node_height > 0 else 0
        
        x_position = float(node_dict.get("x", 0))
        y_position = float(node_dict.get("y", 0))
        normalized_x = x_position / (current_body_width + eps) if current_body_width > 0 else 0
        normalized_y = y_position / (parent_node_height + eps) if parent_node_height and parent_node_height > 0 else 0
        
        normalized_depth = min(depth / 20.0, 1.0)
        normalized_position = position_in_siblings / (total_siblings + eps)
        
        bg_color = [0, 0, 0, 0]
        fills = node_dict.get("fills", [])
        if fills and isinstance(fills, list) and len(fills) > 0 :
            color = fills[0]["color"]
            if color:
                bg_color = [color.get(k, 0) for k in ("r", "g", "b", "a")]
        
        font_size = float(node_dict.get("fontSize", 0)) / 100.0
        flex_direction = 1 if node_dict.get("flexDirection", "") == "column" else 0
        
        # Combine features
        feature_vector = np.concatenate([
            node_type_emb,
            text_emb,
            node_name_emb,
            [normalized_width, normalized_height, aspect_ratio,
             normalized_x, normalized_y,
             normalized_depth, normalized_position,
             *bg_color, font_size, flex_direction]
        ])
        
        # Add node features to list
        features_and_labels_list.append({
            "feature_vector": feature_vector,
            "tag": bioes_label
        })
        
        self.stats["nodes_processed"] += 1
        
        # Process children if any
        if has_children:
            children = node_data_item["children"]
            total_children = len(children)
            for child_idx, child_node in enumerate(children):
                child_features = self.extract_features(
                    node_data_item=child_node,
                    current_body_width=current_body_width,
                    sequence_id=sequence_id,
                    parent_node_height=node_height,
                    parent_base_tag=base_tag,
                    depth=depth + 1,
                    position_in_siblings=child_idx,
                    total_siblings=total_children
                )
                features_and_labels_list.extend(child_features)
            
            # Append E_CONTAINER after processing children if this is a B_CONTAINER
            if bioes_label == "B_CONTAINER":
                e_container_feature = {
                    "feature_vector": np.zeros_like(feature_vector),
                    "tag": "E_CONTAINER"
                }
                features_and_labels_list.append(e_container_feature)
                
        return features_and_labels_list

    def process_file(self, file_path: str) -> Optional[List[Dict]]:
        """Process a single JSON file."""
        try:
            with open(file_path, "r", encoding="utf-8") as f:
                json_data = json.load(f)
            sequence_id = os.path.basename(file_path).replace(".json", "")
            root_node_info = json_data.get("node", {})
            body_width = float(root_node_info.get("width", 1000.0)) or 1000.0
            features = self.extract_features(
                node_data_item=json_data,
                current_body_width=body_width,
                sequence_id=sequence_id
            )
            
            # specify end of website
            e_website_feature = {
                "feature_vector": np.ones_like(features[0]["feature_vector"]),
                "tag": "E_WEBSITE",
            }
            features.append(e_website_feature)
            
            return features
        except Exception as e:
            print(f"Error processing {file_path}: {str(e)}")
            self.stats["json_errors"] += 1
            return None

    def process_directory(self, input_dir: str, output_path: str) -> None:
        """Process and save features from all JSON files incrementally."""
        if not os.path.exists(input_dir):
            print(f"Error: Input directory '{input_dir}' does not exist.")
            return
        json_files = [f for f in os.listdir(input_dir) if f.endswith(".json")]
        if not json_files:
            print(f"No JSON files found in {input_dir}")
            return
        print(f"Found {len(json_files)} JSON files to process.")

        first_file = True  # Flag to control writing headers or overwriting

        for file_name in tqdm(json_files, desc="Processing files"):
            file_path = os.path.join(input_dir, file_name)
            features = self.process_file(file_path)
            if features:
                self._save_features(features, output_path, append=not first_file)
                self.stats["files_processed"] += 1
                first_file = False

        self._print_stats()

    def _save_features(self, features: List[Dict], output_path: str, append: bool = False) -> None:
        """Save extracted features to file incrementally."""
        if not features:
            return

        df = pd.DataFrame({
            "feature_vector": [f["feature_vector"].tolist() for f in features],
            "tag": [f["tag"] for f in features]
        })

        if self.output_format == 'csv':
            df.to_csv(output_path, index=False, header=not append, mode='a' if append else 'w')
        elif self.output_format == 'parquet':
            df.to_parquet(output_path, index=False, append=append)
        elif self.output_format == 'hdf5':
            with h5py.File(output_path, 'a' if append else 'w') as f:
                for column in df.columns:
                    data = df[column].apply(lambda x: x if not isinstance(x, list) else np.array(x)).values
                    if column == "feature_vector":
                        data = np.vstack(data)
                        dtype = np.float32
                    elif df[column].dtype == object:
                        dtype = 'S100'
                        data = np.array(data, dtype=dtype)
                    else:
                        dtype = df[column].dtype
                        data = np.array(data)

                    if column in f:
                        # Resize existing dataset and append
                        old_size = f[column].shape[0]
                        new_size = old_size + data.shape[0]
                        f[column].resize((new_size,) + f[column].shape[1:])
                        f[column][old_size:] = data
                    else:
                        maxshape = (None,) + data.shape[1:] if len(data.shape) > 1 else (None,)
                        f.create_dataset(column, data=data, maxshape=maxshape, chunks=True)

                f.attrs['num_samples'] = f[column].shape[0]
                f.attrs['feature_dim'] = data.shape[1] if len(data.shape) > 1 else 1

    def _print_stats(self) -> None:
        """Print processing statistics."""
        print("\n--- Processing Statistics ---")
        print(f"Files processed: {self.stats['files_processed']}")
        print(f"Nodes processed: {self.stats['nodes_processed']}")
        print(f"JSON errors: {self.stats['json_errors']}")
        print(f"Unique node types: {len(self.stats['unique_node_types'])}")

if __name__ == "__main__":
    input_dir = "../modified_json_data"   # "../experimental_json"
    output_path = "figma_dataset_custom.csv"
    extractor = FigmaHTMLFeatureExtractor(
        semantic_model_name='all-MiniLM-L6-v2',
        node_type_embedding_dim=50,
        output_format='csv'
    )
    extractor.process_directory(input_dir, output_path)

Found 1370 JSON files to process.


Processing files:   4%|▎         | 49/1370 [01:22<47:36,  2.16s/it] 

In [None]:
# # Read the CSV file
# df = pd.read_csv(output_path)

# # Get unique values in the 'tag' column
# unique_tags = df['tag'].unique()

# # Print the unique tags
# print(unique_tags)

['B_CONTAINER' 'P' 'DIV' 'HR' 'E_CONTAINER' 'BUTTON' 'SVG' 'LIST' 'IMG'
 'INPUT']
