In [3]:
import pymatgen
import mp_api

MP_API_KEY =  # Materials Project API key here

# Plotting scripts

In [4]:
import plotly
import plotly.io
import plotly.graph_objects as go

import numpy as np
from matscipy.neighbours import neighbour_list

def plot_sphere(x ,y ,z, radius, color='black',  resolution = 20):
    """
    Parameters
    ----------
    x,y,z : coordinates of center of sphere
    radius : radius of sphere
    color : color of sphere
    resolution: resolution of sphere
    """
    u = np.linspace(0, 2*np.pi, resolution)
    v = np.linspace(0, 2*np.pi, resolution)
    u, v = np.meshgrid(u,v)
    X = radius * np.cos(u)*np.sin(v) + x
    Y = radius * np.sin(u)*np.sin(v) + y
    Z = radius * np.cos(v) + z
    return [go.Surface(x=X, y=Y, z=Z, colorscale=[[0,color],[1,color]], showscale=False)]

def plot_cell(coords, atoms, lattice, cell, atoms_style):
    """
    plots atoms in a unit cell
    Parameters
    ----------
    coords: coordinates of atoms
    atoms: atom types
    lattice: lattice parameters
    cell: unit cell to plot
    atoms_style: colors and sizes to plot atoms
    """
    plot_coords = coords+cell@lattice
    atoms_plot = []
    for i, atom in enumerate(atoms):
        atoms_plot+=plot_sphere(plot_coords[i,0], plot_coords[i,1], plot_coords[i,2], 
                                atoms_style[atom][1], atoms_style[atom][0])
    return atoms_plot

def plot_bond(coords, lattice, cell, cells, sender, receiver, sender_shift):
    """
    plots bond
    Parameters
    ----------
    coords: coordinates of atoms
    lattice: lattice parameters
    cell: unit cell receiver atom is in
    cells: cells plotted
    sender: starting atom
    receiver: ending atom
    sender_shift: unit cells sender shifted from cell of receiver
    """
    coords1 = coords[receiver]+cell@lattice
    shifted_send = cell+sender_shift
    coords2 = coords[sender]+shifted_send@lattice
    if min(np.linalg.norm(cells-shifted_send, axis=1))<0.1:
        return [go.Scatter3d(x=[coords1[0],coords2[0]],
                            y=[coords1[1],coords2[1]],
                            z=[coords1[2],coords2[2]],
                            line={'color':'grey','width':10})]
    else:
        return []
    
def plot_structure(structure, atoms_style = {'Ba':['black',0.2],'Ti':['red',0.3],'O':['blue',0.4]},
                   lattices = np.mgrid[-1:1,-1:1,-1:1].reshape(3,-1).T, cutoff = 2.7, size=5,
                   shift = np.array([[0.,0,0]])):
    """
    plots crystal from Structure object
    Parameters
    ----------
    structure: Structure object containing structure data
    atoms_style: colors and sizes to plot atoms
    lattices: unit cells to plot
    cutoff: max distance between atoms considered bond
    size: size of plot
    shift: overall translation of plot
    """
    receivers, senders, senders_unit_shifts = neighbour_list(
        quantities="ijS",
        pbc=structure.pbc,
        cell=structure.lattice.matrix,
        positions=structure.cart_coords,
        cutoff=cutoff,
    )
    return plot_crystal(structure.cart_coords, structure.labels, structure.lattice.matrix,
                 receivers, senders, senders_unit_shifts,
                atoms_style = atoms_style, lattices=lattices, size=size, shift=shift)

def plot_data(data, atoms_style = {'Ba':['black',0.2],'Ti':['red',0.3],'O':['blue',0.4]},
                   lattices = np.mgrid[-1:1,-1:1,-1:1].reshape(3,-1).T,
                   size = 5, shift = np.array([[0.,0,0]])):
    """
    plots crystal from Data object
    Parameters
    ----------
    data: Data object containing structure data
    atoms_style: colors and sizes to plot atoms
    lattices: unit cells to plot
    size: size of plot
    shift: overall translation of plot
    """
    atoms = []
    one_hot = {'Ba':torch.tensor([1,0,0.]), 'Ti':torch.tensor([0,1,0.]), 'O':torch.tensor([0,0,1.])}
    for atom_encoded in data['z']:
        atom_encoded = atom_encoded[:3]
        minimum = 1000
        min_key = None
        for atom in one_hot:
            if torch.norm(atom_encoded-one_hot[atom])<minimum:
                min_key = atom
                minimum = torch.norm(atom_encoded-one_hot[atom])
        atoms.append(min_key)
    return plot_crystal(data['pos'].detach().numpy(), atoms, data['lattice'].detach().numpy(),
                 data['dst'].detach().numpy(), data['src'].detach().numpy(), data['shifts'].detach().numpy(),
                atoms_style = atoms_style, lattices=lattices, size=size, shift=shift)

def plot_crystal(coords, atoms, lattice, receivers, senders, senders_unit_shifts, atoms_style = {'Ba':['black',0.2],'Ti':['red',0.3],'O':['blue',0.4]}, 
              lattices = np.mgrid[-1:1,-1:1,-1:1].reshape(3,-1).T, size = 5, shift = np.array([[0.,0,0]])):
    """
    plots crystal
    Parameters
    ----------
    coords: coordinates of atoms
    atoms: atom types
    lattice: lattice parameters
    receivers: ending atoms for bond
    senders: starting atoms for bond
    sender_unit_shifts: unit cells sender shifted from cell of receiver
    atoms_style: colors and sizes to plot atoms
    lattices: unit cells to plot
    size: size of plot
    shift: overall translation of plot
    """
    atoms_plot = []

    for cell in lattices:
        atoms_plot += plot_cell(coords+shift, atoms, lattice, cell, atoms_style)

    bonds_plot = []

    for cell in lattices:
        for receiver, sender, sender_shift in zip(receivers, senders, senders_unit_shifts):
            bonds_plot += plot_bond(coords+shift, lattice, cell, lattices,
                                    sender, receiver, sender_shift)

    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=1, y=1, z=0.5)
    )

    layout = go.Layout(scene_xaxis_visible=False, 
                       scene_yaxis_visible=False, 
                       scene_zaxis_visible=False, 
                       scene_camera = camera,
                       scene=dict(aspectmode='cube',)
                      )

    fig = go.Figure(data = atoms_plot+bonds_plot, layout=layout)

    fig.update_layout(showlegend=False,
        scene = dict(
            xaxis = dict(nticks=4, range=[-size,size],),
            yaxis = dict(nticks=4, range=[-size,size],),
            zaxis = dict(nticks=4, range=[-size,size],),),
        width=700,
        height=800,
        margin=dict(r=0, l=0, b=0, t=0))

    return fig

