In [None]:
!git clone https://github.com/ToelUl/Lattice-gauge-equivariant-CNN.git

!cp -r Lattice-gauge-equivariant-CNN/lge ./

In [None]:
import numpy as np
import torch
import lge

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# SU(2) spin $\frac{1}{2}$ group equivariant and invariant model

## Initialize the group

In [42]:
su2_group = lge.SU2Group().to(device)
su2_rep_dim = su2_group.rep_dim
lie_algebra = su2_group.algebra
print(f"SU(2) representation dimension: {su2_rep_dim}")
print(f"SU(2) Lie algebra dimension: {lie_algebra.lie_alg_dim}")
print(f'Generator of SU(2) Lie algebra : {lie_algebra.generators()}')
print(f'Identity element of SU(2) group: {su2_group.identity}')

SU(2) representation dimension: 2
SU(2) Lie algebra dimension: 3
Generator of SU(2) Lie algebra : [tensor([[0.-0.0000j, 0.-0.5000j],
        [0.-0.5000j, 0.-0.0000j]]), tensor([[ 0.0000-0.j, -0.5000+0.j],
        [ 0.5000-0.j,  0.0000-0.j]]), tensor([[0.-0.5000j, 0.-0.0000j],
        [0.-0.0000j, 0.+0.5000j]])]
Identity element of SU(2) group: tensor([[1.+0.j, 0.+0.j],
        [0.+0.j, 1.+0.j]], device='cuda:0')


## Initialize the equivariant and invariant neural network

### 2D lattice

In [43]:
L = 4 # Lattice size
dims = [L, L] # Dimensions of the lattice
hidden_sizes = [5, 5] # Hidden layer sizes. Number of output channels (Wilson loops and links) per lattice site.
kernel_size = 2 # Kernel size or range for convolution, which means the maximum size of Wilson loop.
out_channels = len(dims) # Number of output channels (Wilson loops and links) per lattice site.

su2_gauge_equivariant_model_2d = lge.LgeConvNet(
    dims=dims,
    hidden_sizes=hidden_sizes,
    kernel_size=kernel_size,
    out_channels=out_channels,
    group=su2_group,
    gauge_invariant=False,
    use_norm=True,
    use_act_fn=True,
    threshold=1e-6
).to(device)

Check the model's structure

In [44]:
lge.check_model(su2_gauge_equivariant_model_2d)

LgeConvNet(
  (group): SU2Group(
    (algebra): SU2LieAlgebra()
  )
  (plaquette_layer): Plaquette(
    (group): SU2Group(
      (algebra): SU2LieAlgebra()
    )
  )
  (input_conv): LConvBilin()
  (after_input): Sequential(
    (0): LgeReLU()
    (1): TrNorm()
  )
  (hidden_block): Sequential(
    (0): LConvBilin()
    (1): LgeReLU()
    (2): TrNorm()
    (3): LConvBilin()
    (4): LgeReLU()
    (5): TrNorm()
  )
)
Total number of trainable parameters: 1184


### 4D lattice

In [45]:
L = 4 # Lattice size
dims = [L, L, L, L] # Dimensions of the lattice
hidden_sizes = [5, 5] # Hidden layer sizes. Number of output channels (Wilson loops and links) per lattice site.
kernel_size = 2 # Kernel size or range for convolution, which means the maximum size of Wilson loop.
out_channels = len(dims) # Number of output channels (Wilson loops and links) per lattice site.

su2_gauge_equivariant_model_4d = lge.LgeConvNet(
    dims=dims,
    hidden_sizes=hidden_sizes,
    kernel_size=kernel_size,
    out_channels=out_channels,
    group=su2_group,
    gauge_invariant=False,
    use_norm=True,
    use_act_fn=True,
    threshold=1e-6
).to(device)

Check the model's structure

In [46]:
lge.check_model(su2_gauge_equivariant_model_4d)

LgeConvNet(
  (group): SU2Group(
    (algebra): SU2LieAlgebra()
  )
  (plaquette_layer): Plaquette(
    (group): SU2Group(
      (algebra): SU2LieAlgebra()
    )
  )
  (input_conv): LConvBilin()
  (after_input): Sequential(
    (0): LgeReLU()
    (1): TrNorm()
  )
  (hidden_block): Sequential(
    (0): LConvBilin()
    (1): LgeReLU()
    (2): TrNorm()
    (3): LConvBilin()
    (4): LgeReLU()
    (5): TrNorm()
  )
)
Total number of trainable parameters: 1702


