# Verification notebook

## Instructions
1. Run the cells in order to initialize things
2. After the "Symbols" cell everything is set up and the symbolic computations for NLS and GP energies/momenta follow

In [1]:
%%javascript
IPython.OutputArea.auto_scroll_threshold = 9999;

<IPython.core.display.Javascript object>

In [2]:
from sympy import *
from IPython.display import display
import time
import pickle
import datetime
from collections import defaultdict

%env USE_SYMENGINE 1

def Equ(*args, **kwargs):
    kwargs['evaluate'] = False
    return Eq(*args, **kwargs)

env: USE_SYMENGINE=1


In [3]:
#courtesy: https://stackoverflow.com/questions/15463412/differential-operator-usable-in-matrix-form-in-python-module-sympy
from IPython.display import display
from sympy.core.decorators import call_highest_priority
from sympy import Expr, Matrix, Mul, Add, diff, Function
from sympy.core.numbers import Zero
%env USE_SYMENGINE 1

class D(Expr):
    _op_priority = 11.
    is_commutative = False
    diff_symbols = []
    diff_symbols_nc = []
    non_diff_symbols = []
    non_diff_symbols_nc = []

    def __init__(self, *variables, **assumptions):
        super(D, self).__init__()
        self.evaluate = False
        self.variables = variables

    def __repr__(self):
        return 'D%s' % str(self.variables)

    def __str__(self):
        return self.__repr__()

    @call_highest_priority('__mul__')
    def __rmul__(self, other):
        return Mul(other, self)

    @call_highest_priority('__rmul__')
    def __mul__(self, other):
        if isinstance(other, D):
            variables = self.variables + other.variables
            return D(*variables)
        
        if isinstance(other, Matrix):
            other_copy = other.copy()
            for i, elem in enumerate(other):
                other_copy[i] = self * elem
            return other_copy

        if self.evaluate:
            return D.multi_deriv(other, *self.variables)
        else:
            return Mul(self, other)

    def __pow__(self, other):
        variables = self.variables
        for i in range(other-1):
            variables += self.variables
        return D(*variables)

    @staticmethod
    def deriv(expr, xyz):
        if not hasattr(expr, 'free_symbols'):
            return 0
        
        def custom_replace(expr):
            X = Wild('X')
            F = WildFunction('F')
            def replace_helper(F, X):
                if F.args == (X,):
                    return F.func(deriv(X)) / deriv(X)
                else: return Derivative(F, X)        
            return expr.replace(Derivative(F, X), replace_helper)
        # custom_replace makes functions that can not be differentiated commute with differentiaition.
                
        res = 0

        res += custom_replace(Derivative(expr, xyz).doit())

        original = [(sym, sym.is_real) for sym in D.diff_symbols if sym in expr.free_symbols]
        original_nc = [(sym_nc, sym_c.is_real) for sym_c, sym_nc in zip(D.diff_symbols, D.diff_symbols_nc) if sym_nc in expr.free_symbols]

        for sym, is_real in original:
            deriv_term = custom_replace(Derivative(expr, sym).doit())
            if deriv_term != 0:
                newName = multi_var_deriv_name(sym, xyz)

                #dsym = Symbol(newName, commutative=True, real=True)
                dsym = Symbol(newName, commutative=True, real=is_real)
                
                if dsym not in D.diff_symbols:
                    D.create_diff_symbol(newName, real=is_real)
                    
                res += deriv_term * dsym
                
        for sym, is_real in original_nc:
            deriv_term = custom_replace(Derivative(expr, sym).doit())
            if deriv_term != 0:
                newName = multi_var_deriv_name(sym, xyz)

                dsym = Symbol(newName, commutative=False)
                
                if dsym not in D.diff_symbols_nc:
                    D.create_diff_symbol(newName, real=is_real)
                    
                res += deriv_term * dsym
                
        return res
    
    @staticmethod
    def multi_deriv(expr, *variables):
        result = expr
        for xyz in variables:
            result = D.deriv(result, xyz)
        return result
    
#     @staticmethod
#     def create_diff_symbol(name, real=None):
#         #new_symbol = Symbol(name, commutative=True, real=True)
#         new_symbol = Symbol(name, commutative=True, real=real)
#         if new_symbol not in D.diff_symbols:
#             D.diff_symbols.append(new_symbol)
#         new_symbol_nc = Symbol(name, commutative=False)
#         if new_symbol_nc not in D.diff_symbols_nc:
#             D.diff_symbols_nc.append(new_symbol_nc)
#         return new_symbol, new_symbol_nc

    @staticmethod
    def create_diff_symbol(name, real=None):
        comm = Symbol(name, commutative=True, real=real)
        nonc = Symbol(name, commutative=False)
        ic = comm in D.diff_symbols
        inc = nonc in D.diff_symbols_nc
        if not ic and not inc:
            D.diff_symbols.append(comm)
            D.diff_symbols_nc.append(nonc)
        elif ic and not inc:
            i = D.diff_symbols.index(comm)
            D.diff_symbols_nc.insert(i, nonc)
        elif inc and not ic:
            i = D.diff_symbols_nc.index(nonc)
            D.diff_symbols.insert(i, comm)
        return comm, nonc
    
#     @staticmethod
#     def create_non_diff_symbol(name, real=None):
#         # new_symbol = Symbol(name, commutative=True, real=True)
#         new_symbol = Symbol(name, commutative=True, real=real)
#         if new_symbol not in D.non_diff_symbols:
#             D.non_diff_symbols.append(new_symbol)
#         new_symbol_nc = Symbol(name, commutative=False)
#         if new_symbol_nc not in D.non_diff_symbols_nc:
#             D.non_diff_symbols_nc.append(new_symbol_nc)
#         return new_symbol, new_symbol_nc
    
    @staticmethod
    def create_non_diff_symbol(name, real=None):
        comm = Symbol(name, commutative=True, real=real)
        nonc = Symbol(name, commutative=False)
        ic = comm in D.non_diff_symbols
        inc = nonc in D.non_diff_symbols_nc
        if not ic and not inc:
            D.non_diff_symbols.append(comm)
            D.non_diff_symbols_nc.append(nonc)
        elif ic and not inc:
            i = D.non_diff_symbols.index(comm)
            D.non_diff_symbols_nc.insert(i, nonc)
        elif inc and not ic:
            i = D.non_diff_symbols_nc.index(nonc)
            D.non_diff_symbols.insert(i, comm)
        return comm, nonc

    @staticmethod
    def create_diff_symbols(*names, real=None):
        new_symbols = []
        for name in names:
            new_symbol, new_symbol_nc = D.create_diff_symbol(name, real)
            new_symbols.append(new_symbol_nc)
            new_symbols.append(new_symbol)

        return new_symbols
    
    @staticmethod
    def create_non_diff_symbols(*names, real=None):
        new_symbols = []
        for name in names:
            new_symbol, new_symbol_nc = D.create_non_diff_symbol(name, real)
            new_symbols.append(new_symbol_nc)
            new_symbols.append(new_symbol)
            
        return new_symbols
    
    @staticmethod
    def comm_to_non_comm(expr):
        if not hasattr(expr, "subs"):
            return expr
        for (sym, sym_nc) in zip(D.diff_symbols, D.diff_symbols_nc):
            expr = expr.subs(sym, sym_nc)
        for (sym, sym_nc) in zip(D.non_diff_symbols, D.non_diff_symbols_nc):
            expr = expr.subs(sym, sym_nc)
        return expr
    
    @staticmethod
    def non_comm_to_comm(expr):
        if not hasattr(expr, "subs"):
            return expr
        for (sym, sym_nc) in zip(D.diff_symbols, D.diff_symbols_nc):
            expr = expr.subs(sym_nc, sym)
        for (sym, sym_nc) in zip(D.non_diff_symbols, D.non_diff_symbols_nc):
            expr = expr.subs(sym_nc, sym)
        return expr
        
    @staticmethod
    def reset_symbols():
        D.diff_symbols = []
        D.diff_symbols_nc = []
        D.non_diff_symbols = []
        D.non_diff_symbols_nc = []
        

def var_deriv_name(var):
    if ')_{x' in var.name:
        return var.name[:-1] + 'x}'
    else:
        return '(' + var.name + ')_{x}'
        #return var.name + '_x'

def multi_var_deriv_name(var, xyz):
    if ')_{' in var.name:
        i = var.name.rindex(')_{') + 3
        derivs = var.name[i:-1]

        return var.name[:i] + ''.join(sorted(derivs + xyz.name)) + '}'
    else:
        return '(' + var.name + ')_{' + xyz.name + '}'
    
