Skip to content
This repository has been archived by the owner. It is now read-only.

WIP: Use contextual dispatch for replacing functions #334

Open
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

@vchuravy
Copy link
Member

@vchuravy vchuravy commented Jan 29, 2019

On 1.1 Cassette should be performant enough for these kinds of transforms.

Fixes https://github.com/JuliaGPU/CUDAnative.jl/issues/27

@maleadt did you have a branch similar to this around?

@vchuravy vchuravy requested a review from maleadt Jan 29, 2019
@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Jan 29, 2019

bors try

Loading

bors bot added a commit that referenced this issue Jan 29, 2019
@bors
Copy link
Contributor

@bors bors bot commented Jan 29, 2019

try

Build failed

Loading

@maleadt
Copy link
Member

@maleadt maleadt commented Jan 30, 2019

Yes, https://github.com/JuliaGPU/CUDAnative.jl/compare/tb/cassette
Didn't work because of plenty allocations, invokes, dispatches, etc. Is your approach different in that regard?
Also, #265.

Loading

@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Jan 30, 2019

bors try

Loading

bors bot added a commit that referenced this issue Jan 30, 2019
@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Jan 30, 2019

As bors tells us apparently not ;)

@jrevels https://gitlab.com/JuliaGPU/CUDAnative.jl/-/jobs/153739960 is full of interesting cases.

Loading

@bors

This comment has been hidden.

src/context.jl Outdated Show resolved Hide resolved
Loading
@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Jan 30, 2019

bors try

Loading

bors bot added a commit that referenced this issue Jan 30, 2019
@bors

This comment has been hidden.

@maleadt maleadt changed the title Use contextual dispatch for replacing functions WIP: Use contextual dispatch for replacing functions Jan 31, 2019
@maleadt
Copy link
Member

@maleadt maleadt commented Jan 31, 2019

Yeah, as I feared... Let's mark this WIP then 🙁

Loading

@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Feb 1, 2019

bors try

Loading

bors bot added a commit that referenced this issue Feb 1, 2019
@bors
Copy link
Contributor

@bors bors bot commented Feb 1, 2019

try

Build failed

Loading

@maleadt
Copy link
Member

@maleadt maleadt commented Feb 1, 2019

Same error count; inlining doesn't help.
That said, many stack traces point to getindex again, so maybe there's only a small number of errors remaining. I'll have another go at reducing vadd when I have some time.

Loading

@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Feb 4, 2019

I was planning on grabbing Jarrett this week and see if we can figure it out. (I am in the progress to add GPU support to Cthulhu so that should make it easier)

Loading

@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Feb 7, 2019

bors try

Loading

bors bot added a commit that referenced this issue Feb 7, 2019
@bors
Copy link
Contributor

@bors bors bot commented Feb 7, 2019

try

Build failed

Loading

@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Feb 7, 2019

Ok! The debugging session with Jarrett proved fruitful, we are down to 10ish failures :)

Loading

@maleadt
Copy link
Member

@maleadt maleadt commented Feb 8, 2019

Cool! What were the changes?

Loading

@jrevels
Copy link

@jrevels jrevels commented Feb 8, 2019

Cool! What were the changes?

We applied my usual Cassette issue workaround of "isolate the problematic thing and make it a contextual primitive (i.e. don't overdub into it)". The problematic thing here was the @pure function datatype_align.

It turns out that while Cassette propagates purity to the compiler correctly, the compiler is (probably rightfully) pessimistic and just bails out on purity optimization for generated functions (i.e. overdub). ref JuliaLang/julia#31012, which is my naive attempt at changing the compiler to allow this sort of thing. If that lands, we can remove the extra contextual primitive definition here.

Loading

end
end

contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)
Copy link
Member Author

@vchuravy vchuravy Feb 11, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
contextualize(f::F) where F = (args...) -> Cassette.overdub(cudactx, f, args...)
contextualize(f::F) where F = (args...) -> (Cassette.overdub(cudactx, f, args...); return nothing)

We could go back to automatically returning nothing here

Loading

@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Feb 11, 2019

bors try

Loading

bors bot added a commit that referenced this issue Feb 11, 2019
@maleadt
Copy link
Member

