# BERT Pipeline for Drug-Target Interaction (DTI) Prediction

This notebook implements an end-to-end pipeline for predicting drug-target interactions using a BERT-based model. The workflow includes:

- *Data Preprocessing:* Cleaning, merging, and filtering raw CSV datasets.
- *Frequent Consecutive Subsequence (FCS) Extraction:* Extracting and ranking meaningful subsequences from drug SMILES and protein sequences.
- *Sequence Encoding:* Converting sequences into token indices using generated token dictionaries.
- *BERT Input Construction and Masking:* Creating unified BERT inputs with special tokens and applying dynamic masking.
- *Model Training and Evaluation:* Training a BERT-based classifier and evaluating its performance using ROC-AUC.


In [484]:
%pip install datasets

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.2.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


##Importing required Libraries

In [485]:
import pandas as pd
from collections import Counter
import re
from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
from datasets import load_dataset, Dataset
import torch
import random
import json
from tqdm import tqdm # For processing bars
import os
import numpy as np

In [486]:
#Define file paths
DRUGBANK_CSV = "drug_smiles.csv"
PROTEIN_CSV = "pdb_sequences.csv"
INTERACTIONS_CSV = "confirmed_interactions.csv"
DRUG_ENCODED_CSV = "drug_encoded.csv"
PROTEIN_ENCODED_CSV = "protein_encoded.csv"
DRUG_FCS_CSV = "drug_smiles_fcs_freq_100.csv"
PROTEIN_FCS_CSV = "protein_fcs_freq_100.csv"
MERGED_ENCODING_CSV = "merged_encodings.csv"
TARGET_LABELS_CSV = "target_labels.csv"

###Frequent Consecutive Subsequence (FCS) Extraction

This section extracts all consecutive subsequences (using a sliding window) from the input strings (drug SMILES and protein sequences). The extraction function also prunes low-frequency subsequences periodically.


In [487]:
def extract_fcs_subsequences_stream(strings, min_length=2, max_length=None,
                                      prune_every=1000, prune_threshold=5,
                                      preserve_short_length=5):
    """
    Extract frequent consecutive subsequences from a list of strings in a streaming fashion.
    Uses different pruning criteria: always keep subsequences shorter than `preserve_short_length`
    while pruning longer ones based on prune_threshold.

    :param strings: Iterable of strings (e.g., drug SMILES or protein sequences).
    :param min_length: Minimum subsequence length.
    :param max_length: Maximum subsequence length to consider (if None, use full length available).
    :param prune_every: After processing this many sequences, prune the counter.
    :param prune_threshold: For subsequences of length >= preserve_short_length,
                            remove those with counts lower than this threshold during pruning.
    :param preserve_short_length: Subsequences with length less than this value will be preserved regardless.
    :return: A Counter mapping each subsequence to its frequency.
    """
    subseq_counter = Counter()

    for idx, s in enumerate(strings, 1):
        n = len(s)
        for i in range(n):
            current_max = n - i if max_length is None else min(max_length, n - i)
            for l in range(min_length, current_max + 1):
                subseq = s[i:i+l]
                subseq_counter[subseq] += 1

        # Prune periodically to free memory:
        if idx % prune_every == 0:
            # For short subsequences (length < preserve_short_length), keep them always;
            # for longer ones, only keep if count >= prune_threshold.
            subseq_counter = Counter({
                k: v for k, v in subseq_counter.items()
                if len(k) < preserve_short_length or v >= prune_threshold
            })
            print(f"Processed {idx} sequences, counter pruned to {len(subseq_counter)} keys.")

    return subseq_counter

###Filtering and Ranking Subsequences

We filter the extracted subsequences by applying a minimum frequency threshold and then rank them. The result is saved into a DataFrame containing each subsequence, its frequency, and rank.


In [488]:
def filter_and_rank_fcs(fcs_counts, min_frequency=100):
    """
    Filter and rank subsequences that appear with frequency >= min_frequency.

    :param fcs_counts: Counter mapping subsequences to frequency.
    :param min_frequency: Minimum frequency for a subsequence to be included.
    :return: A pandas DataFrame with columns [Subsequence, Frequency, Rank].
    """
    filtered_items = [(subseq, freq) for subseq, freq in fcs_counts.items() if freq >= min_frequency]
    filtered_items.sort(key=lambda x: x[1], reverse=True)

    data = []
    for rank, (subseq, freq) in enumerate(filtered_items, start=1):
        data.append({"Subsequence": subseq, "Frequency": freq, "Rank": rank})

    return pd.DataFrame(data)

In [489]:
drugbank_csv = DRUGBANK_CSV          # CSV file for drug SMILES
smiles_column = "smiles"                  # Column name in drugbank.csv

