## Supervised Learning with ITensor: MPS for Classification

[Supervised Learning with Quantum-Inspired Tensor Networks](https://arxiv.org/pdf/1605.05775), *NeurIPS 2016.*

In [26]:
using Random, Statistics
using ITensors, ITensorMPS

*Why machine learning?* Tensor trains provide a linear model from a very large feature space, with only $O(N dm^2)$ parameters. 

*Goal:* Classify $8\times 8$ grayscale images into two classes, can be extended to multi-class classification (e.g. MNIST).

In [27]:
# synthetic data generation

const H, W = 8, 8
function make_bar(imgtype::Symbol; rng=Random.default_rng())
    X = fill(0.0, H, W)
    if imgtype == :vertical
        j = rand(rng, 2:W-1)
        X[:, j] .= 1.0
    elseif imgtype == :horizontal
        i = rand(rng, 2:H-1)
        X[i, :] .= 1.0
    end
    return X
end

function synth_dataset(n_per=200; rng=Random.default_rng())
    Xv = [make_bar(:vertical; rng=rng) for _ in 1:n_per]
    Xh = [make_bar(:horizontal; rng=rng) for _ in 1:n_per]
    X = vcat(Xv, Xh)
    y = vcat(fill(1, n_per), fill(0, n_per)) # 1 = vertical, 0 = horizontal
    shuffle = randperm(rng, length(X))
    return X[shuffle], y[shuffle]
end

X, y = synth_dataset(200)
N = H * W # number of pixels, i.e. number of MPS sites

64

*How are tensors useful?* Let $x = (x_1, x_2, \ldots, x_N) \in \mathbb R^N$ be an input data point (image with $N$ pixels). A length $N$ binary vector has $2^N$ possible configurations, which is infeasible to work with directly. 

The trick is the *lift* each scalar $x_j$ to a small local feature vector and use a *tensor product* to build a structured, high-dimensional feature $\Phi(x)$ that can be efficiently represented as a matrix product state (MPS).

*Local feature map:*
$$ \phi: \mathbb R \to \mathbb R^d,\quad x_j \mapsto \phi(x_j) = \left(\phi_1(x_j), \ldots, \phi_d(x_j)\right) $$
We develop a tensor-product (rank-1, order-$N$ tensor) feature map:
$$ \Phi(x)_{s_1, s_2, \ldots, s_N} = \left[ \phi(x_1) \otimes \cdots \otimes \phi(x_N) \right]_{s_1, s_2, \ldots, s_N} = \prod_{j=1}^N \phi_{s_j}(x_j),\quad s_j \in \{1, 2, \ldots, d\} $$

For grayscale images, a simple $d=2$ choice is
$$ \phi(x_j) = \left( \cos\left(\frac{\pi}{2} x_j\right), \sin\left(\frac{\pi}{2} x_j\right) \right) \in \mathbb R^2 $$
which is *normalized*, i.e. $\|\phi(x_j)\|_2 = 1$.

*Why is this helpful?* $\Phi(x)$ lives in a $d^N$-dimensional space, which is huge even for small $N$. However, the tensor-product structure allows us to represent $\Phi(x)$ efficiently as an MPS with low bond dimension.

In [28]:
@inline feature(x::Real) = (cospi(0.5 * x), sinpi(0.5 * x)) # d = 2 feature map

feature (generic function with 1 method)

We classify data points with a *linear model* in the lifted space.
$$ f(x) = W \cdot \Phi(x) = \sum_{s_1, s_2, \ldots, s_N} W_{s_1, s_2, \ldots, s_N} \Phi(x)_{s_1, s_2, \ldots, s_N} $$
where $W$ is a weight tensor of the same order and dimension as $\Phi(x)$. To avoid working with the full $W$, we represent it as an MPS with low bond dimension $m$:
$$ W_{s_1, s_2, \ldots, s_N} = \sum_{\{\alpha\}} A^{[1]}_{s_1, \alpha_1} A^{[2]}_{\alpha_1, s_2, \alpha_2} \cdots A^{[N]}_{\alpha_{N-1}, s_N} $$
This reduces the number of parameters from $d^N$ to $O(N d m^2)$.

In [29]:
# local 2-dim indices for each pixel
sites = siteinds("Qubit", N) # qubit sites have dim 2

# random weight MPS with bond dimension m
m = 8
Wmps = randomMPS(sites, m) # normalized random MPS (our weight tensor)

# turn an image into 1-site feature Itensors matching 'sites'
function image_to_features(img::AbstractMatrix{<:Real}, sites::Vector{<:Index})
    feats = ITensor[]
    v = vec(img)
    for n in 1:length(v)
        s = sites[n]
        a, b = feature(v[n])
        phi = ITensor(s)
        phi[s => 1] = a
        phi[s => 2] = b
        push!(feats, phi)
    end
    return feats
end

# contract features with Wmps to get a scalar score f(x)
function score(W::MPS, img::AbstractMatrix)
    feats = image_to_features(img, siteinds(W))
    T = nothing
    @inbounds for n in 1:length(feats)
        T = isnothing(T) ? W[n] * feats[n] : T * W[n] * feats[n]
    end
    return sum(Array(T)) # scalar output
end

pred(s) = s > 0 ? 1 : 0 # binary prediction from score
acc(W, Xs, ys) = mean(pred(score(W, x)) == y for (x, y) in zip(Xs, ys)) # accuracy

acc (generic function with 1 method)

To factor a big tensor $\Psi_{s_1, \ldots, s_N}$ to an MPS, we use a sequence of SVDs (tensor train SVD).

The standard algorithm for this is as follows
1. group $(s_1)$ and $(s_2, \ldots, s_N)$
2. SVD to get $U\Sigma V^T$ and set $A^{[1]} = U\Sigma$
3. absorb $V^T$ into the next tensor and repeat until all tensors are extracted.

For binary labels $y_n \in \{0, 1\}$, we use a *squared loss*
$C(W) = \frac{1}{2} \sum_{n=1}^{N_T} (f(x_n) - y_n)^2$

In [30]:
# one-site training loop

Xphi = [image_to_features(x, sites) for x in X] # precompute features

function left_env(W::MPS, feats::Vector{ITensor}, j::Int)
    T = ITensor()
    for n in 1:j-1
        T = n == 1 ? W[n] * feats[n] : T * W[n] * feats[n]
    end
    return T
end

function right_env(W::MPS, feats::Vector{ITensor}, j::Int)
    T = ITensor()
    for n in length(feats):-1:j+1
        T = n == length(feats) ? W[n] * feats[n] : W[n] * feats[n] * T
    end
    return T
end

# helper to make indices of A and B match for contraction
_align_to(to::ITensor, ref::ITensor) = replaceinds(to, uniqueinds(to) .=> uniqueinds(ref))

# one-site gradient step at site j using small mini-batch
function update_site!(W::MPS, j; batch, lr=5e-3)
    G = zero(W[j]) # gradient accumulator
    for t in batch
        feats = Xphi[t]
        
        # prediction
        T = nothing
        @inbounds for n in 1:length(feats)
            T = isnothing(T) > W[n] * feats[n] : T * W[n] * feats[n]
        end
        yhat = sum(Array(T))
        e = (y[t] - yhat) # error

        L = left_env(W, feats, j)
        R = right_env(W, feats, j)
        phi = feats[j]

        contrib = isemptyinds(L) ? phi : L * phi
        contrib = isemptyinds(R) ? contrib : contrib * R
        G += -e * _align_to(contrib, W[j])
    end
    new_Wj = W[j] - lr * G
    replacebond!(W, j, new_Wj; maxdim=m, cutoff=1e-12)
    return W
end

# sweep training
function train_one_site!(W::MPS; epochs=4, batchsize=128, rng=Random.default_rng())
    idxs = collect(1:length(X))
    for ep in 1:epochs
        shuffle!(rng, idxs)
        for j in 1:length(W)
            batch = @view idxs[1:min(batchsize, length(idxs))]
            update_site!(W, j; batch=batch)
            rotate!(idxs, min(batchsize, length(idxs)))
        end
        @info "epoch=$ep  acc=$(acc(W, X, y))"
    end
    return W
end

@info "initial acc" acc(Wmps, X, y)
train_one_site!(Wmps; epochs=3)
@info "final acc"   acc(Wmps, X, y)

┌ Info: initial acc
│   acc(Wmps, X, y) = 0.235
└ @ Main /Users/aniket/Documents/university/siam/julia-tensors-worshop/notebooks/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X14sZmlsZQ==.jl:66


MethodError: MethodError: no method matching *(::Nothing, ::ITensor)

Closest candidates are:
  *(::Any, ::Any, !Matched::Any, !Matched::Any...)
   @ Base operators.jl:587
  *(!Matched::ITensor, ::ITensor)
   @ ITensors ~/.julia/packages/ITensors/iPhbw/src/tensor_operations/tensor_algebra.jl:60
  *(!Matched::ChainRulesCore.NoTangent, ::Any)
   @ ChainRulesCore ~/.julia/packages/ChainRulesCore/Vsbj9/src/tangent_arithmetic.jl:64
  ...
