# rewrite loading xyz using ase

In [8]:
import py3Dmol
import numpy as np
import matplotlib.pyplot as plt
# from IPython.display import display
from ipyfilechooser import FileChooser
# import re
import ase
from ase.io import read
import plotly.graph_objects as go
import copy
from pprint import pprint



class ClusterNeighbor(object):
    def __init__(self):
        pass
    
    def _calculate_distance(self, coord1, coord2):
        return np.sqrt((coord2[0] - coord1[0])**2 + (coord2[1] - coord1[1])**2 + (coord2[2] - coord1[2])**2)


    def load_xyz(self, from_file=True, xyz_path=None,  atom_object=None):
        if from_file is True:
            self.atoms = read(xyz_path)
        else:
            self.atoms = atom_object
        self.elements = [self.atoms[i].symbol for i in range(len(self.atoms))]
        self.elements_num = {element_i: self.elements.count(element_i) for element_i in set(self.elements)}
        self.element_index_group = {element_set_i: [i for i, element_i in enumerate(self.elements) if element_i == element_set_i] for element_set_i in set(self.elements)}
        self.center = self.atoms.get_center_of_mass()

    def view_xyz(self, style_all=None, highlight_atom1="O", highlight_atom2="Pb", label=True):
        self.xyz_string = f"{len(self.atoms)}\n\n" 
        for atom in self.atoms:
            self.xyz_string += f"{atom.symbol} {atom.position[0]} {atom.position[1]} {atom.position[2]}\n"
        
        self.view = py3Dmol.view(width=500,height=500)
        self.view.addModel(self.xyz_string,'xyz',)
        if style_all is None:
            style_all = {'stick':{'radius':.1, 'alpha':0.2, 'color':'gray'}, 
                         'sphere': {'radius':.3}
                        }
        self.view.setStyle(style_all)
        
        self.view.addStyle({'atom': highlight_atom1}, 
                           {'sphere': {'color': 'red', 'radius': 0.5}})  
        
        self.view.addStyle({'atom': highlight_atom2}, 
                           {'sphere': {'color': 'blue', 'radius': 0.3}})  
        self.view.setBackgroundColor('0xeeeeee')
        if label:
            for i, atom_i in enumerate(self.atoms):
                self.view.addLabel(f"{i}", {'position': {'x': atom_i.position[0], 'y': atom_i.position[1], 'z': atom_i.position[2]}, 
                                    'fontColor': 'k', 'fontSize': 12, 'backgroundColor': 'white', 'backgroundOpacity':0.5})
        self.view.zoomTo()
        self.view.show()
        self.view.title(self.atoms.get_chemical_formula())
    
    def get_cluster_size(self):
        self.cluster_size = self.atoms.get_all_distances().max()/2
        print(f"Cluster size is {self.cluster_size} A")
        return self.cluster_size
    
    def shrink_cluster_size(self, new_radius=None):
        if new_radius is None:
            new_radius = self.cluster_size - 0.1
        
        atoms_smaller = copy.deepcopy(self.atoms)
        indices_remove_lst = []
        for i, atom_i in enumerate(atoms_smaller):
            radius_i = np.abs(self._calculate_distance(self.atoms.get_positions()[i], self.center))
            if radius_i > new_radius:
                indices_remove_lst.append(i)

        atoms_smaller = ase.Atoms(self.atoms[[i for i in range(len(self.atoms)) if i not in indices_remove_lst]])
        return atoms_smaller
    
    def get_pairs(self):
        self.pairs_index = [(i, j) for i in range(len(self.atoms)) for j in range(i + 1, len(self.atoms))]
        # self.pairs_index = [(i, j) for i in range(len(self.atoms)) for j in range(len(self.atoms))]
        self.pairs_element = [sorted([self.atoms[i].symbol, self.atoms[j].symbol]) for i, j in self.pairs_index]
        self.pairs = [f"{atom_i}-{atom_j}" for atom_i, atom_j in self.pairs_element]
        self.pairs_unique = [f"{self.atoms[i].symbol}({self.atoms[i].index})-{self.atoms[j].symbol}({self.atoms[j].index})" for i, j in self.pairs_index]
        self.distance_all = [self.atoms.get_all_distances()[i][j] for i, j in self.pairs_index]
        self.pairs_types = set(self.pairs)
        
        self.pairs_group = {key: {'pairs_index':[],
                                  'pairs':[],
                                  'pairs_unique':[],
                                  'distance':[]} for key in self.pairs_types}
        
        for i in range(len(self.pairs)):
            self.pairs_group[self.pairs[i]]['pairs_index'].append(self.pairs_index[i])
            self.pairs_group[self.pairs[i]]['pairs_unique'].append(self.pairs_unique[i])
            self.pairs_group[self.pairs[i]]['distance'].append(self.distance_all[i])
        return self.pairs_group

    # def get_CN(self, center_atom=None, error_bar=0.01):
    #     if not hasattr(self, 'pairs_types'):
    #         self.get_pairs()
            
    #     if not hasattr(self, 'CN_distances_old'):
    #         self.CN_distances_old = {}
    #         self.CN_old = {}
                    
    #     if center_atom is None:
    #         center_atom = self.atoms[0].symbol
                
    #     for pair_i in self.pairs_types:
    #         if center_atom in pair_i:
    #             distance_sorted = np.array(sorted(self.pairs_group[pair_i]['distance']))
    #             diff = np.diff(distance_sorted)
    #             indices = np.where(diff > error_bar)[0] + 1
    #             self.CN_distances_old[pair_i] = np.split(distance_sorted, indices)
    #             self.CN_old[pair_i] = {np.average(group): group.shape[0]/self.elements_num[center_atom] for group in self.CN_distances_old[pair_i]}
    #     return self.CN_old

    def get_CN(self, center_atom=None, CN_atom=None, error_bar=0.01):
        if not hasattr(self, 'pairs_types'):
            self.get_pairs()
            
        if not hasattr(self, 'CN_distances'):
            self.CN_distances = {}
            self.CN = {}
                    
        if center_atom is None:
            center_atom = self.atoms[0].symbol
        
        if CN_atom is None:
            CN_atom = self.atoms[1].symbol
            
        bond_type = f"{center_atom}-{CN_atom}"                
        center_atom_index = self.element_index_group[center_atom]
        CN_atom_index = self.element_index_group[CN_atom]
        
        distances_all = np.asarray([self.atoms.get_distances(atom_i, CN_atom_index) for atom_i in center_atom_index])
        distance_sorted = np.sort(distances_all.flatten())      
        distance_sorted = distance_sorted[distance_sorted!=0]
        diff = np.diff(distance_sorted)
        indices = np.where(diff > error_bar)[0] + 1
        self.CN_distances[bond_type] = np.split(distance_sorted, indices)
        self.CN[bond_type] = {np.average(group): group.shape[0]/self.elements_num[center_atom] for group in self.CN_distances[bond_type]}
        return self.CN
    
    # def get_all_distances()
    def plot_hist(self, binsize=0.2):
        if not hasattr(self, 'pairs_types'):
            self.get_pairs()      
        fig = go.Figure()
        for key_i in self.pairs_group.keys():
            fig.add_trace(go.Histogram(x=self.pairs_group[key_i]['distance'], name=key_i, opacity=0.6, 
                                       xbins={'size':binsize},marker={'line':{'color':'white','width':2}}))

        fig.update_layout(
            xaxis_title_text='Distances [A]', yaxis_title_text='pairs',
            plot_bgcolor='rgba(0.02,0.02,0.02,0.02)',  # Transparent plot background
            xaxis={'tickmode':'auto'}, barmode='overlay',  # Overlay histograms,
            width=600, height=400)
        fig.show()

