Skip to content

Commit

Permalink
Make Plots an optional dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
mfherbst committed Oct 7, 2020
1 parent d9880b5 commit 380965e
Show file tree
Hide file tree
Showing 9 changed files with 106 additions and 81 deletions.
5 changes: 2 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PeriodicTable = "7b2266bf-644c-5ea3-82d8-af4bbd25a884"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Polynomials = "f27b6e38-b328-58d1-80ce-0feddd5e7a45"
Primes = "27ebfcd6-29c5-5fa9-bf4b-fb8fc14df3ae"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
Expand Down Expand Up @@ -55,7 +54,6 @@ NLsolve = "4"
Optim = "0.22, 1"
OrderedCollections = "1"
PeriodicTable = "1"
Plots = "1"
Polynomials = "1"
Primes = "0.4, 0.5"
ProgressMeter = "1"
Expand All @@ -73,8 +71,9 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
DoubleFloats = "497a8b3b-efae-58df-a0af-a86822472b78"
IntervalArithmetic = "d1acc4aa-44c8-5952-acd4-ba5d80a2a253"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Aqua", "DoubleFloats", "IntervalArithmetic", "Random", "KrylovKit"]
test = ["Test", "Aqua", "DoubleFloats", "IntervalArithmetic", "Plots", "Random", "KrylovKit"]
8 changes: 1 addition & 7 deletions examples/collinear_magnetism.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,7 @@ idown = iup + length(scfres.basis.kpoints) ÷ 2
# around the Fermi level, where the spin-up and spin-down DOS differ.

using Plots
εs = range(minimum(minimum(scfres.eigenvalues)) - .5,
maximum(maximum(scfres.eigenvalues)) + .5, length=1000)
Dup = DOS.(εs, Ref(basis), Ref(scfres.eigenvalues), spins=(1, )) # DOS spin-up
Ddown = DOS.(εs, Ref(basis), Ref(scfres.eigenvalues), spins=(2, )) # DOS spin-down
q = plot(εs, Dup, label="DOS :up", color=:blue)
plot!(q, εs, Ddown, label="DOS :down", color=:red)
vline!(q, [scfres.εF], label="εF", color=:green, lw=1.5)
plot_dos(scfres)

# Similarly the band structure shows clear differences between both spin components.
plot_bandstructure(scfres, kline_density=3, unit=:eV)
7 changes: 1 addition & 6 deletions examples/metallic_systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,4 @@ scfres.energies

# The fact that magnesium is a metal is confirmed
# by plotting the density of states around the Fermi level.

εs = range(minimum(minimum(scfres.eigenvalues)) - .5,
maximum(maximum(scfres.eigenvalues)) + .5, length=1000)
Ds = DOS.(εs, Ref(basis), Ref(scfres.eigenvalues))
q = plot(εs, Ds, label="DOS")
vline!(q, [scfres.εF], label="εF")
plot_dos(scfres)
3 changes: 2 additions & 1 deletion src/DFTK.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,13 +156,13 @@ include("external/pymatgen.jl")

export high_symmetry_kpath
export compute_bands
export plot_band_data
export plot_bandstructure
include("postprocess/band_structure.jl")

export DOS
export LDOS
export NOS
export plot_dos
include("postprocess/DOS.jl")
export compute_χ0
export apply_χ0
Expand All @@ -184,6 +184,7 @@ function __init__()
@require DoubleFloats="497a8b3b-efae-58df-a0af-a86822472b78" begin
!isdefined(DFTK, :GENERIC_FFT_LOADED) && include("fft_generic.jl")
end
@require Plots="91a5bcdd-55d7-5caf-9e0b-520d859cae80" include("plotting.jl")
end

end # module DFTK
88 changes: 88 additions & 0 deletions src/plotting.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import Plots

# This is needed to flag that the plots-dependent code has been loaded
const PLOTS_LOADED = true

