# MPS Hamiltonian Learning with TensorKrowch

In [426]:
import numpy as np
import torch
from torchvision import transforms, datasets
import tensorkrowch as tk

import jax
import jax.numpy as jnp

import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
import glob
import yaml


### Things to do:

This will be a pytorch version of the Hamiltonian learning problem, where the NN has a layer with an MPS structure from tensorkrowch. There are a few things that I need to sort out:
1. This NN doesn't train from data, it starts from an ansatz and modifies it until it finds the optimal solution
2. I still need to run the dynamics for every epoch, which consist on:
    2.1. Taking initial state and applyting rotations in the X,Y and Z directions, with the option to customize which rotations I want to apply
    2.2. Doing a time evolution of the resulting state under a Hamiltonian. The Hamiltonian contains only interaction terms, and it also must be customizable
3. After running the Hamiltonian, we extract bitstring probabilities and compute nll loss function with the input data, which are the generated bitstrings 

## Data loading

In [427]:
def load_config(config_path):
    '''Load configuration from YAML file'''
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

def load_experimental_data(config):
    """Load experimental/simulated data"""
    N = config["L"]
    chi = config['bond_dimension']
    T_max = config["t_max"]
    search_pattern = f"../data/experimental_data_quantum_sampling_L{N}_Chi_{chi}_*_counts.csv"
    files = glob.glob(search_pattern)

    if not files:
        raise FileNotFoundError(f"No data found for L={N}")

    config_file = files[0]
    file_core = config_file.replace(".csv", "").replace("../data/experimental_data_quantum_sampling_", "")
    
    print(f"\n{'='*60}")
    print(f"LOADING DATA: {file_core}")
    print(f"{'='*60}")
    
    df_counts = pd.read_csv(f"../data/experimental_data_quantum_sampling_{file_core}.csv")
        
    # Remove leading single quote if present
    if df_counts['bitstring'].astype(str).str.startswith("'").all():
        df_counts['bitstring'] = df_counts['bitstring'].str[1:]
    
    # Now extract values
    bitstrings = df_counts['bitstring'].values.astype(str)
    counts_shots = df_counts['count'].values.astype(np.int32)
    
    return bitstrings, counts_shots


def local_probability_tensor(strings, counts):
    '''Calculates probabilities of each qubit, returning vector of size L
    containing the probs of each qubit being 0 or 1'''
    L = len(strings[0])
    total_counts = sum(counts)
    
    prob_matrix = torch.zeros((1, L, 2)) #First index is batch. Needed for feeding into NN
    
    for bitstring, count in zip(strings, counts):
        for qubit in range(L):
            bit_value = int(bitstring[qubit])
            prob_matrix[0, qubit, bit_value] += count
    
    # Normalize by total counts
    prob_matrix /= total_counts
    
    return prob_matrix

## Useful functions

In [428]:
def paulis(dtype=torch.complex64):
    '''Creates single-qubit basis operators'''
    sx = torch.tensor([[0., 1.], [1., 0.]], dtype=dtype)
    sy = torch.tensor([[0., -1j], [1j, 0.]], dtype=dtype)
    sz = torch.tensor([[1., 0.], [0., -1.]], dtype=dtype)
    id2 = torch.eye(2, dtype=dtype)
    return sx, sy, sz, id2

def kron_n(ops):
    '''Tensor product of a list of operators'''
    out = ops[0]
    for A in ops[1:]:
        out = torch.kron(out, A)
    return out


def x_rotation(theta, dtype=torch.complex64):
    sx = torch.tensor([[0., 1.], [1., 0.]], dtype=dtype)
    return torch.matrix_exp(-1j * theta / 2 * sx)

def y_rotation(theta, dtype=torch.complex64):
    sy = torch.tensor([[0., -1j], [1j, 0.]], dtype=dtype)
    return torch.matrix_exp(-1j * theta / 2 * sy)

def z_rotation(theta, dtype=torch.complex64):
    sz = torch.tensor([[1., 0.], [0., -1.]], dtype=dtype)
    return torch.matrix_exp(-1j * theta / 2 * sz)