protein_csv = PROTEIN_CSV         # CSV file for protein sequences
protein_column = "sequence"               # Column name in protein_sequences.csv

# Parameters for subsequence extraction
min_length = 2              # Generate subsequences of length 2 or more
max_length = 10             # Maximum subsequence length (adjust as needed)
min_freq_threshold = 5    # Final frequency threshold for inclusion

# Pruning parameters:
prune_every = 1000          # Prune every 1000 sequences processed
prune_threshold = 5         # For subsequences with length >= preserve_short_length, prune if count < 5
preserve_short_length = 5   # Always preserve subsequences with length < 5

# --- PROCESS DRUG SMILES ---
drug_output_file = DRUG_FCS_CSV #path defined
if not os.path.exists(drug_output_file):
    df_drug = pd.read_csv(drugbank_csv)
    if smiles_column not in df_drug.columns:
        raise ValueError(f"No '{smiles_column}' column found in {drugbank_csv}!")

    smiles_list = df_drug[smiles_column].dropna().astype(str).tolist()
    print("Extracting subsequences from drug SMILES...")

    smiles_fcs_counts = extract_fcs_subsequences_stream(
        smiles_list,
        min_length=min_length,
        max_length=max_length,
        prune_every=prune_every,
        prune_threshold=prune_threshold,
        preserve_short_length=preserve_short_length
    )
    smiles_fcs_df = filter_and_rank_fcs(smiles_fcs_counts, min_frequency=min_freq_threshold)
    smiles_fcs_df.to_csv(drug_output_file, index=False)

    print(f"[DRUG SMILES] Total subsequences (frequency ≥ {min_freq_threshold}): {len(smiles_fcs_df)}")
    print(smiles_fcs_df.head(10))
else:
    print(f"{drug_output_file} already exists. Skipping drug SMILES subsequence extraction.")

# --- PROCESS PROTEIN SEQUENCES ---
protein_output_file = PROTEIN_FCS_CSV #path defined
if not os.path.exists(protein_output_file):
    df_protein = pd.read_csv(protein_csv)
    if protein_column not in df_protein.columns:
        raise ValueError(f"No '{protein_column}' column found in {protein_csv}!")

    protein_list = df_protein[protein_column].dropna().astype(str).tolist()
    print("Extracting subsequences from protein sequences...")

    protein_fcs_counts = extract_fcs_subsequences_stream(
        protein_list,
        min_length=min_length,
        max_length=max_length,
        prune_every=prune_every,
        prune_threshold=prune_threshold,
        preserve_short_length=preserve_short_length
    )
    protein_fcs_df = filter_and_rank_fcs(protein_fcs_counts, min_frequency=min_freq_threshold)
    protein_fcs_df.to_csv(protein_output_file, index=False)

    print(f"[PROTEIN SEQUENCES] Total subsequences (frequency ≥ {min_freq_threshold}): {len(protein_fcs_df)}")
    print(protein_fcs_df.head(10))
else:
    print(f"{protein_output_file} already exists. Skipping protein sequences subsequence extraction.")


drug_smiles_fcs_freq_100.csv already exists. Skipping drug SMILES subsequence extraction.
protein_fcs_freq_100.csv already exists. Skipping protein sequences subsequence extraction.


###Sequence Encoding

In this section, we encode each sequence (drug SMILES or protein sequence) by mapping every valid subsequence to its corresponding token index using a token dictionary. If a subsequence is missing in the dictionary, a default value of 0 is used.


In [490]:
def encode_sequence(seq, token_dict, min_subseq_len=2):
    """
    Encodes a sequence (drug SMILES or protein sequence) by scanning all consecutive subsequences
    of length >= min_subseq_len and mapping them to their token indices (or 0 if not found).
    """
    return [token_dict.get(seq[i:j], 0)
            for i in range(len(seq))
            for j in range(i + min_subseq_len, len(seq) + 1)]


In [491]:
drug_df = pd.read_csv(DRUGBANK_CSV)             # Must contain columns "Drug id" and "smiles"
protein_df = pd.read_csv(PROTEIN_CSV)  # Must contain columns "pbd id" and "sequence"

# Load token dictionaries for drugs and proteins
drug_dict_df = pd.read_csv(DRUG_FCS_CSV)    # Columns: "Subsequence", "Rank"
protein_dict_df = pd.read_csv(PROTEIN_FCS_CSV)       # Columns: "Subsequence", "Rank"


In [492]:
import os
import pandas as pd

