In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import h5py
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm

class FigmaDataset(Dataset):
    """Dataset for Figma node sequences."""
    
    def __init__(self, features, labels, sequence_ids):
        """
        Initialize the dataset.
        
        Args:
            features (dict): Dictionary mapping sequence IDs to feature tensors
            labels (dict): Dictionary mapping sequence IDs to label tensors
            sequence_ids (list): List of unique sequence IDs
        """
        self.features = features
        self.labels = labels
        self.sequence_ids = sequence_ids
    
    def __len__(self):
        return len(self.sequence_ids)
    
    def __getitem__(self, idx):
        seq_id = self.sequence_ids[idx]
        return {
            'features': self.features[seq_id],
            'labels': self.labels[seq_id],
            'seq_id': seq_id
        }

def collate_fn(batch):
    """Custom collate function to handle variable length sequences."""
    features = [item['features'] for item in batch]
    labels = [item['labels'] for item in batch]
    seq_ids = [item['seq_id'] for item in batch]
    
    # Get sequence lengths
    lengths = torch.tensor([len(f) for f in features])
    
    # Pad sequences
    max_len = max(lengths)
    
    # Padding for features (get feature dimension from first item)
    feature_dim = features[0].shape[1]
    padded_features = torch.zeros((len(batch), max_len, feature_dim))
    
    # Padding for labels
    padded_labels = torch.ones((len(batch), max_len), dtype=torch.long) * -100  # Use -100 for ignore_index in CrossEntropyLoss
    
    # Fill padded tensors
    for i, (f, l) in enumerate(zip(features, labels)):
        seq_len = f.shape[0]
        padded_features[i, :seq_len] = f
        padded_labels[i, :seq_len] = l
    
    return {
        'features': padded_features,
        'labels': padded_labels,
        'lengths': lengths,
        'seq_ids': seq_ids
    }

class FigmaBLSTM(nn.Module):
    """Bidirectional LSTM model for Figma tag prediction."""
    
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2, dropout=0.3):
        """
        Initialize the BLSTM model.
        
        Args:
            input_dim (int): Dimension of input features
            hidden_dim (int): Dimension of hidden layers
            output_dim (int): Number of output classes (HTML tags)
            num_layers (int): Number of LSTM layers
            dropout (float): Dropout probability
        """
        super(FigmaBLSTM, self).__init__()
        
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        
        # LSTM layer
        self.lstm = nn.LSTM(
            input_dim, 
            hidden_dim, 
            num_layers=num_layers, 
            batch_first=True, 
            bidirectional=True,
            dropout=dropout if num_layers > 1 else 0
        )
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )
        
        # Output layer
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, lengths):
        """
        Forward pass through the network.
        
        Args:
            x (torch.Tensor): Input tensor of shape (batch_size, seq_len, input_dim)
            lengths (torch.Tensor): Lengths of each sequence in the batch
            
        Returns:
            torch.Tensor: Output predictions of shape (batch_size, seq_len, output_dim)
        """
        batch_size, seq_len, _ = x.size()
        
        # Pack padded sequence
        packed_input = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        
        # LSTM forward pass
        packed_output, _ = self.lstm(packed_input)
        
        # Unpack output
        output, _ = nn.utils.rnn.pad_packed_sequence(packed_output, batch_first=True)
        
        # Apply dropout
        output = self.dropout(output)
        
        # Linear layer to get logits
        logits = self.fc(output)
        
        return logits

