In [None]:
# default_exp netlist

# Netlist

> SAX Netlist Models

In [None]:
# hide
import os, sys; sys.stderr = open(os.devnull, "w")

The models were generated as follows:

In [None]:
# export
from enum import Enum
from functools import partial
from hashlib import md5
from typing import Any, Dict, List, Optional, Union

import black
import jax.numpy as jnp
import networkx as nx
import numpy as np
import orjson
from fastcore.basics import patch_to
from pydantic import BaseModel as _BaseModel
from pydantic import Extra, Field, validator
from sax.utils import clean_string, flatten_dict

In [None]:
# hide
import sax

In [None]:
# export


def hash_dict(dic):
    return int(
        md5(
            orjson.dumps(dic, option=orjson.OPT_SERIALIZE_NUMPY | orjson.OPT_SORT_KEYS)
        ).hexdigest(),
        16,
    )


class BaseModel(_BaseModel):
    class Config:
        extra = Extra.ignore
        allow_mutation = False
        frozen = True

    def __repr__(self):
        s = super().__repr__()
        s = black.format_str(s, mode=black.Mode())
        return s

    def __str__(self):
        return self.__repr__()

    def __hash__(self):
        return hash_dict(self.dict())

In [None]:
import io
import json
from contextlib import redirect_stdout
from urllib.request import urlopen

import datamodel_code_generator as dcg

NETLIST_SCHEMA_URL = "https://raw.githubusercontent.com/gdsfactory/gdsfactory/master/gdsfactory/tests/schemas/netlist.json"

def download_schema(url=NETLIST_SCHEMA_URL):
    response = urlopen(url)
    json_text = response.read().decode()
    schema_dict = json.loads(json_text)
    return schema_dict

def generate_models():
    schema_dict = download_schema()
    json_text = json.dumps(schema_dict)
    buf = io.StringIO()
    with redirect_stdout(buf):
        dcg.generate(input_=json_text)
    content = buf.getvalue()
    content = content.replace(
        "extra = Extra.forbid", 
        (
            "extra = Extra.ignore\n"
            "        allow_mutation = False\n"
            "        frozen = True"
        ),
    )
    return content

print(generate_models())

This was autogenerated (with minimal changes made):

In [None]:
# export


class ComponentModel(BaseModel):
    class Config:
        extra = Extra.ignore
        allow_mutation = False
        frozen = True

    component: Union[str, Dict[str, Any]] = Field(..., title="Component")
    settings: Optional[Dict[str, Any]] = Field(None, title="Settings")

    # this was added:

    @validator("component")
    def validate_component_name(cls, value):
        if "," in value:
            raise ValueError(
                f"Invalid component string. Should not contain ','. Got: {value}"
            )
        return clean_string(value)
    

class PortEnum(Enum):
    ce = "ce"
    cw = "cw"
    nc = "nc"
    ne = "ne"
    nw = "nw"
    sc = "sc"
    se = "se"
    sw = "sw"
    center = "center"
    cc = "cc"


class PlacementModel(BaseModel):
    class Config:
        extra = Extra.ignore
        allow_mutation = False
        frozen = True

    x: Optional[Union[str, float]] = Field(0, title="X")
    y: Optional[Union[str, float]] = Field(0, title="Y")
    xmin: Optional[Union[str, float]] = Field(None, title="Xmin")
    ymin: Optional[Union[str, float]] = Field(None, title="Ymin")
    xmax: Optional[Union[str, float]] = Field(None, title="Xmax")
    ymax: Optional[Union[str, float]] = Field(None, title="Ymax")
    dx: Optional[float] = Field(0, title="Dx")
    dy: Optional[float] = Field(0, title="Dy")
    port: Optional[Union[str, PortEnum]] = Field(None, title="Port")
    rotation: Optional[int] = Field(0, title="Rotation")
    mirror: Optional[bool] = Field(False, title="Mirror")


