In [1]:
import tensorflow as tf
import tensorflow_datasets as tfds
import tensornetwork as tn
import numpy as np
from tensorflow.keras.datasets import mnist
import time

tn.set_default_backend("numpy")

In [None]:
# The following code construct a MPS like this

#                        label_size
#                            |
    # bond_dim               |
# A1 --------- A2 ---------- L --------- A3 --------- An
# |            |                         |            |
# |            |                         |            |
#           feature      
def construct_MPS(bond_dim, feature, rank, label_size):
    label_site = rank // 2
    mps = tn.InfiniteMPS.random(
        d = [feature for _ in range(label_site)] + [label_size] + [feature for _ in range(rank - label_site)],
        D = [bond_dim for _ in range(rank+2)],
        dtype = np.float32
    )

    # left_region = np.stack(mps.tensors[:label_site], axis=0)
    # output_site = mps.tensors[label_site]
    # right_region = np.stack(mps.tensors[label_site+1:], axis=0)
    
    return mps
    # return nodes, connected_bonds

In [None]:
class MPSModle(tf.Module):
    def __init__(self, bond_dim, feature, rank, label_size, name=None):
        super().__init__(name=name)
        mps = construct_MPS(bond_dim, feature, rank, label_size)

    # def __call__(self, x):
    #     return self.a_variable * x + self.non_trainable_variable

In [23]:
class InputModule(tf.Module):
    def __init__(self, tensor, name=None):
        super().__init__(name=name)
        # tensor [l, feature, r]
        self.input_tensor = tf.Variable(tensor, trainable=True)

    def __call__(self, input):
        # input [feature]
        core = tn.Node(self.input_tensor.numpy())
        input_node = tn.Node(input)
        core.edges[1] ^ input_node[0]
        result = core @ input_node
        return result

class OutputModule(tf.Module):
    def __init__(self, tensor, name=None):
        super().__init__(name=name)
        # tensor [l, label_size, r]
        self.output_tensor = tf.Variable(tensor, trainable=True)

    def __call__(self):        
        return tn.Node(self.output_tensor.numpy())

class InputRegion(tf.Module):
    def __init__(self, tensors, name=None):
        super().__init__(name=name)
        # tensors [[l, feature, r]]
        input_Modules = [InputModule(tensor) for tensor in tensors]

        self.tensors = tensors
        self.input_Modules = input_Modules
    
    def __call__(self, input):
        # input [rank, feature]
        assert input.shape[0] == len(self.input_Modules)
        # nodes [[l,r]]
        nodes = [module(tensor) for module, tensor in zip(self.input_Modules, input)]
        # rel [l,r]
        for i in range(len(nodes) - 1):
            nodes[i].edges[1] ^ nodes[i+1].edges[0]
            rel = nodes[i] @ nodes[i+1]
        return rel

In [31]:
print("######################################")
print("FOR inputModule Test")
input_tensor = np.random.random_sample([3,2,3])
inputM = InputModule(input_tensor)
print(inputM.variables)
print(inputM(np.ones(2)).tensor.shape)

print("\n######################################")
print("FOR outputModule Test")
output_tensor = np.random.random_sample([3,10,3])
outputM = OutputModule(output_tensor)
print(outputM().tensor.shape[1])

print("\n######################################")
print("FOR inputRegion Test")
tensors = [np.random.random_sample([3,2,3]) for _ in range(28*27)]
inputRegion = InputRegion(tensors)
print(len(inputRegion.variables))

input = np.random.random_sample([28*27,2])
print(inputRegion(input).tensor.shape)

######################################
FOR inputModule Test
(<tf.Variable 'Variable:0' shape=(3, 2, 3) dtype=float64, numpy=
array([[[0.90062223, 0.23863093, 0.44086617],
        [0.87259808, 0.2783695 , 0.46173209]],

       [[0.29503336, 0.47423814, 0.7647958 ],
        [0.03496632, 0.39925278, 0.50692447]],

       [[0.81228597, 0.01524813, 0.18350953],
        [0.92846533, 0.37227338, 0.00361462]]])>,)
(3, 3)

######################################
FOR outputModule Test
10

######################################
FOR inputRegion Test
756
(3, 3)
