Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

graph building and file I/O with Xtals.jl #112

Merged
merged 19 commits into from
Sep 7, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/Documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- uses: actions/checkout@v2
- uses: julia-actions/setup-julia@latest
with:
version: '1.4'
version: '1.6'
- name: Install dependencies
env:
PYTHON: ""
Expand Down
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ GraphPlot = "a2cc645c-3eea-5389-862e-a155d0052231"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LightGraphs = "093fc24a-ae57-5d10-9952-331d41423f4d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Serialization = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
SimpleWeightedGraphs = "47aef6b3-ad0c-573a-a1e2-d07658019622"
Xtals = "ede5f01d-793e-4c47-9885-c447d1f18d6d"
thazhemadam marked this conversation as resolved.
Show resolved Hide resolved

thazhemadam marked this conversation as resolved.
Show resolved Hide resolved
[compat]
CSV = "0.7, 0.8"
Expand All @@ -32,7 +34,7 @@ JSON = "0.21"
LightGraphs = "1"
PyCall = "1"
SimpleWeightedGraphs = "1"
julia = "1.4, 1.5, 1.6"
julia = "1.6"
thazhemadam marked this conversation as resolved.
Show resolved Hide resolved

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
5 changes: 1 addition & 4 deletions src/atoms/atomgraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ AtomGraph(adj::Array{R}, elements::Vector{String}, id = "") where {R<:Real} =
AtomGraph(SimpleWeightedGraph(adj), elements, id)

"""
AtomGraph(input_file_path, id = splitext(input_file_path)[begin]; output_file_path = nothing, featurization = nothing, overwrite_file = false, use_voronoi = false, cutoff_radius = 8.0, max_num_nbr = 12, dist_decay_func = inverse_square, normalize_weights = true)
AtomGraph(input_file_path, id = splitext(input_file_path)[begin]; output_file_path = nothing, featurization = nothing, overwrite_file = false, use_voronoi = false, cutoff_radius = 8.0, max_num_nbr = 12, dist_decay_func = inverse_square)

Construct an AtomGraph object from a structure file.

Expand All @@ -67,7 +67,6 @@ Construct an AtomGraph object from a structure file.
- `cutoff_radius::Real = 8.0`: If not using Voronoi neighbor lists, longest allowable distance to a neighbor, in Angstroms
- `max_num_nbr::Integer = 12`: If not using Voronoi neighbor lists, largest allowable number of neighbors
- `dist_decay_func = inverse_square`: Function by which to assign edge weights according to distance between neighbors
- `normalize_weights::Bool = true`: Whether to normalize weights such that the largest is 1.0

# Note
`max_num_nbr` is a "soft" limit – if multiple neighbors are at the same distance, the full neighbor list may be longer.
Expand All @@ -81,7 +80,6 @@ function AtomGraph(
cutoff_radius::Real = 8.0,
max_num_nbr::Integer = 12,
dist_decay_func::Function = inverse_square,
normalize_weights::Bool = true,
)

local ag
Expand All @@ -103,7 +101,6 @@ function AtomGraph(
cutoff_radius = cutoff_radius,
max_num_nbr = max_num_nbr,
dist_decay_func = dist_decay_func,
normalize_weights = normalize_weights,
)
ag = AtomGraph(adj_mat, elements, id)
catch
Expand Down
2 changes: 1 addition & 1 deletion src/codecs/OneHotOneCold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ end

"A flexible version of Flux.onehot that can handle both categorical and continuous-valued encoding."
function build_onehot_vec(val, bins, categorical)
local bin_index , onehot_vec
local bin_index, onehot_vec
if categorical
onehot_vec = [0.0 for i = 1:length(bins)]
bin_index = findfirst(isequal(val), bins)
Expand Down
2 changes: 1 addition & 1 deletion src/featurizations/graphnodefeaturization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function GraphNodeFeaturization(
categorical::Union{Vector{Bool},Bool,Nothing} = nothing,
)
num_features = length(feature_names)
local lookup_table_here , logspaced_here , categorical_here , nbins_here
local lookup_table_here, logspaced_here, categorical_here, nbins_here
if isnothing(lookup_table)
lookup_table_here = atom_data_df
else
Expand Down
2 changes: 1 addition & 1 deletion src/utils/elementfeature_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ function get_bins(
logspaced::Bool = default_log(feature_name, lookup_table),
categorical::Bool = default_categorical(feature_name, lookup_table),
)
local bins , min_val , max_val
local bins, min_val, max_val

if categorical
if feature_name in categorical_feature_names
Expand Down
123 changes: 93 additions & 30 deletions src/utils/graph_building.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,23 @@ export inverse_square, exp_decay
using PyCall
using ChemistryFeaturization
using Serialization
using Xtals
using NearestNeighbors
#rc[:paths][:crystals] = @__DIR__ # so that Xtals.jl knows where things are

# options for decay of bond weights with distance...
# user can of course write their own as well
inverse_square(x) = x^-2.0
exp_decay(x) = exp(-x)

"""
Function to build graph from a file storing a crystal structure (currently supports anything ase.io.read can read in). Returns an AtomGraph object.
Build graph from a file storing a crystal structure (currently supports anything Xtals.jl can read in). Returns the weight matrix and elements used for constructing an `AtomGraph`.

