In [7]:
import os
import gc
import psutil
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
import h5py
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, classification_report, confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm
import warnings
from torch.cuda.amp import autocast, GradScaler
import json
import time

warnings.filterwarnings('ignore')

# Memory monitoring function
def print_memory_usage(printing=True):
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    memory_gb = memory_info.rss / 1024**3
    if printing:
        print(f"Memory usage: {memory_gb:.2f} GB")
    return memory_gb

# Set memory limit and device
MEMORY_LIMIT_GB = 7.5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize GradScaler for mixed precision
scaler = GradScaler()

class EarlyStopping:
    """Early stopping to prevent overfitting."""
    def __init__(self, patience=7, min_delta=0.0001, restore_best_weights=True):
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        """Save model checkpoint."""
        self.best_weights = model.state_dict().copy()

class TreeNode:
    def __init__(self, feature_vector, label, node_index, parent_index):
        self.feature_vector = feature_vector
        self.label = label
        self.node_index = node_index
        self.parent_index = parent_index
        self.children = []

class LazyFigmaTreeDataset(Dataset):
    """Lazy loading dataset for Figma tree structures."""
    
    def __init__(self, data_path, tree_metadata, label_encoder, cache_size=100, expected_dim=831):
        """
        Args:
            data_path: Path to the HDF5 data file
            tree_metadata: List of dicts with tree info (tree_id, node_indices)
            label_encoder: Fitted label encoder
            cache_size: Number of trees to keep in memory cache
            expected_dim: Expected feature vector dimension
        """
        self.data_path = data_path
        self.tree_metadata = tree_metadata
        self.label_encoder = label_encoder
        self.cache_size = cache_size
        self.expected_dim = expected_dim
        self.cache = {}
        self.cache_order = []
        
        self._file_handle = None
        self._open_file()
    
    def _open_file(self):
        """Open file handle for HDF5."""
        self._file_handle = h5py.File(self.data_path, 'r')
    
    def _close_file(self):
        """Close file handle."""
        if self._file_handle is not None:
            self._file_handle.close()
            self._file_handle = None
    
    def __del__(self):
        """Cleanup file handle."""
        self._close_file()
    
    def _load_tree_from_file(self, tree_metadata):
        """Load a single tree from file."""
        tree_id = tree_metadata['tree_id']
        node_indices = tree_metadata['node_indices']
        
        if self._file_handle is None:
            self._open_file()
        
        tree_data = {
            'tree_id': self._file_handle['tree_id'][node_indices],
            'node_index': self._file_handle['node_index'][node_indices],
            'parent_index': self._file_handle['parent_index'][node_indices],
            'feature_vector': self._file_handle['feature_vector'][node_indices],
            'tag': [s.decode('utf-8') for s in self._file_handle['tag'][node_indices]]
        }
        
        nodes = {}
        for i in range(len(node_indices)):
            node_index = tree_data['node_index'][i]
            parent_index = tree_data['parent_index'][i]
            feature_vector = tree_data['feature_vector'][i]
            if feature_vector.shape[0] != self.expected_dim:
                raise ValueError(
                    f"Invalid feature vector dimension for tree_id={tree_id}, "
                    f"node_index={node_index}: got {feature_vector.shape[0]}, expected {self.expected_dim}"
                )
            tag = tree_data['tag'][i]
            label = self.label_encoder.transform([tag])[0]
            nodes[node_index] = TreeNode(feature_vector, label, node_index, parent_index)
        
        for node in nodes.values():
            if node.parent_index != -1:
                parent = nodes.get(node.parent_index)
                if parent:
                    parent.children.append(node)
        
        root = next((node for node in nodes.values() if node.parent_index == -1), None)
        return root
    
    def _manage_cache(self, tree_id, tree_data):
        """Manage LRU cache for trees."""
        if tree_id in self.cache:
            self.cache_order.remove(tree_id)
        
        self.cache[tree_id] = tree_data
        self.cache_order.append(tree_id)
        
        while len(self.cache) > self.cache_size:
            oldest_tree_id = self.cache_order.pop(0)
            del self.cache[oldest_tree_id]
    
    def __len__(self):
        return len(self.tree_metadata)
    
    def __getitem__(self, idx):
        tree_metadata = self.tree_metadata[idx]
        tree_id = tree_metadata['tree_id']
        
        if tree_id in self.cache:
            self.cache_order.remove(tree_id)
            self.cache_order.append(tree_id)
            tree_root = self.cache[tree_id]
        else:
            tree_root = self._load_tree_from_file(tree_metadata)
            self._manage_cache(tree_id, tree_root)
        
        return {'tree_root': tree_root, 'tree_id': tree_id}

