New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounds checking + shared memory sometimes causes invalid results #4

Open
maleadt opened this Issue Sep 16, 2016 · 28 comments

Comments

Projects
None yet
5 participants
@maleadt
Member

maleadt commented Sep 16, 2016

Cause seems to be an added checkbounds, if that even makes sense.

Repro:

using CUDAdrv, CUDAnative

@target ptx function kernel(arr::Ptr{Int32})
    temp = @cuStaticSharedMem(Int32, (2, 1))
    tx = Int(threadIdx().x)

    if tx == 1
        for i = 1:2
            # THIS BREAKS STUFF: checkbounds(temp, i)
            Base.pointerset(temp.ptr, 1, i, 8)
        end
    end
    sync_threads()

    Base.pointerset(arr, Base.pointerref(temp.ptr, tx, 8), tx, 8)

    return nothing
end

dev = CuDevice(0)
ctx = CuContext(dev)

d_arr = CuArray(Int32, (2, 1))
@cuda (1,2) kernel(d_arr.ptr)
println(Array(d_arr))

destroy(ctx)

Result without checkbounds: [1; 1]. With: [1; 0].

cc @cfoket

@maleadt maleadt self-assigned this Sep 16, 2016

@maleadt maleadt added the bug label Sep 16, 2016

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 5, 2016

Member

Not reproducible with this code anymore, but rodinia/lud.jl still fails with --check-bounds=yes probably still caused by the same underlying issue.

Member

maleadt commented Oct 5, 2016

Not reproducible with this code anymore, but rodinia/lud.jl still fails with --check-bounds=yes probably still caused by the same underlying issue.

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 26, 2016

Member

New repro, again using shared memory + bounds checking, but this time the invalid value is the result of __shfl_down (not touching shared memory at all):

using CUDAdrv, CUDAnative

function kernel(ptr::Ptr{Cint})
    shared = @cuStaticSharedMem(Cint, 4)

    lane = (threadIdx().x-1) % warpsize

    if lane == 0
        @boundscheck Base.checkbounds(shared, threadIdx().x)
        unsafe_store!(shared.ptr, 0, threadIdx().x)
    end

    sync_threads()

    val = shfl_down(Cint(32), 1, 4)
    if lane == 0
        unsafe_store!(ptr, val)
    end

    return
end

dev = CuDevice(0)
ctx = CuContext(dev)

gpu_val = CuArray(Cint, 1)
@cuda dev (1,4) kernel(gpu_val.ptr)
val = Array(gpu_val)[1]
println(val)

destroy(ctx)

Returns 0 with checkbounds, 32 without.

Member

maleadt commented Oct 26, 2016

New repro, again using shared memory + bounds checking, but this time the invalid value is the result of __shfl_down (not touching shared memory at all):

using CUDAdrv, CUDAnative

function kernel(ptr::Ptr{Cint})
    shared = @cuStaticSharedMem(Cint, 4)

    lane = (threadIdx().x-1) % warpsize

    if lane == 0
        @boundscheck Base.checkbounds(shared, threadIdx().x)
        unsafe_store!(shared.ptr, 0, threadIdx().x)
    end

    sync_threads()

    val = shfl_down(Cint(32), 1, 4)
    if lane == 0
        unsafe_store!(ptr, val)
    end

    return
end

dev = CuDevice(0)
ctx = CuContext(dev)

gpu_val = CuArray(Cint, 1)
@cuda dev (1,4) kernel(gpu_val.ptr)
val = Array(gpu_val)[1]
println(val)

destroy(ctx)

Returns 0 with checkbounds, 32 without.

@maleadt maleadt changed the title from Shared memory changes not visible for all threads to Bounds checking + shared memory sometimes causes invalid results Oct 26, 2016

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 26, 2016

Member

Managed to reduce to two sets of LLVM IR, executed using the following snippet:

using CUDAdrv, CUDAnative, LLVM

dev = CuDevice(0)
ctx = CuContext(dev)

for ir_fn in ["bug-working.ll", "bug-broken.ll"]
    gpu_val = CuArray(Cint[42])

    ir = readstring(ir_fn)
    mod = parse(LLVM.Module, ir)
    fn = "kernel"
    entry = get(functions(mod), "kernel")
    ptx = CUDAnative.mcgen(mod, entry, v"3.0")

    cuda_mod = CuModule(ptx)
    cuda_fun = CuFunction(cuda_mod, fn)

    cudacall(cuda_fun, 1, 4, (Ptr{Cint},), gpu_val.ptr)

    val = Array(gpu_val)[1]
    println(val)
end

destroy(ctx)

Working IR:

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

@shmem = internal addrspace(3) global [4 x i32] zeroinitializer, align 4

define void @kernel(i32*) {
top:
  %1 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %2 = and i32 %1, 31
  %3 = icmp eq i32 %2, 0
  br i1 %3, label %lane0_boundscheck, label %sync_shfl

lane0_boundscheck:
  %4 = icmp ugt i32 %1, 3
  br i1 %4, label %lane0_oob, label %lane0_shmem

lane0_oob:
  tail call void @llvm.trap()
  unreachable

sync_shfl:
  tail call void @llvm.nvvm.barrier0()
  %5 = tail call i32 @llvm.nvvm.shfl.down.i32(i32 32, i32 1, i32 7199)
  br i1 %3, label %lane0_writeback, label %end

lane0_shmem:
  %6 = getelementptr [4 x i32], [4 x i32] addrspace(3)* @shmem, i32 0, i32 %1
  store i32 0, i32 addrspace(3)* %6, align 8
  br label %sync_shfl

lane0_writeback:
  store i32 %5, i32* %0, align 8
  br label %end

end:
  ret void
}

declare void @llvm.trap()
declare i32 @llvm.nvvm.shfl.down.i32(i32, i32, i32)
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare void @llvm.nvvm.barrier0()

Broken IR:

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

@shmem = internal addrspace(3) global [4 x i32] zeroinitializer, align 4

define void @kernel(i32*) {
top:
  %1 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %2 = and i32 %1, 31
  %3 = icmp eq i32 %2, 0
  br i1 %3, label %lane0_boundscheck, label %sync_shfl

lane0_boundscheck:
  %4 = icmp ugt i32 %1, 3
  br i1 %4, label %lane0_oob, label %lane0_shmem

sync_shfl:
  tail call void @llvm.nvvm.barrier0()
  %5 = tail call i32 @llvm.nvvm.shfl.down.i32(i32 32, i32 1, i32 7199)
  br i1 %3, label %lane0_writeback, label %end

lane0_oob:
  tail call void @llvm.trap()
  unreachable

lane0_shmem:
  %6 = getelementptr [4 x i32], [4 x i32] addrspace(3)* @shmem, i32 0, i32 %1
  store i32 0, i32 addrspace(3)* %6, align 8
  br label %sync_shfl

lane0_writeback:
  store i32 %5, i32* %0, align 8
  br label %end

end:
  ret void
}