# Arguments
## Required Arguments
- `file_path::String`: Path to ASE-readable file containing a molecule/crystal structure

## Keyword Arguments
- `normalize_weights::Bool=true`: Whether to rescale graph weights such that the maximum value is 1.0 (recommended)
- `use_voronoi::bool`: if true, use Voronoi method for neighbor lists, if false use cutoff method

(The rest of these parameters are only used if `use_voronoi==false`)
Expand All @@ -36,45 +38,56 @@ function build_graph(
cutoff_radius::Real = 8.0,
max_num_nbr::Integer = 12,
dist_decay_func::Function = inverse_square,
normalize_weights::Bool = true,
)
aseio = pyimport_conda("ase.io", "ase", "conda-forge")
atoms_object = aseio.read(file_path)

# list of atom symbols
atom_ids = [get(atoms_object, i - 1).symbol for i = 1:length(atoms_object)]

# check if any nonperiodic BC's
nonpbc = any(.!atoms_object.pbc)
local cant_voronoi = false
if nonpbc & use_voronoi
@warn "Voronoi edge weights are not supported if any direction in the structure is nonperiodic. Using cutoff weights method..."
cant_voronoi = true
end
c = Crystal(file_path)
atom_ids = String.(c.atoms.species)

if use_voronoi && !cant_voronoi
if use_voronoi
@info "Note that building neighbor lists and edge weights via the Voronoi method requires the assumption of periodic boundaries. If you are building a graph for a molecule, you probably do not want this..."
s = pyimport_conda("pymatgen.core.structure", "pymatgen", "conda-forge")
pmgase = pyimport_conda("pymatgen.io.ase", "pymatgen", "conda-forge")
aa = pmgase.AseAtomsAdaptor()
struc = aa.get_structure(atoms_object)
struc = s.Structure.from_file(file_path)
weight_mat = weights_voronoi(struc)
return weight_mat, atom_ids
else
nl = pyimport_conda("ase.neighborlist", "ase", "conda-forge")
is, js, dists = nl.neighbor_list("ijd", atoms_object, cutoff_radius)
weight_mat = weights_cutoff(
is .+ 1,
js .+ 1,
dists;
build_graph(
c;
cutoff_radius = cutoff_radius,
max_num_nbr = max_num_nbr,
dist_decay_func = dist_decay_func,
)
end

if normalize_weights
weight_mat = weight_mat ./ maximum(weight_mat)
end
end

"""
Build graph from a Crystal object. Currently only supports the "cutoff" method of neighbor list/weight calculation (not Voronoi).
This dispatch exists to support autodiff of graph-building.

# Arguments
## Required Arguments
- `crys::Crystal`: Crystal object representing the atomic geometry from which to build a graph

## Keyword Arguments
- `cutoff_radius::Real=8.0`: cutoff radius for atoms to be considered neighbors (in angstroms)
- `max_num_nbr::Integer=12`: maximum number of neighbors to include (even if more fall within cutoff radius)
- `dist_decay_func::Function=inverse_square`: function to determine falloff of graph edge weights with neighbor distance
"""
function build_graph(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we have a reference to this docstring also under Utilities/Graph Building?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't the @autodocs do that?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can't check right now because my local setup for building the docs is somehow broken again :/ but I thought the @autodocs block should do it?

crys::Crystal;
cutoff_radius::Real = 8.0,
max_num_nbr::Integer = 12,
dist_decay_func::Function = inverse_square,
)

return weight_mat, atom_ids
is, js, dists = neighbor_list(crys; cutoff_radius = cutoff_radius)
weight_mat = weights_cutoff(
is,
js,
dists;
max_num_nbr = max_num_nbr,
dist_decay_func = dist_decay_func,
)
return weight_mat, String.(crys.atoms.species)
end

"""
Expand Down Expand Up @@ -103,6 +116,9 @@ function weights_cutoff(is, js, dists; max_num_nbr = 12, dist_decay_func = inver

# average across diagonal, just in case
weight_mat = 0.5 .* (weight_mat .+ weight_mat')

# normalize weights
weight_mat = weight_mat ./ maximum(weight_mat)
thazhemadam marked this conversation as resolved.
Show resolved Hide resolved
end

"""
Expand Down Expand Up @@ -131,6 +147,53 @@ function weights_voronoi(struc)

# average across diagonal (because neighborness isn't strictly symmetric in the way we're defining it here)
weight_mat = 0.5 .* (weight_mat .+ weight_mat')

# normalize weights
weight_mat = weight_mat ./ maximum(weight_mat)
end


"""
Find all lists of pairs of atoms in `crys` that are within a distance of `cutoff_radius` of each other, respecting periodic boundary conditions.

Returns as is, js, dists to be compatible with ASE's output format for the analogous function.
"""
function neighbor_list(crys::Crystal; cutoff_radius::Real = 8.0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, can we have a reference to this docstring in the docs? 😄

n_atoms = crys.atoms.n

# make 3 x 3 x 3 supercell and find indices of "middle" atoms
# as well as index mapping from outer -> inner
supercell = replicate(crys, (3, 3, 3))

# check for size of cutoff radius relative to size of cell
min_celldim = min(crys.box.a, crys.box.b, crys.box.c)
if cutoff_radius >= min_celldim
@warn "Your cutoff radius is quite large relative to the size of your unit cell. This may cause issues with neighbor list generation, and will definitely cause a very dense graph. To avoid issues, I'm setting it to be approximately equal to the smallest unit cell dimension."
cutoff_radius = 0.99*min_celldim
rkurchin marked this conversation as resolved.
Show resolved Hide resolved
end

# todo: try BallTree, also perhaps other leafsize values
#tree = BruteTree(sc.atoms.coords.xf, PeriodicEuclidean([1.0, 1.0, 1.0]))
tree = BruteTree(Cart(supercell.atoms.coords, supercell.box).x)

is_raw = 13*n_atoms+1:14*n_atoms
js_raw =
inrange(tree, Cart(supercell.atoms.coords[is_raw], supercell.box).x, cutoff_radius)

index_map(i) = (i - 1) % n_atoms + 1 # I suddenly understand why some people dislike 1-based indexing

# this looks horrifying but it does do the right thing...
#ijraw_pairs = [p for p in Iterators.flatten([Iterators.product([p for p in zip(is_raw, js_raw)][n]...) for n in 1:4]) if p[1]!=p[2]]
split1 = map(zip(is_raw, js_raw)) do x
return [p for p in [(x[1], [j for j in js if j!=x[1]]...) for js in x[2]] if length(p)==2]
rkurchin marked this conversation as resolved.
Show resolved Hide resolved
end
ijraw_pairs = [(split1...)...]
get_pairdist((i,j)) = distance(supercell.atoms, supercell.box, i, j, false)
rkurchin marked this conversation as resolved.
Show resolved Hide resolved
dists = get_pairdist.(ijraw_pairs)
is = index_map.([t[1] for t in ijraw_pairs])
js = index_map.([t[2] for t in ijraw_pairs])

return is, js, dists
end

# TODO: graphs from SMILES via OpenSMILES.jl
Expand Down
7 changes: 0 additions & 7 deletions test/test_data/strucs/methane.xyz

This file was deleted.

11 changes: 0 additions & 11 deletions test/test_data/strucs/mp-195.poscar

This file was deleted.

Binary file removed test/test_data/strucs/mp-195.traj
Binary file not shown.
6 changes: 0 additions & 6 deletions test/test_data/strucs/mp-195.xyz

This file was deleted.

Binary file modified test/test_data/strucs/testgraph.jls
Binary file not shown.
43 changes: 12 additions & 31 deletions test/utils/GraphBuilding_tests.jl
Original file line number Diff line number Diff line change
@@ -1,41 +1,22 @@
using Test
using ChemistryFeaturization.Utils.GraphBuilding
using Xtals

@testset "GraphBuilding" begin
adj, els = build_graph(
abspath(@__DIR__, "..", "test_data", "strucs", "mp-195.cif"),
use_voronoi = true,
)
path1 = abspath(@__DIR__, "..", "test_data", "strucs", "mp-195.cif")
adj, els = build_graph(path1; use_voronoi = true)
wm_true = [0.0 1.0 1.0 1.0; 1.0 0.0 1.0 1.0; 1.0 1.0 0.0 1.0; 1.0 1.0 1.0 0.0]
@test adj == wm_true
@test els == ["Ho", "Pt", "Pt", "Pt"]
els_true = ["Ho", "Pt", "Pt", "Pt"]

adj, els = build_graph(
abspath(@__DIR__, "..", "test_data", "strucs", "mp-195.cif");
use_voronoi = false,
)
@test adj == wm_true
@test els == ["Ho", "Pt", "Pt", "Pt"]
@test els == els_true

# tests for some other file formats
info = Tuple{Matrix,Vector{String}}[]
for fp in ["mp-195.poscar", "mp-195.traj", "mp-195.xyz"]
push!(info, build_graph(abspath(@__DIR__, "..", "test_data", "strucs", fp)))
end
for t in info
@test t[1] == wm_true
@test t[2] == els
end
adj, els = build_graph(path1; use_voronoi = false)
@test adj == wm_true
@test els == els_true

# test for nonperiodic system
@test_logs (
:warn,
"Voronoi edge weights are not supported if any direction in the structure is nonperiodic. Using cutoff weights method...",
) build_graph(
abspath(@__DIR__, "..", "test_data", "strucs", "methane.xyz"),
use_voronoi = true,
)
adj, els = build_graph(abspath(@__DIR__, "..", "test_data", "strucs", "methane.xyz"))
@test all(isapprox.(adj[2:5, 1], 1.0, atol = 1e-4))
@test all(isapprox.(adj[3:2, 2], 0.375, atol = 1e-5))
# test that we get the same results building from a Crsytal object
adjc, elsc = build_graph(Crystal(path1))
@test adjc == wm_true
@test elsc == els_true
end