In [None]:
# Install PyTorch first (CUDA 11.8)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -q

# Install PyTorch Geometric and extensions using pre-built wheels
!pip install torch-geometric -q
!pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.7.0+cu118.html -q

# Install other utilities
!pip install matplotlib numpy scipy pandas psutil -q

print("All packages installed successfully!")

In [None]:
# Core libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# PyTorch Geometric
from torch_geometric.data import Data
from torch_geometric.nn import GATConv, GCNConv
from torch_geometric.loader import DataLoader

# Numerical and visualization
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.patches import Rectangle

# System utilities
import os
import time
import gc
import psutil
from scipy.spatial import KDTree
from collections import defaultdict, Counter

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print(f"   PyTorch version: {torch.__version__}")

if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

print("\nAll libraries imported successfully!")

In [None]:
# ============================================
# CONFIGURE YOUR CIRCUITNET PATHS HERE
# ============================================

CIRCUITNET_BASE = r"H:\Labs\Generative Ai\Ayush1\Ayush\CircuitNet"

# Data paths
PLACEMENT_PATH = os.path.join(CIRCUITNET_BASE, "instance_placement_micron-002", "instance_placement_micron")
NODE_ATTR_PATH = os.path.join(CIRCUITNET_BASE, "graph_features", "graph_information", "node_attr")
NET_ATTR_PATH = os.path.join(CIRCUITNET_BASE, "graph_features", "graph_information", "net_attr")
PIN_ATTR_PATH = os.path.join(CIRCUITNET_BASE, "graph_features", "graph_information", "pin_attr")
CONGESTION_PATH = os.path.join(CIRCUITNET_BASE, "congestion")

# Verify paths exist
print("Verifying dataset paths...\n")

paths_to_check = {
    "Base Directory": CIRCUITNET_BASE,
    "Placement Data": PLACEMENT_PATH,
    "Node Attributes": NODE_ATTR_PATH,
    "Net Attributes": NET_ATTR_PATH,
    "Pin Attributes": PIN_ATTR_PATH,
}

all_exist = True
for name, path in paths_to_check.items():
    exists = os.path.exists(path)
    status = "[OK]" if exists else "[MISSING]"
    print(f"{status} {name}: {path}")
    if not exists:
        all_exist = False

print()
if all_exist:
    print("All paths verified! Ready to load data.")
else:
    print("WARNING: Some paths are missing. Please check your CircuitNet installation.")
    print("\nDownload from: https://drive.google.com/drive/folders/1GjW-1LBx1563bg3pHQGvhcEyK2A9sYUB")

In [None]:
# Load dataset
MAX_SAMPLES = 500  # Increased from 100 for better generalization

print(f"Loading {MAX_SAMPLES} samples from CircuitNet...\n")
circuitnet_dataset = load_circuitnet_dataset(max_samples=MAX_SAMPLES)

if circuitnet_dataset:
    # Split into train/test (80/20)
    split_idx = int(len(circuitnet_dataset) * 0.8)
    cn_train = circuitnet_dataset[:split_idx]
    cn_test = circuitnet_dataset[split_idx:]
    
    print(f"\nDataset Statistics:")
    print(f"   Total samples: {len(circuitnet_dataset)}")
    print(f"   Training samples: {len(cn_train)}")
    print(f"   Test samples: {len(cn_test)}")
    print(f"   Cells per sample: ~{circuitnet_dataset[0].num_cells:,}")
    print(f"   Edges per sample: ~{circuitnet_dataset[0].edge_index.shape[1]:,}")
    print(f"\nDataset ready for training!")
else:
    print("Failed to load dataset. Check your paths!")

In [None]:
class VLSIPlacementGNN(nn.Module):
    """
    Graph Attention Network for VLSI cell placement prediction
    
    Industry-Aware Architecture (v2 - Anti-Collapse):
    - Input: 16 features per cell (size, connectivity, macro classification, context)
    - 4 GAT layers with multi-head attention + residual connections
    - NO Sigmoid bottleneck - uses scaled tanh for full [0,1] range utilization
    - Residual connections prevent gradient vanishing in deep GNN
    - Output: (x, y) coordinates for each cell
    """
    
    def __init__(self, input_dim=16, hidden_dim=128, output_dim=2, num_layers=4, heads=4):
        super(VLSIPlacementGNN, self).__init__()
        
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.num_layers = num_layers
        
        # Input projection
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.input_norm = nn.LayerNorm(hidden_dim)
        
        # GAT layers with attention + layer norms for residual connections
        self.gat_layers = nn.ModuleList()
        self.layer_norms = nn.ModuleList()
        for i in range(num_layers):
            in_channels = hidden_dim
            out_channels = hidden_dim
            self.gat_layers.append(
                GATConv(in_channels, out_channels // heads, heads=heads, dropout=0.1, concat=True)
            )
            self.layer_norms.append(nn.LayerNorm(hidden_dim))
        
        # Output projection - no Sigmoid! Use clamped output instead
        # Sigmoid causes center collapse by squashing gradients at extremes
        self.output_proj = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim // 2, output_dim)
        )
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        # Input projection
        x = self.input_proj(x)
        x = self.input_norm(x)
        x = F.relu(x)
        
        # GAT layers with RESIDUAL connections
        for i, (gat_layer, layer_norm) in enumerate(zip(self.gat_layers, self.layer_norms)):
            residual = x
            x = gat_layer(x, edge_index)
            x = layer_norm(x)
            if i < len(self.gat_layers) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=0.1, training=self.training)
            # Residual connection - prevents information loss in deep GNN
            x = x + residual
        
        # Output projection
        out = self.output_proj(x)
        
        # Clamp to [0, 1] instead of Sigmoid
        # This preserves gradients at boundaries (Sigmoid kills them)
        out = out.clamp(0.0, 1.0)
        
        return out