In [11]:
def bond_angles_structure(structure, cutoff = 2.7):
    """
    computes bond angles from Structure object
    Parameters
    ----------
    structure: Structure object containing structure data
    cutoff: max distance between atoms considered bond
    """
    receivers, senders, senders_unit_shifts = neighbour_list(
        quantities="ijS",
        pbc=structure.pbc,
        cell=structure.lattice.matrix,
        positions=structure.cart_coords,
        cutoff=cutoff,
    )
    return bond_angles(structure.cart_coords, structure.labels, structure.lattice.matrix,
                       receivers, senders, senders_unit_shifts)

def bond_angles_data(data):
    """
    plots crystal from Data object
    Parameters
    ----------
    data: Data object containing structure data
    """
    atoms = []
    one_hot = {'Ba':torch.tensor([1,0,0.]), 'Ti':torch.tensor([0,1,0.]), 'O':torch.tensor([0,0,1.])}
    for atom_encoded in data['z']:
        atom_encoded = atom_encoded[:3]
        minimum = 1000
        min_key = None
        for atom in one_hot:
            if torch.norm(atom_encoded-one_hot[atom])<minimum:
                min_key = atom
                minimum = torch.norm(atom_encoded-one_hot[atom])
        atoms.append(min_key)
    return bond_angles(data['pos'].detach().numpy(), atoms, data['lattice'].detach().numpy(),
                 data['dst'].detach().numpy(), data['src'].detach().numpy(), data['shifts'].detach().numpy())

def bond_angles(coords, atoms, lattice, receivers, senders, senders_unit_shifts):
    """
    computes bond angles
    Parameters
    ----------
    coords: coordinates of atoms
    atoms: atom types
    lattice: lattice parameters
    receivers: ending atoms for bond
    senders: starting atoms for bond
    sender_unit_shifts: unit cells sender shifted from cell of receiver
    """
    angles = {}
    for i,a1 in enumerate(receivers):
        for j,a2 in enumerate(receivers[:i]):
            if a1==a2:
                coords1 = coords[a1]
                shifted_send1 = senders_unit_shifts[i]
                coords2 = coords[senders[i]]+shifted_send1@lattice
                vec1 = coords1-coords2

                coords1 = coords[a2]
                shifted_send2 = senders_unit_shifts[j]
                coords2 = coords[senders[j]]+shifted_send2@lattice
                vec2 = coords1-coords2

                angle = np.arccos(vec1@vec2/np.linalg.norm(vec1)/np.linalg.norm(vec2))/np.pi*180
                angles[(senders[i],a1,senders[j],tuple(shifted_send1),tuple(shifted_send2))] = angle
    return angles

In [12]:
def bond_length_structure(structure, cutoff = 2.7):
    """
    computes bond angles from Structure object
    Parameters
    ----------
    structure: Structure object containing structure data
    cutoff: max distance between atoms considered bond
    """
    receivers, senders, senders_unit_shifts = neighbour_list(
        quantities="ijS",
        pbc=structure.pbc,
        cell=structure.lattice.matrix,
        positions=structure.cart_coords,
        cutoff=cutoff,
    )
    return bond_length(structure.cart_coords, structure.labels, structure.lattice.matrix,
                       receivers, senders, senders_unit_shifts)

def bond_length_data(data):
    """
    plots crystal from Data object
    Parameters
    ----------
    data: Data object containing structure data
    """
    atoms = []
    one_hot = {'Ba':torch.tensor([1,0,0.]), 'Ti':torch.tensor([0,1,0.]), 'O':torch.tensor([0,0,1.])}
    for atom_encoded in data['z']:
        atom_encoded = atom_encoded[:3]
        minimum = 1000
        min_key = None
        for atom in one_hot:
            if torch.norm(atom_encoded-one_hot[atom])<minimum:
                min_key = atom
                minimum = torch.norm(atom_encoded-one_hot[atom])
        atoms.append(min_key)
    return bond_length(data['pos'].detach().numpy(), atoms, data['lattice'].detach().numpy(),
                 data['dst'].detach().numpy(), data['src'].detach().numpy(), data['shifts'].detach().numpy())

def bond_length(coords, atoms, lattice, receivers, senders, senders_unit_shifts):
    """
    computes bond angles
    Parameters
    ----------
    coords: coordinates of atoms
    atoms: atom types
    lattice: lattice parameters
    receivers: ending atoms for bond
    senders: starting atoms for bond
    sender_unit_shifts: unit cells sender shifted from cell of receiver
    """
    bonds = {}
    for i in range(len(senders)):
        a, b = senders[i], receivers[i]
        if a>b:
            coords2 = coords[b]
            shifted_send = senders_unit_shifts[i]
            coords1 = coords[a]+shifted_send@lattice
            vec = coords1-coords2

            bonds[(a,b,tuple(shifted_send))] = np.linalg.norm(vec)
    avg = sum([bonds[bond] for bond in bonds])/len(bonds)
    var = sum([(bonds[bond]-avg)**2 for bond in bonds])/len(bonds)
    return bonds, avg, var

### Searching for materials in materials project

In [13]:
import numpy as np
from mp_api.client import MPRester
# from mp_api import MPID

with MPRester(MP_API_KEY) as mpr:
    test_list = mpr.get_material_ids("BaTiO3")
test_list

Retrieving MaterialsDoc documents:   0%|          | 0/11 [00:00<?, ?it/s]

[MPID(mp-2998),
 MPID(mp-5020),
 MPID(mp-5777),
 MPID(mp-5933),
 MPID(mp-5986),
 MPID(mp-19990),
 MPID(mp-504715),
 MPID(mp-558125),
 MPID(mp-644497),
 MPID(mp-995191),
 MPID(mp-1076932)]

### Visualizing some BaTiO3 crystals

In [14]:
with MPRester(MP_API_KEY) as mpr:
    structure = mpr.get_structure_by_material_id("mp-5986",conventional_unit_cell=True)
    print(structure)
