## Import Libraries

In [1]:
!pip install torch_geometric pyvis

Defaulting to user installation because normal site-packages is not writeable


In [1]:
import os
import ast
import torch
import random
import numpy as np
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigsh
#from torch_geometric.data import Data, DataLoader, Batch
from torch_geometric.data import Data, Batch
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv, SAGEConv, GCNConv
import torch.nn.functional as F

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RUN_PATH = './raw_graph'

## Parse Files

In [2]:
def parse_formatted(filepath):
    globals_ = {}
    bounds = {}
    with open(filepath, 'r') as f:
        lines = f.read().splitlines()
    f.close()
    for l in lines:
        if '=' in l:
            key, val = l.split('=', 1)
            key = key.strip()
            if key in {'lx', 'ly', 'ux', 'uy'}:
                bounds[key] = float(val.strip())
            else:
                globals_[key] = float(val.strip())
    records = [ast.literal_eval(l) for l in lines if l.strip().startswith('{')]
    drivers = [r['driver']['id'] for r in records]
    sinks = [s['id'] for r in records for s in r['sinks']]
    node_ids = sorted(set(drivers + sinks))
    return node_ids, records, globals_, bounds

def parse_label_formatted(label_path, node_ids, bounds):
    lines = []
    with open(label_path, 'r') as f:
        for row in f:
            parts = row.strip().split()
            if len(parts) == 3 and parts[0].isdigit():
                lines.append((int(parts[0]), float(parts[1]), float(parts[2])))
    f.close()
    ids, xs, ys = zip(*lines)
    xs = np.array(xs)
    ys = np.array(ys)
    lx, ly, ux, uy = bounds['lx'], bounds['ly'], bounds['ux'], bounds['uy']

    x_norm = (xs - lx) / (ux - lx)
    y_norm = (ys - ly) / (uy - ly)
    
    # clamp if outside lower or upper bound  
    x_norm = np.clip(x_norm, 0.0, 1.0)
    y_norm = np.clip(y_norm, 0.0, 1.0)
    # map back to node order
    id2coord = {i: (x_norm[idx], y_norm[idx]) for idx, i in enumerate(ids)}
    coords = [id2coord.get(nid, (0.0, 0.0)) for nid in node_ids]
    return torch.tensor(coords, dtype=torch.float)


## Matrix/Feature Generation and Relative Loss

In [3]:
def build_edge_index(node_ids, records, bidirectional=True):
    id2idx = {nid: i for i, nid in enumerate(node_ids)}
    edges = []
    for r in records:
        d = id2idx[r['driver']['id']]
        for s in r['sinks']:
            sid = id2idx[s['id']]
            edges.append((d, sid))
            if bidirectional:
                edges.append((sid, d))
    if not edges:
        return torch.empty((2, 0), dtype=torch.long)
    return torch.tensor(edges, dtype=torch.long).t().contiguous()

def build_adjacency(N, edge_index):
    src, dst = edge_index.cpu().numpy()
    return csr_matrix((np.ones(len(src)), (src, dst)), shape=(N, N))

def compute_laplacian_eigenvectors(adj, k=10, normalized=True):
    N = adj.shape[0]
    k_eff = min(k, max(N-1, 0))
    deg = np.array(adj.sum(axis=1)).flatten()
    if normalized:
        inv_s = np.where(deg > 0, 1.0/np.sqrt(deg), 0.0)
        D = csr_matrix((inv_s, (range(N), range(N))), shape=adj.shape)
        L = csr_matrix(np.eye(N)) - D @ adj @ D
    else:
        D = csr_matrix((deg, (range(N), range(N))), shape=adj.shape)
        L = D - adj
    if k_eff < 1:
        return np.zeros((N, 0), dtype=np.float32)
    try:
        _, vecs = eigsh(L, k=k_eff+1, which='SM')
    except:
        _, vecs = np.linalg.eigh(L.toarray())
    return vecs[:, 1:k_eff+1]

def compute_relative_loss(out, data, criterion):
    edge_index = data.edge_index
    
    src, tgt = edge_index
    pred_src, pred_tgt = out[src], out[tgt]
    true_src, true_tgt = data.y[src], data.y[tgt]
    pred_dist = torch.norm(pred_src - pred_tgt, dim=1)
    true_dist = torch.norm(true_src - true_tgt, dim=1)
    
    loss = criterion(pred_dist, true_dist)

    return loss

def adjust_learning_rate(optimizer, epoch):
    adjust_list = [80, 150, 300]
    if epoch in adjust_list:
        for param_group in optimizer.param_groups:
            param_group['lr'] = param_group['lr'] * 0.5

## Load Data