# Create model with 16 input features (industry-relevant)
model = VLSIPlacementGNN(
    input_dim=16,      # 16 industry-relevant features
    hidden_dim=128,
    output_dim=2,
    num_layers=4,
    heads=4
).to(device)

# Count parameters
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Architecture (v2 - Anti-Collapse):\n")
print(model)
print(f"\nModel Statistics:")
print(f"   Total parameters: {num_params:,}")
print(f"   Input features: 16 dimensions (industry-relevant)")
print(f"   Key improvements over v1:")
print(f"     - Residual connections (prevents gradient vanishing)")
print(f"     - LayerNorm (stabilizes training)")
print(f"     - No Sigmoid (prevents center collapse)")
print(f"     - Clamp [0,1] output (preserves gradients at boundaries)")
print(f"   Device: {device}")
print("\nModel created successfully!")

In [None]:
# Training loop with resume capability
train_losses = []
test_losses = []
start_epoch = 0

# Check if model exists and load it
model_save_path = r"H:\Labs\Generative Ai\Ayush\vlsi_placement_model.pth"

if os.path.exists(model_save_path):
    print("=" * 80)
    print("LOADING EXISTING MODEL")
    print("=" * 80)
    
    checkpoint = torch.load(model_save_path)
    
    # Check if saved model has same architecture (input_dim may differ)
    saved_config = checkpoint.get('model_config', {})
    saved_input_dim = saved_config.get('input_dim', 10)
    current_input_dim = model.input_dim
    
    if saved_input_dim != current_input_dim:
        print(f"   WARNING: Saved model has input_dim={saved_input_dim}, current model has input_dim={current_input_dim}")
        print(f"   Architecture changed (new industry-relevant features). Training from scratch.")
        print("=" * 80)
        print()
    else:
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        train_losses = checkpoint.get('train_losses', [])
        test_losses = checkpoint.get('test_losses', [])
        
        print(f"Model loaded from: {model_save_path}")
        print(f"   Previous epochs completed: {start_epoch}")
        print(f"   Previous train loss: {train_losses[-1]:.6f}" if train_losses else "   No previous train loss")
        print(f"   Previous test loss: {test_losses[-1]:.6f}" if test_losses else "   No previous test loss")
        print(f"   Resuming training from epoch {start_epoch + 1}")
        print("=" * 80)
        print()
else:
    print("=" * 80)
    print("STARTING TRAINING FROM SCRATCH")
    print("=" * 80)
    print()

print("Starting Training...\n")
print("=" * 80)

start_time = time.time()

for epoch in range(start_epoch, start_epoch + NUM_EPOCHS):
    model.train()
    epoch_loss = 0
    batch_count = 0
    
    # Training
    for data in cn_train:
        data = data.to(device)
        optimizer.zero_grad()
        
        # Forward pass
        pred = model(data)
        loss = criterion(pred, data.y)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
        batch_count += 1
    
    avg_train_loss = epoch_loss / batch_count
    train_losses.append(avg_train_loss)
    
    # Evaluation
    model.eval()
    test_loss = 0
    test_count = 0
    
    with torch.no_grad():
        for data in cn_test:
            data = data.to(device)
            pred = model(data)
            loss = criterion(pred, data.y)
            test_loss += loss.item()
            test_count += 1
    
    avg_test_loss = test_loss / test_count
    test_losses.append(avg_test_loss)
    
    # Print progress
    elapsed = time.time() - start_time
    print(f"Epoch {epoch+1:2d}/{start_epoch + NUM_EPOCHS} | "
          f"Train Loss: {avg_train_loss:.6f} | "
          f"Test Loss: {avg_test_loss:.6f} | "
          f"Time: {elapsed:.1f}s")

print("=" * 80)
print(f"\nTraining complete! Total time: {elapsed/60:.1f} minutes")
print(f"   Final train loss: {train_losses[-1]:.6f}")
print(f"   Final test loss: {test_losses[-1]:.6f}")
print(f"   Total epochs completed: {start_epoch + NUM_EPOCHS}")

In [None]:
# Save model
model_save_path = r"H:\Labs\Generative Ai\Ayush1\Ayush\vlsi_placement_model.pth"

# Calculate total epochs (start_epoch + NUM_EPOCHS)
total_epochs = start_epoch + NUM_EPOCHS

torch.save({
    'epoch': total_epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'train_losses': train_losses,
    'test_losses': test_losses,
    'model_config': {
        'input_dim': 16,
        'hidden_dim': 128,
        'output_dim': 2,
        'num_layers': 4,
        'heads': 4
    }
}, model_save_path)

print(f"Model saved to: {model_save_path}")
print(f"   Total epochs completed: {total_epochs}")
print(f"   File size: {os.path.getsize(model_save_path) / 1e6:.2f} MB")
print(f"\nTo load later:")
print(f"   checkpoint = torch.load('{model_save_path}')")
print(f"   model.load_state_dict(checkpoint['model_state_dict'])")