class FigmaHTMLPredictor:
    """Class for training and predicting HTML tags from Figma features."""
    
    def __init__(self, model_config=None):
        """
        Initialize the predictor.
        
        Args:
            model_config (dict, optional): Configuration for the model
        """
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.model = None
        self.label_encoder = LabelEncoder()
        self.model_config = model_config if model_config else {
            'hidden_dim': 256,
            'num_layers': 2,
            'dropout': 0.3,
            'learning_rate': 0.001,
            'batch_size': 16,
            'epochs': 20
        }
    
    def load_data(self, data_path, test_size=0.2, random_state=42):
        """
        Load and preprocess data.
        
        Args:
            data_path (str): Path to the data file (CSV or HDF5)
            test_size (float): Proportion of data to use for testing
            random_state (int): Random seed for splitting
            
        Returns:
            tuple: train_loader, val_loader, num_classes, input_dim
        """
        print(f"Loading data from {data_path}...")
        
        # Load based on file extension
        if data_path.endswith('.csv'):
            df = pd.read_csv(data_path)
            # Convert string representation of feature vectors to numpy arrays
            df['feature_vector'] = df['feature_vector'].apply(lambda x: np.array(eval(x)))
        elif data_path.endswith('.hdf5') or data_path.endswith('.h5'):
            with h5py.File(data_path, 'r') as f:
                # Convert HDF5 to pandas DataFrame
                sequence_ids = [s.decode('utf-8') for s in f['sequence_id'][:]]
                node_ids = [s.decode('utf-8') for s in f['node_id'][:]]
                feature_vectors = f['feature_vector'][:]
                tags = [s.decode('utf-8') for s in f['tag'][:]]
                raw_tags = [s.decode('utf-8') for s in f['raw_tag'][:]]
                node_types = [s.decode('utf-8') for s in f['node_type'][:]]
                depths = f['depth'][:]
                positions = f['position'][:]
                parent_tags = [s.decode('utf-8') for s in f['parent_tag'][:]]
                
                df = pd.DataFrame({
                    'sequence_id': sequence_ids,
                    'node_id': node_ids,
                    'feature_vector': list(feature_vectors),
                    'tag': tags,
                    'raw_tag': raw_tags,
                    'node_type': node_types,
                    'depth': depths,
                    'position': positions,
                    'parent_tag': parent_tags
                })
        else:
            raise ValueError(f"Unsupported file format: {data_path}")
        
        print(f"Loaded {len(df)} records with {df['sequence_id'].nunique()} unique sequences")
        
        # Encode labels
        self.label_encoder.fit(df['tag'].unique())
        print(f"Found {len(self.label_encoder.classes_)} unique tags: {self.label_encoder.classes_}")
        
        # Get input dimension from first feature vector
        input_dim = len(df['feature_vector'].iloc[0])
        print(f"Input feature dimension: {input_dim}")
        
        # Group by sequence_id to create sequences
        sequences = {}
        for seq_id, group in df.groupby('sequence_id'):
            feature_vectors = np.stack(group['feature_vector'].values)
            labels = self.label_encoder.transform(group['tag'].values)
            sequences[seq_id] = {
                'features': torch.FloatTensor(feature_vectors),
                'labels': torch.LongTensor(labels)
            }
        
        # Split into train and validation sets
        train_seq_ids, val_seq_ids = train_test_split(
            list(sequences.keys()), 
            test_size=test_size, 
            random_state=random_state
        )
        
        print(f"Training on {len(train_seq_ids)} sequences, validating on {len(val_seq_ids)} sequences")
        
        # Create datasets
        train_features = {seq_id: sequences[seq_id]['features'] for seq_id in train_seq_ids}
        train_labels = {seq_id: sequences[seq_id]['labels'] for seq_id in train_seq_ids}
        val_features = {seq_id: sequences[seq_id]['features'] for seq_id in val_seq_ids}
        val_labels = {seq_id: sequences[seq_id]['labels'] for seq_id in val_seq_ids}
        
        train_dataset = FigmaDataset(train_features, train_labels, train_seq_ids)
        val_dataset = FigmaDataset(val_features, val_labels, val_seq_ids)
        
        # Create data loaders
        train_loader = DataLoader(
            train_dataset, 
            batch_size=self.model_config['batch_size'], 
            shuffle=True, 
            collate_fn=collate_fn
        )
        
        val_loader = DataLoader(
            val_dataset, 
            batch_size=self.model_config['batch_size'], 
            shuffle=False, 
            collate_fn=collate_fn
        )
        
        return train_loader, val_loader, len(self.label_encoder.classes_), input_dim
    
    def build_model(self, input_dim, num_classes):
        """
        Build the BLSTM model.
        
        Args:
            input_dim (int): Dimension of input features
            num_classes (int): Number of output classes
            
        Returns:
            FigmaBLSTM: The model
        """
        model = FigmaBLSTM(
            input_dim=input_dim,
            hidden_dim=self.model_config['hidden_dim'],
            output_dim=num_classes,
            num_layers=self.model_config['num_layers'],
            dropout=self.model_config['dropout']
        )
        model = model.to(self.device)
        print(f"Model built with {sum(p.numel() for p in model.parameters())} parameters")
        return model
    
    def train(self, data_path, output_dir='./models', model_name='figma_blstm_model.pt'):
        """
        Train the model.
        
        Args:
            data_path (str): Path to the data file
            output_dir (str): Directory to save the model
            model_name (str): Name of the model file
            
        Returns:
            dict: Training history
        """
        # Create output directory if it doesn't exist
        os.makedirs(output_dir, exist_ok=True)
        
        # Load data
        train_loader, val_loader, num_classes, input_dim = self.load_data(data_path)
        
        # Build model
        self.model = self.build_model(input_dim, num_classes)
        
        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss(ignore_index=-100)
        optimizer = optim.Adam(self.model.parameters(), lr=self.model_config['learning_rate'])
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
        
        # Training loop
        best_val_loss = float('inf')
        history = {'train_loss': [], 'val_loss': [], 'val_accuracy': []}
        
        print("Starting training...")
        for epoch in range(self.model_config['epochs']):
            # Training
            self.model.train()
            train_loss = 0.0
            train_batches = 0
            
            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{self.model_config['epochs']} [Train]")
            for batch in progress_bar:
                # Get batch data
                features = batch['features'].to(self.device)
                labels = batch['labels'].to(self.device)
                lengths = batch['lengths']
                
                # Forward pass
                outputs = self.model(features, lengths)
                batch_size, seq_len, num_classes = outputs.size()
                
                # Reshape outputs and labels for loss calculation
                outputs = outputs.reshape(-1, num_classes)
                labels = labels.reshape(-1)
                
                # Calculate loss
                loss = criterion(outputs, labels)
                
                # Backward pass and optimize
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)  # Gradient clipping
                optimizer.step()
                
                # Update statistics
                train_loss += loss.item()
                train_batches += 1
                progress_bar.set_postfix({'loss': loss.item()})
            
            avg_train_loss = train_loss / train_batches
            history['train_loss'].append(avg_train_loss)
            
            # Validation
            self.model.eval()
            val_loss = 0.0
            val_batches = 0
            total_correct = 0
            total_samples = 0
            
            with torch.no_grad():
                progress_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{self.model_config['epochs']} [Val]")
                for batch in progress_bar:
                    # Get batch data
                    features = batch['features'].to(self.device)
                    labels = batch['labels'].to(self.device)
                    lengths = batch['lengths']
                    
                    # Forward pass
                    outputs = self.model(features, lengths)
                    batch_size, seq_len, num_classes = outputs.size()
                    
                    # Reshape outputs and labels for loss calculation
                    outputs_flat = outputs.reshape(-1, num_classes)
                    labels_flat = labels.reshape(-1)
                    
                    # Calculate loss
                    loss = criterion(outputs_flat, labels_flat)
                    
                    # Update statistics
                    val_loss += loss.item()
                    val_batches += 1
                    
                    # Calculate accuracy
                    mask = (labels_flat != -100)
                    if mask.sum() > 0:
                        predicted = torch.argmax(outputs_flat[mask], dim=1)
                        correct = (predicted == labels_flat[mask]).sum().item()
                        total_correct += correct
                        total_samples += mask.sum().item()
                    
                    progress_bar.set_postfix({'loss': loss.item()})
            
            avg_val_loss = val_loss / val_batches
            val_accuracy = total_correct / total_samples if total_samples > 0 else 0
            
            history['val_loss'].append(avg_val_loss)
            history['val_accuracy'].append(val_accuracy)
            
            print(f"Epoch {epoch+1}/{self.model_config['epochs']}")
            print(f"  Train Loss: {avg_train_loss:.4f}")
            print(f"  Val Loss: {avg_val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}")
            
            # Update learning rate
            scheduler.step(avg_val_loss)
            
            # Save best model
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                model_path = os.path.join(output_dir, model_name)
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_loss': best_val_loss,
                    'label_encoder': self.label_encoder,
                    'model_config': self.model_config
                }, model_path)
                print(f"  Model saved to {model_path}")
        
        print("Training completed!")
        return history
    
    def load_model(self, model_path):
        """
        Load a trained model.
        
        Args:
            model_path (str): Path to the model file
        """
        checkpoint = torch.load(model_path, map_location=self.device)
        
        # Load model configuration
        self.model_config = checkpoint.get('model_config', self.model_config)
        
        # Load label encoder
        self.label_encoder = checkpoint['label_encoder']
        
        # Build model
        input_dim = checkpoint['model_state_dict']['lstm.weight_ih_l0'].size(1)
        num_classes = checkpoint['model_state_dict']['fc.weight'].size(0)
        
        self.model = FigmaBLSTM(
            input_dim=input_dim,
            hidden_dim=self.model_config['hidden_dim'],
            output_dim=num_classes,
            num_layers=self.model_config['num_layers'],
            dropout=self.model_config['dropout']
        )
        
        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.model = self.model.to(self.device)
        self.model.eval()
        
        print(f"Model loaded from {model_path}")
    
    def predict(self, features, sequence_id='test_sequence'):
        """
        Predict HTML tags for a sequence of features.
        
        Args:
            features (np.ndarray): Feature vectors of shape (seq_len, feature_dim)
            sequence_id (str): Identifier for the sequence
            
        Returns:
            list: Predicted HTML tags
        """
        if self.model is None:
            raise ValueError("Model not loaded. Call load_model() first.")
        
        # Convert to tensor
        features_tensor = torch.FloatTensor(features).unsqueeze(0)  # Add batch dimension
        lengths = torch.tensor([features_tensor.size(1)])
        
        # Move to device
        features_tensor = features_tensor.to(self.device)
        
        # Get predictions
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(features_tensor, lengths)
            predictions = torch.argmax(outputs, dim=2).squeeze(0)
        
        # Convert to class labels
        predicted_labels = self.label_encoder.inverse_transform(predictions.cpu().numpy())
        
        return predicted_labels
    
    def predict_batch(self, data_loader):
        """
        Predict HTML tags for a batch of sequences.
        
        Args:
            data_loader (DataLoader): DataLoader containing the sequences
            
        Returns:
            dict: Dictionary mapping sequence IDs to predicted tags and true tags
        """
        if self.model is None:
            raise ValueError("Model not loaded. Call load_model() first.")
        
        results = {}
        self.model.eval()
        
        with torch.no_grad():
            for batch in tqdm(data_loader, desc="Predicting"):
                # Get batch data
                features = batch['features'].to(self.device)
                labels = batch['labels']
                lengths = batch['lengths']
                seq_ids = batch['seq_ids']
                
                # Forward pass
                outputs = self.model(features, lengths)
                predictions = torch.argmax(outputs, dim=2)
                
                # Process each sequence in the batch
                for i, seq_id in enumerate(seq_ids):
                    seq_len = lengths[i].item()
                    pred_indices = predictions[i, :seq_len].cpu().numpy()
                    true_indices = labels[i, :seq_len].cpu().numpy()
                    
                    # Convert indices to tags
                    pred_tags = self.label_encoder.inverse_transform(pred_indices)
                    true_tags = [self.label_encoder.classes_[idx] if idx != -100 else "UNKNOWN" for idx in true_indices]
                    
                    results[seq_id] = {
                        'predicted_tags': pred_tags,
                        'true_tags': true_tags
                    }
        
        return results

