# Data splitting

This notebook is used to generate the train-test splits with and without SMOTE technique.

In [1]:
import os
from collections import Counter
import pandas as pd
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from imblearn.over_sampling import SMOTE

In [2]:
import warnings

warnings.filterwarnings(action="ignore", category=pd.errors.PerformanceWarning)

# Load dataset

In [3]:
fingerprint_df_dict = {}

for fingperint_name in ["ecfp4", "rdkit", "maccs", "mhfp6", "erg", "chem_phys"]:
    fingerprint_df_dict[fingperint_name] = pd.read_csv(
        f"../data/fingerprints/combined_{fingperint_name}.tsv", sep="\t"
    )

len(fingerprint_df_dict)

6

# Split the dataset into train-test

The split ratio choose for this purpose was 80-10. Since we want to handle just integers, we convert the classes into integers with: 
* gram-negative - 1
* gram-positive - 2
* acid-fast - 3
* fungi - 4

Since a high imbalance between the classes exists, we apply the SMOTE technique to rebalance the classes.

In [4]:
os.makedirs("../data/splits", exist_ok=True)

In [5]:
def smote_base_sampling(df: pd.DataFrame, name: str):
    print(f"\n Processing {name} dataset")

    # Split the data into training and testing sets
    train, test = train_test_split(
        df,
        test_size=0.2,  # 80% training and 20% testing
        random_state=42,  # to ensure that the split is always the same
        shuffle=True,
        stratify=df[
            "label"
        ],  # to ensure that the distribution of the labels is the same in both splits
    )

    # Saving the orginal splits
    train.to_csv(f"../data/splits/{name}_train.csv", index=False)
    test.to_csv(f"../data/splits/{name}_test.csv", index=False)

    print("Original dataset shape %s" % Counter(train["label"]))

    # Map the labels to integers (SMOTE only works with integers)
    train["label"] = train["label"].map(
        {"gram-negative": 1, "gram-positive": 2, "acid-fast": 3, "fungi": 4}
    )

    # Split the training data
    X_train = train.drop(columns=["label", "cmp_id"])
    y_train = train["label"]

    # Apply SMOTE to the training data
    sm = SMOTE(random_state=42)
    smote_sampled_train, smote_sampled_labels = sm.fit_resample(X_train, y_train)

    # Map the labels back to their original values
    smote_sampled_train["label"] = smote_sampled_labels
    smote_sampled_train["label"] = smote_sampled_train["label"].map(
        {1: "gram-negative", 2: "gram-positive", 3: "acid-fast", 4: "fungi"}
    )

    print("SMOTE dataset shape %s" % Counter(smote_sampled_train["label"]))

    # Saving the SMOTE splits
    smote_sampled_train.to_csv(f"../data/splits/{name}_smote_train.csv", index=False)

In [6]:
for fingerprint_name, df in tqdm(fingerprint_df_dict.items()):
    smote_base_sampling(df, fingerprint_name)

  0%|          | 0/6 [00:00<?, ?it/s]


 Processing ecfp4 dataset
Original dataset shape Counter({'gram-positive': 27805, 'gram-negative': 12406, 'fungi': 11744, 'acid-fast': 7406})
SMOTE dataset shape Counter({'fungi': 27805, 'gram-negative': 27805, 'acid-fast': 27805, 'gram-positive': 27805})


 17%|█▋        | 1/6 [00:27<02:15, 27.02s/it]


 Processing rdkit dataset
Original dataset shape Counter({'gram-positive': 27805, 'gram-negative': 12406, 'fungi': 11744, 'acid-fast': 7406})
SMOTE dataset shape Counter({'fungi': 27805, 'gram-negative': 27805, 'acid-fast': 27805, 'gram-positive': 27805})


 33%|███▎      | 2/6 [00:54<01:50, 27.52s/it]


 Processing maccs dataset
Original dataset shape Counter({'gram-positive': 27805, 'gram-negative': 12406, 'fungi': 11744, 'acid-fast': 7406})
SMOTE dataset shape Counter({'fungi': 27805, 'gram-negative': 27805, 'acid-fast': 27805, 'gram-positive': 27805})


 50%|█████     | 3/6 [01:02<00:54, 18.22s/it]


 Processing mhfp6 dataset
Original dataset shape Counter({'gram-positive': 27805, 'gram-negative': 12406, 'fungi': 11744, 'acid-fast': 7406})
SMOTE dataset shape Counter({'fungi': 27805, 'gram-negative': 27805, 'acid-fast': 27805, 'gram-positive': 27805})


 67%|██████▋   | 4/6 [02:42<01:41, 50.67s/it]


 Processing erg dataset
Original dataset shape Counter({'gram-positive': 27805, 'gram-negative': 12406, 'fungi': 11744, 'acid-fast': 7406})
SMOTE dataset shape Counter({'fungi': 27805, 'gram-negative': 27805, 'acid-fast': 27805, 'gram-positive': 27805})


 83%|████████▎ | 5/6 [02:58<00:38, 38.21s/it]


 Processing chem_phys dataset
Original dataset shape Counter({'gram-positive': 27805, 'gram-negative': 12406, 'fungi': 11744, 'acid-fast': 7406})
SMOTE dataset shape Counter({'fungi': 27805, 'gram-negative': 27805, 'acid-fast': 27805, 'gram-positive': 27805})


100%|██████████| 6/6 [02:59<00:00, 29.98s/it]
