In [26]:
from deepchem.molnet import load_bace_classification, load_bbbp
import numpy as np

from simcse import SimCSE
import torch
from torch import nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

from args_parser import parse_args
import sys
import pandas as pd

datasets = {
        "bace": load_bace_classification,
        "bbbp": load_bbbp
        }

sys.argv = ['']
args = parse_args()

args.samples_per_class=250
args.n_augment = 0


input_dim = 512
output_dim = args.num_labels
set_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

_, datasets, _ = datasets.get(args.dataset_name)(reload=False)
(train_dataset, valid_dataset, test_dataset) = datasets

12/23/2022 00:11:22 - INFO - deepchem.molnet.load_function.molnet_loader -   About to featurize bace_c dataset.
12/23/2022 00:11:22 - INFO - deepchem.data.data_loader -   Loading raw samples now.
12/23/2022 00:11:22 - INFO - deepchem.data.data_loader -   shard_size: 8192
12/23/2022 00:11:22 - INFO - deepchem.utils.data_utils -   About to start loading CSV from /var/folders/s4/4l6cbdrn7cq2m4vs14xz9qpm0000gn/T/bace.csv
12/23/2022 00:11:22 - INFO - deepchem.utils.data_utils -   Loading shard 1 of size 8192.
12/23/2022 00:11:22 - INFO - deepchem.data.data_loader -   About to featurize shard.
12/23/2022 00:11:22 - INFO - deepchem.feat.base_classes -   Featurizing datapoint 0
12/23/2022 00:11:29 - INFO - deepchem.feat.base_classes -   Featurizing datapoint 1000
12/23/2022 00:11:32 - INFO - deepchem.data.data_loader -   TIMING: featurizing shard 0 took 9.532 s
12/23/2022 00:11:32 - INFO - deepchem.data.datasets -   TIMING: dataset construction took 9.732 s
12/23/2022 00:11:32 - INFO - deepche

In [27]:
output_dim

2

In [28]:

model = SimCSE("shahrukhx01/muv2x-simcse-smole-bert")
train_indices = []

def embed_smiles(smiles):
    embeddings = model.encode(smiles)
    return embeddings


train_labels = [y[0] for y in train_dataset.y]
label_df = pd.DataFrame(train_labels, columns=["labels"])
if args.samples_per_class > 0:
    np.random.seed()
    tp = np.random.choice(
        list(label_df[label_df["labels"] == 1].index),
        args.samples_per_class,
        replace=False,
    )
    tn = np.random.choice(
        list(label_df[label_df["labels"] == 0].index),
        args.samples_per_class,
        replace=False,
    )
    train_indices = list(tp) + list(tn)
    
np.random.seed()

train_smiles = train_dataset.ids[train_indices]
train_embeddings = embed_smiles(smiles=list(train_smiles))
train_labels = [y[0] for y in train_dataset.y[train_indices]]

val_smiles = valid_dataset.ids
val_embeddings = embed_smiles(smiles=list(val_smiles))
val_labels = [y[0] for y in valid_dataset.y]

test_smiles = test_dataset.ids
test_embeddings = embed_smiles(smiles=list(test_smiles))
test_labels = [y[0] for y in test_dataset.y]

def mixup_augment(embedding1, embedding2, label1, label2, lamda):
    embedding_output = lam * embedding1 + (1.0 - lam) * embedding2
    label_output = lam * label1 + (1.0 - lam) * label2
    return (embedding_output, label_output)

augmented_embeds, augmented_labels = [], []
if args.n_augment:
    for idx, (train_embedding, train_label) in enumerate(zip(train_embeddings, train_labels)):
        train_embeddings_idx = np.array([_idx for _idx in range(len(train_embeddings)) if _idx!=idx])
        for i in range(args.n_augment):
            np.random.seed()
            lam = np.random.beta(args.alpha, args.alpha)
            embedding2_idx = np.random.choice(train_embeddings_idx, replace=False)
            embedding2 = train_embeddings[embedding2_idx, :]
            label2 = train_labels[embedding2_idx]
            aug_embed, aug_label =  mixup_augment(embedding1=train_embedding, embedding2=embedding2, label1=train_label+1, label2=label2+1, lamda=lam)
            aug_label = aug_label-1
            augmented_embeds.append(aug_embed)
            augmented_labels.append(aug_label)

train_embeddings_augmented, train_labels_augmented = None, None

if len(augmented_embeds):
    augmented_embeds = torch.stack(augmented_embeds)
    train_embeddings_augmented = torch.cat([train_embeddings, augmented_embeds])
    train_labels_augmented = train_labels + augmented_labels
else:
    train_embeddings_augmented, train_labels_augmented = train_embeddings, train_labels
    
train_data = TensorDataset(
            train_embeddings_augmented,
            torch.Tensor(train_labels_augmented)
        )
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(
    train_data, sampler=train_sampler, batch_size=args.batch_size
)


val_data = TensorDataset(
            val_embeddings,
            torch.Tensor(val_labels)
        )
val_sampler = RandomSampler(val_data)
val_dataloader = DataLoader(
    val_data, sampler=val_sampler, batch_size=len(val_data)
)