declare void @llvm.trap()
declare i32 @llvm.nvvm.shfl.down.i32(i32, i32, i32)
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare void @llvm.nvvm.barrier0()

That's right, the only difference between those two is the placement of the oob BB...
cc @cfoket

Member

maleadt commented Oct 26, 2016

Managed to reduce to two sets of LLVM IR, executed using the following snippet:

using CUDAdrv, CUDAnative, LLVM

dev = CuDevice(0)
ctx = CuContext(dev)

for ir_fn in ["bug-working.ll", "bug-broken.ll"]
    gpu_val = CuArray(Cint[42])

    ir = readstring(ir_fn)
    mod = parse(LLVM.Module, ir)
    fn = "kernel"
    entry = get(functions(mod), "kernel")
    ptx = CUDAnative.mcgen(mod, entry, v"3.0")

    cuda_mod = CuModule(ptx)
    cuda_fun = CuFunction(cuda_mod, fn)

    cudacall(cuda_fun, 1, 4, (Ptr{Cint},), gpu_val.ptr)

    val = Array(gpu_val)[1]
    println(val)
end

destroy(ctx)

Working IR:

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

@shmem = internal addrspace(3) global [4 x i32] zeroinitializer, align 4

define void @kernel(i32*) {
top:
  %1 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %2 = and i32 %1, 31
  %3 = icmp eq i32 %2, 0
  br i1 %3, label %lane0_boundscheck, label %sync_shfl

lane0_boundscheck:
  %4 = icmp ugt i32 %1, 3
  br i1 %4, label %lane0_oob, label %lane0_shmem

lane0_oob:
  tail call void @llvm.trap()
  unreachable

sync_shfl:
  tail call void @llvm.nvvm.barrier0()
  %5 = tail call i32 @llvm.nvvm.shfl.down.i32(i32 32, i32 1, i32 7199)
  br i1 %3, label %lane0_writeback, label %end

lane0_shmem:
  %6 = getelementptr [4 x i32], [4 x i32] addrspace(3)* @shmem, i32 0, i32 %1
  store i32 0, i32 addrspace(3)* %6, align 8
  br label %sync_shfl

lane0_writeback:
  store i32 %5, i32* %0, align 8
  br label %end

end:
  ret void
}

declare void @llvm.trap()
declare i32 @llvm.nvvm.shfl.down.i32(i32, i32, i32)
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare void @llvm.nvvm.barrier0()

Broken IR:

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

@shmem = internal addrspace(3) global [4 x i32] zeroinitializer, align 4

define void @kernel(i32*) {
top:
  %1 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %2 = and i32 %1, 31
  %3 = icmp eq i32 %2, 0
  br i1 %3, label %lane0_boundscheck, label %sync_shfl

lane0_boundscheck:
  %4 = icmp ugt i32 %1, 3
  br i1 %4, label %lane0_oob, label %lane0_shmem

sync_shfl:
  tail call void @llvm.nvvm.barrier0()
  %5 = tail call i32 @llvm.nvvm.shfl.down.i32(i32 32, i32 1, i32 7199)
  br i1 %3, label %lane0_writeback, label %end

lane0_oob:
  tail call void @llvm.trap()
  unreachable

lane0_shmem:
  %6 = getelementptr [4 x i32], [4 x i32] addrspace(3)* @shmem, i32 0, i32 %1
  store i32 0, i32 addrspace(3)* %6, align 8
  br label %sync_shfl

lane0_writeback:
  store i32 %5, i32* %0, align 8
  br label %end

end:
  ret void
}

declare void @llvm.trap()
declare i32 @llvm.nvvm.shfl.down.i32(i32, i32, i32)
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare void @llvm.nvvm.barrier0()

That's right, the only difference between those two is the placement of the oob BB...
cc @cfoket

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 26, 2016

Member

One layer deeper...

Working PTX:

.version 3.2
.target sm_30
.address_size 64

.visible .entry kernel(
        .param .u64 output  // single int output
)
{
        .reg .pred      %p<4>;
        .reg .b32       %r<6>;
        .reg .b64       %rd<6>;
        .shared .align 4 .b8 shmem[16];     // 4 integers
        ld.param.u64    %rd1, [output];

        // calculate lane, check if 0
        mov.u32 %r1, %tid.x;
        and.b32         %r2, %r1, 31;
        setp.ne.s32     %p1, %r2, 0;
        @%p1 bra        BB_SHFL;

        // bounds check for shmem access
        setp.lt.u32     %p2, %r1, 4;
        @%p2 bra        BB_SHMEM;
        bra.uni         BB_OOB;
BB_SHMEM:
        mul.wide.s32    %rd2, %r1, 4;
        mov.u64         %rd3, shmem;
        add.s64         %rd4, %rd3, %rd2;
        mov.u32         %r4, 0;
        st.shared.u32   [%rd4], %r4;
BB_SHFL:
        setp.eq.s32     %p3, %r2, 0;
        bar.sync        0;
        mov.u32         %r5, 32;
        shfl.down.b32 %r3, %r5, 1, 7199;
        @%p3 bra        BB_WRITEBACK;
        bra.uni         BB_END;
BB_WRITEBACK:
        cvta.to.global.u64      %rd5, %rd1;
        st.global.u32   [%rd5], %r3;
BB_END:
        ret;
BB_OOB:
        trap;
}

Broken PTX:

.version 3.2
.target sm_30
.address_size 64

.visible .entry kernel(
        .param .u64 output  // single int output
)
{
        .reg .pred      %p<4>;
        .reg .b32       %r<6>;
        .reg .b64       %rd<6>;
        .shared .align 4 .b8 shmem[16];     // 4 integers
        ld.param.u64    %rd1, [output];

        // calculate lane, check if 0
        mov.u32 %r1, %tid.x;
        and.b32         %r2, %r1, 31;
        setp.ne.s32     %p1, %r2, 0;
        @%p1 bra        BB_SHFL;

        // bounds check for shmem access
        setp.gt.u32     %p2, %r1, 3;
        @%p2 bra        BB_OOB;
        bra.uni         BB_SHMEM;
BB_SHMEM:
        mul.wide.s32    %rd2, %r1, 4;
        mov.u64         %rd3, shmem;
        add.s64         %rd4, %rd3, %rd2;
        mov.u32         %r4, 0;
        st.shared.u32   [%rd4], %r4;
BB_SHFL:
        setp.eq.s32     %p3, %r2, 0;
        bar.sync        0;
        mov.u32         %r5, 32;
        shfl.down.b32 %r3, %r5, 1, 7199;
        @%p3 bra        BB_WRITEBACK;
        bra.uni         BB_END;
BB_WRITEBACK:
        cvta.to.global.u64      %rd5, %rd1;
        st.global.u32   [%rd5], %r3;
BB_END:
        ret;
BB_OOB:
        trap;
}

Loader:

using CUDAdrv

dev = CuDevice(0)
ctx = CuContext(dev)

fn = "kernel"

