In [1]:
from sympy import *
from IPython.display import display, Latex, HTML, Markdown
init_printing()
from eqn_manip import *
from codegen_extras import *
import codegen_extras
from importlib import reload
from sympy.codegen.ast import Assignment, For, CodeBlock, real, Variable, Pointer, Declaration
from sympy.codegen.cnodes import void

## Cubic Spline solver - derivation and code generation

### Tridiagonal Solver
From Wikipedia: https://en.wikipedia.org/wiki/Tridiagonal_matrix_algorithm

In the future it would be good to derive these equations from Gaussian elimintation (as on the Wikipedia page), but for now they are simply given.

In [2]:
n = Symbol('n', integer=True)
i = Symbol('i', integer=True)
x = IndexedBase('x',shape=(n,))
dp = IndexedBase("d'",shape=(n,))
cp = IndexedBase("c'",shape=(n,))
a = IndexedBase("a",shape=(n,))
b = IndexedBase("b",shape=(n,))
c = IndexedBase("c",shape=(n,))
d = IndexedBase("d",shape=(n,))

In [3]:
# forward sweep
# start/end using the natural range for math notation
#start = 1
#end = n
# Use the C++ range 0,n-1
start = 0
end = n-1
teq1 = Eq(cp[start], c[start]/b[start])
display(teq1)
teq2 = Eq(dp[start], d[start]/b[start])
display(teq2)
teq3 = Eq(dp[i],(d[i] - dp[i-1]*a[i])/ (b[i] - cp[i-1]*a[i]))
display(teq3)
teq4 = Eq(cp[i],c[i]/(b[i] - cp[i-1]*a[i]))
display(teq4)



        c[0]
c'[0] = ────
        b[0]

        d[0]
d'[0] = ────
        b[0]

        -a[i]⋅d'[i - 1] + d[i]
d'[i] = ──────────────────────
        -a[i]⋅c'[i - 1] + b[i]

                 c[i]         
c'[i] = ──────────────────────
        -a[i]⋅c'[i - 1] + b[i]

In [4]:
# backward sweep
teq5 = Eq(x[end],dp[end])
display(teq5)
teq6 = Eq(x[i],dp[i] - cp[i]*x[i+1])
display(teq6)

x[n - 1] = d'[n - 1]

x[i] = -c'[i]⋅x[i + 1] + d'[i]

### Cubic Spline equations
Start with uniform knot spacing.  The derivation is easier to see than in the case with general knot spacing.

In [5]:
# Distance from the previous knot, for the case of uniform knot spacing
t = Symbol('t')

# Number of knots
n = Symbol('n', integer=True)
i = Symbol('i', integer=True)
# Function values to intepolated at the knots
y = IndexedBase('y',shape=(n,))

# Coefficients of the spline function
a,b,c,d = [IndexedBase(s, shape=(n,)) for s in 'a b c d'.split()]

# Cubic spline function
s = a + b*t + c*t*t + d*t**3
display(Eq(y,s))

# With indexed variables
si = a[i] + b[i]*t + c[i]*t*t + d[i]*t**3
display(Eq(y[i],si))

     3      2            
y = t ⋅d + t ⋅c + t⋅b + a

        3         2                     
y[i] = t ⋅d[i] + t ⋅c[i] + t⋅b[i] + a[i]

### Strategy
To eventually reduce the equations to a tridiagonal form, express the equations in terms of the second derivative ($E$).
See the MathWorld page for cubic splines, which derives the equations in terms of the first derivative ($D$).

http://mathworld.wolfram.com/CubicSpline.html

In [6]:
# Value at knots (t=0)
sp1 = Eq(si.subs(t,0), y[i])
sp1

a[i] = y[i]

In [7]:
# Value at knots (t=1)
sp2 = Eq(si.subs(t,1), y[i+1])
sp2

a[i] + b[i] + c[i] + d[i] = y[i + 1]