"""
Plot the trace of an SCF, i.e. the absolute error of the total energy at
each iteration versus the converged energy in a semilog plot. By default
a new plot canvas is generated, but an existing one can be passed and reused
along with `kwargs` for the call to `plot!`.
"""
function ScfPlotTrace(plt=Plots.plot(yaxis=:log); kwargs...)
energies = Float64[]
function callback(info)
if info.stage == :finalize
minenergy = minimum(energies[max(1, end-5):end])
error = abs.(energies .- minenergy)
error[error .== 0] .= NaN
extra = ifelse(:mark in keys(kwargs), (), (mark=:x, ))
Plots.plot!(plt, error; extra..., kwargs...)
display(plt)
else
push!(energies, info.energies.total)
end
end
end


function plot_band_data(band_data; εF=nothing,
klabels=Dict{String, Vector{Float64}}(), unit=:eV, kwargs...)
eshift = isnothing(εF) ? 0.0 : εF
data = prepare_band_data(band_data, klabels=klabels)

# For each branch, plot all bands, spins and errors
p = Plots.plot(xlabel="wave vector")
for ibranch = 1:data.n_branches
kdistances = data.kdistances[ibranch]
for spin in data.spins, iband = 1:data.n_bands
yerror = nothing
if hasproperty(data, :λerror)
yerror = data.λerror[ibranch][spin][iband, :] ./ unit_to_au(unit)
end
energies = (data.λ[ibranch][spin][iband, :] .- eshift) ./ unit_to_au(unit)

color = (spin == :up) ? :blue : :red
Plots.plot!(p, kdistances, energies; color=color, label="", yerror=yerror,
kwargs...)
end
end

# X-range: 0 to last kdistance value
Plots.xlims!(p, (0, data.kdistances[end][end]))
Plots.xticks!(p, data.ticks["distance"],
[replace(l, raw"$\mid$" => " | ") for l in data.ticks["label"]])

ylims = [-4, 4]
!isnothing(εF) && is_metal(band_data, εF) && (ylims = [-10, 10])
ylims = round.(ylims * units.eV ./ unit_to_au(unit), sigdigits=2)
if isnothing(εF)
Plots.ylabel!(p, "eigenvalues ($(string(unit))")
else
Plots.ylabel!(p, "eigenvalues - ε_f ($(string(unit)))")
Plots.ylims!(p, ylims...)
end

p
end


function plot_dos(basis, eigenvalues; εF=nothing)
n_spin = basis.model.n_spin_components
εs = range(minimum(minimum(eigenvalues)) - .5,
maximum(maximum(eigenvalues)) + .5, length=1000)

