In [1]:
%load_ext autoreload
%autoreload 2
from torch_geometric.data import Data, Batch
from ogb.utils.features import (atom_to_feature_vector,
 bond_to_feature_vector) 
from rdkit import Chem
from rdkit.Chem import rdDetermineBonds
import numpy as np
import torch
import os 
import sys

In [2]:
def get_mol_objects(filename):
    mol_list = []
    energy_list = []
   # energies = 
    with open(filename, "r") as f:
        lines = f.readlines()
        file_str  = "".join(lines)
        atom_num = lines[0]
    xyz_list = file_str.split(atom_num)[1:]
    for i in range(len(xyz_list)):
        x = Chem.MolFromXYZBlock(f'{atom_num.strip()}\n' + xyz_list[i])
        Chem.rdDetermineBonds.DetermineConnectivity(x)
        energy = float(xyz_list[i].split('\n')[0].strip())
        energy_list.append(energy)
        mol_list.append(x)
    return mol_list, energy_list

In [3]:
import contextlib
import os
import re
import subprocess
import warnings

def _get_energy(file):
    normal_termination = False
    with open(file) as f:
        for l in f:
            if "TOTAL ENERGY" in l:
                try:
                    energy = float(re.search(r"[+-]?(?:\d*\.)?\d+", l).group())
                except:
                    return np.nan
            if "normal termination of xtb" in l:
                normal_termination = True
    if normal_termination:
        return energy
    else:
        return np.nan

def run_gfn_xtb(
    filepath,
    filename,
    gfn_version="gfnff",
    opt=False,
    gfn_xtb_config: str = None,
    remove_scratch=True,
):
    """
    Runs GFN_XTB/FF given a directory and either a coord file or all coord files will be run

    :param filepath: Directory containing the coord file
    :param filename: if given, the specific coord file to run
    :param gfn_version: GFN_xtb version (default is 2)
    :param opt: optimization or single point (default is opt)
    :param gfn_xtb_config: additional xtb config (default is None)
    :param remove_scratch: remove xtb files
    :return:
    """
    xyz_file = os.path.join(filepath, filename)

    # optimization vs single point
    if opt:
        opt = "--opt"
    else:
        opt = ""

    # cd to filepath
    starting_dir = os.getcwd()
    os.chdir(filepath)

    file_name = str(xyz_file.split(".")[0])
    cmd = "xtb --{} {} {} {}".format(
        str(gfn_version), xyz_file, opt, str(gfn_xtb_config or "")
    )

    # run XTB
    with open(file_name + ".out", "w") as fd:
        subprocess.run(cmd, shell=True, stdout=fd, stderr=subprocess.STDOUT)

    # check XTB results
    if os.path.isfile(os.path.join(filepath, "NOT_CONVERGED")):
        # optimization not converged
        warnings.warn(
            "xtb --{} for {} is not converged, using last optimized step instead; proceed with caution".format(
                str(gfn_version), file_name
            )
        )

        # remove files
        if remove_scratch:
            os.remove(os.path.join(filepath, "NOT_CONVERGED"))
            os.remove(os.path.join(filepath, "xtblast.xyz"))
            os.remove(os.path.join(filepath, file_name + ".out"))
        energy = np.nan

    elif opt and not os.path.isfile(os.path.join(filepath, "xtbopt.xyz")):
        # other abnormal optimization convergence
        warnings.warn(
            "xtb --{} for {} abnormal termination, likely scf issues, using initial geometry instead; proceed with caution".format(
                str(gfn_version), file_name
            )
        )
        if remove_scratch:
            os.remove(os.path.join(filepath, file_name + ".out"))
        energy = np.nan

    else:
        # normal convergence
        # get energy
        energy = _get_energy(file_name + ".out")
        if remove_scratch:
            with contextlib.suppress(FileNotFoundError):
                os.remove(os.path.join(filepath, file_name + ".out"))
                os.remove(os.path.join(filepath, "gfnff_charges"))
                os.remove(os.path.join(filepath, "gfnff_adjacency"))
                os.remove(os.path.join(filepath, "gfnff_topo"))
                os.remove(os.path.join(filepath, "xtbopt.log"))
                os.remove(os.path.join(filepath, "xtbopt.xyz"))
                os.remove(os.path.join(filepath, "xtbtopo.mol"))
                os.remove(os.path.join(filepath, "wbo"))
                os.remove(os.path.join(filepath, "charges"))
                os.remove(os.path.join(filepath, "xtbrestart"))
    os.chdir(starting_dir)
    return energy