## Examples

In [12]:
import ipywidgets as widgets
from IPython.display import display

# Create buttons and output widget
button1 = widgets.Button(description="Button 1")
button2 = widgets.Button(description="Button 2")
button3 = widgets.Button(description="Button 3")
output = widgets.Output()

# Event handler function
def on_button_clicked(b):
    with output:
        print(f"{b.description} clicked!")

# Linking buttons to the event handler
button1.on_click(on_button_clicked)
button2.on_click(on_button_clicked)
button3.on_click(on_button_clicked)

# Arrange buttons in a vertical box
box = widgets.HBox([button1, button2, button3, output])
display(box)


HBox(children=(Button(description='Button 1', style=ButtonStyle()), Button(description='Button 2', style=Butto…

In [33]:
import ipywidgets as widgets
from IPython.display import display, clear_output

def read_file():
    if fc.selected is not None:
        with open(fc.selected, 'r') as file:
            content = file.read()
        return content
    else:
        print("No file selected.")


# Create buttons and output widget
button_load = widgets.Button(description="Load xyz")
button_show = widgets.Button(description="Show")
button_hist = widgets.Button(description="Plot histogram")
button_getCN = widgets.Button(description="Calculate CN")
button_clear = widgets.Button(description="Clear")

output = widgets.Output()

def load_xyz(b):
    if fc.value is not None:
        print("load", fc.value)
        cluster = ClusterNeighbor()
        cluster.load_xyz(xyz_path=fc.value)

def on_button_load_clicked(b):
    # with output:
    print("load xyz")
    fc = FileChooser()
    display(fc)

def on_button_show_clicked(b):
    with output:
        # clear_output()  # Clear the output area
        
        print("Show xyz")
        # cluster.view_xyz(highlight_atom1="Pb", highlight_atom2="O", label=True)
        # print("center of mass:", cluster.center)

def on_button_hist_clicked(b):
    with output:
        # clear_output()  # Clear the output area
        cluster.plot_hist(binsize=0.1)
        # print("center of mass:", cluster.center)

def on_button_getCN_clicked(b):
    with output:
        # clear_output()  # Clear the output area
        print("Show xyz")
        cluster.get_CN()
        pprint(CN)
        # print("center of mass:", cluster.center)

def on_button_clear_clicked(b):
    with output:
        clear_output()  # Clear the output area
        
# Linking buttons to the event handler
button_load.on_click(on_button_load_clicked)
button_clear.on_click(on_button_clear_clicked)
button_show.on_click(on_button_show_clicked)
button_hist.on_click(on_button_hist_clicked)
button_getCN.on_click(on_button_getCN_clicked)

# Arrange buttons in a vertical box
box = widgets.HBox([button_show, button_hist, button_getCN, button_clear])
display(box)


HBox(children=(Button(description='Show', style=ButtonStyle()), Button(description='Plot histogram', style=But…

In [80]:
cluster = ClusterNeighbor()
cluster.load_xyz(xyz_path=fc.value)
cluster.view_xyz(highlight_atom1="Pb", highlight_atom2="O", label=True)
cluster.get_cluster_size()
cluster.atoms.get_angle(1,4,0)
print("center of mass:", cluster.center)

new_cluster = cluster.shrink_cluster_size(new_radius=3.6)
cluster_small = ClusterNeighbor()
cluster_small.load_xyz(from_file=False, atom_object=new_cluster)
cluster_small.view_xyz(highlight_atom1="Pb", highlight_atom2="O", label=True)
cluster_small.get_cluster_size()
print("center of mass:", cluster_small.center)


Cluster size is 5.926372577426853 A
center of mass: [-3.01051652e-16  1.12298166e-16  1.83279193e-16]


Cluster size is 3.3900003925275586 A
center of mass: [-4.71534994e-18  3.44525012e-19  6.52398545e-17]


# plot histogram

In [81]:
cluster.plot_hist(binsize=0.1)
cluster_small.plot_hist(binsize=0.1)

# get CNs

In [86]:
error_bar = 0.02
cluster.get_CN(center_atom='Pb', CN_atom='Pb', error_bar=error_bar)
cluster.get_CN(center_atom='Pb', CN_atom='O', error_bar=error_bar)
pprint(cluster.CN)

cluster_small.get_CN(center_atom='Pb', CN_atom='Pb', error_bar=error_bar)
cluster_small.get_CN(center_atom='Pb', CN_atom='O', error_bar=error_bar)
pprint(cluster_small.CN)


{'Pb-O': {2.164006661818935: 5.2,
          3.759464103600722: 2.4,
          4.0161487745681885: 2.4,
          4.450974934970656: 4.266666666666667,
          4.861048023898767: 0.9333333333333333,
          5.062180336008798: 3.7333333333333334,
          5.262128217289388: 1.6,
          5.917697886611758: 3.2,
          6.229254614092598: 2.6666666666666665,
          6.541817358979241: 2.1333333333333333,
          6.659154270243802: 1.3333333333333333,
          7.113771779844665: 0.5333333333333333,
          7.340720537707306: 2.4,
          7.472377961033678: 1.6,
          7.609254687726023: 0.5333333333333333,
          7.752545874266088: 1.0666666666666667,
          7.860522288892568: 1.6,
          8.08286153850564: 1.0666666666666667,
          8.342552820814575: 0.5333333333333333,
          8.536897202414814: 0.8,
          8.582446934276872: 0.26666666666666666,
          8.768876438013004: 0.5333333333333333,
          9.073544419622824: 0.26666666666666666,
       

In [87]:
error_bar = 0.02
cluster.get_CN(center_atom='Pb', CN_atom='Pb', error_bar=error_bar)
cluster.get_CN(center_atom='Pb', CN_atom='O', error_bar=error_bar)
pprint(cluster.CN)

cluster_small.get_CN(center_atom='Pb', CN_atom='Pb', error_bar=error_bar)
cluster_small.get_CN(center_atom='Pb', CN_atom='O', error_bar=error_bar)
pprint(cluster_small.CN)


{'Pb-O': {2.164006661818935: 5.2,
          3.759464103600722: 2.4,
          4.0161487745681885: 2.4,
          4.450974934970656: 4.266666666666667,
          4.861048023898767: 0.9333333333333333,
          5.062180336008798: 3.7333333333333334,
          5.262128217289388: 1.6,
          5.917697886611758: 3.2,
          6.229254614092598: 2.6666666666666665,
          6.541817358979241: 2.1333333333333333,
          6.659154270243802: 1.3333333333333333,
          7.113771779844665: 0.5333333333333333,
          7.340720537707306: 2.4,
          7.472377961033678: 1.6,
          7.609254687726023: 0.5333333333333333,
          7.752545874266088: 1.0666666666666667,
          7.860522288892568: 1.6,
          8.08286153850564: 1.0666666666666667,
          8.342552820814575: 0.5333333333333333,
          8.536897202414814: 0.8,
          8.582446934276872: 0.26666666666666666,
          8.768876438013004: 0.5333333333333333,
          9.073544419622824: 0.26666666666666666,
       

In [88]:
cluster_small.view_xyz(highlight_atom1="Pb", highlight_atom2="O", label=True)
cluster_small.atoms.get_distance(7,8)

3.3900003925275586

# test single

In [89]:
test2 = ClusterNeighbor()
test2.load_xyz(from_file=False, atom_object=cluster.atoms[:7])
test2.view_xyz(highlight_atom1="Pb", highlight_atom2="O", label=True)
test2.get_cluster_size()
test2.get_pairs()
test2.get_CN(center_atom='Pb', CN_atom='O', error_bar=0.01)
test2.CN_distances
test2.get_CN(center_atom='O', CN_atom='O', error_bar=0.01)
test2.get_CN(center_atom='O', CN_atom='Pb', error_bar=0.01)

Cluster size is 2.1692845084652332 A


{'Pb-O': {2.1534510867921863: 2.0, 2.169284413894355: 4.0},
 'O-O': {2.707596995825078: 0.6666666666666666,
  3.056656089842477: 2.6666666666666665,
  3.3899997023095034: 0.6666666666666666,
  4.30690217358428: 0.3333333333333333,
  4.338568827788602: 0.6666666666666666},
 'O-Pb': {2.1534510867921863: 0.3333333333333333,
  2.169284413894355: 0.6666666666666666}}

In [63]:
test2 = ClusterNeighbor()
test2.load_xyz(from_file=False, atom_object=cluster.atoms[:12])
test2.view_xyz(highlight_atom1="Pb", highlight_atom2="O", label=True)
test2.get_cluster_size()
test2.get_pairs()
test2.get_CN(center_atom='Pb', CN_atom='O', error_bar=0.02)
# test2.get_CN(center_atom='O', CN_atom='O', error_bar=0.02)
test2.get_CN(center_atom='O', CN_atom='Pb', error_bar=0.02)

# test2.get_CN(center_atom='Pb', CN_atom='Pb', error_bar=0.02)



Cluster size is 3.6503248720208448 A


{'Pb-O': {2.1640067315448177: 6.0,
  4.016148350175418: 2.0,
  5.262128024409925: 2.0},
 'O-Pb': {2.1640067315448177: 1.2,
  4.016148350175418: 0.4,
  5.262128024409925: 0.4}}

In [None]:
ase.io.write('reduced.xyz', test2.atoms)

In [None]:
plt.figure(figsize=(10,3))
for key_i in output['bond_types']:
    plt.hist(output[f"{key_i}_num"].values(), bins=80, alpha=0.3, edgecolor='white', label=key_i)
plt.xlabel("Distance [Å]")
plt.ylabel("Number of pairs")
plt.legend()
plt.ylim(0, 180)
plt.tight_layout()
plt.show()

75/number of Pd num_atoms (coordinations )

1. CN of Pd atoms, CN of Oxygen
2. calculate radius of the cluster: longest distance from (0,0,0)
3. increase radius of the cluster, doesn't need to shrink/expand the bond lengths, chop the hist and calculate the CN again
4. cif file --> cluster, increase clusters

Display local file.

In [66]:
view = py3Dmol.view(query='pdb:1ycr')
view.setStyle({'cartoon': {'color':'white'}})
view.addSurface(py3Dmol.VDW,{'opacity':0.7,'colorscheme':{'prop':'b','gradient':'sinebow','min':0,'max':70}})

<py3Dmol.view at 0x135b79f90>

In [67]:
import requests, base64
r = requests.get('https://mmtf.rcsb.org/v1.0/full/5lgo')
view = py3Dmol.view()
view.addModel(base64.b64encode(r.content).decode(),'mmtf')
view.addUnitCell()
view.zoomTo()


<py3Dmol.view at 0x135a6cf40>