for name in ["bug-working", "bug-broken"]
    gpu_val = CuArray(Cint[42])

    ptx = readstring("$name.ptx")

    cuda_mod = CuModule(ptx)
    cuda_fun = CuFunction(cuda_mod, fn)

    cudacall(cuda_fun, 1, 4, (Ptr{Cint},), gpu_val.ptr)

    val = Array(gpu_val)[1]
    println(val)
end

destroy(ctx)

Only difference: the bounds-check branch (>3 or <4):

$ diff bug-working.ptx bug-broken.ptx                                                                                                                                            *[master] 
22,24c22,24
<         setp.lt.u32     %p2, %r1, 4;
<         @%p2 bra        BB_SHMEM;
<         bra.uni         BB_OOB;
---
>         setp.gt.u32     %p2, %r1, 3;
>         @%p2 bra        BB_OOB;
>         bra.uni         BB_SHMEM;

Probably an assembler bug.

Member

maleadt commented Oct 26, 2016

One layer deeper...

Working PTX:

.version 3.2
.target sm_30
.address_size 64

.visible .entry kernel(
        .param .u64 output  // single int output
)
{
        .reg .pred      %p<4>;
        .reg .b32       %r<6>;
        .reg .b64       %rd<6>;
        .shared .align 4 .b8 shmem[16];     // 4 integers
        ld.param.u64    %rd1, [output];

        // calculate lane, check if 0
        mov.u32 %r1, %tid.x;
        and.b32         %r2, %r1, 31;
        setp.ne.s32     %p1, %r2, 0;
        @%p1 bra        BB_SHFL;

        // bounds check for shmem access
        setp.lt.u32     %p2, %r1, 4;
        @%p2 bra        BB_SHMEM;
        bra.uni         BB_OOB;
BB_SHMEM:
        mul.wide.s32    %rd2, %r1, 4;
        mov.u64         %rd3, shmem;
        add.s64         %rd4, %rd3, %rd2;
        mov.u32         %r4, 0;
        st.shared.u32   [%rd4], %r4;
BB_SHFL:
        setp.eq.s32     %p3, %r2, 0;
        bar.sync        0;
        mov.u32         %r5, 32;
        shfl.down.b32 %r3, %r5, 1, 7199;
        @%p3 bra        BB_WRITEBACK;
        bra.uni         BB_END;
BB_WRITEBACK:
        cvta.to.global.u64      %rd5, %rd1;
        st.global.u32   [%rd5], %r3;
BB_END:
        ret;
BB_OOB:
        trap;
}

Broken PTX:

.version 3.2
.target sm_30
.address_size 64

.visible .entry kernel(
        .param .u64 output  // single int output
)
{
        .reg .pred      %p<4>;
        .reg .b32       %r<6>;
        .reg .b64       %rd<6>;
        .shared .align 4 .b8 shmem[16];     // 4 integers
        ld.param.u64    %rd1, [output];

        // calculate lane, check if 0
        mov.u32 %r1, %tid.x;
        and.b32         %r2, %r1, 31;
        setp.ne.s32     %p1, %r2, 0;
        @%p1 bra        BB_SHFL;

        // bounds check for shmem access
        setp.gt.u32     %p2, %r1, 3;
        @%p2 bra        BB_OOB;
        bra.uni         BB_SHMEM;
BB_SHMEM:
        mul.wide.s32    %rd2, %r1, 4;
        mov.u64         %rd3, shmem;
        add.s64         %rd4, %rd3, %rd2;
        mov.u32         %r4, 0;
        st.shared.u32   [%rd4], %r4;
BB_SHFL:
        setp.eq.s32     %p3, %r2, 0;
        bar.sync        0;
        mov.u32         %r5, 32;
        shfl.down.b32 %r3, %r5, 1, 7199;
        @%p3 bra        BB_WRITEBACK;
        bra.uni         BB_END;
BB_WRITEBACK:
        cvta.to.global.u64      %rd5, %rd1;
        st.global.u32   [%rd5], %r3;
BB_END:
        ret;
BB_OOB:
        trap;
}

Loader:

using CUDAdrv

dev = CuDevice(0)
ctx = CuContext(dev)

fn = "kernel"

for name in ["bug-working", "bug-broken"]
    gpu_val = CuArray(Cint[42])

    ptx = readstring("$name.ptx")

    cuda_mod = CuModule(ptx)
    cuda_fun = CuFunction(cuda_mod, fn)

    cudacall(cuda_fun, 1, 4, (Ptr{Cint},), gpu_val.ptr)

    val = Array(gpu_val)[1]
    println(val)
end

destroy(ctx)

Only difference: the bounds-check branch (>3 or <4):

$ diff bug-working.ptx bug-broken.ptx                                                                                                                                            *[master] 
22,24c22,24
<         setp.lt.u32     %p2, %r1, 4;
<         @%p2 bra        BB_SHMEM;
<         bra.uni         BB_OOB;
---
>         setp.gt.u32     %p2, %r1, 3;
>         @%p2 bra        BB_OOB;
>         bra.uni         BB_SHMEM;

Probably an assembler bug.

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 26, 2016

Member

