In [1]:
from gemmi import cif
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import gemmi

In [2]:
atom_df = pd.read_csv('../data/atom_df_extended.csv')
bond_df = pd.read_csv('../data/bond_df.csv')

In [3]:
print(atom_df.shape)
print(bond_df.shape)

(2041187, 8)
(2121450, 4)


In [4]:
atom_df.head()

Unnamed: 0,comp_id,atom_id,type_symbol,model_Cartn_x,model_Cartn_y,model_Cartn_z,is_hydrogen,bonded_hydrogens
0,0,C,C,32.88,-0.09,51.314,False,0.0
1,0,O,O,32.16,0.18,50.105,False,0.0
2,0,OA,O,34.147,-0.94,51.249,False,0.0
3,0,CB,C,33.872,-2.227,50.459,False,3.0
4,0,OXT,O,32.419,0.429,52.564,False,1.0


In [5]:
bond_df.head()

Unnamed: 0,comp_id,atom_id_1,atom_id_2,value_order
0,0,C,OXT,SING
1,0,O,C,DOUB
2,0,OA,C,SING
3,0,CB,OA,SING
4,0,CB,HB,SING


# Data preprocessing: Remove NAN rows / molecules where NAN occurs at least once

### remove nan rows from atom_df and bond_df

In [6]:
# Initial row counts
initial_atom_row_count = len(atom_df)
initial_bond_row_count = len(bond_df)

# Step 1: Identify rows with NaN comp_id and collect comp_ids to delete
nan_comp_ids = set()

# Find NaN comp_ids in atom_df
nan_atom_rows = atom_df[atom_df.isna().any(axis=1)]
nan_comp_ids.update(nan_atom_rows['comp_id'])

# Find NaN comp_ids in bond_df
nan_bond_rows = bond_df[bond_df.isna().any(axis=1)]
nan_comp_ids.update(nan_bond_rows['comp_id'])

# Step 2: Remove rows with NaN comp_ids from both dataframes
atom_df = atom_df.dropna(subset=['comp_id'])
bond_df = bond_df.dropna(subset=['comp_id'])

# Step 3: Convert comp_id to set for efficient lookup
nan_comp_ids = set(nan_comp_ids)

# Find all unique comp_ids to delete from atom_df and bond_df
comp_ids_to_delete = set(atom_df[atom_df['comp_id'].isin(nan_comp_ids)]['comp_id']).union(
                     set(bond_df[bond_df['comp_id'].isin(nan_comp_ids)]['comp_id']))

# Step 4: Delete rows with identified comp_ids from both dataframes
atom_df = atom_df[~atom_df['comp_id'].isin(comp_ids_to_delete)]
bond_df = bond_df[~bond_df['comp_id'].isin(comp_ids_to_delete)]

# Final row counts
final_atom_row_count = len(atom_df)
final_bond_row_count = len(bond_df)

# Step 5: Print results
print(f"Identified {len(nan_comp_ids)} unique comp_ids with NaN values.")
print(f"Deleted {initial_atom_row_count - final_atom_row_count} rows from atom_df.")
print(f"Deleted {initial_bond_row_count - final_bond_row_count} rows from bond_df.")

Identified 205 unique comp_ids with NaN values.
Deleted 15625 rows from atom_df.
Deleted 16588 rows from bond_df.


### remove "?" values from atom_df and bond_df

In [7]:
def clean_dataframes(atom_df, bond_df):
    # Identify rows in atom_df that contain "?" in any column
    rows_with_question_mark = atom_df.isin(["?"]).any(axis=1)
    
    # Collect the comp_ids from these rows
    comp_ids_to_delete = atom_df.loc[rows_with_question_mark, 'comp_id'].unique()
    
    # Remove all rows with these comp_ids from both dataframes
    atom_df_cleaned = atom_df[~atom_df['comp_id'].isin(comp_ids_to_delete)]
    bond_df_cleaned = bond_df[~bond_df['comp_id'].isin(comp_ids_to_delete)]
    
    return atom_df_cleaned, bond_df_cleaned, comp_ids_to_delete

# Initial row counts
initial_atom_row_count = len(atom_df)
initial_bond_row_count = len(bond_df)

# Clean dataframes
atom_df, bond_df, comp_ids_to_delete = clean_dataframes(atom_df, bond_df)

# Final row counts
final_atom_row_count = len(atom_df)
final_bond_row_count = len(bond_df)