@maleadt maleadt commented Jul 23, 2019

Rebased and removed the 265 hacks now that we have JuliaLang/julia#32237.
Nothing wrong with depending on a day-old alpha, right?

EDIT: oh wow everything broke. @vchuravy and ideas? Looks like it the transform "works", just plenty of dynamic invocations:

julia> kernel(a) = (a[1] = 1; nothing)
kernel (generic function with 2 methods)

julia> kernel([0])

julia> CUDAnative.contextualize(kernel)([0])

julia> @cuda kernel(cu([0]))
ERROR: InvalidIRError: compiling #148(CuDeviceArray{Float32,1,CUDAnative.AS.Global}) resulted in invalid LLVM IR
Reason: unsupported dynamic function invocation (call to overdub)
Stacktrace:
 [1] Val at essentials.jl:694
 [2] setindex! at /home/tbesard/Julia/pkg/CUDAnative/src/device/array.jl:84
 [3] kernel at REPL[5]:1
 [4] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75

Seems something has changed with Val. Adding it to the list of pure overdub exceptions (@inline Cassette.overdub(::CUDACtx, ::typeof(Base.Val), x) = return Base.Val(x)) doesn't help.
Are those still needed anyhow now that we have JuliaLang/julia#31012?

Loading

@maleadt
Copy link
Member

@maleadt maleadt commented Jul 24, 2019

MWE:

@generated bar(::Val{align}) where {align} = :(42)
foo(i) = i+bar(Val(1))

using Cassette

function transform(ctx, ref)
    CI = ref.code_info
    noinline = any(@nospecialize(x) ->
                       Core.Compiler.isexpr(x, :meta) &&
                       x.args[1] == :noinline,
                   CI.code)
    CI.inlineable = !noinline

    CI.ssavaluetypes = length(CI.code)
    Core.Compiler.validate_code(CI)
    return CI
end
const InlinePass = Cassette.@pass transform

Cassette.@context Ctx
const ctx = Cassette.disablehooks(Ctx(pass = InlinePass))
contextualize(f::F) where F = (args...) -> Cassette.overdub(ctx, f, args...)

using InteractiveUtils
code_llvm(foo, Tuple{Int})
code_llvm(contextualize(foo), Tuple{Int})

Doesn't need the addition, but otherwise we get a const jlapi function. Also doesn't need the inlining pass, but otherwise the LLVM contains a call to overdub, while it now clearly shows a dynamic invocation:

;  @ /home/tbesard/Julia/wip2.jl:2 within `foo'
define i64 @julia_foo_15985(i64) {
top:
; ┌ @ int.jl:53 within `+'
   %1 = add i64 %0, 42
; └
  ret i64 %1
}

