In [None]:
from rdkit import DataStructs
from rdkit.DataStructs import ExplicitBitVect
from rdkit.Chem import MACCSkeys
import xlsxwriter
from rdkit.Chem import Draw
from io import BytesIO
import tkinter as tk
from tkinter import ttk, messagebox
from tkinter import *
import numpy as np
import os
import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import dgl
from dgllife.utils import smiles_to_bigraph
from dgllife.utils import AttentiveFPAtomFeaturizer, AttentiveFPBondFeaturizer
from dgllife.data import MoleculeCSVDataset
from dgllife.model.gnn import AttentiveFPGNN
from dgllife.model.readout import AttentiveFPReadout
from sklearn.preprocessing import StandardScaler
from rdkit.Chem import AllChem
from rdkit import Chem
from rdkit.Chem import MolStandardize
import pandas as pd
from PIL import Image, ImageTk
from pathlib import Path

# Main Window
class WelcomeWindow:
    def __init__(self, master):
        self.master = master
        master.title('dye-predictor')
        master.config(background='#FEFEFE')
        master.resizable(False, False)

        image = Image.open("./Data/p1.png")
        img = image.resize((1020, 600))
        self.my_img = ImageTk.PhotoImage(img)
        self.background_label = Label(master, image=self.my_img)
        self.background_label.grid(row=0, column=0, rowspan=4)

        self.retrieval_button = tk.Button(master, text="Flourescent Dye Retrieval", command=self.open_search_wind, font=('Arial', 15),bg='#FFFFFF', relief=tk.RAISED, width=25, borderwidth=4)
        self.retrieval_button.grid(row=2, column=0, pady=20, sticky="s")
        self.prediction_button = tk.Button(master, text="Flourescent Dye Prediction", command=self.open_predict_wind, font=('Arial', 15),bg='#FFFFFF', relief=tk.RAISED, width=25, borderwidth=4)
        self.prediction_button.grid(row=3, column=0, pady=20, sticky="n")

    def open_search_wind(self):
        self.master.withdraw()
        root = tk.Toplevel(self.master)
        app = search_app(root)
    def open_predict_wind(self):
        self.master.withdraw()
        root = tk.Toplevel(self.master)
        app = predict_app(root)

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
# Search Feature Code