### Gauge invariant model with using the trace of the Wilson loop.

In [47]:
su2_gauge_invariant_model_4d = lge.LgeConvNet(
    dims=dims,
    hidden_sizes=hidden_sizes,
    kernel_size=kernel_size,
    out_channels=out_channels,
    group=su2_group,
    gauge_invariant=True,
    use_norm=True,
    use_act_fn=True,
    threshold=1e-6
).to(device)

Check the model's structure

In [48]:
lge.check_model(su2_gauge_invariant_model_4d)

LgeConvNet(
  (group): SU2Group(
    (algebra): SU2LieAlgebra()
  )
  (plaquette_layer): Plaquette(
    (group): SU2Group(
      (algebra): SU2LieAlgebra()
    )
  )
  (input_conv): LConvBilin()
  (after_input): Sequential(
    (0): LgeReLU()
    (1): TrNorm()
  )
  (hidden_block): Sequential(
    (0): LConvBilin()
    (1): LgeReLU()
    (2): TrNorm()
    (3): LConvBilin()
    (4): LgeReLU()
    (5): TrNorm()
    (6): LTrace()
  )
)
Total number of trainable parameters: 1702


### Baseline model

In [49]:
class BaselineModel(torch.nn.Module):
    def __init__(self, channels):
        super(BaselineModel, self).__init__()
        self.mlp1 = torch.nn.Linear(channels, 3)
        self.mlp2 = torch.nn.Linear(3, channels)
        self.relu = torch.nn.ReLU()

    def forward(self, x):
        shape = tuple(x.shape)
        x = torch.flatten(x, start_dim=1)
        x = self.mlp1(x)
        x = self.relu(x)
        x = self.mlp2(x)
        x = x.view(*shape)
        return x

baseline_model = BaselineModel(channels=np.prod(dims)*len(dims)*su2_rep_dim*su2_rep_dim*2).to(device)

lge.check_model(baseline_model)

BaselineModel(
  (mlp1): Linear(in_features=8192, out_features=3, bias=True)
  (mlp2): Linear(in_features=3, out_features=8192, bias=True)
  (relu): ReLU()
)
Total number of trainable parameters: 57347


## Initialize the gauge link

In [50]:
batch_size = 10
dim = len(dims) # Number of dimensions
num_spacial_points = np.prod(dims) # Number of spacial points
sample_size = batch_size * num_spacial_points * dim
link = su2_group.random_element(
    sample_size=sample_size,
    apply_map=True, # Apply the exponential map.
)
print(f"link shape: {link.shape}")
print(f'Check is unitary: {su2_group.is_unitary(link)}')

# Reshape the link to the desired shape. The shape should be (batch_size, num_spacial_points, dim, u1_rep_dim, u1_rep_dim).
link = link.view(batch_size, num_spacial_points, dim, su2_rep_dim, su2_rep_dim)
print(f"link shape: {link.shape}")
print(f'Check is unitary: {su2_group.is_unitary(link)}')
print(f'check is determinant 1: {su2_group.is_determinant_one(link)}')

link shape: torch.Size([10240, 2, 2])
Check is unitary: True
link shape: torch.Size([10, 256, 4, 2, 2])
Check is unitary: True
check is determinant 1: True


Check the type of link

In [51]:
print(f'Link type: {link.type()}')

Link type: torch.cuda.ComplexFloatTensor


The input to the lattice gauge equivariant model should be a real tensor

In [52]:
link_real = su2_group.convert_to_real(link)
print(f'Link real type: {link_real.type()}')
print(f'Link real shape: {link_real.shape}')

Link real type: torch.cuda.FloatTensor
Link real shape: torch.Size([10, 256, 4, 2, 2, 2])


## Apply the gauge transformation

### Global gauge transformation

In [53]:
left_act = su2_group.left_action_on_Cn
right_act = su2_group.right_action_on_Cn

Choose a random element from the group

In [54]:
global_g = su2_group.random_element().to(device)
print(f'Random element shape: {global_g.shape}')
print(f'Check is unitary: {su2_group.is_unitary(global_g)}')
print(f'check is determinant 1: {su2_group.is_determinant_one(global_g)}')

Random element shape: torch.Size([1, 2, 2])
Check is unitary: True
check is determinant 1: True


Left action of the group element on the link.
$g \cdot link$.

