# Data splitting

This notebook is used to generate the train-test splits with and without SMOTE technique.

In [1]:
import os
from collections import Counter
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE

In [2]:
import warnings

warnings.filterwarnings(action="ignore", category=pd.errors.PerformanceWarning)

# Load combined dataset

In [3]:
fingerprint_df_dict = {}

for fingperint_name in ["ecfp4", "rdkit", "maccs", "mhfp6", "erg", "chem_phys"]:
    fingerprint_df_dict[fingperint_name] = pd.read_csv(
        f"../data/fingerprints/combined_{fingperint_name}.tsv", sep="\t"
    )

len(fingerprint_df_dict)

6

# Split the dataset into train-test

The split ratio choose for this purpose was 80-10. Since we want to handle just integers, we convert the classes into integers with: 
* gram-negative - 1
* gram-positive - 2
* acid-fast - 3
* fungi - 4
* inactive - 5

Since a high imbalance between the classes exists, we apply the SMOTE technique to rebalance the classes.

In [4]:
SPLITS_DIR = "../data/splits"
os.makedirs(SPLITS_DIR, exist_ok=True)

In [5]:
def smote_base_sampling(df: pd.DataFrame, name: str, sample_type: str, labels: dict):
    sample_dir = f"{SPLITS_DIR}/{sample_type}"
    os.makedirs(sample_dir, exist_ok=True)

    print(f"Processing {name} dataset")

    # Split the data into training and testing sets
    train, test = train_test_split(
        df,
        test_size=0.2,  # 80% training and 20% testing
        random_state=42,  # to ensure that the split is always the same
        shuffle=True,
        stratify=df[
            "label"
        ],  # to ensure that the distribution of the labels is the same in both splits
    )

    # Saving the orginal splits
    train.to_csv(f"{sample_dir}/{name}_train.csv", index=False)
    test.to_csv(f"{sample_dir}/{name}_test.csv", index=False)

    print("Original dataset shape %s" % Counter(train["label"]))

    # Map the labels to integers (SMOTE only works with integers)
    train["label"] = train["label"].map(labels)

    # Split the training data
    X_train = train.drop(columns=["label", "cmp_id"])
    y_train = train["label"]

    # Apply SMOTE to the training data
    sm = SMOTE(random_state=42)
    smote_sampled_train, smote_sampled_labels = sm.fit_resample(X_train, y_train)

    # Map the labels back to their original values
    smote_sampled_train["label"] = smote_sampled_labels
    smote_sampled_train["label"] = smote_sampled_train["label"].map(
        {v: k for k, v in labels.items()}
    )

    print("SMOTE dataset shape %s" % Counter(smote_sampled_train["label"]))

    # Saving the SMOTE splits
    smote_sampled_train.to_csv(f"{sample_dir}/{name}_smote_train.csv", index=False)
    print("\n")

In [6]:
for fingerprint_name, df in tqdm(fingerprint_df_dict.items()):
    labels = {
        "gram-negative": 1,
        "gram-positive": 2,
        "acid-fast": 3,
        "fungi": 4,
        "inactive": 5,
    }
    smote_base_sampling(df, fingerprint_name, sample_type="combined", labels=labels)

  0%|          | 0/6 [00:00<?, ?it/s]

Processing ecfp4 dataset
Original dataset shape Counter({'gram-positive': 21148, 'inactive': 15654, 'gram-negative': 9083, 'fungi': 7631, 'acid-fast': 5845})
SMOTE dataset shape Counter({'gram-positive': 21148, 'inactive': 21148, 'fungi': 21148, 'gram-negative': 21148, 'acid-fast': 21148})


 17%|█▋        | 1/6 [00:28<02:23, 28.65s/it]



Processing rdkit dataset
Original dataset shape Counter({'gram-positive': 21148, 'inactive': 15654, 'gram-negative': 9083, 'fungi': 7631, 'acid-fast': 5845})
SMOTE dataset shape Counter({'gram-positive': 21148, 'inactive': 21148, 'fungi': 21148, 'gram-negative': 21148, 'acid-fast': 21148})


 33%|███▎      | 2/6 [00:57<01:54, 28.66s/it]