class search_app:
    def __init__(self, master):
        self.master = master
        master.title('Flourescent Dye Retrieval')
        master.resizable(False, False)
        
        image1 = Image.open("./Data/p2.png")
        img1 = image1.resize((1025, 100))
        self.my_img1 = ImageTk.PhotoImage(img1)
        self.background_label1 = Label(master, image=self.my_img1)
        self.background_label1.grid(row=0, column=0, columnspan=10)
        self.canvas1 = tk.Canvas(master, width=1025, height=320, bg="#FFFFFF")
        self.canvas1.grid(row=1, column=0, rowspan=8, columnspan=10)
        
        # Insert the input entry
        self.label_smiles1 = tk.Label(master, text="Please input a SMILES for the target molecule", font=('Arial', 12, "bold"), fg='#233C6F', bg='#FFFFFF')
        self.label_smiles1.grid(row=1, column=0, pady=5, columnspan=10)
        self.entry_smiles1 = tk.Entry(master, width=100)
        self.entry_smiles1.grid(row=2, column=0, pady=5, columnspan=10)
        self.entry_smiles1.insert(0, "Such as: CCOC(=O)c1ccccc1-c1c2ccc(=[N+](CC)CC)cc-2oc2c1c(=O)oc1cc3c(cc12)CCCN3C")
         
        # Insert Function Selection Button
        self.label_function1 = tk.Label(master, text="Please select a function", font=('Arial', 12, "bold"), fg='#233C6F', bg='#FFFFFF')
        self.label_function1.grid(row=3, column=0, pady=5, columnspan=10)
        self.selected_option = tk.IntVar(value=1)
        self.query_button1 = tk.Radiobutton(master, text="Direct Retrieval", variable=self.selected_option, value=1, font=('Arial', 13, "bold"), bg='#FFFFFF')
        self.query_button1.grid(row=4, column=4, pady=5, columnspan=10 ,sticky="w")
        self.similarity_search_button1 = tk.Radiobutton(master, text="Similarity Search", variable=self.selected_option, value=2, font=('Arial', 13, "bold"), bg='#FFFFFF')
        self.similarity_search_button1.grid(row=5, column=4, pady=5, columnspan=10,sticky="w")
        self.sub_search_button1 = tk.Radiobutton(master, text="Substructure Search", variable=self.selected_option, value=3, font=('Arial', 13, "bold"), bg='#FFFFFF')
        self.sub_search_button1.grid(row=6, column=4, pady=5, columnspan=10,sticky="w")
        
        # Insert Run Button
        self.run_button1 = tk.Button(master, text="Run", command=self.run_search, font=('Arial', 13, "bold"), bg='#24A5C8', relief=tk.RAISED, width=10)
        self.run_button1.grid(row=7, column=0, pady=5, columnspan=10)

        # Insert Function Transformation Button
        self.to_prediction_button = tk.Button(master, text="Go to Prediction", command=self.return_to_prediction, font=('Arial', 13), bg='#FFFFFF', relief=tk.RAISED, width=20)
        self.to_prediction_button.grid(row=8, column=0, columnspan=10)

        # Insert Tips
        image3 = Image.open("./Data/p4.png")
        img3 = image3.resize((1025, 200))
        self.my_img3 = ImageTk.PhotoImage(img3)
        self.background_label3 = Label(master, image=self.my_img3)
        self.background_label3.grid(row=9, column=0, columnspan=10)

    def return_to_prediction(self):
        self.master.withdraw()
        root = tk.Toplevel(self.master)
        app = predict_app(root)
        
###################################################
# Pre-search Preparation
    
    def run_search(self):
        try:
            def standardize_smiles(smiles):
                try:
                    mol = Chem.MolFromSmiles(smiles)
                    if mol:
                        standardized_mol = normalizer.standardize(mol)
                        return Chem.MolToSmiles(standardized_mol)
                    else:
                        return None
                except:
                    return None
                    
            # # Generate Morgan fingerprints for the molecules to be queried
            def generate_morgan_fingerprint(smiles, radius=2, n_bits=1024):
                mol = Chem.MolFromSmiles(smiles)
                if mol is not None:
                    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
                    return list(fp)
                else:
                    return [None] * n_bits

            smiles_input = self.entry_smiles1.get().strip()
            smiles_list = smiles_input.split(',')
            df = pd.DataFrame({'SMILES': smiles_list})
            cols = ['SMILES']
            df = df[cols]
            
            normalizer = MolStandardize.Standardizer()
            df['SMILES'] = df['SMILES'].apply(standardize_smiles)
            df_valid = df.dropna(subset=['SMILES'])
            l2 = len(df_valid)
            
            if l2 == 0:
                print("The input molecule cannot be recognized. Please re-enter!")
                
            output_file = 'Process_data/que-data.csv'
            df_valid.to_csv(output_file, index=False)
            
            df = pd.read_csv('Process_data/que-data.csv')
            df['Morgan_Fingerprint'] = df['SMILES'].apply(lambda x: generate_morgan_fingerprint(x))
            fingerprints_df = pd.DataFrame(df['Morgan_Fingerprint'].tolist())
            fingerprints_df.to_csv('Process_data/que-morgan.csv', index=False)

            if not smiles_input:
                messagebox.showerror("Error", "Please enter a SMILES string.")
                return
            
###################################################
# Direct Search
    
            if self.selected_option.get() == 1:
                messagebox.showinfo("Query", f"Searching for molecule: {smiles_input}")
                file1 = pd.read_csv('Process_data/que-data.csv')
                file2 = pd.read_csv('./Data/Data_all_name.csv')
                smiles_list = file1['SMILES'].tolist()
                filtered_rows = file2[file2['SMILES'].isin(smiles_list)]
                filtered_rows.to_csv('Results/Target_search.csv', index=False)
                print("Filtering complete, results have been saved to Target_search.csv")

