## Supervised Learning with ITensor: MPS for Classification

[Supervised Learning with Quantum-Inspired Tensor Networks](https://arxiv.org/pdf/1605.05775), *NeurIPS 2016.*

In [1]:
using Random, Statistics
using ITensors, ITensorMPS

*Why machine learning?* Tensor trains provide a linear model from a very large feature space, with only $O(N dm^2)$ parameters. 

*Goal:* Classify $8\times 8$ grayscale images into two classes, can be extended to multi-class classification (e.g. MNIST).

In [2]:
# synthetic data generation

const H, W = 8, 8
function make_bar(imgtype::Symbol; rng=Random.default_rng())
    X = fill(0.0, H, W)
    if imgtype == :vertical
        j = rand(rng, 2:W-1)
        X[:, j] .= 1.0
    elseif imgtype == :horizontal
        i = rand(rng, 2:H-1)
        X[i, :] .= 1.0
    end
    return X
end

function synth_dataset(n_per=200; rng=Random.default_rng())
    Xv = [make_bar(:vertical; rng=rng) for _ in 1:n_per]
    Xh = [make_bar(:horizontal; rng=rng) for _ in 1:n_per]
    X = vcat(Xv, Xh)
    y = vcat(fill(1, n_per), fill(0, n_per)) # 1 = vertical, 0 = horizontal
    shuffle = randperm(rng, length(X))
    return X[shuffle], y[shuffle]
end

X, y = synth_dataset(200)
N = H * W # number of pixels, i.e. number of MPS sites

64

*How are tensors useful?* Let $x = (x_1, x_2, \ldots, x_N) \in \mathbb R^N$ be an input data point (image with $N$ pixels). A length $N$ binary vector has $2^N$ possible configurations, which is infeasible to work with directly. 

The trick is the *lift* each scalar $x_j$ to a small local feature vector and use a *tensor product* to build a structured, high-dimensional feature $\Phi(x)$ that can be efficiently represented as a matrix product state (MPS).

*Local feature map:*
$$ \phi: \mathbb R \to \mathbb R^d,\quad x_j \mapsto \phi(x_j) = \left(\phi_1(x_j), \ldots, \phi_d(x_j)\right) $$
We develop a tensor-product (rank-1, order-$N$ tensor) feature map:
$$ \Phi(x)_{s_1, s_2, \ldots, s_N} = \left[ \phi(x_1) \otimes \cdots \otimes \phi(x_N) \right]_{s_1, s_2, \ldots, s_N} = \prod_{j=1}^N \phi_{s_j}(x_j),\quad s_j \in \{1, 2, \ldots, d\} $$

For grayscale images, a simple $d=2$ choice is
$$ \phi(x_j) = \left( \cos\left(\frac{\pi}{2} x_j\right), \sin\left(\frac{\pi}{2} x_j\right) \right) \in \mathbb R^2 $$
which is *normalized*, i.e. $\|\phi(x_j)\|_2 = 1$.

*Why is this helpful?* $\Phi(x)$ lives in a $d^N$-dimensional space, which is huge even for small $N$. However, the tensor-product structure allows us to represent $\Phi(x)$ efficiently as an MPS with low bond dimension.

In [3]:
@inline feature(x::Real) = (cospi(0.5 * x), sinpi(0.5 * x)) # d = 2 feature map

feature (generic function with 1 method)

We classify data points with a *linear model* in the lifted space.
$$ f(x) = W \cdot \Phi(x) = \sum_{s_1, s_2, \ldots, s_N} W_{s_1, s_2, \ldots, s_N} \Phi(x)_{s_1, s_2, \ldots, s_N} $$
where $W$ is a weight tensor of the same order and dimension as $\Phi(x)$. To avoid working with the full $W$, we represent it as an MPS with low bond dimension $m$:
$$ W_{s_1, s_2, \ldots, s_N} = \sum_{\{\alpha\}} A^{[1]}_{s_1, \alpha_1} A^{[2]}_{\alpha_1, s_2, \alpha_2} \cdots A^{[N]}_{\alpha_{N-1}, s_N} $$
This reduces the number of parameters from $d^N$ to $O(N d m^2)$.

In [None]:
# local 2-dim indices for each pixel
sites = siteinds("Qubit", N) # qubit sites have dim 2

# random weight MPS with bond dimension m
m = 8
Wmps = randomMPS(sites, m) # normalized random MPS (our weight tensor)

# turn an image into 1-site feature Itensors matching 'sites'
function image_to_features(img::AbstractMatrix{<:Real}, sites::Vector{<:Index})
    feats = ITensor[]
    v = vec(img)
    for n in 1:length(v)
        s = sites[n]
        a, b = feature(v[n])
        phi = ITensor(s)
        phi[s => 1] = a
        phi[s => 2] = b
        push!(feats, phi)
    end
    return feats
end

# safe contraction helper (defined here for use in score)
function _chain_contract(W::MPS, feats::Vector{ITensor})
    @assert length(feats) == length(W)
    T = W[1] * feats[1]
    @inbounds for n in 2:length(feats)
        Wn_contracted = W[n] * feats[n]  # contract physical index first
        T = T * Wn_contracted              # then contract link indices
    end
    return T
end

# contract features with Wmps to get a scalar score f(x)
function score(W::MPS, img::AbstractMatrix)
    feats = image_to_features(img, siteinds(W))
    T = _chain_contract(W, feats)
    # T should be a scalar (rank-0 tensor) after contracting all indices
    return scalar(T)
end

pred(s) = s > 0 ? 1 : 0 # binary prediction from score
acc(W, Xs, ys) = mean(pred(score(W, x)) == y for (x, y) in zip(Xs, ys)) # accuracy


acc (generic function with 1 method)

To factor a big tensor $\Psi_{s_1, \ldots, s_N}$ to an MPS, we use a sequence of SVDs (tensor train SVD).

The standard algorithm for this is as follows
1. group $(s_1)$ and $(s_2, \ldots, s_N)$
2. SVD to get $U\Sigma V^T$ and set $A^{[1]} = U\Sigma$
3. absorb $V^T$ into the next tensor and repeat until all tensors are extracted.

For binary labels $y_n \in \{0, 1\}$, we use a *squared loss*
$C(W) = \frac{1}{2} \sum_{n=1}^{N_T} (f(x_n) - y_n)^2$

In [5]:
# one-site training loop

Xφ = [image_to_features(x, sites) for x in X] # precompute features

# contract features with MPS
function chain_contract(W::MPS, feats::Vector{ITensor})
    @assert length(feats) == length(W)
    T = W[1] * feats[1]
    @inbounds for n in 2:length(feats)
        Wn_contracted = W[n] * feats[n]
        T = T * Wn_contracted
    end
    return T
end

# left environment up to site j-1
function left_env(W::MPS, feats::Vector{ITensor}, j::Int)
    if j <= 1
        return nothing
    end
    T = W[1] * feats[1]
    @inbounds for n in 2:(j-1)
        Wn_contracted = W[n] * feats[n]
        T = T * Wn_contracted
    end
    return T
end

# right environment from site j+1 to N
function right_env(W::MPS, feats::Vector{ITensor}, j::Int)
    N = length(feats)
    if j >= N
        return nothing
    end
    T = W[N] * feats[N]
    @inbounds for n in (N-1):-1:(j+1)
        Wn_contracted = W[n] * feats[n]
        T = Wn_contracted * T
    end
    return T
end

# one-site gradient step at site j using small mini-batch
function update_site!(W::MPS, j; batch, lr=0.1)
    G = zero(W[j])
    for t in batch
        feats = Xφ[t]
        T_result = chain_contract(W, feats)
        yhat = scalar(T_result)
        e = (y[t] - yhat)
        
        L = left_env(W, feats, j)
        R = right_env(W, feats, j)
        ϕ = feats[j]
        
        contrib = (L === nothing ? ϕ : (L * ϕ))
        contrib = (R === nothing ? contrib : (contrib * R))
        G += (-e) * contrib
    end
    
    W[j] = W[j] - lr * G
    return W
end

# sweep training
function train_one_site!(W::MPS; epochs=4, batchsize=128, rng=Random.default_rng())
    idxs = collect(1:length(X))
    for ep in 1:epochs
        shuffle!(rng, idxs)
        for j in 1:length(W)
            batch = @view idxs[1:min(batchsize, length(idxs))]
            update_site!(W, j; batch=batch)
            circshift!(idxs, -min(batchsize, length(idxs)))
        end
        @info "epoch=$ep  acc=$(acc(W, X, y))"
    end
    return W
end

@info "initial acc" acc(Wmps, X, y)
train_one_site!(Wmps; epochs=10, batchsize=32, rng=Random.default_rng())
@info "final acc"   acc(Wmps, X, y)


┌ Info: initial acc
│   acc(Wmps, X, y) = 0.665
└ @ Main /Users/aniket/Documents/university/siam/julia-tensors-worshop/notebooks/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:80
┌ Info: epoch=1  acc=0.665
└ @ Main /Users/aniket/Documents/university/siam/julia-tensors-worshop/notebooks/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:75
┌ Info: epoch=2  acc=0.665
└ @ Main /Users/aniket/Documents/university/siam/julia-tensors-worshop/notebooks/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:75
┌ Info: epoch=3  acc=0.665
└ @ Main /Users/aniket/Documents/university/siam/julia-tensors-worshop/notebooks/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:75
┌ Info: epoch=2  acc=0.665
└ @ Main /Users/aniket/Documents/university/siam/julia-tensors-worshop/notebooks/jl_notebook_cell_df34fa98e69747e1a8f8a730347b8e2f_X13sZmlsZQ==.jl:75
┌ Info: epoch=3  acc=0.665
└ @ Main /Users/aniket/Documents/university/siam/julia-tensors-wors

In [11]:
function classification_summary(W::MPS, Xs, ys)
    predictions = [pred(score(W, x)) for x in Xs]
    correct = predictions .== ys
    
    n_correct = sum(correct)
    n_total = length(ys)
    accuracy = n_correct / n_total
    
    # by class
    vertical_indices = findall(ys .== 1)
    horizontal_indices = findall(ys .== 0)
    
    vertical_correct = sum(correct[vertical_indices])
    horizontal_correct = sum(correct[horizontal_indices])
    
    println("="^60)
    println("classification summary")
    println("="^60)
    println("overall accuracy: $(round(accuracy*100, digits=2))% ($n_correct/$n_total)")
    println()
    println("vertical bars (class 1):")
    println("  correct: $vertical_correct / $(length(vertical_indices))")
    println("  accuracy: $(round(vertical_correct/length(vertical_indices)*100, digits=2))%")
    println()
    println("horizontal bars (class 0):")
    println("  correct: $horizontal_correct / $(length(horizontal_indices))")
    println("  accuracy: $(round(horizontal_correct/length(horizontal_indices)*100, digits=2))%")
    println("="^60)
    
    # Show some misclassified examples
    misclassified = findall(.!correct)
    if length(misclassified) > 0
        println("\nmisclassified indices (first 10): ", misclassified[1:min(10, length(misclassified))])
    else
        println("\nno misclassifications! perfect accuracy!")
    end
end

classification_summary(Wmps, X, y)

classification summary
overall accuracy: 66.5% (266/400)

vertical bars (class 1):
  correct: 168 / 200
  accuracy: 84.0%

horizontal bars (class 0):
  correct: 98 / 200
  accuracy: 49.0%

misclassified indices (first 10): [2, 5, 8, 11, 12, 16, 26, 28, 33, 38]