In [55]:
global_g_link = left_act(global_g, link)
print(f'g_link shape: {global_g_link.shape}')

g_link shape: torch.Size([10, 256, 4, 2, 2])


Because the group is not abelian, do the right action of the group element on the $g \cdot link$ should not return the original link.

<br>$g \cdot link \cdot g^{-1} \neq link$

In [56]:
global_g_link_g_inv = right_act(global_g, global_g_link)
print(torch.allclose(link, global_g_link_g_inv))

False


And the way to apply the gauge transformation on gauge link is

<br> $g(x) \cdot link(x)_{\mu} \cdot g^{-1}(x+\hat{\mu})$

<br>We can use the function `gauge_trans_to_gauge_link` to apply the gauge transformation on gauge link.

In [57]:
global_g_link_g_inv_ = lge.gauge_trans_to_gauge_link(
    u=link,
    global_group_element=global_g,
    local_group_elements=None,
    dims=None
)
# The output should be the same
print(torch.allclose(global_g_link_g_inv, global_g_link_g_inv_))

True


We can use the function `gauge_trans_to_wilson_loop` to apply the gauge transformation on closed loop.

In [58]:
w = lge.generate_wilson_loops(dims=dims, u=link, group=su2_group)
print(f'Wilson loops shape: {w.shape}')

global_g_w_g_inv = lge.gauge_trans_to_wilson_loop(
    w=w,
    global_group_element=global_g,
    local_group_elements=None
)
print(f'Global g*w*g_inv shape: {global_g_w_g_inv.shape}')

Wilson loops shape: torch.Size([10, 256, 6, 2, 2])
Global g*w*g_inv shape: torch.Size([10, 256, 6, 2, 2])


### Local gauge transformation

Choose random elements from the group for each spacial point.

In [59]:
spacial_points = batch_size * num_spacial_points
local_g = su2_group.random_element(sample_size=spacial_points, apply_map=True).to(device)
print(f'Random element shape: {local_g.shape}')
local_g = local_g.view(batch_size, num_spacial_points, 1, su2_rep_dim, su2_rep_dim)
print(f'Random element shape: {local_g.shape}')

Random element shape: torch.Size([2560, 2, 2])
Random element shape: torch.Size([10, 256, 1, 2, 2])


$g(x) \cdot W(x)$

In [60]:
w = lge.generate_wilson_loops(dims=dims, u=link, group=su2_group)
print(f'Wilson loops shape: {w.shape}')

local_g_w = left_act(local_g, w)
print(f'local_g_w shape: {local_g_w.shape}')

Wilson loops shape: torch.Size([10, 256, 6, 2, 2])
local_g_w shape: torch.Size([10, 256, 6, 2, 2])


Non-abelian group

<br>$g(x) \cdot W(x) \cdot g^{-1}(x) \neq W(x)$

In [61]:
local_g_w_g_inv = right_act(local_g, local_g_w)
print(torch.allclose(w, local_g_w_g_inv))

False


We can use the function `gauge_trans_to_wilson_loop` to apply the gauge transformation on closed loop.

In [62]:
local_g_w_g_inv_ = lge.gauge_trans_to_wilson_loop(
    w=w,
    global_group_element=None,
    local_group_elements=local_g,
)
print(f'Local g*w*g_inv shape: {local_g_w_g_inv.shape}')

# The output should be the same
print(torch.allclose(local_g_w_g_inv, local_g_w_g_inv_))

Local g*w*g_inv shape: torch.Size([10, 256, 6, 2, 2])
True


And the way to apply the local gauge transformation on gauge link is

<br> $g(x) \cdot link(x)_{\mu} \cdot g^{-1}(x+\hat{\mu})$

<br> We can use the function `gauge_trans_to_gauge_link` to apply the gauge transformation on gauge link.

In [63]:
local_g_link_g_inv = lge.gauge_trans_to_gauge_link(
    u=link,
    global_group_element=None,
    local_group_elements=local_g,
    dims=dims
)
print(f'Local g*link*g_inv shape: {local_g_link_g_inv.shape}')

Local g*link*g_inv shape: torch.Size([10, 256, 4, 2, 2])


### Gauge transformation with both global and local elements

In [64]:
g_w_g_inv = lge.gauge_trans_to_wilson_loop(
    w=w,
    global_group_element=global_g,
    local_group_elements=local_g,
)
print(f'g*w*g_inv shape: {g_w_g_inv.shape}')

