In [6]:
using Distributed
# addprocs(44);

In [1]:
using pulse_input_DDM
using pulse_input_DDM.Parameters:@unpack

In [2]:
using BenchmarkTools

In [3]:
using CUDA

In [None]:
using Profile
using ProfileSVG
ProfileSVG.set_default(maxdepth = 500, maxframes = 3000)

In [4]:
n = 53 
dt = 1e-2
ntrials = 50000

σ2_i = 5.
B = 8.
λ = -1.
σ2_a = 20.
σ2_s = 0.5
ϕ = 0.8
τ_ϕ = 0.05

bias=0.
lapse=0.0

θ = θchoice(θz=θz(σ2_i = σ2_i, B = B, λ = λ, σ2_a = σ2_a, σ2_s = σ2_s,
    ϕ = ϕ, τ_ϕ = τ_ϕ), bias= bias, lapse=lapse);

In [7]:
rng = 1
clicks = synthetic_clicks(ntrials, rng, tmin = 1., tmax = 1.)
binned_clicks = bin_clicks.(clicks,centered=true,dt=dt)
inputs = map((clicks, binned_clicks)-> pulse_input_DDM.choiceinputs(clicks=clicks, binned_clicks=binned_clicks, 
    dt=dt, centered=true), clicks, binned_clicks)
 
ntrials = length(inputs)
rng = pulse_input_DDM.sample(pulse_input_DDM.Random.seed!(rng), 1:ntrials, ntrials; replace=false)

#choices = rand.(Ref(θ), inputs, rng)
choices = pmap((inputs, rng) -> rand(θ, inputs, rng), inputs, rng)

data = pulse_input_DDM.choicedata.(inputs, choices);


In [None]:
model = choiceDDM(θ=θ, data=data, n=n)

In [None]:
@benchmark loglikelihood(model)

In [11]:
function data_for_gpu(inputs, θ, ntrials)
    
    @unpack σ2_s = θ.θz
    
    mus = zeros(Float32, ceil(Int, 1/dt)+1, ntrials)
    sigmas = zeros(Float32, ceil(Int, 1/dt), ntrials)

    leftclicks = map(input -> map(x->round.(Int,x), (input.clicks.L ./dt)), inputs) 
    rightclicks = map(input -> map(x->round.(Int,x), (input.clicks.R ./dt)), inputs) 

    for i = 1:ntrials
        leftidx = unique(leftclicks[i])[2:end-1]
        rightidx = unique(rightclicks[i])[2:end-1]
        mus[leftidx,i] .= Float32(-1)
        mus[rightidx,i] .= Float32(1)
        sigmas[leftidx, i] .= Float32(σ2_s)
        sigmas[rightidx, i] .= Float32(σ2_s)

    end
    
    return cu(mus), cu(sigmas)
    
end



function transition_M_element_one!(i::Int, j::Int, σ2::TT, λ::TT, μ::TT, dx::UU,
        xc::Union{CuDeviceVector{TT}, Vector{TT}}, n::Int, dt::Float32) where {TT,UU <: Any}

    #if s <= xc[1]
      # F[1,j] += ps[k]
    # elseif s >= xc[n]
      # F[n,j] += ps[k]

    mu = exp(λ*dt)*xc[j] + μ * expm1_div_x(λ*dt)
    norm = sqrt(2*pi*σ2)/dx

    res ::Float32 = 0.

    if ((i == 1 && j == 1) || (i == n && j == n))
       return Float32(1.0)
    end

    if j == 1 || j == n
      return Float32(0.0)
    end

    s = xc[2] + (i-2)*dx

    val = Float32(exp(-0.5 * (s-mu)^2/σ2)/norm)

    if i == 1
      return Float32(0.5 * val) + Float32(sqrt((pi*σ2)/2) * (erf(Float32((xc[1] - mu)/sqrt(2*σ2))) + 1) / (dx* norm))
    end

    if i == n
      return Float32(0.5 * val) + Float32(sqrt((pi*σ2)/2) * (erfc(Float32((xc[n] - mu)/sqrt(2*σ2)))) / (dx * norm))
    end
    
    return val
    