In [429]:
def prepare_initial_state(L, kind, dtype=torch.complex64):
    """Prepare initial quantum states for L qubits."""
    if kind == 'all_zeros':
        psi0 = torch.zeros(2**L, dtype=dtype)
        psi0[0] = 1.0
        
    elif kind == 'all_plus':
        plus = torch.ones(2, dtype=dtype) / np.sqrt(2)
        psi0 = plus
        for _ in range(L - 1):
            psi0 = torch.kron(psi0, plus)
            
    else:
        raise ValueError(f"Initial state '{kind}' not recognized. "
                        f"Use 'all_zeros' or 'all_plus'")
    return psi0

In [430]:
class OperatorClass:
    '''Class that contains a list of all the operator types the Hamiltonian will have
       The operators will be applied to each qubit, and we will allow for the construction of any
       combination of Pauli strings 
    '''
    def __init__(self, L, dtype=torch.complex64):

        self.L = L
        self.dim = 2**L
        self.pauli_basis = {}
        self.pauli_basis['X'], self.pauli_basis['Y'], self.pauli_basis['Z'], self.pauli_basis['I'] = paulis(dtype)
        self.operators = []
    
    def __len__(self):
        return len(self.operators)
    
    def __getitem__(self, idx):
        return self.operators[idx]
    
    def add_operators(self, pauli_string:str):
        #e.g. 'X','Y','ZZ'
        '''Adds one type of operator at a time. It loops through all the qubits, 
        and for each position does the tensor product of the whole chain, with the 
        required qubits substituted by the operators of the string'''

        if len(pauli_string) > self.L:
            raise ValueError(f"Pauli string '{pauli_string}' longer than system size {self.L}")
        
        if not all(char in 'XYZI' for char in pauli_string):
            raise ValueError(f"Invalid character in '{pauli_string}'. Use only X, Y, Z, I")
        
        for i in range(self.L - len(pauli_string) + 1):
                #Create identity operators for each qubit
                ops = [self.pauli_basis['I']]*self.L
                for j, char in enumerate(pauli_string):
                     #Build string
                     ops[i+j] = self.pauli_basis[char]
                self.operators.append(kron_n(ops))
        print(f"{pauli_string} terms added to the Hamiltonian")


        

## Manual NN

These were the functions defined by Marcin

In [431]:
def mlp_forward(params, x):
    h = x
    for layer in params[:-1]:
        h = np.tanh(h @ layer["W"] + layer["b"])
    last = params[-1]
    return h @ last["W"] + last["b"]


def init_mlp_params(layer_sizes, scale=0.1):
    params = []
    for i, (m, n) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
        # Initialize weights with scaled normal distribution
        W = scale * torch.randn((m, n))
        # Initialize biases to zero
        b = torch.zeros((n,))
        params.append({"W": W, "b": b})
    return params

## Pytorch NN

Equivalent of Marcin's functions but with Pytorch syntax

In [432]:
class MLP(nn.Module):
    def __init__(self, layer_sizes):
        super().__init__()
        self.layers = nn.ModuleList()
        
        # Create linear layers
        for i, (in_dim, out_dim) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
            self.layers.append(nn.Linear(in_dim, out_dim))
        
        # Custom initialization (similar to your function)
        self._initialize_parameters()
    
    def _initialize_parameters(self, scale=0.1):
        """Initialize weights with normal distribution and biases to zero."""
        for layer in self.layers:
            nn.init.normal_(layer.weight, mean=0.0, std=scale)
            nn.init.zeros_(layer.bias)
    
    def forward(self, x):
        # Apply all but last layer with tanh activation
        for i, layer in enumerate(self.layers[:-1]):
            x = torch.tanh(layer(x))
        
        # Last layer - linear only (no activation)
        x = self.layers[-1](x)
        return x

## TensorKrowch NN

My pytorch NN integrating MPS layer from TensorKrowch

