# Load and clean the dataset

In [None]:
import pandas as pd
from rdkit import Chem
from rdkit.Chem import MACCSkeys, Descriptors,AllChem
from rdkit.Chem.MolStandardize import rdMolStandardize
from matplotlib.colors import ListedColormap
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import GroupShuffleSplit
import rdkit.Chem.Descriptors3D
from morfeus import xtb, read_xyz
import os

In [None]:
file_path = '20240621_Dataset_Raw_Exp.xlsx'

solubility_df = pd.read_excel(file_path, sheet_name='Solubility', engine='openpyxl')
solvents_df = pd.read_excel(file_path, sheet_name='Solvents', engine='openpyxl')
drugs_df = pd.read_excel(file_path, sheet_name='Drugs', engine='openpyxl')


In [None]:
solubility_df.shape

In [None]:
lab = solubility_df[solubility_df['DOI'] == 'Lab']
lab.shape

In [None]:
literature = solubility_df[solubility_df['DOI'] != 'Lab']
literature.shape

In [None]:
literature

In [None]:
def duplicate_removal(df):
    
    df['Mono solvent'] = 'No'

    df.loc[(df['Solvent_1_weight_fraction'] == 1) | (df['Solvent_1_mol_fraction'] == 1), 'Mono solvent'] = df['Solvent_1']
    df.loc[(df['Solvent_1_weight_fraction'] == 0) | (df['Solvent_1_mol_fraction'] == 0), 'Mono solvent'] = df['Solvent_2']

    df_no_duplicates = df[df['Mono solvent'] != 'No'].drop_duplicates(subset=['Drug', 'Mono solvent', 'Temperature (K)'])

    result_df = pd.concat([df[df['Mono solvent'] == 'No'], df_no_duplicates], ignore_index=True)
    
    result_df = result_df.drop(['Mono solvent'], axis = 1)

    return result_df

In [None]:
literature = duplicate_removal(literature)
literature.shape

In [None]:
literature['Solvent_1_Fraction'] = literature['Solvent_1_weight_fraction'].combine_first(literature['Solvent_1_mol_fraction'])
literature['Solvent_2_Fraction'] = 1 - literature['Solvent_1_Fraction']

In [None]:
literature

In [None]:
def check_and_swap(group):
    correlation_1 = group['Solvent_1_Fraction'].corr(group['Solubility (mol/mol)'])
    correlation_2 = group['Solvent_2_Fraction'].corr(group['Solubility (mol/mol)'])

    if correlation_1 < correlation_2:

        temp_solvent = group['Solvent_1'].copy()
        group['Solvent_1'] = group['Solvent_2']
        group['Solvent_2'] = temp_solvent


        group['Solvent_1_Fraction'] = 1 - group['Solvent_1_Fraction']
        

        if 'Solvent_1_weight_fraction' in group and group['Solvent_1_weight_fraction'].notna().all():
            group['Solvent_1_weight_fraction'] = 1 - group['Solvent_1_weight_fraction']

        if 'Solvent_1_mol_fraction' in group and group['Solvent_1_mol_fraction'].notna().all():
            group['Solvent_1_mol_fraction'] = 1 - group['Solvent_1_mol_fraction']

    return group


literature = literature.groupby(['Drug', 'Solvent_1', 'Solvent_2', 'Temperature (K)']).apply(check_and_swap).reset_index(drop=True)


In [None]:
literature

In [None]:
literature = literature.drop(['Solvent_1_Fraction', 'Solvent_2_Fraction'], axis = 1)

In [None]:
literature

In [None]:
duplicates_mask = literature.duplicated(subset=['Drug', 'Solvent_1', 'Solvent_1_weight_fraction', 'Solvent_1_mol_fraction', 'Solvent_2', 'Temperature (K)'], keep=False)

df_duplicates = literature[duplicates_mask]

df_duplicates = df_duplicates.reset_index(drop=True)

df_duplicates_sorted = df_duplicates.sort_values(by=['Drug', 'Solvent_1', 'Solvent_1_weight_fraction', 'Solvent_1_mol_fraction', 'Solvent_2', 'Temperature (K)']).reset_index(drop=True)

df_duplicates_sorted.shape[0]



# Generate features