In [4]:
def load_all_data(root_dir, train_list, test_list, design_filter=[], batch_size=16, shuffle=True):
    #data_list = []
    train_data_list = []
    test_data_list = []
    for dp, _, files in os.walk(root_dir):
        for f in files:
            if not f.endswith('_formatted.txt') or f.endswith('_label_formatted.txt'):
                continue
            #if design_filter and f.split('_')[0] != design_filter:
            curr_file = f.rsplit('_', 6)[0]
            #print(curr_file)
            if design_filter and curr_file not in design_filter:
                continue
            fp = os.path.join(dp, f)
            lf = fp.replace('_formatted.txt', '_label_formatted.txt')
            if not os.path.exists(lf):
                continue
            node_ids, records, globals_, bounds = parse_formatted(fp)
            orig_coords = {rec['driver']['id']:(rec['driver']['x'], rec['driver']['y']) for rec in records}
            for rec in records:
                for s in rec['sinks']:
                    orig_coords[s['id']] = (s['x'], s['y'])
            feats = torch.tensor(
                compute_laplacian_eigenvectors(
                    build_adjacency(len(node_ids), build_edge_index(node_ids, records)), 10
                ), dtype=torch.float
            )
            labels = parse_label_formatted(lf, node_ids, bounds)
            # need to globally normalize these values
            u_vec = torch.tensor([
                (globals_['Core Aspect Ratio'] - 0.5) / 0.4,
                (globals_['Utilization'] - 40.0) / 28.0,
                (globals_['Place Density'] - 0.2) / 0.3,
                (globals_['core_width']/1000000),
                (globals_['core_height']/1000000)
            ], dtype=torch.float).unsqueeze(0)
            edges = build_edge_index(node_ids, records)
            data = Data(x=feats*100, edge_index=edges, u=u_vec, y=labels)
            data.to(device)
            data.design_name = f.replace('_formatted.txt','')
            data.node_ids = node_ids
            data.bounds = bounds
            data.orig_coords = orig_coords
            data.fixed_ids = [rec['driver']['id'] for rec in records if rec['driver'].get('is_fixed')]
            data.fixed_ids += [s['id'] for rec in records for s in rec['sinks'] if s.get('is_fixed')]
            #data_list.append(data)
            if curr_file in train_list:
                train_data_list.append(data)
            elif curr_file in test_list:
                test_data_list.append(data)
    if shuffle:
        random.shuffle(train_data_list)
        random.shuffle(test_data_list)
    return DataLoader(train_data_list, batch_size=batch_size, shuffle=shuffle, exclude_keys=['orig_coords','node_ids','bounds','fixed_ids', 'records']), DataLoader(test_data_list, batch_size=batch_size, shuffle=shuffle, exclude_keys=['orig_coords','node_ids','bounds','fixed_ids', 'records'])


## Model

In [5]:
class PlacementGNN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels=64, num_layers=3, global_channels=3, conv_type='sage'):
        super().__init__()
        ConvMap = {'gat': GATv2Conv, 'sage': SAGEConv, 'gcn': GCNConv}
        Conv = ConvMap[conv_type]
        self.convs = torch.nn.ModuleList()
        for i in range(num_layers):
            self.convs.append(
                Conv(
                    in_channels  if i == 0 else hidden_channels,
                    #in_channels + global_channels if i == 0 else hidden_channels,
                    hidden_channels
                )
            )
        self.post_lin = torch.nn.Linear(hidden_channels + global_channels, hidden_channels)
        #self.post_lin = torch.nn.Linear(hidden_channels, hidden_channels)
        self.out_lin = torch.nn.Linear(hidden_channels, 2)

    def forward(self, x, edge_index, batch, u, edge_attr=None):
        
        for conv in self.convs:
            if isinstance(conv, GATv2Conv):
                x = conv(x, edge_index, edge_attr)
            else:
                x = conv(x, edge_index)
            x = F.relu(x)
        u_exp = u[batch]
        h = torch.cat([x, u_exp], dim=1)
        h = F.relu(self.post_lin(h))
        return self.out_lin(h)

In [6]:
 def save_checkpoint(state, is_best):
     filepath = 'checkpoint.pth'
     torch.save(state, filepath)
     if is_best:
         shutil.copyfile(filepath, 'model_best.pth.tar')

## Training

In [6]:
!unzip "{RUN_PATH}.zip" -d '.'

Archive:  ./raw_graph.zip
   creating: ./raw_graph/
  inflating: ./raw_graph/aes_nangate45_aes_run_1_1_1_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_1_1_label_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_1_2_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_1_2_label_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_1_3_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_1_3_label_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_1_4_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_1_4_label_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_2_1_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_2_1_label_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_2_2_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_2_2_label_formatted.txt  
  inflating: ./raw_graph/aes_nangate45_aes_run_1_2_3_formatted.txt  
  inflating: ./

