In [60]:
import pandas as pd
from sklearn.model_selection import train_test_split
from transformers import RobertaTokenizer, RobertaModel, RobertaConfig, AdamW, get_linear_schedule_with_warmup , BertModel
import torch.nn as nn
import torch
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import roc_auc_score



In [25]:
dataset = pd.read_csv("/sise/home/adamu/thesis_new/datasets/final_dataset_2_0.75.csv")
esm_features = pd.read_csv("/sise/home/adamu/thesis_new/feature_extraction/outputs/esm_features.csv")
uniprot_mapping = pd.read_csv("/sise/home/adamu/thesis_new/feature_extraction/data/idmapping_2024_05_09.tsv",  delimiter = "\t")
uniprot_mapping

Unnamed: 0,From,Entry,Reviewed,Entry Name,Protein names,Gene Names,Organism,Length,Sequence
0,Q15788,Q15788,reviewed,NCOA1_HUMAN,Nuclear receptor coactivator 1 (NCoA-1) (EC 2....,NCOA1 BHLHE74 SRC1,Homo sapiens (Human),1441,MSGLGDSSSDPANPDSHKRKGSPCDTLASSTEKRRREQENKYLEEL...
1,O14867,O14867,reviewed,BACH1_HUMAN,Transcription regulator protein BACH1 (BTB and...,BACH1,Homo sapiens (Human),736,MSLSENSVFAYESSVHSTNVLLSLNDQRKKDVLCDVTIFVEGQRFR...
2,Q13797,Q13797,reviewed,ITA9_HUMAN,Integrin alpha-9 (Integrin alpha-RLC),ITGA9,Homo sapiens (Human),1035,MGGPAAPRGAGRLRALLLALVVAGIPAGAYNLDPQRPVHFQGPADS...
3,Q6PL18,Q6PL18,reviewed,ATAD2_HUMAN,ATPase family AAA domain-containing protein 2 ...,ATAD2 L16 PRO2000,Homo sapiens (Human),1390,MVVLRSSLELHNHSAASATGSLDLSSDFLSLEHIGRRRLRSAGAAQ...
4,O00512,O00512,reviewed,BCL9_HUMAN,B-cell CLL/lymphoma 9 protein (B-cell lymphoma...,BCL9,Homo sapiens (Human),1426,MHSSNPKVRSSPSGNTQSSPKSKQEVMVRPPTVMSPSGNPQLDSKF...
...,...,...,...,...,...,...,...,...,...
291,O14786,O14786,reviewed,NRP1_HUMAN,Neuropilin-1 (Vascular endothelial cell growth...,NRP1 NRP VEGF165R,Homo sapiens (Human),923,MERGLPLLCAVLALVLAPAGAFRNDKCGDTIKIESPGYLTSPGYPH...
292,Q13501,Q13501,reviewed,SQSTM_HUMAN,Sequestosome-1 (EBI3-associated protein of 60 ...,SQSTM1 ORCA OSIL,Homo sapiens (Human),440,MASLTVKAYLLGKEDAAREIRRFSFCCSPEPEAEAEAAAGPGPCER...
293,Q92843,Q92843,reviewed,B2CL2_HUMAN,Bcl-2-like protein 2 (Bcl2-L-2) (Apoptosis reg...,BCL2L2 BCLW KIAA0271,Homo sapiens (Human),193,MATPASAPDTRALVADFVGYKLRQKGYVCGAGPGEGPAADPLHQAM...
294,P51587,P51587,reviewed,BRCA2_HUMAN,Breast cancer type 2 susceptibility protein (F...,BRCA2 FACD FANCD1,Homo sapiens (Human),3418,MPIGSKERPTFFEIFKTRCNKADLGPISLNWFEELSSEAPPYNSEP...


In [26]:
def convert_uniprot_ids(dataset, mapping_df):
    # Create a dictionary from the mapping dataframe
    mapping_dict = mapping_df.set_index('From')['Entry'].to_dict()

    # Map the uniprot_id1 and uniprot_id2 columns to their respective Entry values
    dataset['uniprot_id1'] = dataset['uniprot_id1'].map(mapping_dict)
    dataset['uniprot_id2'] = dataset['uniprot_id2'].map(mapping_dict)
    return dataset.drop_duplicates()
    
def merge_datasets(dataset, features_df):
    # Merge features for uniprot_id1
    dataset = dataset.merge(features_df, how='left', left_on='uniprot_id1', right_on='UniProt_ID', suffixes=('', '_id1'))
    dataset = dataset.drop(columns=['UniProt_ID'])
    
    # Merge features for uniprot_id2
    features_df_renamed = features_df.add_suffix('_id2')
    features_df_renamed = features_df_renamed.rename(columns={'UniProt_ID_id2': 'UniProt_ID'})
    dataset = dataset.merge(features_df_renamed, how='left', left_on='uniprot_id2', right_on='UniProt_ID', suffixes=('', '_id2'))
    dataset = dataset.drop(columns=['UniProt_ID', 'uniprot_id1', 'uniprot_id2'])
    
    return dataset.drop_duplicates()

In [27]:
dataset = convert_uniprot_ids(dataset, uniprot_mapping)
dataset = merge_datasets(dataset, esm_features)
dataset

Unnamed: 0,smiles,label,Feature_0,Feature_1,Feature_2,Feature_3,Feature_4,Feature_5,Feature_6,Feature_7,...,Feature_1270_id2,Feature_1271_id2,Feature_1272_id2,Feature_1273_id2,Feature_1274_id2,Feature_1275_id2,Feature_1276_id2,Feature_1277_id2,Feature_1278_id2,Feature_1279_id2
0,OC(=O)c1nc(sc1-c1ccc(cc1)-c1ccccc1)-c1ccc2CCCN...,0,0.033387,-0.091496,-0.021745,0.011464,-0.045568,-0.076223,0.067172,0.002492,...,-0.046108,0.018945,-0.160452,0.058708,-0.045173,-0.135907,0.036617,-0.136259,-0.023708,0.164988
1,CN1C(=O)N(c2cc(Cl)cc(Cl)c2)C(=O)[C@]12CN(c1nc3...,0,0.018931,-0.050427,-0.001708,0.052669,-0.015171,-0.051038,0.085986,-0.013006,...,0.095369,-0.000298,-0.187148,0.049039,-0.007977,-0.004218,-0.013003,-0.156473,0.049008,0.133108
2,CC(C)NS(=O)(=O)c1ccc(OCC(=O)N2CCOCC2)cc1,0,-0.020703,-0.077155,-0.054688,0.065730,-0.108772,-0.025299,0.091760,-0.013818,...,0.052764,0.007230,-0.032889,0.013346,-0.002336,-0.061765,0.122240,-0.107535,-0.056424,0.064282
3,CC(=O)N1CCN(C(=O)/C=C/c2ccc(Sc3ccccc3C(N)=O)c(...,1,0.022582,-0.052505,-0.042596,0.089144,-0.089632,-0.131044,0.148719,0.032024,...,0.013739,0.025413,-0.184884,0.035220,-0.037794,-0.026719,0.093608,-0.122195,-0.045078,0.117260
4,CC(C)C1=C(SC2=N[C@](C)([C@H](N12)c1ccc(Cl)cc1)...,1,0.033985,-0.059724,-0.006932,0.144283,-0.005031,-0.037092,0.083921,0.026034,...,0.084858,0.003144,-0.058713,0.010643,0.022048,-0.019426,0.027595,-0.099573,0.011354,0.013198
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
109226,Clc1cc(OCCN2CCOCC2)ccc1Nc1nc2c(-c3nnc[nH]3)ccc...,0,0.033387,-0.091496,-0.021745,0.011464,-0.045568,-0.076223,0.067172,0.002492,...,-0.046108,0.018945,-0.160452,0.058708,-0.045173,-0.135907,0.036617,-0.136259,-0.023708,0.164988
109227,COC(=O)c1cc(O)cc(OC)c1Oc1cc(C)cc(O)c1C(=O)O,0,-0.038923,-0.071458,0.011513,0.078479,-0.190378,0.070974,0.128713,0.020999,...,0.095667,0.027029,-0.136574,-0.025870,0.009829,-0.081369,0.001253,-0.073539,0.061196,0.029591
109228,CCCC1=CN(C[C@H](NC(=O)OCc2ccccc2)C(=O)O)C(=O)N...,0,0.033387,-0.091496,-0.021745,0.011464,-0.045568,-0.076223,0.067172,0.002492,...,-0.046108,0.018945,-0.160452,0.058708,-0.045173,-0.135907,0.036617,-0.136259,-0.023708,0.164988
109229,O=C(CSc1nc2ccccc2c(=O)n1Cc1ccc2c(c1)OCO2)NCc1c...,1,0.016134,-0.060721,-0.051823,0.028336,-0.062652,-0.053298,0.085150,-0.030032,...,-0.046108,0.018945,-0.160452,0.058708,-0.045173,-0.135907,0.036617,-0.136259,-0.023708,0.164988


In [32]:
dataset[dataset.Feature_0.isna() | dataset.Feature_0_id2.isna()]

Unnamed: 0,smiles,label,Feature_0,Feature_1,Feature_2,Feature_3,Feature_4,Feature_5,Feature_6,Feature_7,...,Feature_1270_id2,Feature_1271_id2,Feature_1272_id2,Feature_1273_id2,Feature_1274_id2,Feature_1275_id2,Feature_1276_id2,Feature_1277_id2,Feature_1278_id2,Feature_1279_id2


In [63]:
dataset.iloc[:, 2:] = dataset.iloc[:, 2:].astype(float)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
# Extract SMILES strings, PPI features, and labels
smiles_list = dataset['smiles'].tolist()
ppi_features = dataset.iloc[:, 2:].values  # Excluding SMILES and label columns
labels = dataset['label'].values

tokenizer = RobertaTokenizer.from_pretrained("DeepChem/ChemBERTa-77M-MTR")
encoded_smiles = tokenizer(smiles_list, truncation=True, padding=True, return_tensors="pt")


cuda


In [64]:
class CustomDataset(Dataset):
    def __init__(self, encoded_smiles, ppi_features, labels):
        self.encoded_smiles = encoded_smiles
        self.ppi_features = torch.tensor(ppi_features, dtype=torch.float32)
        self.labels = torch.tensor(labels, dtype=torch.float32)

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {
            "input_ids": self.encoded_smiles["input_ids"][idx],
            "attention_mask": self.encoded_smiles["attention_mask"][idx],
            "ppi_features": self.ppi_features[idx],
            "labels": self.labels[idx]
        }
        return item

In [65]:
# Splitting the data
train_data, temp_data, train_labels, temp_labels = train_test_split(dataset, labels, test_size=0.2, random_state=42, stratify=labels)
valid_data, test_data, valid_labels, test_labels = train_test_split(temp_data, temp_labels, test_size=0.5, random_state=42, stratify=temp_labels)

# Convert these splits into their respective Datasets and DataLoaders
train_dataset = CustomDataset(tokenizer(train_data['smiles'].tolist(), truncation=True, padding=True, return_tensors="pt"), train_data.iloc[:, 2:].values, train_labels)
valid_dataset = CustomDataset(tokenizer(valid_data['smiles'].tolist(), truncation=True, padding=True, return_tensors="pt"), valid_data.iloc[:, 2:].values, valid_labels)
test_dataset = CustomDataset(tokenizer(test_data['smiles'].tolist(), truncation=True, padding=True, return_tensors="pt"), test_data.iloc[:, 2:].values, test_labels)

train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=64, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [66]:
train_data.iloc[:, 2:]

Unnamed: 0,Feature_0,Feature_1,Feature_2,Feature_3,Feature_4,Feature_5,Feature_6,Feature_7,Feature_8,Feature_9,...,Feature_1270_id2,Feature_1271_id2,Feature_1272_id2,Feature_1273_id2,Feature_1274_id2,Feature_1275_id2,Feature_1276_id2,Feature_1277_id2,Feature_1278_id2,Feature_1279_id2
63826,-0.009516,-0.048620,0.004722,0.037495,0.089705,-0.022223,0.022028,0.152187,0.045285,0.051748,...,0.012767,-0.027525,-0.085463,0.038302,0.034502,-0.081834,-0.010867,-0.085387,0.033624,0.077056
98240,-0.020703,-0.077155,-0.054688,0.065730,-0.108772,-0.025299,0.091760,-0.013818,-0.079494,0.198468,...,-0.035412,0.052223,-0.199088,0.039508,-0.027858,-0.104056,0.072464,-0.095414,0.010231,0.057766
74598,-0.005542,-0.096082,-0.049218,-0.004587,0.096639,-0.117399,0.065587,0.079753,-0.064333,0.058297,...,0.027155,-0.031982,-0.049737,0.152361,0.032702,-0.069526,0.100826,0.034708,-0.079781,0.043650
102131,0.058166,-0.144322,-0.046411,0.098785,-0.164210,-0.044439,0.309844,-0.310033,-0.193687,0.088806,...,0.153227,0.024963,-0.114406,-0.013927,0.006024,-0.051427,0.034106,-0.097461,0.034687,-0.044699
82490,-0.027515,-0.093644,0.023019,0.072452,-0.087972,-0.038408,0.183902,0.003826,-0.036904,0.164225,...,0.072114,0.023437,-0.090315,0.004398,0.011356,-0.108461,0.031974,-0.039432,-0.032934,0.044267
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3348,-0.027515,-0.093644,0.023019,0.072452,-0.087972,-0.038408,0.183902,0.003826,-0.036904,0.164225,...,0.072114,0.023437,-0.090315,0.004398,0.011356,-0.108461,0.031974,-0.039432,-0.032934,0.044267
71060,0.014477,-0.060971,0.028221,0.096667,-0.034841,-0.071785,0.078922,-0.039060,-0.020456,0.097615,...,0.017056,-0.000085,-0.180716,0.031813,-0.013119,-0.175944,-0.003636,-0.079875,-0.004251,0.074700
100504,0.063274,-0.029659,0.015889,0.101089,-0.016113,-0.101207,0.000601,0.111697,0.024047,0.033692,...,0.024946,-0.011196,-0.193725,-0.000243,-0.027750,-0.031218,0.054268,-0.062203,-0.028887,0.091523
94930,0.019827,-0.132773,-0.000028,0.141837,-0.097555,0.025639,0.215369,-0.009062,0.020522,0.144291,...,0.055858,0.055593,-0.091637,0.041687,-0.016587,-0.035697,0.025542,-0.037288,0.011203,0.017735


In [67]:
import torch.nn.functional as F

class ChemBERTaWithPPI(nn.Module):
    def __init__(self, model_name, ppi_feature_size, hidden_size1=1024, hidden_size2=512, hidden_size3 = 256):
        super(ChemBERTaWithPPI, self).__init__()
        self.chemberta = BertModel.from_pretrained(model_name)
        self.ppi_fc = nn.Linear(ppi_feature_size, hidden_size1)  # Ensure output matches BERT's dimension
        self.hidden_layer1 = nn.Linear(1408, hidden_size1)  # First hidden layer
        self.hidden_layer2 = nn.Linear(hidden_size1, hidden_size2)  # New added hidden layer
        self.hidden_layer3 = nn.Linear(hidden_size2, hidden_size3)  # New added hidden layer
        self.classifier = nn.Linear(hidden_size3, 1)  # Adjusted for the output of the second hidden layer

    def forward(self, input_ids, attention_mask, ppi_features):
        ppi_out = self.ppi_fc(ppi_features)
        
        # Expand the dimensions of ppi_out to match bert_output[0]
        ppi_out = ppi_out.unsqueeze(1).expand(-1, input_ids.size(1), -1)
        
        bert_output = self.chemberta(input_ids=input_ids, attention_mask=attention_mask)
        integrated_output = torch.cat((bert_output[0], ppi_out), dim=-1)  # Concatenate along the last dimension
        
        # Average pooling over the sequence length dimension
        pooled_output = integrated_output.mean(dim=1)
        
        # Passing through the first hidden layer with a ReLU activation function
        hidden_output1 = F.relu(self.hidden_layer1(pooled_output))
        
        # Passing through the second hidden layer with a ReLU activation function
        hidden_output2 = F.relu(self.hidden_layer2(hidden_output1))

        # Passing through the second hidden layer with a ReLU activation function
        hidden_output3 = F.relu(self.hidden_layer3(hidden_output2))
        
        logits = self.classifier(hidden_output3)
        return logits


In [68]:
# Assuming you have your data loaded in a DataLoader named `dataloader`
model_name = "DeepChem/ChemBERTa-77M-MTR"
model = ChemBERTaWithPPI(model_name, ppi_feature_size=2560).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
criterion = nn.BCEWithLogitsLoss()

You are using a model of type roberta to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.
Some weights of BertModel were not initialized from the model checkpoint at DeepChem/ChemBERTa-77M-MTR and are newly initialized: ['encoder.layer.0.attention.self.query.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.1.attention.output.LayerNorm.weight', 'embeddings.position_embeddings.weight', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.1.output.dense.bias', 'embeddings.word_embeddings.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.2.intermediate.dense.bias', 'pooler.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.bias', 'embeddings.LayerNorm.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.2.attention.self.value.weight', 'pooler.dense.weight', 'encoder.layer.2.attention.o

In [None]:
from tqdm import tqdm

num_epochs = 25
patience = 5
best_auc = 0
epochs_without_improvement = 0

for epoch in range(num_epochs):
    # Training loop
    model.train()
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}", leave=True)
    for batch in progress_bar:
        # Move batch data to the chosen device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        ppi_features = batch["ppi_features"].to(device)
        labels = batch["labels"].to(device)
        
        optimizer.zero_grad()
        logits = model(input_ids, attention_mask, ppi_features)
        loss = criterion(logits.squeeze(-1), labels)
        loss.backward()
        optimizer.step()
        
        # Optionally, update the progress bar description with the current loss
        progress_bar.set_postfix({'loss': loss.item()})
    # Validation loop
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for batch in valid_dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            ppi_features = batch["ppi_features"].to(device)
            labels = batch["labels"].to(device)
            
            logits = model(input_ids, attention_mask, ppi_features)
            predictions = torch.sigmoid(logits).squeeze(-1)
            all_preds.extend(predictions.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    auc = roc_auc_score(all_labels, all_preds)
    print(f"Epoch {epoch + 1}/{num_epochs} - Validation AUC: {auc:.4f}")

    # Early stopping
    if auc > best_auc:
        best_auc = auc
        epochs_without_improvement = 0
        # Optionally, save the best model
        # torch.save(model.state_dict(), 'best_model.pth')
    else:
        epochs_without_improvement += 1
        if epochs_without_improvement == patience:
            print("Early stopping due to no improvement in validation AUC.")
            
    

Epoch 1/25: 100%|██████████| 1366/1366 [04:57<00:00,  4.59it/s, loss=0.689]


Epoch 1/25 - Validation AUC: 0.6262


Epoch 2/25: 100%|██████████| 1366/1366 [04:58<00:00,  4.58it/s, loss=0.452]


Epoch 2/25 - Validation AUC: 0.7269


Epoch 3/25: 100%|██████████| 1366/1366 [04:57<00:00,  4.60it/s, loss=0.375]


Epoch 3/25 - Validation AUC: 0.7973


Epoch 4/25: 100%|██████████| 1366/1366 [04:56<00:00,  4.60it/s, loss=0.443]


Epoch 4/25 - Validation AUC: 0.8432


Epoch 5/25: 100%|██████████| 1366/1366 [04:57<00:00,  4.60it/s, loss=0.43] 


Epoch 5/25 - Validation AUC: 0.8603


Epoch 6/25: 100%|██████████| 1366/1366 [04:57<00:00,  4.58it/s, loss=0.615]


Epoch 6/25 - Validation AUC: 0.8538


Epoch 7/25:  10%|▉         | 130/1366 [00:28<04:30,  4.58it/s, loss=0.288]