## Tautomers, charge state, and diastereomers

In [None]:
def find_diastereomers(df,Compound):
    possible_diastereomers = []
    
    for index, row in df.iterrows():
        smiles = row['SMILES']
        #smiles = row['SMILES']
        drug_name = row[Compound]
        

        mol = Chem.MolFromSmiles(smiles)
        
        if mol:
            chiral_centers = Chem.FindMolChiralCenters(mol, includeUnassigned=True)
            

            if len(chiral_centers) > 1:
                possible_diastereomers.append(drug_name)
    
    return possible_diastereomers

In [None]:
possible_diastereomers = find_diastereomers(drugs_df, 'Drug')
print("possible_diastereomers:")
for drug in possible_diastereomers:
    print(drug)

In [None]:
possible_diastereomers = find_diastereomers(solvents_df, 'Solvent')
print("possible_diastereomers:")
for solvent in possible_diastereomers:
    print(solvent)

In [None]:
tautomer_enumerator = rdMolStandardize.TautomerEnumerator()
uncharger = rdMolStandardize.Uncharger()

def standardize_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    mol = tautomer_enumerator.Canonicalize(mol)
    mol = uncharger.uncharge(mol)
    return Chem.MolToSmiles(mol)

In [None]:
drugs_df['standardized_SMILES'] = drugs_df['SMILES'].apply(standardize_smiles)

In [None]:
solvents_df['standardized_SMILES'] = solvents_df['SMILES'].apply(standardize_smiles)

## XTB_features

In [None]:
def calculate_xtb_features(smile, max_tries=500):
    print('')
    print(smile)
    for attempt in range(max_tries):
        try:

            mol = Chem.MolFromSmiles(smile)
            mol_3d = Chem.AddHs(mol)
            AllChem.EmbedMolecule(mol_3d)

            AllChem.UFFOptimizeMolecule(mol_3d, maxIters=500000)
            xyz_file = "molecule.xyz"

            with open(xyz_file, "w") as f:
                f.write(Chem.MolToXYZBlock(mol_3d))
        
            atoms, coordinates = read_xyz(xyz_file)

            xtb_instance = xtb.XTB(atoms, coordinates)
            
            xtb_ea = xtb_instance.get_ea()
            xtb_global_descriptor = xtb_instance.get_global_descriptor('electrophilicity')
            xtb_homo = xtb_instance.get_homo()
            xtb_ip = xtb_instance.get_ip()
            xtb_lumo = xtb_instance.get_lumo()
            
            return [xtb_ea, xtb_global_descriptor, xtb_homo, xtb_ip, xtb_lumo]
        
        except Exception as e:
            if attempt%5==0:
                print(f"Attempt {attempt + 1} failed")
            if attempt == max_tries - 1:
                raise Exception(f"Failed to calculate XTB features after {max_tries} attempts.")

In [None]:
xtb_drug_features_list = []

for smile in drugs_df['standardized_SMILES']:
    xtb_drug_features = calculate_xtb_features(smile)
    xtb_drug_features_list.append(xtb_drug_features)
    xtb_drug_features_df = pd.DataFrame(xtb_drug_features_list, columns=['xtb_ea', 'xtb_global_descriptor', 'xtb_homo', 'xtb_ip', 'xtb_lumo'])

xtb_drug_features_df.to_excel("xtb_drug_features.xlsx")


In [None]:
xtb_solvent_features_list = []

for smile in solvents_df['standardized_SMILES']:
    xtb_solvent_features = calculate_xtb_features(smile)
    xtb_solvent_features_list.append(xtb_solvent_features)
    xtb_solvent_features_df = pd.DataFrame(xtb_solvent_features_list, columns=['xtb_ea', 'xtb_global_descriptor', 'xtb_homo', 'xtb_ip', 'xtb_lumo'])

xtb_solvent_features_df.to_excel("xtb_solvent_features.xlsx")


## Other features

