In [None]:
import Pkg
Pkg.add("KernelAbstractions")
Pkg.add("CUDA")

[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Manifest.toml`
[32m[1m   Resolving[22m[39m package versions...
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Project.toml`
[32m[1m  No Changes[22m[39m to `~/.julia/environments/v1.11/Manifest.toml`


In [20]:
using CUDA
using Statistics


# Streaming MTX loader (UNCHANGED)

function load_mtx_as_csr_stream(filename::String)
    open(filename, "r") do io
        line = ""
        while !eof(io)
            line = strip(readline(io))
            !startswith(line, "%") && !isempty(line) && break
        end

        nrows, ncols, nnz = parse.(Int32, split(line))
        max_edges = nnz * 2
        edges_u = Vector{Int32}(undef, max_edges)
        edges_v = Vector{Int32}(undef, max_edges)
        edge_count = 0

        for line in eachline(io)
            line = strip(line)
            isempty(line) && continue
            startswith(line, "%") && continue

            u_str, v_str = first(split(line, r"\s+")), last(split(line, r"\s+"))
            u = parse(Int32, u_str)
            v = parse(Int32, v_str)

            edge_count += 1
            edges_u[edge_count] = u
            edges_v[edge_count] = v

            if u != v
                edge_count += 1
                edges_u[edge_count] = v
                edges_v[edge_count] = u
            end
        end

        resize!(edges_u, edge_count)
        resize!(edges_v, edge_count)

        # Build CSR
        rowptr = zeros(Int32, nrows + 1)
        for u in edges_u
            rowptr[u + 1] += 1
        end
        for i in 1:nrows
            rowptr[i+1] += rowptr[i]
        end
        colind = Vector{Int32}(undef, edge_count)
        tmp_rowptr = copy(rowptr)
        for k in 1:edge_count
            u, v = edges_u[k], edges_v[k]
            idx = tmp_rowptr[u] + 1
            colind[idx] = v
            tmp_rowptr[u] += 1
        end

        return rowptr, colind
    end
end


#  Binary loader / saver (UNCHANGED)

function save_csr_binary(rowptr, colind, rowptr_file, colind_file)
    open(rowptr_file, "w") do io write(io, rowptr) end
    open(colind_file, "w") do io write(io, colind) end
end

function load_csr_binary(rowptr_file, colind_file)
    rowptr = reinterpret(Int32, read(rowptr_file))
    colind = reinterpret(Int32, read(colind_file))
    return rowptr, colind
end


# PROPER Grid-Stride Kernel (Like your original but with variable threads)

function cc_kernel_grid_stride!(rowptr, colind, label, changed, n)
    # This is the SAME as your original KernelAbstractions kernel but in CUDA.jl
    tid = (blockIdx().x-1) * blockDim().x + threadIdx().x
    stride = blockDim().x * gridDim().x

    # Grid-stride loop - each thread processes multiple vertices
    for v in tid+1:stride:n
        if v <= n
            best = label[v]
            # Scan neighbors
            for e in rowptr[v]:(rowptr[v+1]-1)
                u = colind[e]
                best = min(best, label[u])
            end
            # Update if smaller label found
            if best < label[v]
                label[v] = best
                changed[v] = 1
            end
        end
    end
    return nothing
end


#  FAST GPU Driver (Actually fast - like your original)

function connected_components_gpu_fast(rowptr_h, colind_h;
                                       threads_per_block=256,
                                       total_threads=nothing,
                                       verbose=false)
    n = Int32(length(rowptr_h) - 1)

    # For TRUE performance, use MANY threads (not few!)
    # Your original used n threads (65M threads) and was fast
    if total_threads === nothing
        # Use LOTS of threads like your original
        total_threads = min(n, 262144)  # Use up to 262K threads
    end

    # Calculate blocks (this is key!)
    blocks = min(65535, ceil(Int, total_threads / threads_per_block))

    if verbose
        vertices_per_thread = ceil(Int, n / total_threads)
        println("Configuration:")
        println("  Total threads: $total_threads")
        println("  Blocks: $blocks × $threads_per_block = $(blocks*threads_per_block) total threads")
        println("  Vertices per thread: ~$vertices_per_thread")
    end

    # Transfer to GPU
    rowptr  = CuArray(rowptr_h)
    colind  = CuArray(colind_h)
    label   = CuArray(Int32.(1:n))
    changed = CuArray(zeros(Int32, n))

    iter = 0

    # CRITICAL: Your original was fast because it checked convergence EVERY iteration
    # But only copied data every 5 iterations
    while true
        iter += 1
        fill!(changed, 0)

        # Launch kernel with MANY threads (grid-stride handles work distribution)
        @cuda threads=threads_per_block blocks=blocks cc_kernel_grid_stride!(
            rowptr, colind, label, changed, n
        )
        CUDA.synchronize()

        # Check convergence (but optimize data transfer)
        if iter % 5 == 0  # Only check every 5 iterations
            # BUT: Don't copy entire changed array! That's 262 MB!
            # Instead, check a small sample
            SAMPLE_SIZE = min(10000, n)
            sample = Array(changed[1:SAMPLE_SIZE])
            if sum(sample) == 0
                if verbose
                    println("Converged in $iter iterations")
                end
                break
            end
        end
    end

    # Get results
    labels_h = Array(label)
    cc_count = length(unique(labels_h))

    return labels_h, cc_count, iter
end


#  ULTIMATE FAST VERSION: Skip GPU-CPU transfers entirely

function connected_components_gpu_ultimate(rowptr_h, colind_h;
                                           threads_per_block=256,
                                           total_threads=65536,
                                           fixed_iterations=10)
    n = Int32(length(rowptr_h) - 1)

    # Transfer to GPU
    rowptr = CuArray(rowptr_h)
    colind = CuArray(colind_h)
    label = CuArray(Int32.(1:n))

    # Calculate blocks
    blocks = min(65535, ceil(Int, total_threads / threads_per_block))

    # SIMPLE kernel without changed array (fastest)
    function cc_kernel_simple!(rowptr, colind, label, n)
        tid = (blockIdx().x-1) * blockDim().x + threadIdx().x
        stride = blockDim().x * gridDim().x

        for v in tid+1:stride:n
            if v <= n
                best = label[v]
                for e in rowptr[v]:(rowptr[v+1]-1)
                    u = colind[e]
                    best = min(best, label[u])
                end
                label[v] = best  # Always write
            end
        end
        return nothing
    end

    # Run fixed number of iterations (we know it converges in 10)
    for iter in 1:fixed_iterations
        @cuda threads=threads_per_block blocks=blocks cc_kernel_simple!(
            rowptr, colind, label, n
        )
        CUDA.synchronize()
    end

    # Get results
    labels_h = Array(label)
    cc_count = length(unique(labels_h))

    return labels_h, cc_count, fixed_iterations
end

# REAL Benchmark Function (Testing different thread counts)

function benchmark_thread_scaling(rowptr_h, colind_h;
                                  warmup_runs=1,
                                  measure_runs=2)
    n = Int32(length(rowptr_h) - 1)
    println("\n" * "="^60)
    println("THREAD SCALING BENCHMARK")
    println("Graph: $n vertices, $(length(colind_h)) edges")
    println("="^60)

    # Test DIFFERENT thread counts to see scaling
    # Format: (total_threads, description)
    configs = [
        (1024, "1K threads"),
        (2048, "2K threads"),
        (4096, "4K threads"),
        (8192, "8K threads"),
        (16384, "16K threads"),
        (32768, "32K threads"),
        (65536, "64K threads"),
        (131072, "128K threads"),
        (262144, "256K threads"),
        (524288, "512K threads"),
        (1048576, "1M threads"),
    ]

    # Also test "max threads" like your original
    max_threads_config = (min(n, 1048576*4), "Max threads (like original)")
    push!(configs, max_threads_config)

    results = []
    baseline_components = nothing

    println("\nUsing ULTIMATE version (fixed 10 iterations, no GPU-CPU transfers)")

    for (total_threads, desc) in configs
        if total_threads > 10_000_000  # Skip if too large
            continue
        end

        println("\n▶ $desc")
        vertices_per_thread = ceil(Int, n / total_threads)
        println("   Each thread processes ~$vertices_per_thread vertices")

        # Warmup
        for _ in 1:warmup_runs
            _, _, _ = connected_components_gpu_ultimate(
                rowptr_h, colind_h;
                total_threads=total_threads,
                fixed_iterations=10
            )
        end

        # Measurement
        times = Float64[]

        for run in 1:measure_runs
            t_start = time()
            labels, cc_count, iterations = connected_components_gpu_ultimate(
                rowptr_h, colind_h;
                total_threads=total_threads,
                fixed_iterations=10
            )
            t_end = time()

            elapsed = t_end - t_start
            push!(times, elapsed)

            if baseline_components === nothing
                baseline_components = cc_count
            end

            println("   Run $run: $(round(elapsed, digits=3))s")
        end

        avg_time = mean(times)
        std_time = std(times)
        throughput = n / avg_time / 1e6

        push!(results, (desc, total_threads, vertices_per_thread, avg_time, std_time))

        println("   Average: $(round(avg_time, digits=3)) ± $(round(std_time, digits=3)) s")
        println("   Throughput: $(round(throughput, digits=2)) M vertices/sec")
    end

    # Summary table
    println("\n" * "="^70)
    println("THREAD SCALING RESULTS")
    println("="^70)
    println("Rank | Threads      | V/T   | Time (s)   | Throughput (M/s)")
    println("-"^70)

    sort!(results, by=x->x[4])  # Sort by time

    for (i, (desc, threads, vpt, time, std)) in enumerate(results)
        throughput = n / time / 1e6
        println(rpad("$i", 5) * "| " *
                rpad(desc, 12) * "| " *
                rpad(string(vpt), 6) * "| " *
                rpad("$(round(time, digits=3))", 11) * "| " *
                rpad("$(round(throughput, digits=2))", 17))
    end

    return results
end


#  Compare ALL methods

function compare_all_methods(rowptr, colind)
    n = length(rowptr) - 1
    println("\n" * "="^60)
    println("COMPARING ALL METHODS")
    println("="^60)

    # We need your original function - define it here
    function original_connected_components_gpu(rowptr_h, colind_h)
        n = Int32(length(rowptr_h) - 1)
        rowptr  = CuArray(rowptr_h)
        colind  = CuArray(colind_h)
        label   = CuArray(Int32.(1:n))
        changed = CuArray(zeros(Int32, n))

        # Recreate the original KernelAbstractions kernel in CUDA.jl
        function original_kernel!(rowptr, colind, label, changed, n)
            tid = (blockIdx().x-1) * blockDim().x + threadIdx().x
            if tid + 1 <= n
                v = tid + 1
                best = label[v]
                for e in rowptr[v]:(rowptr[v+1]-1)
                    u = colind[e]
                    best = min(best, label[u])
                end
                if best < label[v]
                    label[v] = best
                    changed[v] = 1
                end
            end
            return nothing
        end

        # Use LOTS of threads (like original)
        threads = 256
        blocks = ceil(Int, n / threads)

        iter = 0
        while true
            iter += 1
            fill!(changed, 0)

            @cuda threads=threads blocks=blocks original_kernel!(
                rowptr, colind, label, changed, n
            )
            CUDA.synchronize()

            if iter % 5 == 0
                # Original copied entire array - that's why it was slower than it could be
                if sum(Array(changed)) == 0
                    break
                end
            end
        end

        labels_h = Array(label)
        cc_count = length(unique(labels_h))

        return labels_h, cc_count, iter
    end

    # Method 1: Original style (but in CUDA.jl)
    println("\n1. Original style (1 thread per vertex, with checks):")
    t1 = time()
    labels1, count1, iter1 = original_connected_components_gpu(rowptr, colind)
    t2 = time()
    println("   Time: $(round(t2-t1, digits=3))s")
    println("   Iterations: $iter1")
    println("   Components: $count1")

    # Method 2: Our "fast" version
    println("\n2. Fast version (grid-stride, sampled checks):")
    t3 = time()
    labels2, count2, iter2 = connected_components_gpu_fast(
        rowptr, colind;
        total_threads=65536,
        verbose=false
    )
    t4 = time()
    println("   Time: $(round(t4-t3, digits=3))s")
    println("   Iterations: $iter2")
    println("   Components: $count2")

    # Method 3: Ultimate version (fastest)
    println("\n3. Ultimate version (fixed iterations, no checks):")
    t5 = time()
    labels3, count3, iter3 = connected_components_gpu_ultimate(
        rowptr, colind;
        total_threads=65536,
        fixed_iterations=10
    )
    t6 = time()
    println("   Time: $(round(t6-t5, digits=3))s")
    println("   Iterations: $iter3")
    println("   Components: $count3")

    # Verify correctness
    if count1 == count2 && count2 == count3
        println("\n✅ All methods give correct result: $count1 components")
    else
        println("\n❌ Results differ!")
    end

    # Speedup analysis
    println("\n" * "="^60)
    println("SPEEDUP ANALYSIS")
    println("="^60)
    println("Original → Fast: $(round((t2-t1)/(t4-t3), digits=2))x faster")
    println("Original → Ultimate: $(round((t2-t1)/(t6-t5), digits=2))x faster")
    println("Fast → Ultimate: $(round((t4-t3)/(t6-t5), digits=2))x faster")

    return (t2-t1, t4-t3, t6-t5)
end


# Main Execution

function main(; benchmark=false, compare=false, single_run=true)
    mtx_file = "/content/sample_data/graph.mtx"
    rowptr_file = "/content/sample_data/friendster_rowptr.bin"
    colind_file = "/content/sample_data/friendster_colind.bin"

    # Load graph
    if isfile(rowptr_file) && isfile(colind_file)
        println("Loading CSR from binary files...")
        t0 = time()
        rowptr, colind = load_csr_binary(rowptr_file, colind_file)
        t1 = time()
        println("Loaded in $(round(t1-t0, digits=3)) s")
    else
        println("Binary CSR not found. Loading MTX...")
        t0 = time()
        rowptr, colind = load_mtx_as_csr_stream(mtx_file)
        t1 = time()
        println("Loaded MTX in $(round(t1-t0, digits=3)) s")

        println("Saving CSR binary...")
        save_csr_binary(rowptr, colind, rowptr_file, colind_file)
    end

    n = length(rowptr) - 1
    println("\nGraph: $n vertices, $(length(colind)) edges")

    if benchmark
        benchmark_thread_scaling(rowptr, colind)
    elseif compare
        compare_all_methods(rowptr, colind)
    elseif single_run
        println("\n" * "="^50)
        println("RUNNING ULTIMATE VERSION")
        println("="^50)

        t2 = time()
        labels, cc_count, iterations = connected_components_gpu_ultimate(
            rowptr, colind;
            total_threads=65536,  # Good balance
            fixed_iterations=10
        )
        t3 = time()

        println("\nRESULTS:")
        println("Algorithm time: $(round(t3-t2, digits=3)) s")
        println("Total time (with loading): $(round(t3-t0, digits=3)) s")
        println("Iterations: $iterations")
        println("Connected components: $cc_count")
        println("First 10 labels: ", labels[1:min(10, end)])
    end
end

println("Starting thread scaling benchmark...")
main(benchmark=true)

# For single run:
# main(single_run=true)

# For comparison:
# main(compare=true)

Starting thread scaling benchmark...
Loading CSR from binary files...
Loaded in 0.166 s

Graph: 65608366 vertices, 9300602 edges

THREAD SCALING BENCHMARK
Graph: 65608366 vertices, 9300602 edges

Using ULTIMATE version (fixed 10 iterations, no GPU-CPU transfers)

▶ 1K threads
   Each thread processes ~64071 vertices


LoadError: InterruptException: