In [None]:
# REDUCED or KEEP
# TJM :
- canonical_form (KEEP)
split_qr_contract_r_to_neighbour
- move_orthogonalization_cente (KEEP)
_move_orth_center_to_neighbour
split_qr_contract_r_to_neighbour

# TDVP :
- _orthogonalize_init (move_orthogonalization_center(KEEP))
- _split_updated_site(split_node_qr(KEEP))
- _move_orth_and_update_cache_for_path (move_orthogonalization_center(KEEP))

# subspace expansion
- canonical_form (REDUCED)
# chech ttn3.canon
# change move_orthogonalization_center to keep


In [None]:
def update_expansion_params(self, params_to_update: List[str]) -> bool:
    """
    Prompts the user to update specified expansion parameters.
    Allows the user to abort the execution gracefully.
    Args:
        params_to_update (List[str]): List of parameter names to update.
            Supports nested parameters using dot notation, e.g., 'SVDParameters.max_bond_dim'.
    Returns:
        bool: True if the update was successful or skipped, False if the user aborted.
    """
    
    print("\n--- Update Expansion Parameters ---")
    try:
        for param in params_to_update:
            keys = param.split('.')
            current_obj = self.config.Expansion_params
            try:
                # Traverse nested structures using keys
                for key in keys[:-1]:
                    if isinstance(current_obj, dict):
                        current_obj = current_obj[key]
                    else:
                        current_obj = getattr(current_obj, key)
                
                last_key = keys[-1]
                
                # Retrieve existing value
                if isinstance(current_obj, dict):
                    existing_value = current_obj[last_key]
                else:
                    existing_value = getattr(current_obj, last_key)
                
                # Handle tuple type (e.g., rel_tot_bond)
                if isinstance(existing_value, tuple) and len(existing_value) == 2 and all(isinstance(x, (int, float)) for x in existing_value):
                    print(f"Updating '{param}' which is a range (min, max): {existing_value}")
                    
                    # Prompt for new min value
                    user_input_min = input(f"Enter new MIN value for '{param}' (current: {existing_value[0]}) or press Enter to keep unchanged: ")
                    if user_input_min.strip():
                        try:
                            new_min = type(existing_value[0])(user_input_min)
                        except ValueError:
                            print(f"Invalid input for '{param}' MIN. Keeping the current value: {existing_value[0]}")
                            new_min = existing_value[0]
                    else:
                        new_min = existing_value[0]
                    
                    # Prompt for new max value
                    user_input_max = input(f"Enter new MAX value for '{param}' (current: {existing_value[1]}) or press Enter to keep unchanged: ")
                    if user_input_max.strip():
                        try:
                            new_max = type(existing_value[1])(user_input_max)
                        except ValueError:
                            print(f"Invalid input for '{param}' MAX. Keeping the current value: {existing_value[1]}")
                            new_max = existing_value[1]
                    else:
                        new_max = existing_value[1]
                    
                    # Validate that min < max
                    if new_min >= new_max:
                        print(f"Error: For '{param}', MIN ({new_min}) must be less than MAX ({new_max}). Keeping previous values: {existing_value}")
                    else:
                        new_tuple = (new_min, new_max)
                        if isinstance(current_obj, dict):
                            current_obj[last_key] = new_tuple
                        else:
                            setattr(current_obj, last_key, new_tuple)
                        print(f"Updated '{param}' to {new_tuple}")
                
                # Only prompt for numerical values (int, float)
                elif isinstance(existing_value, (int, float)):
                    user_input = input(f"Enter new value for '{param}' (current: {existing_value}) or press Enter to keep unchanged: ")
                    if user_input.strip():
                        try:
                            # Convert input to the appropriate type
                            new_value = type(existing_value)(user_input)
                            if isinstance(current_obj, dict):
                                current_obj[last_key] = new_value
                            else:
                                setattr(current_obj, last_key, new_value)
                            print(f"Updated '{param}' to {new_value}")
                        except ValueError:
                            print(f"Invalid input for '{param}'. Keeping the current value: {existing_value}")
                    else:
                        print(f"Keeping the current value for '{param}': {existing_value}")
                
                else:
                    print(f"Skipping update for '{param}' with unsupported type: {type(existing_value).__name__}")
            except (KeyError, AttributeError):
                print(f"Parameter '{param}' not found in config. Skipping.")
    except KeyboardInterrupt:
        print("\n--- Update Aborted by User ---")
        return False  # Indicate that the user aborted the update
    print("--- Update Complete ---\n")
    return True  # Indicate that the update was completed successfully

