In [None]:
#imports
from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# t-SNE Chemical Data Split

In [None]:
def tsne_split(data, rs=42):
    """
    Splits the input data into training and testing sets based on t-SNE clustering.
    
    Parameters:
    - data (pd.DataFrame): A DataFrame containing SMILES strings and binary activity labels.
    - rs (int, optional): Random state for TSNE. Defaults to 42.
    
    Returns:
    - tuple: Contains training data, testing data, and a DataFrame with t-SNE coordinates and set labels.
    """
    # Convert SMILES to MorganFingerprint bit vectors
    bitvectors = np.array([AllChem.GetMorganFingerprintAsBitVect(Chem.MolFromSmiles(sm), 2, nBits=32) for sm in data['SMILES']])
    
    # Perform t-SNE
    tsne = TSNE(random_state=rs, init='pca')
    bit_tsne = tsne.fit_transform(bitvectors)
    
    # Split into training and testing based on t-SNE y-axis (~50/50%)
    train_indices = bit_tsne[:, 1] > 0
    test_indices = ~train_indices
    train_data = data.iloc[train_indices]
    test_data = data.iloc[test_indices]
    
    # Create DataFrame for t-SNE coordinates and assign sets
    tsne_coords_df = pd.DataFrame(bit_tsne, columns=["tSNE_1", "tSNE_2"])
    tsne_coords_df['activity'] = data['activity'].values
    tsne_coords_df['set'] = np.where(train_indices, 'train', 'test')
    
    return train_data, test_data, tsne_coords_df


def activity_counter(curr_df):
    """
    Counts and prints the number of active and inactive samples in the given DataFrame.
    
    Parameters:
    - curr_df (pd.DataFrame): DataFrame containing the binary 'activity' column.
    """
    total = len(curr_df)
    actives = curr_df['activity'].sum()
    inactives = total - actives
    
    print(f"Total samples: {total}\nActives: {actives}\nInactives: {inactives}")


### Import Data

In [None]:
bioassay_data_df = pd.read_csv('your/path/here.csv')

###  Apply t-SNE Split

In [None]:
train_df, test_df, viz_df = tsne_split(bioassay_data_df, rs=42)

In [None]:
activity_counter(train_df)

In [None]:
activity_counter(test_df)

In [None]:
# #save train_data, test_data 
# train_df.to_csv('train.csv')
# test_df.to_csv('test.csv')

# Visualization of the Split

In [None]:
def visualize_split(curr_df):
    """
    Visualizes the training and testing data split using t-SNE coordinates.
    
    This function creates a scatter plot with the data points colored and marked differently based on their
    activity status (active or inactive) and whether they are part of the training or testing set.
    
    Parameters:
    - curr_df (pd.DataFrame): DataFrame containing t-SNE coordinates, activity status, and set labels.
    """
    plt.figure(figsize=[10, 8], dpi=300)
    
    # Define plot settings for each combination of activity and set
    settings = [
        {'filter': (curr_df['activity'] == 0) & (curr_df['set'] == 'train'), 'color': 'm', 'marker': 'd', 'label': 'Inactive Train'},
        {'filter': (curr_df['activity'] == 1) & (curr_df['set'] == 'train'), 'color': 'c', 'marker': 'd', 'label': 'Active Train'},
        {'filter': (curr_df['activity'] == 0) & (curr_df['set'] == 'test'), 'color': 'm', 'marker': 'o', 'label': 'Inactive Test'},
        {'filter': (curr_df['activity'] == 1) & (curr_df['set'] == 'test'), 'color': 'c', 'marker': 'o', 'label': 'Active Test'},
    ]

    for setting in settings:
        filtered_data = curr_df[setting['filter']]
        sns.scatterplot(x='tSNE_1', y='tSNE_2', data=filtered_data,
                        color=setting['color'], marker=setting['marker'], s=50, alpha=1, label=setting['label'])
    
    plt.axhline(y=0, color='red', linestyle='dotted')
    plt.legend(title='Data Type', title_fontsize='13', fontsize='11', loc='best', frameon=True, shadow=True)
    plt.title('Train-Test t-SNE Chemical Split', fontsize=16, fontweight='bold')
    plt.xlabel('t-SNE Dimension 1', fontsize=14)
    plt.ylabel('t-SNE Dimension 2', fontsize=14)
    plt.grid(True, which='both', linestyle='--', linewidth=0.5)
    plt.tight_layout()
    plt.show()


In [None]:
visualize_split(viz_df)