In [1]:
import duckdb
import pandas as pd
# import findspark
# findspark.init()
import pyspark
from pyspark.conf import SparkConf

from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType, IntegerType, StructType, StructField, StringType, LongType
from pyspark.ml.feature import StringIndexer, OneHotEncoder, Word2Vec
from imblearn.over_sampling import SMOTE
import rdkit
from rdkit import Chem
from rdkit.Chem import AllChem

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import average_precision_score
# from sklearn.preprocessing import OneHotEncoder

import mlflow

import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader

import sklearn
import time
# import h2o
# from h2o.estimators import H2OGradientBoostingEstimator
# from h2o.frame import H2OFrame
# from pysparkling import H2OContext
# import pysparkling

In [2]:
# schema = StructType([
#     StructField("id", LongType(), True),
#     StructField("buildingblock1_smiles", StringType(), True),
#     StructField("buildingblock2_smiles", StringType(), True),
#     StructField("buildingblock3_smiles", StringType(), True),
#     StructField("molecule_smiles", StringType(), True),
#     StructField("protein_name", StringType(), True),
#     StructField("binds", LongType(), True)
# ])
# train_1 = spark.read.parquet("train_1.parquet", schema=schema)
# train_0 = spark.read.parquet("train_0.parquet", schema=schema)

In [3]:
def vector_to_array(v):
    return v.toArray().tolist()
vector_to_array_udf = udf(vector_to_array, ArrayType(FloatType()))

def smiles_to_fingerprint(smiles, radius = 3, nBits = 25):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits) # 10, 15
        return [bit for bit in fp]
    else:
        return [0] * nBits

# Register the UDF
smiles_to_fingerprint_udf = udf(smiles_to_fingerprint, ArrayType(IntegerType()))

# Register the UDF using pandas_udf
@pandas_udf(ArrayType(IntegerType()), PandasUDFType.SCALAR)
def smiles_to_fingerprint_udf(smiles_series: pd.Series) -> pd.Series:
    return smiles_series.apply(smiles_to_fingerprint)



# ### Distinct Counts
# # buildingblock 1: 271
# # buildingblock 2: 693
# # buildingblock 3: 872
# # molecule_smiles: 29,656
# # 4 repeats for buildingblock triplets but they are binded to different target proteins



In [4]:
spark = SparkSession.builder.appName('belka') \
    .config("spark.executor.cores", "4") \
    .config("spark.executor.memory", "4g") \
    .config("spark.cores.max", "16") \
    .config("spark.executor.memoryOverhead", "1g") \
    .config("spark.driver.memoryOverhead", "1g") \
    .getOrCreate()
    # .config("spark.ext.h2o.backend.cluster.mode", "internal")\
    # .config("spark.executor.instances", "1")\
    # .config("spark.executor.memory", "2g")\
    # .config("spark.driver.memory", "2g")\
train_data = spark.read.parquet("train.parquet") # train_data.count() # 295,246,830
sample_data = train_data.sample(.0001)

train_1 = spark.read.parquet("train_1.parquet")

# pandas_sample = sample_data.toPandas()
# train_rdd = train_data.rdd
# h2o.init()
# hc = H2OContext.getOrCreate()

# bind1_data = train_data.where(F.col("binds") == 1)
# bind0_data = train_data.where(F.col("binds") == 0)

24/06/22 16:51:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


In [5]:
### Vectorize inputs with Morgan Fingerprint Bit Embedding
radius = 3
vector_size = 25
encoded_data = (
    train_1.withColumn("encoded_buildingblock1", smiles_to_fingerprint_udf(F.col("buildingblock1_smiles") ))
                .withColumn("encoded_buildingblock2", smiles_to_fingerprint_udf(F.col("buildingblock2_smiles") ))
                .withColumn("encoded_buildingblock3", smiles_to_fingerprint_udf(F.col("buildingblock3_smiles")))
                .withColumn("encoded_molecule_vector", smiles_to_fingerprint_udf(F.col("molecule_smiles")))
                .withColumn("encoded_sEH", F.when(F.col("protein_name") == "sEH", 1).otherwise(0))
                .withColumn("encoded_HSA", F.when(F.col("protein_name") == "HSA", 1).otherwise(0))
                .withColumn("encoded_BRD4", F.when(F.col("protein_name") == "BRD4", 1).otherwise(0))
)