def create_tree_metadata(data_path, max_trees=None):
    """Create metadata for tree structures."""
    print(f"Creating tree metadata from {data_path}...")
    
    with h5py.File(data_path, 'r') as f:
        tree_ids = f['tree_id'][:]
        node_indices = f['node_index'][:]
        tags = [s.decode('utf-8') for s in f['tag'][:]]
        
        unique_tree_ids = np.unique(tree_ids)
        tree_metadata = []
        all_tags = set()
        
        for tree_id in unique_tree_ids:
            tree_node_indices = np.where(tree_ids == tree_id)[0]
            tree_metadata.append({
                'tree_id': tree_id,
                'node_indices': tree_node_indices
            })
            for idx in tree_node_indices:
                all_tags.add(tags[idx])
            
            if max_trees and len(tree_metadata) >= max_trees:
                break
    
    print(f"Found {len(tree_metadata)} trees")
    print(f"Found {len(all_tags)} unique tags")
    return tree_metadata, sorted(list(all_tags))

def tree_collate_fn(batch):
    """Collate function for tree-based dataset."""
    trees = [item['tree_root'] for item in batch]
    tree_ids = [item['tree_id'] for item in batch]
    return {'trees': trees, 'tree_ids': tree_ids}

class TreeLSTMCell(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(TreeLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.iou = nn.Linear(input_dim + hidden_dim * 2, 3 * hidden_dim)
        self.f = nn.Linear(hidden_dim * 2, hidden_dim)
        
    def forward(self, x, h_children, c_children):
        print(f"x shape: {x.shape}")  # Debug
        h_sum = sum(h_children) if h_children else torch.zeros(self.hidden_dim * 2).to(device)
        print(f"h_sum shape: {h_sum.shape}")  # Debug
        combined = torch.cat([x, h_sum], dim=-1)
        print(f"Combined input shape: {combined.shape}")  # Debug
        iou = self.iou(combined)
        i, o, u = torch.chunk(iou, 3, dim=-1)
        i, o, u = torch.sigmoid(i), torch.sigmoid(o), torch.tanh(u)
        
        f_list = [torch.sigmoid(self.f(h_c)) for h_c in h_children] if h_children else []
        c = sum(f * c for f, c in zip(f_list, c_children)) if h_children else torch.zeros_like(u)
        c = i * u + c
        h = o * torch.tanh(c)
        return h, c

class TreeBLSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(TreeBLSTM, self).__init__()
        self.forward_cell = TreeLSTMCell(input_dim, hidden_dim)
        self.backward_cell = TreeLSTMCell(input_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)
        self.dropout = nn.Dropout(0.3)
    
    def forward(self, tree):
        h_forward, c_forward, h_backward, c_backward = self._forward_backward(tree)
        h_combined = torch.cat([h_forward, h_backward], dim=-1)
        h_combined = self.dropout(h_combined)
        logits = self.fc(h_combined)
        return logits, h_combined
    
    def _forward_backward(self, node):
        if not node.children:
            x = torch.tensor(node.feature_vector, dtype=torch.float32).to(device)
            h_f, c_f = self.forward_cell(x, [], [])
            h_b, c_b = self.backward_cell(x, [], [])
            return h_f, c_f, h_b, c_b
        else:
            h_children_f = []
            c_children_f = []
            h_children_b = []
            c_children_b = []
            
            for child in node.children:
                h_f, c_f, h_b, c_b = self._forward_backward(child)
                h_children_f.append(h_f)
                c_children_f.append(c_f)
                h_children_b.append(h_b)
                c_children_b.append(c_b)
            
            x = torch.tensor(node.feature_vector, dtype=torch.float32).to(device)
            h_f, c_f = self.forward_cell(x, h_children_f, c_children_f)
            h_b, c_b = self.backward_cell(x, h_children_b, c_children_b)
            return h_f, c_f, h_b, c_b

def safe_tree_training_step(model, tree, criterion, optimizer, device):
    """Safe training step for tree-based model."""
    try:
        optimizer.zero_grad()
        
        with autocast():
            logits, _ = model(tree)
            label = torch.tensor([tree.label], dtype=torch.long).to(device)
            loss = criterion(logits, label)
        
        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return loss.item()
    
    except RuntimeError as e:
        if "out of memory" in str(e):
            print(f"WARNING: Out of memory error: {e}")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            return None
        else:
            raise e

def evaluate_tree_model(model, data_loader, criterion, device, label_encoder, phase="Validation"):
    """Evaluate tree-based model."""
    model.eval()
    total_loss = 0.0
    all_predictions = []
    all_labels = []
    successful_trees = 0
    
    with torch.no_grad():
        progress_bar = tqdm(data_loader, desc=f"{phase} Evaluation")
        for batch in progress_bar:
            trees = batch['trees']
            for tree in trees:
                try:
                    with autocast():
                        logits, _ = model(tree)
                        label = torch.tensor([tree.label], dtype=torch.long).to(device)
                        loss = criterion(logits, label)
                    
                    total_loss += loss.item()
                    successful_trees += 1
                    
                    prediction = torch.argmax(logits, dim=1).cpu().numpy()[0]
                    label = label.cpu().numpy()[0]
                    all_predictions.append(prediction)
                    all_labels.append(label)
                    
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                    
                    progress_bar.set_postfix({
                        'loss': f'{loss.item():.4f}',
                        'mem': f'{print_memory_usage(printing=False):.1f}GB'
                    })
                
                except RuntimeError as e:
                    if "out of memory" in str(e):
                        print(f"Skipping {phase.lower()} tree due to OOM")
                        if torch.cuda.is_available():
                            torch.cuda.empty_cache()
                        continue
                    else:
                        raise e
    
    avg_loss = total_loss / successful_trees if successful_trees > 0 else float('inf')
    
    if len(all_predictions) > 0:
        accuracy = accuracy_score(all_labels, all_predictions)
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels, all_predictions, average='weighted', zero_division=0
        )
        
        precision_pc, recall_pc, f1_pc, support_pc = precision_recall_fscore_support(
            all_labels, all_predictions, average=None, zero_division=0, labels=range(len(label_encoder.classes_))
        )
        
        report = classification_report(
            all_labels, all_predictions, 
            target_names=label_encoder.classes_, 
            zero_division=0,
            output_dict=True
        )
        
        metrics = {
            'loss': avg_loss,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'precision_per_class': precision_pc,
            'recall_per_class': recall_pc,
            'f1_per_class': f1_pc,
            'support_per_class': support_pc,
            'classification_report': report,
            'predictions': all_predictions,
            'labels': all_labels
        }
    else:
        metrics = {
            'loss': avg_loss,
            'accuracy': 0.0,
            'precision': 0.0,
            'recall': 0.0,
            'f1': 0.0
        }
    
    return metrics

def plot_confusion_matrix(y_true, y_pred, class_names, title="Confusion Matrix", save_path=None):
    """Plot and save confusion matrix."""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Confusion matrix saved to {save_path}")
    
    plt.show()

def plot_training_history(train_losses, val_losses, train_accuracies, val_accuracies, save_path=None):
    """Plot training history."""
    epochs = range(1, len(train_losses) + 1)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss')
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss')
    ax1.set_title('Training and Validation Loss')
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    ax2.plot(epochs, train_accuracies, 'b-', label='Training Accuracy')
    ax2.plot(epochs, val_accuracies, 'r-', label='Validation Accuracy')
    ax2.set_title('Training and Validation Accuracy')
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        print(f"Training history saved to {save_path}")
    
    plt.show()

def save_detailed_results(metrics, label_encoder, output_dir, phase="test"):
    """Save detailed evaluation results."""
    results_path = os.path.join(output_dir, f"{phase}_results.json")
    
    json_metrics = {key: value.tolist() if isinstance(value, np.ndarray) else value 
                    for key, value in metrics.items() if key not in ['predictions', 'labels']}
    json_metrics['class_names'] = label_encoder.classes_.tolist()
    
    with open(results_path, 'w') as f:
        json.dump(json_metrics, f, indent=2)
    print(f"Detailed results saved to {results_path}")
    
    if 'precision_per_class' in metrics:
        per_class_df = pd.DataFrame({
            'class': label_encoder.classes_,
            'precision': metrics['precision_per_class'],
            'recall': metrics['recall_per_class'],
            'f1_score': metrics['f1_per_class'],
            'support': metrics['support_per_class']
        })
        per_class_path = os.path.join(output_dir, f"{phase}_per_class_metrics.csv")
        per_class_df.to_csv(per_class_path, index=False)
        print(f"Per-class metrics saved to {per_class_path}")

def train_tree_model(data_path, model_config, max_trees=None):
    output_dir = './models'
    os.makedirs(output_dir, exist_ok=True)
    
    # Remove existing checkpoints
    for checkpoint in ["figma_tree_blstm_model.pt", "figma_tree_blstm_final_model.pt"]:
        checkpoint_path = os.path.join(output_dir, checkpoint)
        if os.path.exists(checkpoint_path):
            os.remove(checkpoint_path)
            print(f"Removed existing checkpoint: {checkpoint_path}")
    
    tree_metadata, all_tags = create_tree_metadata(data_path, max_trees)
    
    label_encoder = LabelEncoder()
    label_encoder.fit(all_tags)
    num_classes = len(label_encoder.classes_)
    
    with h5py.File(data_path, 'r') as f:
        input_dim = f['feature_vector'][0].shape[0]
        print(f"First feature vector dimension: {input_dim}")
        feature_vectors = f['feature_vector'][:]
        dims = np.array([fv.shape[0] for fv in feature_vectors])
        unique_dims = np.unique(dims)
        print(f"Unique feature vector dimensions: {unique_dims}")
        if len(unique_dims) > 1:
            tree_ids = f['tree_id'][:]
            node_indices = f['node_index'][:]
            tags = [s.decode('utf-8') for s in f['tag'][:]]
            for dim in unique_dims:
                count = np.sum(dims == dim)
                print(f"Found {count} feature vectors with dimension {dim}")
            anomalous_indices = np.where(dims != input_dim)[0]
            for idx in anomalous_indices:
                print(f"Anomaly: tree_id={tree_ids[idx]}, node_index={node_indices[idx]}, tag={tags[idx]}, dim={dims[idx]}")
            raise ValueError("Inconsistent feature vector dimensions in dataset")
    
    print(f"Number of classes: {num_classes}")
    
    train_metadata, temp_metadata = train_test_split(tree_metadata, test_size=0.3, random_state=42)
    val_metadata, test_metadata = train_test_split(temp_metadata, test_size=0.5, random_state=42)
    
    print(f"Train trees: {len(train_metadata)}")
    print(f"Validation trees: {len(val_metadata)}")
    print(f"Test trees: {len(test_metadata)}")
    
    cache_size = min(50, len(train_metadata) // 4)
    
    train_dataset = LazyFigmaTreeDataset(data_path, train_metadata, label_encoder, cache_size=cache_size, expected_dim=input_dim)
    val_dataset = LazyFigmaTreeDataset(data_path, val_metadata, label_encoder, cache_size=cache_size//2, expected_dim=input_dim)
    test_dataset = LazyFigmaTreeDataset(data_path, test_metadata, label_encoder, cache_size=cache_size//2, expected_dim=input_dim)
    
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, collate_fn=tree_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, collate_fn=tree_collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=tree_collate_fn)
    
    # Initialize model
    model = TreeBLSTM(input_dim=input_dim, hidden_dim=model_config['hidden_dim'], output_dim=num_classes).to(device)
    print(f"Model initialized with input_dim={input_dim}, hidden_dim={model_config['hidden_dim']}")
    
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=model_config['learning_rate'])
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)
    
    early_stopping = EarlyStopping(patience=model_config['early_stopping_patience'])
    
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []
    
    print("Starting training with tree-based BLSTM...")
    start_time = time.time()
    
    for epoch in range(model_config['epochs']):
        print(f"Epoch {epoch+1}/{model_config['epochs']}")
        model.train()
        train_loss = 0.0
        successful_trees = 0
        
        for batch in tqdm(train_loader, desc="Training"):
            trees = batch['trees']
            for tree in trees:
                loss = safe_tree_training_step(model, tree, criterion, optimizer, device)
                if loss is not None:
                    train_loss += loss
                    successful_trees += 1
        
        avg_train_loss = train_loss / successful_trees if successful_trees > 0 else float('inf')
        
        val_metrics = evaluate_tree_model(model, val_loader, criterion, device, label_encoder, "Validation")
        train_metrics = evaluate_tree_model(model, train_loader, criterion, device, label_encoder, "Training")
        
        train_losses.append(avg_train_loss)
        val_losses.append(val_metrics['loss'])
        train_accuracies.append(train_metrics['accuracy'])
        val_accuracies.append(val_metrics['accuracy'])
        
        print(f"Train Loss: {avg_train_loss:.4f}, Train Acc: {train_metrics['accuracy']:.4f}")
        print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.4f}")
        
        model_path = os.path.join(output_dir, "figma_tree_blstm_model.pt")
        torch.save({
            'model_state_dict': model.state_dict(),
            'label_encoder': label_encoder,
            'model_config': model_config,
            'input_dim': input_dim,
            'num_classes': num_classes,
            'tree_metadata': tree_metadata,
            'train_losses': train_losses,
            'val_losses': val_losses,
            'train_accuracies': train_accuracies,
            'val_accuracies': val_accuracies
        }, model_path)
        
        scheduler.step(val_metrics['loss'])
        if early_stopping(val_metrics['loss'], model):
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    
    total_time = time.time() - start_time
    print(f"Training completed in {total_time:.2f} seconds")
    
    plot_training_history(train_losses, val_losses, train_accuracies, val_accuracies, 
                         os.path.join(output_dir, "training_history.png"))
    
    test_metrics = evaluate_tree_model(model, test_loader, criterion, device, label_encoder, "Test")
    print(f"Test Accuracy: {test_metrics['accuracy']:.4f}, F1: {test_metrics['f1']:.4f}")
    
    save_detailed_results(test_metrics, label_encoder, output_dir, "test")
    if len(test_metrics['predictions']) > 0:
        plot_confusion_matrix(test_metrics['labels'], test_metrics['predictions'], label_encoder.classes_,
                             save_path=os.path.join(output_dir, "test_confusion_matrix.png"))
    
    final_model_path = os.path.join(output_dir, "figma_tree_blstm_final_model.pt")
    torch.save({
        'model_state_dict': model.state_dict(),
        'label_encoder': label_encoder,
        'model_config': model_config,
        'input_dim': input_dim,
        'test_metrics': test_metrics
    }, final_model_path)
    
    train_dataset._close_file()
    val_dataset._close_file()
    test_dataset._close_file()
    
    return model, label_encoder, test_metrics

