In [0]:
import numpy as np
import tensorflow as tf
import matplotlib as plt

In [0]:
class tt():
  
    def __init__(self, psi_list):
    
        self.psi = psi_list
        self.lefts = {}
        self.rights = {}
        self.mids = {}
        self.norm = None
      
    def make_mids(self):
    
        for i in range(1, len(self.psi)-1):
            left = tf.einsum('ink,jnl->ijlk', self.psi[i], tf.conj(self.psi[i]))
            self.mids[str(i) + ' ' + str(i)] = left
            for j in range(i+1, len(self.psi)-1):
                vertice = tf.einsum('ink,jnl->ijlk', self.psi[j], tf.conj(self.psi[j]))
                left = tf.einsum('ijkl,lknm->ijnm', left, vertice)
                self.mids[str(i) + ' ' + str(j)] = left
  
  
  
    def make_lefts(self):
        left = tf.einsum('ij,ik->kj', self.psi[0], tf.conj(self.psi[0]))
        self.lefts['0'] = left
        for n in range(1, len(self.psi)-1):
            self.lefts[str(n)] = (tf.einsum('ij,jikl->kl', left, self.mids['1' + ' ' + str(n)]))
          
          
          
      
    def make_rights(self):
        right = tf.einsum('ij,kj->ik', self.psi[-1], tf.conj(self.psi[-1]))
        self.rights[str(len(self.psi) - 1)] = right
        for n in range(1, len(self.psi)-1):
            self.rights[str(n)] = (tf.einsum('ijkl,lk->ij', self.mids[str(n) + ' ' + str(len(self.psi)-2)], right))
          
          
          
          
    def norm_psi(self):
        self.norm = tf.einsum('ij,ji->',self.lefts['0'],self.rights['1'])
        
    def double_derivative(self):
        result = {}
        
        for j in range(0, len(self.psi)-1):
            if j == 0:
                right = self.rights['1']
                result['0' + ' ' + str(j)] = tf.einsum('ij,kl->jlik', tf.eye(tf.shape(psi[0])[0], dtype=tf.complex64), right)
            elif j == 1:
                right = self.rights['2']
                result['1' + ' ' + str(j)] = tf.einsum('ij,klm,nm->ikjln', self.psi[0], tf.conj(self.psi[1]), right)
            else:
                right = self.rights[str(j + 1)]
                mid = self.mids['1 ' + str(j-1)]
                tf.einsum('ij,jklm,lqr,sr->ikmqs', self.psi[0], mid, tf.conj(self.psi[j]), right)
                
        
        for i in range(1, len(self.psi)):
            if i == len(self.psi) - 1:
                left = self.lefts[str(i - 1)]
                result[str(i) + ' ' + str(i)] = tf.einsum('ij,lk->ikjl', left, tf.eye(tf.shape(psi[0])[0], dtype=tf.complex64))
            elif i == len(self.psi) - 2:
                left = self.lefts[str(i - 1)]
                result[str(i) + ' ' + str(len(self.psi) - 1)] = tf.einsum('ij,jlm,no->ilnmo', left, self.psi[i], tf.conj(self.psi[i + 1]))
            else:
                left = self.lefts[str(i - 1)]
                mid = self.mids[str(i+1) + ' ' + str(len(self.psi) - 2)]
                result[str(i) + ' ' + str(len(self.psi) - 1)] = tf.einsum('ij,jkl,lmno,np->ikmop', left, self.psi[i], mid, tf.conj(self.psi[len(self.psi) - 1]))
        
        for i in range(1, len(self.psi)-1):
            left = self.lefts[str(i-1)]
            for j in range(i, len(self.psi)-1):
                if i == j:
                    right = self.rights[str(j + 1)]
                    result[str(i) + ' ' + str(j)] = tf.einsum('ij,kl,mn->iknjlm', left, tf.eye(tf.shape(psi[0])[0], dtype=tf.complex64), right)
                elif i + 1 == j:
                    right = self.rights[str(j + 1)]
                    result[str(i) + ' ' + str(j)] = tf.einsum('ij,jkl,mno,ao->ikmlna', left, self.psi[i], tf.conj(self.psi[j]), right)
                else:
                    right = self.rights[str(j + 1)]
                    mid = self.mids[str(i+1) + ' ' + str(j-1)]
                    result[str(i) + ' ' + str(j)] = tf.einsum('ij,jkl,lmno,nqr,sr->ikmoqs', left, self.psi[i], mid, tf.conj(self.psi[j]), right)
        return result
      
    def check_double_derivative(self, result):
        l = []
        for k, v in result.items():
            ind = list(map(int, k.split()))
            conj_psi = tf.conj(self.psi[ind[0]])
            conj_psi = tf.reshape(conj_psi, shape=(-1,))
            psi = self.psi[ind[1]]
            psi = tf.reshape(psi, shape=(-1,))
            V = tf.reshape(v, shape=(tf.shape(conj_psi)[0], tf.shape(psi)[0]))
            l.append(tf.einsum('ij,j,i->', V, psi, conj_psi))
        return l
              
    def check_left_right(self):
      l = []
      for i in range(1, len(self.psi)-1):
          l.append(tf.einsum('ij,ji->', self.lefts[str(i)], self.rights[str(i + 1)]))
      return l

In [0]:
tf.reset_default_graph()

psi = []
psi.append(tf.constant(np.random.randn(2, 4) + np.random.randn(2, 4) * 1j, dtype=tf.complex64))
for i in range(10):
    psi.append(tf.constant(np.random.randn(4, 2, 4) + np.random.randn(4, 2, 4) * 1j, dtype=tf.complex64))
psi.append(tf.constant(np.random.randn(4, 2) + np.random.randn(4, 2) * 1j, dtype=tf.complex64))
state = tt(psi)
state.make_mids()
state.make_lefts()
state.make_rights()
state.norm_psi()

out_mids = state.mids
out_lefts = state.lefts
out_rights = state.rights
out_norm = state.norm
result = state.double_derivative()
out = state.check_double_derivative(result)

In [35]:
sess = tf.Session()
sess.run(out)

[(41410740000000-589824j),
 (41410745000000-2457600j),
 (41410737000000-983040j),
 (41410740000000-1343488j),
 (41410740000000+1835008j),
 (41410732000000-98304j),
 (41410737000000+1703936j),
 (41410745000000+1703936j),
 (41410745000000+2204672j),
 (41410745000000+32768j),
 (41410737000000+2424832j),
 (41410745000000+2834432j),
 (41410737000000-1572864j),
 (41410745000000-458752j),
 (41410740000000+917504j),
 (41410745000000-491520j),
 (41410745000000-1081344j),
 (41410745000000-2523136j),
 (41410740000000-393216j),
 (41410745000000-1769472j),
 (41410745000000-2654208j),
 (41410745000000-2818048j),
 (41410732000000-2686976j),
 (41410732000000-983040j),
 (41410732000000+507904j),
 (41410737000000-671744j),
 (41410740000000+2277376j),
 (41410732000000+2031616j),
 (41410732000000+901120j),
 (41410737000000-1048576j),
 (41410737000000-999424j),
 (41410740000000+1048576j),
 (41410732000000+2097152j),
 (41410740000000+1835008j),
 (41410753000000+1572864j),
 (41410740000000+2359296j),
 (41410