In [8]:
# Express the second derivative at the beginning of the interval in terms of E
E = IndexedBase('E',shape=(n,))
sp3 = Eq(E[i], diff(si,t,2).subs(t,0))
sp3

E[i] = 2⋅c[i]

In [9]:
# Express the second derivative at the end of the interval in terms of E
sp4 = Eq(E[i+1], diff(si,t,2).subs(t,1))
sp4

E[i + 1] = 2⋅c[i] + 6⋅d[i]

In [10]:
# Continuity of the first derivative
sp5 = Eq(diff(si,t).subs(t,1), diff(si,t).subs(t,0).subs(i,i+1))
sp5

b[i] + 2⋅c[i] + 3⋅d[i] = b[i + 1]

### For general spacing of the knots

In [11]:
L = IndexedBase('L',shape=(n,))   # L[i] = x[i+1] - x[i]
t = Symbol('t')
x = IndexedBase('x',shape=(n,))

si = a[i] + b[i]*t + c[i]*t*t + d[i]*t**3

In [12]:
# Value at knots (t=0)
sp1 = Eq(si.subs(t,0), y[i])
sp1

a[i] = y[i]

In [13]:
# Value at next knot
sp2 = Eq(si.subs(t,L[i]), y[i+1])
sp2

    3            2                                   
L[i] ⋅d[i] + L[i] ⋅c[i] + L[i]⋅b[i] + a[i] = y[i + 1]

In [14]:
# Express the second derivative at the beginning of the interval in terms of E
E = IndexedBase('E',shape=(n,))
sp3 = Eq(E[i], diff(si,t,2).subs(t,0))
sp3

E[i] = 2⋅c[i]

In [15]:
# Express the second derivative at the end of the interval in terms of E
sp4 = Eq(E[i+1], diff(si,t,2).subs(t,L[i]))
sp4

E[i + 1] = 6⋅L[i]⋅d[i] + 2⋅c[i]

In [16]:
 # Solve for spline coefficients in terms of E's
sln = solve([sp1,sp2,sp3,sp4], [a[i],b[i],c[i],d[i]])
sln

⎧                                            2                                
⎪                    (E[i + 1] + 2⋅E[i])⋅L[i]                                 
⎪                  - ───────────────────────── + y[i + 1] - y[i]              
⎨                                6                                      E[i]  
⎪a[i]: y[i], b[i]: ─────────────────────────────────────────────, c[i]: ────, 
⎪                                       L[i]                             2    
⎩                                                                             

                     ⎫
                     ⎪
                     ⎪
      E[i + 1] - E[i]⎬
d[i]: ───────────────⎪
           6⋅L[i]    ⎪
                     ⎭

In [17]:
# also for i+1
sln1 = {k.subs(i,i+1):v.subs(i,i+1) for k,v in sln.items()}
sln1

⎧                                                                2            
⎪                                (2⋅E[i + 1] + E[i + 2])⋅L[i + 1]             
⎪                              - ───────────────────────────────── - y[i + 1] 
⎨                                                6                            
⎪a[i + 1]: y[i + 1], b[i + 1]: ───────────────────────────────────────────────
⎪                                                       L[i + 1]              
⎩                                                                             

                                                              ⎫
                                                              ⎪
+ y[i + 2]                                                    ⎪
                      E[i + 1]            -E[i + 1] + E[i + 2]⎬
──────────, c[i + 1]: ────────, d[i + 1]: ────────────────────⎪
                         2                     6⋅L[i + 1]     ⎪
                                                              

In [18]:
# Continuity of first derivatives at knots
# This will define the tridiagonal system to be solved
sp5 = Eq(diff(si,t).subs(t,L[i]), diff(si,t).subs(i, i+1).subs(t,0))
sp5

      2                                     
3⋅L[i] ⋅d[i] + 2⋅L[i]⋅c[i] + b[i] = b[i + 1]

In [19]:
sp6 = sp5.subs(sln).subs(sln1)
sp7 = expand(sp6)
sp7