In [4]:
def get_mol_path(mol_idx, solvation=False):
    if solvation:
        return os.path.join(os.getcwd(), 'conformers', f'molecule_{mol_idx}', 'solvation')
    else:
        return os.path.join(os.getcwd(), 'conformers', f'molecule_{mol_idx}', 'vacuum')
get_mol_path(0)

'/home/radoslavralev/Documents/Thesis/gfn-diffusion/energy_sampling/notebooks/conformers/molecule_0/vacuum'

In [5]:
torch.ones(32, 69).reshape(-1, 23, 3).shape

torch.Size([32, 23, 3])

In [6]:
def xyz_mol2graph(xyz_mol):
    """
    Converts SMILES string to graph Data object
    :input: SMILES string (str)
    :return: graph object
    """

    mol = xyz_mol
    mol = Chem.AddHs(mol)
    #mol = Chem.RemoveHs(mol)
    pos = mol.GetConformer().GetPositions()
    #print(mol)
    # atoms
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_features_list.append([atom.GetAtomicNum()])
    x = np.array(atom_features_list, dtype = np.int64)

    # bonds
    num_bond_features = 3  # bond type, bond stereo, is_conjugated
    if len(mol.GetBonds()) > 0: # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()

            edge_feature = bond_to_feature_vector(bond)

            # add edges in both directions
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = np.array(edges_list, dtype = np.int64).T

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = np.array(edge_features_list, dtype = np.int64)

    else:   # mol has no bonds
        print('Mol has no bonds :()')
        edge_index = np.empty((2, 0), dtype = np.int64)
        edge_attr = np.empty((0, num_bond_features), dtype = np.int64)

    graph = dict()
    graph['edge_index'] = edge_index
    graph['edge_feat'] = edge_attr
    graph['node_feat'] = x
    graph['pos'] = pos
    graph['num_nodes'] = len(x)
    graph['num_bonds'] = len(mol.GetBonds())

    return graph 

In [7]:
def prep_input(graph, pos=None, device=None):
    datalist = []
    for xyz in pos:
        if pos is not None:
            graph['pos'] = xyz  
        #print('Number of atoms:', graph['node_feat'].shape[0])
        data = Data(
            atoms=torch.from_numpy(graph['node_feat'][:, 0]), 
            edge_index=torch.from_numpy(graph['edge_index']), 
            edge_attr=torch.from_numpy(graph['edge_feat']), 
            pos=graph['pos'],).to(device)
        data.validate(raise_on_error=True)
        datalist.append(data)
    #batch = Batch.from_data_list(datalist).to(device)
    return datalist

In [8]:
def extract_graphs(filename):
    mol_objects, mol_ens = get_mol_objects(filename)
    datalist = []
    for mol, en in zip(mol_objects, mol_ens):
        graph = xyz_mol2graph(mol)
        if graph['num_bonds'] == 0:
            continue
        data = Data(
            atoms=torch.from_numpy(graph['node_feat'][:, 0]), 
            edge_index=torch.from_numpy(graph['edge_index']), 
            edge_attr=torch.from_numpy(graph['edge_feat']), 
            pos=torch.from_numpy(graph['pos']),
            y=torch.tensor(en))
        data.validate(raise_on_error=True)
        datalist.append(data)
    return datalist