def mydiff(expr, *variables):
    if isinstance(expr, D):
        expr.variables += variables
        return D(*expr.variables)
    if isinstance(expr, Matrix):
        expr_copy = expr.copy()
        for i, elem in enumerate(expr):
            expr_copy[i] = D.multi_deriv(expr, *variables)
        return expr_copy
    if isinstance(expr, conjugate):
        return conjugate(D.multi_deriv(expr.args[0], *variables))
    if isinstance(expr, im):
        return im(D.multi_deriv(expr.args[0], *variables))
    if isinstance(expr, re):
        return re(D.multi_deriv(expr.args[0], *variables))
    
    return D.multi_deriv(expr, *variables)


def isFunction(expr):
    return hasattr(expr, 'args') and len(expr.args) > 0 and hasattr(expr, 'func')

spacing = '|   '
def evaluateMul(expr, printing=False, space=spacing, postFunc=None):
    if printing:
        print(space, 'evaluateMul', expr, expr.args)
    if hasattr(expr, 'expand'):
        expr = expr.expand()
    if expr.args:
        if printing:
            print(space, 'hasArgs')
        if isinstance(expr.args[-1], D):
            if printing:
                print(space, 'finalD: zero')
            return Zero()
    initial_args = expr.args
    for i in range(len(expr.args)-1, -1, -1):
        arg = initial_args[i]
        if hasattr(arg, 'expand'):
            arg = arg.expand()
        if printing:
            print(space, 'arg', i, 'is', arg)
        if isinstance(arg, D):
            if printing:
                print(space, 'arg is D')
            left = Mul(*initial_args[:i])
            if printing:
                print(space, 'left', left)
            right = Mul(*expr.args[i+1:])
            if printing:
                print(space, 'right', right)
            right = mydiff(right, *arg.variables)
            if printing:
                print(space, 'new right', right)
            if printing:
                print(space, 'restart')
            return processExpr(left * right, printing=printing, space=space+spacing, postFunc=postFunc)
        else:
            if printing:
                print(space, 'arg is processed further')
            arg = processExpr(arg, printing=printing, space=space+spacing, postFunc=postFunc)
            left = Mul(*initial_args[:i])
            if printing:
                print(space, 'left', left)
            right = Mul(*expr.args[i+1:])
            if printing:
                print(space, 'right', right)
            expr = left * arg * right
            if len(expr.args) < len(initial_args):
                return processExpr(expr, printing=printing, space=space+spacing, postFunc=postFunc)

    if printing:
        print(space, '--Mul-->', expr)
    return postFunc(expr) if postFunc else expr

def processExpr(expr, printing=False, space=spacing, postFunc=None):
    if printing:
        print(space, 'processExpr: ', expr)
    if hasattr(expr, 'expand'):
        expr = expr.expand()
    if isinstance(expr, Matrix):
        if printing:
            print(space, 'Matrix')
        for i, elem in enumerate(expr):
            expr[i] = processExpr(elem, printing=printing, space=space+spacing, postFunc=postFunc)
        if printing:
            print(space, 'newExpr', expr)
    elif isinstance(expr, Mul):
        if printing:
            print(space, 'Mul')
        expr = evaluateMul(expr, printing=printing, space=space+spacing, postFunc=postFunc)
    elif isinstance(expr, D):
        if printing:
            print(space, 'D')
        expr = Zero()
    elif isFunction(expr):
        if printing:
            print(space, 'Function', expr.args)
        new_args = [processExpr(a, printing=printing, space=space+spacing, postFunc=postFunc) for a in expr.args]
        expr = expr.func(*new_args)     
    if printing:
        print(space, '------->', expr)
    return postFunc(expr) if postFunc else expr

def evaluateExpr(expr, printing=False, space=spacing):
    expr = D.comm_to_non_comm(expr)
    expr = processExpr(expr, printing=printing, space=space)
    expr = D.non_comm_to_comm(expr)
    return(expr)

def test_suite():
    D.reset_symbols()
    x, x_c, y, y_c, z, z_c = D.create_non_diff_symbols('x', 'y', 'z')
    qq, qq_c, qq_conj, qq_conj_c, = D.create_diff_symbols('q', '\overline{q}')
    uu, uu_c, ff, ff_c, gg, gg_c, vv, vv_c = D.create_diff_symbols('u', 'f', 'g', 'v')
    AA, AA_c, pphi, pphi_c = D.create_diff_symbols('A', '\phi')

    theta = I / 2 * (pphi - pi / 2)

    M = 1 / sqrt(2) * Matrix([[exp(- theta), exp(theta)], [exp(- theta), - exp(theta)]])
    M_conj = 1 / sqrt(2) * Matrix([[exp(theta), exp(- theta)], [exp(theta), - exp(- theta)]])
    
    Lax = I * Matrix([[D(x), - qq], [qq_conj, - D(x)]]).subs(qq, AA * exp(I * pphi)).subs(qq_conj, AA * exp(- I * pphi))
    
    Pax = I * Matrix([[2 * D(x, x) - (AA**2 - 1), - AA * exp(I * pphi) * D(x) - D(x) * AA * exp(I * pphi)],
                    [AA * exp(- I * pphi) * D(x) + D(x) * AA * exp(- I * pphi), - 2 * D(x, x) + (AA**2 - 1)]])
    
    F = Matrix([[ff, 0], [0, gg]])
    
    res1 = evaluateExpr(Matrix([[AA - (D(x) * pphi - pphi * D(x)) / 2, I * D(x)], 
                   [I * D(x), - AA - (D(x) * pphi - pphi * D(x)) / 2]]) * F)
    
    res2 = evaluateExpr(evaluateExpr(Matrix([[0, (D(x, x) - 2 * (AA**2 - 1)) * AA * exp(I * pphi)],
                [(D(x, x) - 2 * (AA**2 - 1)) * AA * exp(- I * pphi), 0]])) * F)

    
    old_expr = M * Lax * Transpose(M_conj) * F
    expr = old_expr.copy()
    new_expr = evaluateExpr(expr)
    display(simplify(new_expr))
    if simplify(new_expr - res1) == Matrix([[0, 0], [0, 0]]):
        print('Test 1: success')
    else:
        print('Test 2: failed')
    old_expr = (Pax * Lax - Lax * Pax) * F
    expr = old_expr.copy()
    new_expr = evaluateExpr(expr)
    display(simplify(new_expr))
    if simplify(new_expr - res2) == Matrix([[0, 0], [0, 0]]):
        print('Test 2: success')
    else:
        print('Test 2: failed')
    D.reset_symbols()
    
test_suite()

env: USE_SYMENGINE=1


Matrix([
[f*(-(\phi)_{x} + 2*A)/2,               I*(g)_{x}],
[              I*(f)_{x}, g*(-(\phi)_{x} - 2*A)/2]])

Test 1: success


Matrix([
[                                                                                                    0, g*((A)_{xx} + 2*I*(A)_{x}*(\phi)_{x} + I*(\phi)_{xx}*A - (\phi)_{x}**2*A - 2*A**3 + 2*A)*exp(I*\phi)],
[f*((A)_{xx} - 2*I*(A)_{x}*(\phi)_{x} - I*(\phi)_{xx}*A - (\phi)_{x}**2*A - 2*A**3 + 2*A)*exp(-I*\phi),                                                                                                    0]])

Test 2: success


In [4]:
from collections import deque
from multiset import Multiset

def get_var_name_from_deriv(sym):
    start = sym.name.find('(')
    end = sym.name.find(')')
    if start != -1 and end != -1:
        sym_name = sym.name[start+1:end]
        return sym_name
    else:
        sym_name = sym.name
        return sym_name
        
def get_multiindex_from_deriv(sym):
    start = sym.name.find('(')
    end = sym.name.find(')')
    if start != -1 and end != -1:
        return sym.name[end+3:-1]
    else:
        return ''
    
def get_order_from_deriv(sym):
    start = sym.name.find('(')
    end = sym.name.find(')')
    if start != -1 and end != -1:
        sym_name = sym.name[start+1:end]
        return len(sym.name) - 4 - sym.name.find(')_{')
    else:
        return 0

def get_var_from_name(name, isComm):
    if isComm:
        var = next((s for s in D.diff_symbols if s.name == name and get_order_from_deriv(s) == 0), 0)
    else:
        var = next((s for s in D.diff_symbols_nc if s.name == name and get_order_from_deriv(s) == 0), 0)
    return var

def get_dvar_from_deriv_name(deriv_name, isComm, isReal=None):
    if isComm:
        var = next((s for s in D.diff_symbols if s.name == deriv_name), 
            D.create_diff_symbols(deriv_name, isReal=isReal)[1])
    else:
        var = next((s for s in D.diff_symbols_nc if s.name == deriv_name), 
            D.create_diff_symbols(deriv_name, isReal=isReal)[0])
    return var

def get_var_from_deriv(var):
    return get_var_from_name(get_var_name_from_deriv(var), var.is_commutative)
    
def deriv(poly):
    return multi_deriv(poly, [x])

def higher_deriv(poly, n):
    return multi_deriv(poly, [x] * n)
   