plot_structure(structure)
bond_angles_structure(structure)
bond_length_structure(structure)

Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Full Formula (Ba1 Ti1 O3)
Reduced Formula: BaTiO3
abc   :   3.990379   3.990379   4.102655
angles:  90.000000  90.000000  90.000000
pbc   :       True       True       True
Sites (5)
  #  SP      a    b         c
---  ----  ---  ---  --------
  0  Ba    0.5  0.5  0.582541
  1  Ti    0    0    0.100043
  2  O     0    0.5  0.064221
  3  O     0.5  0    0.064221
  4  O     0    0    0.548972


({(4, 1, (0, 0, -1)): 2.2608542032899206,
  (2, 1, (0, -1, 0)): 2.000595051111635,
  (3, 1, (-1, 0, 0)): 2.000595051111635,
  (2, 1, (0, 0, 0)): 2.000595051111635,
  (3, 1, (0, 0, 0)): 2.000595051111635,
  (4, 1, (0, 0, 0)): 1.8418011867100792},
 2.0175059324077567,
 0.015205741704478434)

In [16]:
with MPRester(MP_API_KEY) as mpr:
    structure = mpr.get_structure_by_material_id("mp-2998",conventional_unit_cell=True)
    print(structure)
plot_structure(structure)
bond_angles_structure(structure),bond_length_structure(structure)

Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Full Formula (Ba1 Ti1 O3)
Reduced Formula: BaTiO3
abc   :   4.007682   4.007682   4.007682
angles:  90.000000  90.000000  90.000000
pbc   :       True       True       True
Sites (5)
  #  SP      a    b    c
---  ----  ---  ---  ---
  0  Ba    0.5  0.5  0.5
  1  Ti    0    0    0
  2  O     0.5  0    0
  3  O     0    0.5  0
  4  O     0    0    0.5