# Print results
print(f"Identified {len(comp_ids_to_delete)} unique comp_ids with '?' values in atom_df.")
print(f"Deleted {initial_atom_row_count - final_atom_row_count} rows from atom_df.")
print(f"Deleted {initial_bond_row_count - final_bond_row_count} rows from bond_df.")

Identified 1625 unique comp_ids with '?' values in atom_df.
Deleted 86867 rows from atom_df.
Deleted 89122 rows from bond_df.


# Data preprocessing: Modify columns

In [8]:
# make new column atomic number and depending on the type symbol used, give the atomic number
# E.g. in the periodic table, the atom C has the atomic number 6
atom_df['atomic_number'] = atom_df['type_symbol'].apply(lambda x: gemmi.Element(x).atomic_number)

In [9]:
# Define the mapping from string to integer
value_order_mapping = {
    'SING': 1,
    'DOUB': 2,
    'TRIP': 3
}

# Replace the values in the 'value_order' column
bond_df['value_order'] = bond_df['value_order'].replace(value_order_mapping)

  bond_df['value_order'] = bond_df['value_order'].replace(value_order_mapping)


# Restrict dataset to a certain amount of bonded hydrogens (and if wanted to a certain central atom)

In [10]:
# Number of hydrogen atoms to train a model for and the central atom type
num_hydrogen = 1
central_atom = 'C'

### only for a given central atom and a number of bonded hydrogens

In [11]:
atom_df_filtered = atom_df.loc[(atom_df['type_symbol'] == central_atom) & (atom_df['bonded_hydrogens'] == num_hydrogen)]
atom_df_filtered.head()

Unnamed: 0,comp_id,atom_id,type_symbol,model_Cartn_x,model_Cartn_y,model_Cartn_z,is_hydrogen,bonded_hydrogens,atomic_number
10,1,C02,C,25.498,13.476,26.66,False,1.0,6
14,1,C06,C,27.3,11.861,26.289,False,1.0,6
31,1,C16,C,27.938,12.754,21.659,False,1.0,6
35,1,C18,C,30.921,13.77,23.757,False,1.0,6
40,1,C23,C,35.552,13.034,23.155,False,1.0,6


### for all central atom types and a given number of bonded hydrogen 

In [8]:
atom_df_filtered = atom_df.loc[(atom_df['bonded_hydrogens'] == num_hydrogen)]
atom_df_filtered.head()

Unnamed: 0,comp_id,atom_id,type_symbol,model_Cartn_x,model_Cartn_y,model_Cartn_z,is_hydrogen,bonded_hydrogens,atomic_number
4,0,OXT,O,32.419,0.429,52.564,False,1.0,8
10,1,C02,C,25.498,13.476,26.66,False,1.0,6
14,1,C06,C,27.3,11.861,26.289,False,1.0,6
31,1,C16,C,27.938,12.754,21.659,False,1.0,6
35,1,C18,C,30.921,13.77,23.757,False,1.0,6


# Start making a big table where atoms and bonds are merged

In [12]:
# Merge on comp_id and atom_id_1
merged_1 = atom_df_filtered.merge(bond_df, left_on=['comp_id', 'atom_id'], right_on=['comp_id', 'atom_id_1'], suffixes=('', '_bond_1'))

# Merge on comp_id and atom_id_2
merged_2 = atom_df_filtered.merge(bond_df, left_on=['comp_id', 'atom_id'], right_on=['comp_id', 'atom_id_2'], suffixes=('', '_bond_2'))

In [13]:
# Concatenate the two merged DataFrames
result_df = pd.concat([merged_1, merged_2], ignore_index=True)
result_df.head()

Unnamed: 0,comp_id,atom_id,type_symbol,model_Cartn_x,model_Cartn_y,model_Cartn_z,is_hydrogen,bonded_hydrogens,atomic_number,atom_id_1,atom_id_2,value_order
0,1,C02,C,25.498,13.476,26.66,False,1.0,6,C02,C03,1
1,1,C02,C,25.498,13.476,26.66,False,1.0,6,C02,H021,1
2,1,C06,C,27.3,11.861,26.289,False,1.0,6,C06,H061,1
3,1,C16,C,27.938,12.754,21.659,False,1.0,6,C16,C17,1
4,1,C16,C,27.938,12.754,21.659,False,1.0,6,C16,H161,1


In [14]:
# Sort the result_df by comp_id first and then by atom_id
result_df_sorted = result_df.sort_values(by=['comp_id', 'atom_id']).reset_index(drop=True)
result_df_sorted.head()