In [9]:
# calculate free energy for all pairs of conformers

def solvation_en(index=0):
    solvent_dir = os.path.join('../..', 'conformation_sampling', 'conformers', f'molecule_{index}', 'solvation', 'crest_conformers.xyz')
    vacuum_dir = os.path.join('../..', 'conformation_sampling', 'conformers', f'molecule_{index}', 'vacuum', 'crest_conformers.xyz')
    _, solv_en = get_mol_objects(solvent_dir)
    _, vac_en = get_mol_objects(vacuum_dir)
    #print(solv_en, vac_en)
    # avg_solv_en = np.mean(np.exp(np.array(solv_en)))
    # avg_vac_en = np.mean(np.exp(np.array(vac_en))) 
    # free_en = avg_vac_en - avg_solv_en
    l = min(len(solv_en), len(vac_en))
    free_en = -1/627.583*np.logaddexp.reduce(-(np.array(solv_en[:l])-np.array(vac_en[:l])) * 627.503) - np.log(l)
    return free_en
    
for i in range(600):
    free_en = solvation_en(i)
    if free_en:
        print(free_en)
    if i == 10:
        break
    
    

-5.070985140121123
-2.717450109829989
-0.7029199965511594
-1.9523127350697063
-0.700240312316485
-3.3073722504625738
-0.003604080517829243
-0.0033940972884697737
-1.7962577726122022
-2.7146895664696427
-3.6762836631174647


In [10]:
solvent_data = []
vacuum_data = []
for dir in os.listdir(os.path.join(os.getcwd(),'..', '..', 'conformation_sampling', 'conformers')):
    solvent_dir = os.path.join(os.getcwd(), '..', '..','conformation_sampling', 'conformers', dir, 'solvation', 'crest_conformers.xyz')
    vacuum_dir = os.path.join(os.getcwd(),'..', '..','conformation_sampling', 'conformers', dir, 'vacuum', 'crest_conformers.xyz')
    solvent_graphs = extract_graphs(solvent_dir)
    vacuum_graphs = extract_graphs(vacuum_dir)
    solvent_data.extend(solvent_graphs)
    vacuum_data.extend(vacuum_graphs)

In [11]:
atom_counts = {}
chemical_letters = {1: 'H', 6: 'C', 7: 'N', 8: 'O', 9: 'F', 15: 'P', 16: 'S', 17: 'Cl', 35: 'Br', 53: 'I'}

for v in solvent_data:
    atoms = v.atoms.tolist()
    for atom in atoms:
        atom_type = chemical_letters[atom]
        if atom_type in atom_counts:
            atom_counts[atom_type] += 1
        else:
            atom_counts[atom_type] = 1

atom_counts

{'C': 397716,
 'O': 81001,
 'H': 796336,
 'N': 8659,
 'Cl': 4919,
 'F': 1495,
 'I': 309,
 'Br': 888,
 'S': 23004,
 'P': 11046}

In [67]:
ANI2x.species_converter

SpeciesConverter()