Alternative loader, using ptxas to generate a cubin (in order to play with ptxas optimization flags, but doesn't seem to matter):

using CUDAdrv

dev = CuDevice(0)
ctx = CuContext(dev)

fn = "kernel"

for name in ["kernel-working", "kernel-broken"]
    gpu_val = CuArray(Cint[42])

    run(`ptxas -arch=sm_61 -o $name.cubin $name.ptx`)

    cuda_mod = CuModule(read("$name.cubin"))
    cuda_fun = CuFunction(cuda_mod, fn)

    cudacall(cuda_fun, 1, 4, (Ptr{Cint},), gpu_val.ptr)

    val = Array(gpu_val)[1]
    println(val)
end

destroy(ctx)
Member

maleadt commented Oct 26, 2016

Alternative loader, using ptxas to generate a cubin (in order to play with ptxas optimization flags, but doesn't seem to matter):

using CUDAdrv

dev = CuDevice(0)
ctx = CuContext(dev)

fn = "kernel"

for name in ["kernel-working", "kernel-broken"]
    gpu_val = CuArray(Cint[42])

    run(`ptxas -arch=sm_61 -o $name.cubin $name.ptx`)

    cuda_mod = CuModule(read("$name.cubin"))
    cuda_fun = CuFunction(cuda_mod, fn)

    cudacall(cuda_fun, 1, 4, (Ptr{Cint},), gpu_val.ptr)

    val = Array(gpu_val)[1]
    println(val)
end

destroy(ctx)
@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 26, 2016

Member

Almost definitely looks like an assembler bug. See the following annotated & prettified Pascal SASS (sm_61):

Working version:

kernel:
.text.kernel:
        MOV R1, c[0x0][0x20];
        S2R R2, SR_TID.X;
        SSY `(BB_SHFL);         // push BB_SHFL on reconvergence stack

        // calculate lane, check if 0
        LOP32I.AND R0, R2, 0x1f;
        ISETP.NE.AND P0, PT, R0, RZ, PT;
    @P0 SYNC                    // not lane 0, pop BB_SHFL from reconvergence stack

        // bounds check for shmem access
        ISETP.LT.U32.AND P0, PT, R2, 0x4, PT;
    @P0 BRA `(BB_SHMEM);

//BB_OOB:
        BPT.TRAP 0x1;
        EXIT;
BB_SHMEM:
        SHL R2, R2, 0x2;
        STS [R2], RZ;
        SYNC                    // pop BB_SHFL from reconvergence stack
BB_SHFL:
        // check if lane 0
  {     ISETP.EQ.AND P0, PT, R0, RZ, PT;
        BAR.SYNC 0x0;        }
        // shuffle unconditionally
        MOV32I R0, 0x20;
        SHFL.DOWN PT, R0, R0, 0x1, 0x1c1f;
   @!P0 EXIT;                  // not lane 0, exit
//BB_WRITEBACK:
        MOV R2, c[0x0][0x140];
        MOV R3, c[0x0][0x144];
        STG.E [R2], R0;
        EXIT;
.BB_END:
        BRA `(.BB_END);

Broken version:

kernel:
.text.kernel:
        MOV R1, c[0x0][0x20];
        S2R R2, SR_TID.X;

        // calculate lane, check if 0
        LOP32I.AND R0, R2, 0x1f;
        ISETP.NE.AND P0, PT, R0, RZ, PT;
    @P0 BRA `(BB_SHFL);         // not lane 0, branch to BB_SHFL

        // bounds check for shmem access
        ISETP.GT.U32.AND P0, PT, R2, 0x3, PT;
    @P0 BRA `(BB_OOB);

//BB_SHMEM:
        SHL R2, R2, 0x2;
        STS [R2], RZ;
BB_SHFL:
        // check if lane 0
 {      ISETP.EQ.AND P0, PT, R0, RZ, PT;
        BAR.SYNC 0x0;        }
        // shuffle unconditionally
        MOV32I R0, 0x20;
        SHFL.DOWN PT, R0, R0, 0x1, 0x1c1f;
   @!P0 EXIT;                  // not lane 0, exit
//BB_WRITEBACK:
        MOV R2, c[0x0][0x140];
        MOV R3, c[0x0][0x144];
        STG.E [R2], R0;
        EXIT;
BB_OOB:
        BPT.TRAP 0x1;
        EXIT;
.L_3:
        BRA `(.L_3);
.L_18:

The broken version clearly messes up its reconvergence stack, not pushing anything on it despite multiple conditional branches (for some info on how this works, see this paper by Bialas and Strzelecki)...

Member

maleadt commented Oct 26, 2016

Almost definitely looks like an assembler bug. See the following annotated & prettified Pascal SASS (sm_61):

Working version:

kernel:
.text.kernel:
        MOV R1, c[0x0][0x20];
        S2R R2, SR_TID.X;
        SSY `(BB_SHFL);         // push BB_SHFL on reconvergence stack

        // calculate lane, check if 0
        LOP32I.AND R0, R2, 0x1f;
        ISETP.NE.AND P0, PT, R0, RZ, PT;
    @P0 SYNC                    // not lane 0, pop BB_SHFL from reconvergence stack

        // bounds check for shmem access
        ISETP.LT.U32.AND P0, PT, R2, 0x4, PT;
    @P0 BRA `(BB_SHMEM);

//BB_OOB:
        BPT.TRAP 0x1;
        EXIT;
BB_SHMEM:
        SHL R2, R2, 0x2;
        STS [R2], RZ;
        SYNC                    // pop BB_SHFL from reconvergence stack
BB_SHFL:
        // check if lane 0
  {     ISETP.EQ.AND P0, PT, R0, RZ, PT;
        BAR.SYNC 0x0;        }
        // shuffle unconditionally
        MOV32I R0, 0x20;
        SHFL.DOWN PT, R0, R0, 0x1, 0x1c1f;
   @!P0 EXIT;                  // not lane 0, exit
//BB_WRITEBACK:
        MOV R2, c[0x0][0x140];
        MOV R3, c[0x0][0x144];
        STG.E [R2], R0;
        EXIT;
.BB_END:
        BRA `(.BB_END);

Broken version:

kernel:
.text.kernel:
        MOV R1, c[0x0][0x20];
        S2R R2, SR_TID.X;

        // calculate lane, check if 0
        LOP32I.AND R0, R2, 0x1f;
        ISETP.NE.AND P0, PT, R0, RZ, PT;
    @P0 BRA `(BB_SHFL);         // not lane 0, branch to BB_SHFL

        // bounds check for shmem access
        ISETP.GT.U32.AND P0, PT, R2, 0x3, PT;
    @P0 BRA `(BB_OOB);

//BB_SHMEM:
        SHL R2, R2, 0x2;
        STS [R2], RZ;
BB_SHFL:
        // check if lane 0
 {      ISETP.EQ.AND P0, PT, R0, RZ, PT;
        BAR.SYNC 0x0;        }
        // shuffle unconditionally
        MOV32I R0, 0x20;
        SHFL.DOWN PT, R0, R0, 0x1, 0x1c1f;
   @!P0 EXIT;                  // not lane 0, exit
//BB_WRITEBACK:
        MOV R2, c[0x0][0x140];
        MOV R3, c[0x0][0x144];
        STG.E [R2], R0;
        EXIT;
BB_OOB:
        BPT.TRAP 0x1;
        EXIT;
.L_3:
        BRA `(.L_3);
.L_18:

The broken version clearly messes up its reconvergence stack, not pushing anything on it despite multiple conditional branches (for some info on how this works, see this paper by Bialas and Strzelecki)...

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 26, 2016

Member

And a C++ loader, for reporting purposes.

#include <stdio.h>

#include <cuda.h>

#define CHECK(err) __check(err, __FILE__, __LINE__)
inline void __check(CUresult err, const char *file, const int line) {
  if (CUDA_SUCCESS != err) {
    const char *name, *descr;
    cuGetErrorName(err, &name);
    cuGetErrorString(err, &name);
    fprintf(stderr, "CUDA error #%s: %s at %s:%i\n", name, descr, file, line);
    abort();
  }
}

int test(const char *path) {
  CUmodule mod;
  cuModuleLoad(&mod, path);

  CUfunction fun;
  CHECK(cuModuleGetFunction(&fun, mod, "kernel"));

  int *gpu_val;
  CHECK(cuMemAlloc((CUdeviceptr*) &gpu_val, sizeof(int)));

  void *args[1] = {&gpu_val};
  cuLaunchKernel(fun, 1, 1, 1, 4, 1, 1, 0, NULL, args, NULL);

  int val;
  CHECK(cuMemcpyDtoH(&val, (CUdeviceptr) gpu_val, sizeof(int)));

  CHECK(cuModuleUnload(mod));

  return val;
}

int main() {
  CHECK(cuInit(0));

  CUdevice dev;
  CHECK(cuDeviceGet(&dev, 0));

  CUcontext ctx;
  CHECK(cuCtxCreate(&ctx, 0, dev));

  printf("working: %d\n", test("kernel-working.ptx"));
  printf("broken: %d\n", test("kernel-broken.ptx"));

  CHECK(cuCtxDestroy(ctx));

  return 0;
}

Will probably submit this to NVIDIA soon, unless anybody still spots us doing something wrong.

Member

maleadt commented Oct 26, 2016

And a C++ loader, for reporting purposes.

#include <stdio.h>

#include <cuda.h>

#define CHECK(err) __check(err, __FILE__, __LINE__)
inline void __check(CUresult err, const char *file, const int line) {
  if (CUDA_SUCCESS != err) {
    const char *name, *descr;
    cuGetErrorName(err, &name);
    cuGetErrorString(err, &name);
    fprintf(stderr, "CUDA error #%s: %s at %s:%i\n", name, descr, file, line);
    abort();
  }
}

int test(const char *path) {
  CUmodule mod;
  cuModuleLoad(&mod, path);

  CUfunction fun;
  CHECK(cuModuleGetFunction(&fun, mod, "kernel"));

  int *gpu_val;
  CHECK(cuMemAlloc((CUdeviceptr*) &gpu_val, sizeof(int)));

  void *args[1] = {&gpu_val};
  cuLaunchKernel(fun, 1, 1, 1, 4, 1, 1, 0, NULL, args, NULL);

  int val;
  CHECK(cuMemcpyDtoH(&val, (CUdeviceptr) gpu_val, sizeof(int)));

  CHECK(cuModuleUnload(mod));

  return val;
}

int main() {
  CHECK(cuInit(0));

  CUdevice dev;
  CHECK(cuDeviceGet(&dev, 0));

  CUcontext ctx;
  CHECK(cuCtxCreate(&ctx, 0, dev));

  printf("working: %d\n", test("kernel-working.ptx"));
  printf("broken: %d\n", test("kernel-broken.ptx"));

  CHECK(cuCtxDestroy(ctx));

  return 0;
}

Will probably submit this to NVIDIA soon, unless anybody still spots us doing something wrong.

@maleadt maleadt added the upstream label Oct 27, 2016

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 27, 2016

Member

Reported this repro to NVIDIA, bug #1833004. Will disable bounds checking for the time being.

Member

maleadt commented Oct 27, 2016

Reported this repro to NVIDIA, bug #1833004. Will disable bounds checking for the time being.

maleadt added a commit that referenced this issue Oct 27, 2016

@vchuravy

This comment has been minimized.

Show comment
Hide comment
@vchuravy

vchuravy Oct 27, 2016

Member

Could we fix this on the LLVM side? Any bugfix to the assembler is going to be deployed slowly.

Member

vchuravy commented Oct 27, 2016

Could we fix this on the LLVM side? Any bugfix to the assembler is going to be deployed slowly.

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Oct 27, 2016

Member

I haven't figured out what PTX pattern exactly triggers the SASS emission bug. Probably the branch to a trap BB. I've asked NVIDIA for some background on the bug, if they deem it a bug, so I'm going to wait for them to respond before sinking more time into this.

Member

maleadt commented Oct 27, 2016

I haven't figured out what PTX pattern exactly triggers the SASS emission bug. Probably the branch to a trap BB. I've asked NVIDIA for some background on the bug, if they deem it a bug, so I'm going to wait for them to respond before sinking more time into this.

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Nov 10, 2016

Member

Status update from NVIDIA:

The following items have been modified for this bug:
 - Status changed from "Open - pending review" to "Open - Fix being tested"

... but I haven't got access to their bug tracker (I'm only on its CC list), so I can't look at or ask for more details 😕

Member

maleadt commented Nov 10, 2016

Status update from NVIDIA:

The following items have been modified for this bug:
 - Status changed from "Open - pending review" to "Open - Fix being tested"

... but I haven't got access to their bug tracker (I'm only on its CC list), so I can't look at or ask for more details 😕