###################################################
# Executing the similarity search function

            elif self.selected_option.get() == 2:
                messagebox.showinfo("Similarity Search", f"Performing molecule similarity search: {smiles_input}")
                rank = 100
                def array_to_bitvector(fp_array):
                    bitvect = ExplicitBitVect(len(fp_array))
                    for i, bit in enumerate(fp_array):
                        if bit == 1:
                            bitvect.SetBit(i)
                    return bitvect
                file1 = pd.read_csv('Process_data/que-morgan.csv')
                file2 = pd.read_csv('./Data/Data_all_morgan.csv')
                file3 = pd.read_csv('./Data/Data_all_name.csv')
                fp1 = file1.iloc[0].values
                fp1_vect = array_to_bitvector(fp1)
                fps2 = file2.apply(lambda row: array_to_bitvector(row.values), axis=1)
                
                # Calculating the Tanimoto similarity between the target molecule and all molecules in the database.
                similarities = [DataStructs.TanimotoSimilarity(fp1_vect, fp2) for fp2 in fps2]
                file2['Similarity_to_fp1'] = similarities
                top_100_indices = file2.nlargest(rank, 'Similarity_to_fp1').index
                top_100_molecules = file3.iloc[top_100_indices]
                top_100_molecules.to_csv('Results/Similarity_search.csv', index=False)

                input_file_path = 'Results/Similarity_search.csv'
                output_file_path = 'Results/Similarity_search.csv'
                try:
                    data = pd.read_csv(input_file_path)
                    data.fillna('NA', inplace=True)
                    data.to_csv(output_file_path, index=False)
                except Exception as e:
                    print(f"An error occurred while processing the file: {e}")
                    
                input_file_path = 'Results/Similarity_search.csv'
                output_file_path = 'Results/Similarity_search.csv'
                try:
                    data = pd.read_csv(input_file_path)
                    data.insert(0, 'Synonym', range(1, len(data) + 1))
                    data.to_csv(output_file_path, index=False)
                except Exception as e:
                    print(f"An error occurred while processing the file: {e}")
                print("Similarity calculation and data extraction are complete. The results have been saved to Similarity_search.csv")
                file_toexcel = "Results/Similarity_search"
        
                # Creating an Excel workbook.
                file_name = file_toexcel
                def mol_to_excel(file):
                    header = ['Synonym', 'SMILES', 'Solvent', 'Ex (nm)', 'Em (nm)', 'ST (nm)', 'QY', 'AC(cm-1M-1)']
                    item_style = {
                        'align': 'center',
                        'valign': 'vcenter',
                        'top': 2,
                        'left': 2,
                        'right': 2,
                        'bottom': 2,
                        'text_wrap': 1
                    }
                    header_style = {
                        'bold': 1,
                        'valign': 'vcenter',
                        'align': 'center',
                        'top': 2,
                        'left': 2,
                        'right': 2,
                        'bottom': 2
                    }
                
                    workbook = xlsxwriter.Workbook(f'{file}.xlsx')
                    ItemStyle = workbook.add_format(item_style)
                    HeaderStyle = workbook.add_format(header_style)
                    worksheet = workbook.add_worksheet()
                
                    worksheet.set_column('A:A', 38)
                    worksheet.set_column('B:B', 40)
                    worksheet.set_column('C:I', 20)
                
                    for ix_, i in enumerate(header):
                        worksheet.write(0, ix_, i, HeaderStyle)
                
                    df = pd.read_csv(f'{file}.csv')
                
                    for i in range(df.shape[0]):
                        synonym = df.iloc[i, 0]
                        structure_smi = df.iloc[i, 1]
                
                        img_data_structure = BytesIO()
                        c_structure = Chem.MolFromSmiles(structure_smi)
                        img_structure = Draw.MolToImage(c_structure)
                        img_structure.save(img_data_structure, format='PNG')
                
                        worksheet.set_row(i + 1, 185)
                        worksheet.insert_image(i + 1, 0, 'f', {'x_scale': 0.9, 'y_scale': 0.8, 'image_data': img_data_structure, 'positioning': 1})
                
                        worksheet.write(i + 1, 1, structure_smi, ItemStyle)
                
                        for j in range(2, 8):
                                    cell_value = df.iloc[i, j]
                                    if pd.isna(cell_value):
                                        cell_value = 'NA'
                                    worksheet.write(i + 1, j, cell_value, ItemStyle)
                    workbook.close()
                
                mol_to_excel(file_name)
                print('Image generation is complete.')