In [433]:
#So, for input layer I should put number of possible bitstrings, and for output number of parameters to train. In the middle idk yet, but we'll see
class MPS_MLP(nn.Module):
    def __init__(self, L, chi, num_params, num_dims = []):
        super().__init__()
        self.layers = nn.ModuleList()

        layer_sizes = num_dims + [num_params]

        # 1. MPS input layer: processes L×2 features → first hidden size (will be output size if no middle layers)
        mps = tk.models.MPSLayer(
            n_features=L,
            in_dim=2,
            out_dim=layer_sizes[0],  
            bond_dim=chi
        )
        self.layers.append(mps)
        
        # 2. Middle layers (optional)
        for i in range(len(layer_sizes)-1):
            self.layers.append(nn.Linear(
                layer_sizes[i],
                layer_sizes[i+1]
            ))
        
        if len(layer_sizes) > 1:
        # 3. Final output layer: hidden → num_params
            self.layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))
        
        # Custom initialization (similar to your function)
        self._initialize_parameters()
    
    def _initialize_parameters(self, scale=0.1):
        """Initialize weights with normal distribution and biases to zero."""
        for layer in self.layers:
            if isinstance(layer, nn.Linear): #MPSLayer initializes itself
                nn.init.normal_(layer.weight, mean=0.0, std=scale)
                nn.init.zeros_(layer.bias)
    
    def forward(self, x):
        # Apply all but last layer with tanh activation
        for i, layer in enumerate(self.layers[:-1]):
            x = torch.tanh(layer(x))
        
        # Last layer - linear only (no activation)
        x = self.layers[-1](x)
        return x

In [434]:
# def compute_rotations_direct(psi, params, L):
#     """
#     Apply rotations by reshaping the state vector and applying 2x2 matrices.
#     This is the most efficient method for single-qubit rotations.
#     """
#     # Map parameter keys to rotation functions
#     rot_funcs = {
#         'rot_x': lambda theta: x_rotation(theta, dtype=psi.dtype),
#         'rot_y': lambda theta: y_rotation(theta, dtype=psi.dtype),
#         'rot_z': lambda theta: z_rotation(theta, dtype=psi.dtype)
#     }
    
#     # Initialize per-qubit rotations as identity
#     per_qubit_rots = [torch.eye(2, dtype=psi.dtype, device=psi.device) for _ in range(L)]
    
#     # Accumulate all rotations for each qubit
#     for key, rot_func in rot_funcs.items():
#         if key in params:
#             rot_list = params[key]
#             for i in range(L):
#                 if i < len(rot_list):
#                     per_qubit_rots[i] = rot_func(rot_list[i]) @ per_qubit_rots[i]
    
#     # Apply rotations qubit by qubit using tensor reshaping
#     # This avoids constructing the full 2^L × 2^L matrix
#     for i in range(L):
#         # Reshape state to separate the i-th qubit
#         shape = [2] * L
#         psi_reshaped = psi.view(*shape)
        
#         # Move the i-th dimension to the front
#         psi_reshaped = psi_reshaped.movedim(i, 0)
        
#         # Apply 2x2 rotation to each 2-element slice
#         original_shape = psi_reshaped.shape
#         psi_reshaped = psi_reshaped.view(2, -1)
#         psi_reshaped = per_qubit_rots[i] @ psi_reshaped
        
#         # Restore shape and put dimension back
#         psi_reshaped = psi_reshaped.view(original_shape)
#         psi_reshaped = psi_reshaped.movedim(0, i)
        
#         psi = psi_reshaped.reshape(-1)
    
#     return psi

In [None]:
def compute_rotations(psi, params, L):
    sx, sy, sz, id2 = paulis()

    rot_funcs = {
    'rot_x': lambda theta: x_rotation(theta, dtype=psi.dtype),
    'rot_y': lambda theta: y_rotation(theta, dtype=psi.dtype),
    'rot_z': lambda theta: z_rotation(theta, dtype=psi.dtype)
    }
    
    for key, rot_func in rot_funcs.items():
        if key in params:
            for i in range(L):
                rot = kron_n([id2]*i + [rot_func(params[key][i])] + [id2]*(L-i-1))
                #print(f'rot {key} in qubit {i} of angle {params[key][i]}')
                psi = rot@psi
                         
    return psi  
 

def rk4_step(state, H, t, dt, rhs_fun):
    dt_c = torch.asarray(dt, dtype=state.dtype)
    k1 = rhs_fun(H, t, state)
    k2 = rhs_fun(H, t + 0.5*dt_c, state + 0.5*dt_c*k1)
    k3 = rhs_fun(H, t + 0.5*dt_c, state + 0.5*dt_c*k2)
    k4 = rhs_fun(H, t + dt_c, state + dt_c*k3)
    state_next = state + (dt_c/6.0)*(k1 + 2*k2 + 2*k3 + k4)

    
    if state.ndim == 1:  # State vector
        norm = torch.linalg.norm(state_next)
        return state_next / (norm + 1e-12)
    else:  # Density matrix
        state_next = 0.5 * (state_next + state_next.conj().T)
        trace = torch.trace(state_next).real
        return state_next / (trace + 1e-12)