def multi_deriv(expr, xyz):
    if isinstance(xyz, (list, tuple)):
        result = expr
        for elem in xyz:
            result = multi_deriv(result, elem)
        return result
    
    if isinstance(expr, Matrix):
        expr_copy = expr.copy()
        for i, elem in enumerate(expr):
            expr_copy[i] = multi_deriv(elem, xyz)
        return expr_copy
    
    if not hasattr(expr, 'free_symbols'):
        return 0
    
    res = 0
    res += Derivative(expr, xyz).doit()

    fixed_symbols = (set(D.diff_symbols) | set(D.diff_symbols_nc)) & set(expr.free_symbols)
    symbols_to_iterate = list(fixed_symbols)  # Creates a separate list copy

    for sym in symbols_to_iterate:
        
        def custom_replace(expr):
            X = Wild('X')
            F = WildFunction('F')
            def replace_helper(F, X):
                if F.args == (X,):
                    return F.func(deriv(X)) / deriv(X)
                else: return Derivative(F, X)        
            return expr.replace(Derivative(F, X), replace_helper)
        # custom_replace makes functions that can not be differentiated commute with differentiaition.
        deriv_term = custom_replace(Derivative(expr, sym).doit())
        if deriv_term != 0:
            newName = multi_var_deriv_name(sym, xyz)
            if sym.is_commutative:
                dsym = next((s for s in D.diff_symbols if s.name == newName), D.create_diff_symbols(newName, real=sym.is_real)[1])
            else:
                dsym = next((s for s in D.diff_symbols_nc if s.name == newName), D.create_diff_symbols(newName, real=sym.is_real)[0])
            res += deriv_term * dsym
    
    return res

def single_subs(expr, var, sub, scale=1):
    if not hasattr(expr, 'subs'):
        return expr
        
    if isinstance(scale, dict):
        for xyz, scaling in scale.items():
            expr = expr.subs(xyz, scaling * xyz)
            
    if var not in set(D.diff_symbols) | set(D.diff_symbols_nc):
        return expr.subs(var, sub)
    
    var_name = get_var_name_from_deriv(var)
    var_multiindex = Multiset(get_multiindex_from_deriv(var))
    
    fixed_symbols = (set(D.diff_symbols) | set(D.diff_symbols_nc)) & set(expr.free_symbols)
    symbols_to_iterate = list(fixed_symbols)  # Creates a separate list copy

    for sym in symbols_to_iterate:
        if get_var_name_from_deriv(sym) == var_name:
            sym_multiindex = Multiset(get_multiindex_from_deriv(sym))
            if var_multiindex.issubset(sym_multiindex):
                target_multiindex = sym_multiindex.difference(var_multiindex)
                target_operator = [Symbol(char, real=True) for char in target_multiindex]
                #target_operator = [Symbol(char) for char in target_multiindex]
                if isinstance(scale, dict):
                    factor = 1
                    for direction in target_operator:
                        factor *= scale.get(direction, 1)
                else:
                    factor = scale**len(target_operator)
                expr = expr.subs(sym, factor * multi_deriv(sub, target_operator))

    return expr

def subs(expr, data, scale=1):
    for (var, sub) in data:
        expr = single_subs(expr, var, sub, scale=scale)
    return expr

def variation(expr, sym):
    if not hasattr(expr, 'free_symbols'):
        return 0
    
    res = 0
    order = 0

    syms = []
    orders = []

    start = sym.name.find('(')
    end = sym.name.find(')')
    if start != -1 and end != -1:
        sym_name = sym.name[start+1:end]
        syms.append(sym)
        orders.append(len(sym.name) - 4 - sym.name.find(')_{'))
    else:
        sym_name = sym.name
        syms.append(sym)
        orders.append(0)

    for s in expr.free_symbols:
        start = s.name.find('(')
        end = s.name.find(')')
        if start != -1 and end != -1 and sym_name == s.name[start+1:end]:
            if s.name not in [sym.name for sym in syms]:
                syms.append(s)
                orders.append(len(s.name) - 4 - s.name.find(')_{'))    
    
    for (sym, order) in zip(syms, orders):
        res += (-1)**order * higher_deriv(Derivative(expr, sym).doit(), order)

    return simplify(res)

epsilon, epsiloninv = symbols('\epsilon, \eta', positive=True)

def polynomize(expr):
    return simplify(Poly(expr, epsilon, epsiloninv).subs(epsiloninv, 1/epsilon))

def depolynomize(poly):
    monoms = poly.monoms()
    coeffs = poly.coeffs()
    
    X = 0
    for (k, m) in enumerate(monoms):
        X += epsilon**m[0] * epsiloninv**m[1] * coeffs[k]
            
    return X

def poly_simplify(expr):
    return depolynomize(polynomize(expr))

def extract_deriv(expr, k): 
    expr = simplify(expr).expand()
    res = 0
    if isinstance(expr, Add):
        for mon in expr.args:
            count = 0
            if hasattr(mon, "name"):
                if "_{x" in mon.name:
                    count += 1
            else:
                for fac in mon.args:
                    if hasattr(fac, "name"):
                        if "_{x" in fac.name:
                            count += 1
                    elif isinstance(fac, Pow):
                        if "_{x" in fac.base.name and fac.exp > 0:
                            count += fac.exp     
            if count == k:
                res += mon

    else:
        mon = expr
        count = 0
        if hasattr(mon, "name"):
            if "_{x" in mon.name:
                count += 1
        elif hasattr(mon, "args"):
            for fac in mon.args:
                if hasattr(fac, "name"):
                    if "_{x" in fac.name:
                        count += 1
                elif isinstance(fac, Pow):
                    if "_{x" in fac.base.name and fac.exp > 0:
                        count += fac.exp     
        if count == k:
            res += mon

    return res
        
def extract_deriv_alt(expr, k):
    expr = simplify(expr).expand()
    res = 0
    if isinstance(expr, Add):
        for mon in expr.args:
            presence_deriv_q_conj = False
            count = 0
            if hasattr(mon, "name"):
                if "_{x" in mon.name:
                    count += 1
                    if get_var_name_from_deriv(mon) == q_conj.name:
                        presence_deriv_q_conj = True
            else:
                for fac in mon.args:
                    if hasattr(fac, "name"):
                        if "_{x" in fac.name:
                            count += 1
                            if get_var_name_from_deriv(fac) == q_conj.name:
                                presence_deriv_q_conj = True
                    elif isinstance(fac, Pow):
                        if "_{x" in fac.base.name and fac.exp > 0:
                            count += fac.exp     
                            if get_var_name_from_deriv(fac.base) == q_conj.name:
                                presence_deriv_q_conj = True
            if count == k and presence_deriv_q_conj:
                res += mon

    else:
        presence_deriv_q_conj = False
        mon = expr
        count = 0
        if hasattr(mon, "name"):
            if "_{x" in mon.name:
                count += 1
                if get_var_name_from_deriv(mon) == q_conj.name:
                    presence_deriv_q_conj = True
        elif hasattr(mon, "args"):
            for fac in mon.args:
                if hasattr(fac, "name"):
                    if "_{x" in fac.name:
                        count += 1
                        if get_var_name_from_deriv(fac) == q_conj.name:
                            presence_deriv_q_conj = True
                elif isinstance(fac, Pow):
                    if "_{x" in fac.base.name and fac.exp > 0:
                        count += fac.exp   
                        if get_var_name_from_deriv(fac.base) == q_conj.name:
                            presence_deriv_q_conj = True
        if count == k and presence_deriv_q_conj:
            res += mon
    return res

def group_by_orders(expr, syms):
    sym_names = [get_var_name_from_deriv(sym) for sym in syms]
    deriv_syms = [sym for sym in D.diff_symbols if (get_var_name_from_deriv(sym) in sym_names)]    
    
    if hasattr(expr, "expand"):
        expr = simplify(expr).expand()
    if isinstance(expr, Add):
        args = expr.args
    else:
        args = [expr]
        
    order_groups = {}
    for mon in args:
        order = 0
        if isinstance(mon, Pow):
            if mon.args[0] in deriv_syms:
                order += mon.args[1]
        if isinstance(mon, Mul):
            factors = mon.args
            for fac in factors:
                if isinstance(fac, Pow):
                    if fac.args[0] in deriv_syms:
                        order += fac.args[1]
                elif fac in deriv_syms:
                    order += 1
        elif mon in deriv_syms:
            order += 1
            
        order_groups[order] = order_groups.get(order, 0) + mon
    return {key: order_groups[key] for key in sorted(order_groups.keys())}