In [14]:
def soft_density_loss(x, y, cell_area=0.0001, bin_size=0.05, density_threshold=0.4, sigma=0.01):
    device = x.device
    x = torch.clamp(x, 0.0, 1.0)
    y = torch.clamp(y, 0.0, 1.0)
    x_bins = torch.arange(bin_size / 2, 1.0, bin_size, device=device)
    y_bins = torch.arange(bin_size / 2, 1.0, bin_size, device=device)
    x_centers, y_centers = torch.meshgrid(x_bins, y_bins, indexing='ij')

    Bx, By = x_centers.shape
    num_bins = Bx * By

    x_centers_flat = x_centers.flatten().unsqueeze(0)
    y_centers_flat = y_centers.flatten().unsqueeze(0)
    x_expand = x.unsqueeze(1)
    y_expand = y.unsqueeze(1)

    dx2 = (x_expand - x_centers_flat) ** 2
    dy2 = (y_expand - y_centers_flat) ** 2
    gauss = torch.exp(-(dx2 + dy2) / (2 * sigma**2))

    density_per_bin = torch.sum(gauss, dim=0) * cell_area
    bin_area = bin_size * bin_size
    density_norm = density_per_bin / bin_area

    penalty = torch.clamp(density_norm - density_threshold, min=0.0)
    loss = penalty.sum()

    return loss

In [33]:
def train(train_dataset, model, criterion, optimizer, epoch):
    train_loss = 0.0
    model.train()
    for data in train_dataset:
        data = data.to(device)

        optimizer.zero_grad()
        out = model(data.x, data.edge_index, data.batch, data.u)
        pred_x, pred_y = out[:, 0], out[:, 1]
        density = soft_density_loss(pred_x, pred_y, cell_area=0.0001, bin_size=0.05, density_threshold=0.6, sigma=0.01)
        loss = 0.2 * criterion(out, data.y) + 0.8 * compute_relative_loss(out, data, criterion) + 0.0 * density

        loss.backward()
        optimizer.step()

        train_loss += loss.item() * data.num_graphs
        
    if (epoch) % 10 == 0:
        print(f'Epoch {epoch} Training loss: MSE: {train_loss/len(train_dataset.dataset):.4f}')

In [34]:
def validate(test_dataset, model, criterion):
    test_loss = 0.0
    model.eval()
    with torch.no_grad():
        for data in test_dataset.dataset:
            data = data.to(device)

            batch_vec = torch.zeros(data.x.size(0), dtype=torch.long, device=device)
            out = model(data.x, data.edge_index, batch_vec, data.u)
            pred_x, pred_y = out[:, 0], out[:, 1]
            density = soft_density_loss(pred_x, pred_y, cell_area=0.0001, bin_size=0.05, density_threshold=0.6, sigma=0.01)
            loss = 0.2 * criterion(out, data.y) + 0.8 * compute_relative_loss(out, data, criterion) + 0.0 * density
    
            test_loss += loss.item()
            
    return test_loss

In [35]:
def train_wrapper(train_dataset, test_dataset):
    
    lr = 5e-3
    weight_decay = 1e-4
    epochs = 120
    # gat (0.0029) - lr = 8e-3 250, 350, 450 * 0.5, epochs = 500
    
    model = PlacementGNN(10, 64, 3, 5, 'sage').to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    #optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = torch.nn.MSELoss().to(device)
    
    for epoch in range(0, epochs):
        adjust_learning_rate(optimizer, epoch)
        # Training
        train(train_dataset, model, criterion, optimizer, epoch)
        # Validation
        test_loss = validate(test_dataset, model, criterion)
        
        if (epoch) % 10 == 0:
            print(f'Validation loss: MSE: {test_loss/len(test_dataset.dataset):.4f}\n')
    
    print("Saving Model")
    torch.save(model.state_dict(), 'gnn_all.pth')


## Inference