Processing maccs dataset
Original dataset shape Counter({'gram-positive': 21148, 'inactive': 15654, 'gram-negative': 9083, 'fungi': 7631, 'acid-fast': 5845})
SMOTE dataset shape Counter({'gram-positive': 21148, 'inactive': 21148, 'fungi': 21148, 'gram-negative': 21148, 'acid-fast': 21148})


 50%|█████     | 3/6 [01:05<00:58, 19.37s/it]



Processing mhfp6 dataset
Original dataset shape Counter({'gram-positive': 21148, 'inactive': 15654, 'gram-negative': 9083, 'fungi': 7631, 'acid-fast': 5845})
SMOTE dataset shape Counter({'gram-positive': 21148, 'inactive': 21148, 'fungi': 21148, 'gram-negative': 21148, 'acid-fast': 21148})


 67%|██████▋   | 4/6 [02:44<01:41, 50.58s/it]



Processing erg dataset
Original dataset shape Counter({'gram-positive': 21148, 'inactive': 15654, 'gram-negative': 9083, 'fungi': 7631, 'acid-fast': 5845})
SMOTE dataset shape Counter({'gram-positive': 21148, 'inactive': 21148, 'fungi': 21148, 'gram-negative': 21148, 'acid-fast': 21148})


 83%|████████▎ | 5/6 [03:00<00:38, 38.13s/it]



Processing chem_phys dataset
Original dataset shape Counter({'gram-positive': 21148, 'inactive': 15654, 'gram-negative': 9083, 'fungi': 7631, 'acid-fast': 5845})
SMOTE dataset shape Counter({'gram-positive': 21148, 'inactive': 21148, 'fungi': 21148, 'gram-negative': 21148, 'acid-fast': 21148})


100%|██████████| 6/6 [03:01<00:00, 30.23s/it]








# Pathogen class specific training files 

In [7]:
for pathogen_class in ["gram-negative", "gram-positive", "acid-fast", "fungi"]:
    for fingerprint_name in tqdm(fingerprint_df_dict):
        pathogen_df = pd.read_csv(
            f"../data/fingerprints/{pathogen_class}_{fingerprint_name}.tsv", sep="\t"
        )

        labels = {
            "active": 1,
            "inactive": 0,
        }

        smote_base_sampling(
            pathogen_df, fingerprint_name, sample_type=pathogen_class, labels=labels
        )

  0%|          | 0/6 [00:00<?, ?it/s]

Processing ecfp4 dataset
Original dataset shape Counter({'inactive': 23532, 'active': 9266})
SMOTE dataset shape Counter({'inactive': 23532, 'active': 23532})


 17%|█▋        | 1/6 [00:12<01:00, 12.17s/it]



Processing rdkit dataset
Original dataset shape Counter({'inactive': 23532, 'active': 9266})
SMOTE dataset shape Counter({'inactive': 23532, 'active': 23532})


 33%|███▎      | 2/6 [00:25<00:50, 12.66s/it]



Processing maccs dataset
Original dataset shape Counter({'inactive': 23532, 'active': 9266})
SMOTE dataset shape Counter({'inactive': 23532, 'active': 23532})


 50%|█████     | 3/6 [00:28<00:24,  8.18s/it]



Processing mhfp6 dataset
Original dataset shape Counter({'inactive': 23532, 'active': 9266})
SMOTE dataset shape Counter({'inactive': 23532, 'active': 23532})


 67%|██████▋   | 4/6 [01:15<00:47, 23.83s/it]



Processing erg dataset
Original dataset shape Counter({'inactive': 23532, 'active': 9266})
SMOTE dataset shape Counter({'inactive': 23532, 'active': 23532})


 83%|████████▎ | 5/6 [01:23<00:17, 17.85s/it]



Processing chem_phys dataset
Original dataset shape Counter({'inactive': 23532, 'active': 9266})
SMOTE dataset shape Counter({'inactive': 23532, 'active': 23532})


100%|██████████| 6/6 [01:23<00:00, 13.94s/it]






  0%|          | 0/6 [00:00<?, ?it/s]

Processing ecfp4 dataset
Original dataset shape Counter({'inactive': 22854, 'active': 14024})
SMOTE dataset shape Counter({'inactive': 22854, 'active': 22854})


 17%|█▋        | 1/6 [00:15<01:15, 15.19s/it]



Processing rdkit dataset
Original dataset shape Counter({'inactive': 22854, 'active': 14024})
SMOTE dataset shape Counter({'inactive': 22854, 'active': 22854})


 33%|███▎      | 2/6 [00:31<01:03, 15.77s/it]