###################################################
# Executing the substructure search function

            elif self.selected_option.get() == 3:       
                messagebox.showinfo("Substructure Search", f"Performing molecule Substructure search: {smiles_input}")

                # Input the substructure file path.
                file1 = pd.read_csv('Process_data/que-data.csv')
                substructure_smiles = file1['SMILES'].iloc[0]
                print(substructure_smiles)
                
                # Input the database file path.
                csv_file2 = "./Data/Data_all.csv"  
                smiles_column = "SMILES"  
                df = pd.read_csv(csv_file2)
                
                substructure = Chem.MolFromSmiles(substructure_smiles)
                if substructure is None:
                    raise ValueError("The SMILES format of the substructure is incorrect. Please check your input！")
                
                # Filter molecules containing the substructure.
                def contains_substructure(smiles):
                    mol = Chem.MolFromSmiles(smiles)
                    if mol is None:
                        return False
                    return mol.HasSubstructMatch(substructure)
                
                # Apply filtering criteria.
                df['Contains_Substructure'] = df[smiles_column].apply(contains_substructure)
                matching_molecules = df[df['Contains_Substructure']]
                matching_molecules.insert(0, 'Synonym', range(1, len(matching_molecules) + 1))
                matching_molecules.to_csv('./Results/Sub_search.csv', index=False)
                print("Molecules containing the substructure have been saved to the 'Sub_search.csv' file.")
                
                file_toexcel = "Results/Sub_search"
                # Creating an Excel workbook.
                file_name = file_toexcel
                def mol_to_excel(file):
                    header = ['Synonym', 'SMILES', 'Solvent', 'Ex (nm)', 'Em (nm)', 'ST (nm)', 'QY', 'AC(cm-1M-1)']
                    item_style = {
                        'align': 'center',
                        'valign': 'vcenter',
                        'top': 2,
                        'left': 2,
                        'right': 2,
                        'bottom': 2,
                        'text_wrap': 1
                    }
                    header_style = {
                        'bold': 1,
                        'valign': 'vcenter',
                        'align': 'center',
                        'top': 2,
                        'left': 2,
                        'right': 2,
                        'bottom': 2
                    }
                
                    workbook = xlsxwriter.Workbook(f'{file}.xlsx')
                    ItemStyle = workbook.add_format(item_style)
                    HeaderStyle = workbook.add_format(header_style)
                    worksheet = workbook.add_worksheet()
                
                    worksheet.set_column('A:A', 38)
                    worksheet.set_column('B:B', 40)
                    worksheet.set_column('C:I', 20)
                
                    for ix_, i in enumerate(header):
                        worksheet.write(0, ix_, i, HeaderStyle)
                
                    df = pd.read_csv(f'{file}.csv')
                
                    for i in range(df.shape[0]):
                        synonym = df.iloc[i, 0]
                        structure_smi = df.iloc[i, 1]
                
                        img_data_structure = BytesIO()
                        c_structure = Chem.MolFromSmiles(structure_smi)
                        img_structure = Draw.MolToImage(c_structure)
                        img_structure.save(img_data_structure, format='PNG')
                
                        worksheet.set_row(i + 1, 185)
                        worksheet.insert_image(i + 1, 0, 'f', {'x_scale': 0.9, 'y_scale': 0.8, 'image_data': img_data_structure, 'positioning': 1})
                
                        worksheet.write(i + 1, 1, structure_smi, ItemStyle)
                
                        for j in range(2, 8):
                                    cell_value = df.iloc[i, j]
                                    if pd.isna(cell_value):
                                        cell_value = 'NA'
                                    worksheet.write(i + 1, j, cell_value, ItemStyle)
                    workbook.close()
                
                mol_to_excel(file_name)
                print('Image generation is complete.')


        except Exception as e:
            messagebox.showerror("Error", f"An error occurred: {str(e)}")
            

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
# Prediction Feature Code