({(3, 1, 4, (0, -1, 0), (0, 0, -1)): 90.0,
  (2, 1, 4, (-1, 0, 0), (0, 0, -1)): 90.0,
  (2, 1, 3, (-1, 0, 0), (0, -1, 0)): 90.0,
  (2, 1, 4, (0, 0, 0), (0, 0, -1)): 90.0,
  (2, 1, 3, (0, 0, 0), (0, -1, 0)): 90.0,
  (2, 1, 2, (0, 0, 0), (-1, 0, 0)): 180.0,
  (3, 1, 4, (0, 0, 0), (0, 0, -1)): 90.0,
  (3, 1, 3, (0, 0, 0), (0, -1, 0)): 180.0,
  (3, 1, 2, (0, 0, 0), (-1, 0, 0)): 90.0,
  (3, 1, 2, (0, 0, 0), (0, 0, 0)): 90.0,
  (4, 1, 4, (0, 0, 0), (0, 0, -1)): 180.0,
  (4, 1, 3, (0, 0, 0), (0, -1, 0)): 90.0,
  (4, 1, 2, (0, 0, 0), (-1, 0, 0)): 90.0,
  (4, 1, 2, (0, 0, 0), (0, 0, 0)): 90.0,
  (4, 1, 3, (0, 0, 0), (0, 0, 0)): 90.0,
  (1, 2, 1, (1, 0, 0), (0, 0, 0)): 180.0,
  (1, 3, 1, (0, 1, 0), (0, 0, 0)): 180.0,
  (1, 4, 1, (0, 0, 1), (0, 0, 0)): 180.0},
 ({(4, 1, (0, 0, -1)): 2.0038408178341407,
   (3, 1, (0, -1, 0)): 2.0038408178341407,
   (2, 1, (-1, 0, 0)): 2.0038408178341407,
   (2, 1, (0, 0, 0)): 2.0038408178341407,
   (3, 1, (0, 0, 0)): 2.0038408178341407,
   (4, 1, (0, 0, 0)): 2.003

# Atoms matching algorithm

Matches atoms based on how similar it is in relation to the other atoms. This takes care of any translation issues as long as the crystals share the same lattice. This is described in Appendix H.3.1.

In [7]:
import scipy.optimize

def local_signature(coords, lattice, labels, index):
    """
    computes signature (as described in Appendix H.3.1) of atom in structure
    Parameters
    ----------
    coords: coordinates of atoms
    lattice: lattice parameters
    labels: atom types
    index: index of atom to compute signature for
    """
    nearby_cells = np.mgrid[-1:2,-1:2,-1:2].reshape(3,-1).T
    signature = dict()
    for coord, atom in zip(coords, labels):
        diffs = (coord-coords[index]).reshape([-1,3])+nearby_cells@lattice
        if atom not in signature:
            signature[atom] = []
        signature[atom].append(diffs)
    return signature

def compare_signatures(sig1, sig2):
    """
    compares signature (as described in Appendix H.3.1) of two structures
    Parameters
    ----------
    sig1: signatures of first structure
    sig2: signatures of second structure
    """
    cost = 0
    for key in sig1:
        diffs_list1 = np.tile(np.array(sig1[key]),[27,len(sig1[key]),1,1,1]).transpose([1,2,0,3,4])
        diffs_list2 = np.tile(np.array(sig2[key]),[27,len(sig1[key]),1,1,1]).transpose([2,1,3,0,4])
        try:
            sig_diff = np.linalg.norm(diffs_list1-diffs_list2,axis=4)**2
        except:
            return np.inf
        sig_diff = sig_diff.reshape([diffs_list1.shape[0],diffs_list1.shape[1],-1])
        sig_diff = np.min(sig_diff,axis = 2)
        rows, cols = scipy.optimize.linear_sum_assignment(sig_diff)
        cost += sig_diff[rows,cols].sum()
    return cost

def get_local_signatures(coords, atoms, lattice):
    """
    computes local signatures (as described in Appendix H.3.1) of all atoms in structure
    Parameters
    ----------
    coords: coordinates of atoms
    atoms: atom types
    lattice: lattice parameters
    """
    sigs = []
    for index in range(len(atoms)):
        sigs.append(local_signature(coords, lattice, atoms, index))
    return sigs
    
def atoms_match(coords1, atoms1, lattice1, coords2, atoms2, lattice2):
    """
    matches atoms in different structures (as described in Appendix H.3.1)
    Parameters
    ----------
    coords1: coordinates of atoms in structure 1
    atoms1: atom types in structure 1
    lattice1: lattice parameters for structure 1
    coords2: coordinates of atoms in structure 2
    atoms2: atom types in structure 2
    lattice2: lattice parameters for structure 2
    """
    cost_matrix = []
    sigs1, sigs2 = get_local_signatures(coords1, atoms1, lattice1), get_local_signatures(coords2, atoms2, lattice2)
    for a1, s1 in zip(atoms1, sigs1):
        cost_matrix.append([])
        for a2, s2 in zip(atoms2, sigs2):
            cost_matrix[-1].append(compare_signatures(s1,s2) if a1==a2 else np.inf)
    cost_matrix = np.array(cost_matrix)
    matches1, matches2 = scipy.optimize.linear_sum_assignment(cost_matrix)
    return matches1, matches2, cost_matrix, sum(cost_matrix[matches1,matches2])

def atoms_match_structure(structure1, structure2):
    """
    matches atoms in different structures (as described in Appendix H.3.1) from Structure objects
    Parameters
    ----------
    structure1: Structure object for structure 1
    structure2: Structure object for structure 2
    """
    return atoms_match(structure1.cart_coords,structure1.labels,structure1.lattice.matrix,
                        structure2.cart_coords,structure2.labels,structure2.lattice.matrix)

def atoms_match_structure_same_lattice(structure1, structure2, lattice):
    """
    matches atoms in different structures (as described in Appendix H.3.1) from Structure objects
    but with same lattice
    Parameters
    ----------
    structure1: Structure object for structure 1
    structure2: Structure object for structure 2
    lattice: lattice parameters
    """
    return atoms_match(structure1.frac_coords@lattice,structure1.labels,lattice,
                        structure2.frac_coords@lattice,structure2.labels,lattice)

def matching_neighbors_cells(coords1, lattice1, coords2, lattice2, matches1, matches2):
    """
    generates information about lattice shifts from matching algorithm (as described in Appendix H.3.1)
    Parameters
    ----------
    coords1: coordinates of atoms in structure 1
    atoms1: atom types in structure 1
    lattice1: lattice parameters for structure 1
    coords2: coordinates of atoms in structure 2
    atoms2: atom types in structure 2
    lattice2: lattice parameters for structure 2
    """
    nearby_cells = np.mgrid[-1:2,-1:2,-1:2].reshape(3,-1).T
    neighbors_cells1, neighbors_cells2 = [], []
    for a1, a2 in zip(matches1,matches2):
        neighbors_cells1.append([])
        neighbors_cells2.append([])
        for b1, b2 in zip(matches1,matches2):
            diffs1 = np.tile((coords1[b1]-coords1[a1]).reshape([-1,3])+nearby_cells@lattice1,[27,1,1])
            diffs2 = np.tile((coords2[b2]-coords2[a2]).reshape([-1,3])+nearby_cells@lattice2,[27,1,1])
            atom_diffs = diffs1.transpose([1,0,2])-diffs2
            atom_diffs = np.linalg.norm(atom_diffs,axis=2)
            i1,i2 = np.unravel_index(np.argmin(atom_diffs),atom_diffs.shape)
            neighbors_cells1[-1].append(nearby_cells[i1])
            neighbors_cells2[-1].append(nearby_cells[i2])
    return neighbors_cells1, neighbors_cells2

def compare_structures(coords1, lattice1, coords2, lattice2,
                       matches1, matches2, neighbors_cells1, neighbors_cells2):
    """
    compares differences in two structure
    Parameters
    ----------
    coords1: coordinates of atoms in structure 1
    lattice1: lattice parameters for structure 1
    coords2: coordinates of atoms in structure 2
    lattice2: lattice parameters for structure 2
    matches1: indices of atoms in a match in structure 1
    matches2: indices of atoms in a match in structure 2
    neighbors_cells1: lattice shift information for matching for atoms in structure 1
    neighbors_cells2: lattice shift information for matching for atoms in structure 2
    """
#     nearby_cells = torch.tensor(np.mgrid[-1:2,-1:2,-1:2].reshape(3,-1).T,dtype = torch.float) if torch_version else np.mgrid[-1:2,-1:2,-1:2].reshape(3,-1).T
    loss = 0.
    for a1, a2, n1s, n2s in zip(matches1,matches2, neighbors_cells1, neighbors_cells2):
        for b1, b2, n1, n2 in zip(matches1,matches2, n1s, n2s):
            diff = coords1[b1]-coords1[a1]+n1@lattice1-(coords2[b2]-coords2[a2]+n2@lattice2)
            loss += sum(diff*diff)
    return loss

Testing that it works

In [8]:
with MPRester(MP_API_KEY) as mpr:
    structure1 = mpr.get_structure_by_material_id("mp-2998",conventional_unit_cell=True)
    structure2 = mpr.get_structure_by_material_id("mp-5777",conventional_unit_cell=False)
    
### adding arbitrary translations
new_struct1_coords = structure1.cart_coords+np.array([0,0,0])
new_struct1_coords = new_struct1_coords@np.linalg.inv(structure1.lattice.matrix)
new_struct1_coords = (new_struct1_coords%1)@structure1.lattice.matrix

matches1, matches2, c_mat, l = atoms_match(new_struct1_coords,structure1.labels,structure1.lattice.matrix,
                        structure2.cart_coords,structure2.labels,structure2.lattice.matrix)
neighbors_cells1, neighbors_cells2 = matching_neighbors_cells(new_struct1_coords,structure1.lattice.matrix,
                        structure2.cart_coords,structure2.lattice.matrix, matches1, matches2)
mismatch = compare_structures(new_struct1_coords,structure1.lattice.matrix,
                        structure2.cart_coords,structure2.lattice.matrix,
                        matches1, matches2, neighbors_cells1, neighbors_cells2)
print(matches1, matches2, c_mat, l, mismatch)

Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

[0 1 2 3 4] [0 1 4 3 2] [[9.65896188e-03            inf            inf            inf
             inf]
 [           inf 5.41668809e-02            inf            inf
             inf]
 [           inf            inf 2.22070171e+01 2.23301461e+01
  2.08871093e-02]
 [           inf            inf 2.17885994e+01 1.95269950e-02
  2.27580321e+01]
 [           inf            inf 2.22152256e-02 2.15688169e+01
  2.24905535e+01]] 0.1264551727155693 0.1264551727155693


In [9]:
from torch_geometric.data import Data
import torch

def create_data(structure, cutoff = 2.7, lattice = None, breaker = None):
    """
    creates Data object from structure
    Parameters
    ----------
    structure: Structure object for crystal structure
    cutoff: radial cutoff for bonds
    lattice: lattice parameters to override structure lattice
    breaker: symmetry breaking object shared across nodes
    """
    if lattice is None:
        lattice = structure.lattice.matrix
    receivers, senders, senders_unit_shifts = neighbour_list(
        quantities="ijS",
        pbc=structure.pbc,
        cell=lattice,
        positions=structure.frac_coords@lattice,
        cutoff=cutoff,
    )
    
    one_hot = {'Ba':[1,0,0.], 'Ti':[0,1,0.], 'O':[0,0,1.]}
    if breaker is None:
        z = torch.tensor([one_hot[atom] for atom in structure.labels], dtype=torch.float)
    else:
        z = torch.tensor([one_hot[atom]+list(breaker) for atom in structure.labels], dtype=torch.float)
    
    return Data(pos = torch.tensor(structure.frac_coords@lattice, dtype=torch.float),
                x = breaker,
                z = z,
                lattice = torch.tensor(lattice, dtype=torch.float),
                src = torch.tensor(senders, dtype=torch.long), dst = torch.tensor(receivers, dtype=torch.long),
                shifts = torch.tensor(senders_unit_shifts, dtype=torch.float))

def create_matching_data(structure1, structure2, lattice = None):
    """
    creates matching data (as described in Appendix H.3.1) to compute loss efficiently
    Parameters
    ----------
    structure1: Structure object for structure 1
    structure2: Structure object for structure 2
    lattice: lattice parameters
    """
    if lattice is None:
        m1, m2, c_mat, l = atoms_match_structure(structure1, structure2)
        n1s, n2s = matching_neighbors_cells(structure1.cart_coords,structure1.lattice.matrix,
                        structure2.cart_coords,structure2.lattice.matrix, m1, m2)
    else:
        m1, m2, c_mat, l = atoms_match_structure_same_lattice(structure1, structure2, lattice)
        n1s, n2s = matching_neighbors_cells(structure1.frac_coords@lattice, lattice,
                        structure2.frac_coords@lattice, lattice, m1, m2)
    print(c_mat, l)
    n1s, n2s = torch.tensor(np.array(n1s), dtype=torch.float32), torch.tensor(np.array(n2s), dtype = torch.float32)
    return Data(matches1 = m1, matches2 = m2, neighbors_cells1 = n1s, neighbors_cells2 = n2s)

def get_matched_structure_mismatch(matching_data):
    """
    creates function to compare structures given matching data from matching algorithm
    Parameters
    ----------
    matching_data: matching data
    """
    def struct_mismatch(structure1, structure2):
        return compare_structures(structure1['pos'],structure1['lattice'],
                        structure2['pos'],structure2['lattice'],
                        matching_data['matches1'], matching_data['matches2'],
                        matching_data['neighbors_cells1'], matching_data['neighbors_cells2'])
    return struct_mismatch

In [10]:
lattice = None
lattice = np.array([[4.,0,0],[0,4,0],[0,0,4]])
cutoff = 2.7

with MPRester(MP_API_KEY) as mpr:
    initial_structure = mpr.get_structure_by_material_id("mp-2998",conventional_unit_cell=True)
    target_structure = mpr.get_structure_by_material_id("mp-5986",conventional_unit_cell=True)
initial_structure_data = create_data(initial_structure, lattice=lattice, cutoff = cutoff)
target_structure_data = create_data(target_structure, lattice=lattice, cutoff = cutoff)
matching_data = create_matching_data(initial_structure, target_structure, lattice = lattice)
# plot_data(initial_structure_data).show()
# plot_data(target_structure_data).show()

Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

Retrieving MaterialsDoc documents:   0%|          | 0/1 [00:00<?, ?it/s]

[[ 0.0336725          inf         inf         inf         inf]
 [        inf  0.08769583         inf         inf         inf]
 [        inf         inf 24.02962254  0.02962254 22.4689723 ]
 [        inf         inf  0.02962254 24.02962254 22.4689723 ]
 [        inf         inf 23.16333326 23.16333326  0.0672035 ]] 0.247816911080844


In [11]:
compare_structures(initial_structure_data['pos'],initial_structure_data['lattice'],
                        target_structure_data['pos'],target_structure_data['lattice'],
                        matching_data['matches1'], matching_data['matches2'],
                        matching_data['neighbors_cells1'], matching_data['neighbors_cells2'])
structure_mismatch = get_matched_structure_mismatch(matching_data)
structure_mismatch(initial_structure_data, target_structure_data)

tensor(0.2478)

# Symmetry breaking set

In [12]:
import e3nn
import e3nn.o3 as o3
import torch
import torch.optim
from e3nn.io import SphericalTensor

from e3nn.nn.models.gate_points_2102 import *
from torch_cluster import radius_graph

device='cpu'

In [120]:
def Oh_rotations_default_orient(i):
    """
    generates rotational components of normalizer of O)h
    Parameters
    ----------
    i - index for accessing elements in normalizer group
    """
    group_elements = [(torch.tensor([1.,0,0]),torch.tensor(0.)), ## identity
                      (torch.tensor([1.,0,0]),torch.tensor(np.pi/2)), ## deg 90 rot about 4-fold axes
                      (torch.tensor([1.,0,0]),torch.tensor(3*np.pi/2)),
                      (torch.tensor([0.,1,0]),torch.tensor(np.pi/2)),
                      (torch.tensor([0.,1,0]),torch.tensor(3*np.pi/2)),
                      (torch.tensor([0.,0,1]),torch.tensor(np.pi/2)),
                      (torch.tensor([0.,0,1]),torch.tensor(3*np.pi/2)),
                      (torch.tensor([1.,0,0]),torch.tensor(np.pi)), ## deg 180 rot about 4-fold axes
                      (torch.tensor([0.,1,0]),torch.tensor(np.pi)),
                      (torch.tensor([0.,0,1]),torch.tensor(np.pi)),
                      (torch.tensor([1/2**0.5,1/2**0.5,0]),torch.tensor(np.pi)), ## deg 180 rot about 2-fold axes
                      (torch.tensor([1/2**0.5,-1/2**0.5,0]),torch.tensor(np.pi)),
                      (torch.tensor([1/2**0.5,0,1/2**0.5]),torch.tensor(np.pi)),
                      (torch.tensor([1/2**0.5,0,-1/2**0.5]),torch.tensor(np.pi)),
                      (torch.tensor([0.,1/2**0.5,1/2**0.5]),torch.tensor(np.pi)),
                      (torch.tensor([0.,1/2**0.5,-1/2**0.5]),torch.tensor(np.pi)),
                      (torch.tensor([1.,1,1])/3**0.5,torch.tensor(2*np.pi/3)), ## deg 120 rot about 3-fold axes
                      (torch.tensor([1.,1,1])/3**0.5,torch.tensor(4*np.pi/3)),
                      (torch.tensor([1.,1,-1])/3**0.5,torch.tensor(2*np.pi/3)),
                      (torch.tensor([1.,1,-1])/3**0.5,torch.tensor(4*np.pi/3)),
                      (torch.tensor([1.,-1,1])/3**0.5,torch.tensor(2*np.pi/3)),
                      (torch.tensor([1.,-1,1])/3**0.5,torch.tensor(4*np.pi/3)),
                      (torch.tensor([-1.,1,1])/3**0.5,torch.tensor(2*np.pi/3)),
                      (torch.tensor([-1.,1,1])/3**0.5,torch.tensor(4*np.pi/3))
                     ]
    return group_elements[i]

def equi_symmetry_breaker(i=-1, axis=torch.tensor([0,0,1.]), angle=torch.tensor(0.)):
    """
    creates equivariant symmetry breaker
    Parameters
    ----------
    i - index for accessing elements in SBS. Chooses randomly if negative
    axis - axis we rotate about
    angle- angle by which we rotate our SBS
    """
    i = int(torch.randint(0,24,(1,))[0].detach()) if i<0 else i
    irreps = o3.Irreps('1o')
    rot = irreps.D_from_axis_angle(axis,angle)
#     s = SphericalTensor(lmax = 4, p_val = 1, p_arg = -1)
#     element = s.sum_of_diracs(torch.tensor([[1.,0,0],[0,1,0]],dtype=torch.float32),torch.ones([2]))[-9:]
#     element = torch.cat([torch.tensor([0.,0,1/3]),element])
    
    element = torch.tensor([0.,0,1])
    
    axis, angle = Oh_rotations_default_orient(i)
    group_rot = irreps.D_from_axis_angle(axis,angle)
    return element@group_rot.T@rot.T

In [39]:
import plotly
import plotly.graph_objects as go
import plotly.io as pio

pio.renderers.default = "notebook"
pio.renderers

Renderers configuration
-----------------------
    Default renderer: 'notebook'
    Available renderers:
        ['plotly_mimetype', 'jupyterlab', 'nteract', 'vscode',
         'notebook', 'notebook_connected', 'kaggle', 'azure', 'colab',
         'cocalc', 'databricks', 'json', 'png', 'jpeg', 'jpg', 'svg',
         'pdf', 'browser', 'firefox', 'chrome', 'chromium', 'iframe',
         'iframe_connected', 'sphinx_gallery', 'sphinx_gallery_png']

In [143]:
breaker = equi_symmetry_breaker(0)*0
print(breaker)
initial_structure_data = create_data(initial_structure, lattice=lattice, cutoff = cutoff, breaker = breaker)
target_structure_data = create_data(target_structure, lattice=lattice, cutoff = cutoff, breaker = breaker)

plot_data(initial_structure_data).show()
plot_data(target_structure_data).show()

tensor([-0., 0., 0.])


# Custom network

Modified from builtin model in e3nn so that we directly tell it what the edges are rather than infering it from the positions. Because we are dealing with crystals, we also incorporate periodic boundary conditions.

In [125]:
from typing import Dict, Union

from e3nn import o3
from e3nn.math import soft_one_hot_linspace
from e3nn.nn import FullyConnectedNet, Gate
from e3nn.o3 import FullyConnectedTensorProduct, TensorProduct
from e3nn.util.jit import compile_mode

from torch_geometric.data import Data
from torch_cluster import radius_graph
from torch_scatter import scatter

class CustomNetwork(torch.nn.Module):
    r"""equivariant neural network
    Parameters
    ----------
    irreps_in : `e3nn.o3.Irreps` or None
        representation of the input features
        can be set to ``None`` if nodes don't have input features
    irreps_hidden : `e3nn.o3.Irreps`
        representation of the hidden features
    irreps_out : `e3nn.o3.Irreps`
        representation of the output features
    irreps_node_attr : `e3nn.o3.Irreps` or None
        representation of the nodes attributes
        can be set to ``None`` if nodes don't have attributes
    irreps_edge_attr : `e3nn.o3.Irreps`
        representation of the edge attributes
        the edge attributes are :math:`h(r) Y(\vec r / r)`
        where :math:`h` is a smooth function that goes to zero at ``max_radius``
        and :math:`Y` are the spherical harmonics polynomials
    layers : int
        number of gates (non linearities)
    number_of_basis : int
        number of basis on which the edge length are projected
    radial_layers : int
        number of hidden layers in the radial fully connected network
    radial_neurons : int
        number of neurons in the hidden layers of the radial fully connected network
    num_neighbors : float
        typical number of nodes at a distance ``max_radius``
    num_nodes : float
        typical number of nodes in a graph
    """

    def __init__(
        self,
        irreps_in,
        irreps_hidden,
        irreps_out,
        irreps_node_attr,
        irreps_edge_attr,
        layers,
        max_radius,
        number_of_basis,
        radial_layers,
        radial_neurons,
        num_neighbors,
        num_nodes,
        reduce_output=True,
    ) -> None:
        super().__init__()
        self.max_radius = max_radius
        self.number_of_basis = number_of_basis
        self.num_neighbors = num_neighbors
        self.num_nodes = num_nodes
        self.reduce_output = reduce_output

        self.irreps_in = o3.Irreps(irreps_in) if irreps_in is not None else None
        self.irreps_hidden = o3.Irreps(irreps_hidden)
        self.irreps_out = o3.Irreps(irreps_out)
        self.irreps_node_attr = o3.Irreps(irreps_node_attr) if irreps_node_attr is not None else o3.Irreps("0e")
        self.irreps_edge_attr = o3.Irreps(irreps_edge_attr)

        self.input_has_node_in = irreps_in is not None
        self.input_has_node_attr = irreps_node_attr is not None

        irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps("0e")

        act = {
            1: torch.nn.functional.silu,
            -1: torch.tanh,
        }
        act_gates = {
            1: torch.sigmoid,
            -1: torch.tanh,
        }

        self.layers = torch.nn.ModuleList()

        for _ in range(layers):
            irreps_scalars = o3.Irreps(
                [
                    (mul, ir)
                    for mul, ir in self.irreps_hidden
                    if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)
                ]
            )
            irreps_gated = o3.Irreps(
                [(mul, ir) for mul, ir in self.irreps_hidden if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir)]
            )
            ir = "0e" if tp_path_exists(irreps, self.irreps_edge_attr, "0e") else "0o"
            irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated])

            gate = Gate(
                irreps_scalars,
                [act[ir.p] for _, ir in irreps_scalars],  # scalar
                irreps_gates,
                [act_gates[ir.p] for _, ir in irreps_gates],  # gates (scalars)
                irreps_gated,  # gated tensors
            )
            conv = Convolution(
                irreps,
                self.irreps_node_attr,
                self.irreps_edge_attr,
                gate.irreps_in,
                number_of_basis,
                radial_layers,
                radial_neurons,
                num_neighbors,
            )
            irreps = gate.irreps_out
            self.layers.append(Compose(conv, gate))

        self.layers.append(
            Convolution(
                irreps,
                self.irreps_node_attr,
                self.irreps_edge_attr,
                self.irreps_out,
                number_of_basis,
                radial_layers,
                radial_neurons,
                num_neighbors,
            )
        )

    def forward(self, data: Union[Data, Dict[str, torch.Tensor]]) -> torch.Tensor:
        """evaluate the network
        Parameters
        ----------
        data : `torch_geometric.data.Data` or dict
            data object containing
            - ``pos`` the position of the nodes (atoms)
            - ``x`` the input features of the nodes, optional
            - ``z`` the attributes of the nodes, for instance the atom type, optional
            - ``batch`` the graph to which the node belong, optional
        """
        if "batch" in data:
            batch = data["batch"]
        else:
            batch = data["pos"].new_zeros(data["pos"].shape[0], dtype=torch.long)