# Only run encoding if either output file doesn't exist.
if not os.path.exists(DRUG_ENCODED_CSV) or not os.path.exists(PROTEIN_ENCODED_CSV):
    # Create Python dictionaries for mapping
    drug_token_dict = dict(zip(drug_dict_df["Subsequence"], drug_dict_df["Rank"]))
    protein_token_dict = dict(zip(protein_dict_df["Subsequence"], protein_dict_df["Rank"]))

    # Encode drug SMILES and protein sequences separately.
    # You can adjust min_subseq_len if needed.
    drug_df["Encoded"] = drug_df["smiles"].apply(lambda x: encode_sequence(x, drug_token_dict, min_subseq_len=2))
    protein_df["Encoded"] = protein_df["sequence"].apply(lambda x: encode_sequence(x, protein_token_dict, min_subseq_len=2))

    # print("Encoded drug SMILES:")
    # print(drug_df[["Drug id", "Encoded"]].head())
    # print("Encoded protein sequences:")
    # print(protein_df[["pdb_id", "Encoded"]].head())

    # Save the intermediate results to CSV for further inspection.
    drug_df.to_csv(DRUG_ENCODED_CSV, index=False)
    protein_df.to_csv(PROTEIN_ENCODED_CSV, index=False)
else:
    print(f"{DRUG_ENCODED_CSV} and {PROTEIN_ENCODED_CSV} already exist. Skipping encoding step.")

drug_encoded.csv and protein_encoded.csv already exist. Skipping encoding step.


TRIAL!!!

In [493]:
# Load target_labels (binary matrix with no headers)
target_labels = pd.read_csv(TARGET_LABELS_CSV, header=None)

  target_labels = pd.read_csv(TARGET_LABELS_CSV, header=None)


In [494]:
# Extract drug IDs (first column) and protein IDs (first row)
drug_ids = target_labels.iloc[1:, 0].values  # Skip the first row
protein_ids = target_labels.iloc[0, 1:].values  # Skip the first column

In [495]:
# Convert the binary matrix into a DataFrame with proper row/col names
interaction_matrix = target_labels.iloc[1:, 1:].astype(int).values
interaction_df = pd.DataFrame(interaction_matrix, index=drug_ids, columns=protein_ids)

RANDOM SAMPLING

In [496]:
# Get all (drug, protein) pairs where interaction = 1 (positive samples)
positive_pairs = np.argwhere(interaction_matrix == 1)
negative_pairs = np.argwhere(interaction_matrix == 0)

In [497]:
positive_samples = [(str(drug_ids[i]), str(protein_ids[j])) for i, j in positive_pairs]
negative_samples = [(str(drug_ids[i]), str(protein_ids[j])) for i, j in negative_pairs]

In [498]:
# Randomly sample 500 positives and 500 negatives
positive_sampled = np.random.choice(len(positive_samples), 500, replace=False)
negative_sampled = np.random.choice(len(negative_samples), 500, replace=False)

In [499]:
positive_samples = [positive_samples[i] for i in positive_sampled]
negative_samples = [negative_samples[i] for i in negative_sampled]

In [500]:
# Sample an equal number of negative samples as positives
#num_pos = len(positive_samples)
#negative_indices = np.random.choice(len(negative_samples), num_pos, replace=False)
#negative_samples = [negative_samples[i] for i in negative_indices]  # Extract (Drug_ID, Protein_ID) pairs properly

In [501]:
df_selected = pd.DataFrame(positive_samples + negative_samples, columns=["Drug id", "pdb_id"])
df_selected["Interaction"] = [1] * 500 + [0] * 500  # Labels

In [502]:
# --- Step 1: Combine Encoded Sequences into BERT Input Format ---
# Assume you have already encoded sequences stored in separate CSV files.
# Here, we load them. We assume that the encoding is stored as a JSON-encoded list in a column.
# If they are stored as space-separated strings, adjust the parsing accordingly.

drug_encoded_df = pd.read_csv(DRUG_ENCODED_CSV)       # Columns: "Drug id", "Encoded"
protein_encoded_df = pd.read_csv(PROTEIN_ENCODED_CSV)   # Columns: "pbd id", "Encoded"

In [503]:
print("Encoded Proteins:\n",protein_encoded_df[["pdb_id", "Encoded"]].head())
print("Encoded Drugs:\n",drug_encoded_df[["Drug id","Encoded"]].head())

Encoded Proteins:
   pdb_id                                            Encoded