def integrate_by_parts(expr, syms):
    deriv_sym_lists = []
    for base in syms:
        base_name = get_var_name_from_deriv(base)
        group = [ds for ds in D.diff_symbols
                 if get_var_name_from_deriv(ds) == base_name]
        group.sort(key=get_order_from_deriv)
        deriv_sym_lists.append(group)

    if hasattr(expr, "expand"):
        expr = simplify(expr).expand()

    monomials = expr.args if isinstance(expr, Add) else [expr]
    res = 0

    for mon in monomials:
        res += mon
        powers = mon.as_powers_dict()
        for group in deriv_sym_lists:
            present = [ds for ds in group if ds in powers]
            if not present:
                continue
            if len(present) == 1:
                target_ds = present[0]
                coef = 1
                dorder = get_order_from_deriv(target_ds)
            else:
                total_order = sum(get_order_from_deriv(ds) for ds in present)
                if total_order % 2 == 0:
                    target_ds = min(present, key=get_order_from_deriv)
                    coef = 1
                    dorder = get_order_from_deriv(target_ds)
                else:
                    target_ds = max(present, key=get_order_from_deriv)
                    coef = Rational(1, 2)
                    orders = sorted(get_order_from_deriv(ds) for ds in present)
                    dorder = orders[-1] - orders[-2]
            base_sym = get_var_from_deriv(target_ds)
            remaining_order = get_order_from_deriv(target_ds) - dorder
            remaining_factor = (
                higher_deriv(base_sym, remaining_order)
                if remaining_order else base_sym
            )
            replacement = coef * (
                -mon
                + (-1)**dorder * remaining_factor * higher_deriv(simplify(mon / target_ds), dorder)
            )
            res += replacement
            break
    return res

def integrate_by_parts_multiple(expr, syms, N):
    res = expr
    for _ in range(int(N)):
        res = integrate_by_parts(res, syms)
    return res

def integrate_by_parts_auto(expr, syms, max_iter=20):
    res = expr
    k = 0
    for _ in range(max_iter):
        k += 1
        new_res = simplify(integrate_by_parts(res, syms))
        if simplify(new_res - res) == 0:
            break
        res = new_res
    return res

def operator_from_bilinear(expr, syms):
    if len(syms) != 2:
        print("This function is only intended for bilinear forms")    
    op = Matrix([[0 for n in range(len(syms))] for k in range(len(syms))])
    sym_names = [get_var_name_from_deriv(sym) for sym in syms]
    deriv_syms = [(i, s) for i, sym_name in enumerate(sym_names) for s in D.diff_symbols if get_var_name_from_deriv(s) == sym_name]
    expr = integrate_by_parts_auto(expr, syms)
    
    if hasattr(expr, "expand"):
        expr = simplify(expr).expand()
    if isinstance(expr, Add):
        args = expr.args
    else:
        args = [expr]
        
    res = 0
    for mon in args:
        if hasattr(mon, "free_symbols"):
            ind1, sym1 = next(((i, sym) for (i, sym) in deriv_syms if sym in mon.free_symbols), None)
            ind2, sym2 = next(((i, sym) for (i, sym) in deriv_syms if sym in simplify(mon / sym1).free_symbols), None)
            dorder = get_order_from_deriv(sym2)
            op[ind1, ind2] += D.comm_to_non_comm(mon / sym1 / sym2) * (D(xx)**dorder if dorder > 0 else 1)
    return op

def operator_from_linear(expr, sym):
    op = 0
    sym_name = get_var_name_from_deriv(sym)
    deriv_syms = [s for s in D.diff_symbols if get_var_name_from_deriv(s) == sym_name]
    
    if hasattr(expr, "expand"):
        expr = simplify(expr).expand()
    if isinstance(expr, Add):
        args = expr.args
    else:
        args = [expr]
    res = 0
    for mon in args:
        if hasattr(mon, "free_symbols"):
            sym1 = next((sym for sym in deriv_syms if sym in mon.free_symbols), None)
            if (sym1):
                dorder = get_order_from_deriv(sym1)
                op += D.comm_to_non_comm(mon / sym1) * (D(xx)**dorder if dorder > 0 else 1)
    return op

def operator_adjoint(op):
    if isinstance(op, Matrix):
        return op.T.applyfunc(operator_adjoint)
    if hasattr(op, "expand"):
        op = op.expand()
    if isinstance(op, Add):
        return sum(operator_adjoint(t) for t in op.args)
    def _count_order(f):
        if isinstance(f, D):
            return len(f.variables)
        if getattr(f, 'is_Pow', False) and isinstance(f.base, D):
            try:
                return int(f.exp) * len(f.base.variables)
            except:
                return 0
        return 0
    facs = list(op.args) if isinstance(op, Mul) else [op]
    k = sum(_count_order(f) for f in facs)
    r = Mul(*reversed(facs), evaluate=False)
    return (-1)**k * r

def func_deriv(expr, var, pvar):
    t = Symbol('s')
    return Derivative(subs(expr, [(q_conj, q_conj + t * p_conj)]), t).doit().subs(t, 0)


# Symbols

In [5]:
D.reset_symbols()
#To use our custom calculus functionalities we need to initially define all the symbols we use.

#We define directions to take derivatives in
xx, x, yy, y, tt, t = D.create_non_diff_symbols('x', 'y', 't', real=True)
EEps, Eps = D.create_non_diff_symbols('e', real=True)

#We define some complex valued differentiable symbols
qq, q, qq_conj, q_conj = D.create_diff_symbols('q', '\\tilde{q}')
pp, p, pp_conj, p_conj = D.create_diff_symbols('p', '\\tilde{p}')
ww, w, ww_conj, w_conj = D.create_diff_symbols('w', '\\tilde{w}')
vv, v, vv_conj, v_conj = D.create_diff_symbols('v', '\\tilde{v}')

#We define some real valued differential symbols
AA, A, aa, a = D.create_diff_symbols('A', 'a', real=True)
uu, u, aa, a, pphi, phi, ff, f, gg, g = D.create_diff_symbols('u', 'a', '\\phi', 'f', 'g', real=True)
WW_plus, W_plus, WW_minus, W_minus, ww_plus, w_plus, ww_minus, w_minus = D.create_diff_symbols('W_+', 'W_-', 'w_+', 'w_-', real=True)
qq_pm, q_pm, qq_conj_pm, q_conj_pm = D.create_non_diff_symbols('q_\\pm', '\\tilde{q}_\pm')
pphi_pm, phi_pm, pphi_minus, phi_minus, pphi_plus, phi_plus = D.create_non_diff_symbols('\\phi_\pm', '\\phi_-', '\\phi_+', real=True)
ssig, sig, rr, r, ssigt, sigt, rrt, rt = D.create_diff_symbols('\\sigma', 'r', '\\tilde{\\sigma}', '\\tilde{r}')

#We define some simple constant symbols
lam, z = symbols('\\lambda z')

print('''
Above each symbol has two variants, one with a repeated name 'xx' or 'qq', and one with a single name 'x' or 'q'.
The symbols with the single name are commutative. The symbols with the repeated name are not commutative.
In order to use differential operators we need non-commutative symbols. For example note that the first expression is simplified to zero
while the second is not. We want the second behavior for differential operators, so here the 'double' symbols need to be used.
''')
display(simplify(D(xx) * q - q * D(xx)))
display(simplify(D(xx) * qq - qq * D(xx)))

print('''
In order to evaluate this Differential operator as if it is being applied to the function one, we use
''')
display(evaluateExpr(D(xx) * qq - qq * D(xx)))

print('''
Note that outside of evaluateExpr we do not think of the expression as a differential operator anymore, instead just a function.
There we use 'single' variable names.
''')
display(evaluateExpr(((D(xx) * qq - qq * D(xx) * ff))))
display(evaluateExpr((D(xx) * qq - qq * D(xx)) * f))

print('''
We can also take derivatives in several directions and use coordinates
''')
display(evaluateExpr(xx * D(xx, yy) * ff))
display(evaluateExpr(xx * D(xx, yy) * yy * ff))

print('''
We can also consider matrix operators
''')
M = Matrix([[D(xx), lam], [- lam, D(xx)]])
V = Matrix([ff, gg])
display(evaluateExpr(M * V))

print('''
In the computations below we don't use the full differential operator machinery (yet).
Just taking derivatives of differentiable symbols can be done as below.
We also demonstrate that sympy has some inbuilt functions.
''')
display(deriv(f))
display(deriv(deriv(p)))
display(higher_deriv(exp(a), 3))
display(higher_deriv(exp(2 * x), 3))
display(multi_deriv(u, [x, y, t]))
display(multi_deriv(exp(Rational(1, 2) * x) * y**3, [x, x, y]))


Above each symbol has two variants, one with a repeated name 'xx' or 'qq', and one with a single name 'x' or 'q'.
The symbols with the single name are commutative. The symbols with the repeated name are not commutative.
In order to use differential operators we need non-commutative symbols. For example note that the first expression is simplified to zero
while the second is not. We want the second behavior for differential operators, so here the 'double' symbols need to be used.



0

-q*D(x) + D(x)*q


In order to evaluate this Differential operator as if it is being applied to the function one, we use



(q)_{x}


Note that outside of evaluateExpr we do not think of the expression as a differential operator anymore, instead just a function.
There we use 'single' variable names.