Unnamed: 0,comp_id,atom_id,type_symbol,model_Cartn_x,model_Cartn_y,model_Cartn_z,is_hydrogen,bonded_hydrogens,atomic_number,atom_id_1,atom_id_2,value_order
0,1,C02,C,25.498,13.476,26.66,False,1.0,6,C02,C03,1
1,1,C02,C,25.498,13.476,26.66,False,1.0,6,C02,H021,1
2,1,C02,C,25.498,13.476,26.66,False,1.0,6,C01,C02,2
3,1,C06,C,27.3,11.861,26.289,False,1.0,6,C06,H061,1
4,1,C06,C,27.3,11.861,26.289,False,1.0,6,C01,C06,1


In [15]:
# Function to determine the correct join column and perform the join
def join_atom_coordinates(atom_df_extended, atom_df):
    # Create a copy of the DataFrame to avoid modifying the original one
    df = atom_df_extended.copy()

    # Determine the correct join column
    df['join_on_atom_id'] = df.apply(
        lambda row: row['atom_id_2'] if row['atom_id'] == row['atom_id_1'] else row['atom_id_1'], axis=1
    )

    # Perform the join
    merged_df = df.merge(
        atom_df[['comp_id', 'atom_id', 'atomic_number', 'model_Cartn_x', 'model_Cartn_y', 'model_Cartn_z']],
        left_on=['comp_id', 'join_on_atom_id'],
        right_on=['comp_id', 'atom_id'],
        suffixes=('', '_joined')
    )

    # Extract and rename the coordinates columns
    merged_df = merged_df.rename(
        columns={
            'model_Cartn_x_joined': 'model_Cartn_x2',
            'model_Cartn_y_joined': 'model_Cartn_y2',
            'model_Cartn_z_joined': 'model_Cartn_z2'
        }
    )

    # Keep only the necessary columns
    result_df = merged_df[
        ['comp_id', 'atom_id', 'type_symbol', 'model_Cartn_x', 'model_Cartn_y', 'model_Cartn_z', 'is_hydrogen', 
         'bonded_hydrogens', 'atomic_number', 'atom_id_1', 'atom_id_2', 'value_order', 'model_Cartn_x2', 'model_Cartn_y2', 'model_Cartn_z2', 'atomic_number_joined']
    ]

    return result_df

In [16]:
# Apply the function and print the result
result_df_final = join_atom_coordinates(result_df_sorted, atom_df)
result_df_final.head()

Unnamed: 0,comp_id,atom_id,type_symbol,model_Cartn_x,model_Cartn_y,model_Cartn_z,is_hydrogen,bonded_hydrogens,atomic_number,atom_id_1,atom_id_2,value_order,model_Cartn_x2,model_Cartn_y2,model_Cartn_z2,atomic_number_joined
0,1,C02,C,25.498,13.476,26.66,False,1.0,6,C02,C03,1,26.077,13.812,27.91,6
1,1,C02,C,25.498,13.476,26.66,False,1.0,6,C02,H021,1,24.573,13.972,26.319,1
2,1,C02,C,25.498,13.476,26.66,False,1.0,6,C01,C02,2,26.108,12.501,25.848,6
3,1,C06,C,27.3,11.861,26.289,False,1.0,6,C06,H061,1,27.782,11.099,25.653,1
4,1,C06,C,27.3,11.861,26.289,False,1.0,6,C01,C06,1,26.108,12.501,25.848,6


In [17]:
result_df_final.to_csv('../data/centralatom-C_numberhydrogens-1/big-table.csv')

In [2]:
# Load the DataFrame from the CSV file
file_path = '../data/centralatom-C_numberhydrogens-1/big-table.csv'
result_df_final = pd.read_csv(file_path, index_col=0)

# Display the loaded DataFrame
result_df_final.head()

Unnamed: 0,comp_id,atom_id,type_symbol,model_Cartn_x,model_Cartn_y,model_Cartn_z,is_hydrogen,bonded_hydrogens,atomic_number,atom_id_1,atom_id_2,value_order,model_Cartn_x2,model_Cartn_y2,model_Cartn_z2,atomic_number_joined
0,1,C02,C,25.498,13.476,26.66,False,1.0,6,C02,C03,1,26.077,13.812,27.91,6
1,1,C02,C,25.498,13.476,26.66,False,1.0,6,C02,H021,1,24.573,13.972,26.319,1
2,1,C02,C,25.498,13.476,26.66,False,1.0,6,C01,C02,2,26.108,12.501,25.848,6
3,1,C06,C,27.3,11.861,26.289,False,1.0,6,C06,H061,1,27.782,11.099,25.653,1
4,1,C06,C,27.3,11.861,26.289,False,1.0,6,C01,C06,1,26.108,12.501,25.848,6