class RouteModel(BaseModel):
    class Config:
        extra = Extra.ignore
        allow_mutation = False
        frozen = True

    links: Dict[str, str] = Field(..., title="Links")
    settings: Optional[Dict[str, Any]] = Field(None, title="Settings")
    routing_strategy: Optional[str] = Field(None, title="Routing Strategy")


class NetlistModel(BaseModel):
    class Config:
        extra = Extra.ignore
        allow_mutation = False
        frozen = True

    instances: Dict[str, ComponentModel] = Field(..., title="Instances")
    connections: Optional[Dict[str, str]] = Field(None, title="Connections")
    ports: Optional[Dict[str, str]] = Field(None, title="Ports")
    placements: Optional[Dict[str, PlacementModel]] = Field(None, title="Placements")

    # these were removed (irrelevant for SAX):

    # routes: Optional[Dict[str, RouteModel]] = Field(None, title='Routes')
    # name: Optional[str] = Field(None, title='Name')
    # info: Optional[Dict[str, Any]] = Field(None, title='Info')
    # settings: Optional[Dict[str, Any]] = Field(None, title='Settings')
    # pdk: Optional[str] = Field(None, title='Pdk')

    # these are extra additions:

    @validator("instances", pre=True)
    def coerce_string_instance_into_component_model(cls, instances):
        new_instances = {}
        for k, v in instances.items():
            if isinstance(v, str):
                v = {
                    "component": v,
                    "settings": {},
                }
            new_instances[k] = v
        return new_instances

    @staticmethod
    def clean_instance_string(value):
        if "," in value:
            raise ValueError(
                f"Invalid instance string. Should not contain ','. Got: {value}"
            )
        return clean_string(value)

    @validator("instances")
    def validate_instance_names(cls, instances):
        return {cls.clean_instance_string(k): v for k, v in instances.items()}

    @classmethod
    def clean_connection_string(cls, value):
        *comp, port = value.split(",")
        comp = cls.clean_instance_string(",".join(comp))
        return f"{comp},{port}"

    @validator("connections")
    def validate_connection_names(cls, instances):
        return {
            cls.clean_connection_string(k): cls.clean_connection_string(v)
            for k, v in instances.items()
        }

These are manual additions:

In [None]:
# export
    
class RecursiveNetlistModel(BaseModel):
    class Config:
        extra = Extra.ignore
        allow_mutation = False
        frozen = True
        
    __root__: Dict[str, NetlistModel]

In [None]:
import gdsfactory as gf
from gdsfactory.components import mzi
from gdsfactory.get_netlist import get_netlist_recursive, get_netlist_dict, get_netlist_yaml, get_netlist

@gf.cell
def twomzi():
    c = gf.Component()
    mzi1 = mzi(delta_length=10)
    mzi2 = mzi(delta_length=20)
    mzi1_ = (c << mzi1)
    mzi2_ = (c << mzi2)
    mzi2_.connect('o1', mzi1_.ports['o2'])
    return c

comp  = twomzi()
display(comp)
recnet = RecursiveNetlistModel.parse_obj(get_netlist_recursive(comp, get_netlist_func=partial(get_netlist_dict, full_settings=True)))
flatnet = recnet.__root__['mzi_delta_length10']

In [None]:
# export
def create_dag(
    net: RecursiveNetlistModel,
    models: Optional[Dict[str, Any]] = None,
):
    if models is None:
        models = {}
    assert isinstance(models, dict)
    
    all_models = {}
    g = nx.DiGraph()

    for model_name, net in net.dict()['__root__'].items():
        if not model_name in all_models:
            all_models[model_name] = models.get(model_name, net)
            g.add_node(model_name)
        if model_name in models:
            continue
        for instance in net['instances'].values():
            component = instance['component']
            if not component in all_models:
                all_models[component] = models.get(component, None)
                g.add_node(component)
            g.add_edge(model_name, component)
            
    return g

In [None]:
# export
def find_root(g):
    nodes = [n for n, d in g.in_degree() if d == 0]
    return nodes

In [None]:
# export
def find_leaves(g):
    nodes = [n for n, d in g.out_degree() if d == 0]
    return nodes