encoded_data = encoded_data.select(*[(F.col("encoded_buildingblock1")[i]).alias(f"buildingblock1_{i}") for i in range(25)],
                   *[(F.col("encoded_buildingblock2")[i]).alias(f"buildingblock2_{i}") for i in range(25)],
                   *[(F.col("encoded_buildingblock3")[i]).alias(f"buildingblock3_{i}") for i in range(25)],
                   *[(F.col("encoded_molecule_vector")[i]).alias(f"molecule_{i}") for i in range(25)],
                   "encoded_sEH", "encoded_HSA", "encoded_BRD4", "binds").withColumn("row_id", F.monotonically_increasing_id())


# print(bind1_data.count()) # 1_589_906
# print(bind0_data.count()) # 293_656_924

# 293656924 / 1589906 # 184.70080872705682

In [6]:
sample_data = encoded_data.where(F.col("row_id").between(90 * 15899, (90 + 1) * 15899))
sample_pandas = sample_data.toPandas()
sample_pandas

24/06/22 16:52:00 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Unnamed: 0,buildingblock1_0,buildingblock1_1,buildingblock1_2,buildingblock1_3,buildingblock1_4,buildingblock1_5,buildingblock1_6,buildingblock1_7,buildingblock1_8,buildingblock1_9,...,molecule_20,molecule_21,molecule_22,molecule_23,molecule_24,encoded_sEH,encoded_HSA,encoded_BRD4,binds,row_id


In [22]:
# for i in range(10,101): # 10
#     start = time.time()
#     print(i)
#     sample_data = encoded_data.where(F.col("row_id").between(i * 15899, (i + 1) * 15899))
#     sample_data.write.mode("overwrite").parquet(f"train_1_encoded_shards/shard_{i}.parquet")
#     final_df = pd.concat([final_df, sample_data.toPandas()])
#     end = time.time()

#     print("Time: ", end - start)



10


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1135.8215065002441
11


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1119.677745103836
12


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1135.662543296814
13


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1155.328647851944
14


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1134.4454526901245
15


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  6865.955460548401
16


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  10510.841964006424
17


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  9102.605480194092
18


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  7371.999340772629
19


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  5526.186974525452
20


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1110.5066299438477
21


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1109.251111984253
22


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1994.456913948059
23


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1135.1038093566895
24


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1301.161777973175
25


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1128.7823162078857
26


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  4212.000355482101
27


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1113.4462685585022
28


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1132.5642528533936
29


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1168.1129925251007
30


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1201.5664229393005
31


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1205.1166591644287
32


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1195.6964123249054
33


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  4331.7315328121185
34


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1110.5847544670105
35


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1104.5225512981415
36


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  6106.799744844437
37


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1115.0264163017273
38


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1109.7100522518158
39


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1103.033879995346
40


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1106.4022641181946
41


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1108.8896322250366
42


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1108.1209032535553
43


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1109.966866493225
44


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1109.3919959068298
45


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1838.933402299881
46


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1109.9444992542267
47


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1107.3715753555298
48


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1104.2289423942566
49


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1113.8230459690094
50


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1108.8121361732483
51


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1106.781572818756
52


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1104.8064997196198
53


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1109.6672103404999
54


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1111.1917114257812
55


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1119.7706236839294
56


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  30906.950053453445
57


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1110.5252883434296
58


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1099.2023537158966
59


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  16487.03115606308
60


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  5866.04102563858
61


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1111.412573337555
62


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1107.9913268089294
63


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1120.854809999466
64


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1120.9390816688538
65


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1218.3782892227173
66


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1105.2769041061401
67


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1110.6209576129913
68


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1121.4725608825684
69


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1135.702567100525
70


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  2020.2108478546143
71


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1108.3789794445038
72


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1102.4425683021545
73


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1111.6059274673462
74


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1111.3015139102936
75


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1106.5981607437134
76


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1111.4516022205353
77


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1105.8392493724823
78


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1099.0650384426117
79


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1148.594084262848
80


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  2122.4433088302612
81


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1215.7850637435913
82


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1251.8876020908356
83


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1291.4766175746918
84


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  17597.65839934349
85


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  22401.201187372208
86


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  2215.68683218956
87


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1134.52690243721
88


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1131.0876786708832
89


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1144.6194996833801
90


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  5114.908013820648
91


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1149.761355638504
92


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1135.1712493896484
93


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1131.5196342468262
94


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  10498.966329574585
95


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1131.8907692432404
96


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1131.56200838089
97


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1148.216982126236
98


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1133.3615539073944
99


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1122.606997013092
100


  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series


