In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn

In [None]:
import sys
sys.path.append('../')

In [None]:
import cace
from cace.representations.cace_representation import Cace

In [None]:
import numpy as np
from ase import Atoms
from ase.optimize import FIRE
from ase.visualize import view
from ase.md import Langevin
from ase import units
import numpy as np
import time
from ase.io import read,write

In [None]:
from cace.calculators import CACECalculator

In [None]:
cace_nnp = torch.load('best_model.pth')

In [None]:
trainable_params = sum(p.numel() for p in cace_nnp.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {trainable_params}")

In [None]:
calculator = CACECalculator(model_path=cace_nnp, 
                            device='cpu', 
                            energy_key='CACE_energy', 
                            forces_key='CACE_forces')

In [None]:
num_heavy_atoms_list = [ 
    # common composition
    {'C': 5, 'O': 1, 'N': 1, 'S': 0, 'Cl': 0},
    {'C': 4, 'O': 1, 'N': 2, 'S': 0, 'Cl': 0},
    {'C': 6, 'O': 1, 'N': 0, 'S': 0, 'Cl': 0},
    {'C': 5, 'O': 1, 'N': 1, 'S': 0, 'Cl': 0},
    {'C': 4, 'O': 1, 'N': 2, 'S': 0, 'Cl': 0},
    {'C': 5, 'O': 2, 'N': 0, 'S': 0, 'Cl': 0},
    {'C': 5, 'O': 1, 'N': 1, 'S': 0, 'Cl': 0},
    {'C': 6, 'O': 1, 'N': 0, 'S': 0, 'Cl': 0},
    {'C': 5, 'O': 0, 'N': 2, 'S': 0, 'Cl': 0},
]

In [None]:
min_distance = 0.7
box_size = 5



for num_heavy_atoms in num_heavy_atoms_list:
    n_h_max = num_heavy_atoms['C'] * 2 + num_heavy_atoms['N'] - num_heavy_atoms['Cl'] -2
    #print(n_h_max)
    for n_h in range(n_h_max, 5, -2):
        
        num_atoms = {'H': n_h}
        num_atoms.update(num_heavy_atoms)
        print(num_atoms)
        
        name_now = 'C'+str(num_atoms['C'])+'O'+str(num_atoms['O'])+'N'+str(num_atoms['N'])+\
        'S'+str(num_atoms['S'])+'Cl'+str(num_atoms['Cl'])+'H'+str(num_atoms['H'])
        
        file_name = 'relaxation-final-'+name_now+'.xyz'
        if os.path.isfile(file_name): 
            #continue
            lenfile = len(read(file_name,':'))
            if lenfile >= 16: 
                continue
                
            
        for _ in range(32):

            positions = []

            while len(positions) < sum(num_atoms.values()):
                new_pos = np.random.rand(3) * np.array([ box_size *2, box_size, box_size ])
                if all(np.linalg.norm(new_pos - p) >= min_distance for p in positions):
                    positions.append(new_pos)

            # Create ASE Atoms object
            atoms = Atoms(symbols=['C']*num_atoms['C'] \
                          + ['H']*num_atoms['H'] \
                          + ['N']*num_atoms['N'] \
                          + ['O']*num_atoms['O'] \
                          + ['S']*num_atoms['S'] \
                          + ['Cl']*num_atoms['Cl'], 
                          positions=positions)

            atoms.set_calculator(calculator)

            # Perform geometry optimization
            opt = FIRE(atoms, logfile=None)

            run = opt.run(fmax=0.001, steps=500)  # Adjust fmax for convergence criteria

            d_pos = 1.5 * (np.random.rand(*atoms.positions.shape) - 0.5)
            atoms.positions += d_pos

            run = opt.run(fmax=0.0002, steps=1000)  # Adjust fmax for convergence criteria

            if run: # and atoms.get_potential_energy()[0] >= -0.5:
                write(file_name, atoms, append=True)

In [None]:
import pandas as pd

In [None]:
percentages = {}
for num_heavy_atoms in num_heavy_atoms_list[:]:
    n_h_max = num_heavy_atoms['C'] * 2 + num_heavy_atoms['N'] - num_heavy_atoms['Cl'] -2
    #print(n_h_max)
    for n_h in range(n_h_max, 5, -2):
        
        num_atoms = {'H': n_h}
        num_atoms.update(num_heavy_atoms)
        print(num_atoms)
        
        name_now = 'C'+str(num_atoms['C'])+'O'+str(num_atoms['O'])+'N'+str(num_atoms['N'])+\
        'S'+str(num_atoms['S'])+'Cl'+str(num_atoms['Cl'])+'H'+str(num_atoms['H'])
        
        
        file_name = 'relaxation-final-'+name_now+'.xyz'
        if os.path.isfile(file_name) is False: continue
        if os.path.isfile('check-'+name_now+'.csv') is False: 
            ! bash check.sh $name_now
        
        df = pd.read_csv('check-'+name_now+'.csv')
        
        # Replace NaN values with False
        df.fillna(False, inplace=True)
        df['all_pass'] = df.all(axis=1)


        percentages[name_now] = {}
        # Calculate the percentage of True values per column
        total_rows = len(df)
        for column in df.columns[3:]:
            true_count = df[column].sum()
            percentage_true = round((true_count / total_rows) * 100)
            percentages[name_now][column] = int(percentage_true)  # Convert to integer

In [None]:
percentages = {}
df={}
for num_heavy_atoms in num_heavy_atoms_list[:]:
    heavy_name=''
    for ele in ['C', 'O', 'N', 'S', 'Cl']:
        if num_heavy_atoms[ele] > 0:
            heavy_name+=ele+str(num_heavy_atoms[ele])
    #print(heavy_name)
    df[heavy_name] = pd.DataFrame()
    
    n_h_max = num_heavy_atoms['C'] * 2 + num_heavy_atoms['N'] - num_heavy_atoms['Cl'] -2
    #print(n_h_max)
    for n_h in range(n_h_max, 5, -2):
        
        num_atoms = {'H': n_h}
        num_atoms.update(num_heavy_atoms)
        #print(num_atoms)
        
        name_now = 'C'+str(num_atoms['C'])+'O'+str(num_atoms['O'])+'N'+str(num_atoms['N'])+\
        'S'+str(num_atoms['S'])+'Cl'+str(num_atoms['Cl'])+'H'+str(num_atoms['H'])
        
        
        file_name = 'relaxation-final-'+name_now+'.xyz'
        if os.path.isfile(file_name) is False: continue
        
        if os.path.isfile('check-'+name_now+'.csv') is False: 
            ! bash check.sh $name_now
        df2 = pd.read_csv('check-'+name_now+'.csv')
        print(heavy_name, name_now)
        df[heavy_name] = pd.concat([df[heavy_name], df2], ignore_index=True)
    
    if len(df[heavy_name]) == 0:
        del df[heavy_name]
        continue
    # Replace NaN values with False
    df[heavy_name].fillna(False, inplace=True)
    df[heavy_name]['all_pass'] = df[heavy_name].all(axis=1)


    percentages[heavy_name] = {}
    
    # Calculate the percentage of True values per column
    total_rows = len(df[heavy_name])
    for column in df[heavy_name].columns[3:]:
        true_count = df[heavy_name][column].sum()
        percentage_true = round((true_count / total_rows) * 100)
        percentages[heavy_name][column] = int(percentage_true)  # Convert to integer

In [None]:
for name_now in ['QM7b', 
                ]:
        
    if os.path.isfile('check-'+name_now+'.csv') is False: 
        ! bash check.sh $name_now
        
    df = pd.read_csv('check-'+name_now+'.csv')
        
    # Replace NaN values with False
    df.fillna(False, inplace=True)
    df['all_pass'] = df.all(axis=1)


    percentages[name_now] = {}
    # Calculate the percentage of True values per column
    total_rows = len(df)
    for column in df.columns[2:]:
        true_count = df[column].sum()
        percentage_true = round((true_count / total_rows) * 100)
        percentages[name_now][column] = int(percentage_true)  # Convert to integer

In [None]:
import matplotlib.pyplot as plt
columns = [ 'QM7b',\
            'C5O1N1', \
            'C4O1N2', \
            'C6O1', \
            'C5O2Cl1', \
            'C6O1N1', \
           'C6O1S1', \
           'C7Cl1', \
           'C7O1Cl1', \
           'C7O1N1', \
           'C6O1N1Cl1', \
           ]
rows = list(next(iter(percentages.values())).keys())
data = [[percentages[column][row] for column in columns] for row in rows]

fig, ax = plt.subplots(figsize=(5,5))

data = list(zip(*data))
matrix = ax.imshow(data, cmap='viridis', interpolation='nearest')

# Add color bar
cbar = plt.colorbar(matrix, location='right', orientation='vertical',pad=0.06)
cbar.set_label("pass [%]")

ax.set_yticks(range(len(columns)))
ax.set_yticklabels(columns, rotation=0, ha='right', va='bottom')
ax.xaxis.set_ticks_position('top')  # Move x-axis ticks to the top
ax.set_xticks(range(len(rows)))
ax.set_xticklabels(rows, rotation=90)

for j in range(len(rows)):
    for i in range(len(columns)):
        ax.text(j, i, f'{data[i][j]}', va='center', ha='center', color='red')
        

plt.tight_layout()
fig.subplots_adjust(left=0.25)

plt.savefig('qm7b-pass-2.pdf')

plt.show()