if __name__ == "__main__":
    model_config = {
        'hidden_dim': 128,
        'dropout': 0.3,
        'learning_rate': 0.001,
        'batch_size': 1,
        'epochs': 50,
        'early_stopping_patience': 10
    }
    
    DATA_PATH = "figma_dataset_tree_blstm.h5"  # Use corrected dataset
    MAX_TREES = 2000
    
    model, label_encoder, test_metrics = train_tree_model(DATA_PATH, model_config, max_trees=MAX_TREES)
    print("Training completed successfully!")

Using device: cuda
Creating tree metadata from figma_dataset_tree_blstm.h5...
Found 1370 trees
Found 11 unique tags
First feature vector dimension: 831
Unique feature vector dimensions: [831]
Number of classes: 11
Train trees: 959
Validation trees: 205
Test trees: 206
Model initialized with input_dim=831, hidden_dim=128
Starting training with tree-based BLSTM...
Epoch 1/50


Training:   0%|          | 0/959 [00:00<?, ?it/s]

x shape: torch.Size([831])
h_sum shape: torch.Size([256])
Combined input shape: torch.Size([1087])
x shape: torch.Size([831])
h_sum shape: torch.Size([256])
Combined input shape: torch.Size([1087])
x shape: torch.Size([831])
h_sum shape: torch.Size([128])
Combined input shape: torch.Size([959])





RuntimeError: mat1 and mat2 shapes cannot be multiplied (1x959 and 1087x384)