@jmaebe

This comment has been minimized.

Show comment
Hide comment
@jmaebe

jmaebe Nov 10, 2016

At least you know it is in fact their fault :)

jmaebe commented Nov 10, 2016

At least you know it is in fact their fault :)

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Jan 6, 2017

Member

The following items have been modified for this bug:

  • Status changed from "Open - Fix being tested" to "Closed - Fixed"

No idea how / starting which version / ... though (still don't allow me access to the bug tracker).

Member

maleadt commented Jan 6, 2017

The following items have been modified for this bug:

  • Status changed from "Open - Fix being tested" to "Closed - Fixed"

No idea how / starting which version / ... though (still don't allow me access to the bug tracker).

maleadt added a commit that referenced this issue Mar 30, 2017

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Apr 27, 2017

Member

Revisited this issue. Seems like it's still there, at least on NVIDIA driver 375.39, but I found out that it only reproduces on sm_61 hardware or newer. I haven't heard back from NVIDIA, so I don't know which driver includes the fix, and the only system with sm_61 hardware I have is locked to driver 375.39...

Anyone with sm_61 hw on more recent drivers care to test this? I've updated the repro scripts too.

Member

maleadt commented Apr 27, 2017

Revisited this issue. Seems like it's still there, at least on NVIDIA driver 375.39, but I found out that it only reproduces on sm_61 hardware or newer. I haven't heard back from NVIDIA, so I don't know which driver includes the fix, and the only system with sm_61 hardware I have is locked to driver 375.39...

Anyone with sm_61 hw on more recent drivers care to test this? I've updated the repro scripts too.

@vchuravy

This comment has been minimized.

Show comment
Hide comment
@vchuravy

vchuravy Apr 27, 2017

Member

I only have access to sm_60, but I could test it on that.

Member

vchuravy commented Apr 27, 2017

I only have access to sm_60, but I could test it on that.

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Apr 27, 2017

Member

Great! Please send me the output (verify the bug is still there), SASS files generated by ptx.jl (remove existing ones first), and the driver version. No hurry though, it's not like we can do much about it. But given some extra data points, it might be possible to re-enable bounds checking...

Member

maleadt commented Apr 27, 2017

Great! Please send me the output (verify the bug is still there), SASS files generated by ptx.jl (remove existing ones first), and the driver version. No hurry though, it's not like we can do much about it. But given some extra data points, it might be possible to re-enable bounds checking...

@vchuravy

This comment has been minimized.

Show comment
Hide comment
@vchuravy

vchuravy Apr 28, 2017

Member
Member

vchuravy commented Apr 28, 2017

@vchuravy vchuravy self-assigned this May 8, 2017

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Jul 3, 2017

Member

Bug still there on 375.66 (current long-lived).

Member

maleadt commented Jul 3, 2017

Bug still there on 375.66 (current long-lived).

@jlebar

This comment has been minimized.

Show comment
Hide comment
@jlebar

jlebar Oct 27, 2017

It looks like you've discovered https://bugs.llvm.org/show_bug.cgi?id=27738, or something related. Unfortunately we've gotten zero movement from nvidia on this in the ~1.5 years since we discovered it ourselves and brought it to their attention. It's possible that CUDA 9's ptxas will be better, but I don't expect a proper fix except inasmuch as "buy a Volta card and use the new sync intrinsics" is a fix.

Yours is the cleanest reduction of this bug I've seen, btw.

jlebar commented Oct 27, 2017

It looks like you've discovered https://bugs.llvm.org/show_bug.cgi?id=27738, or something related. Unfortunately we've gotten zero movement from nvidia on this in the ~1.5 years since we discovered it ourselves and brought it to their attention. It's possible that CUDA 9's ptxas will be better, but I don't expect a proper fix except inasmuch as "buy a Volta card and use the new sync intrinsics" is a fix.

Yours is the cleanest reduction of this bug I've seen, btw.

@jlebar

This comment has been minimized.

Show comment
Hide comment
@jlebar

jlebar Mar 29, 2018

FYI, @timshen91 is rolling out an incomplete fix for this in LLVM, and working on the full fix. He'll post details in the bug.

Empirically, the partial fix he has in hand fixes this problem for everything we've seen on our end. We'd be curious to hear if it fixes anything for you all.

jlebar commented Mar 29, 2018

FYI, @timshen91 is rolling out an incomplete fix for this in LLVM, and working on the full fix. He'll post details in the bug.

Empirically, the partial fix he has in hand fixes this problem for everything we've seen on our end. We'd be curious to hear if it fixes anything for you all.

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Mar 30, 2018

Member

Oh cool, thanks for the ping! I'll have a look about reproducing, since it's a while ago since I last looked at this. We also mentioned this issue to NVIDIA and they were going to look into giving us more info; if that happens I'll update here.

Member

maleadt commented Mar 30, 2018

Oh cool, thanks for the ping! I'll have a look about reproducing, since it's a while ago since I last looked at this. We also mentioned this issue to NVIDIA and they were going to look into giving us more info; if that happens I'll update here.

@timshen91

This comment has been minimized.

Show comment
Hide comment
@timshen91

timshen91 Mar 30, 2018

The partial fix is https://reviews.llvm.org/D45008 and https://reviews.llvm.org/D45070. Once they are committed, I'll update with the revision number that needs to be sync'ed pass.

timshen91 commented Mar 30, 2018

The partial fix is https://reviews.llvm.org/D45008 and https://reviews.llvm.org/D45070. Once they are committed, I'll update with the revision number that needs to be sync'ed pass.

@timshen91

This comment has been minimized.

Show comment
Hide comment
@timshen91

timshen91 Apr 2, 2018

Any LLVM who's revision is larger than or equal to r328885 should include my partial fix.

I tried to use 367.48 nvcc and ptxas (but with newer driver) to reproduce the bug but failed. I'll wait for @maleadt for a short period of time and see what will happen. :)

timshen91 commented Apr 2, 2018

Any LLVM who's revision is larger than or equal to r328885 should include my partial fix.

I tried to use 367.48 nvcc and ptxas (but with newer driver) to reproduce the bug but failed. I'll wait for @maleadt for a short period of time and see what will happen. :)

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Apr 3, 2018

Member

Similarly, I had to revert to 375.66, as I could not reproduce the issue on 384.111 (Debian stable BPO).

Testing on r329021, it seems like the bug is still there though (on sm_61).
I'll recreate a full non-Julia MWE here so that you can test for yourself:

working.ll:

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

@shmem = internal addrspace(3) global [4 x i32] zeroinitializer, align 4

define void @kernel(i32*) {
top:
  %1 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %2 = and i32 %1, 31
  %3 = icmp eq i32 %2, 0
  br i1 %3, label %lane0_boundscheck, label %sync_shfl

lane0_boundscheck:
  %4 = icmp ugt i32 %1, 3
  br i1 %4, label %lane0_oob, label %lane0_shmem

lane0_oob:
  tail call void @llvm.trap()
  unreachable

sync_shfl:
  tail call void @llvm.nvvm.barrier0()
  %5 = tail call i32 @llvm.nvvm.shfl.down.i32(i32 32, i32 1, i32 7199)
  br i1 %3, label %lane0_writeback, label %end

lane0_shmem:
  %6 = getelementptr [4 x i32], [4 x i32] addrspace(3)* @shmem, i32 0, i32 %1
  store i32 0, i32 addrspace(3)* %6, align 8
  br label %sync_shfl

lane0_writeback:
  store i32 %5, i32* %0, align 8
  br label %end

end:
  ret void
}

declare void @llvm.trap()
declare i32 @llvm.nvvm.shfl.down.i32(i32, i32, i32)
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare void @llvm.nvvm.barrier0()

!nvvm.annotations = !{!0}
!0 = !{void (i32*)* @kernel, !"kernel", i32 1}

broken.ll (only difference is the ordering of the lane0_oob and sync_shfl BBs):

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

@shmem = internal addrspace(3) global [4 x i32] zeroinitializer, align 4

define void @kernel(i32*) {
top:
  %1 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %2 = and i32 %1, 31
  %3 = icmp eq i32 %2, 0
  br i1 %3, label %lane0_boundscheck, label %sync_shfl

lane0_boundscheck:
  %4 = icmp ugt i32 %1, 3
  br i1 %4, label %lane0_oob, label %lane0_shmem

sync_shfl:
  tail call void @llvm.nvvm.barrier0()
  %5 = tail call i32 @llvm.nvvm.shfl.down.i32(i32 32, i32 1, i32 7199)
  br i1 %3, label %lane0_writeback, label %end

lane0_oob:
  tail call void @llvm.trap()
  unreachable

lane0_shmem:
  %6 = getelementptr [4 x i32], [4 x i32] addrspace(3)* @shmem, i32 0, i32 %1
  store i32 0, i32 addrspace(3)* %6, align 8
  br label %sync_shfl

lane0_writeback:
  store i32 %5, i32* %0, align 8
  br label %end

end:
  ret void
}

declare void @llvm.trap()
declare i32 @llvm.nvvm.shfl.down.i32(i32, i32, i32)
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare void @llvm.nvvm.barrier0()

!nvvm.annotations = !{!0}
!0 = !{void (i32*)* @kernel, !"kernel", i32 1}

Compile to PTX:

llc -mcpu=sm_30 broken.ll -o working.ptx
llc -mcpu=sm_30 broken.ll -o broken.ptx

Loader for PTX code:

#include <stdio.h>

#include <cuda.h>

#define CHECK(err) __check(err, __FILE__, __LINE__)
inline void __check(CUresult err, const char *file, const int line) {
  if (CUDA_SUCCESS != err) {
    const char *name, *descr;
    cuGetErrorName(err, &name);
    cuGetErrorString(err, &name);
    fprintf(stderr, "CUDA error #%s: %s at %s:%i\n", name, descr, file, line);
    abort();
  }
}

int test(const char *path) {
  CUmodule mod;
  CHECK(cuModuleLoad(&mod, path));

  CUfunction fun;
  CHECK(cuModuleGetFunction(&fun, mod, "kernel"));

  int *gpu_val;
  CHECK(cuMemAlloc((CUdeviceptr*) &gpu_val, sizeof(int)));

  void *args[1] = {&gpu_val};
  CHECK(cuLaunchKernel(fun, 1, 1, 1, 4, 1, 1, 0, NULL, args, NULL));

  int val;
  CHECK(cuMemcpyDtoH(&val, (CUdeviceptr) gpu_val, sizeof(int)));

  CHECK(cuModuleUnload(mod));

  return val;
}

int main() {
  CHECK(cuInit(0));

  CUdevice dev;
  CHECK(cuDeviceGet(&dev, 0));

  CUcontext ctx;
  CHECK(cuCtxCreate(&ctx, 0, dev));

  printf("working: %d\n", test("working.ptx"));
  printf("broken: %d\n", test("broken.ptx"));

  CHECK(cuCtxDestroy(ctx));

  return 0;
}

Output:

$ clang++ ptx_loader.cpp -o ptx_loader -lcuda
$ ./ptx_loader
working: 32
broken: 0

Even though the generated PTX does differ between LLVM 6.0 and LLVM ToT (but differs identically wrt. the working or broken versions):

--- working_6.0.ptx        2018-04-03 10:34:01.000000000 +0200
+++ working_ToT.ptx        2018-04-03 09:57:20.000000000 +0200
@@ -39,12 +39,12 @@
        mov.u32         %r5, 32;
        shfl.down.b32 %r3, %r5, 1, 7199;
        @%p3 bra        LBB0_5;
-// %bb.6:                               // %end
-       ret;
+       bra.uni         LBB0_6;
 LBB0_5:                                 // %lane0_writeback
        ld.param.u64    %rd2, [kernel_param_0];
        cvta.to.global.u64      %rd1, %rd2;
        st.global.u32   [%rd1], %r3;
+LBB0_6:                                 // %end
        ret;
 LBB0_2:                                 // %lane0_oob
        trap;
Member

maleadt commented Apr 3, 2018

Similarly, I had to revert to 375.66, as I could not reproduce the issue on 384.111 (Debian stable BPO).

Testing on r329021, it seems like the bug is still there though (on sm_61).
I'll recreate a full non-Julia MWE here so that you can test for yourself:

working.ll:

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

@shmem = internal addrspace(3) global [4 x i32] zeroinitializer, align 4

define void @kernel(i32*) {
top:
  %1 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %2 = and i32 %1, 31
  %3 = icmp eq i32 %2, 0
  br i1 %3, label %lane0_boundscheck, label %sync_shfl

lane0_boundscheck:
  %4 = icmp ugt i32 %1, 3
  br i1 %4, label %lane0_oob, label %lane0_shmem

lane0_oob:
  tail call void @llvm.trap()
  unreachable

sync_shfl:
  tail call void @llvm.nvvm.barrier0()
  %5 = tail call i32 @llvm.nvvm.shfl.down.i32(i32 32, i32 1, i32 7199)
  br i1 %3, label %lane0_writeback, label %end

lane0_shmem:
  %6 = getelementptr [4 x i32], [4 x i32] addrspace(3)* @shmem, i32 0, i32 %1
  store i32 0, i32 addrspace(3)* %6, align 8
  br label %sync_shfl

lane0_writeback:
  store i32 %5, i32* %0, align 8
  br label %end

end:
  ret void
}

declare void @llvm.trap()
declare i32 @llvm.nvvm.shfl.down.i32(i32, i32, i32)
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare void @llvm.nvvm.barrier0()

!nvvm.annotations = !{!0}
!0 = !{void (i32*)* @kernel, !"kernel", i32 1}

broken.ll (only difference is the ordering of the lane0_oob and sync_shfl BBs):

target datalayout = "e-p:64:64:64-i1:8:8-i8:8:8-i16:16:16-i32:32:32-i64:64:64-f32:32:32-f64:64:64-v16:16:16-v32:32:32-v64:64:64-v128:128:128-n16:32:64"
target triple = "nvptx64-nvidia-cuda"

@shmem = internal addrspace(3) global [4 x i32] zeroinitializer, align 4

define void @kernel(i32*) {
top:
  %1 = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %2 = and i32 %1, 31
  %3 = icmp eq i32 %2, 0
  br i1 %3, label %lane0_boundscheck, label %sync_shfl

lane0_boundscheck:
  %4 = icmp ugt i32 %1, 3
  br i1 %4, label %lane0_oob, label %lane0_shmem

sync_shfl:
  tail call void @llvm.nvvm.barrier0()
  %5 = tail call i32 @llvm.nvvm.shfl.down.i32(i32 32, i32 1, i32 7199)
  br i1 %3, label %lane0_writeback, label %end

lane0_oob:
  tail call void @llvm.trap()
  unreachable

lane0_shmem:
  %6 = getelementptr [4 x i32], [4 x i32] addrspace(3)* @shmem, i32 0, i32 %1
  store i32 0, i32 addrspace(3)* %6, align 8
  br label %sync_shfl

lane0_writeback:
  store i32 %5, i32* %0, align 8
  br label %end

end:
  ret void
}

declare void @llvm.trap()
declare i32 @llvm.nvvm.shfl.down.i32(i32, i32, i32)
declare i32 @llvm.nvvm.read.ptx.sreg.tid.x()
declare void @llvm.nvvm.barrier0()

!nvvm.annotations = !{!0}
!0 = !{void (i32*)* @kernel, !"kernel", i32 1}

Compile to PTX:

llc -mcpu=sm_30 broken.ll -o working.ptx
llc -mcpu=sm_30 broken.ll -o broken.ptx

Loader for PTX code:

#include <stdio.h>

#include <cuda.h>

#define CHECK(err) __check(err, __FILE__, __LINE__)
inline void __check(CUresult err, const char *file, const int line) {
  if (CUDA_SUCCESS != err) {
    const char *name, *descr;
    cuGetErrorName(err, &name);
    cuGetErrorString(err, &name);
    fprintf(stderr, "CUDA error #%s: %s at %s:%i\n", name, descr, file, line);
    abort();
  }
}

int test(const char *path) {
  CUmodule mod;
  CHECK(cuModuleLoad(&mod, path));

  CUfunction fun;
  CHECK(cuModuleGetFunction(&fun, mod, "kernel"));

  int *gpu_val;
  CHECK(cuMemAlloc((CUdeviceptr*) &gpu_val, sizeof(int)));

  void *args[1] = {&gpu_val};
  CHECK(cuLaunchKernel(fun, 1, 1, 1, 4, 1, 1, 0, NULL, args, NULL));

  int val;
  CHECK(cuMemcpyDtoH(&val, (CUdeviceptr) gpu_val, sizeof(int)));

  CHECK(cuModuleUnload(mod));

  return val;
}

int main() {
  CHECK(cuInit(0));

  CUdevice dev;
  CHECK(cuDeviceGet(&dev, 0));

  CUcontext ctx;
  CHECK(cuCtxCreate(&ctx, 0, dev));

  printf("working: %d\n", test("working.ptx"));
  printf("broken: %d\n", test("broken.ptx"));

  CHECK(cuCtxDestroy(ctx));

  return 0;
}

Output:

$ clang++ ptx_loader.cpp -o ptx_loader -lcuda
$ ./ptx_loader
working: 32
broken: 0

Even though the generated PTX does differ between LLVM 6.0 and LLVM ToT (but differs identically wrt. the working or broken versions):

--- working_6.0.ptx        2018-04-03 10:34:01.000000000 +0200
+++ working_ToT.ptx        2018-04-03 09:57:20.000000000 +0200
@@ -39,12 +39,12 @@
        mov.u32         %r5, 32;
        shfl.down.b32 %r3, %r5, 1, 7199;
        @%p3 bra        LBB0_5;
-// %bb.6:                               // %end
-       ret;
+       bra.uni         LBB0_6;
 LBB0_5:                                 // %lane0_writeback
        ld.param.u64    %rd2, [kernel_param_0];
        cvta.to.global.u64      %rd1, %rd2;
        st.global.u32   [%rd1], %r3;
+LBB0_6:                                 // %end
        ret;
 LBB0_2:                                 // %lane0_oob
        trap;
@jlebar

This comment has been minimized.

Show comment
Hide comment
@jlebar

jlebar Apr 3, 2018

I had to revert to 375.66, as I could not reproduce the issue on 384.111 (Debian stable BPO).

I suspect that this is because the driver contains a copy of ptxas, so changing the driver version changes the ptxas version you're using. If you compiled all the way to SASS for your GPU (dunno if your frontend does this) ahead of time using ptxas, then the driver version shouldn't matter.

I can link you to how we do this in XLA if it'd be helpful.

Will leave the analysis here to @timshen91.

jlebar commented Apr 3, 2018

I had to revert to 375.66, as I could not reproduce the issue on 384.111 (Debian stable BPO).

I suspect that this is because the driver contains a copy of ptxas, so changing the driver version changes the ptxas version you're using. If you compiled all the way to SASS for your GPU (dunno if your frontend does this) ahead of time using ptxas, then the driver version shouldn't matter.

I can link you to how we do this in XLA if it'd be helpful.

Will leave the analysis here to @timshen91.

@timshen91

This comment has been minimized.

Show comment
Hide comment
@timshen91

timshen91 Apr 4, 2018

I also reproduced the ptxas miscompile on sm_61 with ptxas 8.0. I modified the launcher to call kernel<<<...>>>(...), and link the pre-compiled ptx into the launcher.

It looks like the lane0_oob block breaks the region structure (roughly a single-entry, single-exit set of basic blocks) of the program control flow graph (CFG). It has a trap instruction.

I attempted four different variations:
a) add a ret after trap.
b) add a bra.uni THE_RET_BLOCK after trap.
c) At ptx level, "inline" the trapping block into the predecessor(s).
d) replace the trap with a ret.

