# Kronecker Expression Parser

## Problem Setup

Kronecker product and plus are frequently used operations in quantum physics, especially when you want to define a Hamiltonian. However, the mathematical notation of such expressions is different from what we use in modern BLAS libraries (e.g numpy, Eigen, Julia's LinearAlgebra, etc.). Basically we want to inform the machine that the following expression

$$
\sigma_1^1\sigma_1^4 + \sigma_1^3\sigma_1^5 + \sigma_2^1 (\sigma_1^2 + \sigma_1^3)
$$

here we denote the Pauli group as

$$
\sigma_0 = I \quad \sigma_1 = X \quad \sigma_2 = Y \quad \sigma_3 = Z
$$

and this expression actually means

$$
\sigma_1^1\sigma_0^2\sigma_0^3\sigma_1^4\sigma_1^5 + \sigma_0^1\sigma_0^2\sigma_1^3\sigma_0^4\sigma_1^5 + \sigma_2^1 \sigma_1^2\sigma_0^3\sigma_0^4\sigma_0^5 + \sigma_2^1\sigma_0^2\sigma_1^3\sigma_0^4\sigma_0^5
$$

Therefore, our task is to parser the original equation to this machine understandable expression and thanks to Julia's powerful metaprogramming feature, we are able to do this in a more elegant way with zero abstraction overhead by code generation.

## Step 1: Define some shorthands

The builtin `kron` function for kronecker product can only calculates two matrix, to ease future tests, we will implement some shorthands. First, we will overload the $\otimes$ operator which will be parsed by Julia parser.

In [1]:
⊗(A, B) = kron(A, B)

⊗ (generic function with 1 method)

 Then we will implement a function called `kronprod` that do kronecker product on an iterator that will iterates through several matrixes.

In [None]:
function kronprod(itr)
    state = start(itr)
    first, state = next(itr, state)
    second, state = next(itr, state)
    pd = kron(first, second)
    while !done(itr, state)
        val, state = next(itr, state)
        pd = kron(pd, val)
    end
    return pd
end

However, we would like to make this more clear, we will bind this function to another one with varitional parameters.

In [None]:
kronprod(m::AbstractMatrix...) = kronprod(m)

Then we will define some basic constants. In fact you can just use my package [QMTK.jl](http://rogerluo.me/QMTK.jl), the `Consts` module of QMTK provides most of the constants you will use in Quantum Physics based on NIST CODATA project.


In [None]:
const sigmax = sparse([0 1;1 0])
const sigmay = sparse([0 -im;im 0])
const sigmaz = sparse([1 0;0 -1])
const sigmai = sparse([1 0;0 1])

const σ₀ = sigmai
const σ₁ = sigmax
const σ₂ = sigmay
const σ₃ = sigmaz


## Step 2: Analyse the expression

Let's analyse what we are going to parse again. The expression looks like

$$
\sigma_1^1\sigma_1^4 + \sigma_1^3\sigma_1^5 + \sigma_2^1 (\sigma_1^2 + \sigma_1^3)
$$


mathematically, but it is not clear to write a superscript in the code, therefore we use a `[]` instead. Then it will looks like

In [None]:
ex = :(σ₁[1] * σ₁[4] + σ₁[3] * σ₁[5] + σ₂[1] * (σ₁[2] + σ₁[3]))

1. We will use $*$/$\otimes$ to denote the Kronecker product.
2. We won consider the multiplication of scalar factors in this tutorial, which means `:(2 * σ₁[1] * σ₁[4])` will not be parsed.
3. We allow distributive law of the Kronecker product and it should be expanded to plain expressions.

Thus, there are three kinds of objects that we should consider

- notation `x[order]`
- an expression block that contains only Kronecker product operator $\sigma_1^1\sigma_1^4$
- binary operators like `+`, `-`


## Step 3: The Matrix Expression

We first have to create an abstraction for the notation of a matrix expression with a specific order

In [None]:
mutable struct MatExpr
    expr
    index::Int
end

MatExpr(p::Pair) = MatExpr(p.first, p.second)

And we also have to offer a constructor from expression, it only accepts an expression tagged as a reference

In [None]:
function MatExpr(e::Expr)
    if Meta.isexpr(e, :ref)
        return MatExpr(e.args[1], e.args[2])
    else
        throw(ParseError(
            "expects square brackets after matrix: matrix[index]"
        ))
    end
end

Here we could reconsider what is an `type` in Julia. A `type` in Julia actually defines a set and functions on the set defines some properties of such set. And for the set of `MatExpr` has order, thus we will define a binary relationship here, called `isless`. This is a function in the standard interface, but don't worry that you will overload it by accident, Julia requires you to import those methods to overload them, or you will creat something only in current module (it throws an Error when in the `Main` module).

In [None]:
import Base: isless
isless(x::MatExpr, y::MatExpr) = isless(x.index, y.index)

## Step 4: Kronecker Product Block

A block of Kronecker product is a few `MatExpr` with total length (to offer the information about how many identities we need to insert). We will denote it as the **Kronecker Product Block**. We seperate it from other binary operators because such a block has to contain information about the total number of matrixes inside such a block, which is quite different from other binary operators.

However, as a mathematical property of Kronecker product, this operation has orders, which means

$$
\sigma_1\sigma_2\sigma_3
$$

is different from

$$
\sigma_2\sigma_1\sigma_3
$$

We have to keep our `MatExpr` in order based on the index of each `MatExpr`. Therefore, we have to keep a list of `MatExpr` in order. Let's review some sort algorithms:


- **Heap sort**: complexity $n\log(n)$, insertion complexity: $log(n)$
- **quitck sort**: complexity $O(n^2)$ (worst-case), $n\log(n)$ (Best case)

Since we will keep inserting a list while reading the whole expression, we will choose to store our `MatExpr` on a heap. All the utilies related to heap is included in the package `DataStructures`.

In [None]:
using DataStructures

abstract type KronExpr end

mutable struct KronProd <: KronExpr
    len::Int
    args::Vector{MatExpr}
end

KronProd() = KronProd(0, [])

function setlen!(ex::KronProd, len::Int)
    ex.len = len
    return ex
end

function KronProd(seq::Vector{MatExpr})
    total = 0
    for each in seq
        if total < each.index
            total = each.index
        end
    end
    return KronProd(total, heapify(seq))
end

## Step 5: Binary Expression

In this tutorial, we only consider binary operators like `+`, `-`, but we do not want to make a type for each operator with similar content. We will use a parameter type with two abstract type as its tags

In [None]:
abstract type DirectPlus end
abstract type DirectMinus end

mutable struct BinaryExpr{OP, LHS <: KronExpr, RHS <: KronExpr} <: KronExpr
    len::Int
    left::LHS
    right::RHS
end

function setlen!(ex::BinaryExpr, len::Int)
    setlen!(ex.left, len)
    setlen!(ex.right, len)
    return ex
end

function BinaryExpr(::Type{OP}, lhs::LHS, rhs::RHS) where {OP, LHS, RHS}
    len = max(lhs.len, rhs.len)
    setlen!(lhs, len)
    setlen!(rhs, len)
    BinaryExpr{OP, LHS, RHS}(len, lhs, rhs)
end

## Step 6: Pretty Printing

Before start to parse our expression, to ease the process of debugging and monitoring, we will prettify the display of each type first. This requires to overload the `show` interface of Base

In [None]:
import Base: show

function show(io::IO, ex::MatExpr)
    print(io, "$(ex.expr)[$(ex.index)]")
end

function show(io::IO, ex::KronProd)
    seq = copy(ex.args)
    print(io, "kron($(ex.len), ")
    count = 1
    while !isempty(seq)
        each = heappop!(seq)
        print(io, "$each")
        if count < length(ex.args)
            print(io, " ⊗ ")
        end
        count += 1
    end
    print(io, ")")
end

function show(io::IO, ex::BinaryExpr{DirectPlus})
    print(io, "$(ex.left)")
    print(io, " ⊕ ")
    print(io, "$(ex.right)")
end

function show(io::IO, ex::BinaryExpr{DirectMinus})
    print(io, "$(ex.left)")
    print(io, " ⊖ ")
    print(io, "$(ex.right)")
end

## Step 7: Parse to Kronecker Expression

The Kronecker expression includes two type of basic blocks we defined: `KronProd`, `BinaryExpr`, we will parse the Julia native AST(Abstract Syntax Tree) to a tree with two kinds of node and then re-generate an Julia AST (`Expr`) that actually evaluates the expression with function `kron`. Therefore, we firstly need to construct such a tree, since we actually won't care about the efficiency during code generation time, we will simply implement this recursively here.


The builtin parser entrance is called `parse`. We will call ours `kronparse`. By default, `kronparse` should be an identity. 

In [None]:
kronparse(expr) = expr

When the input is an AST, which is the `Expr` type in Julia, we will dispatch it to different factory functions or constructors according to its expression tag. There are four different conditions:

1. this expression is a reference `x[1]`, which means this is a Kronecker product block with a single matrix
2. this expression is not a call expression, which means this is not an operator, then we will just return it
3. this expression is a call expression of binary expressions `+`,`-`, we will send it to factory function `make_binary`
4. this expression is a call expression of binary expressions `*, ⋅, ⊗`, we will send it to factory function `make_kronprod`
5. this expression does not satify any condition above, we throw an `ErrorException` with the detailed information of this expression.

In [None]:
function kronparse(expr::Expr)
    if Meta.isexpr(expr, :ref)
        return KronProd([MatExpr(expr)])
    end

    if !Meta.isexpr(expr, :call)
        return expr
    end

    if expr.args[1] in (:+, :-)
        return make_binary(expr)
    elseif expr.args[1] in (:*, :⋅, :⊗)
        return make_kronprod(expr)
    else
        throw(ErrorException("Invalid Expression $expr"))
    end
end

#### Parse Binary Operators

We then implement those factory functions. Firstly, the `make_binary`, this will dispatch the expression to two different factory functions according to the expression.

In [None]:
function make_binary(expr::Expr)
    if expr.args[1] == :+
        return make_plus(expr)
    else
        return make_minus(expr)
    end
end

Now, before we implement all the details, we first define some very useful utilities.

In [None]:
isprod(expr) = false
isplus(expr) = false
isminus(expr) = false

isprod(expr::Expr) = expr.args[1] in :(:*, :⋅, :⊗)
isplus(expr::Expr) = expr.args[1] == :+
isminus(expr::Expr) = expr.args[1] == :-

We attach the factory function back to make use of multiple dispatch and define the constructor of a `BinaryExpr`

In [None]:
make_plus(expr::Expr) = BinaryExpr(DirectPlus, expr)
make_minus(expr::Expr) = BinaryExpr(DirectMinus, expr)

Each `BinaryExpr` is actually a binary tree in the AST, however, native operators `+`/`-` is parsed to a list in Julia to reduce recursion, we will re-parse them to a tree structure here to easy our use in the future.

In [None]:
function BinaryExpr(::Type{OP}, expr::Expr) where OP
    seq = copy(expr.args[2:end])

    if length(seq) == 2
        r = BinaryExpr(
            OP,
            kronparse(seq[1]),
            kronparse(seq[2])
        )
        return r
    else
        return BinaryExpr(
            OP,
            BinaryExpr(
                DirectPlus,
                Expr(:call, :+, seq[1:end-1]...),
            ),
            kronparse(seq[end])
        )
    end
end

#### Parse Kronecker Product

For Kronecker product block, we have to expand brackets and merge local matrixes inside a brackets to global blocks. And for convenience we will let type `KronProd` inherit the interface of the heap with several methods.

In [None]:
import Base: push!, pop!, copy

copy(c::KronProd) = KronProd(c.len, copy(c.args))
copy(ex::BinaryExpr{OP}) where OP = BinaryExpr(OP, copy(ex.left), copy(ex.right))

push!(c::KronExpr, m::Expr) = push!(c, MatExpr(m))
pop!(c::KronProd) = heappop!(c.args)

function push!(ex::KronProd, m::MatExpr)
    heappush!(ex.args, m)
    if ex.len < m.index
        ex.len = m.index
    end
    return ex
end

function make_kronprod(expr::Expr)
    ex = KronProd()
    for each in expr.args[2:end]
        if Meta.isexpr(each, :ref)
            push!(ex, each)
        else
            ex = merge(ex, kronparse(each))
        end
    end
    return ex
end

Now we will consider several conditions in a Kronecker product expression:

- $(\sigma_1^1\sigma_1^5)(\sigma_1^2\sigma_2^6) = \sigma_1^1\sigma_1^2\sigma_1^5\sigma_2^6$ merge two identical block together to one block
- $(\sigma_1^1 + \sigma_2^3)\sigma_1^2 = \sigma_1^1\sigma_1^2 + \sigma_1^2\sigma_2^3$ merge an identical block on the right into blocks inside a braket from right.
- $\sigma_1^1 (\sigma_1^2 + \sigma_2^3) = \sigma_1^1\sigma_1^2 + \sigma_1^1\sigma_2^3$ merge an identical block on the left into blocks inside a bracket from left.
- $(\sigma_1^1\sigma_2^2 + \sigma_1^1\sigma_1^2)(\sigma_1^3\sigma_1^4 + \sigma_2^3\sigma_2^4)$ expand two binary expressions to single Kronecker product blocks.


In [None]:
# lhs = lhs * rhs
function merge!(lhs::KronProd, rhs::KronProd)
    for each in rhs.args
        push!(lhs, each)
    end

    lhs.len = max(lhs.len, rhs.len)
    return lhs
end

merge(lhs::KronProd, rhs::KronProd) = merge!(copy(lhs), rhs)

# lhs := (lhs[1] + lhs[2]) * rhs = lhs[1] * rhs + lhs[2] * rhs
function merge!(lhs::BinaryExpr, rhs::KronProd)
    merge!(lhs.left, rhs)
    merge!(lhs.right, rhs)

    lhs.len = max(lhs.len, rhs.len)
    return lhs
end

# lhs := lhs[3] * (rhs[1] + rhs[2])
#      = lhs[3] * rhs[1] + lhs[3] * rhs[2]
#      = rhs[1] * lhs[3] + rhs[2] * lhs[3]
function merge(lhs::BinaryExpr, rhs::KronProd)
    merge!(copy(lhs), rhs)
end

merge(lhs::KronProd, rhs::BinaryExpr) = merge(rhs, lhs)


#   (lhs[1] OP1 lhs[2]) * (rhs[3] OP2 rhs[4])
# = lhs[1] * rhs[3] OP2 lhs[1] * rhs[4] OP1 lhs[2] * rhs[3] OP2 lhs[2] * rhs[4]
function merge(lhs::BinaryExpr{OP1}, rhs::BinaryExpr{OP2}) where {OP1, OP2}
    r = BinaryExpr(OP2, merge(lhs.left, rhs.left), merge(lhs.left, rhs.right))
    r = BinaryExpr(OP1, r, merge(lhs.right, rhs.left))
    r = BinaryExpr(OP2, r, merge(lhs.right, rhs.right))
    return r
end


Now we have finished the `kronparser`, and we can parse any avaible expression to a Kronecker expression AST. Let's try some example here.

In [None]:
ex = kronparse(:(σ₁[1] * σ₁[3] * (σ₁[2] + σ₁[5])))

We are almost there!!!

## Step 8: Expand Kronecker Expression to Julia's Expr

This is simply a tree's traversion and we can just write it with multiple dispatch, since we have already parsed the expression to different types.


#### KronProd


In [None]:
function toexpr(ex::KronProd)
    seq = []
    previous = 0; ind = 0
    while !isempty(ex.args)
        each = pop!(ex)
        val, ind = each.expr, each.index
        for i = previous+1:ind-1
            push!(seq, :σ₀)
        end
        previous = ind
        push!(seq, val)
    end

    if ind < ex.len
        for i = ind+1:ex.len
            push!(seq, :σ₀)
        end
    end
    return vec2kron(seq)
end

#### BinaryExpr

In [None]:
toexpr(ex::BinaryExpr{DirectPlus}) = :($(toexpr(ex.left)) + $(toexpr(ex.right)))
toexpr(ex::BinaryExpr{DirectMinus}) = :($(toexpr(ex.left)) - $(toexpr(ex.right)))


We have to parse generated matrix sequence to a Julia `Expr` that calls `kron` function (we will not use `kronprod`, since we want to generate source code directly)

In [None]:
function vec2kron(seq::Vector)
    a, b = seq[1], seq[2]
    ex = Expr(:call, :kron, a, b)
    for each in seq[3:end]
        ex = Expr(:call, :kron, ex, each)
    end
    return ex
end

now let's see what this will produce

In [None]:
toexpr(ex)