#         edge_index = radius_graph(data["pos"], self.max_radius, batch)
#         edge_src = edge_index[0]
#         edge_dst = edge_index[1]
        
        edge_src = data["src"]
        edge_dst = data["dst"]
        
        edge_vec = data["pos"][edge_src] + data["shifts"]@data["lattice"] - data["pos"][edge_dst]
        edge_sh = o3.spherical_harmonics(self.irreps_edge_attr, edge_vec, True, normalization="component")
        edge_length = edge_vec.norm(dim=1)
        edge_length_embedded = soft_one_hot_linspace(
            x=edge_length, start=0.0, end=self.max_radius, number=self.number_of_basis, basis="gaussian", cutoff=False
        ).mul(self.number_of_basis**0.5)
        edge_attr = smooth_cutoff(edge_length / self.max_radius)[:, None] * edge_sh

        if self.input_has_node_in and "x" in data:
            assert self.irreps_in is not None
            x = data["x"]
        else:
            assert self.irreps_in is None
            x = data["pos"].new_ones((data["pos"].shape[0], 1))

        if self.input_has_node_attr and "z" in data:
            z = data["z"]
        else:
            assert self.irreps_node_attr == o3.Irreps("0e")
            z = data["pos"].new_ones((data["pos"].shape[0], 1))

        for lay in self.layers:
            x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded)

        if self.reduce_output:
            return scatter(x, batch, dim=0).div(self.num_nodes**0.5)
        else:
            return x

