In [1]:
import torch
import torch.nn as nn
import math

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, head_count):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_size, head_count, batch_first=True)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, 4*embed_size),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(4*embed_size, embed_size)
        )
        self.dropout = nn.Dropout(0.25)
    
    def forward(self, x):
        x_norm = self.norm1(x)
        attention_output, _ = self.attention(x_norm,x_norm,x_norm)
        x = x + self.dropout(attention_output)
        x_norm = self.norm2(x)
        ff_output = self.feed_forward(x_norm)
        x = x + self.dropout(ff_output)
        return x
    
class Transformer_learnable(nn.Module):
    def __init__(self, vocab_size, posi_size, embed_size, num_layers, head_count):
        super(Transformer_learnable, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.word_embedding = nn.Embedding(vocab_size, embed_size)
        self.position_embedding = nn.Embedding(posi_size, embed_size)
        self.para_embedding = nn.Sequential(
#             nn.LayerNorm(3),
            nn.Linear(3, 2*embed_size),
            nn.ReLU(),
#             nn.Dropout(0.1),
#             nn.LayerNorm(2*embed_size),
            nn.Linear(2*embed_size, embed_size * 2),
            nn.ReLU(),
#             nn.Dropout(0.1),
#             nn.LayerNorm(embed_size*2),
            nn.Linear(embed_size * 2, embed_size)
        )
        
        self.layers = nn.ModuleList(
            [TransformerBlock(embed_size, head_count) for _ in range(num_layers)]
        )
        self.fc_out = nn.Linear(embed_size, 2)
        self.tokendrop = nn.Dropout(0.25)

    def forward(self, inputs, device, mask=None):
        para_inputs = inputs[:, :3].view(-1, 3).to(device).float()
        para_inputs = para_inputs / torch.sum(para_inputs, dim=1, keepdim=True)
        para_context = self.para_embedding(para_inputs).unsqueeze(1)  # (batch_size, 1, embed_size)      
        # Token embeddings
        input_tokens = inputs[:, 3:].to(torch.long).to(device)
        batch_size, token_count = input_tokens.shape[:2]
        token_embeddings = self.word_embedding(input_tokens)      
        # Sinusoidal positional encodings
        positions = torch.arange(0, token_count).expand(batch_size, token_count).to(device)
        position_embeddings = self.position_embedding(positions)    
        out = token_embeddings+position_embeddings     
        out = torch.cat((para_context, out), dim=1)

        for layer in self.layers:
            out = layer(out)       
        # Final output prediction
        out = self.fc_out(out[:, -1, :].reshape(batch_size, self.embed_size)).reshape(batch_size, 2)
        return out
        
def generate_GPTmeasureoutput(J1, J2, J3, model, num_qubits, sample_size, device):
    # Preallocate the outputs tensor to avoid repeated concatenation
    outputs = torch.full((sample_size, 3 + num_qubits * 2), 0.0, device=device)# "0.0" to set J1 J2 J3 float
    outputs[:, :3] = torch.tensor([J1, J2, J3], device=device).expand(sample_size, 3)

    with torch.no_grad():
        for i in range(num_qubits):
            # Assign random input in the corresponding position
            random_input = torch.randint(2, 5, (sample_size,), device=device, dtype=torch.long)
            outputs[:, 3 + 2 * i] = random_input
            # Forward pass through the model
            measurepro = model(outputs[:, :(3 + 2 * i + 1)], device)
            measurepro = torch.nn.functional.softmax(measurepro, dim=1)
            # Generate new index based on softmax probabilities
            newindex = torch.bernoulli(measurepro[:, 1].clamp(min=0, max=1)).to(torch.long)
            outputs[:, 4 + 2 * i] = newindex
    return outputs

In [None]:
from torch.utils.data import Dataset

class myDataset(Dataset):
    def __init__(self,measureResults):
        super().__init__()
        self.results = measureResults
        
    def __getitem__(self, index):
        squence = self.results[index]
        return squence
    
    def __len__(self):
        return len(self.results)
    
from torch.utils.data import DataLoader
batch_size = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
train_dataset = torch.load('symmodified_NNCTFI_pbc_dataset_4to10bit_10000samples_each.pt')[6*240000:].to(device)
train_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True)
test_sampleResults = torch.load('symmodified_NNCTFI_pbc_dataset_15and9_every500_10bit.pt')
test_dataset = myDataset(measureResults=test_sampleResults)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
print(train_dataset.size())

