# Imports

In [None]:
import os
import csv
import sys 
import torch


from openmm import unit 
from tqdm import tqdm


sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))
from energies import MoleculeFromSMILES_XTB
from utils import logmeanexp

# Constants

In [None]:
input_file = '../database.txt'
batch_size = 32
kB = unit.BOLTZMANN_CONSTANT_kB.value_in_unit(unit.hartree/unit.kelvin)
beta=1/(kB*298.15)
hartree_to_kcal = 627.503
T = 298.15

In [None]:
molecules_smiles = {}
with open(input_file, 'r') as infile:
    reader = csv.reader(infile, delimiter=';')
    for row in reader:
        if row[0].startswith('#'):
            continue  # Skip header or comment lines
        smiles = row[1]
        experimental_val = row[3]
        experimental_uncertainty = row[4]
        molecules_smiles[smiles] = experimental_val, experimental_uncertainty

In [None]:
results_dict = {}

In [None]:
for k, v in molecules_smiles.items():
    v_energy = MoleculeFromSMILES_XTB(k, temp=T, solvate=False)
    s_energy = MoleculeFromSMILES_XTB(k, temp=T, solvate=True)
    ground_truth = float(v[0])
    uncertainty = float(v[1])

    default_ta_vals = v_energy.rd_conf.get_freely_rotatable_tas_values()

    torsions = v_energy.tas
    if len(torsions) == 0:
        print('No torsions found for {}'.format(k))
        continue
    n = 2 if len(torsions) >= 2 else len(torsions)
    # create a linspace of 0 to 360 with 5 degree increments for each torsion in a new torch tensor
    # mix them to get all possible permutations depending on the number of torsions
    angles = torch.linspace(0, 360, 73)[:-1]    # flatten the tensor
    grids = [angles] * n
    grid = torch.meshgrid(*grids, indexing='ij')
    cartesian_product = torch.stack(grid, dim=-1).reshape(-1, n)

    # add the rest of the angles to the cartesian product from the default ta values
    cartesian_product = torch.cat((cartesian_product, torch.tensor(default_ta_vals[n:]).unsqueeze(0).repeat(cartesian_product.shape[0], 1)), dim=1)
    
    # make the cartesion product in batches of 32 to avoid memory issues, discard last 32 if not divisible by 32
    num_batches = cartesian_product.shape[0] // batch_size
    if cartesian_product.shape[0] % batch_size != 0:
        num_batches += 1
    
    # create a tensor to store the energies
    v_energies = torch.zeros(cartesian_product.shape[0])
    s_energies = torch.zeros(cartesian_product.shape[0])
    for i in tqdm(range(num_batches)):
        start = i * batch_size
        end = min((i + 1) * batch_size, cartesian_product.shape[0])
        batch = cartesian_product[start:end]
        v_energies[start:end] = v_energy.energy(batch) 
        s_energies[start:end] = s_energy.energy(batch)

    v_energies *= beta
    s_energies *= beta

    factor = hartree_to_kcal * kB * T 
    v_free_energy = -logmeanexp(-v_energies) * factor
    s_free_energy = -logmeanexp(-s_energies) * factor

    fed = s_free_energy - v_free_energy 
    results_dict[k] = fed, ground_truth, uncertainty
    print('Error for {}: {}'.format(k, fed - ground_truth))


In [None]:
len(results_dict)

In [None]:
with open(f'../fed_results/rotation_fed_n={n}.csv', 'w') as f:
    f.write('SMILES,experimental_val,fed_Z_learned,fed_Z,fed_Z_lb,timestamp\n')
    for k, v in results_dict.items():
        f.write(f'{k},{v[1]} ± {v[2]},0 ± 0,{v[0]} ± 0,0 ± 0,23-08-2024 02-47\n')