(a) and (b) attempted to fix the control flow graph (CFG) region structure, but they didn't work. Both (c) and (d) work, but I can' extract a principled heuristic from (c) or (d). Hopefully the new ptxas fixes this kind of issue(s) once for all.

timshen91 commented Apr 4, 2018

I also reproduced the ptxas miscompile on sm_61 with ptxas 8.0. I modified the launcher to call kernel<<<...>>>(...), and link the pre-compiled ptx into the launcher.

It looks like the lane0_oob block breaks the region structure (roughly a single-entry, single-exit set of basic blocks) of the program control flow graph (CFG). It has a trap instruction.

I attempted four different variations:
a) add a ret after trap.
b) add a bra.uni THE_RET_BLOCK after trap.
c) At ptx level, "inline" the trapping block into the predecessor(s).
d) replace the trap with a ret.

(a) and (b) attempted to fix the control flow graph (CFG) region structure, but they didn't work. Both (c) and (d) work, but I can' extract a principled heuristic from (c) or (d). Hopefully the new ptxas fixes this kind of issue(s) once for all.

@maleadt

This comment has been minimized.

Show comment
Hide comment
@maleadt

maleadt Apr 4, 2018

Member

I suspect that this is because the driver contains a copy of ptxas, so changing the driver version changes the ptxas version you're using.