Time:  1134.525916814804


In [17]:
full_train_1 = spark.read.parquet("train_1_encoded_full.parquet")

dummy_row = full_train_1.limit(1).collect()[0]
dummy_row = dummy_row.asDict()
dummy_row["binds"] = 0

In [21]:
sample_pd = full_train_1.sample(.1).toPandas()
sample_pd = sample_pd.append(dummy_row, ignore_index=True)
train_x = sample_pd.drop('binds', axis=1)
train_y = sample_pd['binds']

smote = SMOTE(sampling_strategy={1: 500_000}, k_neighbors=5)
smote_x, smote_y = smote.fit_resample(train_x, train_y)

smote_x['binds'] = smote_y
drop_index = smote_y[smote_y == 0]
smote_x.drop(drop_index)

  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  df[column_name] = series
  sample_pd = sample_pd.append(dummy_row, ignore_index=True)
  smote_x['binds'] = smote_y


Unnamed: 0,buildingblock1_0,buildingblock1_1,buildingblock1_2,buildingblock1_3,buildingblock1_4,buildingblock1_5,buildingblock1_6,buildingblock1_7,buildingblock1_8,buildingblock1_9,...,molecule_20,molecule_21,molecule_22,molecule_23,molecule_24,encoded_sEH,encoded_HSA,encoded_BRD4,row_id,binds
1,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,0,0,419600,1
2,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,0,0,139780,1
3,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,0,0,1,12258,1
4,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,0,0,233760,1
5,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,0,0,168819,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
499996,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,0,0,0,68719489960,1
499997,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,0,0,1,8589942399,1
499998,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,0,0,0,68719486977,1
499999,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,0,0,58190,1


In [22]:
drop_index

159094    0
Name: binds, dtype: int64

In [50]:
smote_x

  smote_x['binds'] = smote_y


Unnamed: 0,buildingblock1_0,buildingblock1_1,buildingblock1_2,buildingblock1_3,buildingblock1_4,buildingblock1_5,buildingblock1_6,buildingblock1_7,buildingblock1_8,buildingblock1_9,...,molecule_19,molecule_20,molecule_21,molecule_22,molecule_23,molecule_24,encoded_sEH,encoded_HSA,encoded_BRD4,binds
0,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,1,1,0,0,0
1,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,1,1,0,0,0
2,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,1,1,0,0,0
3,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,1,0,1,0,0
4,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,1,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
58487,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,0,0,0,1
58488,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,1,0,0,1,1
58489,1,1,1,1,0,1,1,1,1,1,...,1,1,1,1,1,1,0,0,1,1
58490,1,1,1,1,1,1,1,1,1,1,...,1,1,1,1,1,1,1,0,0,1


In [24]:
# ### Vectorize Inputs
# vector_size = 20

# word2Vec1 = Word2Vec(vectorSize=vector_size, minCount=0, inputCol="buildingblock1_array", outputCol="buildingblock1_vector")
# model1 = word2Vec1.fit(sample_train_1)
# result1 = model1.transform(sample_train_1)

# print("Done 1")