g*w*g_inv shape: torch.Size([10, 256, 6, 2, 2])


In [65]:
g_link_g_inv = lge.gauge_trans_to_gauge_link(
    u=link,
    global_group_element=global_g,
    local_group_elements=local_g,
    dims=dims
)
print(f'g*link*g_inv shape: {g_link_g_inv.shape}')

g*link*g_inv shape: torch.Size([10, 256, 4, 2, 2])


## Check gauge equivariance and invariance

The group action on operators like the gauge link is defined as:

<br> $\mathcal{L}_{g(x), \mu} \cdot \hat{\mathcal{O}}_{\mu} \equiv g(x) \cdot \hat{\mathcal{O}}_{\mu} \cdot g^{-1}(x+\hat{\mu})$

<br>The group action on operators like the Wilson loop is defined as:

<br> $\mathcal{L}_{g(x)} \cdot W(x) \equiv g(x) \cdot W(x) \cdot g^{-1}(x)$

<br> The global gauge equivariance of the model is defined as:

<br> $\mathcal{L}_{g} \cdot Model\left(link_{\mu}\right) =  Model\left(\mathcal{L}_{g, \mu} \cdot link_{\mu}\right)$

<br> Where $g$ is a global group element.

<br> The global gauge invariance of the model is defined as:

<br> $Model\left(link_{\mu}\right) =  Model\left(\mathcal{L}_{g,\mu} \cdot link_{\mu}\right)$

<br>The local gauge equivariance of the model is defined as:

<br> $\mathcal{L}_{g(x)} \cdot Model\left(link(x)_{\mu}\right) =  Model\left(\mathcal{L}_{g(x),\mu} \cdot link(x)_{\mu}\right)$

<br> The local gauge invariance of the model is defined as:

<br> $Model\left(link(x)_{\mu}\right) =  Model\left(\mathcal{L}_{g(x),\mu} \cdot link(x)_{\mu}\right)$

### Check global gauge equivariance

The Wilson loop is always invariant under a global gauge transformation, and is always gauge equivalent when didn't trace the Wilson loop.

In [66]:
link = su2_group.random_element(sample_size=sample_size).view(batch_size, num_spacial_points, dim, su2_rep_dim, su2_rep_dim)
link_real = su2_group.convert_to_real(link)
global_g = su2_group.random_element().to(device)
global_g_link = su2_group.left_action_on_Cn(global_g, link)
global_g_link_g_inv = right_act(global_g, global_g_link)
print(f'Global g*link*g_inv shape: {global_g_link_g_inv.shape}')

Global g*link*g_inv shape: torch.Size([10, 256, 4, 2, 2])


Convert the tensor to real type for the model

In [67]:
global_g_link_g_inv_real = su2_group.convert_to_real(global_g_link_g_inv)
print(f'Global g*link*g_inv real shape: {global_g_link_g_inv_real.shape}')

Global g*link*g_inv real shape: torch.Size([10, 256, 4, 2, 2, 2])


Apply the model to the link and the transformed link.

<br>$Model \left(g \cdot link \cdot g^{-1} \right)$

<br>$Model \left( link \right)$

In [68]:
model_global_g_link_g_inv = su2_gauge_equivariant_model_4d(global_g_link_g_inv_real)
baseline_model_global_g_link_g_inv = baseline_model(global_g_link_g_inv_real)

model_link = su2_gauge_equivariant_model_4d(link_real)
baseline_model_link = baseline_model(link_real)

print('Lattice gauge equivariant model: ')
print(f'Model(g*link*g_inv) shape: {model_global_g_link_g_inv.shape}')
print(f'Model(link) shape: {model_link.shape}')
print('='*100)
print('Baseline model: ')
print(f'Model(g*link*g_inv) shape: {baseline_model_global_g_link_g_inv.shape}')
print(f'Model(link) shape: {baseline_model_link.shape}')

Lattice gauge equivariant model: 
Model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2, 2])
Model(link) shape: torch.Size([10, 256, 4, 2, 2, 2])
Baseline model: 
Model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2, 2])
Model(link) shape: torch.Size([10, 256, 4, 2, 2, 2])


Convert the output to complex type and compute

<br>$g \cdot Model(link) \cdot g^{-1}$

In [69]:
model_global_g_link_g_inv = su2_group.convert_to_complex(model_global_g_link_g_inv)
baseline_model_global_g_link_g_inv = su2_group.convert_to_complex(baseline_model_global_g_link_g_inv)