# Training

In [140]:
model_kwargs = {
    "irreps_in": None,
    "irreps_hidden": [(mul, (l, p)) for l, mul in enumerate([16,8,4]) for p in [1,-1]],
    "irreps_out": "1x1o",
    "irreps_node_attr": '3x0e+1o',
    "irreps_edge_attr": SphericalTensor(lmax = 4, p_val = 1, p_arg = -1),
    "layers": 3,
    "max_radius": 2.7,
    "number_of_basis": 10,
    "radial_layers": 3,
    "radial_neurons": 16,
    "num_neighbors": 3,
    "num_nodes": 6,
    "reduce_output": False,
}


model = CustomNetwork(**model_kwargs).to(device)

In [141]:
def get_model_shifted_structure(data, model):
    """
    creates Data object for structure distorted by model outputs
    Parameters
    ----------
    data: input Data object to model
    model: neural network
    """
    output = model(data)
    return Data(pos = data['pos']+output,
                z = data['z'],
#                 lattice = data['lattice']+torch.mean(torch.reshape(output[:,3:],[-1,3,3]),0),
                lattice = data['lattice'],
                src = data['src'], dst = data['dst'],
                shifts = data['shifts'])

In [147]:
import copy

breaker = equi_symmetry_breaker(0)*0
initial_structure_data = create_data(initial_structure, lattice=lattice, cutoff = cutoff, breaker = breaker)
target_structure_data = create_data(target_structure, lattice=lattice, cutoff = cutoff, breaker = breaker)