end




function matrix_vector_dot!(σ2::CuDeviceVector{TT}, λ:: Float32, μ::CuDeviceVector{TT}, dx::Float32,
        xc::CuDeviceVector{TT}, n::Int, dt::Float32,
        P, P_out) where {TT<:Any}

    sdata = @cuStaticSharedMem(Float32,1024)

    tid = threadIdx().x
    tid_global = (threadIdx().y - 1) * blockDim().x + threadIdx().x

    i_trial = blockIdx().z

    for i = 0:blockDim().y:n
      j = tid
      sdata[tid] = 0.0
      i_local = i + threadIdx().y
      
      if j<=n && i_local <= n
          # initialize
          P_out[blockIdx().y] = 0.0

          sdata[tid_global] = transition_M_element_one!(i_local, j, σ2[i_trial], λ, μ[i_trial], dx, xc, n, dt)*P[i_trial, j]
          #sdata[tid_global] = Float32(1.0)

          j += blockDim().x
          while j<=n
              sdata[tid_global] += transition_M_element_one!(i_local, j, σ2[i_trial], λ, μ[i_trial], dx, xc, n, dt)*P[i_trial, j]
              #sdata[tid_global] += Float32(1.0)
              j += blockDim().x
          end
      end

      sync_threads()

      # do reduction in shared mem
      s = blockDim().x ÷ 2
      while s > 0
        if ((tid-1) < s)
            sdata[tid_global] += sdata[tid_global+s]
        end
        sync_threads()
        s = (s÷2)
      end

      if i_local <= n
        # write result
        if tid==1
            P_out[i_trial,i_local] = sdata[1+(threadIdx().y - 1)*blockDim().x]
        end
      end

    end
    return nothing
end
     



function propagate_P(P::CuArray{TT,2}, σ2::CuArray{TT}, λ, μ::CuArray{TT}, n::Int, dx::VV, xc::CuArray{TT}, dt::Float64) where {TT,VV <: Any}
    
    ntrials = size(P, 1)
    P_out = CUDA.zeros(TT,ntrials,n)

    thread_x = 16
    thread_y = 16

    @cuda threads=(thread_x,thread_y) blocks=(Int(round(n/thread_x))+1, Int(round(n/thread_y))+1, ntrials) matrix_vector_dot!(σ2, Float32(λ), μ , Float32(dx),
        xc, n, Float32(dt), P, P_out)
    
    sync.threads()
    return P_out
end



function P_all_trials(θ, n, dt, μs, σs)
    
    @unpack σ2_i, B, λ, σ2_a, σ2_s = θ.θz
    
    P,M,xc,dx = pulse_input_DDM.initialize_latent_model(σ2_i, B, λ, σ2_a, n, dt)
    xc = cu(xc)

    ntrial = size(μs)[1]
    T = size(μs)[2]
    P = cu(Array(repeat(P, 1, ntrials)'))
    
    # walk over time steps
    @inbounds for t = 1:T
        print(t)
        P = propagate_P(P, σs[:,t] .+ Float32(σ2_a), λ, μs[:,t], n, dx, xc, dt)
    end
    
    return P
    
end


function choicelik(choice, pright, pleft)
    if choice 
        return pright
    else 
        return pleft
    end
end


function loglikelihood_gpu(μs, σs, θ, data, n, dt)
    
    P = P_all_trials(θ, n, dt, μs, σs)
    
    choice = map(data->data.choice, data)
    P_left = sum(P[1:floor(Int,n/2)])
    P_right = sum(P[ceil(Int,n/2):n])
    likelihood = pmap((choice, pright, pleft) -> choicelik(choice, pright, pleft), choice, P_right, P_left)
    
    return log.(likelihood)

end


    


loglikelihood_gpu (generic function with 1 method)

In [12]:
μs, σs =  data_for_gpu(inputs, θ, ntrials);


In [None]:
@unpack σ2_i, B, λ, σ2_a, σ2_s = θ.θz
    
P,M,xc,dx = pulse_input_DDM.initialize_latent_model(σ2_i, B, λ, σ2_a, n, dt)
xc = cu(xc)

ntrial = size(μs)[1]
T = size(μs)[2]
P = cu(Array(repeat(P, 1, ntrials)'))


t = 1

propagate_P(P, σs[:,t] .+ Float32(σ2_a), λ, μs[:,t], n, dx, xc, dt)

In [13]:
@benchmark loglikelihood_gpu(μs, σs, θ, data, n, dt)

1

LoadError: [91mInvalidIRError: compiling kernel matrix_vector_dot!(CuDeviceArray{Float32,1,1}, Float32, CuDeviceArray{Float32,1,1}, Float32, CuDeviceArray{Float32,1,1}, Int64, Float32, CuDeviceArray{Float32,2,1}, CuDeviceArray{Float32,2,1}) resulted in invalid LLVM IR[39m
[91mReason: unsupported dynamic function invocation (call to -)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:50[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to literal_pow)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:50[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to *)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:50[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to /)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:50[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to exp)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:50[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to Float32)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:50[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to *)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:53[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to Float32)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:53[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to -)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:53[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to /)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:53[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to *)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:57[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to Float32)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:57[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to -)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:57[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to /)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:57[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported use of an undefined name (use of 'expm1_div_x')[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:35[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:35[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to *)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:35[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to +)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:35[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported use of an undefined name (use of 'erf')[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:53[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:53[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to +)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:53[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported use of an undefined name (use of 'erfc')[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:57[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:57[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to +)[39m
[91mStacktrace:[39m
[91m [1] [1mtransition_M_element_one![22m at [1mIn[11]:57[22m[39m
[91m [2] [1mmultiple call sites[22m at [1munknown:0[22m[39m
[91mReason: unsupported dynamic function invocation (call to *)[39m
[91mStacktrace:[39m
[91m [1] [1mmatrix_vector_dot![22m at [1mIn[11]:87[22m[39m
[91mReason: unsupported dynamic function invocation (call to convert)[39m
[91mStacktrace:[39m
[91m [1] [1msetindex![22m at [1m/home/dikshag/.julia/packages/CUDA/mbPFj/src/device/array.jl:101[22m[39m
[91m [2] [1mmatrix_vector_dot![22m at [1mIn[11]:87[22m[39m
[91mReason: unsupported dynamic function invocation (call to *)[39m
[91mStacktrace:[39m
[91m [1] [1mmatrix_vector_dot![22m at [1mIn[11]:92[22m[39m
[91mReason: unsupported dynamic function invocation (call to +)[39m
[91mStacktrace:[39m
[91m [1] [1mmatrix_vector_dot![22m at [1mIn[11]:92[22m[39m
[91mReason: unsupported dynamic function invocation (call to convert)[39m
[91mStacktrace:[39m
[91m [1] [1msetindex![22m at [1m/home/dikshag/.julia/packages/CUDA/mbPFj/src/device/array.jl:101[22m[39m
[91m [2] [1mmatrix_vector_dot![22m at [1mIn[11]:92[22m[39m

In [None]:
data

In [None]:
@profview loglikelihood(model)
ProfileSVG.save("loglikelihood_prof.svg")


In [None]:
@profview gradient(model)

In [None]:
ProfileSVG.save("gradient_prof.svg")


In [None]:
fit = vcat(trues(9));
lb = [0., 2.,  -5., 0.,   0.,  0., 0.005, -5.0, 0.0]
ub = [30., 100., 5., 200., 10., 1.2,  1., 5.0, 1.0];
options = choiceoptions(fit=fit, lb=lb, ub=ub)
x_generative = collect(pulse_input_DDM.Flatten.flatten(θ_generative));

x0 = vcat([0.1, 15., -0.1, 20.,  0.5, 0.2,  0.008], [0.,  0.01]);

model = choiceDDM(θ=pulse_input_DDM.Flatten.reconstruct(θchoice(),x0), data=data_50K, n=n)

@profview optimize(model, options, iterations = 2)

In [None]:
ProfileSVG.save("optimize_prof.svg")