model_link = su2_group.convert_to_complex(model_link)
baseline_model_link = su2_group.convert_to_complex(baseline_model_link)

global_g_model_link_g_inv = right_act(global_g, left_act(global_g, model_link))
baseline_global_g_model_link_g_inv = right_act(global_g, left_act(global_g, baseline_model_link))

print('Lattice gauge equivariant model: ')
print(f'Model(g*link*g_inv) shape: {model_global_g_link_g_inv.shape}')
print(f'Model(link) shape: {model_link.shape}')
print('='*100)
print('Baseline model: ')
print(f'Model(g*link*g_inv) shape: {baseline_model_global_g_link_g_inv.shape}')
print(f'Model(link) shape: {baseline_model_link.shape}')

Lattice gauge equivariant model: 
Model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2])
Model(link) shape: torch.Size([10, 256, 4, 2, 2])
Baseline model: 
Model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2])
Model(link) shape: torch.Size([10, 256, 4, 2, 2])


#### Check the equivariance
<br> $g \cdot Model \left( link \right) \cdot g^{-1} = Model \left( g \cdot link \cdot g^{-1} \right)$

<br> But due to the numerical precision, the equality may not hold exactly.

<br> We can check the difference between the two outputs.

<br> $||g \cdot Model \left( link \right) \cdot g^{-1} - Model \left( g \cdot link \cdot g^{-1} \right)|| = \epsilon$

In [70]:
print('Lattice gauge equivariant model: ')
print(torch.allclose(global_g_model_link_g_inv, model_global_g_link_g_inv, atol=1e-3))
print('Error :', torch.norm(global_g_model_link_g_inv - model_global_g_link_g_inv).item())
print('='*100)
print('Baseline model: ')
print(torch.allclose(baseline_global_g_model_link_g_inv, baseline_model_global_g_link_g_inv, atol=1e-3))
print('Error :', torch.norm(baseline_global_g_model_link_g_inv - baseline_model_global_g_link_g_inv).item())

Lattice gauge equivariant model: 
True
Error : 0.00039339708746410906
Baseline model: 
False
Error : 67.69098663330078


### Check global gauge invariance

Apply the invariant model to the link and the transformed link.

<br>$Model \left(g \cdot link \cdot g^{-1} \right)$

<br>$Model \left( link \right)$

In [71]:
invariant_model_global_g_link_g_inv = su2_gauge_invariant_model_4d(global_g_link_g_inv_real)
baseline_global_g_model_link_g_inv = baseline_model(global_g_link_g_inv_real)

invariant_model_link = su2_gauge_invariant_model_4d(link_real)
baseline_model_link = baseline_model(link_real)

print('Lattice gauge invariant model: ')
print(f'Invariant model(g*link*g_inv) shape: {invariant_model_global_g_link_g_inv.shape}')
print(f'Invariant model(link) shape: {invariant_model_link.shape}')
print('='*100)
print('Baseline model: ')
print(f'Baseline model(g*link*g_inv) shape: {baseline_global_g_model_link_g_inv.shape}')
print(f'Baseline model(link) shape: {baseline_model_link.shape}')

Lattice gauge invariant model: 
Invariant model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2])
Invariant model(link) shape: torch.Size([10, 256, 4, 2])
Baseline model: 
Baseline model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2, 2])
Baseline model(link) shape: torch.Size([10, 256, 4, 2, 2, 2])


#### Check the invariance
<br> $Model \left( link \right) = Model \left( g \cdot link \cdot g^{-1} \right)$

<br> But due to the numerical precision, the equality may not hold exactly.

<br> We can check the difference between the two outputs.

<br> $||Model \left( link \right) - Model \left( g \cdot link \cdot g^{-1} \right)|| = \epsilon$

In [72]:
print('Lattice gauge invariant model: ')
print(torch.allclose(invariant_model_link, invariant_model_global_g_link_g_inv, atol=1e-3))
print('Error:', torch.norm(invariant_model_link - invariant_model_global_g_link_g_inv).item())
print('='*100)
print('Baseline model: ')
print(torch.allclose(baseline_model_link, baseline_global_g_model_link_g_inv, atol=1e-3))
print('Error:', torch.norm(baseline_model_link - baseline_global_g_model_link_g_inv).item())

Lattice gauge invariant model: 
True
Error: 0.0002274117578053847
Baseline model: 
False
Error: 14.020371437072754