In [None]:
from torch.utils.tensorboard import SummaryWriter
from torch import nn, optim
writer = SummaryWriter()
def train_recursive(model, dataloader, optimizer, loss_fn, epoch, device):
    model.train()  # Set model to training mode
    total_loss = 0  # Initialize total loss
    num_batches = len(dataloader)
    for batch_idx, dataset in enumerate(dataloader):
        dataset = dataset.to(device)
        optimizer.zero_grad()
        batch_size, token_count = dataset.shape[0], dataset.shape[1]
        loss = 0
        for cur_count in range(4, token_count):# different from "(2,token_count)" in training of TFI model
            if cur_count % 2 == 0:
                expected_next_token_indexes = dataset[:, cur_count] 
                model_input = dataset[:, :cur_count] 
                outputs = model(model_input, device) 
                loss += loss_fn(outputs, expected_next_token_indexes.to(torch.long))

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    average_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1} Train Loss: {average_loss:.6f}")
    writer.add_scalar("Loss/train", average_loss, epoch + 1)


def test_recursive(model, dataloader, loss_fn, epoch, device):
    model.eval()
    total_loss = 0
    num_batches = len(dataloader)

    with torch.no_grad():
        for dataset in dataloader:
            dataset = dataset.to(device)
            batch_size, token_count = dataset.shape[0], dataset.shape[1]
            loss = 0
            for cur_count in range(4, token_count):
                if cur_count % 2 == 0:
                    expected_next_token_indexes = dataset[:, cur_count]
                    model_input = dataset[:, :cur_count]
                    outputs = model(model_input, device)
                    loss += loss_fn(outputs, expected_next_token_indexes.to(torch.long))
            total_loss += loss.item()

    average_loss = total_loss / num_batches
    print(f"Epoch {epoch + 1} Test Loss: {average_loss:.6f}")
    writer.add_scalar("Loss/test", average_loss, epoch + 1)

    return average_loss

vocab_size = 5 
posi_size = 20 
embed_size = 128
num_layers = 4
head_count = 4

model = Transformer_learnable(vocab_size, posi_size, embed_size, num_layers, head_count)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()
epochs = 80
best_test_loss = float('inf')

for epoch in range(epochs):
    print(f"\nEpoch {epoch + 1}/{epochs}\n{'-'*30}")
    train_recursive(model, train_loader, optimizer, loss_fn, epoch, device)
    current_test_loss = test_recursive(model, test_loader, loss_fn, epoch, device)
    # Save the best model based on test loss
    if current_test_loss < best_test_loss:
        best_test_loss = current_test_loss
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss_fn
        }, 'GPT_symmodified_NNCTFI_prompt_pbc_dataset_15and9_10000_10bit_801e-4_pn_jump_nodropoutlayernorm_large_best.pth')
        print("Saved Best Model")

writer.flush()
writer.close()
# Save the final model
torch.save({
    'epoch': epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'loss': loss_fn
}, 'GPT_symmodified_NNCTFI_prompt_pbc_dataset_15and9_10000_10bit_801e-4_pn_jump_nodropoutlayernorm_large.pth')

In [None]:
import torch
import torch.nn as nn
import numpy as np
from qiskit.quantum_info import partial_trace, DensityMatrix

X = torch.tensor([[0.,1.],[1.,0.]], dtype = torch.cfloat)
Y = torch.tensor([[0.,-1.j],[1.j,0.]], dtype = torch.cfloat)
Z = torch.tensor([[1.,0.],[0.,-1.]], dtype = torch.cfloat)
I = torch.tensor([[1.,0.],[0.,1.]], dtype = torch.cfloat)
operator_map = {0:I, 1:X, 2:Y, 3:Z}

