Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Use contextual dispatch for replacing functions #334

Open
wants to merge 19 commits into
base: master
from

Conversation

@vchuravy
Copy link
Member

commented Jan 29, 2019

On 1.1 Cassette should be performant enough for these kinds of transforms.

Fixes #27

@maleadt did you have a branch similar to this around?

@vchuravy vchuravy requested a review from maleadt Jan 29, 2019

@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Jan 29, 2019

bors try

bors bot added a commit that referenced this pull request Jan 29, 2019

@bors

This comment has been minimized.

Copy link
Contributor

commented Jan 29, 2019

try

Build failed

@maleadt

This comment has been minimized.

Copy link
Member

commented Jan 30, 2019

Yes, https://github.com/JuliaGPU/CUDAnative.jl/compare/tb/cassette
Didn't work because of plenty allocations, invokes, dispatches, etc. Is your approach different in that regard?
Also, #265.

@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Jan 30, 2019

bors try

bors bot added a commit that referenced this pull request Jan 30, 2019

@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Jan 30, 2019

As bors tells us apparently not ;)

@jrevels https://gitlab.com/JuliaGPU/CUDAnative.jl/-/jobs/153739960 is full of interesting cases.

@bors

This comment was marked as resolved.

Copy link
Contributor

commented Jan 30, 2019

try

Build failed

src/context.jl Outdated Show resolved Hide resolved
@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Jan 30, 2019

bors try

bors bot added a commit that referenced this pull request Jan 30, 2019

@bors

This comment was marked as resolved.

Copy link
Contributor

commented Jan 30, 2019

try

Build failed

@maleadt maleadt changed the title Use contextual dispatch for replacing functions WIP: Use contextual dispatch for replacing functions Jan 31, 2019

@maleadt

This comment has been minimized.

Copy link
Member

commented Jan 31, 2019

Yeah, as I feared... Let's mark this WIP then 🙁

@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Feb 1, 2019

bors try

bors bot added a commit that referenced this pull request Feb 1, 2019

@bors

This comment has been minimized.

Copy link
Contributor

commented Feb 1, 2019

try

Build failed

@maleadt

This comment has been minimized.

Copy link
Member

commented Feb 1, 2019

Same error count; inlining doesn't help.
That said, many stack traces point to getindex again, so maybe there's only a small number of errors remaining. I'll have another go at reducing vadd when I have some time.

@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Feb 4, 2019

I was planning on grabbing Jarrett this week and see if we can figure it out. (I am in the progress to add GPU support to Cthulhu so that should make it easier)

@vchuravy vchuravy force-pushed the vc/cassette branch from e0ed898 to abbc3fd Feb 7, 2019

@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Feb 7, 2019

bors try

bors bot added a commit that referenced this pull request Feb 7, 2019

@bors

This comment has been minimized.

Copy link
Contributor

commented Feb 7, 2019

try

Build failed

@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Feb 7, 2019

Ok! The debugging session with Jarrett proved fruitful, we are down to 10ish failures :)

@maleadt

This comment has been minimized.

Copy link
Member

commented Feb 8, 2019

Cool! What were the changes?

@jrevels

This comment has been minimized.

Copy link

commented Feb 8, 2019

Cool! What were the changes?

We applied my usual Cassette issue workaround of "isolate the problematic thing and make it a contextual primitive (i.e. don't overdub into it)". The problematic thing here was the @pure function datatype_align.

It turns out that while Cassette propagates purity to the compiler correctly, the compiler is (probably rightfully) pessimistic and just bails out on purity optimization for generated functions (i.e. overdub). ref JuliaLang/julia#31012, which is my naive attempt at changing the compiler to allow this sort of thing. If that lands, we can remove the extra contextual primitive definition here.

end
end

contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)

This comment has been minimized.

Copy link
@vchuravy

vchuravy Feb 11, 2019

Author Member
Suggested change
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)
contextualize(f::F) where F = (args...) -> (Cassette.overdub(cudactx, f, args...); return nothing)

We could go back to automatically returning nothing here

@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Feb 11, 2019

bors try

bors bot added a commit that referenced this pull request Feb 11, 2019

@maleadt maleadt force-pushed the vc/cassette branch 2 times, most recently from 7ceda2b to 29b9e2e Jul 23, 2019

@maleadt

This comment has been minimized.

Copy link
Member

commented Jul 23, 2019

Rebased and removed the 265 hacks now that we have JuliaLang/julia#32237.
Nothing wrong with depending on a day-old alpha, right?

EDIT: oh wow everything broke. @vchuravy and ideas? Looks like it the transform "works", just plenty of dynamic invocations:

