In [None]:
#### IMPORTS ####
using ProgressMeter
using Plots
theme(:ggplot2)

In [None]:
function dispersion(t, kx_arr, ky_arr)
    kx_arr_flattened = repeat(kx_arr, inner=size(ky_arr))
    ky_arr_flattened = repeat(ky_arr, outer=size(kx_arr))
    return -2 * t .* (cos.(kx_arr_flattened) + cos.(ky_arr_flattened))
end


function getDOS(num_kspace, dispersionArray)
    delta_k = 2 * pi / num_kspace
    kspaceDos = 1 / delta_k
    densityOfStates = zeros(num_kspace * num_kspace)
    Threads.@threads for j in 0:num_kspace-1
        for i in 1:num_kspace
            E_xy = dispersionArray[j * num_kspace + i]
            E_xpy = dispersionArray[j * num_kspace + i % num_kspace + 1]
            E_xyp = dispersionArray[(j * num_kspace + num_kspace) % num_kspace^2 + i]
            dE_xydk = sqrt((E_xpy - E_xy)^2 / delta_k^2
                           + (E_xyp - E_xy)^2 / delta_k^2
                           )
            densityOfStates[j * num_kspace + i] = kspaceDos / dE_xydk
        end
    end
    replace!(densityOfStates, Inf=>maximum(densityOfStates[densityOfStates .≠ Inf]))
    return densityOfStates
end


function getContour(dispersionArray, num_kspace, energy)
    contourPoints = [[] for j in 1:num_kspace]
    Threads.@threads for j in 0:num_kspace-1
        energyDiffArr = dispersionArray[j * num_kspace + 1: (j + 1) * num_kspace] .- energy
        for i in 1:num_kspace
            if energyDiffArr[i] * energyDiffArr[i % num_kspace + 1] <= 0
                push!(contourPoints[j+1], j * num_kspace + i)
            end
        end
    end
    return collect(Iterators.flatten(contourPoints))
end


function getRenorm(args)
    omega, k1, Ek1, k2_arr, Ek2_arr, cutoffPoints, energyCutoff, densityOfStates, kondoJArray, flags_k1_arr, deltaD = args
    renormArr = zeros(size(k2_arr))
    J_k1 = kondoJArray[k1,cutoffPoints]
    for (index, (k2, Ek2)) in collect(enumerate(zip(k2_arr, Ek2_arr)))
        J_k2 = kondoJArray[k2,cutoffPoints]
        if flags_k1_arr[k2] == 0
            continue
        end
        denominators = omega - abs(energyCutoff) / 2  - Ek1 / 2 - Ek2 / 2 .+ (J_k1 + J_k2) / 4
        if ! all(<(0), denominators)
            flags_k1_arr[k2] = 0
        else
            renormArr[index] = -sum(J_k1 .* J_k2 .* densityOfStates[cutoffPoints] * deltaD ./ denominators)
        end
    end
    return k1, flags_k1_arr, renormArr
end


function main(num_kspace, t, J_init)
    kx_arr = range(-pi, stop=pi, length=num_kspace)
    ky_arr = copy(kx_arr)
    dispersionArray = dispersion(t, kx_arr, ky_arr)
    densityOfStates = getDOS(num_kspace, dispersionArray)

    bandwidth = maximum(dispersionArray)
    deltaD = bandwidth / (num_kspace - 1)
    kondoJArray = Array{Float64}(undef, num_kspace^2, num_kspace^2, num_kspace)
    kondoJArray[:,:,1] .= J_init
    flags = fill(1, num_kspace^2, num_kspace^2)
    # println(flags)
    @showprogress for (stepIndex, energyCutoff) in collect(enumerate(range(bandwidth, stop=deltaD, step=-deltaD)))
        kondoJArray[:,:,stepIndex+1] = kondoJArray[:,:,stepIndex]
        cutoffPoints = getContour(dispersionArray, num_kspace, energyCutoff)
        flags[cutoffPoints,:] .= 0
        flags[:,cutoffPoints] .= 0
        if all(==(0), flags)
            break
        end
        omega = -energyCutoff/2
        innerIndicesArr = (1:length(dispersionArray))[abs.(dispersionArray) .< energyCutoff]
        innerEnergiesArr = dispersionArray[innerIndicesArr]
        Threads.@threads for (k1, Ek1) in collect(zip(innerIndicesArr, innerEnergiesArr))
            args = (omega, k1, Ek1, innerIndicesArr, innerEnergiesArr, 
                cutoffPoints, energyCutoff, densityOfStates, 
                kondoJArray[:,:,stepIndex], flags[k1,:], deltaD)
            k1, flags_k1_arr, renormArr = getRenorm(args)
            flags[k1,:] = flags_k1_arr
            kondoJArray[k1,innerIndicesArr,stepIndex+1] += renormArr
            kondoJArray[k1,innerIndicesArr,stepIndex+1][kondoJArray[k1,innerIndicesArr,stepIndex] .* kondoJArray[k1,innerIndicesArr,stepIndex+1] .< 0] .= 0
        end
    end
    return kondoJArray, dispersionArray
end

In [None]:
p = plot()
num_kspace = 31
t = 1
@time kondoJArray, dispersionArray = main(num_kspace, t, 0.1);
kFpoints = getContour(dispersionArray, num_kspace, t)[begin:30:end]
k2points = getContour(dispersionArray, num_kspace, 0)[[10]]
for (p1, p2) in Iterators.product(kFpoints, k2points)
    plot!(kondoJArray[p1,p2,:], yaxis=:log, linewidth=2, thickness_scaling = 1.5)
end
display(p)