def build_hamiltonian(L, theta, OPS_LIST):
    '''Creates Hamiltonian from list of operators and corresponding weights'''

    expected_shape = len(OPS_LIST)
    
    if len(theta) != expected_shape or len(OPS_LIST) != expected_shape:
        raise ValueError(f"Parameter/operator count mismatch")
    
    H = torch.zeros((2**L, 2**L), dtype=torch.complex64)
    for i in range(expected_shape):
        H += theta[i] * OPS_LIST.operators[i]
    
    return H


def schrodinger_rhs(H, t, psi):
    return -1j * (H @ psi)


def evolve_state(psi_t, H, t_grid):
    rhs_fun = schrodinger_rhs
    dt = t_grid[1] - t_grid[0]
    for i,t in enumerate(t_grid[0:-1]):
        dt = t_grid[i+1] - t_grid[i]
        psi_t = rk4_step(psi_t, H, t, dt, rhs_fun)
 
    return psi_t


def time_evolution(psi, theta, OPS_LIST, L, t_grid):

    H = build_hamiltonian(L, theta, OPS_LIST)

    psi_t = evolve_state(psi, H, t_grid)

    return psi_t


def physics_computation(params, psi0, OPS_LIST, CONFIG, t_grid):

    psi_rot = compute_rotations(psi0, params, CONFIG['L'])

    psi_t = time_evolution(psi_rot, params['theta'], OPS_LIST, CONFIG['L'], t_grid)
    
    return psi_t

def nll(psi, counts):
    #Propper format of data
    counts_torch = torch.from_numpy(counts)

    #Normalized probabilities
    probs = torch.abs(psi)**2
    probs = probs / probs.sum()

    #Avoid log(0) by clipping
    probs = torch.clip(probs, 1e-9, 1.0)

    #negative log likelihood (normalized)
    logp = torch.log(probs)
    ll = torch.sum(counts_torch * logp)
    loss_nll = -ll / torch.sum(counts_torch)

    return loss_nll

In [436]:
def create_parameter_dict(params, OPS_LIST, CONFIG):
    """
    Create parameter dictionary from model output based on configuration.
    
    Args:
        params: Tensor of shape (total_params,)
        OPS_LIST: List of Hamiltonian operators
        CONFIG: Dictionary with keys:
            - 'L': number of qubits
            - 'x_fields': bool (whether to include X rotations)
            - 'y_fields': bool (whether to include Y rotations)
            - 'z_fields': bool (whether to include Z rotations)
    
    Returns:
        Dictionary with keys: 'theta', 'rot_x', 'rot_y', 'rot_z' (only if active)
    """
    predicted_params = {}

    if params.ndim > 1:
        params = params.squeeze() 
    
    # Start index for slicing params
    idx = 0
    
    # 1. Hamiltonian parameters (theta)
    n_hamiltonian = len(OPS_LIST)
    predicted_params['theta'] = params[idx:idx + n_hamiltonian]
    idx += n_hamiltonian
    
    # 2. Rotation parameters based on configuration
    L = CONFIG['L']
    
    if CONFIG.get('x_fields', False):
        predicted_params['rot_x'] = params[idx:idx + L]
        idx += L
    
    if CONFIG.get('y_fields', False):
        predicted_params['rot_y'] = params[idx:idx + L]
        idx += L
    
    if CONFIG.get('z_fields', False):
        predicted_params['rot_z'] = params[idx:idx + L]
        idx += L
    
    # Verify we used all parameters
    if idx != params.shape[0]:
        raise ValueError(f"Parameter count mismatch. Expected {idx} parameters, "
                        f"but model output has {params.shape[0]}")
    
    return predicted_params

## Training algorithm

Function that contains steps of one epoch

1. Forward step of NN
2. Simulation of physics (rotations + H evolution)
3. Calculation of nll
4. Backpropagation


Function that loops for all epochs

In [437]:

