# Pre-processing: undersample of negatives and building block split


In [None]:
import os
import random
from random import sample

import polars as pl
import polars.selectors as cs

In [None]:
data_path = "/home/ec2-user/SageMaker/dataset/"
data_set = pl.scan_parquet(os.path.join(data_path, "train.parquet"))
data_set.explain(streaming=True)
unique_protein_names = data_set.select("protein_name").unique().collect()

## Sample dataset per protein, negatives 5 times positives


In [None]:
for protein in list(unique_protein_names["protein_name"]):
    print(f"Processing protein {protein}")
    protein_data = data_set.filter(pl.col("protein_name") == protein)
    positive_data = protein_data.filter(pl.col("binds") == 1)
    negative_data = protein_data.filter(pl.col("binds") == 0)
    positive_count = positive_data.select(pl.len()).collect()
    negative_data_sampled = negative_data.collect().sample(
        n=positive_count.item() * 5
    )  # 5x more negatives than positives
    combined_data = pl.concat([positive_data.collect(), negative_data_sampled])
    output_path = os.path.join(data_path, f"{protein}.parquet")
    combined_data.write_parquet(output_path)

## Concatenate dataframes per protein into a big train dataframe


In [None]:
seh_data = pl.scan_parquet(os.path.join(data_path, "seh.parquet"))
hsa_data = pl.scan_parquet(os.path.join(data_path, "hsa.parquet"))
brda_data = pl.scan_parquet(os.path.join(data_path, "brda.parquet"))
combined_data = pl.concat([seh_data.collect(), hsa_data.collect(), brda_data.collect()])

## One hot encode protein name


In [None]:
combined_data = combined_data.to_dummies("protein_name")
output_path = os.path.join(data_path, "train_subsampled.parquet")
combined_data.write_parquet(output_path)
print(combined_data)

# Create validation based on building block split


In [None]:
train_subsampled = pl.scan_parquet(
    os.path.join(data_path, "train_subsampled.parquet")
).collect()

In [None]:
# Get unique building blocks
unique_smiles = set(
    train_subsampled["buildingblock1_smiles"].to_list()
    + train_subsampled["buildingblock2_smiles"].to_list()
    + train_subsampled["buildingblock3_smiles"].to_list()
)

# Sample 10% of the codes for validation
total_bbs = len(list(unique_smiles))
train_size = int(total_bbs * 0.90)

train_bbs = random.sample(list(unique_smiles), train_size)
test_bbs = list(unique_smiles - set(train_bbs))

print("Train set size:", len(train_bbs))
print("Test set size:", len(test_bbs))

# Assign samples to validation if they have bbs from validation
train_subsampled = train_subsampled.with_columns(
    pl.any_horizontal(cs.contains("buildingblock").is_in(set(test_bbs))).alias("val")
)

In [None]:
train_val_counts = train_subsampled["val"].value_counts()
val_proportions = train_val_counts["count"] / len(train_subsampled["val"])
print(val_proportions)

In [None]:
val_df = train_subsampled.filter(train_subsampled["val"])
train_df = train_subsampled.filter(~train_subsampled["val"])

In [None]:
output_path = os.path.join(data_path, "train_subsampled_bb_split.parquet")
train_df.write_parquet(output_path)

In [None]:
output_path = os.path.join(data_path, "val_subsampled_bb_split.parquet")
val_df.write_parquet(output_path)