In [1]:
from pathlib import Path

from autogluon.tabular import TabularDataset, TabularPredictor
from mordred import Calculator, descriptors
import pandas as pd
from rdkit.Chem import MolFromSmiles

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)

In [3]:
calc = Calculator(descriptors, ignore_3D=True)

In [None]:
train_file = data_dir / "train.parquet"
if not train_file.exists():
    train_data = pd.read_csv(data_dir / "AqSolDBc.csv")
    train_data["rdkit_mol"] = train_data["SmilesCurated"].apply(MolFromSmiles)
    train_data = train_data.dropna(axis=0, subset=["rdkit_mol"])
    train_descs: pd.DataFrame = calc.pandas(train_data["rdkit_mol"]).fill_missing()
    train_df = pd.concat((train_data[["ExperimentalLogS"]], train_descs), axis=1)
    train_df.to_parquet(train_file)
else:
    train_df = pd.read_parquet(train_file)

In [None]:
test_file = data_dir / "test.parquet"
if not test_file.exists():
    test_data = pd.read_csv(data_dir / "OChemUnseen.csv")
    test_data["rdkit_mol"] = test_data["SMILES"].apply(MolFromSmiles)
    test_data = test_data.dropna(axis=0, subset=["rdkit_mol"])
    test_descs: pd.DataFrame = calc.pandas(test_data["rdkit_mol"]).fill_missing()
    test_df = pd.concat((test_data[["LogS"]], test_descs), axis=1)
    test_df.to_parquet(test_file)
else:
    test_df = pd.read_parquet(test_file)

In [21]:
train_data = TabularDataset(train_df)
test_data = TabularDataset(test_df)

In [None]:
predictor = TabularPredictor(label="ExperimentalLogS", log_to_file=True).fit(train_data, num_gpus=1)

In [None]:
predictor.evaluate(test_data)

In [4]:
test_data = pd.read_csv(data_dir / "OChemUnseen.csv")
test_data["rdkit_mol"] = test_data["SMILES"].apply(MolFromSmiles)
test_data = test_data.dropna(axis=0, subset=["rdkit_mol"])
test_descs: pd.DataFrame = calc.pandas(test_data["rdkit_mol"]).fill_missing()

[12:02:35] Explicit valence for atom # 1 P, 6, is greater than permitted
100%|██████████| 2250/2250 [00:41<00:00, 54.59it/s] 
  t[t.applymap(is_missing)] = value


In [5]:
test_descs

Unnamed: 0,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,LogEE_A,...,SRW10,TSRW10,MW,AMW,WPath,WPol,Zagreb1,Zagreb2,mZagreb1,mZagreb2
0,4.680200,4.942478,0,3,6.828427,2.0,4.0,6.828427,0.97549,2.765108,...,7.655864,32.211905,105.040702,8.753392,48,4,26.0,24.0,4.472222,1.666667
1,9.397338,9.404508,0,0,15.191508,2.44949,4.898979,15.191508,1.168578,3.466567,...,9.687009,44.780241,184.084792,7.363392,220,23,64.0,77.0,6.395833,3.083333
2,6.542301,6.236096,1,0,11.189957,2.193993,4.387987,11.189957,1.243329,3.089765,...,8.590258,37.289972,121.029503,8.644964,88,9,40.0,43.0,3.472222,2.111111
3,6.692130,6.867470,0,0,11.069268,2.069782,4.139564,11.069268,1.106927,3.123657,...,8.086103,37.158205,142.099380,5.920807,142,8,38.0,37.0,5.222222,2.500000
4,6.910910,7.103173,0,0,10.184668,2.101003,4.202006,10.184668,1.018467,3.132559,...,8.258163,37.645869,142.099380,5.920807,135,8,40.0,39.0,5.833333,2.333333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2246,43.292641,32.789840,0,0,72.640956,2.463758,4.916238,72.640956,1.320745,4.943034,...,10.950508,110.700274,755.411880,6.994554,10908,97,296.0,356.0,16.972222,12.083333
2247,44.040187,32.322816,0,0,73.870449,2.460699,4.913694,73.870449,1.319115,4.958474,...,10.962926,110.351740,766.416631,6.967424,11620,99,300.0,359.0,17.222222,12.305556
2248,43.292641,32.789840,0,0,72.640956,2.463758,4.916238,72.640956,1.320745,4.943034,...,10.950508,110.700274,754.391479,7.184681,10908,97,296.0,356.0,16.972222,12.083333
2249,5.656854,6.174783,0,0,10.491415,2.052881,4.105762,10.491415,1.165713,3.008468,...,7.914618,35.304688,134.094294,5.830187,102,8,32.0,32.0,4.361111,2.500000


In [8]:
test_descs = test_descs.apply(pd.to_numeric, errors='coerce')

In [9]:
test_descs