0   1NSI  [289, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
1   1DJL  [307, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
2   1AB2  [45, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
3   4BSJ  [292, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
4   1FLT  [21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
Encoded Drugs:
    Drug id                                            Encoded
0  DB00131  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1  DB00140  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
2  DB00148  [78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
3  DB00159  [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
4  DB00182  [44, 47, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...


In [504]:
# Convert the encoded column from string to list. Here we assume it's stored as JSON.
drug_encoded_df["Encoded"] = drug_encoded_df["Encoded"].apply(json.loads)
protein_encoded_df["Encoded"] = protein_encoded_df["Encoded"].apply(json.loads)

In [505]:
df_filtered = df_selected[
    df_selected["Drug id"].isin(drug_encoded_df["Drug id"]) &
    df_selected["pdb_id"].isin(protein_encoded_df["pdb_id"])
]
print(f"Filtered df_selected shape: {df_filtered.shape}")  # Should be much smaller

Filtered df_selected shape: (327, 3)


In [506]:
from collections import Counter
print("Before merging, df_filtered labels:", Counter(df_filtered["Interaction"]))

Before merging, df_filtered labels: Counter({1: 252, 0: 75})


In [507]:
from collections import Counter

print("Original selected pairs:", Counter(df_selected["Interaction"]))

Original selected pairs: Counter({1: 500, 0: 500})


In [508]:
print("After filtering df_selected:", Counter(df_filtered["Interaction"]))

After filtering df_selected: Counter({1: 252, 0: 75})


In [509]:
merged_df = (
    df_filtered
    .merge(drug_encoded_df, on="Drug id", how="inner")
    .merge(protein_encoded_df, on="pdb_id", how="inner")
)

In [510]:
# Save merged dataset
merged_df.to_csv(MERGED_ENCODING_CSV, index=False)

In [511]:
print(f"Final merged dataset saved as {MERGED_ENCODING_CSV}.")
print(merged_df["Interaction"].value_counts())  # Should show counts for both 1 and 0
print(merged_df.head())

Final merged dataset saved as merged_encodings.csv.
Interaction
1    252
0     75
Name: count, dtype: int64
   Drug id pdb_id  Interaction  \
0  DB06589   1GQ5            1   
1  DB12147   1AGW            1   
2  DB01331   2UWX            1   
3  DB00308   1T0J            1   
4  DB12598   1SPJ            1   

                                              smiles  \
0  Cc1ccc(Nc2nccc(N(C)c3ccc4c(C)n(C)nc4c3)n2)cc1S...   
1  COc1cc(OC)cc(N(CCNC(C)C)c2ccc3ncc(-c4cnn(C)c4)...   
2  CO[C@@]1(NC(=O)Cc2cccs2)C(=O)N2C(C(=O)O)=C(COC...   
3       CCCCCCCN(CC)CCC[C@H](O)c1ccc(NS(C)(=O)=O)cc1   
4      N=C(N)c1ccc2cc(OC(=O)c3ccc(N=C(N)N)cc3)ccc2c1   

                                           Encoded_x  \
0  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   
1  [81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
2  [81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
3  [3, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
4  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   

             

In [512]:
print("Before filtering:", df_selected["Interaction"].value_counts())
print("After filtering:", df_filtered["Interaction"].value_counts())

Before filtering: Interaction
1    500
0    500
Name: count, dtype: int64
After filtering: Interaction
1    252
0     75
Name: count, dtype: int64


In [513]:
print("Before merging:", df_filtered["Interaction"].value_counts())
print("After merging:", merged_df["Interaction"].value_counts())

Before merging: Interaction
1    252
0     75
Name: count, dtype: int64
After merging: Interaction
1    252
0     75
Name: count, dtype: int64


In [514]:
print(merged_df.head())

   Drug id pdb_id  Interaction  \
0  DB06589   1GQ5            1   
1  DB12147   1AGW            1   
2  DB01331   2UWX            1   
3  DB00308   1T0J            1   
4  DB12598   1SPJ            1   

                                              smiles  \
0  Cc1ccc(Nc2nccc(N(C)c3ccc4c(C)n(C)nc4c3)n2)cc1S...   
1  COc1cc(OC)cc(N(CCNC(C)C)c2ccc3ncc(-c4cnn(C)c4)...   
2  CO[C@@]1(NC(=O)Cc2cccs2)C(=O)N2C(C(=O)O)=C(COC...   
3       CCCCCCCN(CC)CCC[C@H](O)c1ccc(NS(C)(=O)=O)cc1   
4      N=C(N)c1ccc2cc(OC(=O)c3ccc(N=C(N)N)cc3)ccc2c1   

                                           Encoded_x  \
0  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   
1  [81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
2  [81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
3  [3, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
4  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   

                                            sequence  \
0  GMLPRLCCLEKGPNGYGFHLHGEKGKLGQYIRLVEPGSPAEKAGLL...   
1  ELPEDP

In [515]:
print("Unique labels:", merged_df["Interaction"].unique())

Unique labels: [1 0]


In [516]:
# Save the final merged file
if not os.path.exists(MERGED_ENCODING_CSV):
    merged_df.to_csv(MERGED_ENCODING_CSV, index=False)
    print(f"File {MERGED_ENCODING_CSV} created.")
else:
    print(f"File {MERGED_ENCODING_CSV} already exists. Skipping save.")

File merged_encodings.csv already exists. Skipping save.


In [517]:
# Define special token IDs
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"

In [518]:
token = "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"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", use_auth_token=token)



In [519]:
# Initialize DataCollator for masking
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

In [520]:
def create_bert_input(drug_tokens, protein_tokens):
    """
    Constructs the final BERT input sequence as:
      [CLS] + drug_tokens + [SEP] + protein_tokens + [SEP]
    """
    drug_str = " ".join(map(str, drug_tokens))
    protein_str = " ".join(map(str, protein_tokens))
    return f"{CLS_TOKEN} {drug_str} {SEP_TOKEN} {protein_str} {SEP_TOKEN}"

In [521]:
def process_chunk(chunk):
    # Convert encoded columns from JSON strings to lists if necessary
    chunk["Encoded_x"] = chunk["Encoded_x"].apply(lambda x: json.loads(x) if isinstance(x, str) else x)
    chunk["Encoded_y"] = chunk["Encoded_y"].apply(lambda x: json.loads(x) if isinstance(x, str) else x)

    # Create BERT input sequences: [CLS] drug_tokens [SEP] protein_tokens [SEP]
    chunk["BERT_Input"] = chunk.apply(lambda row: create_bert_input(row["Encoded_x"], row["Encoded_y"]), axis=1)

    # Tokenize the input sequences with a maximum length of 512 tokens
    tokenized = tokenizer(
        chunk["BERT_Input"].tolist(),
        padding=True,
        truncation=True,
        max_length=512,
        return_tensors="pt"
    )

    # Convert the tokenized tensor to a list of dicts so that each element has an "input_ids" key.
    features = [{"input_ids": ids} for ids in tokenized["input_ids"].tolist()]

    # Apply masking using the data collator.
    masked = data_collator(features)

    # Convert the masked output (a tensor) back to a space-separated string for storage.
    chunk["Masked_Input"] = [" ".join(map(str, seq)) for seq in masked["input_ids"].tolist()]

    return chunk[["Masked_Input", "Interaction"]]


In [522]:
import os
import pandas as pd

if not os.path.exists('final_bert_inputs_masked.h5'):
    chunksize = 1000  # Adjust this based on your available memory

    # Process and write output
    with pd.HDFStore('final_bert_inputs_masked.h5', mode='w') as store:
        for i, chunk in enumerate(pd.read_csv("merged_encodings.csv", chunksize=chunksize)):
            processed_chunk = process_chunk(chunk)
            store.append('df', processed_chunk, format='table', data_columns=True)
            print(f"Processed chunk {i+1}")

    print("Final BERT input sequences with masking saved to final_bert_inputs_masked.h5")
else:
    print("File 'final_bert_inputs_masked.h5' already exists. Skipping processing.")


File 'final_bert_inputs_masked.h5' already exists. Skipping processing.


In [523]:
import pandas as pd
from collections import Counter

# Read processed data from HDF5
with pd.HDFStore('final_bert_inputs_masked.h5', mode='r') as store:
    processed_data = store['df']  # Load the dataframe

# Check if Interaction labels exist
print("Processed Data Shape:", processed_data.shape)
print("Processed Data Columns:", processed_data.columns)
print("Label Distribution After Processing:", Counter(processed_data["Interaction"]))


Processed Data Shape: (279, 2)
Processed Data Columns: Index(['Masked_Input', 'Interaction'], dtype='object')
Label Distribution After Processing: Counter({1: 217, 0: 62})


###Model Training
Training the BERT model by iterating over the DataLoader for a set number of epochs.
Before training we are optimizing the input texts and labels to pass through Data Loader


In [524]:
%pip install h5py

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 23.2.1 -> 25.0.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [525]:
import h5py

In [526]:
def print_hdf5_structure(obj, indent=0):
    for key in obj.keys():
        item = obj[key]
        print("  " * indent + f"{key}: {type(item)}")
        if isinstance(item, h5py.Group):
            print_hdf5_structure(item, indent+1)

with h5py.File('final_bert_inputs_masked.h5', 'r') as f:
    print("HDF5 file structure:")
    print_hdf5_structure(f)

HDF5 file structure:
df: <class 'h5py._hl.group.Group'>
  _i_table: <class 'h5py._hl.group.Group'>
    Interaction: <class 'h5py._hl.group.Group'>
      abounds: <class 'h5py._hl.dataset.Dataset'>
      bounds: <class 'h5py._hl.dataset.Dataset'>
      indices: <class 'h5py._hl.dataset.Dataset'>
      indicesLR: <class 'h5py._hl.dataset.Dataset'>
      mbounds: <class 'h5py._hl.dataset.Dataset'>
      mranges: <class 'h5py._hl.dataset.Dataset'>
      ranges: <class 'h5py._hl.dataset.Dataset'>
      sorted: <class 'h5py._hl.dataset.Dataset'>
      sortedLR: <class 'h5py._hl.dataset.Dataset'>
      zbounds: <class 'h5py._hl.dataset.Dataset'>
    Masked_Input: <class 'h5py._hl.group.Group'>
      abounds: <class 'h5py._hl.dataset.Dataset'>
      bounds: <class 'h5py._hl.dataset.Dataset'>
      indices: <class 'h5py._hl.dataset.Dataset'>
      indicesLR: <class 'h5py._hl.dataset.Dataset'>
      mbounds: <class 'h5py._hl.dataset.Dataset'>
      mranges: <class 'h5py._hl.dataset.Dataset'>
   

In [527]:
with h5py.File('final_bert_inputs_masked.h5', 'r') as f:
    # Access the final table dataset
    dataset = f["df"]["table"]

    # Print the first 10 entries
    print("First 10 entries:")
    print(dataset[:10])


First 10 entries:
[(0, b'101 101 1017 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 19628 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 103 1014 1014 103 1014 1014 1014 1014 1014 6387 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 103 103 103 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 103 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 16232 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 103 1014 103 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 10

In [528]:
with h5py.File('final_bert_inputs_masked.h5', 'r') as f:
    dataset = f["df"]["table"]
    num_entries = dataset.shape[0]
    print("Number of entries:", num_entries)

Number of entries: 279


In [529]:
# --- Step 1: Load Masked Inputs from HDF5 and Pair IDs from CSV ---
masked_df = pd.read_hdf("final_bert_inputs_masked.h5", key="df/table")
pairs_df = pd.read_csv("merged_encodings.csv")  # Should include columns "Drug id" and "pbd id"

In [530]:
# Assign the masked inputs (assumed column name "Masked_Input") from masked_df to pairs_df
pairs_df["Masked_Input"] = masked_df["Masked_Input"]
print(pairs_df.head())

   Drug id pdb_id  Interaction  \
0  DB06589   1GQ5            1   
1  DB12147   1AGW            1   
2  DB01331   2UWX            1   
3  DB00308   1T0J            1   
4  DB12598   1SPJ            1   

                                              smiles  \
0  Cc1ccc(Nc2nccc(N(C)c3ccc4c(C)n(C)nc4c3)n2)cc1S...   
1  COc1cc(OC)cc(N(CCNC(C)C)c2ccc3ncc(-c4cnn(C)c4)...   
2  CO[C@@]1(NC(=O)Cc2cccs2)C(=O)N2C(C(=O)O)=C(COC...   
3       CCCCCCCN(CC)CCC[C@H](O)c1ccc(NS(C)(=O)=O)cc1   
4      N=C(N)c1ccc2cc(OC(=O)c3ccc(N=C(N)N)cc3)ccc2c1   

                                           Encoded_x  \
0  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   
1  [81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
2  [81, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
3  [3, 50, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
4  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   

                                            sequence  \
0  GMLPRLCCLEKGPNGYGFHLHGEKGKLGQYIRLVEPGSPAEKAGLL...   
1  ELPEDP

In [531]:
# --- Step 2: Load the Label Matrix ---
# Ensure that the CSV file now has "Drug id" as header in the first column
label_matrix = pd.read_csv(TARGET_LABELS_CSV)
print(label_matrix.head())

   Drug id  1EVU  1NSI  1DJL  1AB2  1ALS  1CFG  1EG0  1OZ5  4BSJ  ...  1TVB  \
0  DB11300     1     0     0     0     0     1     0     0     0  ...     0   
1  DB11311     1     0     0     0     0     0     0     0     0  ...     0   
2  DB11571     1     0     0     0     0     1     0     0     0  ...     0   
3  DB13151     1     0     0     0     0     1     0     0     0  ...     0   
4  DB05383     0     1     0     0     0     0     0     0     0  ...     0   

   1T5Q  6U6U  7RY7  2MDP  1G2C  4BPU  2X18  2N80  2KR6  
0     0     0     0     0     0     0     0     0     0  
1     0     0     0     0     0     0     0     0     0  
2     0     0     0     0     0     0     0     0     0  
3     0     0     0     0     0     0     0     0     0  
4     0     0     0     0     0     0     0     0     0  

[5 rows x 721 columns]


In [532]:
from torch.utils.data import Dataset
import torch
import ast

class DtiDataset(Dataset):
    def __init__(self, tokenized_texts, labels, max_length=128):
        self.tokenized_texts = tokenized_texts
        self.labels = labels
        self.max_length = max_length

    def __len__(self):
        return len(self.tokenized_texts)

    def __getitem__(self, idx):
        encoding = ast.literal_eval(self.tokenized_texts[idx])  # Convert string to dictionary
        item = {key: torch.tensor(val, dtype=torch.long) for key, val in encoding.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)  # Ensure label is tensor
        return item


#### Load BERT Model and Optimizer
Initialize the BERT tokenizer and the `BertForSequenceClassification` model for predicting drug-target interactions (using 2 labels for single classification). Set up the AdamW optimizer with a learning rate of 2e-5 to update model weights during training.

In [533]:
from transformers import BertForSequenceClassification
from torch.optim import AdamW

tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2, use_auth_token="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")
optimizer = AdamW(model.parameters(), lr=2e-5)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


####Create Dataset and DataLoader
Extract the masked sequences and map each drug ID to its corresponding binary label (derived from the label matrix). Use these to create a PyTorch DataLoader, which facilitates efficient batch processing during model training.


In [534]:
# --- Step 5: Create Dataset and DataLoader ---
# Extract texts from the pairs DataFrame:
texts = pairs_df["Masked_Input"].tolist()
print(label_matrix.head())
# For each drug id in pairs_df, get its label vector from the label_matrix.
# Here we assume the label_matrix’s index matches the "Drug id" values.

# Convert a multi-label vector into a single label (if that is your intended approach)
def convert_to_binary(label_vector):
    return 1 if any(label_vector) else 0

if label_matrix.index.name != "Drug id":
    label_matrix = label_matrix.set_index("Drug id")

labels = pairs_df["Drug id"].apply(
    lambda drug: convert_to_binary(label_matrix.loc[drug].values.tolist())
).tolist()
# print(label_matrix.head())
#labels = pairs_df["Drug id"].apply(lambda drug: label_matrix.loc[drug].values.tolist()).tolist()

# Now create the dataset with the extracted lists
dataset = DtiDataset(texts, labels, tokenizer)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)


   Drug id  1EVU  1NSI  1DJL  1AB2  1ALS  1CFG  1EG0  1OZ5  4BSJ  ...  1TVB  \
0  DB11300     1     0     0     0     0     1     0     0     0  ...     0   
1  DB11311     1     0     0     0     0     0     0     0     0  ...     0   
2  DB11571     1     0     0     0     0     1     0     0     0  ...     0   
3  DB13151     1     0     0     0     0     1     0     0     0  ...     0   
4  DB05383     0     1     0     0     0     0     0     0     0  ...     0   

   1T5Q  6U6U  7RY7  2MDP  1G2C  4BPU  2X18  2N80  2KR6  
0     0     0     0     0     0     0     0     0     0  
1     0     0     0     0     0     0     0     0     0  
2     0     0     0     0     0     0     0     0     0  
3     0     0     0     0     0     0     0     0     0  
4     0     0     0     0     0     0     0     0     0  

[5 rows x 721 columns]


In [535]:
from sklearn.model_selection import train_test_split
import contextlib
from torch.cuda import amp

In [536]:
from collections import Counter

print("Merged dataset labels:", Counter(labels))  # Check overall dataset

Merged dataset labels: Counter({1: 327})


In [537]:
from collections import Counter

print("Before splitting:", Counter(processed_data["Interaction"]))


Before splitting: Counter({1: 217, 0: 62})


In [538]:
# Load processed tokenized data
with pd.HDFStore('final_bert_inputs_masked.h5', mode='r') as store:
    processed_data = store['df']  # Assuming 'df' is the key used when saving

print("Processed Data Shape:", processed_data.shape)
print("Processed Data Columns:", processed_data.columns)

# Extract tokenized inputs and labels
tokenized_inputs = processed_data['Masked_Input'].tolist()  # Pre-tokenized inputs
labels = processed_data['Interaction'].tolist()  # Labels (0 or 1)

Processed Data Shape: (279, 2)
Processed Data Columns: Index(['Masked_Input', 'Interaction'], dtype='object')


In [539]:
print(tokenized_inputs[:5])

['101 101 1017 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 19628 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 103 1014 1014 103 1014 1014 1014 1014 1014 6387 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 103 103 103 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 103 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 16232 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 103 1014 1014 1014 103 1014 103 103 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 1014 

In [540]:
from sklearn.model_selection import train_test_split

train_inputs, test_inputs, train_labels, test_labels = train_test_split(
    tokenized_inputs, labels, test_size=0.2, random_state=42, stratify=labels
)

print(f"Train dataset size: {len(train_inputs)}")
print(f"Test dataset size: {len(test_inputs)}")

Train dataset size: 223
Test dataset size: 56


In [541]:
print("Train Labels Distribution:", Counter(train_labels))
print("Test Labels Distribution:", Counter(test_labels))

Train Labels Distribution: Counter({1: 173, 0: 50})
Test Labels Distribution: Counter({1: 44, 0: 12})


In [542]:
print(pd.Series(train_labels).value_counts())
print(pd.Series(test_labels).value_counts())

1    173
0     50
Name: count, dtype: int64
1    44
0    12
Name: count, dtype: int64


In [543]:
# Create dataset instances using pre-tokenized data
train_dataset = DtiDataset(train_inputs, train_labels)
test_dataset = DtiDataset(test_inputs, test_labels)

from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")


Number of training samples: 223
Number of testing samples: 56


In [544]:
# -------------------------
# Model Initialization and Optimizer
# -------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
model.to(device)

optimizer = AdamW(model.parameters(), lr=2e-5)

Using device: cpu


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [545]:
# -------------------------
# Optional: Set up Mixed Precision Training (if using GPU)
# -------------------------
use_amp = torch.cuda.is_available()
scaler = amp.GradScaler() if use_amp else None

In [546]:
# -------------------------
# Training Loop with Progress Bar
# -------------------------
num_epochs = 1
model.train()

for epoch in range(num_epochs):
    epoch_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")
    for batch in progress_bar:
    # Move batch to the appropriate device
        batch = {k: v.to(device) for k, v in batch.items() if k in ["input_ids", "attention_mask", "labels"]}  # Exclude 'token_type_ids'
    
        optimizer.zero_grad()

        # Mixed precision training if available
        with amp.autocast() if use_amp else contextlib.nullcontext():
            outputs = model(**batch)  # Model expects only 'input_ids' and 'attention_mask'
            loss = outputs.loss


        if use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        epoch_loss += loss.item()
        progress_bar.set_postfix(loss=f"{loss.item():.4f}")
    avg_loss = epoch_loss / len(train_loader)
    print(f"Epoch {epoch+1} finished, average loss: {avg_loss:.4f}")

# -------------------------
# Evaluation Loop with Progress Bar
# -------------------------
model.eval()
correct = total = 0
with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        preds = outputs.logits.argmax(dim=1)
        correct += (preds == batch['labels']).sum().item()
        total += len(batch['labels'])
accuracy = correct / total if total > 0 else 0
print(f"Test Accuracy: {accuracy:.4f}")

Epoch 1/1:   0%|          | 0/112 [00:00<?, ?it/s]


SyntaxError: invalid syntax (<unknown>, line 1)

In [None]:
# from tqdm import tqdm
# model.train()
# for epoch in range(1):  # Example: 3 epochs
#     for batch in tqdm(train_loader):
#         outputs = model(**batch)
#         loss = outputs.loss
#         loss.backward()
#         optimizer.step()
#         optimizer.zero_grad()

# print("\nTraining complete!")

### Model Evalutation
Due to small dataset and overfitting, accuracy is 1.

In [None]:
# from tqdm import tqdm
# model.eval()
# correct = total = 0
# with torch.no_grad():
#     for batch in tqdm(train_loader, desc="Evaluating"):
#         outputs = model(**batch)
#         preds = outputs.logits.argmax(dim=1)
#         correct += (preds == batch['labels']).sum().item()
#         total += len(batch['labels'])
# accuracy = correct / total
# print(f"Accuracy: {accuracy:.4f}")



###Prediction Function
Takes a drug ID and a PDB ID as inputs, constructs the BERT input sequence, tokenizes it, and uses the trained model to predict whether an interaction exists. The function returns "yes" or "no" based on the model's output.

In [None]:
def predict_interaction(drug_id, pdb_id, tokenizer=tokenizer, model=model):
     input_str = f"[CLS] {drug_id} [SEP] {pdb_id} [SEP]"
     inputs = tokenizer(input_str, return_tensors='pt', truncation=True, max_length=512)
     outputs = model(**inputs)
     prediction = outputs.logits.argmax(dim=1).item()
     return "yes" if prediction == 1 else "no"
predict_interaction('DB11300','1NSI')

'yes'