In [1]:
import drjit as dr
import mitsuba as mi
import numpy as np # used for ray events char matrix
from nfa import NFA
from parse import Verifier

mi.set_variant('llvm_ad_rgb')



In [6]:
# A?	0 or 1 occurrence of A
# A*	0 or more occurrence of A
# A+	1 or more occurrence of A
# .     any event. a wildcard in any position


# Build NFA given regex
regex = "1*2.3?4+"
nfa = NFA(regex)
nfa.regex_to_nfa()
verifier = Verifier()

NO_EVENT = mi.Int32(-1)
KILLED_STATE = mi.Int32(-1)
ACCEPT_STATE = mi.Int32(0)


In [10]:

# it seems mi.Int32 doesn't support n-dimension matrix?
# events for ray0, ray1, ray2, ray3
# events = np.array([['b', 's', 'd', NO_EVENT, NO_EVENT],['a', 'a', 'b', NO_EVENT, NO_EVENT],['a','b', 's', 'd', NO_EVENT],['b', 'b', 'd', 'd','c']])
events = np.array([[2, 5, 4, NO_EVENT, NO_EVENT],[1, 1, 2, NO_EVENT, NO_EVENT],[1,2, 5, 4, NO_EVENT],[2, 2, 4, 4,3]])


In [None]:
# the inital starting node for each ray is 1
curr_state_batch = dr.zeros(mi.Int32, 4)

for i in range(events.shape[1]): 
    # for each event batch. 
    # event batch means first event of all rays, seconde event of all rays..
    event_batch = events[:,i]
    next_state_batch = batch_tansition(curr_state_batch, event_batch, nfa.node_count)
    curr_state_batch = next_state_batch

In [69]:
class Interface(Object):

    NO_EVENT = mi.Int32(-1)
    KILLED_STATE = mi.Int32(-1)

    def __init__(self,regex):
        self.nfa = None
        self.verifier = None
        self.regex = regex
    
    def set_up(self):
        self.nfa = NFA(regex)
        self.nfa.regex_to_nfa()
        self.verifier = Verifier()

    # simulation:
    # at a specific time slice 
    # different rays have different state, they are expressed as a vector of state index
    # different rays have different incoming events, they are expressed as a vector of events

    # Param:
        # state batch: array of mi.Int32
        # event batch: array of mi.Int32
        # nfa_state_num: int
    # Return:
        # next_state_batch: array of mi.Int32
    def batch_tansition(self,state_batch, event_batch, nfa_state_num):
        next_state_batch = state_batch

        for idx in range(1,nfa_state_num): # handle state_batch by state index masking
            #  Filtering state and corresponding events
            mask = dr.eq(state_batch, idx)
            # filtered_state = dr.select(mask,state_batch,-1)
            filtered_event = dr.select(mask,event_batch, NO_EVENT)
            # call nfa on 
            update_filtered_state = self.state_transition_given_event(idx,filtered_event)

            next_state_batch = dr.select(mask,update_filtered_state,next_state_batch)
        
        return next_state_batch

    # Param:
        # start_node_idx: int. 
        # filtered_event: array of mi.Int32. The events for rays whose current state index is start_node_idx 
    # Return:
        # update_filtered_state: array of mi.Int32
    def state_transition_given_event(self,start_node_idx,filtered_event):
        update_filtered_state = []
        curr_start_node = self.nfa.get_node(start_node_idx)

        for e in filtered_event:

            if e == NO_EVENT:
                update_filtered_state.append(KILLED_STATE)
                continue

            next_node_set = self.verifier.verify_one(e, curr_start_node)

            if next_node_set is not None:
                next_start_node = next_node_set[0].node_ID
                update_filtered_state.append(next_start_node)
            else:
                update_filtered_state.append(KILLED_STATE)
            
            if self.verifier.has_accepted_state(next_node_set):
                update_filtered_state.append(ACCEPT_STATE)
        return update_filtered_state
        



        

In [75]:

curr_state_batch = dr.arange(mi.Int32, 4)
nfa = NFA()
nfa.set_regex(regex)
nfa.regex_to_nfa()

for i in range(events.shape[1]): # simulate incoming events for each ray at each rendering loop
    event_slice = events[:,i]
    next_state_batch = batch_tansition(curr_state_batch, event_slice, nfa.node_count)
    curr_state_batch = next_state_batch


In [None]:
# match regex
expected_result_0, passed_node_0 = verifier.verify_all("bsd",nfa.start_node)
print(expected_result_0, passed_node_0)

# early stop,dosen't match
expected_result_1, passed_node_1 = verifier.verify_all("aab",nfa.start_node)
print(expected_result_1, passed_node_1)

# match regex
expected_result_2, passed_node_2 = verifier.verify_all("absd",nfa.start_node)
print(expected_result_2, passed_node_2)

# dosen't match
expected_result_3, passed_node_3 = verifier.verify_all("bbddc",nfa.start_node)
print(expected_result_3, passed_node_3)

In [26]:
a = mi.Float([4.0, 8.0, 44.0])
b = mi.UInt32([0, 0, 0, 0])

N = 4

c = dr.zeros(mi.UInt32, N) 
d = dr.full(mi.UInt32, 4, N) 
e = dr.arange(mi.UInt32, N) # -> 0, 1, 2, 3 

# Masks
m1 = dr.full(mi.Bool, True, N)
m2 = e < 2 # -> [True, True, False, False]
for m in m2:
    print(m)


True
True
False
False


In [25]:
a = np.array(['a','b','c','d'])
m = np.array([ True,True,False,False])
a*m

UFuncTypeError: ufunc 'multiply' did not contain a loop with signature matching types (dtype('<U1'), dtype('bool')) -> None

In [28]:
state = dr.zeros(mi.UInt32, N) 
state

[0, 0, 0, 0]

In [32]:
next_state = dr.zeros(mi.UInt32, N) 

event = mi.UInt32([1, 1, 1, 2])

# Epsilon function for Node 0
m_epsilon_0 = dr.eq(event, 1)
# true as 4, false as 2
x = dr.select(m_epsilon_0, 4, 2)
print(x)
              
m_0 = dr.eq(state, 0)
print(m_0)

next_state = dr.select(m_0, x, next_state)
print(next_state)
next_state[m_0] = x
print(next_state)


[4, 4, 4, 2]
[True, True, True, True]
[4, 4, 4, 2]
[4, 4, 4, 2]


In [None]:
def nfa(state, event) -> next_state