In [None]:
class s:
    def __init__(self, s):
        self.s = s
    def __mul__(self, other):
        return s(f"{self.s}×{other.s}")
    def __repr__(self):
        return self.s
    def __str__(self):
        return self.s

In [None]:
def bend_euler(
    angle=90.0,
    p=0.5,
    # cross_section="strip",
    # direction="ccw",
    # with_bbox=True,
    # with_arc_floorplan=True,
    # npoints=720,
):
    return sax.reciprocal({
        ('o1', 'o2'): 1.0
    })

In [None]:
def mmi1x2(
    width=0.5,
    width_taper= 1.0,
    length_taper= 10.0,
    length_mmi= 5.5,
    width_mmi= 2.5,
    gap_mmi= 0.25,
    # cross_section= strip,
    # taper= {function= taper},
    # with_bbox= True,
):
    return sax.reciprocal({
        ('o1', 'o2'): 0.45**0.5,
        ('o1', 'o3'): 0.45**0.5,
    })

In [None]:
def mmi2x2(
    width=0.5,
    width_taper= 1.0,
    length_taper= 10.0,
    length_mmi= 5.5,
    width_mmi= 2.5,
    gap_mmi= 0.25,
    # cross_section= strip,
    # taper= {function= taper},
    # with_bbox= True,
):
    return sax.reciprocal({
        ('o1', 'o3'): 0.45**0.5,
        ('o1', 'o4'): 1j * 0.45**0.5,
        ('o2', 'o3'): 1j * 0.45**0.5,
        ('o2', 'o4'): 0.45**0.5,
    })

In [None]:
def straight(
    length=0.01,
    npoints=2,
    with_bbox=True,
    cross_section=None
):
    if cross_section is None:
        cross_section = {
            "layer": "WG",
            "width": 0.5,
            "offset": 0.0,
            "radius": 10.0,
            "width_wide": None,
            "auto_widen": False,
            "auto_widen_minimum_length": 200.0,
            "taper_length": 10.0,
            "bbox_layers": [],
            "bbox_offsets": [],
            "cladding_layers": ["DEVREC"],
            "cladding_offsets": [0.0],
            "sections": [],
            "port_names": ["o1", "o2"],
            "port_types": ["optical", "optical"],
            "min_length": 0.01,
            "start_straight_length": 0.01,
            "end_straight_length": 0.01,
            "snap_to_grid": None,
            "info": {},
            "name": None,
            "add_center_section": True,
        }
    return sax.reciprocal({
        ('o1', 'o2'): 1.0
    })

In [None]:
straight()

In [None]:
models = {
    'straight': straight,
    'bend_euler': bend_euler,
    'mmi1x2': mmi1x2,
    #"mzi_delta_length20": mmi_2x2
}

In [None]:
g = create_dag(recnet, models)
nx.draw_planar(g, with_labels=True)

In [None]:
def _validate_models(models, dag):
    required_models = find_leaves(dag)
    missing_models = [m for m in required_models if m not in models]
    if missing_models:
        model_diff = {
            "Missing Models": missing_models,
            "Given Models": list(models),
            "Required Models": required_models,
        }
        raise ValueError(
            "Missing models. The following models are still missing to build the circuit:\n"
            f"{black.format_str(repr(model_diff), mode=black.Mode())}"
        )
    return {**models} # shallow copy

In [None]:
models = _validate_models(models, g)

In [None]:
from sax.backends import circuit_backends
from sax.typing_ import Settings, SType
from sax.utils import get_settings, _replace_kwargs, merge_dicts

def _flat_circuit(instances, connections, ports, models, backend):
    evaluate_circuit = circuit_backends[backend]

    inst2model = {k: models[inst.component] for k, inst in instances.items()}

    model_settings = {name: get_settings(model) for name, model in inst2model.items()}
    netlist_settings = {
        name: {k: v for k, v in inst.settings.items() if k in model_settings[name]}
        for name, inst in instances.items()
    }
    default_settings = merge_dicts(model_settings, netlist_settings)

    def _circuit(**settings: Settings) -> SType:
        settings = merge_dicts(model_settings, settings)
        instances: Dict[str, SType] = {}
        for inst_name, model in inst2model.items():
            instances[inst_name] = model(**settings.get(inst_name, {}))
        S = evaluate_circuit(instances, connections, ports)
        return S

    _replace_kwargs(_circuit, **default_settings)

    return _circuit