# Now create the feature vector

### with central atom for each feature vector and RELATIVE POSITIONS (or we want this)

In [3]:
# Function to create feature vectors
def create_feature_vectors_relative(df):
    feature_vectors_dict = {}

    for _, row in tqdm(df.iterrows(), total=df.shape[0]):
        comp_id = row['comp_id']
        atom_id = row['atom_id']
        key = (comp_id, atom_id)
        
        if key not in feature_vectors_dict:
            feature_vectors_dict[key] = [
                row['atomic_number']
            ]

        # Append relevant columns
        feature_vectors_dict[key].extend([
            row['atomic_number_joined'],
            row['value_order'],
            float(row['model_Cartn_x2']) - float(row['model_Cartn_x']),
            float(row['model_Cartn_y2']) - float(row['model_Cartn_y']),
            float(row['model_Cartn_z2']) - float(row['model_Cartn_z'])
        ])

    # Convert the dictionary to a list of feature vectors
    feature_vectors = [features for features in feature_vectors_dict.values()]
    
    return feature_vectors

In [4]:
# Generate the feature vectors
feature_vectors = create_feature_vectors_relative(result_df_final)

100%|███████████████████████████████████████████████████████████████████████| 930505/930505 [00:54<00:00, 17200.03it/s]


In [7]:
# with create_feature_vectors_absolute
feature_vectors[0]

[6,
 6,
 1,
 0.5790000000000006,
 0.3359999999999985,
 1.25,
 1,
 1,
 -0.9250000000000007,
 0.49599999999999866,
 -0.3410000000000011,
 6,
 2,
 0.6099999999999994,
 -0.9750000000000014,
 -0.8120000000000012]

# Separate Hydrogens (y) from the feature vector (X)

In [8]:
X, y = [], []
for vec in tqdm(feature_vectors):
    assert vec[0] != 1, "Central atom is hydrogen atom, Aborted program"
    X_auxiliary = []
    X_auxiliary.append(vec[0])
    found_hydrogen = False
    for i in range(1, len(vec), 5):
        if vec[i] == 1: # hydrogen case
            assert vec[i+1] == 1, "Hydrogen atom does not have one bonding as it should have"
            assert not found_hydrogen, "Found more than one bonding to hydrogen"
            y.append(vec[i+2:i+5])
            found_hydrogen = True
        else:
            X_auxiliary.extend(vec[i:i+5])
    X.append(X_auxiliary)

100%|██████████████████████████████████████████████████████████████████████| 286717/286717 [00:00<00:00, 443937.81it/s]


In [13]:
X[2]

[6,
 6,
 1,
 1.3820000000000014,
 0.4039999999999999,
 0.48799999999999955,
 7,
 1,
 -0.9800000000000004,
 0.18700000000000117,
 1.1020000000000003,
 6,
 1,
 -0.3739999999999988,
 0.8170000000000002,
 -1.2509999999999977]

In [14]:
y[2]

[-0.009000000000000341, -1.077, -0.29199999999999804]

# Sort input X with respect to the atomic number of the bondings (e.g. first all bondings of 6, then all bondings of 7, ...) 
- must be done **before zero padding!**
- makes the input more permutation invariant

In [16]:
def sort_bondings(X):
    # Extract the central atom
    central_atom = X[0]
    
    # Extract the bondings (each bonding consists of 5 values)
    bondings = [X[i:i+5] for i in range(1, len(X), 5)]
    
    # Sort the bondings based on the atomic number of the bonded atom
    sorted_bondings = sorted(bondings, key=lambda b: b[0])
    
    # Flatten the sorted bondings and reconstruct X
    sorted_X = [central_atom] + [item for sublist in sorted_bondings for item in sublist]
    
    return sorted_X

In [17]:
X_sorted = [sort_bondings(vec) for vec in X]

In [19]:
X[2]

[6,
 6,
 1,
 1.3820000000000014,
 0.4039999999999999,
 0.48799999999999955,
 7,
 1,
 -0.9800000000000004,
 0.18700000000000117,
 1.1020000000000003,
 6,
 1,
 -0.3739999999999988,
 0.8170000000000002,
 -1.2509999999999977]

In [20]:
X_sorted[2]

[6,
 6,
 1,
 1.3820000000000014,
 0.4039999999999999,
 0.48799999999999955,
 6,
 1,
 -0.3739999999999988,
 0.8170000000000002,
 -1.2509999999999977,
 7,
 1,
 -0.9800000000000004,
 0.18700000000000117,
 1.1020000000000003]

In [21]:
X = X_sorted

In [51]:
import pickle

# Save the training data
with open('../data/centralatom-C_numberhydrogens-1/data-for-analysis.pkl', 'wb') as f:
    pickle.dump((X_sorted, y), f)

### identify largest feature vector (for zero padding later)
- Result:
    - for all central atoms: 26 (Interpretation: Each bonding consumes 5 values and we have the atomic number of the central atom in the beginning --> 5 Bondings is max.)
    - for central atom C: 21

# Zero Padding

In [22]:
max_length = 0
for vec in X:
    if len(vec) > max_length:
        max_length = len(vec)
print("Max feature vector length is: ", max_length)

Max feature vector length is:  21


In [23]:
# Pad the lists with zeros to the maximum length
X_padded = [lst + [0] * (max_length - len(lst)) for lst in X]

In [24]:
X_padded = np.asarray(X_padded)

In [25]:
X_padded.shape

(286717, 21)

# Append descriptor to the vector (code is not for one hot encoded features!)

In [26]:
from dscribe.descriptors import CoulombMatrix
from ase import Atoms

cm = CoulombMatrix(n_atoms_max=int((max_length-1)/5) + 1, permutation='none')

In [27]:
X_padded_descriptor = []
for vec in tqdm(X_padded):
    X_auxiliary = []
    X_auxiliary.extend(vec)

    # collect parameters for the descriptor
    symbols = [gemmi.Element(int(vec[0])).name]
    positions = [np.array([0, 0, 0])]
    for i in range(1, len(vec), 5):
        if vec[i] != 0: # check that we are not in the zero padding
            symbols.append(gemmi.Element(int(vec[i])).name)
            positions.append(vec[i+2:i+5])

    # create descriptor
    atoms_object = Atoms(symbols=symbols, positions=positions)
    atoms_object_descriptor = cm.create(atoms_object)

    # append to the current zero padded feature vector
    X_auxiliary.extend(atoms_object_descriptor)
    X_padded_descriptor.append(X_auxiliary)

100%|███████████████████████████████████████████████████████████████████████| 286717/286717 [00:21<00:00, 13063.73it/s]


In [28]:
X[0]

[6,
 6,
 1,
 0.5790000000000006,
 0.3359999999999985,
 1.25,
 6,
 2,
 0.6099999999999994,
 -0.9750000000000014,
 -0.8120000000000012]

In [29]:
X_padded[0]

array([ 6.   ,  6.   ,  1.   ,  0.579,  0.336,  1.25 ,  6.   ,  2.   ,
        0.61 , -0.975, -0.812,  0.   ,  0.   ,  0.   ,  0.   ,  0.   ,
        0.   ,  0.   ,  0.   ,  0.   ,  0.   ])

In [30]:
np.array(X_padded_descriptor[0])

array([ 6.        ,  6.        ,  1.        ,  0.579     ,  0.336     ,
        1.25      ,  6.        ,  2.        ,  0.61      , -0.975     ,
       -0.812     ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        , 36.8581052 , 25.3884195 , 25.57072939,  0.        ,
        0.        , 25.3884195 , 36.8581052 , 14.73193646,  0.        ,
        0.        , 25.57072939, 14.73193646, 36.8581052 ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ])

# Encoding 
Options:
- atomic number
    - One hot encoding (because ordinality makes no sense) or
    - One hot encoding of most important atoms and leave one encoding for all the other atoms, e.g. 0, 0, 0, ..., 1 is all the other atoms
    - leave as it is hoping that it works
- value order (bonding)
    - One hot encoding for value order of bonding (although ordinality makes sense here)
    - leave as it is

### prepare for one hot encoding of atomic number value (count distinct values)
- Result: 38 different values --> large one hot encoding --> worth it?

In [28]:
# Initialize a dictionary to store the count of each atomic number
atomic_number_counts = {}

# Iterate through each feature vector
for vec in feature_vectors:
    # Count the atomic number of the central atom
    central_atom = vec[0]
    if central_atom in atomic_number_counts:
        atomic_number_counts[central_atom] += 1
    else:
        atomic_number_counts[central_atom] = 1
    
    # Count the atomic numbers of the bonded atoms
    for i in range(1, len(vec), 5):
        bonded_atom = vec[i]
        if bonded_atom in atomic_number_counts:
            atomic_number_counts[bonded_atom] += 1
        else:
            atomic_number_counts[bonded_atom] = 1

print(atomic_number_counts)
print(sorted(list(atomic_number_counts.keys())))
print(len(atomic_number_counts.keys()))

{6: 831357, 1: 286717, 7: 54595, 8: 39999, 15: 356, 16: 3011, 5: 109, 17: 109, 34: 36, 44: 97, 9: 719, 76: 9, 53: 4, 77: 3, 26: 54, 35: 31, 14: 8, 27: 1, 46: 1, 78: 2, 75: 4}
[1, 5, 6, 7, 8, 9, 14, 15, 16, 17, 26, 27, 34, 35, 44, 46, 53, 75, 76, 77, 78]
21


### Variant 1: one hot encoding for atomic number and for value order
- 5 one hot encoding for atomic number: 10000 for C, 01000 for N, 00100 for O, 00010 for S, 00001 for the rest
- 3 one hot encoding for value order: 100 for 1, 010 for 2, 001 for 3

In [33]:
# One-hot encoding functions
def one_hot_encode_centralatom(value):
    if value == 0:  # for the zero padded values
        return [0, 0, 0, 0, 0]
    elif value == 6:
        return [1, 0, 0, 0, 0]
    elif value == 7:
        return [0, 1, 0, 0, 0]
    elif value == 8:
        return [0, 0, 1, 0, 0]
    elif value == 16:
        return [0, 0, 0, 1, 0]
    else:
        return [0, 0, 0, 0, 1]

def one_hot_encode_valueorder(value):
    if value == 0:  # for the zero padded values
        return [0, 0, 0]
    elif value == 1:
        return [1, 0, 0]
    elif value == 2:
        return [0, 1, 0]
    elif value == 3:
        return [0, 0, 1]
    else:
        print("Mistake occurred!")
        return

In [37]:
centralatom_indices_to_encode = list(range(1, max_length, 5))
print("Central atom indices: ", centralatom_indices_to_encode)

valueorder_indices_to_encode = list(range(2, max_length, 5))
print("Value order indices: ", valueorder_indices_to_encode)

Central atom indices:  [1, 6, 11, 16]
Value order indices:  [2, 7, 12, 17]


In [39]:
len(X_padded_descriptor[0])

46

In [42]:
# Create the new array with one-hot encoded values
X_encoded = []
for sample in tqdm(X_padded_descriptor):
    new_sample = []
    for i, value in enumerate(sample):  # max_length 
        if i in centralatom_indices_to_encode:
            new_sample.extend(one_hot_encode_centralatom(value))
        elif i in valueorder_indices_to_encode:
            new_sample.extend(one_hot_encode_valueorder(value))
        else:
            new_sample.append(value)
    X_encoded.append(new_sample)

100%|███████████████████████████████████████████████████████████████████████| 286717/286717 [00:05<00:00, 50231.14it/s]


In [43]:
len(X_encoded[0])

70

In [44]:
# Convert to numpy array
X_encoded = np.array(X_encoded)

In [45]:
X_encoded.shape

(286717, 70)

In [46]:
X_encoded[0]

array([ 6.        ,  1.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  1.        ,  0.        ,  0.        ,  0.579     ,
        0.336     ,  1.25      ,  1.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  1.        ,  0.        ,
        0.61      , -0.975     , -0.812     ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
       36.8581052 , 25.3884195 , 25.57072939,  0.        ,  0.        ,
       25.3884195 , 36.8581052 , 14.73193646,  0.        ,  0.        ,
       25.57072939, 14.73193646, 36.8581052 ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
        0.        ,  0.        ,  0.        ,  0.        ,  0.  

# Split to training and testing dataset and save

In [48]:
from sklearn.model_selection import train_test_split

# Split the data
X_train, X_test, y_train, y_test = train_test_split(X_encoded, y, test_size=0.1, random_state=42)

In [49]:
# Save the training data
with open('../data/centralatom-C_numberhydrogens-1/training-validation.pkl', 'wb') as f:
    pickle.dump((X_train, y_train), f)

# Save the testing data
with open('../data/centralatom-C_numberhydrogens-1/testing.pkl', 'wb') as f:
    pickle.dump((X_test, y_test), f)

# Auxiliary

#### know these two inverted functions!

In [34]:
print(gemmi.Element('C').atomic_number)
print(gemmi.Element(6).name)

6
C
