In [None]:
import numpy as np
import matplotlib.pyplot as plt
import tensornetwork as tn
%matplotlib inline

plt.rcParams['figure.figsize'] = [9, 5]

# Problem 1: Tensor contractions

In [None]:
dim = 20
A = np.random.random([dim]*4)
B = np.random.random([dim]*4)
A.shape, B.shape

Let's contract the tensors A and B using a fast BLAS matrix multiplication
(Check `np.show_config()` that you have BLAS for speed ups)

In [None]:
# declaring the tensors as nodes
An = tn.Node(A, name='A')
Bn = tn.Node(B, name='B')
# connecting the nodes
tn.connect(An[1], Bn[0])
tn.connect(An[3], Bn[3])
tn.to_graphviz([An, Bn])

In [None]:
# contraction consumes the original network, therefor we work on a copy
ABn = tn.contract_between(*tn.replicate_nodes([An, Bn]))
AB_einsum = np.einsum('imjn,mkln', A, B)
assert np.allclose(ABn.tensor, AB_einsum)

In [None]:
# We can also name the nodes
# declaring the tensors as axis
An = tn.Node(A, name='A', axis_names="imjn")
Bn = tn.Node(B, name='B', axis_names="mkln")
# connecting the nodes
tn.connect(An["m"], Bn["m"], name="m")
tn.connect(An["n"], Bn["n"], name="n")
tn.to_graphviz([An, Bn])

In [None]:
# Lets also check speed
%timeit tn.contract_between(*tn.replicate_nodes([An, Bn]))

In [None]:
# We can speed things up using "jax",
# it can even be used to put things on the GPU
tn.set_default_backend("jax")
# We can also name the nodes
# declaring the tensors as axis
An = tn.Node(A, name='A', axis_names="imjn")
Bn = tn.Node(B, name='B', axis_names="mkln")
# connecting the nodes
tn.connect(An["m"], Bn["m"], name="m")
tn.connect(An["n"], Bn["n"], name="n")
%timeit tn.contract_between(*tn.replicate_nodes([An, Bn]))

# Problem 2: Compression

In [None]:
from scipy import misc
face = misc.face(gray=True)
plt.imshow(face, cmap='gray');

In [None]:
facen = tn.Node(face.astype(float), name='face')

# truncate the singular values for compression
num = 150
us, vhs, truncs = tn.split_node(facen, max_singular_values=num,
                                left_edges=[facen[0]], right_edges=[facen[1]])
face_trunc = tn.contract_between(us, vhs)
__, axes = plt.subplots(ncols=2)
axes[0].set_title("original")
axes[0].imshow(face, cmap="gray")
axes[1].set_title("compressed")
axes[1].imshow(face_trunc.tensor, cmap="gray");
print("Compression:", (us.tensor.size + vhs.tensor.size)/face.size)

In [None]:
# the error is given by the truncated singular values
# Frobenius norm
ferr_singular = np.sqrt(np.sum(truncs**2))
ferr_norm = np.linalg.norm(face - face_trunc.tensor, ord='fro')
print("Frobenius norm:", ferr_singular, ferr_norm)
# Spectral norm
serr_singular = truncs[0]
serr_norm = np.linalg.norm(face - face_trunc.tensor, ord=2)
print("Spectral norm:", serr_singular, serr_norm)

In [None]:
# alternatively, we can directly limit the truncation error:
facen = tn.Node(face.astype(float), name='face')

# truncate the singular values for compression
error = 1e-1
us, vhs, truncs = tn.split_node(facen, left_edges=[facen[0]], right_edges=[facen[1]], max_truncation_err=error, relative=True)
print(f"Truncated {truncs.size} values, that is kept {min(face.shape) - truncs.size}")
face_trunc = tn.contract_between(us, vhs)
__, axes = plt.subplots(ncols=2)
axes[0].set_title("original")
axes[0].imshow(face, cmap="gray")
axes[1].set_title("compressed")
axes[1].imshow(face_trunc.tensor, cmap="gray");
print("Compression:", (us.tensor.size + vhs.tensor.size)/face.size)
# the error is given by the truncated singular values
# Frobenius norm
ferr_singular = np.sqrt(np.sum(truncs**2))
ferr_norm = np.linalg.norm(face - face_trunc.tensor, ord='fro')
print("Frobenius norm:", ferr_singular, ferr_norm,
      "relative", ferr_norm / np.linalg.norm(face, ord="fro"))
# Spectral norm
serr_singular = truncs[0]
serr_norm = np.linalg.norm(face - face_trunc.tensor, ord=2)
print("Spectral norm:", serr_singular, serr_norm,
      "relative", serr_norm / np.linalg.norm(face, ord=2))

Let's also compress the colored image

In [None]:
face = misc.face(gray=False)
plt.imshow(face);

In [None]:
facen = tn.Node(face.astype(float), name='face')

# truncate the singular values for compression
num = 150
us, vhs, truncs = tn.split_node(facen, max_singular_values=num,
                                left_edges=facen[:1], right_edges=facen[1:])
face_trunc = tn.contract_between(us, vhs)
__, axes = plt.subplots(ncols=2)
axes[0].set_title("original")
axes[0].imshow(face)
axes[1].set_title("compressed")
axes[1].imshow(face_trunc.tensor.astype(int));
print("Compression:", (us.tensor.size + vhs.tensor.size)/face.size)