def run_ex_with_pause(self, evaluation_time: Union[int, float] = 1, filepath: str = "", pgbar: bool = True, interactive: bool = False):
    """
    Runs the expansion process over a number of time steps.

    Args:
        evaluation_time (Union[int, float], optional): Evaluation duration. Defaults to 1.
        filepath (str, optional): Path to save results. Defaults to "".
        pgbar (bool, optional): If True, shows a progress bar. Defaults to True.
        interactive (bool, optional): If True, pauses before each expansion step to update config parameters. Defaults to False.
    """
    self.init_results(evaluation_time)
    tol = self.config.Expansion_params["tol"]

    for i in self.create_run_tqdm(pgbar):
        self.evaluate_and_save_results(evaluation_time, i)
        self.run_one_time_step()

        # Expansion Step
        if (i + 1) % (self.config.Expansion_params["expansion_steps"] + 1) == 0:
            if interactive:
                print("\n--- Pausing before Expansion Step ---")
                update_success = self.update_expansion_params(["rel_tot_bond"])  # Prompt user to update specified config parameters
                if not update_success:
                    print("Aborting the run as per user request.")
                    break  # Exit the loop gracefully
                print("--- Resuming Execution ---\n")

            state_ex, tol = self.adjust_tol_and_expand(tol)
            self.state = state_ex
            self._orthogonalize_init(force_new=True)
            self.partial_tree_cache = PartialTreeCachDict()
            self.partial_tree_cache = self._init_partial_tree_cache()   

        self.record_bond_dimensions()

    self.save_results_to_file(filepath)


