In [2]:
import numpy as np
import pandas as pd
from datasets import load_dataset

ds = load_dataset("scikit-fingerprints/MoleculeNet_ESOL")
dataset = pd.DataFrame(ds['train'])
dataset.head()

Unnamed: 0,SMILES,label
0,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...,-0.77
1,Cc1occc1C(=O)Nc2ccccc2,-3.3
2,CC(C)=CCCC(C)=CC(=O),-2.06
3,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43,-7.87
4,c1ccsc1,-1.33


In [3]:
from tqdm import tqdm
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer

chemberta = AutoModelForMaskedLM.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
tokenizer = AutoTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")

def featurize_ChemBERTa(smiles_list, padding=True):
    embeddings_cls = torch.zeros(len(smiles_list), 600)
    embeddings_mean = torch.zeros(len(smiles_list), 600)

    with torch.no_grad():
        for i, smiles in enumerate(tqdm(smiles_list)):
            encoded_input = tokenizer(smiles, return_tensors="pt",padding=padding,truncation=True)
            model_output = chemberta(**encoded_input)
            
            embedding = model_output[0][::,0,::]
            embeddings_cls[i] = embedding
            
            embedding = torch.mean(model_output[0],1)
            embeddings_mean[i] = embedding
            
    return embeddings_cls.numpy(), embeddings_mean.numpy()

# Featurize the dataset
X_cls, X_mean = featurize_ChemBERTa(dataset['SMILES'].tolist(), padding=True)

# save the dataset as parquet
dataset['X_cls'] = list(X_cls)
dataset['X_mean'] = list(X_mean)
display(dataset.head())
dataset.to_parquet("esol.parquet", index=False)

Some weights of RobertaForMaskedLM were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['lm_head.bias', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.layer_norm.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 1128/1128 [00:02<00:00, 478.91it/s]


Unnamed: 0,SMILES,label,X_cls,X_mean
0,OCC3OC(OCC2OC(OC(C#N)c1ccccc1)C(O)C(O)C2O)C(O)...,-0.77,"[0.40536433, 0.0, -0.8746475, -0.11098191, -0....","[0.5421948, 0.0, -0.17677982, -0.25464058, 0.1..."
1,Cc1occc1C(=O)Nc2ccccc2,-3.3,"[-0.20269528, 0.0, -0.20217846, -0.0020512005,...","[0.40014043, 0.0, 0.08028915, -0.57318765, 0.5..."
2,CC(C)=CCCC(C)=CC(=O),-2.06,"[0.10683021, 0.0, -0.59511095, -0.49970043, 0....","[0.48607704, 0.0, -0.30101377, -0.4096281, 0.2..."
3,c1ccc2c(c1)ccc3c2ccc4c5ccccc5ccc43,-7.87,"[-0.10439029, 0.0, 0.56547326, -0.36930147, 0....","[0.29917368, 0.0, -0.031156043, -0.56387055, 0..."
4,c1ccsc1,-1.33,"[-0.12677687, 0.0, 0.07676783, -0.11666492, 0....","[0.16924359, 0.0, 0.22438815, -0.44767582, 0.7..."


In [None]:
# loading the aqsolDB dataset
from rdkit import Chem
aqsoldb = pd.read_csv("curated-solubility-dataset.csv")

dataset = aqsoldb[['SMILES', 'Solubility']]
dataset = dataset.rename(columns={"Solubility": "label"})

def check_smiles(smiles):
    try:
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            return False
        else:
            return True
    except:
        return False
    
# check the smiles data for validity
dataset['valid'] = dataset['SMILES'].apply(check_smiles)
dataset = dataset[dataset['valid'] == True]
print(f"Number of invalid SMILES found: {len(aqsoldb) - len(dataset)}")
dataset = dataset.drop(columns=['valid'])

# Featurize the dataset
X_cls, X_mean = featurize_ChemBERTa(dataset['SMILES'].tolist(), padding=True)
dataset['X_cls'] = list(X_cls)
dataset['X_mean'] = list(X_mean)
display(dataset.head())

# save the dataset as parquet
dataset.to_parquet("aqsoldb.parquet", index=False)

[12:56:52] Explicit valence for atom # 5 N, 4, is greater than permitted
[12:56:52] Explicit valence for atom # 5 N, 4, is greater than permitted


Number of invalid SMILES found: 2


100%|██████████| 9980/9980 [00:24<00:00, 400.07it/s]


Unnamed: 0,SMILES,label,X_cls,X_mean
0,[Br-].CCCCCCCCCCCCCCCCCC[N+](C)(C)C,-3.616127,"[-0.05037888, 0.0, 0.21057384, -0.0044510528, ...","[-0.015998222, 0.0, -0.3155203, -0.082074456, ..."
1,O=C1Nc2cccc3cccc1c23,-3.254767,"[0.13276795, 0.0, 0.26397175, -0.23781067, -0....","[0.38055447, 0.0, 0.15346204, -0.69117033, 0.5..."
2,Clc1ccc(C=O)cc1,-2.177078,"[0.17114204, 0.0, -0.14478922, 0.27940676, 0.2...","[0.45670402, 0.0, -0.23538736, -0.3581017, 0.2..."
3,[Zn++].CC(c1ccccc1)c2cc(C(C)c3ccccc3)c(O)c(c2)...,-3.924409,"[0.21158491, 0.0, 0.498009, 0.7779276, 0.11697...","[0.49604198, 0.0, -0.06561358, -0.35666656, 0...."
4,C1OC1CN(CC2CO2)c3ccc(Cc4ccc(cc4)N(CC5CO5)CC6CO...,-4.662065,"[0.39181775, 0.0, -0.23251453, 0.24793407, -0....","[0.38087612, 0.0, -0.17673464, -0.21271648, 0...."