### Check local gauge equivariance

The Wilson loop is always invariant under a local gauge transformation, and is always local gauge equivalent when didn't trace the Wilson loop.

In [73]:
link = su2_group.random_element(sample_size=sample_size).view(batch_size, num_spacial_points, dim, su2_rep_dim, su2_rep_dim)
link_real = su2_group.convert_to_real(link)
local_g = su2_group.random_element(sample_size=spacial_points).view(batch_size, num_spacial_points, 1, su2_rep_dim, su2_rep_dim).to(device)
local_g_link_g_inv = lge.gauge_trans_to_gauge_link(
    u=link,
    global_group_element=None,
    local_group_elements=local_g,
    dims=dims
)
print(f'Local g*link*g_inv shape: {local_g_link_g_inv.shape}')

Local g*link*g_inv shape: torch.Size([10, 256, 4, 2, 2])


Convert the tensor to real type for the model

In [74]:
local_g_link_g_inv_real = su2_group.convert_to_real(local_g_link_g_inv)
print(f'Local g*link*g_inv real shape: {local_g_link_g_inv_real.shape}')

Local g*link*g_inv real shape: torch.Size([10, 256, 4, 2, 2, 2])


Apply the model to the link and the transformed link.

<br>$Model \left( g(x) \cdot link(x)_{\mu} \cdot g^{-1}(x+\hat{\mu}) \right)$

<br>$Model \left( link(x) \right)$

In [75]:
model_local_g_link_g_inv = su2_gauge_equivariant_model_4d(local_g_link_g_inv_real)
baseline_model_local_g_link_g_inv = baseline_model(local_g_link_g_inv_real)

model_link = su2_gauge_equivariant_model_4d(link_real)
baseline_model_link = baseline_model(link_real)

print('Lattice gauge equivariant model: ')
print(f'Model(g*link*g_inv) shape: {model_local_g_link_g_inv.shape}')
print(f'Model(link) shape: {model_link.shape}')
print('='*100)
print('Baseline model: ')
print(f'Model(g*link*g_inv) shape: {baseline_model_local_g_link_g_inv.shape}')
print(f'Model(link) shape: {baseline_model_link.shape}')

Lattice gauge equivariant model: 
Model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2, 2])
Model(link) shape: torch.Size([10, 256, 4, 2, 2, 2])
Baseline model: 
Model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2, 2])
Model(link) shape: torch.Size([10, 256, 4, 2, 2, 2])


Convert the output to complex type and compute

<br>$g(x) \cdot Model \left( link(x) \right) \cdot g^{-1}(x)$

In [76]:
model_local_g_link_g_inv = su2_group.convert_to_complex(model_local_g_link_g_inv)
baseline_model_local_g_link_g_inv = su2_group.convert_to_complex(baseline_model_local_g_link_g_inv)

model_link = su2_group.convert_to_complex(model_link)
baseline_model_link = su2_group.convert_to_complex(baseline_model_link)

local_g_model_link_g_inv = lge.gauge_trans_to_wilson_loop(
    w=model_link,
    global_group_element=None,
    local_group_elements=local_g,
)
baseline_local_g_model_link_g_inv = lge.gauge_trans_to_wilson_loop(
    w=baseline_model_link,
    global_group_element=None,
    local_group_elements=local_g,
)

print('Lattice gauge equivariant model: ')
print(f'Model(g*link*g_inv) shape: {model_local_g_link_g_inv.shape}')
print(f'Model(link) shape: {model_link.shape}')
print(f'Local g*Model(link)*g_inv shape: {local_g_model_link_g_inv.shape}')
print('='*100)
print('Baseline model: ')
print(f'Model(g*link*g_inv) shape: {baseline_model_local_g_link_g_inv.shape}')
print(f'Model(link) shape: {baseline_model_link.shape}')
print(f'Local g*Model(link)*g_inv shape: {baseline_local_g_model_link_g_inv.shape}')

Lattice gauge equivariant model: 
Model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2])
Model(link) shape: torch.Size([10, 256, 4, 2, 2])
Local g*Model(link)*g_inv shape: torch.Size([10, 256, 4, 2, 2])
Baseline model: 
Model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2])
Model(link) shape: torch.Size([10, 256, 4, 2, 2])
Local g*Model(link)*g_inv shape: torch.Size([10, 256, 4, 2, 2])