In [None]:
def generate_feats(df, xtb_features):
    example = Chem.MolFromSmiles('C')
    example = Chem.AddHs(example)
    AllChem.EmbedMolecule(example)
    threeD_columns = [col for col in rdkit.Chem.Descriptors3D.CalcMolDescriptors3D(example)]
    
    maccs_list = []
    descriptors_list = []
    extra_list = []
    descriptors3d_list = []

    descriptor_names = [desc[0] for desc in Descriptors._descList]
    descriptor_functions = [desc[1] for desc in Descriptors._descList]

    for smile in df['standardized_SMILES']:

        mol = Chem.MolFromSmiles(smile)
        
        mol_3d = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol_3d)
        
        fingerprint = MACCSkeys.GenMACCSKeys(mol)
        maccs_list.append(list(fingerprint)[1:])
        
        descriptor_values = [func(mol) for func in descriptor_functions]
        descriptors_list.append(descriptor_values)
        
        volume = AllChem.ComputeMolVolume(mol_3d)
        AllChem.ComputeGasteigerCharges(mol_3d)
        dipole = sum(float(mol_3d.GetAtomWithIdx(i).GetProp('_GasteigerCharge')) * mol_3d.GetConformer().GetAtomPosition(i).x for i in range(mol_3d.GetNumAtoms()))
        extra_desc = [volume, dipole]
        extra_list.append(extra_desc)

        descriptors3d_values = rdkit.Chem.Descriptors3D.CalcMolDescriptors3D(mol_3d)
        descriptors3d_list.append(descriptors3d_values)


    maccs_df = pd.DataFrame(maccs_list, columns=[f'Maccs_{i}' for i in range(1, 167)])
    descriptors_df = pd.DataFrame(descriptors_list, columns=[name for name in descriptor_names])
    extra_desc_df = pd.DataFrame(extra_list, columns=['Volume', 'Dipole Moment'])

    descriptors3d_df = pd.DataFrame(descriptors3d_list, columns=threeD_columns)


    combined_df = pd.concat([df, maccs_df, descriptors_df], axis=1)

    return combined_df


In [None]:
xtb_drug_features_df = pd.read_excel("xtb_drug_features.xlsx")

drug_feats = generate_feats(drugs_df, xtb_drug_features_df)


In [None]:
xtb_solvent_features_df = pd.read_excel("xtb_solvent_features.xlsx")
solvent_feats = generate_feats(solvents_df, xtb_solvent_features_df)
solvent_feats

## Add features to the dataset

In [None]:
def enhance_solubility_data(solubility_df, drugs_df, solvents_df, drug_features, solvent_features):

    drug_col_rename_map = {feature: f"Drug_{feature}" for feature in drug_features}
    selected_drug_features = drugs_df[drug_features + ['Drug']]
    renamed_drug_features = selected_drug_features.rename(columns=drug_col_rename_map)
    enhanced_df = solubility_df.merge(renamed_drug_features, on='Drug', how='left')

    solvent_1_col_rename_map = {feature: f"Solvent_1_{feature}" for feature in solvent_features}
    solvent_1_col_rename_map['Solvent'] = 'Solvent_1'
    selected_solvent_1_features = solvents_df[solvent_features + ['Solvent']].rename(columns=solvent_1_col_rename_map)
    enhanced_df = enhanced_df.merge(selected_solvent_1_features, left_on='Solvent_1', right_on='Solvent_1', how='left')

    solvent_2_col_rename_map = {feature: f"Solvent_2_{feature}" for feature in solvent_features}
    solvent_2_col_rename_map['Solvent'] = 'Solvent_2'
    selected_solvent_2_features = solvents_df[solvent_features + ['Solvent']].rename(columns=solvent_2_col_rename_map)
    enhanced_df = enhanced_df.merge(selected_solvent_2_features, left_on='Solvent_2', right_on='Solvent_2', how='left')

    return enhanced_df


In [None]:
maccs_feats = [f'Maccs_{i}' for i in range(1, 167)]
rdkit_feats = [desc[0] for desc in Descriptors._descList]
xtb_feats = ['xtb_ea', 'xtb_global_descriptor', 'xtb_homo', 'xtb_ip', 'xtb_lumo']
extra_feats = ['Volume', 'Dipole Moment']

example = Chem.MolFromSmiles('C')
example = Chem.AddHs(example)
AllChem.EmbedMolecule(example)
threeD_feats = [col for col in rdkit.Chem.Descriptors3D.CalcMolDescriptors3D(example)]




calculated_feats = maccs_feats + rdkit_feats