In [None]:

    def compute_divergence(self, state1: 'SecondOrderOneSiteTDVP.State', state2: 'SecondOrderOneSiteTDVP.State') -> float:
        """
        Computes the divergence metric between two states.
        """
        expectation1 = expectation_value(state1, self.operators[0])
        expectation2 = expectation_value(state2, self.operators[0])
        return abs(expectation1 - expectation2)

    def get_time_step(self, specific_time) -> int:
        """
        Get the time step index corresponding to a specific time.

        Args:
            specific_time (Union[int, float]): The specific time for which to get the time step index.

        Returns:
            int: The time step index corresponding to the specific time.

        Raises:
            ValueError: If the specific_time is negative or exceeds the final_time.
        """
        if specific_time > self._final_time:
            raise ValueError("specific_time cannot exceed final_time.")

        # Calculate the number of steps by performing floor division
        time_step_index = floor(specific_time / self._time_step_size)

        # Ensure the time_step_index does not exceed the maximum number of steps
        max_steps = self._num_time_steps
        if time_step_index >= max_steps:
            return max_steps - 1  # Adjust to the last valid index
        return time_step_index

    def _evolve_in_trial(self, new_state = None):   
        if new_state is None:
            trial_self = deepcopy(self)
        else:
            trial_self = deepcopy(self)
            trial_self.state = new_state
        trial_self._orthogonalize_init(force_new=True)
        trial_self.partial_tree_cache = PartialTreeCachDict()
        trial_self.partial_tree_cache = trial_self._init_partial_tree_cache()
        trial_self.run_one_time_step()
        return trial_self
        
    def run_ex(self, evaluation_time: Union[int, float] = 1, filepath: str = "", pgbar: bool = True):
        """
        Runs the expansion process over a number of time steps.

        Args:
            evaluation_time (Union[int, float], optional): Evaluation duration. Defaults to 1.
            filepath (str, optional): Path to save results. Defaults to "".
            pgbar (bool, optional): If True, shows a progress bar. Defaults to True.
            interactive (bool, optional): If True, pauses before each expansion step to update config parameters. Defaults to False.
        """
        self.init_results(evaluation_time)
        tol = self.config.Expansion_params["tol"]        
        use_previous_evolved = False
        previous_evolved_expanded = None
        should_expand = True

        for i  in self.create_run_tqdm(pgbar):
            if i < self.get_time_step(self.config.Expansion_params["InitExpST"]) or not should_expand:
                self.evaluate_and_save_results(evaluation_time, i)
                self.run_one_time_step() 
                self.accepted_states.append(self.state)
                continue

            if should_expand:
                # Evolve current state
                evolved_unexpanded_self = self._evolve_in_trial()

                # Expand and evolve
                if use_previous_evolved:
                    expanded_state = previous_evolved_expanded
                else:
                    expanded_state, tol, should_expand= self.adjust_tol_and_expand(tol)
                    
                # Evolve expanded state
                evolved_expanded_self = self._evolve_in_trial(expanded_state)

                # Compute divergence
                divergence = self.compute_divergence(evolved_unexpanded_self.state, evolved_expanded_self.state)

                if divergence > self.config.Expansion_params["ConvThresh"]:
                    self.__dict__.update(evolved_unexpanded_self.__dict__)
                    self.evaluate_and_save_results(evaluation_time, i)
                    use_previous_evolved = True
                    previous_evolved_expanded = evolved_expanded_self.state
                    self.divergence_list.append(divergence)
                    self.accepted_states.append(self.state)
                    print(divergence)

                else:
                    print("state not converged , divergence:", divergence)
                    trial = 0
                    while divergence <= self.config.Expansion_params["ConvThresh"]:
                        #self.adjust_rel_tot_bond()
                        self.state = expanded_state
                        expanded_state, tol, _ = self.adjust_tol_and_expand(tol)
                        # Evolve expanded state
                        evolved_expanded_self = self._evolve_in_trial(expanded_state)

                        divergence = self.compute_divergence(evolved_unexpanded_self.state, evolved_expanded_self.state)
                        trial += 1
                        print(f"Trial : {trial} , divergence: {divergence}")
                        if divergence > self.config.Expansion_params["ConvThreshUP"]:
                            print("state converged after expansion")
                            self.__dict__.update(evolved_expanded_self.__dict__)
                            self.evaluate_and_save_results(evaluation_time, i)
                            self.divergence_list.append(divergence)
                            self.accepted_states.append(self.state)
                            break
                    use_previous_evolved = False             
                # restore the initial rel_tot_bond to initial value        
                #self.config.Expansion_params["rel_tot_bond"] = intial_rel_tot_bond 
                     
def run_one_time_step_trial(self, new_state):
    """
    Run a single second order time step.
    
    This mean we run a full forward and a full backward sweep through the
    tree.
    """
    if new_state is None:
        trial_self = deepcopy(self)
        trial_self.run_one_time_step()
    else:    
        trial_self = deepcopy(self)
        trial_self.state = new_state
        trial_self._orthogonalize_init(force_new=True)
        trial_self.partial_tree_cache = PartialTreeCachDict()
        trial_self.partial_tree_cache = trial_self._init_partial_tree_cache()
        trial_self.run_one_time_step()  
    return trial_self.state 

def update_state(self, new_state):
    self.state = new_state
    self._orthogonalize_init(force_new=True)
    self.partial_tree_cache = PartialTreeCachDict()
    self.partial_tree_cache = self._init_partial_tree_cache()