def correlation_func(sqe_op):
    func = operator_map[int(sqe_op[0])]
    for i in range(len(sqe_op)-1):
        func = torch.kron(func,operator_map[int(sqe_op[i+1])])
    return func

def symmetry_proj(rho, num_qubits):
    smop1_index = np.zeros(num_qubits)
    smop1 = correlation_func(smop1_index)
    smop2_index = np.zeros(num_qubits)
    smop2_index[::2] = 1
    smop2 = correlation_func(smop2_index)
    smop3_index = np.zeros(num_qubits)
    smop3_index[1::2] = 1
    smop3 = correlation_func(smop3_index)
    smop4_index = np.ones(num_qubits)
    smop4 = correlation_func(smop4_index)
    rho_ = rho+smop2@rho@smop2.H+smop3@rho@smop3.H+smop4@rho@smop4.H
    return rho_/np.trace(rho_)


# Hamiltonian of NNCTFI (or ZZ2_X_ZXZ), notice that in the paper, the coefficients are g1, g2, g3
def Ham_ZZ2_X_ZXZ(J1,J2,J3,num_qubits):
    ham = 0.+0.j
    for i in range(num_qubits-2):
        ZZ_index = np.zeros(num_qubits)
        ZZ_index[i], ZZ_index[i+2] = 3, 3
        ham -= J1*correlation_func(ZZ_index)
    ZZ_index = np.zeros(num_qubits)
    ZZ_index[-2], ZZ_index[0] = 3, 3
    ham -= J1*correlation_func(ZZ_index)
    
    ZZ_index = np.zeros(num_qubits)
    ZZ_index[-1], ZZ_index[1] = 3, 3
    ham -= J1*correlation_func(ZZ_index)
    
    for i in range(num_qubits):
        X_index = np.zeros(num_qubits)
        X_index[i] = 1
        ham -= J2*correlation_func(X_index)
        
    for i in range(1, num_qubits-1):
        ZXZ_index = np.zeros(num_qubits)
        ZXZ_index[i-1], ZXZ_index[i], ZXZ_index[i+1] = 3, 1, 3
        ham -= J3*correlation_func(ZXZ_index)
    
    ZXZ_index = np.zeros(num_qubits)
    ZXZ_index[-2], ZXZ_index[-1], ZXZ_index[0] = 3, 1, 3
    ham -= J3*correlation_func(ZXZ_index)
    
    ZXZ_index = np.zeros(num_qubits)
    ZXZ_index[-1], ZXZ_index[0], ZXZ_index[1] = 3, 1, 3
    ham -= J3*correlation_func(ZXZ_index)
    
    return ham

def truth_cal_averge_two_point_correlation(J1,J2,J3,num_bits,d):
    Ham = Ham_ZZ2_X_ZXZ(J1,J2,J3,num_bits)
    eigenvalues, eigenvectors = torch.linalg.eigh(Ham)
    ground = eigenvectors[:,0].view(-1,1)
    rho = ground@ground.H
    nmz = 0
    for i in range(num_bits):
        ZZ_index = np.zeros(num_bits)
        ZZ_index[i] = 3
        ZZ_index[(i+d)%num_bits] = 3
        ope = correlation_func(ZZ_index)
        nmz += torch.trace(rho@ope)
    return nmz.real/(num_bits)