;  @ /home/tbesard/Julia/wip2.jl:22 within `#7'
define i64 @"julia_#7_16073"(i64) {
top:
  %1 = alloca %jl_value_t addrspace(10)*, i32 2
; ┌ @ /home/tbesard/Julia/wip2.jl:2 within `foo'
; │┌ @ essentials.jl:694 within `Val'
    %2 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %1, i32 0
    store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490737042584 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %2
    %3 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %1, i32 1
    store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490826585840 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %3
    %4 = call nonnull %jl_value_t addrspace(10)* @jl_apply_generic(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140490737042232 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %1, i32 2)
; │└
; │┌ @ int.jl:53 within `+'
; ││┌ @ /home/tbesard/Julia/pkg/Cassette/src/overdub.jl:465 within `_overdub_fallback'
; │││┌ @ /home/tbesard/Julia/pkg/Cassette/src/context.jl:445 within `fallback'
; ││││┌ @ /home/tbesard/Julia/pkg/Cassette/src/context.jl:447 within `call'
       %5 = add i64 %0, 42
; └└└└└
  ret i64 %5
}

Bisected to JuliaLang/julia#31012

Loading

@maleadt
Copy link
Member

@maleadt maleadt commented Jul 25, 2019

265 looks broken:

automatic recompilation: Test Failed at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:320
  Expression: (Array(arr))[] == 2
   Evaluated: 1 == 2
Stacktrace:
 [1] top-level scope at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:320
 [2] top-level scope at /home/tbesard/Julia/julia-dev/build/release/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1114
 [3] top-level scope at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:304
 [4] top-level scope at /home/tbesard/Julia/julia-dev/build/release/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1114
 [5] top-level scope at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:154
 [6] top-level scope at /home/tbesard/Julia/julia-dev/build/release/usr/share/julia/stdlib/v1.3/Test/src/Test.jl:1114
 [7] top-level scope at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:5
non-isbits arguments: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:324

Also the jl_f_tuple failure, and worse, the stack trace points to overdub:

non-isbits arguments: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/execution.jl:324
  Got exception outside of a @test
  InvalidIRError: compiling #148(Type{Int64}, Int64) resulted in invalid LLVM IR
  Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
  Reason: unsupported call to the Julia runtime (call to jl_f_getfield)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75

This one has a better on-device stack trace:

pow: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:46
  Got exception outside of a @test
  InvalidIRError: compiling #148(CuDeviceArray{Float32,0,CUDAnative.AS.Global}, Float32, Int64) resulted in invalid LLVM IR
  Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
  Stacktrace:
   [1] overdub at /home/tbesard/Julia/pkg/Cassette/src/overdub.jl:508
   [2] checked_trunc_sint at boot.jl:582
   [3] toInt32 at boot.jl:619
   [4] Int32 at boot.jl:709
   [5] pow at /home/tbesard/Julia/pkg/CUDAnative/src/device/cuda/math.jl:193
   [6] pow_kernel at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:50
   [7] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
  Reason: unsupported call to the Julia runtime (call to jl_f_getfield)
  Stacktrace:
   [1] throw_inexacterror at boot.jl:560
   [2] overdub at /home/tbesard/Julia/pkg/Cassette/src/overdub.jl:0
   [3] checked_trunc_sint at boot.jl:582
   [4] toInt32 at boot.jl:619
   [5] Int32 at boot.jl:709
   [6] pow at /home/tbesard/Julia/pkg/CUDAnative/src/device/cuda/math.jl:193
   [7] pow_kernel at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:50
   [8] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75

Some more, that fail for all their types:

T = Int32: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:780
  Got exception outside of a @test
  InvalidIRError: compiling #148(Type{Int32}, CuDeviceArray{Int32,1,CUDAnative.AS.Global}) resulted in invalid LLVM IR
  Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
  Reason: unsupported call to the Julia runtime (call to jl_f_getfield)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
T = UInt64: Error During Test at /home/tbesard/Julia/pkg/CUDAnative/test/device/cuda.jl:880
  Got exception outside of a @test
  InvalidIRError: compiling #148(Type{UInt64}, CuDeviceArray{UInt64,1,CUDAnative.AS.Global}) resulted in invalid LLVM IR
  Reason: unsupported call to the Julia runtime (call to jl_f_tuple)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
  Reason: unsupported call to the Julia runtime (call to jl_f_getfield)
  Stacktrace:
   [1] #148 at /home/tbesard/Julia/pkg/CUDAnative/src/context.jl:75
Test Summary:                                                  | Pass  Fail  Error  Total
CUDAnative                                                     |  348     5     32    385
  base interface                                               |                    No tests
  pointer                                                      |   20                  20
  code generation                                              |   90                  90
  code generation (relying on a device)                        |    6                   6
  execution                                                    |   70     5      1     76
    @cuda                                                      |    8     4            12
      low-level interface                                      |                    No tests
      launch configuration                                     |                    No tests
      compilation params                                       |    1                   1
      reflection                                               |    6     3             9
      shared memory                                            |                    No tests
      streams                                                  |                    No tests
      external kernels                                         |                    No tests
      calling device function                                  |                    No tests
    argument passing                                           |   27     1      1     29
      manually allocated                                       |    1                   1
      scalar through single-value array                        |    1                   1
      scalar through single-value array, using device function |    1                   1
      tuples                                                   |    1                   1
      ghost function parameters                                |    2                   2
      immutables                                               |    1                   1
      automatic recompilation                                  |    1     1             2
      non-isbits arguments                                     |                 1      1
      splatting                                                |    3                   3
      object invoke                                            |    1                   1
      closures                                                 |    1                   1
      conversions                                              |    8                   8
      argument count                                           |    4                   4
      keyword arguments                                        |                    No tests
      captured values                                          |    2                   2
    exceptions                                                 |   17                  17
    shmem divergence bug                                       |    7                   7
    dynamic parallelism                                        |   10                  10
    cooperative groups                                         |    1                   1
  pointer                                                      |   41                  41
  device arrays                                                |   20                  20
  CUDA functionality                                           |   87           31    118
    indexing                                                   |    1                   1
    math                                                       |    9            1     10
      pow                                                      |    8            1      9
    formatted output                                           |    6                   6
    @cuprint                                                   |   24                  24
    assertion                                                  |                    No tests
    shared memory                                              |   14                  14
    data movement and conversion                               |    5                   5
    clock and nanosleep                                        |                    No tests
    parallel synchronization and communication                 |   16                  16
    libcudadevrt                                               |                    No tests
    atomics (low-level)                                        |   12                  12
    atomics (high-level)                                       |                30     30
      add                                                      |                 6      6
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
        T = Float32                                            |                 1      1
        T = Float64                                            |                 1      1
      sub                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      and                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      or                                                       |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      xor                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      max                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      min                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
  examples                                                     |    6                   6

Loading

@vchuravy
Copy link
Member Author

@vchuravy vchuravy commented Jul 25, 2019

Will look into these after my talk (especially the 265 one is disappointing). Having a reproducer with MiniCassette would be great.

Loading

@maleadt
Copy link
Member

@maleadt maleadt commented Jul 25, 2019

Yeah no hurry, just reporting in that your fix works. Once merged we'll have CI at least.

Loading

@vchuravy vchuravy reopened this Jul 26, 2019
@maleadt
Copy link
Member

@maleadt maleadt commented Jul 26, 2019

Never mind, Pkg and I were confused about which version of Cassette we were using.
Let's try again.

Loading

@maleadt maleadt force-pushed the vc/cassette branch 2 times, most recently from b0ef268 to e8de056 Jul 26, 2019
@maleadt
Copy link
Member

@maleadt maleadt commented Jul 26, 2019

Turns out we weren't properly contextualizing all methods. Penalty: 20 additional failures.

Test Summary:                                                  | Pass  Fail  Error  Total
CUDAnative                                                     |  323    24     33    380
  base interface                                               |                    No tests
  pointer                                                      |   20                  20
  code generation                                              |   63    21      1     85
    LLVM IR                                                    |   30     7            37
      basic reflection                                         |    4     1             5
      unbound typevars                                         |          1             1
      exceptions                                               |    3                   3
      sysimg                                                   |    1                   1
      child functions                                          |          1             1
      JuliaLang/julia#21121                                    |    1                   1
      kernel functions                                         |   14     2            16
        wrapper function aggregate rewriting                   |          2             2
        property_annotations                                   |   14                  14
      LLVM D32593                                              |                    No tests
      kernel names                                             |    2     2             4
      PTX TBAA                                                 |    5                   5
      tracked pointers                                         |                    No tests
      CUDAnative.jl#278                                        |                    No tests
    PTX assembly                                               |   16     8            24
      basic reflection                                         |    3                   3
      child functions                                          |          1             1
      kernel functions                                         |    6     2             8
        property_annotations                                   |    5                   5
      idempotency                                              |                    No tests
      child function reuse                                     |          2             2
      child function reuse bis                                 |                    No tests
      indirect sysimg function use                             |    2                   2
      compile for host after PTX                               |    1                   1
      LLVM intrinsics                                          |                    No tests
      kernel names                                             |    2     2             4
      exception arguments                                      |                    No tests
      GC and TLS lowering                                      |    2                   2
      float boxes                                              |          1             1
    errors                                                     |   17     6      1     24
      recursion                                                |          2             2
      base intrinsics                                          |          1      1      2
      non-isbits arguments                                     |   10                  10
      invalid LLVM IR                                          |    2                   2
      invalid LLVM IR (ccall)                                  |    2                   2
      delayed bindings                                         |    2                   2
      dynamic call (invoke)                                    |          2             2
      dynamic call (apply)                                     |    1     1             2
  code generation (relying on a device)                        |    6                   6
  execution                                                    |   72     3      1     76
    @cuda                                                      |    9     3            12
      low-level interface                                      |                    No tests
      launch configuration                                     |                    No tests
      compilation params                                       |    1                   1
      reflection                                               |    6     3             9
      shared memory                                            |                    No tests
      streams                                                  |                    No tests
      external kernels                                         |                    No tests
      calling device function                                  |                    No tests
    argument passing                                           |   28            1     29
      manually allocated                                       |    1                   1
      scalar through single-value array                        |    1                   1
      scalar through single-value array, using device function |    1                   1
      tuples                                                   |    1                   1
      ghost function parameters                                |    2                   2
      immutables                                               |    1                   1
      automatic recompilation                                  |    2                   2
      non-isbits arguments                                     |                 1      1
      splatting                                                |    3                   3
      object invoke                                            |    1                   1
      closures                                                 |    1                   1
      conversions                                              |    8                   8
      argument count                                           |    4                   4
      keyword arguments                                        |                    No tests
      captured values                                          |    2                   2
    exceptions                                                 |   17                  17
    shmem divergence bug                                       |    7                   7
    dynamic parallelism                                        |   10                  10
    cooperative groups                                         |    1                   1
  pointer                                                      |   41                  41
  device arrays                                                |   20                  20
  CUDA functionality                                           |   87           31    118
    indexing                                                   |    1                   1
    math                                                       |    9            1     10
      pow                                                      |    8            1      9
    formatted output                                           |    6                   6
    @cuprint                                                   |   24                  24
    assertion                                                  |                    No tests
    shared memory                                              |   14                  14
    data movement and conversion                               |    5                   5
    clock and nanosleep                                        |                    No tests
    parallel synchronization and communication                 |   16                  16
    libcudadevrt                                               |                    No tests
    atomics (low-level)                                        |   12                  12
    atomics (high-level)                                       |                30     30
      add                                                      |                 6      6
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
        T = Float32                                            |                 1      1
        T = Float64                                            |                 1      1
      sub                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      and                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      or                                                       |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      xor                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      max                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
      min                                                      |                 4      4
        T = Int32                                              |                 1      1
        T = Int64                                              |                 1      1
        T = UInt32                                             |                 1      1
        T = UInt64                                             |                 1      1
  examples                                                     |    6                   6
ERROR: LoadError: Some tests did not pass: 323 passed, 24 failed, 33 errored, 0 broken.

Loading

@maleadt
Copy link
Member

@maleadt maleadt commented Jan 20, 2020

Squashed and rebased. Added a contextualize toggle to disable for tests that rely in naming. Maybe that mechanism would also be a way to merge this in a disabled state.

Remaining failures are almost all dynamic calls to jl_f_tuple and jl_f_getfield.

Loading

@maleadt
Copy link
Member

@maleadt maleadt commented Jan 20, 2020

At least one source of those issues is the dynamic dispatch that gets introduced when passing a type. MWE:

using Cassette
Cassette.@context Noop
contextualize(f::F) where F = (args...) -> Cassette.overdub(Noop(), f, args...)

function main()
    a = [0]


    function kernel(ptr)
        unsafe_store!(ptr, 1)
        return
    end

    contextualize(kernel)(pointer(a))
    code_llvm(contextualize(kernel), Tuple{Ptr{Int}})


    function kernel(T, ptr)
        unsafe_store!(ptr, T(1))
        return
    end

    contextualize(kernel)(Int, pointer(a))
    code_llvm(contextualize(kernel), Tuple{Type{Int}, Ptr{Int}})
end
define void @"julia_#34_19961"(i64) {
top:
  %1 = inttoptr i64 %0 to i64*
  store i64 1, i64* %1, align 1
  ret void
}

define void @"julia_#34_19962"(%jl_value_t addrspace(10)* nonnull, i64) {
top:
  %2 = alloca %jl_value_t addrspace(10)*, i32 2
  %gcframe = alloca %jl_value_t addrspace(10)*, i32 3
  %3 = bitcast %jl_value_t addrspace(10)** %gcframe to i8*
  call void @llvm.memset.p0i8.i32(i8* %3, i8 0, i32 24, i32 0, i1 false)
  %thread_ptr = call i8* asm "movq %fs:0, $0", "=r"()
  %ptls_i8 = getelementptr i8, i8* %thread_ptr, i64 -15712
  %ptls = bitcast i8* %ptls_i8 to %jl_value_t***
  %4 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 0
  %5 = bitcast %jl_value_t addrspace(10)** %4 to i64*
  store i64 2, i64* %5
  %6 = getelementptr %jl_value_t**, %jl_value_t*** %ptls, i32 0
  %7 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 1
  %8 = bitcast %jl_value_t addrspace(10)** %7 to %jl_value_t***
  %9 = load %jl_value_t**, %jl_value_t*** %6
  store %jl_value_t** %9, %jl_value_t*** %8
  %10 = bitcast %jl_value_t*** %6 to %jl_value_t addrspace(10)***
  store %jl_value_t addrspace(10)** %gcframe, %jl_value_t addrspace(10)*** %10
  %11 = bitcast %jl_value_t*** %ptls to i8*
  %12 = call noalias nonnull %jl_value_t addrspace(10)* @jl_gc_pool_alloc(i8* %11, i32 1400, i32 16) #1
  %13 = bitcast %jl_value_t addrspace(10)* %12 to %jl_value_t addrspace(10)* addrspace(10)*
  %14 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)* addrspace(10)* %13, i64 -1
  store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140060992907120 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)* addrspace(10)* %14
  %15 = bitcast %jl_value_t addrspace(10)* %12 to i64 addrspace(10)*
  store i64 %1, i64 addrspace(10)* %15, align 8
  %16 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
  store %jl_value_t addrspace(10)* %12, %jl_value_t addrspace(10)** %16
  %17 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i32 0
  store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140060989456256 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %17
  %18 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i32 1
  store %jl_value_t addrspace(10)* %12, %jl_value_t addrspace(10)** %18
  %19 = call nonnull %jl_value_t addrspace(10)* @jl_f_tuple(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* null to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %2, i32 2)
  %20 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 2
  store %jl_value_t addrspace(10)* %19, %jl_value_t addrspace(10)** %20
  %21 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i32 0
  store %jl_value_t addrspace(10)* %19, %jl_value_t addrspace(10)** %21
  %22 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %2, i32 1
  store %jl_value_t addrspace(10)* addrspacecast (%jl_value_t* inttoptr (i64 140060897026208 to %jl_value_t*) to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %22
  %23 = call nonnull %jl_value_t addrspace(10)* @jl_f_getfield(%jl_value_t addrspace(10)* addrspacecast (%jl_value_t* null to %jl_value_t addrspace(10)*), %jl_value_t addrspace(10)** %2, i32 2)
  %24 = bitcast %jl_value_t addrspace(10)* %23 to i64* addrspace(10)*
  %25 = load i64*, i64* addrspace(10)* %24, align 8
  store i64 1, i64* %25, align 1
  %26 = getelementptr %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %gcframe, i32 1
  %27 = load %jl_value_t addrspace(10)*, %jl_value_t addrspace(10)** %26
  %28 = getelementptr %jl_value_t**, %jl_value_t*** %ptls, i32 0
  %29 = bitcast %jl_value_t*** %28 to %jl_value_t addrspace(10)**
  store %jl_value_t addrspace(10)* %27, %jl_value_t addrspace(10)** %29
  ret void
}

code_warntype looks identical:

Variables
  #self#::Core.Compiler.Const(var"#34#35"{var"#kernel#50"}(var"#kernel#50"()), false)
  args::Tuple{Ptr{Int64}}

Body::Nothing
1 ─ %1 = Cassette.overdub::Core.Compiler.Const(Cassette.overdub, false)
│   %2 = Main.Noop()::Core.Compiler.Const(Cassette.Context{nametype(Noop),Nothing,Nothing,Cassette.var"##PassType#424",Nothing,Nothing}(nametype(Noop)(), nothing, nothing, Cassette.var"##PassType#424"(), nothing, nothing), false)
│   %3 = Core.getfield(#self#, :f)::Core.Compiler.Const(var"#kernel#50"(), false)
│   %4 = Core.tuple(%2, %3)::Core.Compiler.Const((Cassette.Context{nametype(Noop),Nothing,Nothing,Cassette.var"##PassType#424",Nothing,Nothing}(nametype(Noop)(), nothing, nothing, Cassette.var"##PassType#424"(), nothing, nothing), var"#kernel#50"()), false)
│   %5 = Core._apply(%1, %4, args)::Core.Compiler.Const(nothing, false)
└──      return %5
Variables
  #self#::Core.Compiler.Const(var"#34#35"{var"#kernel#50"}(var"#kernel#50"()), false)
  args::Core.Compiler.PartialStruct(Tuple{DataType,Ptr{Int64}}, Any[Core.Compiler.Const(Int64, false), Ptr{Int64}])

Body::Nothing
1 ─ %1 = Cassette.overdub::Core.Compiler.Const(Cassette.overdub, false)
│   %2 = Main.Noop()::Core.Compiler.Const(Cassette.Context{nametype(Noop),Nothing,Nothing,Cassette.var"##PassType#424",Nothing,Nothing}(nametype(Noop)(), nothing, nothing, Cassette.var"##PassType#424"(), nothing, nothing), false)
│   %3 = Core.getfield(#self#, :f)::Core.Compiler.Const(var"#kernel#50"(), false)
│   %4 = Core.tuple(%2, %3)::Core.Compiler.Const((Cassette.Context{nametype(Noop),Nothing,Nothing,Cassette.var"##PassType#424",Nothing,Nothing}(nametype(Noop)(), nothing, nothing, Cassette.var"##PassType#424"(), nothing, nothing), var"#kernel#50"()), false)
│   %5 = Core._apply(%1, %4, args)::Core.Compiler.Const(nothing, false)
└──      return %5

And there's no inference failure when looking with Ctulhu:

│ ─ %-1  = invoke #34(::Ptr{Int64})::Core.Compiler.Const(nothing, false)
CodeInfo(
    @ /tmp/wip.jl:3 within `#34'