-(f)_{x}*q + (q)_{x}

(q)_{x}*f


We can also take derivatives in several directions and use coordinates



(f)_{xy}*x

(f)_{xy}*x*y + (f)_{x}*x


We can also consider matrix operators



Matrix([
[(f)_{x} + \lambda*g],
[(g)_{x} - \lambda*f]])


In the computations below we don't use the full differential operator machinery (yet).
Just taking derivatives of differentiable symbols can be done as below.
We also demonstrate that sympy has some inbuilt functions.



(f)_{x}

(p)_{xx}

(a)_{xxx}*exp(a) + 2*(a)_{xx}*(a)_{x}*exp(a) + (a)_{x}*((a)_{xx}*exp(a) + (a)_{x}**2*exp(a))

8*exp(2*x)

(u)_{txy}

3*y**2*exp(x/2)/4

# Expansions - General

In [6]:
def series_expansion(expr, x0, var, num):
    tvar = var.subs(z, tz).subs(lam, tlam)
    ser = complexify(series(realify(expr), x0=x0, x=tvar, n=num).removeO())
    expansion_dict = defaultdict(lambda: 0)
    for term in ser.as_ordered_terms():
        powers_dict = term.as_powers_dict()
        exponent = powers_dict.get(var, 0)
        coefficient = term.subs(var, 1)
        expansion_dict[exponent] += coefficient
    if not expansion_dict:
        return None, None
    final_dict = dict(expansion_dict)
    return final_dict

N = 13

sig_syms = D.create_diff_symbols(*[f'\\sigma_{{{n}}}' for n in range(0, N+1)])
sig_syms = sig_syms[1::2]
sig_syms = {n: sig_syms[n] for n in range(0, N+1)}
sigt_syms = D.create_diff_symbols(*[f'\\tilde{{\\sigma_{{{n}}}}}' for n in range(0, N+1)])
sigt_syms = sigt_syms[1::2]
sigt_syms = {n: sigt_syms[n] for n in range(0, N+1)}

r_syms = D.create_diff_symbols(*[f'r_{{{n}}}' for n in range(0, N+1)])
r_syms = r_syms[1::2]
r_syms = {n: r_syms[n] for n in range(0, N+1)}
R_syms = D.create_diff_symbols(*[f'R_{{{n}}}' for n in range(0, N+1)])
R_syms = R_syms[1::2]
R_syms = {n: R_syms[n] for n in range(0, N+1)}
rt_syms = D.create_diff_symbols(*[f'\\tilde{{r_{{{n}}}}}' for n in range(0, N+1)])
rt_syms = rt_syms[1::2]
rt_syms = {n: rt_syms[n] for n in range(0, N+1)}


# Expansions - NLS

In [7]:
sig_eq_NLS_defocusing = [sig_syms[n] for n in range(N+1)]
sig_eq_NLS_focusing = [sig_syms[n] for n in range(N+1)]
sig_eq_GP_defocusing = [sig_syms[n] for n in range(N+1)]
sig_eq_GP_focusing = [sig_syms[n] for n in range(N+1)]

sig_eq_NLS_defocusing[0] = 0
sig_eq_NLS_defocusing[1] = - q * q_conj
sig_eq_NLS_focusing[0] = 0
sig_eq_NLS_focusing[1] = q * q_conj
sig_eq_GP_defocusing[0] = 0
sig_eq_GP_defocusing[1] = 1 - q * q_conj


for n in range(1, N):
    print("n =", n)
    X = q * deriv(sig_eq_NLS_defocusing[n] / q)
    for k in range(0, n + 1):
        X += sig_eq_NLS_defocusing[k] * sig_eq_NLS_defocusing[n-k]
    sig_eq_NLS_defocusing[n+1] = simplify(X)

    X = q * deriv(sig_eq_NLS_focusing[n] / q)
    for k in range(0, n + 1):
        X += sig_eq_NLS_focusing[k] * sig_eq_NLS_focusing[n-k]
    sig_eq_NLS_focusing[n+1] = simplify(X)

n = 1
n = 2
n = 3
n = 4
n = 5
n = 6
n = 7
n = 8
n = 9
n = 10
n = 11
n = 12


In [8]:
for n in range(N+1):
    nn = n // 2
    if n % 2 == 0:
        sig_eq_GP_defocusing[2*nn] = 0
        for m in range(nn+1):
            sig_eq_GP_defocusing[2*nn] += binomial(nn-1, nn-m) * 4**(nn-m) * sig_eq_NLS_defocusing[2*m]
    else:
        sig_eq_GP_defocusing[2*nn+1] = 0
        for m in range(nn+1):
            sig_eq_GP_defocusing[2*nn+1] += binomial(nn-Rational(1,2), nn-m) * 4**(nn-m) * sig_eq_NLS_defocusing[2*m+1]
        sig_eq_GP_defocusing[2*nn+1] += catalan(nn)


In [11]:
display(list(zip(D.diff_symbols, D.diff_symbols_nc)))