drug_feats_to_extract = ['Collected_Melting_temp (K)', 'Predicted_Melting_temp (K)', 'Drugs@FDA', 'SMILES'] + calculated_feats 
solvent_feats_to_extract = ['Collected_Melting_temp (K)']  + calculated_feats 

In [None]:
enhanced_literature = enhance_solubility_data(literature, drug_feats, solvent_feats, drug_feats_to_extract, solvent_feats_to_extract)
enhanced_literature.shape



In [None]:
enhanced_lab = enhance_solubility_data(lab, drug_feats, solvent_feats, drug_feats_to_extract, solvent_feats_to_extract)
enhanced_lab.shape



# Unit conversion

In [None]:
def calculate_fractions(df):
    for index, row in df.iterrows():

        molecular_weight_1 = row['Solvent_1_ExactMolWt']
        molecular_weight_2 = row['Solvent_2_ExactMolWt']

        if pd.isna(row['Solvent_1_mol_fraction']) and not pd.isna(row['Solvent_1_weight_fraction']):
            weight_fraction_1 = float(row['Solvent_1_weight_fraction'])
            weight_fraction_2 = 1 - weight_fraction_1
            mole_fraction_1 = (weight_fraction_1 / molecular_weight_1) / ((weight_fraction_1 / molecular_weight_1) + (weight_fraction_2 / molecular_weight_2))
            df.at[index, 'Solvent_1_mol_fraction'] = mole_fraction_1

        elif not pd.isna(row['Solvent_1_mol_fraction']) and pd.isna(row['Solvent_1_weight_fraction']):
            mole_fraction_1 = float(row['Solvent_1_mol_fraction'])
            mole_fraction_2 = 1 - mole_fraction_1
            weight_fraction_1 = (mole_fraction_1 * molecular_weight_1) / ((mole_fraction_1 * molecular_weight_1) + (mole_fraction_2 * molecular_weight_2))
            df.at[index, 'Solvent_1_weight_fraction'] = weight_fraction_1
            
    return df


enhanced_literature = calculate_fractions(enhanced_literature)
enhanced_lab = calculate_fractions(enhanced_lab)



In [None]:
def calculate_logs(df_updated):
    
    total_moles = 1

    df_updated['mol0'] = total_moles * df_updated['Solubility (mol/mol)']
    df_updated['mol1'] = (total_moles - df_updated['mol0']) * df_updated['Solvent_1_mol_fraction']
    df_updated['mol2'] = (total_moles - df_updated['mol0']) * (1 - df_updated['Solvent_1_mol_fraction'])
    df_updated['total_mol'] = df_updated['mol0'] + df_updated['mol1'] + df_updated['mol2']
    print(df_updated[['total_mol']].describe())

    df_updated['mass0'] = df_updated['mol0'] * df_updated['Drug_ExactMolWt']
    df_updated['mass1'] = df_updated['mol1'] * df_updated['Solvent_1_ExactMolWt']
    df_updated['mass2'] = df_updated['mol2'] * df_updated['Solvent_2_ExactMolWt']
    df_updated['total_mass'] = df_updated['mass0'] + df_updated['mass1'] + df_updated['mass2']

    df_updated['Solubility (g/g)'] = df_updated['mass0'] / df_updated['total_mass']
    df_updated['Solubility (g/100g)'] = df_updated['Solubility (g/g)'] * 100
    df_updated['LogS'] = np.log10(df_updated['Solubility (g/100g)'])

    df_updated = df_updated.drop(['mol0', 'mol1', 'mol2', 'mass0','mass1','mass2','total_mass','total_mol','Solubility (g/g)'], axis = 1)
    
    return df_updated

In [None]:
updated_literature = calculate_logs(enhanced_literature)


In [None]:
updated_lab = calculate_logs(enhanced_lab)


In [None]:
columns_with_missing_values = updated_lab.columns[updated_lab.isna().any()] + updated_lab.columns[updated_lab.isna().any()]
len(columns_with_missing_values)


# Train/test split

In [None]:
quantiles = updated_literature['LogS'].quantile([0.00, 0.25, 0.50, 0.75, 1.00])