def train_model(model, params, n_epochs, single_qubit_probs, psi0, OPS_LIST, CONFIG, t_grid_fine, learning_rate, counts_shots):

    #initialization
    loss_history = []
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

   #Loss with initial parameters 
    # optimizer.zero_grad()
    # psi_t = physics_computation(params, psi0, OPS_LIST, CONFIG, t_grid_fine)
    # loss = nll(psi_t, counts_shots)
    # loss.backward()
    # optimizer.step()
    
    # loss_history.append(loss.item())


    for epoch_i in range(n_epochs):
        optimizer.zero_grad()
        
        # Forward pass: NN predicts Hamiltonian parameters
        predicted_params = {}
        output_params = model(single_qubit_probs)

        predicted_params = create_parameter_dict(output_params, OPS_LIST, CONFIG)

        #Dynamics of obtained parameters
        psi_t = physics_computation(predicted_params, psi0, OPS_LIST, CONFIG, t_grid_fine)

        #compute loss
        loss = nll(psi_t, counts_shots)

        #backpropagate and update optimizer
        loss.backward()
        optimizer.step()

        loss_history.append(loss.item())

         
    return model, predicted_params, psi_t, loss_history

In [None]:
config_file = "/Users/omichel/Desktop/qilimanjaro/projects/retech/retech_2025/tensorkrowch_version/config/MPS_learning_configuration.yaml"

#load configuration
print(config_file)
CONFIG = load_config(config_file)

# Load data
bitstrings, counts_shots = load_experimental_data(CONFIG)

#Main parameters
L = CONFIG['L']
CHI = CONFIG['bond_dimension']
inital_state_kind = CONFIG['initial_state_kind']
dim = 2**L

#Reshape data into local probabilities
single_qubit_probs = local_probability_tensor(bitstrings, counts_shots)

psi0 = prepare_initial_state(L, inital_state_kind)

#Initialize and onfigure Hamiltonian Ansatz
OPS_LIST = OperatorClass(L)

OPS_LIST.add_operators('ZZ')

NUM_COEFFICIENTS = len(OPS_LIST)

#Initialize parameters
torch.manual_seed(CONFIG["seed_init"])

theta_init = torch.rand(NUM_COEFFICIENTS, dtype=torch.float32, requires_grad=True)
# Initialize NN
NN_INPUT_DIM = L

params = {"theta": theta_init}

# Add rotation parameters for each enabled field type
if CONFIG['x_fields']:
    params["rot_x"] = torch.rand(L, dtype=torch.float32, requires_grad=True)
if CONFIG['y_fields']:
    params["rot_y"] = torch.rand(L, dtype=torch.float32, requires_grad=True)
if CONFIG['z_fields']:
    params["rot_z"] = torch.rand(L, dtype=torch.float32, requires_grad=True)

# Update NN output dimension
NN_OUTPUT_DIM = NUM_COEFFICIENTS + sum(CONFIG[f'{axis}_fields'] for axis in ['x', 'y', 'z']) * L

model = MPS_MLP(NN_INPUT_DIM, CHI, NN_OUTPUT_DIM, num_dims = []) #num_dims is for optional intermediate layers
n_epochs = CONFIG['N_epochs']

t_grid_fine = torch.arange(0.0, CONFIG["t_max"] + CONFIG["dt"]/2, CONFIG["dt"])
learning_rate = CONFIG['learning_rate']

model, final_params, psi_final, loss = train_model(model, params, n_epochs, single_qubit_probs, psi0, OPS_LIST, CONFIG, t_grid_fine, learning_rate, counts_shots)


/Users/omichel/Desktop/qilimanjaro/projects/retech/retech_2025/tensorkrowch_version/config/MPS_learning_configuration.yaml

LOADING DATA: L4_Chi_2_R50000_counts
ZZ terms added to the Hamiltonian


  return tensor[index]


In [439]:
print(final_params)
print(loss)

{'theta': tensor([-0.2385, -0.1752,  0.5253], grad_fn=<SliceBackward0>), 'rot_x': tensor([-0.0557, -0.1260,  0.0406,  0.1731], grad_fn=<SliceBackward0>), 'rot_z': tensor([ 0.0161, -0.3199, -0.4579, -0.0341], grad_fn=<SliceBackward0>)}
[2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725884914398193, 2.7725884914398193, 2.7725889682769775, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725884914398193, 2.7725884914398193, 2.7725884914398193, 2.7725887298583984, 2.7725887298583984, 2.7725884914398193, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725884914398193, 2.7725887298583984, 2.7725884914398193, 2.7725887298583984, 2.7725887298583984, 2.7725884914398193, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725887298583984, 2.7725884914398193, 2.7725887298583984, 2.7725887298583984, 2.7725884914398193, 2.77