Yeah, I've been deliberately using the driver for this because I assume it to be faster than having to call ptxas (we generate code at run-time, so we care about compiler performance). But with issues like this one, #165 (device support of the driver's embedded ptxas not matching that of CUDA's ptxas, despite reporting the same version), and the fact that its not possible to probe the embedded compiler's version in order to work around or guard against bugs like this one, maybe I should consider the manual approach.

It has a trap instruction.

Right, I assume this breaks the structured CFG requirement. I'll just avoid emitting trap for now, thanks for looking into alternatives though.

By the way, any suggestions on similar fatal error reporting mechanisms? trap isn't ideal, both because of this issue, and because it leaves CUDA in an unrecoverable state.
I guess XLA doesn't require such functionality though.

Member

maleadt commented Apr 4, 2018

I suspect that this is because the driver contains a copy of ptxas, so changing the driver version changes the ptxas version you're using.

Yeah, I've been deliberately using the driver for this because I assume it to be faster than having to call ptxas (we generate code at run-time, so we care about compiler performance). But with issues like this one, #165 (device support of the driver's embedded ptxas not matching that of CUDA's ptxas, despite reporting the same version), and the fact that its not possible to probe the embedded compiler's version in order to work around or guard against bugs like this one, maybe I should consider the manual approach.