# word2Vec2 = Word2Vec(vectorSize=vector_size, minCount=0, inputCol="buildingblock2_array", outputCol="buildingblock2_vector")
# model2 = word2Vec2.fit(result1)
# result2 = model2.transform(result1)

# print("Done 2")

# word2Vec3 = Word2Vec(vectorSize=vector_size, minCount=0, inputCol="buildingblock3_array", outputCol="buildingblock3_vector")
# model3 = word2Vec3.fit(result2)
# result3 = model3.transform(result2)

# print("Done 3")

# word2Vec4 = Word2Vec(vectorSize=vector_size, minCount=0, inputCol="molecule_array", outputCol="molecule_vector")
# model4 = word2Vec4.fit(result3)
# result4 = model4.transform(result3)

# print("Done 4")



# vectorized_samples = result4.select('buildingblock1_vector', 'buildingblock2_vector', 'buildingblock3_vector', 
#                                     'molecule_vector', "encoded_sEH", "encoded_HSA", "encoded_BRD4", 'binds')

# array_samples =(
#     vectorized_samples
#     .withColumn("buildingblock1_vector", vector_to_array_udf(vectorized_samples["buildingblock1_vector"]))
#     .withColumn("buildingblock2_vector", vector_to_array_udf(vectorized_samples["buildingblock2_vector"]))
#     .withColumn("buildingblock3_vector", vector_to_array_udf(vectorized_samples["buildingblock3_vector"]))
#     .withColumn("molecule_vector", vector_to_array_udf(vectorized_samples["molecule_vector"]))
# )

# pandas_vector_df = array_samples.toPandas()

# ### Unpacking arrays
# rename_block1 = {i:f"buildingblock1_feature_{i}" for i in range(vector_size)}
# values_df = pd.DataFrame(pandas_vector_df['buildingblock1_vector'].tolist(), index=pandas_vector_df.index)
# result_df = pd.concat([pandas_vector_df.drop(columns=['buildingblock1_vector']), values_df], axis=1)
# result_df = result_df.rename(columns=rename_block1)

# rename_block2 = {i:f"buildingblock2_feature_{i}" for i in range(vector_size)}
# values_df = pd.DataFrame(result_df['buildingblock2_vector'].tolist(), index=result_df.index)
# result_df = pd.concat([result_df.drop(columns=['buildingblock2_vector']), values_df], axis=1)
# result_df = result_df.rename(columns=rename_block2)

# rename_block3 = {i:f"buildingblock3_feature_{i}" for i in range(vector_size)}
# values_df = pd.DataFrame(result_df['buildingblock3_vector'].tolist(), index=result_df.index)
# result_df = pd.concat([result_df.drop(columns=['buildingblock3_vector']), values_df], axis=1)
# result_df = result_df.rename(columns=rename_block3)

# rename_molecule = {i:f"molecule_feature_{i}" for i in range(vector_size)}
# values_df = pd.DataFrame(result_df['molecule_vector'].tolist(), index=result_df.index)
# result_df = pd.concat([result_df.drop(columns=['molecule_vector']), values_df], axis=1)
# result_df = result_df.rename(columns=rename_molecule)
# result_df

In [19]:
binds_distribution = train_data.select('binds').groupby('binds').count()
binds_distribution.show()
"""
Imbalanced labels. Most of the attempts do not bind at all.
0: 293656924
1: 1589906
"""

[Stage 22:>                                                       (0 + 14) / 28]

+-----+---------+
|binds|    count|
+-----+---------+
|    0|293656924|
|    1|  1589906|
+-----+---------+





'\nImbalanced labels. Most of the attempts do not bind at all.\n'

In [16]:
result_distribution = train_data.groupby(['protein_name', 'binds']).count()
result_distribution.show()
"""
sEH tends to be more likely to have a bind than the other two.
"""



+------------+-----+--------+
|protein_name|binds|   count|
+------------+-----+--------+
|         HSA|    0|98007200|
|         sEH|    1|  724532|
|         sEH|    0|97691078|
|        BRD4|    1|  456964|
|        BRD4|    0|97958646|
|         HSA|    1|  408410|
+------------+-----+--------+



