My hard_mish (and hard_mish2, switched names around from original) approximation times are as fast as ReLU for all values, and is closer to original Mish than the approximation (continuous second derivative, including at -4.0; or -1.0 for hard_mish2), I fork from.
julia> function hard_mish2(x) # x(x+1)^2 "between" ReLU
l = x + one(x)
if x >= zero(x)
return x
elseif x <= -one(x)
return zero(x)
else
return l^2*x
end
end
julia> function hard_mish(x) # x(0.25*x+1)^2 "between" ReLU
l = convert(typeof(x), 0.25)*x + one(x)
if x >= zero(x)
return x
elseif x <= convert(typeof(x), -4)
return zero(x)
else
return l^2*x
end
end
julia> hard_mish(x::Float16) = convert(Float16, Float32(x)) # converting to Float64 is as fast but thinking of GPUs, and do not fully trust timing as a bit more instructions with Float32
julia> using BenchmarkTools
julia> @btime hard_mish(-1.2) # near minimum, and similar to for original Mish, unlike for forked hard-mish
0.024 ns (0 allocations: 0 bytes)
-0.5879999999999999
julia> @btime hard_mish(Float16(-1.2))
0.025 ns (0 allocations: 0 bytes)
Float16(-0.588)
julia> @btime hard_mish2(Float16(-0.2)) # near minimum, but minumum much higher than for original Mish
0.024 ns (0 allocations: 0 bytes)
Float16(-0.1279)
julia> @btime hard_mish2(-0.5)
0.024 ns (0 allocations: 0 bytes)
-0.125
julia> @code_native hard_mish2(10.0)
.text
; ┌ @ REPL[78]:1 within `hard_mish'
vxorps %xmm1, %xmm1, %xmm1
; │ @ REPL[78]:3 within `hard_mish'
; │┌ @ operators.jl:350 within `>='
; ││┌ @ float.jl:460 within `<='
vucomisd %xmm1, %xmm0
; │└└
jae L53
movabsq $139696212938512, %rax # imm = 0x7F0D8F260F10
vmovsd (%rax), %xmm2 # xmm2 = mem[0],zero
; │ @ REPL[78]:5 within `hard_mish'
; │┌ @ float.jl:460 within `<='
vucomisd %xmm0, %xmm2
; │└
jae L57
movabsq $.rodata.cst8, %rax
vaddsd (%rax), %xmm0, %xmm1
; │ @ REPL[78]:8 within `hard_mish'
; │┌ @ intfuncs.jl:296 within `literal_pow'
; ││┌ @ float.jl:405 within `*'
vmulsd %xmm1, %xmm1, %xmm1
; │└└
; │┌ @ float.jl:405 within `*'
vmulsd %xmm0, %xmm1, %xmm0
; │└
retq
L53:
vmovaps %xmm0, %xmm1
; │ @ REPL[78]:4 within `hard_mish'
L57:
vmovaps %xmm1, %xmm0
retq
nop
; └
See plots here (substitute extreme values with ReLU, i.e. under -1 or -4, and above 0.0):
My hard_mish is also as fast for Float16, and hard_mish2 is now too (was orders of magnitute slower for that type with out special casing, while was/is as fast for machine floats).
I find mine likely to be better with the third-order polynominal (at least not slower), than second-order, the parabola in the original:
Formula - (x/2).min(2, max(0, x+2))
Figure 1. Hard Mish Activation Function
Architecture | Swish | H-Mish | Mish | ReLU |
---|---|---|---|---|
ResNet-20 | 90.42% | 92.57% | 92.68% | 91.8% |