julia> kernel(a) = (a[1] = 1; nothing)
kernel (generic function with 2 methods)

julia> kernel([0])

julia> CUDAnative.contextualize(kernel)([0])

julia> @cuda kernel(cu([0]))
ERROR: InvalidIRError: compiling #148(CuDeviceArray{Float32,1,CUDAnative.AS.Global}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to overdub)
Stacktrace:
 [1] Val at essentials.jl:694
 [2] setindex! at /home/tbesard/Julia/pkg/CUDAnative/src/device/array.jl:84
 [3] kernel at REPL[5]:1
 [4] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75

Seems something has changed with Val. Adding it to the list of pure overdub exceptions (@inline Cassette.overdub(::CUDACtx, ::typeof(Base.Val), x) = return Base.Val(x)) doesn't help.
Are those still needed anyhow now that we have JuliaLang/julia#31012?

@maleadt

This comment has been minimized.

Copy link
Member

commented Jul 24, 2019

MWE:

@generated bar(::Val{align}) where {align} = :(42)
foo(i) = i+bar(Val(1))

using Cassette

function transform(ctx, ref)
    CI = ref.code_info
    noinline = any(@nospecialize(x) ->
                       Core.Compiler.isexpr(x, :meta) &&
                       x.args[1] == :noinline,
                   CI.code)
    CI.inlineable = !noinline

    CI.ssavaluetypes = length(CI.code)
    Core.Compiler.validate_code(CI)
    return CI
end
const InlinePass = Cassette.@pass transform

Cassette.@context Ctx
const ctx = Cassette.disablehooks(Ctx(pass = InlinePass))
contextualize(f::F) where F = (args...) -> Cassette.overdub(ctx, f, args...)

using InteractiveUtils
code_llvm(foo, Tuple{Int})
code_llvm(contextualize(foo), Tuple{Int})

Doesn't need the addition, but otherwise we get a const jlapi function. Also doesn't need the inlining pass, but otherwise the LLVM contains a call to overdub, while it now clearly shows a dynamic invocation:

;  @ /home/tbesard/Julia/wip2.jl:2 within `foo'
define i64 @julia_foo_15985(i64) {
top:
; ┌ @ int.jl:53 within `+'
   %1 = add i64 %0, 42
; └
  ret i64 %1
}