[(q, q),
 (\tilde{q}, \tilde{q}),
 (p, p),
 (\tilde{p}, \tilde{p}),
 (w, w),
 (\tilde{w}, \tilde{w}),
 (v, v),
 (\tilde{v}, \tilde{v}),
 (A, A),
 (a, a),
 (u, u),
 (\phi, \phi),
 (f, f),
 (g, g),
 (W_+, W_+),
 (W_-, W_-),
 (w_+, w_+),
 (w_-, w_-),
 (\sigma, \sigma),
 (r, r),
 (\tilde{\sigma}, \tilde{\sigma}),
 (\tilde{r}, \tilde{r}),
 ((q)_{x}, (q)_{x}),
 ((q)_{x}, (f)_{x}),
 ((f)_{x}, (f)_{xy}),
 ((f)_{xy}, (g)_{x}),
 ((g)_{x}, (p)_{x}),
 ((p)_{x}, (p)_{xx}),
 ((p)_{xx}, (a)_{x}),
 ((a)_{x}, (a)_{xx}),
 ((a)_{xx}, (a)_{xxx}),
 ((a)_{xxx}, (u)_{x}),
 ((u)_{x}, (u)_{xy}),
 ((u)_{xy}, (u)_{txy}),
 ((u)_{txy}, \sigma_{0}),
 (\sigma_{0}, \sigma_{1}),
 (\sigma_{1}, \sigma_{2}),
 (\sigma_{2}, \sigma_{3}),
 (\sigma_{3}, \sigma_{4}),
 (\sigma_{4}, \sigma_{5}),
 (\sigma_{5}, \sigma_{6}),
 (\sigma_{6}, \sigma_{7}),
 (\sigma_{7}, \sigma_{8}),
 (\sigma_{8}, \sigma_{9}),
 (\sigma_{9}, \sigma_{10}),
 (\sigma_{10}, \sigma_{11}),
 (\sigma_{11}, \sigma_{12}),
 (\sigma_{12}, \sigma_{13}),
 (\sigma_{13},

In [10]:
subs_list_NLS_focusing = [
    (higher_deriv(w, 2), - 2 * w**2 * w_conj),
    (higher_deriv(w_conj, 2), - 2 * w_conj**2 * w),
    (w_conj, w)
]
subs_list_NLS_alt_focusing = [
    (higher_deriv(w, 2), - 2 * w**2 * w_conj),
    (higher_deriv(w_conj, 2), - 2 * w_conj**2 * w),
    (w_conj, w),
    (deriv(w)**2, Eps**2 - w**4)
]
ops = []
ops_adjoint = []
for n in range(0, N+1):
    print(f"-----{n}-----")
    op = operator_from_linear(func_deriv(sig_eq_NLS_focusing[n], q_conj, p_conj), p_conj)
    ops.append(op)
    op_adjoint = operator_adjoint(op)
    ops_adjoint.append(op_adjoint)
    if n > 1:
        X = sig_eq_NLS_focusing[n]
        X -= deriv(sig_eq_NLS_focusing[n-1]) - deriv(q) / q * sig_eq_NLS_focusing[n-1]
        for k in range(n):
            X -= sig_eq_NLS_focusing[k] * sig_eq_NLS_focusing[n-1-k]
        display(Equ(0, simplify(X)))
        
        Y = op
        Y -= D(xx) * operator_from_linear(func_deriv(sig_eq_NLS_focusing[n-1], q_conj, p_conj), p_conj)
        Y -= - deriv(qq) / qq  * ops[n-1]
        for k in range(n):
            Y -= D.comm_to_non_comm(sig_eq_NLS_focusing[k]) * ops[n-1-k]
            Y -= D.comm_to_non_comm(sig_eq_NLS_focusing[n-1-k]) * ops[k]
        display(Equ(0, simplify(evaluateExpr(Y * ff))))
        
        Z = ops_adjoint[n]
        Z -= - operator_adjoint(operator_from_linear(func_deriv(sig_eq_NLS_focusing[n-1], q_conj, p_conj), p_conj)) * D(xx)
        Z -= - ops_adjoint[n-1] * deriv(qq) / qq
        for k in range(n):
            Z -= ops_adjoint[n-1-k] * D.comm_to_non_comm(sig_eq_NLS_focusing[k])
            Z -= ops_adjoint[k] * D.comm_to_non_comm(sig_eq_NLS_focusing[n-1-k])
        display(Equ(0, simplify(evaluateExpr(Z * ff))))
                
    display(
        simplify(subs(
        simplify(subs(
        subs(evaluateExpr(op_adjoint * ff), [(q, w), (q_conj, w)])
        , subs_list_NLS_focusing * 5))
        , subs_list_NLS_focusing * 5))
    )

-----0-----


0

-----1-----


f*w

-----2-----


Eq(0, 0)

Eq(0, 0)

Eq(0, 0)

-(w)_{x}*(f + w)

-----3-----


Eq(0, 0)

Eq(0, 0)

Eq(0, 0)

(\tilde{p})_{xx}*w + 2*(w)_{x}**2 + 2*f*w**3 - 12*f*w*((w)_{x}**2 - w**4)

-----4-----


Eq(0, \tilde{q}**2*q*((q)_{x} - (q)_{x}))

Eq(0, 2*f*q*((\tilde{q})_{x}*q + (q)_{x}*\tilde{q} - (q)_{x}*\tilde{q} - \tilde{r_{13}}*q))

Eq(0, 2*f*q*((\tilde{q})_{x}*q + (q)_{x}*\tilde{q} - (q)_{x}*\tilde{q} - \tilde{r_{13}}*q))

-(\tilde{p})_{xxx}*w - 3*(\tilde{p})_{xx}*(w)_{x} - 2*(w)_{x}*f*w**2 - 4*(w)_{x}*w**3 + 36*(w)_{x}*w*((w)_{x}**2 - w**4) - 4*\tilde{r_{13}}*f*w**2 + 2*f*w**3

-----5-----


Eq(0, 4*(\tilde{q})_{x}*\tilde{q}*q*((q)_{x} - (q)_{x}))

Eq(0, 2*q*(-(\tilde{p})_{xx}*\tilde{q}*f + 3*(\tilde{q})_{x}*(q)_{x}*f + 2*(\tilde{q})_{x}*(q)_{x}*q - 2*(\tilde{q})_{x}*(q)_{x}*f + (q)_{xx}*\tilde{q}*f + 2*(q)_{x}**2*\tilde{q} - 2*(q)_{x}*(q)_{x}*\tilde{q} - (q)_{x}*\tilde{r_{13}}*f - 2*(q)_{x}*\tilde{r_{13}}*q))

Eq(0, -6*(\tilde{p})_{xx}*\tilde{q}*f*q + 4*(\tilde{q})_{xxxx}*\tilde{q}*f*q - 4*(\tilde{q})_{xx}*f*q**2 + 6*(\tilde{q})_{x}*(q)_{x}*f*q - 4*(\tilde{q})_{x}*(q)_{x}*q**2 - 12*(\tilde{q})_{x}*(q)_{x}*f*q + 4*(\tilde{q})_{x}*f*q**2 + 2*(q)_{xx}*\tilde{q}*f*q - 4*(q)_{x}**2*\tilde{q}*q - 4*(q)_{x}*(q)_{x}*\tilde{q}*f + 4*(q)_{x}*(q)_{x}*\tilde{q}*q - 6*(q)_{x}*\tilde{r_{13}}*f*q + 4*(q)_{x}*\tilde{r_{13}}*q**2 + 4*(q)_{x}**2*\tilde{q}*f + 12*(q)_{x}*\tilde{r_{13}}*f*q)

(\tilde{p})_{xxxx}*w + 4*(\tilde{p})_{xxx}*(w)_{x} - 6*(\tilde{p})_{xx}*f*w**2 + 6*(\tilde{p})_{xx}*w**3 - 72*(\tilde{p})_{xx}*w*((w)_{x}**2 - w**4) - 8*(w)_{x}**2*f*w + 8*(w)_{x}**2*w**2 + 18*(w)_{x}*\tilde{r_{13}}*f*w + 12*(w)_{x}*\tilde{r_{13}}*w**2 + 6*(w)_{x}*f*w**2 - 12*(w)_{x}*f*((w)_{x}**2 - 9*w**4) - 8*(w)_{x}*w**3 + 10*f*w**5 - 144*f*w**3*((w)_{x}**2 - w**4)

-----6-----


Eq(0, q*(6*(\tilde{q})_{xx}*(q)_{x}*\tilde{q} - 6*(\tilde{q})_{xx}*(q)_{x}*\tilde{q} + 5*(\tilde{q})_{x}**2*(q)_{x} - 5*(\tilde{q})_{x}**2*(q)_{x} + 4*(q)_{x}*\tilde{q}**3*q - 4*(q)_{x}*\tilde{q}**3*q))

Eq(0, 2*q*(-3*(\tilde{p})_{xx}*(\tilde{q})_{x}*f + 3*(\tilde{p})_{xx}*(\tilde{q})_{x}*q - 3*(\tilde{p})_{xx}*(q)_{x}*\tilde{q} - 3*(\tilde{p})_{xx}*\tilde{r_{13}}*q - (\tilde{q})_{xxxxx}*\tilde{q}*f + 3*(\tilde{q})_{xx}*(q)_{x}*f - 3*(\tilde{q})_{xx}*(q)_{x}*f + 4*(\tilde{q})_{x}*(q)_{xx}*f + 8*(\tilde{q})_{x}*(q)_{x}**2 - 5*(\tilde{q})_{x}*(q)_{x}*(q)_{x} + 6*(\tilde{q})_{x}*\tilde{q}*f*q**2 + (q)_{xxx}*\tilde{q}*f + 3*(q)_{xx}*(q)_{x}*\tilde{q} - (q)_{xx}*\tilde{r_{13}}*f - 3*(q)_{x}**2*\tilde{r_{13}} + 6*(q)_{x}*\tilde{q}**2*f*q - 6*(q)_{x}*\tilde{q}**2*f*q - 6*\tilde{q}*\tilde{r_{13}}*f*q**2))


KeyboardInterrupt



In [None]:
for n in range(1, min(N, 8)+1):
    X = func_deriv(sig_eq_NLS_focusing[n+1], q_conj, p_conj)
    X -= - deriv(q) / q * func_deriv(sig_eq_NLS_focusing[n], q_conj, p_conj)
    for k in range(n+1):
        X -= 2 * sig_eq_NLS_focusing[k] * func_deriv(sig_eq_NLS_focusing[n-k], q_conj, p_conj)
    X = simplify(X)
    print(f"---{n}---")
    display(simplify(integrate_by_parts(simplify(X), [p_conj]) / p_conj))

In [None]:
display(simplify(2*deriv(q) * deriv(deriv(q) / q) * f + q * deriv(deriv(deriv(q) / q) * f)))

In [None]:
subs_list_NLS_focusing = [
    (higher_deriv(w, 2), - 2 * w**2 * w_conj),
    (higher_deriv(w_conj, 2), - 2 * w_conj**2 * w),
    (w_conj, w)
]
subs_list_NLS_alt_focusing = [
    (higher_deriv(w, 2), - 2 * w**2 * w_conj),
    (higher_deriv(w_conj, 2), - 2 * w_conj**2 * w),
    (w_conj, w),
    (deriv(w)**2, Eps**2 - w**4)
]

for n in range(0, min(N, 80)+1):
    #expr2 = simplify(subs(stationary_NLS_focusing[n], [(q, w), (q_conj, w)])).expand()
    expr2 = simplify(subs(simplify(integrate_by_parts(func_deriv(
        sig_eq_NLS_focusing[n], q_conj, p_conj), [p_conj]) / p_conj), [(q, w), (q_conj, w)]))
    display(Equ(sig_syms[n], expr2.expand() if hasattr(expr2, 'expand') else expr2))
    display(Equ(sig_syms[n], 
            
        simplify(subs(
        simplify(subs(
        simplify(subs(
            deriv(1 / w * expr2)
        , subs_list_NLS_alt_focusing * 5))
        , subs_list_NLS_alt_focusing * 5))
        , subs_list_NLS_alt_focusing * 5))
    ))
    display(Equ(sig_syms[n], 
        simplify(subs(
        simplify(subs(
        simplify(subs(
            deriv(1 / deriv(w) * expr2)
        , subs_list_NLS_alt_focusing * 5))
        , subs_list_NLS_alt_focusing * 5))
        , subs_list_NLS_alt_focusing * 5))
    ))

In [None]:
for n in range(2, min(N, 80)+1):
    expr2 = simplify(subs(stationary_NLS_focusing[n], [(q, w), (q_conj, w)])).expand()
    display(Equ(sig_syms[n], expr2.expand() if hasattr(expr2, 'expand') else expr2))
    display(Equ(sig_syms[n], 
        simplify(subs(
        simplify(subs(
        simplify(subs(
            deriv(1 / w * stationary_NLS_focusing[n-1])
        , subs_list_NLS_alt_focusing * 5))
        , subs_list_NLS_alt_focusing * 5))
        , subs_list_NLS_alt_focusing * 5))
    ))
    display(Equ(sig_syms[n], 
        simplify(subs(
        simplify(subs(
        simplify(subs(
            deriv(1 / deriv(w) * stationary_NLS_focusing[n-1])
        , subs_list_NLS_alt_focusing * 5))
        , subs_list_NLS_alt_focusing * 5))
        , subs_list_NLS_alt_focusing * 5))
    ))

In [None]:
for n in range(min(N, 8)+1):
    expr1 = sig_eq_NLS_focusing[n]
    expr2 = sig_eq_NLS_defocusing[n]
    expr3 = sig_eq_GP_defocusing[n]
    display(Equ(sig_syms[n], expr1.expand() if hasattr(expr1, 'expand') else expr1))
    #display(Equ(sig_syms[n], expr2.expand() if hasattr(expr2, 'expand') else expr2))
    #display(Equ(sig_syms[n], expr3.expand() if hasattr(expr3, 'expand') else expr3))

In [None]:
H_NLS_focusing = {n-1: - (-I)**(n-1) * e for n, e in enumerate(sig_eq_NLS_focusing)}
H_NLS_defocusing = {n-1: - (-I)**(n-1) * e for n, e in enumerate(sig_eq_NLS_defocusing)}
H_GP_defocusing = {n-1: - (-I)**(n-1) * e for n, e in enumerate(sig_eq_GP_defocusing)}

H_NLS_focusing[-1] = - Rational(1, 2) * I * (deriv(q) / q - deriv(q_conj) / q_conj)
H_NLS_defocusing[-1] = Rational(1, 2) * I * (deriv(q) / q - deriv(q_conj) / q_conj)
H_GP_defocusing[-1] = Rational(1, 2) * I * (deriv(q) / q - deriv(q_conj) / q_conj)

H_NLS_focusing = {n: Rational(1,2**(n//2)) * subs(H_NLS_focusing[n], [(q, q), (q_conj, q_conj)], scale={x: sqrt(2)}) for n in H_NLS_focusing.keys()}
H_NLS_defocusing = {n: Rational(1,2**(n//2)) * subs(H_NLS_defocusing[n], [(q, q), (q_conj, q_conj)], scale={x: sqrt(2)}) for n in H_NLS_defocusing.keys()}
H_GP_defocusing = {n: Rational(1,2**(n//2)) * subs(H_GP_defocusing[n], [(q, q), (q_conj, q_conj)], scale={x: sqrt(2)}) for n in H_GP_defocusing.keys()}

print("-1-")
stationary_NLS_focusing = {n: simplify(subs(variation(expr, q_conj), [(q, w), (q_conj, w)])) for n, expr in H_NLS_focusing.items()}
print("-2-")
stationary_NLS_defocusing = {n: simplify(subs(variation(expr, q_conj), [(q, w), (q_conj, w)])) for n, expr in H_NLS_defocusing.items()}
print("-3-")
stationary_GP_defocusing = {n: simplify(subs(variation(expr, q_conj), [(q, w), (q_conj, w)])) for n, expr in H_GP_defocusing.items()}


## Critical points - defocusing case

In [None]:

E_GaPe = deriv(q) * deriv(q_conj) + Rational(1,2) * (1 - q * q_conj)**2
Q_GaPe = q * q_conj
Q_GaPe_2 = q * q_conj - 1
M_GaPe = I / 2 * (q_conj * deriv(q) - q * deriv(q_conj))
R_GaPe = higher_deriv(q, 2) * higher_deriv(q_conj, 2) + 3 * q * q_conj * deriv(q) * deriv(q_conj) + (q * deriv(q_conj) + q_conj * deriv(q))**2 / 2 + q**3 * q_conj**3 / 2
R_GaPe_2 = higher_deriv(q, 2) * higher_deriv(q_conj, 2) + 3 * q * q_conj * deriv(q) * deriv(q_conj) + (q * deriv(q_conj) + q_conj * deriv(q))**2 / 2 + q**3 * q_conj**3 / 2 - Rational(1,2)
S_GaPe = R_GaPe - Rational(1, 2) * (3 - Eps**2) * Q_GaPe
S_GaPe_2 = R_GaPe_2 - Rational(1, 2) * (3 - Eps**2) * Q_GaPe_2

display(Equ(Symbol('\\text{Error for }Q_{\\text{GaPe}}'), Matrix([variation(H_NLS_defocusing[0] - Q_GaPe, q), variation(H_NLS_defocusing[0] - Q_GaPe, q_conj)])))
display(Equ(Symbol('\\text{Error for }E_{\\text{GaPe}}'), Matrix([variation(H_GP_defocusing[2] - E_GaPe, q), variation(H_GP_defocusing[2] - E_GaPe, q_conj)])))
display(Equ(Symbol('\\text{Error for }R_{\\text{GaPe}}'), Matrix([variation(H_NLS_defocusing[4] - R_GaPe, q), variation(H_NLS_defocusing[4] - R_GaPe, q_conj)])))


subs_list_GP_defocusing = [
    (higher_deriv(w, 2), w**2 * w_conj - w),
    (higher_deriv(w_conj, 2), w_conj**2 * w - w_conj),
    (w_conj, w)
]
subs_list_GP_alt_defocusing = [
    (higher_deriv(w, 2), w**2 * w_conj - w),
    (higher_deriv(w_conj, 2), w_conj**2 * w - w_conj),
    (w_conj, w),
    (deriv(w)**2, Rational(1, 2) * ((w**2 - 1)**2-Eps**2))
]

for n in range(-1, min(N, 8)):
    display(Equ(Symbol(f'\\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(q, \\tilde{{q}})'), H_NLS_defocusing[n].expand()))
    display(Equ(Symbol(f'\\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(q, \\tilde{{q}})'), H_GP_defocusing[n].expand()))

for n in range(-1, min(N, 8)):
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w}}'), stationary_NLS_defocusing[n].expand()))
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(w)}}{{\delta w}}'), stationary_GP_defocusing[n].expand()))

for n in range(-1, min(N, 14)):
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w}}\\Bigg\\vert_{{\\substack{{(w)_{{xx}} = w^3 - w \\\\ 2 (w)_x^2 = (w^2 - 1)^2 - \\mathcal{{E}}^2}}}}'), 
        simplify(subs(simplify(subs(stationary_NLS_defocusing[n], subs_list_GP_alt_defocusing*n)), subs_list_GP_alt_defocusing*n))
    ))
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(w)}}{{\delta w}}\\Bigg\\vert_{{\\substack{{(w)_{{xx}} = w^3 - w \\\\ 2 (w)_x^2 = (w^2 - 1)^2 - \\mathcal{{E}}^2}}}}'), 
        simplify(subs(simplify(subs(stationary_GP_defocusing[n], subs_list_GP_alt_defocusing*n)), subs_list_GP_alt_defocusing*n))
    ))