test_data = TensorDataset(
            test_embeddings,
            torch.Tensor(test_labels)
        )
test_sampler = RandomSampler(test_data)
test_dataloader = DataLoader(
    test_data, sampler=test_sampler, batch_size=len(test_data)
)


class MolNet(nn.Module):
    """
    This class is created to specify the Neural Network on which vectorized datasets we have created previously
    is trained on, validated and later tested.
    It consist of one input layer, one output layer and multiple hidden layers.
    ...
    """
    def __init__(self, input_dim, output_dim, dropout=0.5):
        super(MolNet, self).__init__()
        # Layer definitions
        self.layers = nn.Sequential(
        nn.Linear(input_dim, 1024),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(512, 256),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(256, 128),
        nn.ReLU(),
        nn.Dropout(dropout),
        nn.Linear(128, output_dim)
        )

    def forward(self, x):
        # Forward pass
        return self.layers(x)

model_mlp = MolNet(input_dim=input_dim, output_dim=output_dim).to(set_device)
criterion = nn.CrossEntropyLoss().to(set_device)
optimizer = getattr(optim, "Adam")(model_mlp.parameters(), lr=args.lr)


from sklearn.metrics import roc_auc_score
from torch import sigmoid
from torch.nn.functional import softmax


def flat_auroc_score(preds, labels):
    """
    Function to calculate the roc_auc_score of our predictions vs labels
    """
    pred_flat = softmax(preds, dim=1)[:, 1]
    # labels_flat = np.argmax(labels, axis=1)
    return roc_auc_score(labels, pred_flat.detach().cpu().numpy())


best_model = None
best_accuracy = 0.0
train_loss_history, recall_train_history = [], []
validation_loss_history, recall_validation_history = list(), list()
for epoch in range(0, args.epoch):
        model_mlp.train()
        train_loss_scores = []
        training_acc_scores = []
        y_pred, y_true= list(), list()
        predictions = []
        for batch, targets in train_dataloader:
            
            ## perform forward pass  
            batch = batch.type(torch.FloatTensor).to(set_device)
            pred = model_mlp(batch) 
            
            preds = torch.max(pred, 1)[1]

            ## accumulate predictions per batch for the epoch
            y_pred += list([x.item() for x in preds.detach().cpu().numpy()])
            targets = torch.LongTensor([x.item() for x in list(targets)])
            y_true +=  list([x.item() for x in targets.detach().cpu().numpy()])
            
            ## compute loss and perform backward pass
            loss = criterion(pred.to(set_device), targets.to(set_device)) ## compute loss 
            optimizer.zero_grad()
            loss.backward() 
            optimizer.step()
            
            predictions.append(pred)

            ## accumulate train loss
            train_loss_scores.append(loss.item())
        
        ## accumulate loss, recall, f1, precision per epoch
        train_loss_history.append((sum(train_loss_scores)/len(train_loss_scores)))
        #recall = flat_auroc_score(predictions, y_true)
        #recall_train_history.append(recall)
        print(f'Training =>  Epoch : {epoch+1} | Loss : {train_loss_history[-1]}') 
              #| AUROC score: {recall_train_history[-1]}')
        
        model_mlp.eval()
        predictions = None
        with torch.no_grad():
            validation_loss_scores = list()
            y_true_val, y_pred_val= list(), list()

            ## perform validation pass
            for batch, targets in val_dataloader:
                ## perform forward pass  
                batch = batch.type(torch.FloatTensor).to(set_device)
                pred = model_mlp(batch) 
                predictions = pred
                preds = torch.max(pred, 1)[1]
                
                ## accumulate predictions per batch for the epoch
                y_pred_val += list([x.item() for x in preds.detach().cpu().numpy()])
                targets = torch.LongTensor([x.item() for x in list(targets)])
                y_true_val +=  list([x.item() for x in targets.detach().cpu().numpy()])
                
                ## computing validate loss
                loss = criterion(pred.to(set_device), targets.to(set_device)) ## compute loss 

                ## accumulate validate loss
                validation_loss_scores.append(loss.item())
                
            
            ## accumulate loss, recall, f1, precision per epoch
            validation_loss_history.append((sum(validation_loss_scores)/len(validation_loss_scores)))
            recall = flat_auroc_score(predictions, y_true_val)
            recall_validation_history.append(recall)

            print(f'Validation =>  Epoch : {epoch+1} | Loss : {validation_loss_history[-1]} | AUROC score: {recall_validation_history[-1]} ')
            
            if recall_validation_history[-1]>best_accuracy:
                best_accuracy = recall_validation_history[-1]
                print('Selecting the model...')
                best_model = model_mlp