def determine_class(x):
    if x <= quantiles[0.25]:
        return f"[{round(quantiles[0.00],2)}, {round(quantiles[0.25],2)}]"
    elif x <= quantiles[0.5]:
        return f"({round(quantiles[0.25],2)}, {round(quantiles[0.5],2)}]"
    elif x <= quantiles[0.75]:
        return f"({round(quantiles[0.5],2)}, {round(quantiles[0.75],2)}]"
    else:
        return f"({round(quantiles[0.75],2)}, {round(quantiles[1.0],2)}]"


updated_literature['Class'] = updated_literature['LogS'].apply(determine_class)

updated_lab['Class'] = None

print(round(quantiles[0.00],2),round(quantiles[0.25],2),round(quantiles[0.5],2),round(quantiles[0.75],2), round(quantiles[1.00],2),)

In [None]:
def drug_solvent_system(row):

    sorted_solvents = sorted([row['Solvent_1'], row['Solvent_2']])

    return f"{row['Drug']}-{sorted_solvents[0]}/{sorted_solvents[1]}"

updated_literature['Drug-solvent system'] = updated_literature.apply(drug_solvent_system, axis=1)
updated_lab['Drug-solvent system'] = updated_lab.apply(drug_solvent_system, axis=1)

In [None]:
def solvent_system(row):

    sorted_solvents = sorted([row['Solvent_1'], row['Solvent_2']])

    return f"{sorted_solvents[0]}/{sorted_solvents[1]}"

updated_literature['Solvent system'] = updated_literature.apply(solvent_system, axis=1)
updated_lab['Solvent system'] = updated_lab.apply(solvent_system, axis=1)

In [None]:
updated_literature.reset_index(drop=True, inplace=True)

gss = GroupShuffleSplit(n_splits=1, test_size=0.25, random_state=19680611)


for train_idx, test_idx in gss.split(updated_literature, groups=updated_literature['Drug-solvent system']):
    updated_literature.loc[train_idx, 'Type'] = 'Train'
    updated_literature.loc[test_idx, 'Type'] = 'Test'

    
train_subset = updated_literature[updated_literature['Type'] == 'Train']
test_subset = updated_literature[updated_literature['Type'] == 'Test']
updated_lab['Type'] = 'Lab'

In [None]:
train_subset.shape

In [None]:
test_subset.shape

In [None]:
updated_lab.shape

In [None]:
overlapping_groups = set(train_subset['Drug-solvent system']).intersection(set(test_subset['Drug-solvent system']))
len(overlapping_groups)

In [None]:
def analyze_solute_data(df):
    
    unique_solute = df['Drug'].nunique()
    
    
    unique_solvents = pd.concat([df['Solvent_1'], df['Solvent_2']]).nunique()

    
    
    unique_solvent_systems = df['Solvent system'].nunique()
    
    
    
    unique_drug_solvent_systems = df['Drug-solvent system'].nunique()

    unique_DOIs = df['DOI'].nunique()

    
    print(f"Number of unique solute: {unique_solute}")
    print(f"Number of unique solvents: {unique_solvents}")
    print(f"Number of unique solvent systems: {unique_solvent_systems}")
    print(f"Number of unique drug-solvent systems: {unique_drug_solvent_systems}")
    print(f"Number of paper: {unique_DOIs}")
    print(f"Number of data entries: {df.shape[0]}")



In [None]:
analyze_solute_data(updated_literature)

In [None]:
analyze_solute_data(train_subset)

In [None]:
analyze_solute_data(test_subset)

In [None]:
analyze_solute_data(updated_lab)

# Plot