Processing maccs dataset
Original dataset shape Counter({'inactive': 22854, 'active': 14024})
SMOTE dataset shape Counter({'inactive': 22854, 'active': 22854})


 50%|█████     | 3/6 [00:35<00:31, 10.54s/it]



Processing mhfp6 dataset
Original dataset shape Counter({'inactive': 22854, 'active': 14024})
SMOTE dataset shape Counter({'inactive': 22854, 'active': 22854})


 67%|██████▋   | 4/6 [01:31<00:56, 28.27s/it]



Processing erg dataset
Original dataset shape Counter({'inactive': 22854, 'active': 14024})
SMOTE dataset shape Counter({'inactive': 22854, 'active': 22854})


 83%|████████▎ | 5/6 [01:38<00:20, 20.90s/it]



Processing chem_phys dataset
Original dataset shape Counter({'inactive': 22854, 'active': 14024})
SMOTE dataset shape Counter({'inactive': 22854, 'active': 22854})


100%|██████████| 6/6 [01:39<00:00, 16.61s/it]






  0%|          | 0/6 [00:00<?, ?it/s]

Processing ecfp4 dataset
Original dataset shape Counter({'inactive': 6105, 'active': 3166})
SMOTE dataset shape Counter({'active': 6105, 'inactive': 6105})


 17%|█▋        | 1/6 [00:03<00:15,  3.09s/it]



Processing rdkit dataset
Original dataset shape Counter({'inactive': 6105, 'active': 3166})
SMOTE dataset shape Counter({'active': 6105, 'inactive': 6105})


 33%|███▎      | 2/6 [00:06<00:12,  3.17s/it]



Processing maccs dataset
Original dataset shape Counter({'inactive': 6105, 'active': 3166})


 50%|█████     | 3/6 [00:06<00:06,  2.00s/it]

SMOTE dataset shape Counter({'active': 6105, 'inactive': 6105})


Processing mhfp6 dataset
Original dataset shape Counter({'inactive': 6105, 'active': 3166})
SMOTE dataset shape Counter({'active': 6105, 'inactive': 6105})


 67%|██████▋   | 4/6 [00:19<00:12,  6.08s/it]



Processing erg dataset
Original dataset shape Counter({'inactive': 6105, 'active': 3166})
SMOTE dataset shape Counter({'active': 6105, 'inactive': 6105})


100%|██████████| 6/6 [00:21<00:00,  3.53s/it]




Processing chem_phys dataset
Original dataset shape Counter({'inactive': 6105, 'active': 3166})
SMOTE dataset shape Counter({'active': 6105, 'inactive': 6105})




  0%|          | 0/6 [00:00<?, ?it/s]

Processing ecfp4 dataset
Original dataset shape Counter({'inactive': 11025, 'active': 3426})
SMOTE dataset shape Counter({'active': 11025, 'inactive': 11025})


 17%|█▋        | 1/6 [00:05<00:26,  5.30s/it]



Processing rdkit dataset
Original dataset shape Counter({'inactive': 11025, 'active': 3426})
SMOTE dataset shape Counter({'active': 11025, 'inactive': 11025})


 33%|███▎      | 2/6 [00:10<00:21,  5.42s/it]



Processing maccs dataset
Original dataset shape Counter({'inactive': 11025, 'active': 3426})
SMOTE dataset shape Counter({'active': 11025, 'inactive': 11025})


 50%|█████     | 3/6 [00:11<00:09,  3.31s/it]



Processing mhfp6 dataset
Original dataset shape Counter({'inactive': 11025, 'active': 3426})
SMOTE dataset shape Counter({'active': 11025, 'inactive': 11025})


 67%|██████▋   | 4/6 [00:33<00:21, 10.64s/it]



Processing erg dataset
Original dataset shape Counter({'inactive': 11025, 'active': 3426})
SMOTE dataset shape Counter({'active': 11025, 'inactive': 11025})


 83%|████████▎ | 5/6 [00:36<00:08,  8.00s/it]



Processing chem_phys dataset
Original dataset shape Counter({'inactive': 11025, 'active': 3426})
SMOTE dataset shape Counter({'active': 11025, 'inactive': 11025})


100%|██████████| 6/6 [00:37<00:00,  6.18s/it]