graph_feat_size = 256
class predict_app:
    def __init__(self, master):
        self.master = master
        master.title('Property_Prediction')
        master.resizable(False, False)
        
        image2 = Image.open("./Data/p3.png")
        img2 = image2.resize((1025, 100))
        self.my_img2 = ImageTk.PhotoImage(img2)
        self.background_label2 = Label(master, image=self.my_img2)
        self.background_label2.grid(row=0, column=0, columnspan=10)
        self.canvas2 = tk.Canvas(master, width=1020, height=320, bg="#FFFFFF")
        self.canvas2.grid(row=1, column=0, rowspan=8, columnspan=10)

        self.target_label2 = tk.Label(master, text="Please select the molecule type", font=('Arial', 12, "bold"), fg='#233C6F', bg='#FFFFFF')
        self.target_label2.grid(row=1, column=0, pady=5, columnspan=10)
        self.target_combobox2 = ttk.Combobox(master, values=["xanthene", "cyanine", "all-types"], width=76, font=('Arial', 13, "bold"))
        self.target_combobox2.grid(row=2, column=0, pady=5, columnspan=10)
        self.target_combobox2.current(0)

        self.smiles_label2 = tk.Label(master, text="Please enter the fluorescent molecule SMILES", font=('Arial', 12, "bold"), fg='#233C6F', bg='#FFFFFF')
        self.smiles_label2.grid(row=3, column=0, pady=5, columnspan=10)
        self.smiles_entry2 = tk.Entry(master, width=100)
        self.smiles_entry2.grid(row=4, column=0, pady=5, columnspan=10)
        self.smiles_entry2.insert(0, "Such as: CCOC(=O)c1ccccc1-c1c2ccc(=[N+](CC)CC)cc-2oc2c1c(=O)oc1cc3c(cc12)CCCN3C")
        
        self.solvent_label2 = tk.Label(master, text="Please enter the solvent molecule SMILES", font=('Arial', 12, "bold"), fg='#233C6F', bg='#FFFFFF')
        self.solvent_label2.grid(row=5, column=0, pady=5, columnspan=10)
        self.solvent_entry2 = tk.Entry(master, width=100)
        self.solvent_entry2.grid(row=6, column=0, pady=5, columnspan=10)
        self.solvent_entry2.insert(0, "Such as: CCO")

        self.run_button2 = tk.Button(master, text="Run", command=self.run_predict, font=('Arial', 13, "bold"), bg='#24A5C8', relief=tk.RAISED, width=10)
        self.run_button2.grid(row=7, column=0, pady=5, columnspan=10)
        self.return_button2 = Button(master, text='Go to search', command=self.return_to_search, font=('Arial', 12),bg='#FFFFFF', relief=tk.RAISED)
        self.return_button2.grid(row=8, column=0, pady=5, columnspan=10)

        # Insert Tips
        image4 = Image.open("./Data/p5.png")
        img4 = image4.resize((1025, 170))
        self.my_img4 = ImageTk.PhotoImage(img4)
        self.background_label4 = Label(master, image=self.my_img4)
        self.background_label4.grid(row=9, column=0, columnspan=10)

    def return_to_search(self):
        self.master.withdraw()
        root = tk.Toplevel(self.master)
        app = search_app(root)

    
    def run_predict(self):
        try:
            smiles = self.smiles_entry2.get()
            solvent = self.solvent_entry2.get()
            target = self.target_combobox2.get()
            
            n_tasks = 4
            dropout_g = 0.4
            dropout_f = 0.5
            dropout_l = 0.4
            fp_size = 1024
            
            if torch.cuda.is_available():
                print('use GPU')
                device = 'cuda'
            else:
                print('use CPU')
                device = 'cpu'

            seed = 42
            random.seed(seed)
            np.random.seed(seed)
            os.environ['PYTHONHASHSEED'] = str(seed)
            torch.manual_seed(seed)
            torch.cuda.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            
            def set_random_seed(seed=42):
                random.seed(seed)
                np.random.seed(seed)
                torch.manual_seed(seed)
                if torch.cuda.is_available():
                    torch.cuda.manual_seed(seed)
            
            atom_featurizer = AttentiveFPAtomFeaturizer(atom_data_field='hv')
            bond_featurizer = AttentiveFPBondFeaturizer(bond_data_field='he')
            n_feats = atom_featurizer.feat_size('hv')
            e_feats = bond_featurizer.feat_size('he')

            if target == 'xanthene':
                train_data = pd.read_csv('./Data/train_xanthene.csv')
                valid_data = pd.read_csv('./Data/valid_xanthene.csv')
            if target == 'cyanine':
                train_data = pd.read_csv('./Data/train_cyanine.csv')
                valid_data = pd.read_csv('./Data/valid_cyanine.csv')
            if target == 'all-types':
                train_data = pd.read_csv('./Data/train_all_types.csv')
                valid_data = pd.read_csv('./Data/valid_all_types.csv')
            
            scaler = StandardScaler()
            train_data[['AM', 'EM', 'QY', 'LGAC']] = scaler.fit_transform(train_data[['AM', 'EM', 'QY', 'LGAC']])
            valid_data[['AM', 'EM', 'QY', 'LGAC']] = scaler.transform(valid_data[['AM', 'EM', 'QY', 'LGAC']])
        
            def load_data_with_fp(data, fp_data, name, load):
                dataset = MoleculeCSVDataset(data,
                                             smiles_to_graph=smiles_to_bigraph,
                                             node_featurizer=atom_featurizer,
                                             edge_featurizer=bond_featurizer,
                                             smiles_column='SMILES',
                                             cache_file_path=str(name)+'_dataset.bin',
                                             task_names=['AM','EM','QY','LGAC'],
                                             load=load, init_mask=True, n_jobs=1
                                            )
            
                combined_data = []
                for i, data_tuple in enumerate(dataset):
                    if len(data_tuple) == 3:
                        smiles, graph, label = data_tuple
                        mask = None
                    else:
                        smiles, graph, label, mask = data_tuple
                    fp = torch.tensor(fp_data[i], dtype=torch.float32)
                    combined_data.append((graph, fp, label, mask))
                return combined_data
            
            class GraphFingerprintsModel(nn.Module):
                def __init__(self, node_feat_size, edge_feat_size, fp_size,
                             graph_feat_size=graph_feat_size, num_layers=2, num_timesteps=2,
                             n_tasks=4, dropout_g=0, dropout_f=0, dropout_l=0):
                    super(GraphFingerprintsModel, self).__init__()
            
                    self.gnn = AttentiveFPGNN(node_feat_size=node_feat_size,
                                              edge_feat_size=edge_feat_size,
                                              num_layers=num_layers,
                                              graph_feat_size=graph_feat_size,
                                              dropout=dropout_g)
                    self.readout = AttentiveFPReadout(feat_size=graph_feat_size,
                                                      num_timesteps=num_timesteps,
                                                      dropout=dropout_g)
            
                    self.fp_fc = nn.Sequential(
                        nn.Linear(fp_size, 256),
                        nn.ReLU(),
                        nn.Dropout(dropout_f),
                        nn.Linear(256, graph_feat_size)
                    )
            
                    self.predict = nn.Sequential(
                        nn.Dropout(dropout_l),
                        nn.Linear(graph_feat_size * 2, 128),
                        nn.ReLU(),
                        nn.Linear(128, n_tasks)
                    )
            
                def forward(self, g, node_feats, edge_feats, fingerprints):
                    if edge_feats is None or 'he' not in g.edata.keys():
                        num_edges = g.number_of_edges()
                        edge_feats = torch.zeros((num_edges, e_feats)).to(g.device)
                    node_feats = self.gnn(g, node_feats, edge_feats)
                    graph_feats = self.readout(g, node_feats)
                    fp_feats = self.fp_fc(fingerprints)
                    combined_feats = torch.cat([graph_feats, fp_feats], dim=1)
                    return self.predict(combined_feats)
            
            class MolecularDataset(Dataset):
                def __init__(self, data):
                    self.data = data
            
                def __len__(self):
                    return len(self.data)
            
                def __getitem__(self, idx):
                    return self.data[idx]
            
            def collate_fn(batch):
                graphs, fps, labels, masks = zip(*batch)
                graphs = dgl.batch(graphs)
                fps = torch.stack(fps)
                labels = torch.stack(labels)
                masks = torch.stack(masks) if masks[0] is not None else None
                return graphs, fps, labels, masks
        
            model = GraphFingerprintsModel(node_feat_size=n_feats,
                                           edge_feat_size=e_feats,
                                           graph_feat_size=graph_feat_size,
                                           num_layers=2,
                                           num_timesteps=2,
                                           fp_size=fp_size,
                                           n_tasks=4,
                                           dropout_g=dropout_g,
                                           dropout_f=dropout_f,
                                           dropout_l=dropout_l).to(device)
            
            if target == 'xanthene':
                model.load_state_dict(torch.load('./Data/Model_xanthene.pth', map_location=device))
            if target == 'cyanine':
                model.load_state_dict(torch.load('./Data/Model_cyanine.pth', map_location=device))
            if target == 'all-types':
                model.load_state_dict(torch.load('./Data/Model_all_types.pth', map_location=device))
                
            def predict(model, dataloader):
                all_predictions = []
                with torch.no_grad():
                    for graphs, fps, _, _ in dataloader:
                        graphs = graphs.to(device)
                        fps = fps.to(device)
            
                        node_feats = graphs.ndata['hv']
                        edge_feats = graphs.edata['he']
            
                        predictions = model(graphs, node_feats, edge_feats, fps)
                        all_predictions.append(predictions.cpu().numpy())
            
                return np.vstack(all_predictions)
            
            def save_predictions(predictions, file_name):
                df = pd.DataFrame(predictions, columns=['AM','EM','QY','LGAC'])
                df.to_csv(file_name, index=False)

            def reverse_standardization(predictions, scaler):
                return scaler.inverse_transform(predictions)
            
            def standardize_smiles(smiles):
                try:
                    mol = Chem.MolFromSmiles(smiles)
                    if mol:
                        standardized_mol = normalizer.standardize(mol)
                        return Chem.MolToSmiles(standardized_mol)
                    else:
                        return None
                except:
                    return None
            
            def generate_morgan_fingerprint(smiles, radius=2, n_bits=1024):
                mol = Chem.MolFromSmiles(smiles)
                if mol is not None:
                    fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits)
                    return list(fp)
                else:
                    return [None] * n_bits
                    
            def load_fingerprint(fp_file):
                df = pd.read_csv(fp_file)
                return torch.tensor(df.values, dtype=torch.float32)
        
            if not smiles or not solvent:
                messagebox.showwarning("Input Error", "Please provide SMILES and Solvent input.")
                return
                
            smiles_list = smiles.split(',')
            solvent_list = solvent.split(',')
        
            if len(smiles_list) != len(solvent_list):
                messagebox.showwarning("Input Error", "The number of SMILES and Solvent entries does not match.")
                return
        
            df = pd.DataFrame({
                'SMILES': smiles_list,
                'Solvent': solvent_list,
                'AM': None,
                'EM': None,
                'QY': None,
                'LGAC': None
            })
        
            cols = ['SMILES', 'Solvent', 'AM', 'EM', 'QY', 'LGAC']
            df = df[cols]
            df.to_csv('Process_data/pred-prep.csv', index=False)
        
            file_path = 'Process_data/pred-prep.csv'
            df = pd.read_csv(file_path)
            normalizer = MolStandardize.Standardizer()
            df['SMILES'] = df['SMILES'].apply(standardize_smiles)
            df_valid = df.dropna(subset=['SMILES'])
            l1= len(df)
            l2 = len(df_valid)
            s = l1-l2
            if l1 == l2:
                print('All molecules are valid molecules.')
            else:
                print('There are', s, 'invalid molecules, which have been deleted.')
            output_file = 'Process_data/pred-data.csv'
            df_valid.to_csv(output_file, index=False)
    
            df = pd.read_csv('Process_data/pred-data.csv')
            df['Morgan_Fingerprint'] = df['Solvent'].apply(lambda x: generate_morgan_fingerprint(x))
            fingerprints_df = pd.DataFrame(df['Morgan_Fingerprint'].tolist())
            fingerprints_df.to_csv('Process_data/pred-morgan.csv', index=False)
        
            pred_data = pd.read_csv('Process_data/pred-data.csv')
            pred_fp = load_fingerprint('Process_data/pred-morgan.csv')
            
            pred_data[['AM', 'EM', 'QY', 'LGAC']] = scaler.transform(pred_data[['AM', 'EM', 'QY', 'LGAC']])
            
            pred_datasets = load_data_with_fp(pred_data, pred_fp, 'pred', True)
            pred_dataset = MolecularDataset(pred_datasets)
            test_loader = DataLoader(pred_dataset, batch_size=32, collate_fn=collate_fn)
            test_predictions = predict(model, test_loader)
            
            test_scale_predictions = reverse_standardization(test_predictions, scaler)
            save_predictions(test_scale_predictions, 'Results/pred-results.csv')
            
            file_name = 'pred_dataset.bin'
            file_path = Path(file_name)
            
            if file_path.exists():
                try:
                    file_path.unlink()
                    print(f"The file {file_name} has been deleted.")
                except Exception as e:
                    print(f"An error occurred while deleting the file: {e}")
            else:
                print(f"The file {file_name} does not exist in the current directory.")
                
            print(test_scale_predictions)
        
            
            result_df = pd.read_csv('Results/pred-results.csv')
            new_columns = list(result_df.columns)
            new_columns[0:4] = ['Ex', 'Em', 'QY', 'Log(AC)']
            result_df.columns = new_columns
            result_df.iloc[:, :2] = result_df.iloc[:, :2].round(2)
            result_df.to_csv('Results/pred-results.csv', index=False)
    
            selected_columns = result_df.iloc[:, :2]
            rounded_columns = selected_columns.round(2)
            rounded_df = pd.concat([rounded_columns, result_df.iloc[:, 2:]], axis=1)
            messagebox.showinfo("Prediction complete", f"The prediction results have been saved to 'pred-results.csv'. Below is a preview of the results:\n{rounded_df.head().to_string(index=False)}")
        except Exception as e:
            messagebox.showerror("Error", f"An error occurred: {str(e)}")

#######################################################################################################################
#######################################################################################################################
#######################################################################################################################
# Main Code

def main():
    root = tk.Tk()
    app = WelcomeWindow(root)
    root.mainloop()
if __name__ == "__main__":
    main()

  from .autonotebook import tqdm as notebook_tqdm


Filtering complete, results have been saved to Target_search.csv
Similarity calculation and data extraction are complete. The results have been saved to Similarity_search.csv
Image generation is complete.
c1ccc2c(c1)Cc1ccccc1O2
Molecules containing the substructure have been saved to the 'Sub_search.csv' file.
Image generation is complete.
use GPU
All molecules are valid molecules.
Processing dgl graphs from scratch...




The file pred_dataset.bin has been deleted.
[[6.0486652e+02 6.6004700e+02 1.6649152e-01 4.8799920e+00]]