In [None]:
def plot_stacked_bar_chart(ann, ax, df_no_outliers, fontsize = 10):
    
    df = df_no_outliers.copy()

    
    solubility_order = ['[-5.48, -0.43]', '(-0.43, 0.36]', '(0.36, 1.0]', '(1.0, 1.99]']
      

    
    quantile_edges = df['Temperature (K)'].quantile([0, 0.25, 0.5, 0.75, 1]).values
    quantile_labels = [f"{quantile_edges[i]:.2f}K to {quantile_edges[i+1]:.2f}K" for i in range(len(quantile_edges)-1)]

    
    df['Temp Quantile Bin'] = pd.cut(df['Temperature (K)'], bins=quantile_edges, labels=quantile_labels, include_lowest=True)

    
    grouped_by_solubility_and_quantile = df.groupby(['Class', 'Temp Quantile Bin']).size().unstack()
    grouped_quantile_ordered = grouped_by_solubility_and_quantile.reindex(solubility_order)

    
    lightest = "#D3D4D9"
    light = "#A8A9B2"
    medium = "#515265"
    dark = "#3D3E4C"
    
    lightest = "gainsboro"
    light = "darkgrey"
    medium = "gray"
    dark = "#3D3E4C"

    
    new_colors = [lightest, light, medium, dark]
    new_cmap = ListedColormap(new_colors)

   
    grouped_quantile_ordered.plot(kind='bar', stacked=True, colormap=new_cmap, ax=ax)

    
    column_totals = grouped_quantile_ordered.sum(axis=1)

   
    for i, bar in enumerate(ax.patches):
        bar_height = bar.get_height()
        bar_base = bar.get_y()
        class_index = i // len(quantile_labels)
        percentage = (bar_height / column_totals.iloc[class_index]) * 100

        font_color = 'white' if bar_base + bar_height > column_totals.iloc[class_index] * 0.7 else 'black'

        if percentage > 0:
            ax.annotate(f'{percentage:.1f}%', 
                        (bar.get_x() + bar.get_width() / 2, bar_base + bar_height / 2),
                        ha='center', va='center', xytext=(0, 5), textcoords='offset points',
                        color=font_color, fontsize = fontsize)


    ax.set_xlabel('Log S (g/100g)', fontsize = fontsize)
    ax.set_ylabel('Number of data points', fontsize = fontsize)
    ax.tick_params(axis='both', which='major', labelsize=fontsize, color='black', length=5)

    handles, labels = ax.get_legend_handles_labels()
    number_of_columns = len(quantile_labels) // 2 + (len(quantile_labels) % 2 > 0)  
    legend = ax.legend(handles, quantile_labels, title='Temperature', 
                       bbox_to_anchor=(0.5, 1.12), loc='upper center', ncol=number_of_columns, 
                       frameon=False, fontsize=fontsize)
    plt.setp(legend.get_title(), fontsize=fontsize)  
    
    ax.set_facecolor('none')
    
    for spine in ['top', 'right', 'bottom', 'left']:
        ax.spines[spine].set_color('black')
        ax.spines[spine].set_linewidth(0.5)

    plt.tight_layout(rect=[0, 0, 1, 0.95])
    ax.set_xticklabels(ax.get_xticklabels(), rotation=0)
    ax.annotate(ann, xy=(0.0, 1.04), xycoords="axes fraction", va="top", ha="left", fontsize = fontsize)


In [None]:
black = '#515265'
red = '#DD706E'
yellow = '#FAAF3A'
blue = '#3A93C2'

def plot_boxplot(ann, ax, df, cols, feature_name=None, fontsize=14):

    custom_palette = {"Train": blue, "Test": red}


    if len(cols) == 2:
        melted_df = pd.melt(df, id_vars='Type', value_vars=cols, var_name='Variable', value_name='Value')
        sns.boxplot(x='Type', y='Value', data=melted_df, palette=custom_palette, ax=ax)
        col_to_plot = 'Value'
    elif len(cols) == 1:
        sns.boxplot(x='Type', y=cols[0], data=df, palette=custom_palette, ax=ax)
        col_to_plot = cols[0]
    else:
        raise ValueError("Please provide one or two column names.")
    
    title = feature_name if feature_name else col_to_plot
    ax.set_ylabel(title, fontsize=fontsize)
    ax.set_xlabel('')
    

    for spine in ['top', 'right', 'bottom', 'left']:
        ax.spines[spine].set_color('black')
        ax.spines[spine].set_linewidth(0.5)


    ax.set_facecolor('none')  
    ax.tick_params(axis='y', which='both', length=5, color='black')  
    ax.grid(False)

    ax.annotate(ann, xy=(0, 1.08), xycoords="axes fraction", va="top", ha="left", fontsize=fontsize)


    ax.xaxis.label.set_size(fontsize)
    ax.yaxis.label.set_size(fontsize)
    

    ax.tick_params(axis='both', which='major', labelsize=fontsize)

    ax.tick_params(axis='y', which='both', labelsize=fontsize)


