Skip to content

Commit

Permalink
- more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
JannickWeisshaupt committed Oct 24, 2018
1 parent d7fad19 commit 6180ae7
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 31 deletions.
40 changes: 10 additions & 30 deletions src/solid_state_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
p_table_rev = {el.__repr__(): i for i, el in enumerate(pt.elements)}
elemental_masses = [pt.mass.mass(el) for el in pt.elements]

def remove_duplicates_old(data, treshold=0.01):
def remove_duplicates(data, treshold=0.01):
if len(data) == 0:
return np.array([])
new_data = []
Expand Down Expand Up @@ -413,7 +413,7 @@ def parse_cif_file(self, filename):
sym_atom_array[counter, 3] = atom_spec
counter += 1

atom_array_finally = remove_duplicates_old(sym_atom_array)
atom_array_finally = remove_duplicates(sym_atom_array)
else:
atom_array_finally = atom_array

Expand Down Expand Up @@ -620,7 +620,6 @@ def construct_brillouin_vertices(crystal_structure):
voronoi = Voronoi(all_points)
wigner_points = voronoi.vertices


wigner_points_cleaned = []

for w_point in wigner_points:
Expand All @@ -632,39 +631,13 @@ def construct_brillouin_vertices(crystal_structure):
for i, w_point in enumerate(wigner_points_cleaned):
vertices_array[i, :] = w_point

return remove_duplicates_old(vertices_array)
return remove_duplicates(vertices_array)


def construct_convex_hull(w_points):
hull = ConvexHull(w_points)
return hull.simplices

# bonds = []
# for simplex in hull.simplices:
# simplex = np.append(simplex, simplex[0])
# for i in range(len(simplex)-1):
# bonds.append([simplex[i],simplex[i+1]])
#
# connections = []
# for i in range(w_points.shape[0]):
# con_temp = set()
# for bond in bonds:
# if i==bond[0]:
# con_temp.add(bond[1])
# elif i==bond[1]:
# con_temp.add(bond[0])
# connections.append(con_temp)
#
# shortest_connections = []
# for i,connection in enumerate(connections):
# connection = list(connection)
# dist = np.linalg.norm(w_points[connection,:]-w_points[i,:] ,axis=1)
# in_sort = np.argsort(dist)[:3]
# shortest_connections.append(np.array(connection)[in_sort])
# shortest_connections = connections

# return shortest_connections


def convert_hs_path_to_own(hs_path):
kpoints = hs_path.kpath['kpoints']
Expand All @@ -685,6 +658,7 @@ def convert_path(path, kpoints):
conv_path = convert_path(path, kpoints)
return conv_path


def calculate_standard_path(structure):
lattice = mg.Lattice(structure.lattice_vectors)
atoms = structure.atoms
Expand Down Expand Up @@ -766,6 +740,7 @@ def calculate_mass_matrix(structure,repeat=(1,1,1),edges=False,phonon_conv=False
np.fill_diagonal(mass_matrix,mass_list_coll)
return mass_matrix


def convert_pymatgen_structure(structure):
cart_coords = structure.cart_coords
atoms_cart = np.zeros((cart_coords.shape[0],4))
Expand All @@ -776,16 +751,19 @@ def convert_pymatgen_structure(structure):
crystal_structure = CrystalStructure(structure.lattice.matrix/bohr, atoms_cart,relative_coords=False)
return crystal_structure


def query_materials_database(structure_formula):
with MPRester(API_KEY) as m:
data = m.get_data(structure_formula)
return data


def get_materials_structure_from_id(id):
with MPRester(API_KEY) as m:
structure = m.get_structure_by_material_id(id)
return convert_pymatgen_structure(structure)


def replace_greek(x):
if x == '\\Gamma':
res = 'Gamma'
Expand All @@ -804,6 +782,7 @@ def get_materials_dos_from_id(id):
data[:,1] = data[:,1] + value
return DensityOfStates(data)


def get_materials_band_structure_from_id(id):
with MPRester(API_KEY) as m:
band_structure = m.get_bandstructure_by_material_id(id)
Expand Down Expand Up @@ -855,6 +834,7 @@ def get_materials_band_structure_from_id(id):
path_conv = calculate_path_length(conv_structure,path)
return BandStructure(sorted_bands,special_k_points=path_conv)


if __name__ == "__main__":
data = query_materials_database('LiBH4')
id = data[0]['material_id']
Expand Down
Empty file added tests/test_exciting.py
Empty file.
8 changes: 7 additions & 1 deletion tests/test_sst.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def compare_path(p,o):
break
return is_equal


def test_molecular_structure():
all_tests = []

Expand All @@ -47,6 +48,7 @@ def test_molecular_structure():

assert all(all_tests)


def test_crystal_structure():
dens = diamond.density()
all_tests = []
Expand Down Expand Up @@ -83,8 +85,12 @@ def test_standard_path():
assert compare_path(stored_path,path)



def test_brillouin_construction():
w_points = sst.construct_brillouin_vertices(diamond)
brillouin_edges = sst.construct_convex_hull(w_points)
assert (len(w_points) == 24) and (len(brillouin_edges) == 44)

if __name__ == "__main__":
test_brillouin_construction()
test_crystal_structure()
test_molecular_structure()

0 comments on commit 6180ae7

Please sign in to comment.