for n in range(-1, min(N, 14)):
    factor1 = 1 / w if n%2==0 else 1 / deriv(w)
    sqrtstring = "\\sqrt{{2}}"
    display(Equ(Symbol(f'\\Bigg(\\frac{{1}}{{{"w" if n%2==0 else "(w)_x"}}} \\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(w)}}{{\delta w}}\\Bigg)_x \\Bigg\\vert_{{(w)_{{xx}} = w^3 - w}}'), simplify(subs(
        deriv(factor1 * simplify(subs(
            stationary_GP_defocusing[n]
        , subs_list_GP_alt_defocusing * n)))
    , subs_list_GP_alt_defocusing * n))
    ))
    
# for n in range(min(N+1, 14)):
#     if n%2==0:
#         display(latex(Equ(Symbol(f'\\Bigg(\\frac{{1}}{{(w)_x}} \\frac{{\\delta \\mathcal{{{"E" if n%2==1 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(w)}}{{\delta w}}\\Bigg)_x \\Bigg\\vert_{{(w)_{{xxx}} = 3 (w)_x w^2 - 2 (w)_x}}'), simplify(subs(
#             simplify(subs(
#                 deriv( 1 / deriv(w) * simplify(subs(
#                     stationary_GP_eqs[n]
#                 , subs_list2 * 5)))
#             , subs_list2 * 5))
#         , subs_list2 * 5))
#         )))
#     if n%2==1:
#         display(latex(Equ(Symbol(f'\\Bigg(\\frac{{1}}{{w}}\\Bigg(\\frac{{\\delta \\mathcal{{{"E" if n%2==1 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(w)}}{{\delta w}}\\Bigg)_x \Bigg)_x \\Bigg\\vert_{{(w)_{{xxx}} = 3 (w)_x w^2 - 2 (w)_x}}'), simplify(subs(
#             deriv(simplify(subs(
#                 1 / deriv(w) * deriv(simplify(subs(
#                     stationary_GP_eqs[n]
#                 , subs_list2 * 5)))
#             , subs_list2 * 5)))
#         , subs_list2 * 5))
#         )))
    



