# Combining operators
Often times we want to combine operators to make a new operator

A chain operator can be thought of multiplying two matricies together or applying one operator after the other.

Note
 - How we check the intermediate space between the two matries are the same
 - How we take advantage of storing our domain and range
 - How we create an intermediate space vector
 - How the add is used in the forward and adjoint
 - 

In [None]:
%load_ext autoreload
%autoreload 2
import sys

! pip install  "sep_plot @ git+https://github.com/SEP-software/sep-plot.git@3fac86108f59c822193cbd6f28687fecce5e298b" 
import generic_solver


In [None]:
from generic_solver._pyOperator import Operator
from sep_python import FloatVector
import numba

class chainOp(Operator):

    def __init__(self,op1,op2):
        """Initialize a chain operator 

             op = op2 op1 
        """
        self._op1=op1
        self._op2=op2
        if not self._op1.range.check_same(self._op2.domain):
            raise Exception("Spaces don't match")
        super()
## Combining operators

7 cells hidden
Click to add a cell.

.__init__(self._op1.domain,self._op2.range)
        self._vec=self._op1.range.clone()

    def forward(self,add,model,data):

        self._op1.forward(False,model,self._vec)
        self._op2.forward(add,self._vec,data)


    def adjoint(self,add,model,data):
        self._op2.adjoint(False,self._vec,data)
        self._op1.adjoint(add,model,self._vec)





## Stack operator

Another way we might combine operators is by having shared domain but a different range.  We will call this a "stack operator"

Note

- Domains match
- How the adjoint works in terms of add and order of operation

In [5]:
from generic_solver import superVector
class stockOp(Operator):

    def __init__(self,op1,op2):
        """Initialize a chain operator 

             op = op2 op1 
        """
        self._op1=op1
        self._op2=op2
        if not self._op1.domain.check_same(self._op2.domain):
            raise Exception("Spaces don't match")
        super().__init__(self._op1.domain,superVector(op1.range,op2.range))

    def forward(self,add,model,data):

        self._op1.forward(False,model,data.vecs[0])
        self._op2.forward(False,model,data.vecs[1])


    def adjoint(self,add,model,data):
        self._op2.adjoint(False,model,data.vecs[1])
        self._op1.adjoint(True,model,data.vecs[0])



In [40]:
class BoxcarF(Operator):

    def __init__(self, mod, dat,halflen):
        """
        Initialize a boxcar convolution (smoothing)

            mod, dat - sepVector
            halflen - Half length of smoothing box
        """
        super().__init__(mod, dat)
        if not isinstance(mod, FloatVector) or not isinstance(dat,FloatVector):
            raise Exception("Expecting model, data, flt to be sepVectors")
        self._halflen=halflen
        self._nd=dat.get_hyper().axes[0].n


    def forward(self, add, mod, dat):
        """
        Forward operation
        """
        self.checkDomainRange(mod, dat)
        if not add:
            dat.zero()
        sc=1./(1+2.*self._halflen)
        for i in range(self._nd):
            tmp=0
            for ib in range(i-self._halflen,i+self._halflen+1):
                tmp+=mod[max(0,min(self._nd-1,ib))]
            dat[i]+=tmp/sc


    def adjoint(self, add, mod, dat):
        """
        Adjoint operation.
        """
        self.checkDomainRange(mod, dat)
        if not add:
            mod.zero()
        
        sc=1./(1+2.*self._halflen)
        for i in range(self._nd):
            tmp=0
            for ib in range(i-self._halflen,i+self._halflen+1):
                mod[max(0,min(self._nd-1,ib))]+=dat[i]/sc
    

In [49]:
from sep_python import get_sep_vector
from sep_plot import Dots
import numpy as np
import holoviews as hv
inp=get_sep_vector(np.zeros((20),np.float32))

op1=BoxcarF(inp,inp,3)
op2=BoxcarF(inp,inp,3)

cop=chainOp(op1,op2)
p1=inp.clone()
p2=inp.clone()

cop2=chainOp(cop,cop)
inp[10]=1
cop.forward(False,inp,p1)

cop2.forward(False,inp,p2)


hv.Layout(Dots(inp)+Dots(p1)+Dots(p2)).cols(1)

  layout_plot = gridplot(
  layout_plot = gridplot(


In [43]:
print(p1[:])

[-0.02040816  0.          0.02040816  0.04081633  0.06122449  0.06122449
  0.06122449  0.04081633  0.02040816]