1 ─ %1 = (getfield)(args, 1)::Ptr{Int64}
│  ┌ @ /tmp/wip.jl:10 within `kernel'
│  │┌ @ pointer.jl:118 within `unsafe_store!' @ pointer.jl:118
│  ││┌ @ /home/tim/Julia/pkg/Cassette/src/overdub.jl:481 within `_overdub_fallback'
│  │││┌ @ /home/tim/Julia/pkg/Cassette/src/context.jl:445 within `fallback'
│  ││││┌ @ /home/tim/Julia/pkg/Cassette/src/context.jl:447 within `call'
│  │││││      (pointerset)(%1, 1, 1, 1)::Ptr{Int64}
│  └└└└└
└──      return
)
Select a call to descend into or ↩ to ascend. [q]uit.
Toggles: [o]ptimize, [w]arn, [d]ebuginfo, [s]yntax highlight for LLVM/Native.
Show: [L]LVM IR, [N]ative code
Advanced: dump [P]arams cache.

 • ↩


│ ─ %-1  = invoke #34(::Type{Int64},::Ptr{Int64})::Core.Compiler.Const(nothing, false)
CodeInfo(
    @ /tmp/wip.jl:3 within `#34'
1 ─ %1 = (getfield)(args, 2)::Ptr{Int64}
│  ┌ @ /tmp/wip.jl:19 within `kernel'
│  │┌ @ pointer.jl:118 within `unsafe_store!' @ pointer.jl:118
│  ││┌ @ /home/tim/Julia/pkg/Cassette/src/overdub.jl:481 within `_overdub_fallback'
│  │││┌ @ /home/tim/Julia/pkg/Cassette/src/context.jl:445 within `fallback'
│  ││││┌ @ /home/tim/Julia/pkg/Cassette/src/context.jl:447 within `call'
│  │││││      (pointerset)(%1, 1, 1, 1)::Ptr{Int64}
│  └└└└└
└──      return
)
Select a call to descend into or ↩ to ascend. [q]uit.
Toggles: [o]ptimize, [w]arn, [d]ebuginfo, [s]yntax highlight for LLVM/Native.
Show: [L]LVM IR, [N]ative code
Advanced: dump [P]arams cache.

 • ↩

Loading

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Linked issues

Successfully merging this pull request may close these issues.

3 participants