In [None]:
def visualize_industry_layout(data, predictions, chip_width_microns=1000, chip_height_microns=1000, dpi=150):
    """
    Industry-grade layout visualization with micron precision
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 8), dpi=dpi)
    
    # Convert normalized coordinates to microns
    predicted_microns = predictions * np.array([chip_width_microns, chip_height_microns])
    actual_microns = data.y.cpu().numpy() * np.array([chip_width_microns, chip_height_microns])
    
    # Plot 1: Predicted Layout
    ax1 = axes[0]
    ax1.set_xlim(0, chip_width_microns)
    ax1.set_ylim(0, chip_height_microns)
    ax1.set_aspect('equal')
    ax1.set_facecolor('#1a1a1a')
    
    # Draw cells with different colors based on size
    for i, (x, y) in enumerate(predicted_microns):
        # Estimate cell size from node features (simplified)
        cell_width = max(5, min(50, data.x[i, 3].item() * 20)) if data.x.shape[1] > 3 else 10
        cell_height = cell_width * 0.8
        
        rect = Rectangle((x - cell_width/2, y - cell_height/2), 
                        cell_width, cell_height,
                        facecolor='cyan', edgecolor='white', 
                        alpha=0.7, linewidth=0.5)
        ax1.add_patch(rect)
    
    # Draw connections (sample)
    edge_index = data.edge_index.cpu().numpy()
    for i in range(0, min(500, edge_index.shape[1]), 5):  # Sample edges
        src, dst = edge_index[:, i]
        ax1.plot([predicted_microns[src, 0], predicted_microns[dst, 0]],
                [predicted_microns[src, 1], predicted_microns[dst, 1]],
                'yellow', alpha=0.2, linewidth=0.3)
    
    ax1.set_title('Predicted Layout (Micron Precision)', fontsize=14, color='white', pad=20)
    ax1.set_xlabel('X Position (µm)', fontsize=12, color='white')
    ax1.set_ylabel('Y Position (µm)', fontsize=12, color='white')
    ax1.tick_params(colors='white')
    ax1.grid(True, alpha=0.2, color='gray')
    
    # Plot 2: Actual Layout
    ax2 = axes[1]
    ax2.set_xlim(0, chip_width_microns)
    ax2.set_ylim(0, chip_height_microns)
    ax2.set_aspect('equal')
    ax2.set_facecolor('#1a1a1a')
    
    for i, (x, y) in enumerate(actual_microns):
        cell_width = max(5, min(50, data.x[i, 3].item() * 20)) if data.x.shape[1] > 3 else 10
        cell_height = cell_width * 0.8
        
        rect = Rectangle((x - cell_width/2, y - cell_height/2), 
                        cell_width, cell_height,
                        facecolor='lime', edgecolor='white', 
                        alpha=0.7, linewidth=0.5)
        ax2.add_patch(rect)
    
    # Draw connections
    for i in range(0, min(500, edge_index.shape[1]), 5):
        src, dst = edge_index[:, i]
        ax2.plot([actual_microns[src, 0], actual_microns[dst, 0]],
                [actual_microns[src, 1], actual_microns[dst, 1]],
                'yellow', alpha=0.2, linewidth=0.3)
    
    ax2.set_title('Actual Layout (Ground Truth)', fontsize=14, color='white', pad=20)
    ax2.set_xlabel('X Position (µm)', fontsize=12, color='white')
    ax2.set_ylabel('Y Position (µm)', fontsize=12, color='white')
    ax2.tick_params(colors='white')
    ax2.grid(True, alpha=0.2, color='gray')
    
    plt.tight_layout()
    plt.savefig('industry_layout.png', dpi=dpi, facecolor='#1a1a1a')
    plt.show()
    
    # Calculate metrics
    mse = np.mean((predicted_microns - actual_microns) ** 2)
    mae = np.mean(np.abs(predicted_microns - actual_microns))
    
    print(f"\nLayout Accuracy Metrics:")
    print(f"   Mean Squared Error: {mse:.2f} µm²")
    print(f"   Mean Absolute Error: {mae:.2f} µm")
    print(f"   Average X Error: {np.mean(np.abs(predicted_microns[:, 0] - actual_microns[:, 0])):.2f} µm")
    print(f"   Average Y Error: {np.mean(np.abs(predicted_microns[:, 1] - actual_microns[:, 1])):.2f} µm")

# Visualize with test data
print("Creating industry-grade layout visualization...")
model.eval()
with torch.no_grad():
    test_data = cn_test[10].to(device)
    test_pred = model(test_data).cpu().numpy()

visualize_industry_layout(test_data, test_pred)
print("Industry-grade visualization complete!")

## Run GNN Placement Inference via GUI
Browse and select your **standard VLSI design files** — the GNN predicts cell placement from scratch:
- **Verilog Netlist** (`.v`) — cell instances, nets, connectivity *(required)*
- **LEF Library** (`.lef`) — cell dimensions & pin definitions *(required)*
- **SDC file** (`.sdc`) — timing constraints *(optional)*
- **Floorplan DEF** (`.def`) — die area & IO pad positions only, NOT cell placement *(optional)*
- **Model checkpoint** (`.pth`) — trained GNN weights *(required)*

The cell parses these files, converts them into the 16-feature graph representation, and the **GNN predicts placement coordinates** — no existing placement required.

In [2]:
import tkinter as tk
from tkinter import filedialog, messagebox, ttk
import re
import json
import threading

# ═══════════════════════════════════════════════════════════════════
# PARSERS: Convert standard VLSI files → GNN-compatible format
# ═══════════════════════════════════════════════════════════════════

def parse_verilog_netlist(verilog_path):
    """
    Parse a Verilog netlist (.v) to extract:
      - Cell instances (name, type)
      - Net connections (which cells connect to which nets)
      - Pin-level connectivity
    
    Supports gate-level netlists with module instantiations like:
        NAND2X1 U123 (.A(net1), .B(net2), .Y(net3));
    """
    with open(verilog_path, 'r') as f:
        content = f.read()
    
    # Remove comments
    content = re.sub(r'//.*?\n', '\n', content)
    content = re.sub(r'/\*.*?\*/', '', content, flags=re.DOTALL)
    
    cells = []        # list of {name, type}
    nets = set()       # all net names
    pins = []          # list of {cell_idx, pin_name, net_name}
    
    # Pattern for module instantiation:
    #   CellType InstanceName ( .PinName(NetName), ... );
    # Also handles: CellType #(...) InstanceName ( ... );
    inst_pattern = re.compile(
        r'(\w+)\s+'                           # cell type
        r'(?:#\s*\([^)]*\)\s+)?'              # optional parameters #(...)
        r'(\w+)\s*\('                          # instance name (
        r'(.*?)\)\s*;',                        # pin connections );
        re.DOTALL
    )
    
    # Skip these keywords (not cell instantiations)
    skip_keywords = {
        'module', 'endmodule', 'input', 'output', 'inout', 'wire', 'reg',
        'assign', 'always', 'initial', 'begin', 'end', 'if', 'else',
        'case', 'endcase', 'for', 'while', 'function', 'endfunction',
        'task', 'endtask', 'generate', 'endgenerate', 'parameter',
        'localparam', 'supply0', 'supply1', 'tri', 'wand', 'wor'
    }
    
    # Pin connection pattern: .PinName(NetName)
    pin_pattern = re.compile(r'\.(\w+)\s*\(\s*([^)]*?)\s*\)')
    
    for match in inst_pattern.finditer(content):
        cell_type = match.group(1)
        inst_name = match.group(2)
        pin_section = match.group(3)
        
        if cell_type.lower() in skip_keywords:
            continue
        
        cell_idx = len(cells)
        cells.append({'name': inst_name, 'type': cell_type})
        
        # Extract pin connections
        for pin_match in pin_pattern.finditer(pin_section):
            pin_name = pin_match.group(1)
            net_name = pin_match.group(2).strip()
            
            if net_name and net_name not in ("", "1'b0", "1'b1"):
                # Handle bus notation: net[3:0] or net[2]
                base_net = re.sub(r'\[.*?\]', '', net_name).strip()
                if base_net:
                    nets.add(base_net)
                    pins.append({
                        'cell_idx': cell_idx,
                        'pin_name': pin_name,
                        'net_name': base_net
                    })
    
    if not cells:
        raise ValueError("No cell instantiations found in Verilog file. "
                         "Make sure it's a gate-level / structural netlist.")
    
    # Build net-to-index mapping
    net_list = sorted(nets)
    net_to_idx = {n: i for i, n in enumerate(net_list)}
    
    # Build node_attr equivalent: [cell_names, cell_types]
    cell_names = [c['name'] for c in cells]
    cell_types = [c['type'] for c in cells]
    node_attr = np.array([cell_names, cell_types], dtype=object)
    
    # Build pin_attr equivalent: [pin_names, net_indices, node_indices]
    pin_names_arr = [p['pin_name'] for p in pins]
    pin_nets_arr = [net_to_idx[p['net_name']] for p in pins]
    pin_nodes_arr = [p['cell_idx'] for p in pins]
    pin_attr = np.array([pin_names_arr, pin_nets_arr, pin_nodes_arr], dtype=object)
    
    print(f"   [Verilog] Parsed {len(cells)} cells, {len(net_list)} nets, {len(pins)} pins")
    return node_attr, pin_attr, cell_names, cell_types


def parse_def_file(def_path):
    """
    Parse a DEF file (.def) to extract cell placement coordinates.
    
    Returns a dict: cell_name -> [x_min, y_min, x_max, y_max]
    
    Handles DEF COMPONENTS section:
        - CELL_NAME CellType + PLACED ( X Y ) Orientation ;
        - CELL_NAME CellType + FIXED ( X Y ) Orientation ;
    """
    with open(def_path, 'r') as f:
        content = f.read()
    
    placement = {}
    
    # Find COMPONENTS section
    comp_match = re.search(r'COMPONENTS\s+\d+\s*;(.*?)END\s+COMPONENTS', content, re.DOTALL)
    if not comp_match:
        raise ValueError("No COMPONENTS section found in DEF file.")
    
    comp_section = comp_match.group(1)
    
    # Parse each component: - INST_NAME CELL_TYPE + PLACED/FIXED ( X Y ) ORIENT ;
    comp_pattern = re.compile(
        r'-\s+(\S+)\s+(\S+)\s*'                    # instance name, cell type
        r'.*?(?:PLACED|FIXED|COVER)\s*'              # placement status
        r'\(\s*(-?\d+)\s+(-?\d+)\s*\)\s*'           # ( X Y )
        r'(\w+)\s*;',                                # orientation
        re.DOTALL
    )
    
    for match in comp_pattern.finditer(comp_section):
        inst_name = match.group(1)
        x = int(match.group(3))
        y = int(match.group(4))
        # Assume default cell size (will be overridden by LEF if available)
        placement[inst_name] = [x, y, x + 1, y + 1]
    
    if not placement:
        raise ValueError("No placed components found in DEF file.")
    
    print(f"   [DEF] Parsed {len(placement)} placed components")
    return placement


def parse_lef_file(lef_path):
    """
    Parse a LEF file (.lef) to extract cell library dimensions.
    
    Returns a dict: cell_type -> {width, height, pins: [pin_name, ...]}
    """
    with open(lef_path, 'r') as f:
        content = f.read()
    
    cell_lib = {}
    
    # Find each MACRO definition
    macro_pattern = re.compile(
        r'MACRO\s+(\S+)\s*\n(.*?)END\s+\1',
        re.DOTALL
    )
    
    for match in macro_pattern.finditer(content):
        macro_name = match.group(1)
        macro_body = match.group(2)
        
        # Extract SIZE width BY height
        size_match = re.search(r'SIZE\s+([\d.]+)\s+BY\s+([\d.]+)', macro_body)
        width, height = 1.0, 1.0
        if size_match:
            width = float(size_match.group(1))
            height = float(size_match.group(2))
        
        # Extract CLASS
        class_match = re.search(r'CLASS\s+(\w+)', macro_body)
        cell_class = class_match.group(1) if class_match else "CORE"
        
        # Extract PIN names
        pin_names = re.findall(r'PIN\s+(\S+)', macro_body)
        
        cell_lib[macro_name] = {
            'width': width,
            'height': height,
            'class': cell_class,
            'pins': pin_names
        }
    
    print(f"   [LEF] Parsed {len(cell_lib)} cell definitions")
    return cell_lib


def parse_sdc_file(sdc_path):
    """
    Parse an SDC file (.sdc) to extract timing constraints.
    
    Returns a dict with:
      - clock_period: float (ns)
      - clock_name: str
      - input_delays: dict
      - output_delays: dict
      - max_fanout: int (if set)
    """
    with open(sdc_path, 'r') as f:
        lines = f.readlines()
    
    constraints = {
        'clock_period': 10.0,  # default 10ns
        'clock_name': 'clk',
        'input_delays': {},
        'output_delays': {},
        'max_fanout': None,
        'max_transition': None,
    }
    
    for line in lines:
        line = line.strip()
        if line.startswith('#') or not line:
            continue
        
        # create_clock -period VALUE -name NAME
        clk_match = re.search(r'create_clock\s+.*?-period\s+([\d.]+)', line)
        if clk_match:
            constraints['clock_period'] = float(clk_match.group(1))
            name_match = re.search(r'-name\s+(\S+)', line)
            if name_match:
                constraints['clock_name'] = name_match.group(1)
        
        # set_input_delay
        in_delay = re.search(r'set_input_delay\s+.*?([\d.]+).*?\[get_ports\s+(\S+)\]', line)
        if in_delay:
            constraints['input_delays'][in_delay.group(2)] = float(in_delay.group(1))
        
        # set_output_delay
        out_delay = re.search(r'set_output_delay\s+.*?([\d.]+).*?\[get_ports\s+(\S+)\]', line)
        if out_delay:
            constraints['output_delays'][out_delay.group(2)] = float(out_delay.group(1))
        
        # set_max_fanout
        fanout_match = re.search(r'set_max_fanout\s+([\d.]+)', line)
        if fanout_match:
            constraints['max_fanout'] = int(float(fanout_match.group(1)))
        
        # set_max_transition
        trans_match = re.search(r'set_max_transition\s+([\d.]+)', line)
        if trans_match:
            constraints['max_transition'] = float(trans_match.group(1))
    
    print(f"   [SDC] Clock: {constraints['clock_name']} @ {constraints['clock_period']}ns, "
          f"{len(constraints['input_delays'])} input delays, "
          f"{len(constraints['output_delays'])} output delays")
    return constraints


# ═══════════════════════════════════════════════════════════════════
# CONVERTER: Parsed data → GNN-compatible PyTorch Geometric Data
# ═══════════════════════════════════════════════════════════════════

def convert_to_gnn_data(node_attr, pin_attr, placement_dict, cell_lib=None, sdc_constraints=None):
    """
    Convert parsed VLSI files into the 16-feature PyTorch Geometric Data object
    that VLSIPlacementGNN expects.
    
    Args:
        node_attr: [2, N] array — cell names and types
        pin_attr:  [3, M] array — pin names, net indices, node indices
        placement_dict: dict cell_name → [x_min, y_min, x_max, y_max]
        cell_lib: (optional) LEF library — cell_type → {width, height, class, pins}
        sdc_constraints: (optional) SDC timing constraints
    """
    cell_names = list(node_attr[0])
    cell_types = list(node_attr[1])
    num_nodes = len(cell_names)
    
    # ── Coordinates & sizes from placement + LEF ──
    coords = np.zeros((num_nodes, 2), dtype=np.float32)
    sizes = np.zeros((num_nodes, 2), dtype=np.float32)
    areas = np.zeros(num_nodes, dtype=np.float32)
    
    for i, name in enumerate(cell_names):
        if name in placement_dict:
            bbox = placement_dict[name]
            coords[i, 0] = (bbox[0] + bbox[2]) / 2.0
            coords[i, 1] = (bbox[1] + bbox[3]) / 2.0
            w = abs(bbox[2] - bbox[0])
            h = abs(bbox[3] - bbox[1])
            
            # If LEF available, use real dimensions
            ctype = cell_types[i]
            if cell_lib and ctype in cell_lib:
                w = cell_lib[ctype]['width']
                h = cell_lib[ctype]['height']
            
            sizes[i] = [max(w, 0.001), max(h, 0.001)]
        else:
            # Cell not placed — random initial position
            coords[i] = np.random.rand(2) * 1000
            sizes[i] = [1.0, 1.0]
        
        areas[i] = sizes[i, 0] * sizes[i, 1]
    
    # ── Normalize ──
    coord_min = coords.min(axis=0)
    coord_range = coords.max(axis=0) - coord_min + 1e-8
    coords_norm = (coords - coord_min) / coord_range
    sizes_norm = sizes / (sizes.max() + 1e-8)
    total_chip_area = coord_range[0] * coord_range[1] + 1e-8
    
    # ── Cell classification ──
    is_macro = np.zeros(num_nodes, dtype=np.float32)
    cell_category = np.zeros(num_nodes, dtype=np.float32)
    type_encoding = np.zeros(num_nodes, dtype=np.float32)
    
    unique_types = sorted(set(cell_types))
    type_to_idx = {t: i for i, t in enumerate(unique_types)}
    type_enc = np.array([type_to_idx[t] for t in cell_types], dtype=np.float32)
    if len(unique_types) > 1:
        type_encoding = type_enc / (len(unique_types) - 1)
    
    for i, (ct, a) in enumerate(zip(cell_types, areas)):
        macro, filler, cat = classify_cell(ct, a)
        is_macro[i] = float(macro)
        cell_category[i] = cat / 3.0
    
    # ── Build edges from pin_attr (netlist connectivity) ──
    edge_index, n_orphans, n_nets, cell_connectivity = build_netlist_edges(
        pin_attr, node_attr, cell_names, star_threshold=10
    )
    
    # Fallback to KNN if no netlist edges
    if edge_index.shape[1] == 0:
        print("   WARNING: No netlist edges built, falling back to KNN")
        k_neighbors = min(8, num_nodes - 1)
        if num_nodes > 1:
            tree = KDTree(coords_norm)
            _, indices = tree.query(coords_norm, k=k_neighbors + 1)
            src = np.repeat(np.arange(num_nodes), k_neighbors)
            dst = indices[:, 1:].flatten()
            edge_index = np.stack([src, dst], axis=0).astype(np.int64)
        else:
            edge_index = np.array([[0], [0]], dtype=np.int64)
        for idx in range(num_nodes):
            cell_connectivity[idx] = {'pin_count': 0, 'net_count': 0, 'avg_fanout': 0, 'max_fanout': 0}
    
    # ── Connectivity features ──
    pin_counts = np.array([cell_connectivity.get(i, {}).get('pin_count', 0) for i in range(num_nodes)], dtype=np.float32)
    net_counts = np.array([cell_connectivity.get(i, {}).get('net_count', 0) for i in range(num_nodes)], dtype=np.float32)
    avg_fanouts = np.array([cell_connectivity.get(i, {}).get('avg_fanout', 0) for i in range(num_nodes)], dtype=np.float32)
    max_fanouts = np.array([cell_connectivity.get(i, {}).get('max_fanout', 0) for i in range(num_nodes)], dtype=np.float32)
    
    pin_counts_norm = pin_counts / (pin_counts.max() + 1e-8)
    net_counts_norm = net_counts / (net_counts.max() + 1e-8)
    avg_fanouts_norm = avg_fanouts / (avg_fanouts.max() + 1e-8)
    max_fanouts_norm = max_fanouts / (max_fanouts.max() + 1e-8)
    
    relative_area = areas / (total_chip_area + 1e-8)
    relative_area_norm = relative_area / (relative_area.max() + 1e-8)
    
    aspect_ratio = sizes[:, 0] / (sizes[:, 1] + 1e-8)
    aspect_ratio = np.clip(aspect_ratio, 0, 10)
    aspect_ratio_norm = aspect_ratio / (aspect_ratio.max() + 1e-8)
    
    conn_importance = pin_counts * net_counts
    conn_importance_norm = conn_importance / (conn_importance.max() + 1e-8)
    
    # Average neighbor size
    neighbor_area_avg = np.zeros(num_nodes, dtype=np.float32)
    if edge_index.shape[1] > 0:
        ei = edge_index
        for e in range(ei.shape[1]):
            src_n, dst_n = ei[0, e], ei[1, e]
            neighbor_area_avg[src_n] += areas[dst_n]
        degree = np.bincount(ei[0], minlength=num_nodes).astype(np.float32)
        degree[degree == 0] = 1
        neighbor_area_avg /= degree
    neighbor_area_norm = neighbor_area_avg / (neighbor_area_avg.max() + 1e-8)
    
    # ── Assemble 16-feature vector ──
    node_features = np.zeros((num_nodes, 16), dtype=np.float32)
    node_features[:, 0]  = coords_norm[:, 0]
    node_features[:, 1]  = coords_norm[:, 1]
    node_features[:, 2]  = sizes_norm[:, 0]
    node_features[:, 3]  = sizes_norm[:, 1]
    node_features[:, 4]  = np.log1p(areas)
    node_features[:, 4] /= (node_features[:, 4].max() + 1e-8)
    node_features[:, 5]  = type_encoding
    node_features[:, 6]  = is_macro
    node_features[:, 7]  = cell_category
    node_features[:, 8]  = pin_counts_norm
    node_features[:, 9]  = net_counts_norm
    node_features[:, 10] = avg_fanouts_norm
    node_features[:, 11] = max_fanouts_norm
    node_features[:, 12] = relative_area_norm
    node_features[:, 13] = aspect_ratio_norm
    node_features[:, 14] = neighbor_area_norm
    node_features[:, 15] = conn_importance_norm
    
    # ── Build Data object ──
    data = Data(
        x=torch.tensor(node_features, dtype=torch.float),
        edge_index=torch.tensor(edge_index, dtype=torch.long),
        y=torch.tensor(coords_norm, dtype=torch.float)  # original placement as reference
    )
    data.num_cells = num_nodes
    data.sample_name = "user_design"
    data.design_name = "user_design"
    data.original_coords = torch.tensor(coords, dtype=torch.float)
    data.is_macro = torch.tensor(is_macro, dtype=torch.float)
    data.areas = torch.tensor(areas, dtype=torch.float)
    data.cell_names = cell_names
    data.cell_types = cell_types
    data.coord_min = coord_min
    data.coord_range = coord_range
    
    print(f"\n   Converted to GNN format:")
    print(f"      Cells: {num_nodes:,}")
    print(f"      Edges: {edge_index.shape[1]:,}")
    print(f"      Macros: {int(is_macro.sum())}")
    print(f"      Features: 16 per cell")
    
    return data


# ═══════════════════════════════════════════════════════════════════
# GUI: File selection + inference + visualization
# ═══════════════════════════════════════════════════════════════════

def run_gnn_inference_gui():
    """
    Launch a tkinter GUI to:
      1. Browse for Verilog netlist, DEF, LEF (optional), SDC (optional), Model
      2. Parse all files automatically
      3. Run GNN inference
      4. Display predicted placement
    """
    
    # ── State ──
    file_paths = {
        'verilog': None,
        'def': None,
        'lef': None,
        'sdc': None,
        'model': None,
    }
    
    # ── Root window ──
    root = tk.Tk()
    root.title("VLSI GNN Placement - File Input")
    root.geometry("720x560")
    root.configure(bg='#1e1e2e')
    root.resizable(False, False)
    
    style = ttk.Style()
    style.theme_use('clam')
    style.configure('Header.TLabel', background='#1e1e2e', foreground='#cdd6f4',
                    font=('Segoe UI', 16, 'bold'))
    style.configure('Sub.TLabel', background='#1e1e2e', foreground='#a6adc8',
                    font=('Segoe UI', 9))
    style.configure('File.TLabel', background='#313244', foreground='#cdd6f4',
                    font=('Consolas', 9), padding=5)
    style.configure('Browse.TButton', font=('Segoe UI', 9))
    style.configure('Run.TButton', font=('Segoe UI', 11, 'bold'))
    style.configure('Status.TLabel', background='#1e1e2e', foreground='#a6e3a1',
                    font=('Segoe UI', 10))
    
    # ── Header ──
    ttk.Label(root, text="VLSI GNN Placement Inference", style='Header.TLabel').pack(pady=(18, 2))
    ttk.Label(root, text="Select your design files below. Verilog + DEF are required; LEF and SDC are optional.",
              style='Sub.TLabel').pack(pady=(0, 14))
    
    # ── File rows ──
    file_frame = tk.Frame(root, bg='#1e1e2e')
    file_frame.pack(fill='x', padx=30)
    
    path_labels = {}
    
    file_defs = [
        ('verilog', 'Verilog Netlist (.v)', [('Verilog', '*.v'), ('All', '*.*')], True),
        ('def',     'DEF Placement (.def)', [('DEF', '*.def'), ('All', '*.*')], True),
        ('lef',     'LEF Library (.lef)',    [('LEF', '*.lef'), ('All', '*.*')], False),
        ('sdc',     'SDC Constraints (.sdc)',[('SDC', '*.sdc'), ('All', '*.*')], False),
        ('model',   'Model Weights (.pth)',  [('PyTorch', '*.pth'), ('All', '*.*')], True),
    ]
    
    for i, (key, label_text, ftypes, required) in enumerate(file_defs):
        row = tk.Frame(file_frame, bg='#1e1e2e')
        row.pack(fill='x', pady=4)
        
        req_tag = " *" if required else ""
        lbl = tk.Label(row, text=f"{label_text}{req_tag}", bg='#1e1e2e', fg='#cdd6f4',
                       font=('Segoe UI', 10), width=28, anchor='w')
        lbl.pack(side='left')
        
        path_var = tk.StringVar(value="No file selected")
        path_lbl = tk.Label(row, textvariable=path_var, bg='#313244', fg='#bac2de',
                            font=('Consolas', 9), width=38, anchor='w', relief='flat', padx=6, pady=3)
        path_lbl.pack(side='left', padx=(4, 6))
        path_labels[key] = path_var
        
        def make_browse(k=key, ft=ftypes, pv=path_var):
            def browse():
                p = filedialog.askopenfilename(filetypes=ft)
                if p:
                    file_paths[k] = p
                    pv.set(os.path.basename(p))
            return browse
        
        btn = ttk.Button(row, text="Browse", command=make_browse(), style='Browse.TButton')
        btn.pack(side='left')
    
    # ── Required note ──
    ttk.Label(file_frame, text="* = required", style='Sub.TLabel').pack(anchor='w', pady=(6, 0))
    
    # ── Status ──
    status_var = tk.StringVar(value="Ready — select files and click Run Inference")
    status_label = tk.Label(root, textvariable=status_var, bg='#1e1e2e', fg='#a6e3a1',
                            font=('Segoe UI', 10), wraplength=650, justify='left')
    status_label.pack(pady=(16, 4), padx=30, anchor='w')
    
    # ── Progress bar ──
    progress = ttk.Progressbar(root, mode='determinate', length=660)
    progress.pack(pady=(2, 12), padx=30)
    
    # ── Run button ──
    def run_inference():
        # Validate required files
        for key in ['verilog', 'def', 'model']:
            if not file_paths[key]:
                nice = {'verilog': 'Verilog Netlist', 'def': 'DEF File', 'model': 'Model Weights'}[key]
                messagebox.showerror("Missing File", f"Please select a {nice} file.")
                return
        
        status_var.set("Parsing files...")
        progress['value'] = 0
        root.update_idletasks()
        
        try:
            # Step 1: Parse Verilog
            status_var.set("Step 1/5 — Parsing Verilog netlist...")
            progress['value'] = 10
            root.update_idletasks()
            node_attr, pin_attr, cell_names, cell_types = parse_verilog_netlist(file_paths['verilog'])
            
            # Step 2: Parse DEF
            status_var.set("Step 2/5 — Parsing DEF placement...")
            progress['value'] = 25
            root.update_idletasks()
            placement_dict = parse_def_file(file_paths['def'])
            
            # Step 3: Parse LEF (optional)
            cell_lib = None
            if file_paths['lef']:
                status_var.set("Step 3/5 — Parsing LEF library...")
                progress['value'] = 35
                root.update_idletasks()
                cell_lib = parse_lef_file(file_paths['lef'])
            
            # Step 4: Parse SDC (optional)
            sdc_constraints = None
            if file_paths['sdc']:
                status_var.set("Step 4/5 — Parsing SDC constraints...")
                progress['value'] = 40
                root.update_idletasks()
                sdc_constraints = parse_sdc_file(file_paths['sdc'])
            
            # Step 5: Convert to GNN format
            status_var.set("Step 5/5 — Converting to GNN format & running inference...")
            progress['value'] = 50
            root.update_idletasks()
            
            gnn_data = convert_to_gnn_data(
                node_attr, pin_attr, placement_dict,
                cell_lib=cell_lib,
                sdc_constraints=sdc_constraints
            )
            
            # Load model
            progress['value'] = 60
            root.update_idletasks()
            
            inference_model = VLSIPlacementGNN(
                input_dim=16, hidden_dim=128, output_dim=2,
                num_layers=4, heads=4
            ).to(device)
            
            ckpt = torch.load(file_paths['model'], map_location=device)
            inference_model.load_state_dict(ckpt['model_state_dict'])
            inference_model.eval()
            
            status_var.set("Running GNN inference...")
            progress['value'] = 75
            root.update_idletasks()
            
            with torch.no_grad():
                gnn_data_device = gnn_data.to(device)
                predictions = inference_model(gnn_data_device).cpu().numpy()
            
            progress['value'] = 90
            root.update_idletasks()
            
            # ── Save results ──
            output_dir = os.path.dirname(file_paths['verilog'])
            
            # Convert predictions back to micron coordinates
            pred_coords = predictions * gnn_data.coord_range + gnn_data.coord_min
            orig_coords = gnn_data.original_coords.numpy()
            
            # Save JSON
            result = {
                'design': gnn_data.design_name,
                'num_cells': int(gnn_data.num_cells),
                'cells': []
            }
            for idx in range(len(predictions)):
                result['cells'].append({
                    'name': gnn_data.cell_names[idx],
                    'type': gnn_data.cell_types[idx],
                    'predicted_x': float(pred_coords[idx, 0]),
                    'predicted_y': float(pred_coords[idx, 1]),
                    'original_x': float(orig_coords[idx, 0]),
                    'original_y': float(orig_coords[idx, 1]),
                })
            
            json_out = os.path.join(output_dir, "gnn_placement_result.json")
            with open(json_out, 'w') as f:
                json.dump(result, f, indent=2)
            
            # Save DEF
            def_out = os.path.join(output_dir, "gnn_placement_result.def")
            with open(def_out, 'w') as f:
                f.write(f"DESIGN {gnn_data.design_name} ;\n")
                f.write(f"COMPONENTS {gnn_data.num_cells} ;\n")
                for idx in range(len(predictions)):
                    x_int = int(pred_coords[idx, 0])
                    y_int = int(pred_coords[idx, 1])
                    f.write(f"  - {gnn_data.cell_names[idx]} {gnn_data.cell_types[idx]}"
                            f" + PLACED ( {x_int} {y_int} ) N ;\n")
                f.write("END COMPONENTS\n")
            
            progress['value'] = 100
            status_var.set(f"Done! Results saved to:\n  {json_out}\n  {def_out}")
            root.update_idletasks()
            
            # ── Close GUI and visualize in matplotlib ──
            root.destroy()
            
            # Visualize
            print("\n" + "=" * 70)
            print("GNN PLACEMENT INFERENCE COMPLETE")
            print("=" * 70)
            print(f"   Cells processed: {gnn_data.num_cells:,}")
            print(f"   Results saved to: {output_dir}")
            print(f"   JSON: {os.path.basename(json_out)}")
            print(f"   DEF:  {os.path.basename(def_out)}")
            print("=" * 70)
            
            # Plot predicted vs original
            fig, axes = plt.subplots(1, 2, figsize=(16, 8), dpi=120)
            fig.patch.set_facecolor('#1a1a1a')
            
            for ax, coords_plot, title, color in [
                (axes[0], pred_coords, 'GNN Predicted Placement', 'cyan'),
                (axes[1], orig_coords, 'Original Placement (from DEF)', 'lime'),
            ]:
                ax.set_facecolor('#1a1a1a')
                
                # Color macros differently
                macro_mask = gnn_data.is_macro.numpy().astype(bool)
                
                ax.scatter(coords_plot[~macro_mask, 0], coords_plot[~macro_mask, 1],
                          s=3, c=color, alpha=0.6, label='Standard cells')
                if macro_mask.any():
                    ax.scatter(coords_plot[macro_mask, 0], coords_plot[macro_mask, 1],
                              s=40, c='red', marker='s', alpha=0.9, label='Macros')
                
                # Sample edges
                ei = gnn_data.edge_index.numpy()
                for e in range(0, min(500, ei.shape[1]), 5):
                    s, d = ei[0, e], ei[1, e]
                    ax.plot([coords_plot[s, 0], coords_plot[d, 0]],
                            [coords_plot[s, 1], coords_plot[d, 1]],
                            'yellow', alpha=0.1, linewidth=0.3)
                
                ax.set_title(title, fontsize=14, color='white', pad=12)
                ax.set_xlabel('X (µm)', color='white')
                ax.set_ylabel('Y (µm)', color='white')
                ax.tick_params(colors='white')
                ax.legend(fontsize=9, loc='upper right')
                ax.grid(True, alpha=0.15, color='gray')
            
            plt.tight_layout()
            plt.savefig(os.path.join(output_dir, 'gnn_placement_comparison.png'),
                        dpi=150, facecolor='#1a1a1a', bbox_inches='tight')
            plt.show()
            
            # Error metrics
            mse = np.mean((pred_coords - orig_coords) ** 2)
            mae = np.mean(np.abs(pred_coords - orig_coords))
            print(f"\nPlacement Metrics vs Original:")
            print(f"   MSE:  {mse:.2f}")
            print(f"   MAE:  {mae:.2f}")
            print(f"   Max displacement: {np.max(np.linalg.norm(pred_coords - orig_coords, axis=1)):.2f}")
        
        except Exception as e:
            status_var.set(f"ERROR: {str(e)}")
            progress['value'] = 0
            messagebox.showerror("Error", f"An error occurred:\n\n{str(e)}")
            import traceback
            traceback.print_exc()
    
    run_btn = tk.Button(root, text="▶  Run Inference", command=run_inference,
                        bg='#89b4fa', fg='#1e1e2e', font=('Segoe UI', 12, 'bold'),
                        relief='flat', padx=30, pady=8, cursor='hand2',
                        activebackground='#74c7ec', activeforeground='#1e1e2e')
    run_btn.pack(pady=(4, 18))
    
    root.mainloop()


# ── Launch the GUI ──
print("Launching file selection GUI...")
print("   Select your Verilog netlist, DEF, and model .pth file.")
print("   LEF and SDC are optional but improve accuracy.\n")
run_gnn_inference_gui()

Launching file selection GUI...
   Select your Verilog netlist, DEF, and model .pth file.
   LEF and SDC are optional but improve accuracy.