In [36]:
def infer_wrapper(train_dataset, test_dataset):
    model = PlacementGNN(10, 64, 3, 5, 'sage').to(device)
    
    #PATH = "./model_best.pth.tar"
    #model = torch.load(PATH)
    #model.load_state_dict(checkpoint['state_dict'])
    PATH = "./gnn_all.pth"
    state_dict = torch.load(PATH)
    model.load_state_dict(state_dict)
    
    os.makedirs("./pred", exist_ok=True)
    
    total = 0.0
    
    model.cuda()
    model.eval()
    criterion = torch.nn.MSELoss().to(device)
    
    with torch.no_grad():
        #for data in loader.dataset:
        for data in test_dataset.dataset:
            formatted_fp = os.path.join(f"{RUN_PATH}", data.design_name + '_formatted.txt')
            _, records, _, _ = parse_formatted(formatted_fp)
            data = data.to(device)
            batch_vec = torch.zeros(data.x.size(0), dtype=torch.long, device=device)
            out = model(data.x, data.edge_index, batch_vec, data.u)
            pred_x, pred_y = out[:, 0], out[:, 1]
            density = soft_density_loss(pred_x, pred_y, cell_area=0.0001, bin_size=0.05, density_threshold=0.6, sigma=0.01)
            loss = 0.2 * criterion(out, data.y) + 0.8 * compute_relative_loss(out, data, criterion) + 0.0 * density
            total += loss.item()
            lx, ux = data.bounds['lx'], data.bounds['ux']
            ly, uy = data.bounds['ly'], data.bounds['uy']
            scale  = torch.tensor([ux - lx, uy - ly], device=out.device)
            offset = torch.tensor([lx, ly], device=out.device)
            preds  = (out * scale + offset).cpu().numpy()
            id2name = {}
            fixed_ids = []
            for rec in records:
                d = rec['driver']
                id2name[d['id']] = d.get('name', str(d['id']))
                for s in rec['sinks']:
                    id2name[s['id']] = s.get('name', str(s['id']))
                    
            node_ids   = data.node_ids.tolist() if torch.is_tensor(data.node_ids) else data.node_ids
            names = [id2name.get(nid, str(nid)) for nid in node_ids]
            if hasattr(data, 'fixed_ids') and data.fixed_ids:
                mask = ~np.isin(node_ids, data.fixed_ids)
                names = [name for name, m in zip(names, mask) if m]
                preds = preds[mask]
    
            fname = f"./pred/{data.design_name}_predictions.txt"
            with open(fname, "w", newline="") as f:
                #w = csv.writer(f)
                f.write(f"InstanceName x_center y_center\n")
                for name, (xv, yv) in zip(names, preds):
                    f.write(f"{name} {xv:.4f} {yv:.4f}\n")
            print(f"Saved {fname}")
            f.close()
            
    print(f'MSE: {total/len(test_dataset.dataset):.4f}')

In [12]:
print("Starting data loading")

#train_set = ["ibex_nangate45", "aes_nangate45", "gcd_asap7", "ibex_asap7", "jpeg_asap7", "jpeg_nangate45"]
train_set = ["gcd_nangate45", "aes_nangate45", "ibex_asap7"]
test_set = ["gcd_asap7"]
    
train_dataset, test_dataset = load_all_data(f"{RUN_PATH}", train_set, test_set, design_filter=["gcd_asap7", "gcd_nangate45", "aes_nangate45", "ibex_asap7"], batch_size=8) 
    
print("Data Loaded!")

Starting data loading
Data Loaded!


In [37]:
train_wrapper(train_dataset, test_dataset)    
#infer_wrapper(train_dataset, test_dataset)

Epoch 0 Training loss: MSE: 0.0310
Validation loss: MSE: 0.0719

Epoch 10 Training loss: MSE: 0.0149
Validation loss: MSE: 0.0547

Epoch 20 Training loss: MSE: 0.0135
Validation loss: MSE: 0.0485

Epoch 30 Training loss: MSE: 0.0126
Validation loss: MSE: 0.0515

Epoch 40 Training loss: MSE: 0.0121
Validation loss: MSE: 0.0536

Epoch 50 Training loss: MSE: 0.0113
Validation loss: MSE: 0.0557

Epoch 60 Training loss: MSE: 0.0112
Validation loss: MSE: 0.0586

Epoch 70 Training loss: MSE: 0.0110
Validation loss: MSE: 0.0574

Epoch 80 Training loss: MSE: 0.0106
Validation loss: MSE: 0.0532

Epoch 90 Training loss: MSE: 0.0104
Validation loss: MSE: 0.0551

Epoch 100 Training loss: MSE: 0.0103
Validation loss: MSE: 0.0552

Epoch 110 Training loss: MSE: 0.0103
Validation loss: MSE: 0.0541

Saving Model


In [27]:
infer_wrapper(train_dataset, test_dataset)

Saved ./pred/gcd_asap7_gcd_run_1_2_2_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_5_5_1_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_5_2_3_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_2_3_2_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_3_5_4_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_5_3_3_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_4_1_3_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_2_3_3_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_4_5_3_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_2_3_4_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_3_5_1_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_3_3_3_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_5_3_4_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_5_1_1_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_5_1_3_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_1_5_3_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_4_1_1_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_4_4_2_predictions.txt
Saved ./pred/gcd_asap7_gcd_run_3_4_3_predictio

In [152]:
# Run: tar -czvf pred.tar.gz pred/
# Transfer and untar: tar -xzvf pred.tar.gz

/bin/bash: line 1: zip: command not found