best_model.eval()
predictions = None
with torch.no_grad():
    validation_loss_scores = list()
    y_true_val, y_pred_val= list(), list()

    ## perform validation pass
    for batch, targets in test_dataloader:
        ## perform forward pass  
        batch = batch.type(torch.FloatTensor).to(set_device)
        pred = best_model(batch) 
        predictions = pred
        preds = torch.max(pred, 1)[1]

        ## accumulate predictions per batch for the epoch
        y_pred_val += list([x.item() for x in preds.detach().cpu().numpy()])
        targets = torch.LongTensor([x.item() for x in list(targets)])
        y_true_val +=  list([x.item() for x in targets.detach().cpu().numpy()])

        ## computing validate loss
        loss = criterion(pred.to(set_device), targets.to(set_device)) ## compute loss 

        ## accumulate validate loss
        validation_loss_scores.append(loss.item())


    ## accumulate loss, recall, f1, precision per epoch
    validation_loss_history.append((sum(validation_loss_scores)/len(validation_loss_scores)))
    recall = flat_auroc_score(predictions, y_true_val)
    recall_validation_history.append(recall)

    print(f'Test => AUROC score: {recall_validation_history[-1]} ')

Some weights of the model checkpoint at shahrukhx01/muv2x-simcse-smole-bert were not used when initializing BertModel: ['mlp.dense.bias', 'mlp.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at shahrukhx01/muv2x-simcse-smole-bert and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
100%|██████████| 8/8 [00:26<00:00,  3.28s/it]
100%|██████████| 3/3 [00:06<00:00,  2.

Training =>  Epoch : 1 | Loss : 0.695096093416214
Validation =>  Epoch : 1 | Loss : 0.688141942024231 | AUROC score: 0.5943496801705758 
Selecting the model...
Training =>  Epoch : 2 | Loss : 0.6863942503929138
Validation =>  Epoch : 2 | Loss : 0.6837015151977539 | AUROC score: 0.6186922530206113 
Selecting the model...


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 3 | Loss : 0.6469144523143768
Validation =>  Epoch : 3 | Loss : 0.6984891295433044 | AUROC score: 0.6592039800995024 
Selecting the model...
Training =>  Epoch : 4 | Loss : 0.5581906318664551
Validation =>  Epoch : 4 | Loss : 0.737140417098999 | AUROC score: 0.6901208244491827 
Selecting the model...


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 5 | Loss : 0.507535719871521
Validation =>  Epoch : 5 | Loss : 0.748183012008667 | AUROC score: 0.6743070362473348 
Training =>  Epoch : 6 | Loss : 0.44030154645442965
Validation =>  Epoch : 6 | Loss : 0.6792479157447815 | AUROC score: 0.7036247334754797 
Selecting the model...


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 7 | Loss : 0.4372164040803909
Validation =>  Epoch : 7 | Loss : 0.676410973072052 | AUROC score: 0.71090973702914 
Selecting the model...
Training =>  Epoch : 8 | Loss : 0.4102799415588379
Validation =>  Epoch : 8 | Loss : 0.685738205909729 | AUROC score: 0.7205046197583511 
Selecting the model...


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 9 | Loss : 0.3826268076896667
Validation =>  Epoch : 9 | Loss : 0.6415461897850037 | AUROC score: 0.7133972992181947 
Training =>  Epoch : 10 | Loss : 0.3587747007608414
Validation =>  Epoch : 10 | Loss : 0.7046258449554443 | AUROC score: 0.71090973702914 


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 11 | Loss : 0.3310509964823723
Validation =>  Epoch : 11 | Loss : 0.7430493831634521 | AUROC score: 0.7100213219616206 
Training =>  Epoch : 12 | Loss : 0.3415596142411232
Validation =>  Epoch : 12 | Loss : 0.7512920498847961 | AUROC score: 0.7093105899076049 


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 13 | Loss : 0.31924952417612074
Validation =>  Epoch : 13 | Loss : 0.7765250205993652 | AUROC score: 0.7151741293532339 
Training =>  Epoch : 14 | Loss : 0.27114118486642835
Validation =>  Epoch : 14 | Loss : 0.7479512691497803 | AUROC score: 0.7187277896233119 


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 15 | Loss : 0.2824199602007866
Validation =>  Epoch : 15 | Loss : 0.7969406843185425 | AUROC score: 0.7192608386638237 
Training =>  Epoch : 16 | Loss : 0.30133824050426483
Validation =>  Epoch : 16 | Loss : 0.7900192141532898 | AUROC score: 0.7022032693674485 


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 17 | Loss : 0.2845265582203865
Validation =>  Epoch : 17 | Loss : 0.9110234975814819 | AUROC score: 0.7162402274342574 
Training =>  Epoch : 18 | Loss : 0.27125139385461805
Validation =>  Epoch : 18 | Loss : 0.8660262227058411 | AUROC score: 0.714641080312722 


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])


Training =>  Epoch : 19 | Loss : 0.24421080946922302
Validation =>  Epoch : 19 | Loss : 0.9373683333396912 | AUROC score: 0.697228144989339 
Training =>  Epoch : 20 | Loss : 0.2786246299743652
Validation =>  Epoch : 20 | Loss : 0.8990718722343445 | AUROC score: 0.6982942430703625 
Test => AUROC score: 0.7673913043478261 


  targets = torch.LongTensor([x.item() for x in list(targets)])
  targets = torch.LongTensor([x.item() for x in list(targets)])