E[i + 1]⋅L[i]   E[i]⋅L[i]   y[i + 1]   y[i]     E[i + 1]⋅L[i + 1]   E[i + 2]⋅L
───────────── + ───────── + ──────── - ──── = - ───────────────── - ──────────
      3             6         L[i]     L[i]             3                   6 

[i + 1]   y[i + 1]   y[i + 2]
─────── - ──────── + ────────
          L[i + 1]   L[i + 1]

In [20]:
sp8 = divide_terms(sp7, [E[i],E[i+1],E[i+2]], [y[i],y[i+1],y[i+2]])
display(sp8)
sp9 = mult_eqn(sp8,6)
display(sp9)

# The index 'i' used in the cubic spline equations is not the same 'i' used
# in the tridigonal solver.   Here we need to make them match.
# The first foundary condition will the equation at index at 0.
# Adjust the indexing on this equation so i=1 is the index of the first continuity interval match
sp9 = sp9.subs(i,i-1)

E[i + 1]⋅L[i + 1]   E[i + 1]⋅L[i]   E[i + 2]⋅L[i + 1]   E[i]⋅L[i]     y[i + 1]
───────────────── + ───────────── + ───────────────── + ───────── = - ────────
        3                 3                 6               6           L[i]  

   y[i]   y[i + 1]   y[i + 2]
 + ──── - ──────── + ────────
   L[i]   L[i + 1]   L[i + 1]

                                                                          6⋅y[
2⋅E[i + 1]⋅L[i + 1] + 2⋅E[i + 1]⋅L[i] + E[i + 2]⋅L[i + 1] + E[i]⋅L[i] = - ────
                                                                             L

i + 1]   6⋅y[i]   6⋅y[i + 1]   6⋅y[i + 2]
────── + ────── - ────────── + ──────────
[i]       L[i]     L[i + 1]     L[i + 1] 

In [21]:
# Extract the three coefficients in each row for the general case
symlist = [E[i-1],E[i],E[i+1],E[i+2]]
coeff1 = get_coeff_for(sp9.lhs, E[i-1], symlist)
display(coeff1)
coeff2 = get_coeff_for(sp9.lhs, E[i], symlist)
display(coeff2)
coeff3 = get_coeff_for(sp9.lhs, E[i+1], symlist)
display(coeff3)

L[i - 1]

2⋅L[i - 1] + 2⋅L[i]

L[i]

In [22]:
# Now get the coefficients for the boundary conditions (first row and last row)

# Natural BC
bc_natural_start = Eq(E[i].subs(i,0),0)
display(bc_natural_start)
bc_natural_end = Eq(E[i].subs(i,end),0)
display(bc_natural_end)

# The coefficients and RHS for this BC are pretty simple. but we will follow
# a deterministic path for derivation anyway.
bc_natural_start_coeff1 = get_coeff_for(bc_natural_start.lhs, E[start],[E[start]])
display(bc_natural_start_coeff1)
bc_natural_start_coeff2 = get_coeff_for(bc_natural_start.lhs, E[start+1],[E[start],E[start+1]])
display(bc_natural_start_coeff2)
bc_natural_end_coeff1 = get_coeff_for(bc_natural_end.lhs, E[end-1],[E[end]])
display(bc_natural_end_coeff1)
bc_natural_end_coeff2 = get_coeff_for(bc_natural_end.lhs, E[end],[E[end]])
bc_natural_end_coeff2

E[0] = 0

E[n - 1] = 0

1

0

0

1

In [23]:
# BC - first derivative specified at the beginning of the range
yp0 = Symbol('yp0')
eqbc1=Eq(diff(si,t).subs(t,0).subs(sln).subs(i,0), yp0)
display(eqbc1)
eqbc1b = divide_terms(expand(eqbc1),[E[0],E[1]],[y[0],y[1],yp0])
eqbc1c = mult_eqn(eqbc1b, 6)
display(eqbc1c)
bc_firstd_start_coeff1 = get_coeff_for(eqbc1c.lhs, E[0], [E[0],E[1]])
display(bc_firstd_start_coeff1)
bc_firstd_start_coeff2 = get_coeff_for(eqbc1c.lhs, E[1], [E[0],E[1]])
display(bc_firstd_start_coeff2)

                      2                    
  (2⋅E[0] + E[1])⋅L[0]                     