def sqe_cal_averge_two_point_pbc(sqe, d, device):
    sqe = sqe.to(device)
    num_bits = int(sqe.shape[1]/2)
    nmz = 0
    expZ = torch.zeros(sqe.shape[0], sqe.shape[1] // 2, dtype=sqe.dtype).to(device)
    cond_3_Z = (sqe[:, :-1] == 4) & (sqe[:, 1:] == 1)
    cond_neg_3_Z = (sqe[:, :-1] == 4) & (sqe[:, 1:] == 0)
    for i in range(0, expZ.shape[1] * 2, 2):
        expZ[:, i // 2][cond_3_Z[:, i]] = 3
        expZ[:, i // 2][cond_neg_3_Z[:, i]] = -3
    expZZ = expZ* torch.cat((expZ[:,d:],expZ[:,:d]),dim=1)
    return torch.sum(expZZ)/(num_bits*sqe.shape[0])

def truth_cal_groundenergy(J1,J2,J3,num_bits):
    Ham = Ham_ZZ2_X_ZXZ(J1,J2,J3,num_bits)
    eigenvalues, eigenvectors = torch.linalg.eigh(Ham)
    return eigenvalues[0].item()

def sqe_cal_groundenergy(sqe,J1,J2,J3,device):
    sqe = sqe.to(device)
    num_bits = int(sqe.shape[1]/2)
    T = sqe.shape[0]
    expX = torch.zeros(sqe.shape[0], sqe.shape[1] // 2, dtype=sqe.dtype).to(device)
    cond_3_X = (sqe[:, :-1] == 2) & (sqe[:, 1:] == 1)
    cond_neg_3_X = (sqe[:, :-1] == 2) & (sqe[:, 1:] == 0)
    for i in range(0, expX.shape[1] * 2, 2):
        expX[:, i // 2][cond_3_X[:, i]] = 3
        expX[:, i // 2][cond_neg_3_X[:, i]] = -3
        
    expZ = torch.zeros(sqe.shape[0], sqe.shape[1] // 2, dtype=sqe.dtype).to(device)
    cond_3_Z = (sqe[:, :-1] == 4) & (sqe[:, 1:] == 1)
    cond_neg_3_Z = (sqe[:, :-1] == 4) & (sqe[:, 1:] == 0)
    for i in range(0, expZ.shape[1] * 2, 2):
        expZ[:, i // 2][cond_3_Z[:, i]] = 3
        expZ[:, i // 2][cond_neg_3_Z[:, i]] = -3
    
    expZZ = expZ[:,:-2]*expZ[:,2:]
    expZXZ = expZ[:,:-2]*expX[:,1:-1]*expZ[:,2:]
#     periodic items
    ZZ1 = expZ[:,-2]*expZ[:,0]
    ZZ2 = expZ[:,-1]*expZ[:,1]
    
    ZXZ1 = expZ[:,-2]*expX[:,-1]*expZ[:,0]
    ZXZ2 = expZ[:,-1]*expX[:,0]*expZ[:,1]
    
    return -(J1*(torch.sum(expZZ)+torch.sum(ZZ1)+torch.sum(ZZ2))+J2*torch.sum(expX)+J3*(torch.sum(expZXZ)+torch.sum(ZXZ1)+torch.sum(ZXZ2)))/T

# true value of Renyi entropy
def truth_cal_renyi(J1,J2,J3,num_bits,res_num_bits):
    Ham = Ham_ZZ2_X_ZXZ(J1,J2,J3,num_bits)
    eigenvalues, eigenvectors = torch.linalg.eigh(Ham)
    ground = eigenvectors[:,0].view(-1,1)
    rho = ground@ground.H
    rho = symmetry_proj(rho,num_bits)
    rho_den = DensityMatrix(rho.numpy())
    rhoL = partial_trace(rho_den,list(range(int(num_bits-res_num_bits))))
    purity_value = rhoL.purity().real
    purity_array = np.array([purity_value])
    return -torch.log2(torch.from_numpy(purity_array))

# calculate Renyi entropy given the sequences generated by the transformer, only the part of left qubits is required
mapping = {(2, 0): 0,(2, 1): 1,(3, 0): 2,(3, 1): 3,(4, 0): 4,(4, 1): 5}
precomputed_f_matrix = torch.tensor([[ 5.0000, -4.0000,  0.5000,  0.5000,  0.5000,  0.5000],
        [-4.0000,  5.0000,  0.5000,  0.5000,  0.5000,  0.5000],
        [ 0.5000,  0.5000,  5.0000, -4.0000,  0.5000,  0.5000],
        [ 0.5000,  0.5000, -4.0000,  5.0000,  0.5000,  0.5000],
        [ 0.5000,  0.5000,  0.5000,  0.5000,  5.0000, -4.0000],
        [ 0.5000,  0.5000,  0.5000,  0.5000, -4.0000,  5.0000]])
def obtain_mappedseq(sequences, device):
    T, length = sequences.shape
    num_bits = length // 2
    pairs = sequences.view(T, num_bits, 2).to(device)
    mapping_tensor = torch.full((5, 6), -1, dtype=torch.int64, device=device)
    for (key, value) in mapping.items():
        mapping_tensor[key] = value
    mapped_sequences = mapping_tensor[pairs[..., 0], pairs[..., 1]]
    return mapped_sequences
def seq_cal_renyi(seq, device):
    device = torch.device(device)
    seq = seq.to(device)
    mapped_sequences = obtain_mappedseq(seq, device)  
    f_values = precomputed_f_matrix.to(device)[mapped_sequences.unsqueeze(1), mapped_sequences.unsqueeze(0)]
    product_f = f_values.prod(dim=-1)
    T = seq.shape[0] 
    mask = (torch.ones((T, T), dtype=torch.float32, device=device) - torch.eye(T, device=device))
    sum_f = (product_f * mask).sum()
    result = sum_f / (T * (T - 1))
    
    return -torch.log2(result)

# Median of means calculation, similar to functions for TFI model
def median_of_means_renyi(sqe, device, num_parts):
    T, _ = sqe.shape
    part_size = T // num_parts
    
    means = []
    for i in range(num_parts):
        part_sqe = sqe[i*part_size : (i+1)*part_size]
        mean_value = seq_cal_renyi(part_sqe,device)
        means.append(mean_value.item())
    
    median_value = torch.median(torch.tensor(means))
    return median_value

def MoM_energy(J1,J2,J3,sqe,device,num_parts):
    T, _ = sqe.shape
    part_size = T // num_parts
    
    means = []
    for i in range(num_parts):
        part_sqe = sqe[i*part_size : (i+1)*part_size]
        mean_value = sqe_cal_groundenergy(part_sqe,J1,J2,J3,device)
        means.append(mean_value.item())
    
    median_value = torch.median(torch.tensor(means))
    return median_value

def MoM_twopoint(sqe,d,device,num_parts):
    T, _ = sqe.shape
    part_size = T // num_parts
    
    means = []
    for i in range(num_parts):
        part_sqe = sqe[i*part_size : (i+1)*part_size]
        mean_value = sqe_cal_averge_two_point_pbc(part_sqe,d,device)
        means.append(mean_value.item())
    
    median_value = torch.median(torch.tensor(means))
    return median_value

In [None]:
# code of prediction
import torch
import numpy as np

num_bits = 10
vocab_size = 5
embed_size = 128
posi_size = 20
num_layers = 4
head_count = 4
model = Transformer_learnable(vocab_size, posi_size, embed_size, num_layers, head_count)
checkpoint = torch.load('GPT_symmodified_NNCTFI_prompt_pbc_dataset_15and9_10000_10bit_801e-4_pn_jump_nodropoutlayernorm_large_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)
model.eval()

def generate_uniform_simplex_samples(n_points):
    samples = []
    for i in range(n_points + 1):
        for j in range(n_points - i + 1):
            J1 = i / n_points
            J2 = j / n_points
            J3 = 1 - J1 - J2
            samples.append([J1, J2, J3])
    return np.array(samples)

# Number of samples along one axis
n_points = 20
# Generate uniform samples in the simplex
simplex_samples = generate_uniform_simplex_samples(n_points)
energy_list = []
twopoint_list = []
renyi_5_list = []
renyi_4_list = []
renyi_3_list = []

for sample in simplex_samples:
    J1, J2, J3 = torch.from_numpy(sample).to(torch.float)

    # Generate sequences and calculate energy and two-point correlation
    seq = generate_GPTmeasureoutput(J1, J2, J3, model, 10, 200000, device)[:, 3:].to(torch.long)
    energy = MoM_energy(J1, J2, J3, seq, device, 10)
    twopoint = MoM_twopoint(seq, 2, device, 10)
    energy_list.append(energy)
    twopoint_list.append(twopoint)
    torch.cuda.empty_cache()
    # Generate sequences and calculate Renyi entropy
    seq_renyi = generate_GPTmeasureoutput(J1, J2, J3, model, 5, 300000, device)[:, 3:].to(torch.long)
    renyi_5 = median_of_means_renyi(seq_renyi, device, 10)
    renyi_4 = median_of_means_renyi(seq_renyi[:, :-2], device, 10)
    renyi_3 = median_of_means_renyi(seq_renyi[:, :-4], device, 10)
    renyi_5_list.append(renyi_5)
    renyi_4_list.append(renyi_4)
    renyi_3_list.append(renyi_3)
    torch.cuda.empty_cache()

energy_list = torch.tensor(energy_list)
twopoint_list = torch.tensor(twopoint_list)
renyi_5_list = torch.tensor(renyi_5_list)
renyi_4_list = torch.tensor(renyi_4_list)
renyi_3_list = torch.tensor(renyi_3_list)

# Save results
torch.save(energy_list, 'NNCTFI_energy_list.pt')
torch.save(twopoint_list, 'NNCTFI_twopoint_list.pt')
torch.save(renyi_5_list, 'NNCTFI_renyi_5_list.pt')
torch.save(renyi_4_list, 'NNCTFI_renyi_4_list.pt')
torch.save(renyi_3_list, 'NNCTFI_renyi_3_list.pt')

In [None]:
# code of predictio when you want to show results on a 3D image
import plotly.graph_objects as go

num_bits=10

vocab_size = 5 
embed_size = 128
posi_size = 20
num_layers = 4
head_count = 4
model = Transformer_learnable(vocab_size, posi_size, embed_size, num_layers, head_count)
checkpoint = torch.load('GPT_symmodified_NNCTFI_prompt_pbc_dataset_15and9_10000_10bit_801e-4_pn_jump_nodropoutlayernorm_large_best.pth')
model.load_state_dict(checkpoint['model_state_dict'])
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)
model.eval()

def generate_uniform_simplex_samples(n_points):
    samples = []
    for i in range(n_points + 1):
        for j in range(n_points - i + 1):
            J1 = i / n_points
            J2 = j / n_points
            J3 = 1 - J1 - J2
            samples.append([J1, J2, J3])
    return np.array(samples)

# Number of samples along one axis
n_points = 20

# Generate uniform samples in the simplex
simplex_samples = generate_uniform_simplex_samples(n_points)

# Prepare lists to hold the values
J1_list = []
J2_list = []
J3_list = []
twopoint_list = []

# Compute energy for each sample in the simplex
for sample in simplex_samples:
    J1, J2, J3 = torch.from_numpy(sample).to(torch.float)
    J1_list.append(J1)
    J2_list.append(J2)
    J3_list.append(J3)
    seq = generate_GPTmeasureoutput(J1,J2,J3,model,10,200000,device)[:,3:].to(torch.long)
    twopoint = MoM_twopoint(seq,2,device,10)
    twopoint_list.append(twopoint)
    torch.cuda.empty_cache()
# Convert lists to tensors
J1_list = torch.tensor(J1_list)
J2_list = torch.tensor(J2_list)
J3_list = torch.tensor(J3_list)
twopoint_list = torch.tensor(twopoint_list)

# Create a 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=J1_list,
    y=J2_list,
    z=J3_list,
    mode='markers',
    marker=dict(
        size=5,
        color=twopoint_list,
        colorscale='Viridis',  # Color scale
        colorbar=dict(title='Energy')
    )
)])

# Update layout for a better look
fig.update_layout(
    title='Energy Spectrum for Symmetric J1, J2, J3 Grid',
    scene=dict(
        xaxis_title='J1',
        yaxis_title='J2',
        zaxis_title='J3'
    )
)
fig.show()