;  @ /home/tbesard/Julia/wip2.jl:22 within `#7'
define i64 @"julia_#7_16073"(i64) {
top:
  %1 = alloca %jl_value_t addrspace(10)*, i32 2
; ┌ @ /home/tbesard/Julia/wip2.jl:2 within `foo'
; │┌ @ essentials.jl:694 within `Val'
    %2 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %1, i32 0
    store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490737042584 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %2
    %3 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %1, i32 1
    store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490826585840 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %3
    %4 = call nonnull %jl_value_t addrspace(10)* @jl_apply_generic(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490737042232 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %1, i32 2)
; │└
; │┌ @ int.jl:53 within `+'
; ││┌ @ /home/tbesard/Julia/pkg/Cassette/src/overdub.jl:465 within `_overdub_fallback'
; │││┌ @ /home/tbesard/Julia/pkg/Cassette/src/context.jl:445 within `fallback'
; ││││┌ @ /home/tbesard/Julia/pkg/Cassette/src/context.jl:447 within `call'
       %5 = add i64 %0, 42
; └└└└└
  ret i64 %5
}

Bisected to JuliaLang/julia#31012

@maleadt

This comment has been minimized.

Copy link
Member

commented Jul 25, 2019

265 looks broken:

automatic recompilation: Test Failed at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:320
  Expression: (Array(arr))[] == 2
   Evaluated: 1 == 2
Stacktrace:
 [1] top-level scope at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:320
 [2] top-level scope at /home/tbesard/Julia/julia-dev/build/release/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1114
 [3] top-level scope at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:304
 [4] top-level scope at /home/tbesard/Julia/julia-dev/build/release/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1114
 [5] top-level scope at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:154
 [6] top-level scope at /home/tbesard/Julia/julia-dev/build/release/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1114
 [7] top-level scope at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:5
non-isbits arguments: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:324

Also the jl_f_tuple failure, and worse, the stack trace points to overdub:

non-isbits arguments: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:324
  Got exception outside of a @test
  InvalidIRError: compiling #148(Type{Int64}, Int64) resulted in invalid LLVM IR
  Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
  Reason: unsupported call to the Julia runtime (call to jl_f_getfield)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75

This one has a better on-device stack trace:

pow: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:46
  Got exception outside of a @test
  InvalidIRError: compiling #148(CuDeviceArray{Float32,0,CUDAnative.AS.Global}, Float32, Int64) resulted in invalid LLVM IR
  Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
  Stacktrace:
   [1] overdub at /home/tbesard/Julia/pkg/Cassette/src/overdub.jl:508
   [2] checked_trunc_sint at boot.jl:582
   [3] toInt32 at boot.jl:619
   [4] Int32 at boot.jl:709
   [5] pow at /home/tbesard/Julia/pkg/CUDAnative/src/device/cuda/math.jl:193
   [6] pow_kernel at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:50
   [7] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
  Reason: unsupported call to the Julia runtime (call to jl_f_getfield)
  Stacktrace:
   [1] throw_inexacterror at boot.jl:560
   [2] overdub at /home/tbesard/Julia/pkg/Cassette/src/overdub.jl:0
   [3] checked_trunc_sint at boot.jl:582
   [4] toInt32 at boot.jl:619
   [5] Int32 at boot.jl:709
   [6] pow at /home/tbesard/Julia/pkg/CUDAnative/src/device/cuda/math.jl:193
   [7] pow_kernel at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:50
   [8] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75

Some more, that fail for all their types:

T = Int32: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:780
  Got exception outside of a @test
  InvalidIRError: compiling #148(Type{Int32}, CuDeviceArray{Int32,1,CUDAnative.AS.Global}) resulted in invalid LLVM IR
  Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
  Reason: unsupported call to the Julia runtime (call to jl_f_getfield)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
T = UInt64: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:880
  Got exception outside of a @test
  InvalidIRError: compiling #148(Type{UInt64}, CuDeviceArray{UInt64,1,CUDAnative.AS.Global}) resulted in invalid LLVM IR
  Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
  Reason: unsupported call to the Julia runtime (call to jl_f_getfield)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
Test Summary:                                                  | Pass  Fail  Error  Total
CUDAnative                                                     |  348     5     32    385
  base interface                                               |                    No tests
  pointer                                                      |   20                  20
  code generation                                              |   90                  90
  code generation (relying on a device)                        |    6                   6
  execution                                                    |   70     5      1     76
    @cuda                                                      |    8     4            12
      low-level interface                                      |                    No tests
      launch configuration                                     |                    No tests
      compilation params                                       |    1                   1
      reflection                                               |    6     3             9
      shared memory                                            |                    No tests
      streams                                                  |                    No tests
      external kernels                                         |                    No tests
      calling device function                                  |                    No tests
    argument passing                                           |   27     1      1     29
      manually allocated                                       |    1                   1
      scalar through single-value array                        |    1                   1
      scalar through single-value array, using device function |    1                   1
      tuples                                                   |    1                   1
      ghost function parameters                                |    2                   2
      immutables                                               |    1                   1
      automatic recompilation                                  |    1     1             2
      non-isbits arguments                                     |                 1      1
      splatting                                                |    3                   3
      object invoke                                            |    1                   1
      closures                                                 |    1                   1
      conversions                                              |    8                   8
      argument count                                           |    4                   4
      keyword arguments                                        |                    No tests
      captured values                                          |    2                   2
    exceptions                                                 |   17                  17
    shmem divergence bug                                       |    7                   7
    dynamic parallelism                                        |   10                  10
    cooperative groups                                         |    1                   1
  pointer                                                      |   41                  41
  device arrays                                                |   20                  20
  CUDA functionality                                           |   87           31    118
    indexing                                                   |    1                   1
    math                                                       |    9            1     10
      pow                                                      |    8            1      9
    formatted output                                           |    6                   6
    @cuprint                                                   |   24                  24
    assertion                                                  |                    No tests
    shared memory                                              |   14                  14
    data movement and conversion                               |    5                   5
    clock and nanosleep                                        |                    No tests
    parallel synchronization and communication                 |   16                  16
    libcudadevrt                                               |                    No tests
    atomics (low-level)                                        |   12                  12
    atomics (high-level)                                       |                30     30
      add                                                      |                 6      6
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
        T = Float32                                            |                 1      1
        T = Float64                                            |                 1      1
      sub                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      and                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      or                                                       |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      xor                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      max                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      min                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
  examples                                                     |    6                   6
@vchuravy

This comment has been minimized.

Copy link
Member Author

commented Jul 25, 2019

Will look into these after my talk (especially the 265 one is disappointing). Having a reproducer with MiniCassette would be great.

@maleadt

This comment has been minimized.

Copy link
Member

commented Jul 25, 2019

Yeah no hurry, just reporting in that your fix works. Once merged we'll have CI at least.

@maleadt

This comment has been minimized.

Copy link
Member

commented Jul 26, 2019

Never mind, Pkg and I were confused about which version of Cassette we were using.
Let's try again.

@maleadt maleadt force-pushed the vc/cassette branch from 29b9e2e to b0ef268 Jul 26, 2019

@maleadt maleadt force-pushed the vc/cassette branch from b0ef268 to e8de056 Jul 26, 2019

@maleadt

This comment has been minimized.

Copy link
Member

commented Jul 26, 2019

Turns out we weren't properly contextualizing all methods. Penalty: 20 additional failures.

Test Summary:                                                  | Pass  Fail  Error  Total
CUDAnative                                                     |  323    24     33    380
  base interface                                               |                    No tests
  pointer                                                      |   20                  20
  code generation                                              |   63    21      1     85
    LLVM IR                                                    |   30     7            37
      basic reflection                                         |    4     1             5
      unbound typevars                                         |          1             1
      exceptions                                               |    3                   3
      sysimg                                                   |    1                   1
      child functions                                          |          1             1
      JuliaLang/julia#21121                                    |    1                   1
      kernel functions                                         |   14     2            16
        wrapper function aggregate rewriting                   |          2             2
        property_annotations                                   |   14                  14
      LLVM D32593                                              |                    No tests
      kernel names                                             |    2     2             4
      PTX TBAA                                                 |    5                   5
      tracked pointers                                         |                    No tests
      CUDAnative.jl#278                                        |                    No tests
    PTX assembly                                               |   16     8            24
      basic reflection                                         |    3                   3
      child functions                                          |          1             1
      kernel functions                                         |    6     2             8
        property_annotations                                   |    5                   5
      idempotency                                              |                    No tests
      child function reuse                                     |          2             2
      child function reuse bis                                 |                    No tests
      indirect sysimg function use                             |    2                   2
      compile for host after PTX                               |    1                   1
      LLVM intrinsics                                          |                    No tests
      kernel names                                             |    2     2             4
      exception arguments                                      |                    No tests
      GC and TLS lowering                                      |    2                   2
      float boxes                                              |          1             1
    errors                                                     |   17     6      1     24
      recursion                                                |          2             2
      base intrinsics                                          |          1      1      2
      non-isbits arguments                                     |   10                  10
      invalid LLVM IR                                          |    2                   2
      invalid LLVM IR (ccall)                                  |    2                   2
      delayed bindings                                         |    2                   2
      dynamic call (invoke)                                    |          2             2
      dynamic call (apply)                                     |    1     1             2
  code generation (relying on a device)                        |    6                   6
  execution                                                    |   72     3      1     76
    @cuda                                                      |    9     3            12
      low-level interface                                      |                    No tests
      launch configuration                                     |                    No tests
      compilation params                                       |    1                   1
      reflection                                               |    6     3             9
      shared memory                                            |                    No tests
      streams                                                  |                    No tests
      external kernels                                         |                    No tests
      calling device function                                  |                    No tests
    argument passing                                           |   28            1     29
      manually allocated                                       |    1                   1
      scalar through single-value array                        |    1                   1
      scalar through single-value array, using device function |    1                   1
      tuples                                                   |    1                   1
      ghost function parameters                                |    2                   2
      immutables                                               |    1                   1
      automatic recompilation                                  |    2                   2
      non-isbits arguments                                     |                 1      1
      splatting                                                |    3                   3
      object invoke                                            |    1                   1
      closures                                                 |    1                   1
      conversions                                              |    8                   8
      argument count                                           |    4                   4
      keyword arguments                                        |                    No tests
      captured values                                          |    2                   2
    exceptions                                                 |   17                  17
    shmem divergence bug                                       |    7                   7
    dynamic parallelism                                        |   10                  10
    cooperative groups                                         |    1                   1
  pointer                                                      |   41                  41
  device arrays                                                |   20                  20
  CUDA functionality                                           |   87           31    118
    indexing                                                   |    1                   1
    math                                                       |    9            1     10
      pow                                                      |    8            1      9
    formatted output                                           |    6                   6
    @cuprint                                                   |   24                  24
    assertion                                                  |                    No tests
    shared memory                                              |   14                  14
    data movement and conversion                               |    5                   5
    clock and nanosleep                                        |                    No tests
    parallel synchronization and communication                 |   16                  16
    libcudadevrt                                               |                    No tests
    atomics (low-level)                                        |   12                  12
    atomics (high-level)                                       |                30     30
      add                                                      |                 6      6
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
        T = Float32                                            |                 1      1
        T = Float64                                            |                 1      1
      sub                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      and                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      or                                                       |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      xor                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      max                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      min                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
  examples                                                     |    6                   6
ERROR: LoadError: Some tests did not pass: 323 passed, 24 failed, 33 errored, 0 broken.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
3 participants
You can’t perform that action at this time.