def run_ex_2(self, evaluation_time: Union[int, float] = 1, filepath: str = "", pgbar: bool = True, interactive: bool = False):
    """
    Runs the expansion process over a number of time steps.

    Args:
        evaluation_time (Union[int, float], optional): Evaluation duration. Defaults to 1.
        filepath (str, optional): Path to save results. Defaults to "".
        pgbar (bool, optional): If True, shows a progress bar. Defaults to True.
        interactive (bool, optional): If True, pauses before each expansion step to update config parameters. Defaults to False.
    """
    self.init_results(evaluation_time)       
    tol = self.config.Expansion_params["tol"]        
    use_previous_evolved = False             
    previous_evolved_expanded = None         

    for i  in self.create_run_tqdm(pgbar):
        if i < self.get_time_step(self.config.Expansion_params["InitExpST"]):
            self.evaluate_and_save_results(evaluation_time, i)
            self.run_one_time_step() 
            continue

        # Evolve current state
        evolved_unexpanded_state = self.run_one_time_step_trial(self.state)

        # Expand and evolve
        if use_previous_evolved:
            expanded_state = previous_evolved_expanded
        else:
            state = deepcopy(self.state)
            expanded_state, tol = adjust_tol_and_expand(state, self.hamiltonian, tol, self.config)
            
        # Evolve expanded state
        evolved_expanded_state = self.run_one_time_step_trial(expanded_state)
        
        # Compute convergence
        convergence = self.compute_convergence(evolved_unexpanded_state, evolved_expanded_state)
            
        if convergence < self.config.Expansion_params["ConvThresh"]:
            use_previous_evolved = True
            previous_evolved_expanded = evolved_expanded_state
            self.update_state(evolved_unexpanded_state)
            self.evaluate_and_save_results(evaluation_time, i)
            self.convergence_list.append(convergence)

        else:
            print("state not converged , convergence:", convergence)
            evolved_expanded_state = self.run_one_time_step_trial(expanded_state)

            self.update_state(evolved_expanded_state)
            self.evaluate_and_save_results(evaluation_time, i)
            self.convergence_list.append(convergence)
            use_previous_evolved = False           
        # restore the initial rel_tot_bond to initial value        
        #self.config.Expansion_params["rel_tot_bond"] = intial_rel_tot_bond 

def run_ex_22(self, evaluation_time: Union[int, float] = 1, filepath: str = "", pgbar: bool = True, interactive: bool = False):
    """
    Runs the expansion process over a number of time steps.

    Args:
        evaluation_time (Union[int, float], optional): Evaluation duration. Defaults to 1.
        filepath (str, optional): Path to save results. Defaults to "".
        pgbar (bool, optional): If True, shows a progress bar. Defaults to True.
        interactive (bool, optional): If True, pauses before each expansion step to update config parameters. Defaults to False.
    """
    self.init_results(evaluation_time)       
    tol = self.config.Expansion_params["tol"]        
    use_previous_evolved = False             
    previous_evolved_expanded = None         

    for i  in self.create_run_tqdm(pgbar):
        if i < self.get_time_step(self.config.Expansion_params["InitExpST"]):
            self.evaluate_and_save_results(evaluation_time, i)
            self.run_one_time_step() 
            continue

        # Evolve current state
        evolved_unexpanded_state , evolved_unexpanded_ptc = self.run_one_time_step_trial()

        # Expand and evolve
        if use_previous_evolved:
            evolved_expanded_state , evolved_expanded_ptc = self.run_one_time_step_trial(previous_evolved_expanded , previous_evolved_expanded_ptc)
        else:
            state = deepcopy(self.state)
            expanded_state, tol = adjust_tol_and_expand(state, self.hamiltonian, tol, self.config)
            evolved_expanded_state , evolved_expanded_ptc = self.run_one_time_step_trial(new_state = expanded_state)
                        
        # Compute convergence
        convergence = self.compute_convergence(evolved_unexpanded_state, evolved_expanded_state)
            
        if convergence < self.config.Expansion_params["ConvThresh"]:
            use_previous_evolved = True
            previous_evolved_expanded     = evolved_expanded_state
            previous_evolved_expanded_ptc = evolved_expanded_ptc
            self.state                    = evolved_unexpanded_state
            self.partial_tree_cache       = evolved_unexpanded_ptc
            self.evaluate_and_save_results(evaluation_time, i)
            self.convergence_list.append(convergence)
        else:
            print("state not converged , convergence:", convergence)
            self.state               = evolved_expanded_state
            self.partial_tree_cache = evolved_expanded_ptc
            self.evaluate_and_save_results(evaluation_time, i)
            self.convergence_list.append(convergence)
            use_previous_evolved = False
        # restore the initial rel_tot_bond to initial value       
        #self.config.Expansion_params["rel_tot_bond"] = intial_rel_tot_bond 