In [3]:
import numpy as np
import tensorflow as tf

import plotly.express as px
import plotly.graph_objects as go

In [4]:
def cartesian_product(*arrays):
    la = len(arrays)
    dtype = np.result_type(*arrays)
    arr = np.empty([len(a) for a in arrays] + [la], dtype=dtype)
    for i, a in enumerate(np.ix_(*arrays)):
        arr[...,i] = a
    return arr.reshape(-1, la)

In [5]:
num_quasimomentum_components = 99

max_momentum_site_component = 7
num_momentum_sites = (2*max_momentum_site_component + 1) ** 2

quasimomentum_component = np.linspace(0, 2, num_quasimomentum_components)
quasimomentum = cartesian_product(*(quasimomentum_component,)*2)

momentum_site_component = np.arange(-max_momentum_site_component, max_momentum_site_component + 1)
momentum_site = cartesian_product(*(momentum_site_component,)*2)

In [6]:
def render_momentum_site_potential(depth, height):
    momentum_site_potential = np.zeros((5, 5))

    def add(momentum_site, potential):
        momentum_site_potential[momentum_site[:, 0] + 2, momentum_site[:, 1] + 2] += potential

    add(np.array([[ 0,  0]                              ]), -1/4  * depth)
    add(np.array([[-2,  0], [+2,  0], [0,  -2], [ 0, +2]]), -1/8  * depth)
    add(np.array([[-2, -2], [-2, +2], [+2, -2], [+2, +2]]), -1/16 * depth)

    add(np.array([[ 0,  0]                              ]), -1/4  * height)
    add(np.array([[-1,  0], [+1,  0], [0,  -1], [ 0, +1]]), -1/8  * height)
    add(np.array([[-1, -1], [-1, +1], [+1, -1], [+1, +1]]), +1/16 * height)

    return momentum_site_potential

In [7]:
def spatial_from_momentum_site_potential(momentum_site_potential, num_points=100):
    spatial_potential = tf.math.real(tf.signal.rfft2d(tf.signal.ifftshift(tf.pad(momentum_site_potential, tf.constant(np.full((2, 2), num_points))))))
    return tf.concat([spatial_potential, tf.reverse(spatial_potential, [0, 1])], 1)

def tile_spatial_potential(spatial_potential, periods=1):
    return tf.tile(spatial_potential, np.full((2,), periods))

In [8]:
def compute_dispersion(momentum_site_potential):
    momentum = 2*momentum_site + quasimomentum[:, np.newaxis]
    kinetic = tf.linalg.diag(np.linalg.norm(momentum, axis=-1) ** 2)

    momentum_site_change = momentum_site[:, np.newaxis] - momentum_site
    max_momentum_site_component_change = 2 * max_momentum_site_component - 1
    momentum_site_to_index = (momentum_site_potential.shape[0] - 1) // 2 
    potential = np.pad(momentum_site_potential, max_momentum_site_component_change)[
        momentum_site_change[:, :, 0] + momentum_site_to_index + max_momentum_site_component_change, 
        momentum_site_change[:, :, 1] + momentum_site_to_index + max_momentum_site_component_change]

    energy = kinetic + potential[np.newaxis, ...]

    dispersion = tf.reshape(tf.linalg.eigvalsh(energy), (num_quasimomentum_components, num_quasimomentum_components, (2*max_momentum_site_component + 1) ** 2))

    return dispersion

In [9]:
def compute_tunneling_strengths(dispersion, band):
    return tf.signal.fftshift(tf.abs(tf.signal.fft2d(tf.cast(dispersion[:, :, band], tf.complex64))) / num_quasimomentum_components ** 2)

In [10]:
def crop_neighbor_tunneling_strengths(tunneling_strengths, radius=1):
    index_min = num_quasimomentum_components // 2 - radius
    index_max = num_quasimomentum_components // 2 + radius
    return tunneling_strengths[index_min:index_max+1, index_min:index_max+1]

# Lieb

In [11]:
spatial_potential = spatial_from_momentum_site_potential(momentum_site_potential)

figure = go.Figure(data=[go.Surface(z=tile_spatial_potential(spatial_potential, 1))])
figure.update_layout(autosize=False, width=750, height=750)
figure.show()

NameError: name 'momentum_site_potential' is not defined

In [13]:
momentum_site_potential = render_momentum_site_potential(30, 3)
momentum_site_potential

array([[-1.875 ,  0.    , -3.75  ,  0.    , -1.875 ],
       [ 0.    ,  0.1875, -0.375 ,  0.1875,  0.    ],
       [-3.75  , -0.375 , -8.25  , -0.375 , -3.75  ],
       [ 0.    ,  0.1875, -0.375 ,  0.1875,  0.    ],
       [-1.875 ,  0.    , -3.75  ,  0.    , -1.875 ]])

In [14]:
dispersion = compute_dispersion(momentum_site_potential)

In [17]:
figure = go.Figure(data=[go.Surface(x=quasimomentum_component, y=quasimomentum_component, z=dispersion[:, :, band], showscale=False) for band in range(0, 4)])
figure.update_layout(autosize=False, width=750, height=750)
figure.show()