circuit = _flat_circuit(flatnet.instances, flatnet.connections, flatnet.ports, models, "default")

In [None]:
def _validate_circuit_backend(backend):
    backend = backend.lower()
    # assert valid circuit_backend
    if backend not in circuit_backends:
        raise KeyError(
            f"circuit backend {backend} not found. Allowed circuit backends: "
            f"{', '.join(circuit_backends.keys())}."
        )
    return backend


def _validate_modes(modes) -> List[str]:
    if modes is None:
        return ["te"]
    elif not modes:
        return ["te"]
    elif isinstance(modes, str):
        return [modes]
    elif all(isinstance(m, str) for m in modes):
        return modes
    else:
        raise ValueError(f"Invalid modes given: {modes}")


def _validate_net(
    net: Union[NetlistModel, RecursiveNetlistModel]
) -> RecursiveNetlistModel:
    if isinstance(net, NetlistModel):
        net = RecursiveNetlistModel(__root__={"top_level": net})
    return net


def _validate_dag(dag):
    nodes = find_root(dag)
    if len(nodes) > 1:
        raise ValueError(f"Multiple top_levels found in netlist: {nodes}")
    if len(nodes) < 1:
        raise ValueError(f"Netlist does not contain any nodes.")
    if not dag.is_directed():
        raise ValueError("Netlist dependency cycles detected!")
    return dag


def _make_singlemode_or_multimode(net, modes, models):
    if len(modes) == 1:
        connections, ports, models = _make_singlemode(net, modes[0], models)
    else:
        connections, ports, models = _make_multimode(net, modes, models)
    return connections, ports, models


def _make_singlemode(net, mode, models):
    models = {k: singlemode(m, mode=mode) for k, m in models.items()}
    return net.connections, net.ports, models


def _make_multimode(net, mode, models):
    models = {k: multimode(m, modes=modes) for k, m in models.items()}
    connections = {
        f"{p1}@{mode}": f"{p2}@{mode}"
        for p1, p2 in net.connections.items()
        for mode in modes
    }
    ports = {
        f"{p1}@{mode}": f"{p2}@{mode}" for p1, p2 in net.ports.items() for mode in modes
    }
    return connections, ports, models


from sax.multimode import singlemode, multimode
from sax.typing_ import Model
from typing import Tuple, List


def circuit(
    net: Union[NetlistModel, RecursiveNetlistModel],
    models: Dict[str, Model],
    modes: Optional[List[str]]=None,
    backend: str="default",
) -> Tuple[Model, Dict[str, Model]]:
    
    recnet: RecursiveNetlistModel = _validate_net(net)
    dag: nx.DiGraph = _validate_dag(create_dag(recnet, models))  # directed acyclic graph
    models = _validate_models(models, dag)
    modes = _validate_modes(modes)
    backend = _validate_circuit_backend(backend)

    circuit = None
    new_models = {}
    current_models = {}
    model_names = list(nx.topological_sort(dag))[::-1]
    for model_name in model_names:
        if model_name in models:
            new_models[model_name] = models[model_name]
            continue

        flatnet = recnet.__root__[model_name]

        connections, ports, new_models = _make_singlemode_or_multimode(
            flatnet, modes, new_models
        )
        current_models.update(new_models)
        new_models = {}

        current_models[model_name] = circuit = _flat_circuit(
            flatnet.instances, connections, ports, current_models, backend
        )
    
    assert circuit is not None
    return circuit, current_models

    # if modes is None or isinstance(modes, str):
    #    instances, connections, ports, sm_mm = _make_singlemode(net, modes)
    # else:
    #    instances, connections, ports,

In [None]:
circuit, _ = circuit(recnet, models)

In [None]:
circuit()