iters = 300

losses = []

input_opt = torch.optim.Adam(model.parameters(),lr=3e-2)
min_loss = 10000
best_model = model
for i in range(iters):
    out = model(initial_structure_data)
    
    output_structure_data = get_model_shifted_structure(initial_structure_data, model)
    loss = structure_mismatch(output_structure_data,target_structure_data)
    if loss<min_loss:
        min_loss = loss
        best_model = copy.deepcopy(model)
    loss.backward()
    input_opt.step()
    input_opt.zero_grad()
    losses.append(float(loss))
    if i%50==0:
        print(f'loss: {loss}, iter: {i}, min_loss: {min_loss}')

loss: 0.24781672656536102, iter: 0, min_loss: 0.24781672656536102



KeyboardInterrupt



Visualizing training loss.

In [145]:
x = list(range(len(losses)))
go.Figure(data = go.Scatter(x=x,y=[np.log10(loss) for loss in losses])).show()
go.Figure(data = go.Scatter(x=x,y=losses)).show()

In [149]:
breaker = equi_symmetry_breaker(0)*0
initial_structure_data = create_data(initial_structure, lattice=lattice, cutoff = cutoff, breaker = breaker)
output_structure_data = get_model_shifted_structure(initial_structure_data, best_model)

bond_length_data(output_structure_data)

({(4, 1, (0.0, 0.0, -1.0)): 2.0,
  (3, 1, (0.0, -1.0, 0.0)): 2.0,
  (2, 1, (-1.0, 0.0, 0.0)): 2.0,
  (2, 1, (0.0, 0.0, 0.0)): 2.0,
  (3, 1, (0.0, 0.0, 0.0)): 2.0,
  (4, 1, (0.0, 0.0, 0.0)): 2.0},
 2.0,
 0.0)