- ───────────────────── - y[0] + y[1]      
            6                              
───────────────────────────────────── = yp₀
                 L[0]                      

                                   6⋅y[0]   6⋅y[1]
-2⋅E[0]⋅L[0] - E[1]⋅L[0] = 6⋅yp₀ + ────── - ──────
                                    L[0]     L[0] 

-2⋅L[0]

-L[0]

In [24]:
# For the general algorithm, the input parameters for the boundary conditions are
#  - first derivative, if value is less than cutoff
#  - second derivative is zero, if vlaue is greater than cutoff

bc_cutoff = 0.99e30

tbc_start_coeff1 = Piecewise((bc_firstd_start_coeff1, yp0 < bc_cutoff),(bc_natural_start_coeff1,True))
display(tbc_start_coeff1)
tbc_start_coeff2 = Piecewise((bc_firstd_start_coeff2, yp0 < bc_cutoff),(bc_natural_start_coeff2,True))
display(tbc_start_coeff2)

sym_bc_start_coeff1 = Symbol('bc_start1')
sym_bc_start_coeff2 = Symbol('bc_start2')
bc_eqs = [Eq(sym_bc_start_coeff1, tbc_start_coeff1)]
bc_eqs.append(Eq(sym_bc_start_coeff2, tbc_start_coeff2))


⎧-2⋅L[0]  for yp₀ < 9.9e+29
⎨                          
⎩   1         otherwise    

⎧-L[0]  for yp₀ < 9.9e+29
⎨                        
⎩  0        otherwise    

In [25]:
# BC - first derivative specified at the end of the range
ypn = Symbol('ypn')
eqbc2=Eq(diff(si,t).subs(t,L[end-1]).subs(sln).subs(i,end-1),ypn)
display(eqbc2)
eqbc2b = divide_terms(expand(eqbc2),[E[end-1],E[end]],[y[end-1],y[end],ypn])
display(eqbc2b)
eqbc2c = mult_eqn(eqbc2b, 6)
display(eqbc2c)
bc_firstd_end_coeff1 = get_coeff_for(eqbc2c.lhs, E[end-1],[E[end-1],E[end]])
display(bc_firstd_end_coeff1)
bc_firstd_end_coeff2 = get_coeff_for(eqbc2c.lhs, E[end],[E[end-1],E[end]])
display(bc_firstd_end_coeff2)

                                                                   2          
                                   (E[n - 1] + 2⋅E[n - 2])⋅L[n - 2]           
                                 - ───────────────────────────────── + y[n - 1
(E[n - 1] - E[n - 2])⋅L[n - 2]                     6                          
────────────────────────────── + ─────────────────────────────────────────────
              2                                           L[n - 2]            

                                      
                                      
] - y[n - 2]                          
                                      
──────────── + E[n - 2]⋅L[n - 2] = ypn
                                      

E[n - 1]⋅L[n - 2]   E[n - 2]⋅L[n - 2]         y[n - 1]   y[n - 2]
───────────────── + ───────────────── = ypn - ──────── + ────────
        3                   6                 L[n - 2]   L[n - 2]

                                                  6⋅y[n - 1]   6⋅y[n - 2]
2⋅E[n - 1]⋅L[n - 2] + E[n - 2]⋅L[n - 2] = 6⋅ypn - ────────── + ──────────
                                                   L[n - 2]     L[n - 2] 

L[n - 2]

2⋅L[n - 2]

