Skip to content

Commit

Permalink
Adapt functions to new struct definition
Browse files Browse the repository at this point in the history
  • Loading branch information
bowenszhu committed Jun 18, 2024
1 parent 989907b commit 608e23a
Showing 1 changed file with 31 additions and 23 deletions.
54 changes: 31 additions & 23 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ sdict(kv...) = Dict{BasicSymbolic, Any}(kv...)
end

Base.@kwdef struct BasicSymbolic{T} <: Symbolic{T}
x::BasicSymbolicImpl
impl::BasicSymbolicImpl
metadata::Metadata = NO_METADATA
hash::RefValue{UInt} = Ref(EMPTY_HASH)
end
Expand All @@ -56,7 +56,7 @@ function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic})
end

function exprtype(x::BasicSymbolic)
@match x::BasicSymbolic begin
@match x.impl begin
Term => TERM
Add => ADD
Mul => MUL
Expand All @@ -71,6 +71,7 @@ end
# Same but different error messages
@noinline error_on_type() = error("Internal error: unreachable reached!")
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
@noinline error_const() = error("Const doesn't have a operation or arguments!")
@noinline error_property(E, s) = error("$E doesn't have field $s")

# We can think about bits later
Expand All @@ -94,13 +95,14 @@ symtype(x::Number) = typeof(x)

# We're returning a function pointer
@inline function operation(x::BasicSymbolic)
@match x::BasicSymbolic begin
@match x.impl begin
Term => x.f
Add => (+)
Mul => (*)
Div => (/)
Pow => (^)
Sym => error_sym()
Const => error_const()
_ => error_on_type()
end
end
Expand All @@ -109,7 +111,7 @@ end

function arguments(x::BasicSymbolic)
args = unsorted_arguments(x)
@match x::BasicSymbolic begin
@match x.impl begin
Add => @goto ADD
Mul => @goto MUL
_ => return args
Expand All @@ -132,50 +134,51 @@ end
unsorted_arguments(x) = arguments(x)
children(x::BasicSymbolic) = arguments(x)
function unsorted_arguments(x::BasicSymbolic)
@match x::BasicSymbolic begin
@match x.impl begin
Term => return x.arguments
Add => @goto ADDMUL
Mul => @goto ADDMUL
Div => @goto DIV
Pow => @goto POW
Sym => error_sym()
Const => error_const()
_ => error_on_type()
end

@label ADDMUL
E = exprtype(x)
args = x.arguments
args = x.impl.arguments
isempty(args) || return args
siz = length(x.dict)
idcoeff = E === ADD ? iszero(x.coeff) : isone(x.coeff)
siz = length(x.impl.dict)
idcoeff = E === ADD ? iszero(x.impl.coeff) : isone(x.impl.coeff)
sizehint!(args, idcoeff ? siz : siz + 1)
idcoeff || push!(args, x.coeff)
idcoeff || push!(args, x.impl.coeff)
if isadd(x)
for (k, v) in x.dict
for (k, v) in x.impl.dict
push!(args, applicable(*,k,v) ? k*v :
maketerm(k, *, [k, v]))
end
else # MUL
for (k, v) in x.dict
for (k, v) in x.impl.dict
push!(args, unstable_pow(k, v))
end
end
return args

@label DIV
args = x.arguments
args = x.impl.arguments
isempty(args) || return args
sizehint!(args, 2)
push!(args, x.num)
push!(args, x.den)
push!(args, x.impl.num)
push!(args, x.impl.den)
return args

@label POW
args = x.arguments
args = x.impl.arguments
isempty(args) || return args
sizehint!(args, 2)
push!(args, x.base)
push!(args, x.exp)
push!(args, x.impl.base)
push!(args, x.impl.exp)
return args
end

Expand Down Expand Up @@ -220,15 +223,17 @@ function _isequal(a, b, E)
if E === SYM
nameof(a) === nameof(b)
elseif E === ADD || E === MUL
coeff_isequal(a.coeff, b.coeff) && isequal(a.dict, b.dict)
coeff_isequal(a.impl.coeff, b.impl.coeff) && isequal(a.impl.dict, b.impl.dict)
elseif E === DIV
isequal(a.num, b.num) && isequal(a.den, b.den)
isequal(a.impl.num, b.impl.num) && isequal(a.impl.den, b.impl.den)
elseif E === POW
isequal(a.exp, b.exp) && isequal(a.base, b.base)
isequal(a.impl.exp, b.impl.exp) && isequal(a.impl.base, b.impl.base)
elseif E === TERM
a1 = arguments(a)
a2 = arguments(b)
isequal(operation(a), operation(b)) && _allarequal(a1, a2)
elseif E === CONST
isequal(a.impl.val, b.impl.val)
else
error_on_type()
end
Expand All @@ -246,6 +251,7 @@ const ADD_SALT = 0xaddaddaddaddadda % UInt
const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt
const DIV_SALT = 0x334b218e73bbba53 % UInt
const POW_SALT = 0x2b55b97a6efb080c % UInt
const COS_SALT = 0xdc3d6b8f18b75e3c % UInt
function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
E = exprtype(s)
if E === SYM
Expand All @@ -255,13 +261,13 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
h = s.hash[]
!iszero(h) && return h
hashoffset = isadd(s) ? ADD_SALT : SUB_SALT
h′ = hash(hashoffset, hash(s.coeff, hash(s.dict, salt)))
h′ = hash(hashoffset, hash(s.impl.coeff, hash(s.impl.dict, salt)))
s.hash[] = h′
return h′
elseif E === DIV
return hash(s.num, hash(s.den, salt DIV_SALT))
return hash(s.impl.num, hash(s.impl.den, salt DIV_SALT))
elseif E === POW
hash(s.exp, hash(s.base, salt POW_SALT))
hash(s.impl.exp, hash(s.impl.base, salt POW_SALT))
elseif E === TERM
!iszero(salt) && return hash(hash(s, zero(UInt)), salt)
h = s.hash[]
Expand All @@ -271,6 +277,8 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
h′ = hashvec(arguments(s), hash(oph, salt))
s.hash[] = h′
return h′
elseif E === CONST
return hash(s.impl.val, salt COS_SALT)
else
error_on_type()
end
Expand Down

0 comments on commit 608e23a

Please sign in to comment.