#### Check the equivariance
<br> $g(x) \cdot Model \left( link(x)_{\mu} \right) \cdot g^{-1}(x+\hat{\mu}) = Model \left( g(x) \cdot link(x)_{\mu} \cdot g^{-1}(x+\hat{\mu}) \right)$

<br> But due to the numerical precision, the equality may not hold exactly.

<br> We can check the difference between the two outputs.

<br> $||g(x) \cdot Model \left( link(x)_{\mu} \right) \cdot g^{-1}(x+\hat{\mu}) - Model \left( g(x) \cdot link(x)_{\mu} \cdot g^{-1}(x+\hat{\mu}) \right)|| = \epsilon$

In [77]:
print('Lattice gauge equivariant model: ')
print(torch.allclose(local_g_model_link_g_inv, model_local_g_link_g_inv, atol=1e-3))
print('Error :', torch.norm(local_g_model_link_g_inv - model_local_g_link_g_inv).item())
print('='*100)
print('Baseline model: ')
print(torch.allclose(baseline_local_g_model_link_g_inv, baseline_model_local_g_link_g_inv, atol=1e-3))
print('Error :', torch.norm(baseline_local_g_model_link_g_inv - baseline_model_local_g_link_g_inv).item())

Lattice gauge equivariant model: 
True
Error : 0.0005367920966818929
Baseline model: 
False
Error : 70.31221008300781


And the trace of the Wilson loop should be invariant under a local gauge transformation.

In [78]:
tr_local_g_model_link_g_inv = torch.einsum('bndii->bnd', local_g_model_link_g_inv)
tr_model_local_g_link_g_inv = torch.einsum('bndii->bnd', model_local_g_link_g_inv)
print(torch.norm(tr_local_g_model_link_g_inv - tr_model_local_g_link_g_inv))

tensor(0.0003, device='cuda:0', grad_fn=<LinalgVectorNormBackward0>)


### Local gauge invariance

Apply the invariant model to the link and the transformed link.

<br>$Model \left( g(x) \cdot link(x)_{\mu} \cdot g^{-1}(x+\hat{\mu}) \right)$

<br>$Model \left( link(x)_{\mu} \right)$

In [79]:
invariant_model_local_g_link_g_inv = su2_gauge_invariant_model_4d(local_g_link_g_inv_real)
baseline_model_local_g_link_g_inv = baseline_model(local_g_link_g_inv_real)

invariant_model_link = su2_gauge_invariant_model_4d(link_real)
baseline_model_link = baseline_model(link_real)

print('Lattice gauge invariant model: ')
print(f'Invariant model(g*link*g_inv) shape: {invariant_model_local_g_link_g_inv.shape}')
print(f'Invariant model(link) shape: {invariant_model_link.shape}')
print('='*100)
print('Baseline model: ')
print(f'Baseline model(g*link*g_inv) shape: {baseline_model_local_g_link_g_inv.shape}')
print(f'Baseline model(link) shape: {baseline_model_link.shape}')

Lattice gauge invariant model: 
Invariant model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2])
Invariant model(link) shape: torch.Size([10, 256, 4, 2])
Baseline model: 
Baseline model(g*link*g_inv) shape: torch.Size([10, 256, 4, 2, 2, 2])
Baseline model(link) shape: torch.Size([10, 256, 4, 2, 2, 2])


#### Check the invariance
<br> $Model \left( link(x) \right) = Model \left( g(x) \cdot link(x)_{\mu} \cdot g^{-1}(x+\hat{\mu}) \right)$

<br> But due to the numerical precision, the equality may not hold exactly.

<br> We can check the difference between the two outputs.

<br> $||Model \left( link(x) \right) - Model \left( g(x) \cdot link(x)_{\mu} \cdot g^{-1}(x+\hat{\mu}) \right)|| = \epsilon$

In [80]:
print('Lattice gauge invariant model: ')
print(torch.allclose(invariant_model_link, invariant_model_local_g_link_g_inv, atol=1e-3))
print('Error:', torch.norm(invariant_model_link - invariant_model_local_g_link_g_inv).item())
print('='*100)
print('Baseline model: ')
print(torch.allclose(baseline_model_link, baseline_model_local_g_link_g_inv, atol=1e-3))
print('Error:', torch.norm(baseline_model_link - baseline_model_local_g_link_g_inv).item())

Lattice gauge invariant model: 
True
Error: 0.00035227451007813215
Baseline model: 
False
Error: 23.86435317993164