In [26]:
# Create the conditional expression for the end BC
tbc_end_coeff1 = Piecewise((bc_firstd_end_coeff1, ypn < bc_cutoff),(bc_natural_end_coeff1, True))
display(tbc_end_coeff1)
sym_bc_end_coeff1 = Symbol('bc_end1')
bc_eqs.append(Eq(sym_bc_end_coeff1, tbc_end_coeff1))
tbc_end_coeff2 = Piecewise((bc_firstd_end_coeff2, ypn < bc_cutoff),(bc_natural_end_coeff2, True))
tbc_end_coeff2
display(tbc_end_coeff2)
sym_bc_end_coeff2 = Symbol('bc_end2')
bc_eqs.append(Eq(sym_bc_end_coeff2, tbc_end_coeff2))

⎧L[n - 2]  for ypn < 9.9e+29
⎨                           
⎩   0          otherwise    

⎧2⋅L[n - 2]  for ypn < 9.9e+29
⎨                             
⎩    1           otherwise    

In [27]:
# conditional expressions for RHS for boundary conditions
rhs_start = Piecewise((eqbc1c.rhs,yp0 < bc_cutoff),(bc_natural_start.rhs,True))
display(rhs_start)
rhs_end = Piecewise((eqbc2c.rhs, ypn < bc_cutoff), (bc_natural_end.rhs, True))
display(rhs_end)

sym_rhs_start = Symbol('rhs_start')
sym_rhs_end = Symbol('rhs_end')
bc_eqs.append(Eq(sym_rhs_start, rhs_start))
bc_eqs.append(Eq(sym_rhs_end, rhs_end))
bc_eqs

⎧        6⋅y[0]   6⋅y[1]                   
⎪6⋅yp₀ + ────── - ──────  for yp₀ < 9.9e+29
⎨         L[0]     L[0]                    
⎪                                          
⎩           0                 otherwise    

⎧        6⋅y[n - 1]   6⋅y[n - 2]                   
⎪6⋅ypn - ────────── + ──────────  for ypn < 9.9e+29
⎨         L[n - 2]     L[n - 2]                    
⎪                                                  
⎩               0                     otherwise    

⎡                                                                             
⎢           ⎧-2⋅L[0]  for yp₀ < 9.9e+29             ⎧-L[0]  for yp₀ < 9.9e+29 
⎢bcₛₜₐᵣₜ₁ = ⎨                          , bcₛₜₐᵣₜ₂ = ⎨                        ,
⎢           ⎩   1         otherwise                 ⎩  0        otherwise     
⎣                                                                             

                                                                              
           ⎧L[n - 2]  for ypn < 9.9e+29            ⎧2⋅L[n - 2]  for ypn < 9.9e
 bc_end1 = ⎨                           , bc_end2 = ⎨                          
           ⎩   0          otherwise                ⎩    1           otherwise 
                                                                              

                ⎧        6⋅y[0]   6⋅y[1]                               ⎧      
+29             ⎪6⋅yp₀ + ────── - ──────  for yp₀ < 9.9e+29            ⎪6⋅ypn 
   , rhsₛₜₐᵣₜ = ⎨         L[0]     L[0]           

 ### Substitute cubic spline equations into tridiagonal solver

In [28]:

subslist = {
    a[start] : 0,
    a[i] : coeff1,
    a[end] : sym_bc_end_coeff1,
    
    b[start] : sym_bc_start_coeff1,
    b[i] : coeff2,
    b[end] : sym_bc_end_coeff2,
    
    c[start] : sym_bc_start_coeff2,
    c[i] : coeff3,
    c[end] : 0,
    
    d[start] : sym_rhs_start,
    d[i] : sp9.rhs,
    d[end] : sym_rhs_end,
}

# Replace knot spacing with differences bewteen knot locations
subsL = {
  L[i] : x[i+1] - x[i],
  L[i+1] : x[i+2] - x[i+1],
  L[i-1] : x[i] - x[i-1],
  L[start] : x[start+1]-x[start],
  L[start+1] : x[start+2]-x[start+1],
  L[end-1] : x[end] - x[end-1],
}
subslist