It has a trap instruction.

Right, I assume this breaks the structured CFG requirement. I'll just avoid emitting trap for now, thanks for looking into alternatives though.

By the way, any suggestions on similar fatal error reporting mechanisms? trap isn't ideal, both because of this issue, and because it leaves CUDA in an unrecoverable state.
I guess XLA doesn't require such functionality though.

@jlebar

This comment has been minimized.

Show comment
Hide comment
@jlebar

jlebar Apr 4, 2018

By the way, any suggestions on similar fatal error reporting mechanisms? trap isn't ideal, both because of this issue, and because it leaves CUDA in an unrecoverable state. I guess XLA doesn't require such functionality though.

XLA doesn't require this functionality at the moment, but we have talked about adding an assert/trap instruction to XLA. Our idea for implementing it was to use a global variable. Which is ugly for sure. But I'm not sure how to do the global variable and prevent future kernels from running. That's really what trap is for. I guess we could dereference a null pointer or something, although who knows what ptxas will do when it sees that. :-/

jlebar commented Apr 4, 2018

By the way, any suggestions on similar fatal error reporting mechanisms? trap isn't ideal, both because of this issue, and because it leaves CUDA in an unrecoverable state. I guess XLA doesn't require such functionality though.

XLA doesn't require this functionality at the moment, but we have talked about adding an assert/trap instruction to XLA. Our idea for implementing it was to use a global variable. Which is ugly for sure. But I'm not sure how to do the global variable and prevent future kernels from running. That's really what trap is for. I guess we could dereference a null pointer or something, although who knows what ptxas will do when it sees that. :-/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment