In [None]:
%load_ext autoreload
%autoreload 2

# Selfies Featurization with one hot encoding

In [None]:
import os
from dotenv import load_dotenv
import selfies as sf
from tqdm import tqdm
from selfies import EncoderError
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math
import json
from collections import deque
import gc




from Code.Utils.util_methods import NNUtils

In [None]:
current_file = os.getcwd()  # or provide a specific file path
try:
    base = NNUtils.find_project_root(current_file)
    print(f"Project root found: {base}")
except FileNotFoundError as e:
    base=None
    print(e)
    
load_dotenv(f'{base}/.env')

In [None]:
# Todo: Set the variables False if not needed
NEW_ORDERED_DICT = True # If a new ordered (by SELFIES part occurence) dictionary should be used, where only the present SELFIES parts are used, instead of the existing dict
NEW_RANDOM_DICT = False # Same but the order of the parts will be random
NORMALIZE_SPECTRA = False # KEEP IT FALSE # If the spectra (X data) should be normalized between 0 and 1

## Load the cleaned data

In [None]:
cleaned_df = NNUtils.read_big_csv(f"{base}/Dataset/Mass_spectra/cleaned_df.csv")
print(cleaned_df.shape)
cleaned_df.head()

## Add the selfies 

In [None]:
cleaned_df_selfies = cleaned_df.copy()

# add a new column with SELFIES representation
cleaned_df_selfies = NNUtils.add_selfies_to_df(cleaned_df_selfies)

cleaned_df_selfies.drop(cleaned_df_selfies[cleaned_df_selfies["SELFIES"] == "Invalid"].index, inplace=True)
cleaned_df_selfies.reset_index(drop=True, inplace=True)


# filter rows where SELFIES contains a dot
contains_dot = cleaned_df_selfies["SELFIES"].str.contains(r'\.')
# count the number of SELFIES containing a dot
count_with_dot = contains_dot.sum()
print(f"Number of SELFIES containing a dot: {count_with_dot}")
# remove rows with SELFIES containing a dot
cleaned_df_selfies = cleaned_df_selfies[~cleaned_df_selfies["SELFIES"].str.contains(r'\.')].reset_index(drop=True)


print('Invalid SMILES and SELFIES with a dot removed')
print(cleaned_df_selfies.shape)
cleaned_df_selfies.head()

## Create a DF with unique SELFIES

In [None]:
selfies_df = cleaned_df_selfies[["SELFIES"]].drop_duplicates().reset_index(drop=True)
print(selfies_df.shape)
selfies_df

## Create a dictionary with each SELFIES component

In [None]:
selfies_group_dict = {}
if NEW_RANDOM_DICT:
    # extract unique SELFIES groups
    unique_selfies_groups = sf.get_alphabet_from_selfies(selfies_df["SELFIES"])

    # create the dictionary with enumerated positions
    selfies_group_dict = {group: idx+1 for idx, group in enumerate(unique_selfies_groups)}
    selfies_group_dict['[nop]'] = 0
selfies_group_dict

In [None]:
if not(NEW_ORDERED_DICT or NEW_RANDOM_DICT):
    # read json file into dictionary
    with open(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/selfies_group_dict.json', 'r') as json_file:
        selfies_group_dict = json.load(json_file)
selfies_group_dict

## Get the frequency of each SELFIES component

In [None]:
# split SELFIES into fragments and count frequencies
all_keys = []
for selfies in selfies_df["SELFIES"]:
    all_keys.extend(selfies.split("]["))  # split on "][" to get individual fragments

# normalize fragments by adding brackets back
normalized_keys = [f"[{key.strip('][')}]" for key in all_keys]

# count frequencies
key_frequencies = Counter(normalized_keys)

# display the frequencies
print(f'There are {len(key_frequencies)+1} different selfies parts')
key_frequencies


In [None]:
if NEW_ORDERED_DICT:
    selfies_group_dict = {'[nop]': 0}
    c=0
    for i in key_frequencies:
        c+=1
        selfies_group_dict[i] = c

    with open(f'{base}/Code/Full_systems/Selfies_Mol/Featurization/selfies_group_dict.json', "w") as file:
        json.dump(selfies_group_dict, file, indent=4)
selfies_group_dict

## Get the frequency of each SELFIES length

In [None]:
slefies_length = selfies_df["SELFIES"].apply(lambda x: len(x.split("][")))

# count the frequency of each length. Sort by the length
length_frequencies = dict(Counter(slefies_length).most_common())
length_frequencies

In [None]:
# sorted by the frequency
dict(sorted(length_frequencies.items()))

In [None]:
# create a bar chart for the frequencies
max_length = max(length_frequencies.keys())
plt.figure(figsize=(8, 5))
bars = plt.bar(length_frequencies.keys(), length_frequencies.values(), width=0.6, label=f'SELFIES length, max Length: {max_length}')

# Add thin vertical dotted lines at each x position where there's a bar
for bar in bars:
    if bar.get_height() > 0:  # Only add a line where there's a bar
        plt.axvline(x=bar.get_x() + bar.get_width() / 2, color='red', linestyle='dotted', zorder=0)
        

min_length = 0
plt.xticks(range(min_length, max_length + 1, 10))
plt.xlabel('Number of Fragments in SELFIES', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.title('Frequency of SELFIES Lengths', fontsize=14)
plt.legend()
plt.tight_layout()
plt.show()

Remove the SELFIES longer than 60. 60 will be the maximum length of the SELFIES

In [None]:
MAX_SELFIES_LENGTH = int(os.getenv("MAX_SELFIES_LENGTH"))
MAX_SELFIES_LENGTH

In [None]:
# remove rows with SELFIES longer than 60
cleaned_df_selfies_cropped = cleaned_df_selfies[cleaned_df_selfies["SELFIES"].apply(lambda x: len(x.split("][")) <= MAX_SELFIES_LENGTH)].reset_index(drop=True)
print(cleaned_df_selfies_cropped.shape)
print(f'{cleaned_df_selfies.shape[0]-cleaned_df_selfies_cropped.shape[0]} rows removed')

In [None]:
cleaned_df_selfies_cropped

## Featurize the SELFIES

In [None]:
selfies_group_dict

In [None]:
len(selfies_group_dict)

In [None]:
hot = sf.batch_selfies_to_flat_hot(cleaned_df_selfies_cropped['SELFIES'].astype(str).tolist(), vocab_stoi=selfies_group_dict, pad_to_len=MAX_SELFIES_LENGTH)

In [None]:
del cleaned_df_selfies, cleaned_df
gc.collect()

In [None]:
#hot: list of one hot encoding lists
print(len(hot), len(hot[0]))
#hot[0]

In [None]:
# selfies_emb_df = pd.DataFrame(hot, columns=[f'bit_{i}' for i in range(MAX_SELFIES_LENGTH*len(selfies_group_dict))])
# print('hot deleted')
# selfies_emb_df['SELFIES'] = cleaned_df_selfies_cropped['SELFIES']
# print(selfies_emb_df.shape)
# selfies_emb_df

In [None]:
# df_cols = [f'bit_{i}' for i in range(MAX_SELFIES_LENGTH*len(selfies_group_dict))]
# df_cols.append('SELFIES')
# selfies_emb_df = pd.DataFrame(columns=df_cols)
# for encoding in tqdm(range(len(hot))):
#     row = hot[encoding]
#     row.append(cleaned_df_selfies_cropped.iloc[encoding]['SELFIES'])
#     selfies_emb_df.loc[encoding] = row

In [None]:
# define the columns for the dataframe
df_cols = [f'bit_{i}' for i in range(MAX_SELFIES_LENGTH * len(selfies_group_dict))]
df_cols.append('SELFIES')

# batch size to create dataframes
batch_size = 5000

# initialize an empty list to hold rows temporarily
temp_rows = []

# a list to store dataframe batches
dataframes = []

#selfies_emb_df = pd.DataFrame(columns=df_cols)

# iterate over the encodings and process each row
for encoding in tqdm(range(len(hot))):

    # create the row from hot and append 'SELFIES' value
    row = hot[encoding]
    row.append(cleaned_df_selfies_cropped.iloc[encoding]['SELFIES'])

    # selfies_emb_df.loc[encoding] = row
        
    # add the row to the temporary list
    temp_rows.append(row)
    
    # when the batch is full, convert it to a DataFrame and store it
    if len(temp_rows) == batch_size:
        batch_df = pd.DataFrame(temp_rows, columns=df_cols)
        dataframes.append(batch_df)
        temp_rows = []  # reset the temporary list to save memory

# handle any remaining rows after the loop ends
if temp_rows:
    batch_df = pd.DataFrame(temp_rows, columns=df_cols)
    dataframes.append(batch_df)

del hot, temp_rows
gc.collect()

# concatenate all batch dataframes to create the final dataframe
selfies_emb_df = pd.concat(dataframes, ignore_index=True)

selfies_emb_df.to_pickle(f'{os.getcwd()}/one_hot_selfies_encoding.pkl')
print('one_hot_selfies_encoding.pkl saved')

# display the resulting dataframe
print(selfies_emb_df.shape)
selfies_emb_df.head()


In [None]:
len(selfies_group_dict)

## Remove duplicates

In [None]:
# selfies_emb_unique_df = selfies_emb_df.drop_duplicates().reset_index(drop=True)
# print(selfies_emb_unique_df.shape)
# selfies_emb_unique_df

## Save the featurized SELFIES into a PKL

In [None]:
# selfies_emb_unique_df.to_csv(f'{base}/Dataset/Embeddings/{os.getenv("SELFIES_EMBEDDING")}', index=False)
# print(f'Saved the SELFIES embeddings into {base}/Dataset/Embeddings/{os.getenv("SELFIES_EMBEDDING")}')

# Combine the SELFIES embeddings with the cleaned data

In [None]:
ms_emb_df = pd.concat([cleaned_df_selfies_cropped, selfies_emb_df.drop(columns=['SELFIES'])], axis=1)
print(ms_emb_df.shape)
ms_emb_df

## Separate X and y

In [None]:
spectra_columns = [col for col in ms_emb_df.columns if 'mz' in col]
X = ms_emb_df[spectra_columns]
if NORMALIZE_SPECTRA:
    X = X / 999
print(X.shape)
X

In [None]:
embedding_columns = [col for col in ms_emb_df.columns if 'bit_' in col]
y = ms_emb_df[embedding_columns]
print(y.shape)
y

## Separation into train and test

In [None]:
X_train, X_test, y_train, y_test = NNUtils.divide_big_train_and_test_data(X, y)

In [None]:
selfies_test = ms_emb_df.loc[X_test.index, 'SELFIES']
selfies_test.head()

In [None]:
print(X_train.shape)
X_train.head()

In [None]:
print(y_train.shape)
y_train.head()

In [None]:
print(X_test.shape)
X_test.head()

In [None]:
print(y_test.shape)
y_test.head()

In [None]:
X_train.reset_index(drop=True, inplace=True)
X_test.reset_index(drop=True, inplace=True)
y_train.reset_index(drop=True, inplace=True)
y_test.reset_index(drop=True, inplace=True)
selfies_test.reset_index(drop=True, inplace=True)

In [None]:
input_size = X_train.shape[1]
output_size = y_train.shape[1]
input_size, output_size

In [None]:
X_train.to_pickle(os.getenv('X_TRAIN'))
print(f"{os.getenv('X_TRAIN')} saved")
y_train.to_pickle(os.getenv('Y_TRAIN'))
print(f"{os.getenv('Y_TRAIN')} saved")
X_test.to_pickle(os.getenv('X_TEST'))
print(f"{os.getenv('X_TEST')} saved")
y_test.to_pickle(os.getenv('Y_TEST'))
print(f"{os.getenv('Y_TEST')} saved")
selfies_test.to_pickle(os.getenv('SELFIES_X_TEST'))
print(f"{os.getenv('SELFIES_X_TEST')} saved")

In [None]:
print('done')