In [12]:
# count the number of molecules in solvent_data that have only HCNOFSCl atoms
count = 0
for v in solvent_data:
    atoms = v.atoms.tolist()
    atom_types = [chemical_letters[atom] for atom in atoms]
    if all(atom in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'] for atom in atom_types):
        count += 1
count

36246

In [13]:
# select only molecules that have only HCNOFSCl atoms
solvent_data = [v for v in solvent_data if all(atom in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'] for atom in [chemical_letters[atom] for atom in v.atoms.tolist()])]

In [55]:
vacuum_data = [v for v in vacuum_data if all(atom in ['H', 'C', 'N', 'O', 'F', 'S', 'Cl'] for atom in [chemical_letters[atom] for atom in v.atoms.tolist()])]

In [56]:
import torch
import torchani
import os
import math
import torch.utils.tensorboard
import tqdm
import pickle

# helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol

ANI2x = torchani.models.ANI2x(periodic_table_index=True)
# device to run the training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def iter_data(data_list):
    for data in data_list:
        species = data.atoms
        coordinates = data.pos
        energy = data.y
        species = [chemical_letters[atom.item()] for atom in species]
        ret = {'species': species, 'coordinates': coordinates, 'energies': energy}
        yield ret

def to_torchani_iterator(data_list):
    return torchani.data.TransformableIterable(torchani.data.IterableAdapter(lambda: iter_data(data_list)))


/home/radoslavralev/miniconda3/envs/gfn-energy/lib/python3.11/site-packages/torchani/resources/


In [57]:
batch_size = 64

species_order = ANI2x.species
num_species = len(species_order)
aev_computer = ANI2x.aev_computer#torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = ANI2x.energy_shifter


In [58]:
pickled_dataset_path = 'dataset.pkl'


# if os.path.isfile(pickled_dataset_path):
#     print(f'Unpickling preprocessed dataset found in {pickled_dataset_path}')
#     with open(pickled_dataset_path, 'rb') as f:
#         dataset = pickle.load(f)
#     training = dataset['training'].collate(batch_size).cache()
#     validation = dataset['validation'].collate(batch_size).cache()
#     energy_shifter.self_energies = dataset['self_energies'].to(device)
# else:
training, validation = to_torchani_iterator(vacuum_data)\
                                    .subtract_self_energies(energy_shifter, species_order)\
                                    .species_to_indices(species_order)\
                                    .shuffle()\
                                    .split(0.8, None)
with open(pickled_dataset_path, 'wb') as f:
    pickle.dump({'training': training,
                    'validation': validation,
                    'self_energies': energy_shifter.self_energies.cpu()}, f)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies)

Self atomic energies:  tensor([-0.5111, -2.1399, -2.8945, -4.0611, -3.2814, -4.7150, -4.5564],
       dtype=torch.float64)


In [59]:
nn = ANI2x.neural_networks.to(device)
print(nn)

Ensemble(
  (0-7): 8 x ANIModel(
    (H): Sequential(
      (0): Linear(in_features=1008, out_features=256, bias=True)
      (1): CELU(alpha=0.1)
      (2): Linear(in_features=256, out_features=192, bias=True)
      (3): CELU(alpha=0.1)
      (4): Linear(in_features=192, out_features=160, bias=True)
      (5): CELU(alpha=0.1)
      (6): Linear(in_features=160, out_features=1, bias=True)
    )
    (C): Sequential(
      (0): Linear(in_features=1008, out_features=224, bias=True)
      (1): CELU(alpha=0.1)
      (2): Linear(in_features=224, out_features=192, bias=True)
      (3): CELU(alpha=0.1)
      (4): Linear(in_features=192, out_features=160, bias=True)
      (5): CELU(alpha=0.1)
      (6): Linear(in_features=160, out_features=1, bias=True)
    )
    (N): Sequential(
      (0): Linear(in_features=1008, out_features=192, bias=True)
      (1): CELU(alpha=0.1)
      (2): Linear(in_features=192, out_features=160, bias=True)
      (3): CELU(alpha=0.1)
      (4): Linear(in_features=160, ou

In [60]:
model = torchani.nn.Sequential(aev_computer, nn).to(device)

In [61]:
AdamW = torch.optim.AdamW(model.parameters(), lr=1e-3)
AdamW_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(AdamW, factor=0.5, patience=5, threshold=0)

In [62]:
latest_checkpoint = 'latest.pt'

In [63]:
if os.path.isfile(latest_checkpoint):
    checkpoint = torch.load(latest_checkpoint)
    nn.load_state_dict(checkpoint['nn'])
    AdamW.load_state_dict(checkpoint['AdamW'])
    AdamW_scheduler.load_state_dict(checkpoint['AdamW_scheduler'])

In [64]:
def validate():
    # run validation
    mse_sum = torch.nn.MSELoss(reduction='sum')
    total_mse = 0.0
    count = 0
    model.train(False)
    with torch.no_grad():
        for properties in validation:
            species = properties['species'].to(device)
            coordinates = properties['coordinates'].to(device).float()
            true_energies = properties['energies'].to(device).float()
            _, predicted_energies = model((species, coordinates))
            total_mse += mse_sum(predicted_energies, true_energies).item()
            count += predicted_energies.shape[0]
    model.train(True)
    return hartree2kcalmol(math.sqrt(total_mse / count))


In [65]:
tensorboard = torch.utils.tensorboard.SummaryWriter()


In [66]:
mse = torch.nn.MSELoss(reduction='none')

print("training starting from epoch", AdamW_scheduler.last_epoch + 1)
max_epochs = 150
early_stopping_learning_rate = 5E-6
best_model_checkpoint = 'best.pt'

for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
    rmse = validate()
    print('RMSE:', rmse, 'at epoch', AdamW_scheduler.last_epoch + 1)

    learning_rate = AdamW.param_groups[0]['lr']

    if learning_rate < early_stopping_learning_rate:
        break

    # checkpoint
    if AdamW_scheduler.is_better(rmse, AdamW_scheduler.best):
        torch.save(nn.state_dict(), best_model_checkpoint)

    AdamW_scheduler.step(rmse)

    tensorboard.add_scalar('validation_rmse', rmse, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
    tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)

    for i, properties in tqdm.tqdm(
        enumerate(training),
        total=len(training),
        desc="epoch {}".format(AdamW_scheduler.last_epoch)
    ):
        species = properties['species'].to(device)
        coordinates = properties['coordinates'].to(device).float()
        true_energies = properties['energies'].to(device).float()
        num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
        _, predicted_energies = model((species, coordinates))

        loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()

        AdamW.zero_grad()
        loss.backward()
        AdamW.step()

        # write current batch loss to TensorBoard
        tensorboard.add_scalar('batch_loss', loss, AdamW_scheduler.last_epoch * len(training) + i)

    torch.save({
        'nn': nn.state_dict(),
        'AdamW': AdamW.state_dict(),
        'AdamW_scheduler': AdamW_scheduler.state_dict(),
    }, latest_checkpoint)

training starting from epoch 1
RMSE: 101.01377577878378 at epoch 1


epoch 1: 100%|██████████| 442/442 [00:15<00:00, 29.18it/s]


RMSE: 3.89293588957423 at epoch 2


epoch 2: 100%|██████████| 442/442 [00:14<00:00, 29.84it/s]


RMSE: 6.997254509163164 at epoch 3


epoch 3: 100%|██████████| 442/442 [00:14<00:00, 29.65it/s]


RMSE: 4.954335097643003 at epoch 4


epoch 4: 100%|██████████| 442/442 [00:15<00:00, 29.45it/s]


RMSE: 2.172582992088832 at epoch 5


epoch 5: 100%|██████████| 442/442 [00:15<00:00, 29.07it/s]


RMSE: 2.845150646282102 at epoch 6


epoch 6: 100%|██████████| 442/442 [00:15<00:00, 29.41it/s]


RMSE: 5.100510738382582 at epoch 7


epoch 7: 100%|██████████| 442/442 [00:14<00:00, 29.63it/s]


RMSE: 1.1739863846087772 at epoch 8


epoch 8: 100%|██████████| 442/442 [00:15<00:00, 29.37it/s]


RMSE: 3.113753087515694 at epoch 9


epoch 9: 100%|██████████| 442/442 [00:14<00:00, 29.55it/s]


RMSE: 1.116588343382995 at epoch 10


epoch 10: 100%|██████████| 442/442 [00:15<00:00, 29.26it/s]


RMSE: 5.265427607282203 at epoch 11


epoch 11: 100%|██████████| 442/442 [00:15<00:00, 29.36it/s]


RMSE: 2.1433374008347177 at epoch 12


epoch 12: 100%|██████████| 442/442 [00:15<00:00, 28.69it/s]


RMSE: 1.4139769777046318 at epoch 13


epoch 13: 100%|██████████| 442/442 [00:15<00:00, 29.04it/s]


RMSE: 29.22308247343612 at epoch 14


epoch 14: 100%|██████████| 442/442 [00:15<00:00, 29.15it/s]


RMSE: 2.8811220807479345 at epoch 15


epoch 15: 100%|██████████| 442/442 [00:15<00:00, 28.99it/s]


RMSE: 1.5246435340700901 at epoch 16


epoch 16: 100%|██████████| 442/442 [00:15<00:00, 28.81it/s]


RMSE: 1.0642020430142651 at epoch 17


epoch 17: 100%|██████████| 442/442 [00:15<00:00, 29.11it/s]


RMSE: 1.0494452159975038 at epoch 18


epoch 18: 100%|██████████| 442/442 [00:15<00:00, 28.38it/s]


RMSE: 1.013750895121347 at epoch 19


epoch 19: 100%|██████████| 442/442 [00:15<00:00, 28.57it/s]


RMSE: 1.3260028814415854 at epoch 20


epoch 20: 100%|██████████| 442/442 [00:15<00:00, 27.84it/s]


RMSE: 1.6707552101763181 at epoch 21


epoch 21: 100%|██████████| 442/442 [00:15<00:00, 28.55it/s]


RMSE: 1.4106482207034368 at epoch 22


epoch 22: 100%|██████████| 442/442 [00:15<00:00, 27.83it/s]


RMSE: 1.404929433651668 at epoch 23


epoch 23: 100%|██████████| 442/442 [00:16<00:00, 26.94it/s]


RMSE: 1.8751688753375695 at epoch 24


epoch 24: 100%|██████████| 442/442 [00:16<00:00, 26.97it/s]


RMSE: 2.1445795280521978 at epoch 25


epoch 25: 100%|██████████| 442/442 [00:15<00:00, 28.30it/s]


RMSE: 0.8904583497509705 at epoch 26


epoch 26: 100%|██████████| 442/442 [00:15<00:00, 29.09it/s]


RMSE: 1.0895404936865016 at epoch 27


epoch 27: 100%|██████████| 442/442 [00:15<00:00, 29.17it/s]


RMSE: 0.8488081070669322 at epoch 28


epoch 28: 100%|██████████| 442/442 [00:15<00:00, 28.20it/s]


RMSE: 1.0656490513258547 at epoch 29


epoch 29: 100%|██████████| 442/442 [00:16<00:00, 26.82it/s]


RMSE: 2.4811407354081174 at epoch 30


epoch 30: 100%|██████████| 442/442 [00:15<00:00, 27.99it/s]


RMSE: 2.496989961682674 at epoch 31


epoch 31: 100%|██████████| 442/442 [00:15<00:00, 28.73it/s]


RMSE: 2.078021766110859 at epoch 32


epoch 32: 100%|██████████| 442/442 [00:15<00:00, 29.05it/s]


RMSE: 1.698058521819944 at epoch 33


epoch 33: 100%|██████████| 442/442 [00:16<00:00, 26.22it/s]


RMSE: 1.2743277512645481 at epoch 34


epoch 34: 100%|██████████| 442/442 [00:16<00:00, 26.97it/s]


RMSE: 0.7011514032865791 at epoch 35


epoch 35: 100%|██████████| 442/442 [00:16<00:00, 26.64it/s]


RMSE: 0.7703878685920543 at epoch 36


epoch 36: 100%|██████████| 442/442 [00:16<00:00, 26.14it/s]


RMSE: 1.2451485278607215 at epoch 37


epoch 37: 100%|██████████| 442/442 [00:16<00:00, 26.87it/s]


RMSE: 0.9332580575949206 at epoch 38


epoch 38: 100%|██████████| 442/442 [00:16<00:00, 26.97it/s]


RMSE: 0.9395814944699116 at epoch 39


epoch 39: 100%|██████████| 442/442 [00:16<00:00, 26.78it/s]


RMSE: 1.3913587602143522 at epoch 40


epoch 40: 100%|██████████| 442/442 [00:16<00:00, 26.54it/s]


RMSE: 0.7083518344090434 at epoch 41


epoch 41: 100%|██████████| 442/442 [00:16<00:00, 26.84it/s]


RMSE: 0.6274154011066335 at epoch 42


epoch 42: 100%|██████████| 442/442 [00:16<00:00, 27.19it/s]


RMSE: 0.5841222733766287 at epoch 43


epoch 43: 100%|██████████| 442/442 [00:16<00:00, 27.17it/s]


RMSE: 0.7378869674813192 at epoch 44


epoch 44: 100%|██████████| 442/442 [00:16<00:00, 26.75it/s]


RMSE: 1.0546523373016834 at epoch 45


epoch 45: 100%|██████████| 442/442 [00:16<00:00, 26.55it/s]


RMSE: 0.8809915287829629 at epoch 46


epoch 46: 100%|██████████| 442/442 [00:16<00:00, 26.84it/s]


RMSE: 1.0453446327993081 at epoch 47


epoch 47: 100%|██████████| 442/442 [00:16<00:00, 26.87it/s]


RMSE: 0.7916499062958458 at epoch 48


epoch 48: 100%|██████████| 442/442 [00:16<00:00, 26.30it/s]


RMSE: 0.9599279053568147 at epoch 49


epoch 49: 100%|██████████| 442/442 [00:16<00:00, 27.24it/s]


RMSE: 0.5757114034214689 at epoch 50


epoch 50: 100%|██████████| 442/442 [00:17<00:00, 25.73it/s]


RMSE: 0.5719797677594074 at epoch 51


epoch 51: 100%|██████████| 442/442 [00:16<00:00, 27.12it/s]


RMSE: 0.5620439368856072 at epoch 52


epoch 52: 100%|██████████| 442/442 [00:15<00:00, 27.67it/s]


RMSE: 0.5581026082883808 at epoch 53


epoch 53: 100%|██████████| 442/442 [00:16<00:00, 27.19it/s]


RMSE: 0.5549664756010663 at epoch 54


epoch 54: 100%|██████████| 442/442 [00:16<00:00, 26.90it/s]


RMSE: 0.548026777710352 at epoch 55


epoch 55: 100%|██████████| 442/442 [00:15<00:00, 27.72it/s]


RMSE: 0.5498600395867311 at epoch 56


epoch 56: 100%|██████████| 442/442 [00:15<00:00, 28.13it/s]


RMSE: 0.5466613830135534 at epoch 57


epoch 57: 100%|██████████| 442/442 [00:15<00:00, 28.69it/s]


RMSE: 0.545958108790662 at epoch 58


epoch 58: 100%|██████████| 442/442 [00:14<00:00, 29.48it/s]


RMSE: 0.5440918334810315 at epoch 59


epoch 59: 100%|██████████| 442/442 [00:15<00:00, 27.73it/s]


RMSE: 0.5428712246749068 at epoch 60


epoch 60: 100%|██████████| 442/442 [00:16<00:00, 27.32it/s]


RMSE: 0.5416378287632302 at epoch 61


epoch 61: 100%|██████████| 442/442 [00:15<00:00, 27.92it/s]


RMSE: 0.5405016685108109 at epoch 62


epoch 62: 100%|██████████| 442/442 [00:16<00:00, 27.51it/s]


RMSE: 0.5394546546894233 at epoch 63


epoch 63: 100%|██████████| 442/442 [00:15<00:00, 28.70it/s]


RMSE: 0.5384592219859087 at epoch 64


epoch 64:  88%|████████▊ | 390/442 [00:13<00:01, 28.59it/s]


KeyboardInterrupt: 