⎧                                                                             
⎨a[0]: 0, a[i]: L[i - 1], a[n - 1]: bc_end1, b[0]: bcₛₜₐᵣₜ₁, b[i]: 2⋅L[i - 1] 
⎩                                                                             

                                                                              
+ 2⋅L[i], b[n - 1]: bc_end2, c[0]: bcₛₜₐᵣₜ₂, c[i]: L[i], c[n - 1]: 0, d[0]: rh
                                                                              

              6⋅y[i + 1]   6⋅y[i]   6⋅y[i - 1]    6⋅y[i]                    ⎫
sₛₜₐᵣₜ, d[i]: ────────── - ────── + ────────── - ────────, d[n - 1]: rhs_end⎬
                 L[i]       L[i]     L[i - 1]    L[i - 1]                   ⎭

In [29]:
# Substitute into the tridiagonal solver
display(teq1.subs(subslist))
teq2b = teq2.subs(subslist).subs(subsL)
display(teq2b)
teq3b = simplify(teq3.subs(subslist).subs(subsL))
display(teq3b)
teq4b = teq4.subs(subslist).subs(subsL)
display(teq4b)
teq5b = Eq(teq5.lhs,teq5.rhs.subs(dp[end],teq3.rhs).subs(i,end).subs(subslist))
display(teq5b)
display(teq6.subs(subslist))

        bcₛₜₐᵣₜ₂
c'[0] = ────────
        bcₛₜₐᵣₜ₁

        rhsₛₜₐᵣₜ
d'[0] = ────────
        bcₛₜₐᵣₜ₁

                                             2                                
        - (x[i + 1] - x[i])⋅(x[i - 1] - x[i]) ⋅d'[i - 1] + 6⋅(x[i + 1] - x[i])