In [None]:
black = '#515265'
red = '#DD706E'
yellow = '#FAAF3A'
blue = '#3A93C2'

def plot_violin(ann, ax, dataset, cols, feature_name=None, fontsize=14):
    
    df = dataset.copy()
    custom_palette = {"Train": blue, "Test": red}
    if len(cols) == 2:
        data_to_plot = df[cols].values.flatten()
    elif len(cols) == 1:
        data_to_plot = df[cols[0]].values
    else:
        raise ValueError("Please provide one or two column names.")
    

    df['Type'] = pd.Categorical(df['Type'], categories=["Train", "Test"], ordered=True)

    stats = df.groupby('Type')[cols[0]].quantile([0.25, 0.5, 0.75]).unstack()
    stats.columns = ['Q1', 'Q2', 'Q3']
    print()
    print(cols)
    print(stats)
    print()

    # Create a split violin plot
    sns.violinplot(x='Type', y=cols[0], data=df, palette=custom_palette, ax=ax, split=True, inner="quartile", cut=0)
    
    title = feature_name if feature_name else cols[0]
    ax.set_ylabel(feature_name, fontsize=fontsize)
    ax.set_xlabel('', fontsize=fontsize)
    
    for spine in ['top', 'right', 'bottom', 'left']:
        ax.spines[spine].set_color('black')
        ax.spines[spine].set_linewidth(0.5)
    
    ax.set_facecolor('none')
    ax.tick_params(axis='y', which='both', length=5, color='black')
    ax.grid(False)
    
    ax.annotate(ann, xy=(0, 1.08), xycoords="axes fraction", va="top", ha="left", fontsize=fontsize)
    
    ax.xaxis.label.set_size(fontsize)
    ax.yaxis.label.set_size(fontsize)
    
    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    ax.tick_params(axis='y', which='both', labelsize=fontsize)

In [None]:
fig = plt.figure(figsize=(18, 12))

grid = plt.GridSpec(2, 4, wspace=0.4, hspace=0.2)


ax1 = fig.add_subplot(grid[0:2, 0:2])
plot_stacked_bar_chart('a)',ax1, updated_literature, fontsize = 14)


ax2 = fig.add_subplot(grid[0, 2])
plot_violin('b)',ax2, updated_literature, ['Temperature (K)'], 'Temperature (K)')


ax3 = fig.add_subplot(grid[0, 3])
plot_violin('c)',ax3, updated_literature, ['Drug_ExactMolWt'], 'Solute_MW (Da)')


ax4 = fig.add_subplot(grid[1, 2])
plot_violin('d)',ax4, updated_literature, ['Solvent_1_Collected_Melting_temp (K)','Solvent_2_Collected_Melting_temp (K)'], 'Solvent_MP (K)')


ax5 = fig.add_subplot(grid[1, 3])
plot_violin('e)',ax5, updated_literature, ['LogS'], 'LogS (g/100g)')


fig.patch.set(facecolor='none')


fig.canvas.print_figure('Figure_2_Dataset_overview.png', dpi=600)


plt.show()


In [None]:
updated_literature = updated_literature.rename(columns={
    'Solvent_1_mol_fraction': 'Solvent_mol_fraction',
    'Solvent_1_weight_fraction': 'Solvent_mass_fraction'
})

In [None]:
updated_lab = updated_lab.rename(columns={
    'Solvent_1_mol_fraction': 'Solvent_mol_fraction',
    'Solvent_1_weight_fraction': 'Solvent_mass_fraction'
})

In [None]:
all_feat_df = pd.concat([updated_literature, updated_lab], ignore_index=True)
all_feat_df.shape


In [None]:
all_feat_df

In [None]:
all_feat_df[['Drug_ExactMolWt', 'Drug_MolLogP']].describe()

In [None]:
all_feat_df = all_feat_df.drop(['Web of Science Index','Solubility (mol/mol)', 'DOI', 'Solvent system','Drug_Drugs@FDA'], axis = 1)



In [None]:
all_feat_df.to_csv('Raw_dataset_dataset_20240705.csv', index=False)

In [None]:
all_feat_df['Type']