p = Plots.plot()
spinlabels = spin_components(basis.model)
colors = [:blue, :red]
for σ in 1:n_spin
D = DOS.(εs, Ref(basis), Ref(eigenvalues), spins=(σ, ))
label = n_spin > 1 ? "DOS $(spinlabels[σ]) spin" : "DOS"
Plots.plot!(p, εs, D, label=label, color=colors[σ])
end
if !isnothing(εF)
Plots.vline!(p, [εF], label="εF", color=:green, lw=1.5)
end
p
end
plot_dos(scfres; kwargs...) = plot_dos(scfres.basis, scfres.eigenvalues; εF=scfres.εF, kwargs...)
5 changes: 5 additions & 0 deletions src/postprocess/DOS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,8 @@ function LDOS(ε, basis, eigenvalues, ψ; smearing=basis.model.smearing,
end
return sum(ρs[iσ] forin spins)
end

"""
Plot the density of states over a reasonable range
"""
function plot_dos end
47 changes: 5 additions & 42 deletions src/postprocess/band_structure.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using PyCall
import Plots

# Functionality for computing band structures, mostly using pymatgen

Expand Down Expand Up @@ -102,47 +101,6 @@ function is_metal(band_data, εF, tol=1e-4)
false
end


function plot_band_data(band_data; εF=nothing,
klabels=Dict{String, Vector{Float64}}(), unit=:eV, kwargs...)
eshift = isnothing(εF) ? 0.0 : εF
data = prepare_band_data(band_data, klabels=klabels)

# For each branch, plot all bands, spins and errors
p = Plots.plot(xlabel="wave vector")
for ibranch = 1:data.n_branches
kdistances = data.kdistances[ibranch]
for spin in data.spins, iband = 1:data.n_bands
yerror = nothing
if hasproperty(data, :λerror)
yerror = data.λerror[ibranch][spin][iband, :] ./ unit_to_au(unit)
end
energies = (data.λ[ibranch][spin][iband, :] .- eshift) ./ unit_to_au(unit)

color = (spin == :up) ? :blue : :red
Plots.plot!(p, kdistances, energies; color=color, label="", yerror=yerror,
kwargs...)
end
end

# X-range: 0 to last kdistance value
Plots.xlims!(p, (0, data.kdistances[end][end]))
Plots.xticks!(p, data.ticks["distance"],
[replace(l, raw"$\mid$" => " | ") for l in data.ticks["label"]])

ylims = [-4, 4]
!isnothing(εF) && is_metal(band_data, εF) && (ylims = [-10, 10])
ylims = round.(ylims * units.eV ./ unit_to_au(unit), sigdigits=2)
if isnothing(εF)
Plots.ylabel!(p, "eigenvalues ($(string(unit))")
else
Plots.ylabel!(p, "eigenvalues - ε_f ($(string(unit)))")
Plots.ylims!(p, ylims...)
end

p
end

function detexify_kpoint(string)
# For some reason Julia doesn't support this naively: https://github.com/JuliaLang/julia/issues/29849
replacements = ("\\Gamma" => "Γ",
Expand All @@ -163,6 +121,10 @@ are plotted in `:eV` unless a different `unit` is selected.
"""
function plot_bandstructure(basis, ρ, ρspin, n_bands;
εF=nothing, kline_density=20, unit=:eV, kwargs...)
if !isdefined(DFTK, :PLOTS_LOADED)
error("Plots not loaded. Run 'using Plots' before calling plot_bandstructure.")
end

# Band structure calculation along high-symmetry path
kcoords, klabels, kpath = high_symmetry_kpath(basis.model; kline_density=kline_density)
println("Computing bands along kpath:")
Expand All @@ -173,6 +135,7 @@ function plot_bandstructure(basis, ρ, ρspin, n_bands;
if kline_density 10
plotargs = (markersize=2, markershape=:circle)
end

plot_band_data(band_data; εF=εF, klabels=klabels, unit=unit, plotargs...)
end
function plot_bandstructure(scfres; n_bands=nothing, kwargs...)
Expand Down
23 changes: 2 additions & 21 deletions src/scf/scf_callbacks.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,5 @@
"""
Plot the trace of an SCF, i.e. the absolute error of the total energy at
each iteration versus the converged energy in a semilog plot. By default
a new plot canvas is generated, but an existing one can be passed and reused
along with `kwargs` for the call to `plot!`.
"""
function ScfPlotTrace(plt=plot(yaxis=:log); kwargs...)
energies = Float64[]
function callback(info)
if info.stage == :finalize
minenergy = minimum(energies[max(1, end-5):end])
error = abs.(energies .- minenergy)
error[error .== 0] .= NaN
extra = ifelse(:mark in keys(kwargs), (), (mark=:x, ))
plot!(plt, error; extra..., kwargs...)
display(plt)
else
push!(energies, info.energies.total)
end
end
end
# For ScfPlotTrace() see DFTK.jl/src/plotting.jl, which is conditionally loaded upon
# Plots.jl is included.

"""
Default callback function for `self_consistent_field`, which prints a convergence table
Expand Down
1 change: 0 additions & 1 deletion src/scf/self_consistent_field.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
using Plots
include("scf_callbacks.jl")

function default_n_bands(model)
Expand Down

0 comments on commit 380965e

Please sign in to comment.