In [24]:
breaker = equi_symmetry_breaker(0)
initial_structure_data = create_data(initial_structure, lattice=lattice, cutoff = cutoff, breaker = breaker)
output_structure_data = get_model_shifted_structure(initial_structure_data, best_model)
structure_mismatch(output_structure_data,target_structure_data)

tensor(1.1369e-13, grad_fn=<AddBackward0>)

# Plotting our training results

Here, we plot and save the results of our model.

### Initial structure

In [132]:
import os
path = os.getcwd()
shift = np.array([[5.,-5,1.5]])*0
fig = plot_data(initial_structure_data, shift=shift)

camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=1.1, y=0.9, z=0.5)
    )
fig.update_layout(scene_camera = camera)

# fig.write_image(path+'/BaTiO3_initial.png')
fig.show()

Cropping image

In [None]:
from PIL import Image
import os

img = Image.open(path+'/BaTiO3_initial.png')
left = 40
right = 630
top = 160
bottom = 720

img_res = img.crop((left, top, right, bottom))
# img_res.show()
img_res.save(path+'/BaTiO3_initial.png')

### Model distorted structures

In [133]:
def draw_arrow(vector,pos = [0,0,0], color = 'black', arrow_size=0.3):
    """
    Parameters
    ----------
    vector - direction and size our arrow points
    pos - base the arrow starts at
    color - color of arrow
    arrow_size - size of head of arrow
    """
    return [go.Scatter3d(x=[pos[0],pos[0]+vector[0]],
                             y=[pos[1],pos[1]+vector[1]], 
                             z=[pos[2],pos[2]+vector[2]], mode="lines", line=dict(color=color,width=5)),
              go.Cone(x=[pos[0]+vector[0]*1.1], y=[pos[1]+vector[1]*1.1], z=[pos[2]+vector[2]*1.1],
                      u=[vector[0]*arrow_size], v=[vector[1]*arrow_size], w=[vector[2]*arrow_size],
                      sizemode="absolute", anchor='tip',
                      colorscale = [[0, color],[1, color]], showscale=False)]

Plotting symmetry breaking objects

In [134]:
index = 0
indices = [0,1,2,3,4,8]
for i, index in enumerate(indices):
    fig = go.Figure()
    shift = np.array([3.,-7,-0.5])*0
    for j in [0,1,2,3,4,8]:
        breaker = equi_symmetry_breaker(j).numpy()
        breaker_plot = draw_arrow(breaker*1.5,pos = shift ,arrow_size=0.7,color='gray')

        [fig.add_trace(b_plot) for b_plot in breaker_plot]

    breaker = equi_symmetry_breaker(index).numpy()
    breaker_plot = draw_arrow(breaker*1.5,pos = shift ,arrow_size=0.7,color='blue')

    [fig.add_trace(b_plot) for b_plot in breaker_plot]

    size=6
    shift = 0

    camera = dict(
            up=dict(x=0, y=0, z=1),
            center=dict(x=0, y=0, z=0),
            eye=dict(x=1.1, y=0.9, z=0.5)
        )
    fig.update_layout(showlegend=False,
            scene_camera = camera,
            scene_xaxis_visible=False, 
            scene_yaxis_visible=False, 
            scene_zaxis_visible=False,
        scene = dict(
            xaxis = dict(nticks=4, range=[-size+shift,size+shift],),
            yaxis = dict(nticks=4, range=[-size-shift,size-shift],),
            zaxis = dict(nticks=4, range=[-size,size],),
            aspectmode = 'cube'),
        width=900,
        height=900,
        margin=dict(r=0, l=0, b=0, t=0))
    fig.write_image(path+f'/axes_{i}.png')
# fig.show()

Plotting model outputs given the distinct symmetry breaking objects

In [136]:
import os

path = os.getcwd()

indices = [0,1,2,3,4,8]
# indices = [0]

for i, index in enumerate(indices):
    breaker = equi_symmetry_breaker(index)
    initial_structure_data = create_data(initial_structure, lattice=lattice, cutoff = cutoff, breaker = breaker)
    lattices = np.mgrid[-1:1,-1:1,-1:1].reshape([3,-1]).T
    size, shift =5., np.array([[5.,-5,1.5]])*0
    fig = plot_data(get_model_shifted_structure(initial_structure_data, best_model), lattices = lattices, size=size, shift=shift)
    

    camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=1.1, y=0.9, z=0.5)
    )
    fig.update_layout(scene_camera = camera)
fig
#     fig.write_image(path+f'/BaTiO3_model_only_tilted_{i}.png')
# fig.write_image(path+'/triangular_prism_all.png')

Cropping images of structure and symmetry breaking objects and putting them together

In [None]:
from PIL import Image
import os

path = os.getcwd()
for i in range(6):
    img = Image.open(path+f'/BaTiO3_model_only_tilted_{i}.png')
    left = 40
    right = 630
    top = 160
    bottom = 720

    img_res = img.crop((left, top, right, bottom))
#     img_res.show()
    
    img2 = Image.open(path+f'/axes_{i}.png')
    left = 300
    right = 550
    top = 160
    bottom = 720

    img2_res = img2.crop((left, top, right, bottom))
    
    new_image = Image.new('RGB',(img_res.size[0]+img2_res.size[0], img_res.size[1]), (250,250,250))
    new_image.paste(img_res,(0,0))
    new_image.paste(img2_res,(img_res.size[0],0))
#     new_image.show()
    new_image.save(path+f'/BaTiO3_model_tilted_{i}.png')

### Target structure

In [None]:
import os
path = os.getcwd()

shift = np.array([[5.,-5,1.5]])*0
fig = plot_data(target_structure_data, shift=shift)

camera = dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0, y=0, z=0),
        eye=dict(x=1.1, y=0.9, z=0.5)
    )
fig.update_layout(scene_camera = camera)
fig.write_image(path+'/BaTiO3_target.png')
# fig.show()

Cropping image

In [None]:
from PIL import Image
import os

img = Image.open(path+'/BaTiO3_target.png')
left = 40
right = 630
top = 160
bottom = 720

img_res = img.crop((left, top, right, bottom))
img_res.save(path+'/BaTiO3_target.png')
# img_res.show()

In [None]:
key = 'lattice'
get_model_shifted_structure(initial_structure_data)[key], target_structure_data[key]