Unnamed: 0,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,LogEE_A,...,SRW10,TSRW10,MW,AMW,WPath,WPol,Zagreb1,Zagreb2,mZagreb1,mZagreb2
0,4.680200,4.942478,0,3,6.828427,2.000000,4.000000,6.828427,0.975490,2.765108,...,7.655864,32.211905,105.040702,8.753392,48,4,26.0,24.0,4.472222,1.666667
1,9.397338,9.404508,0,0,15.191508,2.449490,4.898979,15.191508,1.168578,3.466567,...,9.687009,44.780241,184.084792,7.363392,220,23,64.0,77.0,6.395833,3.083333
2,6.542301,6.236096,1,0,11.189957,2.193993,4.387987,11.189957,1.243329,3.089765,...,8.590258,37.289972,121.029503,8.644964,88,9,40.0,43.0,3.472222,2.111111
3,6.692130,6.867470,0,0,11.069268,2.069782,4.139564,11.069268,1.106927,3.123657,...,8.086103,37.158205,142.099380,5.920807,142,8,38.0,37.0,5.222222,2.500000
4,6.910910,7.103173,0,0,10.184668,2.101003,4.202006,10.184668,1.018467,3.132559,...,8.258163,37.645869,142.099380,5.920807,135,8,40.0,39.0,5.833333,2.333333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2246,43.292641,32.789840,0,0,72.640956,2.463758,4.916238,72.640956,1.320745,4.943034,...,10.950508,110.700274,755.411880,6.994554,10908,97,296.0,356.0,16.972222,12.083333
2247,44.040187,32.322816,0,0,73.870449,2.460699,4.913694,73.870449,1.319115,4.958474,...,10.962926,110.351740,766.416631,6.967424,11620,99,300.0,359.0,17.222222,12.305556
2248,43.292641,32.789840,0,0,72.640956,2.463758,4.916238,72.640956,1.320745,4.943034,...,10.950508,110.700274,754.391479,7.184681,10908,97,296.0,356.0,16.972222,12.083333
2249,5.656854,6.174783,0,0,10.491415,2.052881,4.105762,10.491415,1.165713,3.008468,...,7.914618,35.304688,134.094294,5.830187,102,8,32.0,32.0,4.361111,2.500000


In [32]:
means = test_descs.mean(axis=0)
stdevs = test_descs.var(axis=0).pow(0.5)
stdevs

ABC          9.256814
ABCGG        6.812518
nAcid        0.757646
nBase        0.895059
SpAbs_A     15.043430
              ...    
WPol        20.989663
Zagreb1     62.169606
Zagreb2     73.318118
mZagreb1     4.869810
mZagreb2     2.672769
Length: 1613, dtype: float64

In [26]:
# when all are nan, set mean and stdev to 0
means.fillna(0.0, inplace=True)
stdevs.fillna(0.0, inplace=True)

In [34]:
test_descs.clip(lower=means - 3 * stdevs, upper=means + 3 * stdevs, axis=1)

Unnamed: 0,ABC,ABCGG,nAcid,nBase,SpAbs_A,SpMax_A,SpDiam_A,SpAD_A,SpMAD_A,LogEE_A,...,SRW10,TSRW10,MW,AMW,WPath,WPol,Zagreb1,Zagreb2,mZagreb1,mZagreb2
0,4.680200,4.942478,0.0,3.0,6.828427,2.000000,4.000000,6.828427,0.978147,2.765108,...,7.655864,32.211905,105.040702,8.753392,48.0,4.000000,26.000000,24.0000,4.472222,1.666667
1,9.397338,9.404508,0.0,0.0,15.191508,2.449490,4.898979,15.191508,1.168578,3.466567,...,9.687009,44.780241,184.084792,7.363392,220.0,23.000000,64.000000,77.0000,6.395833,3.083333
2,6.542301,6.236096,1.0,0.0,11.189957,2.193993,4.387987,11.189957,1.243329,3.089765,...,8.590258,37.289972,121.029503,8.644964,88.0,9.000000,40.000000,43.0000,3.472222,2.111111
3,6.692130,6.867470,0.0,0.0,11.069268,2.069782,4.139564,11.069268,1.106927,3.123657,...,8.086103,37.158205,142.099380,5.920807,142.0,8.000000,38.000000,37.0000,5.222222,2.500000
4,6.910910,7.103173,0.0,0.0,10.184668,2.101003,4.202006,10.184668,1.018467,3.132559,...,8.258163,37.645869,142.099380,5.920807,135.0,8.000000,40.000000,39.0000,5.833333,2.333333
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2246,41.327102,32.223296,0.0,0.0,67.409494,2.463758,4.916238,67.409494,1.320745,4.943034,...,10.950508,106.965762,755.411880,6.994554,10908.0,88.884099,275.540818,322.0508,16.972222,12.083333
2247,41.327102,32.223296,0.0,0.0,67.409494,2.460699,4.913694,67.409494,1.319115,4.958474,...,10.962926,106.965762,766.416631,6.967424,11620.0,88.884099,275.540818,322.0508,17.222222,12.084554
2248,41.327102,32.223296,0.0,0.0,67.409494,2.463758,4.916238,67.409494,1.320745,4.943034,...,10.950508,106.965762,754.391479,7.184681,10908.0,88.884099,275.540818,322.0508,16.972222,12.083333
2249,5.656854,6.174783,0.0,0.0,10.491415,2.052881,4.105762,10.491415,1.165713,3.008468,...,7.914618,35.304688,134.094294,5.830187,102.0,8.000000,32.000000,32.0000,4.361111,2.500000


In [33]:
(means + 3 * stdevs).isna().value_counts()

False    1549
True       64
Name: count, dtype: int64