d'[i] = ──────────────────────────────────────────────────────────────────────
                           (x[i + 1] - x[i])⋅(x[i - 1] - x[i])⋅(-(x[i - 1] - x

                                                           
⋅(y[i - 1] - y[i]) + 6⋅(x[i - 1] - x[i])⋅(-y[i + 1] + y[i])
───────────────────────────────────────────────────────────
[i])⋅c'[i - 1] - 2⋅x[i + 1] + 2⋅x[i - 1])                  

                            x[i + 1] - x[i]                    
c'[i] = ───────────────────────────────────────────────────────
        -(-x[i - 1] + x[i])⋅c'[i - 1] + 2⋅x[i + 1] - 2⋅x[i - 1]

           -bc_end1⋅d'[n - 2] + rhs_end
x[n - 1] = ────────────────────────────
           -bc_end1⋅c'[n - 2] + bc_end2

x[i] = -c'[i]⋅x[i + 1] + d'[i]

In [30]:
# Extract sub-expressions
subexpr, final_expr = cse([simplify(teq3b),simplify(teq4b)],symbols=numbered_symbols('z'))
display(subexpr)
display(final_expr)

[(z₀, -x[i]), (z₁, z₀ + x[i + 1]), (z₂, z₀ + x[i - 1]), (z₃, 2⋅x[i + 1]), (z₄,
 2⋅x[i - 1]), (z₅, z₂⋅c'[i - 1]), (z₆, -y[i])]

⎡             2                                                               
⎢        z₁⋅z₂ ⋅d'[i - 1] - 6⋅z₁⋅(z₆ + y[i - 1]) + 6⋅z₂⋅(z₆ + y[i + 1])       
⎢d'[i] = ──────────────────────────────────────────────────────────────, c'[i]
⎣                             z₁⋅z₂⋅(z₃ - z₄ + z₅)                            

                   ⎤
   -x[i + 1] + x[i]⎥
 = ────────────────⎥
    -z₃ + z₄ - z₅  ⎦

In [31]:
# Substitute knot spacing into the boundary conditions
bc_eqs2 = [eq.subs(subsL) for eq in bc_eqs]
bc_eqs2

⎡                                                                             
⎢           ⎧2⋅x[0] - 2⋅x[1]  for yp₀ < 9.9e+29             ⎧x[0] - x[1]  for 
⎢bcₛₜₐᵣₜ₁ = ⎨                                  , bcₛₜₐᵣₜ₂ = ⎨                 
⎢           ⎩       1             otherwise                 ⎩     0           
⎣                                                                             

                                                                              
yp₀ < 9.9e+29            ⎧x[n - 1] - x[n - 2]  for ypn < 9.9e+29            ⎧2
             , bc_end1 = ⎨                                      , bc_end2 = ⎨ 
otherwise                ⎩         0               otherwise                ⎩ 
                                                                              

                                                      ⎧           6⋅y[0]      
⋅x[n - 1] - 2⋅x[n - 2]  for ypn < 9.9e+29             ⎪6⋅yp₀ + ──────────── - 
                                         , rhsₛₜₐᵣ

In [32]:
# Use temporary storage for cp, and reuse output vector for dp
#  In the future there should be some dependency analysis to verify this is a legal transformation
tmp = IndexedBase('u',shape=(n,))
y2 = IndexedBase('y2',shape=(n,))
storage_subs = {cp:y2, dp:tmp}
#storage_subs = {}
teq1c = teq1.subs(subslist).subs(storage_subs)
display(teq1c)
teq2c = teq2b.subs(subslist).subs(storage_subs)
display(teq2c)
teq3c = final_expr[0].subs(storage_subs)
display(teq3c)
teq4c = final_expr[1].subs(storage_subs)
display(teq4c)
teq5c = teq5b.subs(storage_subs).subs(x,y2)
display(teq5c)
teq6c = teq6.subs(storage_subs).subs(x,y2)
display(teq6c)

        bcₛₜₐᵣₜ₂
y2[0] = ────────
        bcₛₜₐᵣₜ₁

       rhsₛₜₐᵣₜ
u[0] = ────────
       bcₛₜₐᵣₜ₁

            2                                                       
       z₁⋅z₂ ⋅u[i - 1] - 6⋅z₁⋅(z₆ + y[i - 1]) + 6⋅z₂⋅(z₆ + y[i + 1])
u[i] = ─────────────────────────────────────────────────────────────
                            z₁⋅z₂⋅(z₃ - z₄ + z₅)                    

        -x[i + 1] + x[i]
y2[i] = ────────────────
         -z₃ + z₄ - z₅  

            -bc_end1⋅u[n - 2] + rhs_end 
y2[n - 1] = ────────────────────────────
            -bc_end1⋅y2[n - 2] + bc_end2

y2[i] = u[i] - y2[i + 1]⋅y2[i]

In [33]:
# Now for some code generation
#reload(codegen_more)
#from codegen_more import *

In [34]:
templateT = Type('T')

In [35]:
# forward sweep
fr = ARange(start+1,end,1)

body = []
for e in subexpr:
    body.append(Variable(e[0],type=templateT).as_Declaration(value=e[1].subs(storage_subs)))
    
body.append(convert_eq_to_assignment(teq3c))
body.append(convert_eq_to_assignment(teq4c))
loop1 = For(i,fr,body)

In [36]:
# backward sweep
br = ARangeClosedEnd(end-1,start,-1)
loop2 = For(i,br,[convert_eq_to_assignment(teq6c)])

In [37]:
tmp_init = VariableWithInit("n",tmp,type=Type("std::vector<T>")).as_Declaration()
bc_tmps = []
for e in bc_eqs2:
    bc_tmps.append(Variable(e.lhs, type=templateT).as_Declaration(value=e.rhs))
algo = CodeBlock(tmp_init,
                 *bc_tmps,
                 convert_eq_to_assignment(teq1c),
                 convert_eq_to_assignment(teq2c),
                 loop1,
                 convert_eq_to_assignment(teq5c),
                 loop2)

In [38]:
# Generate the inner part of the algorithm to check it
ACP = ACodePrinter()
s = ACP.doprint(algo)
print(s)

// Not supported in C++:
// IndexedBase
std::vector<T> u(n);
T bc_start1 = ((yp0 < 9.9000000000000002e+29) ? (
   2*x[0] - 2*x[1]
)
: (
   1
));
T bc_start2 = ((yp0 < 9.9000000000000002e+29) ? (
   x[0] - x[1]
)
: (
   0
));
T bc_end1 = ((ypn < 9.9000000000000002e+29) ? (
   x[n - 1] - x[n - 2]
)
: (
   0
));
T bc_end2 = ((ypn < 9.9000000000000002e+29) ? (
   2*x[n - 1] - 2*x[n - 2]
)
: (
   1
));
T rhs_start = ((yp0 < 9.9000000000000002e+29) ? (
   6*yp0 + 6*y[0]/(-x[0] + x[1]) - 6*y[1]/(-x[0] + x[1])
)
: (
   0
));
T rhs_end = ((ypn < 9.9000000000000002e+29) ? (
   6*ypn - 6*y[n - 1]/(x[n - 1] - x[n - 2]) + 6*y[n - 2]/(x[n - 1] - x[n - 2])
)
: (
   0
));
y2[0] = bc_start2/bc_start1;
u[0] = rhs_start/bc_start1;
for (auto i = 1; i < n - 1; i += 1) {
   T z0 = -x[i];
   T z1 = z0 + x[i + 1];
   T z2 = z0 + x[i - 1];
   T z3 = 2*x[i + 1];
   T z4 = 2*x[i - 1];
   T z5 = z2*y2[i - 1];
   T z6 = -y[i];
   u[i] = (z1*z2*z2*u[i - 1] - 6*z1*(z6 + y[i - 1]) + 6*z2*(z6 + y[i + 1]))/(z1*z2*(z3 -

In [39]:
# Set up to create a template function
tx = Pointer(x,type=templateT)
ty = Pointer(y,type=templateT)
ty2 = Pointer(y2,type=templateT)
yp0_var = Variable('yp0',type=templateT)
ypn_var = Variable('ypn',type=templateT)

tf = TemplateFunctionDefinition(void, "cubic_spline_solve",[tx,ty,n,yp0_var,ypn_var,ty2],[templateT],algo)

In [40]:
ACP = ACodePrinter()
s = ACP.doprint(tf)
print(s)

// Not supported in C++:
// IndexedBase
// IndexedBase
// IndexedBase
// IndexedBase
template<typename T>
void cubic_spline_solve(T * x, T * y, int n, T yp0, T ypn, T * y2){
   std::vector<T> u(n);
   T bc_start1 = ((yp0 < 9.9000000000000002e+29) ? (
      2*x[0] - 2*x[1]
   )
   : (
      1
   ));
   T bc_start2 = ((yp0 < 9.9000000000000002e+29) ? (
      x[0] - x[1]
   )
   : (
      0
   ));
   T bc_end1 = ((ypn < 9.9000000000000002e+29) ? (
      x[n - 1] - x[n - 2]
   )
   : (
      0
   ));
   T bc_end2 = ((ypn < 9.9000000000000002e+29) ? (
      2*x[n - 1] - 2*x[n - 2]
   )
   : (
      1
   ));
   T rhs_start = ((yp0 < 9.9000000000000002e+29) ? (
      6*yp0 + 6*y[0]/(-x[0] + x[1]) - 6*y[1]/(-x[0] + x[1])
   )
   : (
      0
   ));
   T rhs_end = ((ypn < 9.9000000000000002e+29) ? (
      6*ypn - 6*y[n - 1]/(x[n - 1] - x[n - 2]) + 6*y[n - 2]/(x[n - 1] - x[n - 2])
   )
   : (
      0
   ));
   y2[0] = bc_start2/bc_start1;
   u[0] = rhs_start/bc_start1;
   for (auto i = 1; i < n -