def main():
    # Configuration
    model_config = {
        'hidden_dim': 256,
        'num_layers': 2,
        'dropout': 0.3,
        'learning_rate': 0.001,
        'batch_size': 16,
        'epochs': 20
    }
    
    # Initialize predictor
    predictor = FigmaHTMLPredictor(model_config)
    
    # Train model
    data_path = "figma_dataset_custom.csv"  # Path to the output from FigmaHTMLFeatureExtractor
    output_dir = "./models"
    
    # Train the model
    print("Training BLSTM model...")
    history = predictor.train(data_path, output_dir)
    
    # Example of loading and using the model
    print("\nLoading trained model...")
    predictor.load_model(os.path.join(output_dir, "figma_blstm_model.pt"))
    
    # Load test data for prediction demonstration
    _, test_loader, _, _ = predictor.load_data(data_path, test_size=0.2)
    
    # Make predictions
    print("\nMaking predictions on test data...")
    results = predictor.predict_batch(test_loader)
    
    # Print sample predictions
    print("\nSample predictions:")
    for i, (seq_id, seq_results) in enumerate(list(results.items())[:3]):
        print(f"\nSequence {i+1} (ID: {seq_id}):")
        
        # Get the true and predicted tags
        true_tags = seq_results['true_tags']
        pred_tags = seq_results['predicted_tags']
        
        # Print first 10 predictions
        for j in range(min(10, len(true_tags))):
            print(f"  Node {j+1}: True={true_tags[j]}, Predicted={pred_tags[j]}")
    
    print("\nModel training and evaluation completed!")

if __name__ == "__main__":
    main()