## Critical Points - focusing case

In [None]:

subs_list_NLS_focusing = [
    (higher_deriv(w, 2), - w**2 * w_conj),
    (higher_deriv(w_conj, 2), - w_conj**2 * w),
    (w_conj, w)
]
subs_list_NLS_alt_focusing = [
    (higher_deriv(w, 2), - w**2 * w_conj),
    (higher_deriv(w_conj, 2), - w_conj**2 * w),
    (w_conj, w),
    (deriv(w)**2, Rational(1, 2) * (Eps**2 - w**4))
]

for n in range(-1, min(N, 8)):
    display(Equ(Symbol(f'\\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(q, \\tilde{{q}})'), H_NLS_focusing[n].expand()))

for n in range(-1, min(N, 8)):
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w}}'), stationary_NLS_focusing[n].expand()))

for n in range(-1, min(N, 14)):
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w}}\\Bigg\\vert_{{\\substack{{(w)_{{xx}} = - w^3 \\\\ 2 (w_x)^2 = \\mathcal{{E}}^2 - w^4}}}}'), 
        simplify(subs(simplify(subs(stationary_NLS_focusing[n], subs_list_NLS_alt_focusing*n)), subs_list_NLS_alt_focusing*n))
    ))

for n in range(-1, min(N, 14)):
    factor1 = 1 / w if n%2==0 else 1 / deriv(w)
    sqrtstring = "\\sqrt{{2}}"
    display(Equ(Symbol(f'\\Bigg(\\frac{{1}}{{{"w" if n%2==0 else "(w)_x"}}} \\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(w)}}{{\delta w}}\\Bigg)_x \\Bigg\\vert_{{(w)_{{xx}} = - w^3}}'), 
        simplify(subs(deriv(factor1 * simplify(subs(
            stationary_NLS_focusing[n]
        , subs_list_NLS_focusing * n)))
    , subs_list_NLS_focusing * n))
    ))
    

## Second order variations - focusing case

In [None]:
subs_list = [(q, w + f + I * g), (q_conj, w + f - I * g)]
subs_list_NLS_alt_focusing_noncomm = [(D.comm_to_non_comm(tup[0]), D.comm_to_non_comm(tup[1])) for tup in subs_list_NLS_alt_focusing]

for n in range(0, N):
    bilinear = group_by_orders(subs(H_NLS_focusing[n], subs_list), [f, g]).get(2, 0)
    op = operator_from_bilinear(bilinear, [f, g])
    #diff = simplify(bilinear - (Matrix([f, g]).transpose() * evaluateExpr(op * Matrix([ff, gg])))[0,0])
    #display(Equ(0, simplify(variation(diff, f))**2 + simplify(variation(diff, g))**2 + simplify(variation(diff, w))**2))
    new_op = simplify(subs(simplify(subs(op, subs_list_NLS_alt_focusing_noncomm*n)), subs_list_NLS_alt_focusing_noncomm*n)).expand()
    #display(Equ(Symbol(f'\\big(f \;\; g \\big) \\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\begin{{pmatrix}} f \\\\ g \\end{{pmatrix}}'), bilinear))
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\Bigg\\vert_{{\\substack{{(w)_{{xx}} = w^3 - w \\\\ 2 (w^2)_x = (w^2 - 1)^2 - \\mathcal{{E}}^2}}}}'), 
        Matrix([new_op[0,0], new_op[0,1], new_op[1,0], new_op[1,1]])))
    #display(simplify(subs(evaluateExpr(new_op * Matrix([deriv(ww), ww])), subs_list_GP_alt_defocusing*n)))

## Second order variations - defocusing case

In [None]:
bilinear = group_by_orders(subs(S_GaPe_2, subs_list), [f, g]).get(2, 0)
display(Equ(Symbol("bilinear"), bilinear))
display(Equ(Symbol("bilinear ibp"), simplify(integrate_by_parts_auto(bilinear, [f, g])).expand()))

op = operator_from_bilinear(bilinear, [f, g])
#diff = simplify(bilinear - (Matrix([f, g]).transpose() * evaluateExpr(op * Matrix([ff, gg])))[0,0])
#display(Equ(0, simplify(variation(diff, f))**2 + simplify(variation(diff, g))**2 + simplify(variation(diff, w))**2))
new_op = simplify(subs(simplify(subs(op, subs_list_GP_alt_defocusing_noncomm*n)), subs_list_GP_alt_defocusing_noncomm*n)).expand()
#display(Equ(Symbol(f'\\big(f \;\; g \\big) \\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\begin{{pmatrix}} f \\\\ g \\end{{pmatrix}}'), bilinear))
display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\Bigg\\vert_{{\\substack{{(w)_{{xx}} = w^3 - w \\\\ 2 (w^2)_x = (w^2 - 1)^2 - \\mathcal{{E}}^2}}}}'), 
    Matrix([op[0,0], op[0,1], op[1,0], op[1,1]])))
display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\Bigg\\vert_{{\\substack{{(w)_{{xx}} = w^3 - w \\\\ 2 (w^2)_x = (w^2 - 1)^2 - \\mathcal{{E}}^2}}}}'), 
    Matrix([new_op[0,0], new_op[0,1], new_op[1,0], new_op[1,1]])))
#display(simplify(subs(evaluateExpr(new_op * Matrix([deriv(ww), ww])), subs_list_GP_alt_defocusing*n)))
    

In [None]:
subs_list = [(q, w + f + I * g), (q_conj, w + f - I * g)]
subs_list_GP_alt_defocusing_noncomm = [(D.comm_to_non_comm(tup[0]), D.comm_to_non_comm(tup[1])) for tup in subs_list_GP_alt_defocusing]

for n in range(0, N):
    bilinear = group_by_orders(subs(H_GP_defocusing[n], subs_list), [f, g]).get(2, 0)
    op = operator_from_bilinear(bilinear, [f, g])
    #diff = simplify(bilinear - (Matrix([f, g]).transpose() * evaluateExpr(op * Matrix([ff, gg])))[0,0])
    #display(Equ(0, simplify(variation(diff, f))**2 + simplify(variation(diff, g))**2 + simplify(variation(diff, w))**2))
    new_op = simplify(subs(simplify(subs(op, subs_list_GP_alt_defocusing_noncomm*n)), subs_list_GP_alt_defocusing_noncomm*n)).expand()
    #display(Equ(Symbol(f'\\big(f \;\; g \\big) \\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\begin{{pmatrix}} f \\\\ g \\end{{pmatrix}}'), bilinear))
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{NLS}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\Bigg\\vert_{{\\substack{{(w)_{{xx}} = w^3 - w \\\\ 2 (w^2)_x = (w^2 - 1)^2 - \\mathcal{{E}}^2}}}}'), 
        Matrix([new_op[0,0], new_op[0,1], new_op[1,0], new_op[1,1]])))
    #display(simplify(subs(evaluateExpr(new_op * Matrix([deriv(ww), ww])), subs_list_GP_alt_defocusing*n)))
    
    bilinear = group_by_orders(subs(H_NLS_defocusing[n], subs_list), [f, g]).get(2, 0)
    op = operator_from_bilinear(bilinear, [f, g])
    #diff = simplify(bilinear - (Matrix([f, g]).transpose() * evaluateExpr(op * Matrix([ff, gg])))[0,0])
    #display(Equ(0, simplify(variation(diff, f))**2 + simplify(variation(diff, g))**2 + simplify(variation(diff, w))**2))
    new_op = simplify(subs(simplify(subs(op, subs_list_GP_alt_defocusing_noncomm*n)), subs_list_GP_alt_defocusing_noncomm*n)).expand()
    #display(Equ(Symbol(f'\\big(f \;\; g \\big) \\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\begin{{pmatrix}} f \\\\ g \\end{{pmatrix}}'), bilinear))
    display(Equ(Symbol(f'\\frac{{\\delta \\mathcal{{{"E" if n%2==0 else "P"}}}_{{\\text{{GP}}}}^{{{n//2}}}(w)}}{{\delta w \delta w}}\\Bigg\\vert_{{\\substack{{(w)_{{xx}} = w^3 - w \\\\ 2 (w^2)_x = (w^2 - 1)^2 - \\mathcal{{E}}^2}}}}'), 
        Matrix([new_op[0,0], new_op[0,1], new_op[1,0], new_op[1,1]])))
    #display(simplify(subs(evaluateExpr(new_op * Matrix([deriv(ww), ww])), subs_list_GP_alt_defocusing*n)))