In [1]:
from environment import Molecule_Environment
import numpy as np
from reinforce import *
import shutil
import os.path
import matplotlib.pyplot as plt


In [2]:
# Assuming Molecule_Environment is defined as provided and properly imported
number_of_atoms = 7
env = Molecule_Environment(n_atoms = number_of_atoms, chemical_symbols = ["B"], dimensions = (21,21,21), resolution=np.array([0.1,0.1,0.1]), ref_spectra_path = op.join(script_dir,op.join('references','reference_custom_1.dat')), print_spectra=0)
flatten_dimensions = np.prod(env.dimensions)
# state_size = math.comb(flatten_dimensions, number_of_atoms-1)  # Flattened state size
state_size = flatten_dimensions
# state_size = 10**6
action_size = len(env.actions)
policy = Policy(state_size, action_size).to(device)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
    



In [4]:
dir_path = "ir"  # Change this to an absolute path if needed, e.g., r"C:\path\to\ir"
if os.path.exists(dir_path):
    print(f"Directory {dir_path} exists, attempting to remove it.")
    try:
        shutil.rmtree(dir_path, ignore_errors=True)
        print(f"Directory {dir_path} removed successfully.")
    except Exception as e:
        print(f"An error occurred while trying to remove the directory: {e}")
else:
    print(f"Directory {dir_path} does not exist.")
scores = reinforce(policy, optimizer,env=env, n_episodes=100)



Directory ir does not exist.


KeyboardInterrupt: 

In [5]:
# Use the trained policy to generate the molecule
state = env.reset()
flattened_state = get_flattened_state(state)
done = False

while not done:
    action_idx, _ = policy.act(flattened_state,epsilon=0.0)
    action = env.actions[action_idx]  # Convert action index to coordinates
    state, _, done = env.step(action)
    flattened_state = get_flattened_state(state)

# Plotting the spectra
ref_spectra_y = env.ref_spectra[:, 1]
atom_pos = np.where(env.state == 1)
coords_atom = list(zip(*atom_pos))
positions = np.array(coords_atom) * env.resolution
spectra = spectra_from_arrays(positions=positions, chemical_symbols=env.chem_symbols, name=env.name, writing=False)
print(positions)
spectra_y = spectra[:, 1]
np.where(env.state == 1)
print(policy)



KeyboardInterrupt: 

In [None]:
#################################################
# test function
point1 = [-0.475, -0.475, 0.0]
point2 = [0.475, 0.475, 0.0]

print(calculate_distance(point1, point2))
print(calculate_distance(positions[0], positions[1]))

##################


In [None]:

plt.figure(figsize=(10, 5))
plt.plot(env.ref_spectra[:, 0], ref_spectra_y, label='Reference Spectrum')
plt.plot(spectra[:, 0], spectra_y, label='Generated Spectrum', linestyle='--')
plt.xlabel('Wavenumber (cm^-1)')
plt.ylabel('Intensity')
plt.legend()
plt.show()


resolution = env.resolution
grid_dimensions = env.dimensions
print(grid_dimensions)
plot_3d_structure(positions, resolution, grid_dimensions)