# BERT Pipeline for Drug-Target Interaction (DTI) Prediction

This notebook implements an end-to-end pipeline for predicting drug-target interactions using a BERT-based model. The workflow includes:

- *Data Preprocessing:* Cleaning, merging, and filtering raw CSV datasets.
- *Frequent Consecutive Subsequence (FCS) Extraction:* Extracting and ranking meaningful subsequences from drug SMILES and protein sequences.
- *Sequence Encoding:* Converting sequences into token indices using generated token dictionaries.
- *BERT Input Construction and Masking:* Creating unified BERT inputs with special tokens and applying dynamic masking.
- *Model Training and Evaluation:* Training a BERT-based classifier and evaluating its performance using ROC-AUC.


In [137]:
%pip install datasets

Note: you may need to restart the kernel to use updated packages.


##Importing required Libraries

In [138]:
import pandas as pd
from collections import Counter
import re
from transformers import BertTokenizer, BertForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from torch.utils.data import DataLoader
# from datasets import load_dataset, Dataset
import torch
import random
import json
from tqdm import tqdm # For processing bars
import os

In [139]:
#Define file paths
DRUGBANK_CSV = "drug_smiles.csv"
PROTEIN_CSV = "pdb_sequences.csv"
INTERACTIONS_CSV = "confirmed_interactions.csv"
DRUG_ENCODED_CSV = "drug_encoded.csv"
PROTEIN_ENCODED_CSV = "protein_encoded.csv"
DRUG_FCS_CSV = "drug_smiles_fcs_freq_100.csv"
PROTEIN_FCS_CSV = "protein_fcs_freq_100.csv"
MERGED_ENCODING_CSV = "merged_encodings.csv"

# Frequent Consecutive Subsequence (FCS) Extraction

This section extracts all consecutive subsequences (using a sliding window) from the input strings (drug SMILES and protein sequences). The extraction function also prunes low-frequency subsequences periodically.


In [140]:
def extract_fcs_subsequences_stream(strings, min_length=2, max_length=None,
                                      prune_every=1000, prune_threshold=5,
                                      preserve_short_length=5):
    """
    Extract frequent consecutive subsequences from a list of strings in a streaming fashion.
    Uses different pruning criteria: always keep subsequences shorter than `preserve_short_length`
    while pruning longer ones based on prune_threshold.

    :param strings: Iterable of strings (e.g., drug SMILES or protein sequences).
    :param min_length: Minimum subsequence length.
    :param max_length: Maximum subsequence length to consider (if None, use full length available).
    :param prune_every: After processing this many sequences, prune the counter.
    :param prune_threshold: For subsequences of length >= preserve_short_length,
                            remove those with counts lower than this threshold during pruning.
    :param preserve_short_length: Subsequences with length less than this value will be preserved regardless.
    :return: A Counter mapping each subsequence to its frequency.
    """
    subseq_counter = Counter()

    for idx, s in enumerate(strings, 1):
        n = len(s)
        for i in range(n):
            current_max = n - i if max_length is None else min(max_length, n - i)
            for l in range(min_length, current_max + 1):
                subseq = s[i:i+l]
                subseq_counter[subseq] += 1

        # Prune periodically to free memory:
        if idx % prune_every == 0:
            # For short subsequences (length < preserve_short_length), keep them always;
            # for longer ones, only keep if count >= prune_threshold.
            subseq_counter = Counter({
                k: v for k, v in subseq_counter.items()
                if len(k) < preserve_short_length or v >= prune_threshold
            })
            print(f"Processed {idx} sequences, counter pruned to {len(subseq_counter)} keys.")

    return subseq_counter

# Filtering and Ranking Subsequences

We filter the extracted subsequences by applying a minimum frequency threshold and then rank them. The result is saved into a DataFrame containing each subsequence, its frequency, and rank.


In [141]:
def filter_and_rank_fcs(fcs_counts, min_frequency=10):
    """
    Filter and rank subsequences that appear with frequency >= min_frequency.

    :param fcs_counts: Counter mapping subsequences to frequency.
    :param min_frequency: Minimum frequency for a subsequence to be included.
    :return: A pandas DataFrame with columns [Subsequence, Frequency, Rank].
    """
    filtered_items = [(subseq, freq) for subseq, freq in fcs_counts.items() if freq >= min_frequency]
    filtered_items.sort(key=lambda x: x[1], reverse=True)

    data = []
    for rank, (subseq, freq) in enumerate(filtered_items, start=1):
        data.append({"Subsequence": subseq, "Frequency": freq, "Rank": rank})

    return pd.DataFrame(data)

In [143]:
drugbank_csv = DRUGBANK_CSV          # CSV file for drug SMILES
smiles_column = "smiles"                  # Column name in drugbank.csv

protein_csv = PROTEIN_CSV         # CSV file for protein sequences
protein_column = "sequence"               # Column name in protein_sequences.csv

# Parameters for subsequence extraction
min_length = 2              # Generate subsequences of length 2 or more
max_length = 10             # Maximum subsequence length (adjust as needed)
min_freq_threshold = 5    # Final frequency threshold for inclusion

# Pruning parameters:
prune_every = 1000          # Prune every 1000 sequences processed
prune_threshold = 5         # For subsequences with length >= preserve_short_length, prune if count < 5
preserve_short_length = 5   # Always preserve subsequences with length < 5

# --- PROCESS DRUG SMILES ---
drug_output_file = DRUG_FCS_CSV #path defined
if not os.path.exists(drug_output_file):
    df_drug = pd.read_csv(drugbank_csv)
    if smiles_column not in df_drug.columns:
        raise ValueError(f"No '{smiles_column}' column found in {drugbank_csv}!")

    smiles_list = df_drug[smiles_column].dropna().astype(str).tolist()
    print("Extracting subsequences from drug SMILES...")

    smiles_fcs_counts = extract_fcs_subsequences_stream(
        smiles_list,
        min_length=min_length,
        max_length=max_length,
        prune_every=prune_every,
        prune_threshold=prune_threshold,
        preserve_short_length=preserve_short_length
    )
    smiles_fcs_df = filter_and_rank_fcs(smiles_fcs_counts, min_frequency=min_freq_threshold)
    smiles_fcs_df.to_csv(drug_output_file, index=False)

    print(f"[DRUG SMILES] Total subsequences (frequency ≥ {min_freq_threshold}): {len(smiles_fcs_df)}")
    print(smiles_fcs_df.head(10))
else:
    print(f"{drug_output_file} already exists. Skipping drug SMILES subsequence extraction.")

# --- PROCESS PROTEIN SEQUENCES ---
protein_output_file = PROTEIN_FCS_CSV #path defined
if not os.path.exists(protein_output_file):
    df_protein = pd.read_csv(protein_csv)
    if protein_column not in df_protein.columns:
        raise ValueError(f"No '{protein_column}' column found in {protein_csv}!")

    protein_list = df_protein[protein_column].dropna().astype(str).tolist()
    print("Extracting subsequences from protein sequences...")

    protein_fcs_counts = extract_fcs_subsequences_stream(
        protein_list,
        min_length=min_length,
        max_length=max_length,
        prune_every=prune_every,
        prune_threshold=prune_threshold,
        preserve_short_length=preserve_short_length
    )
    protein_fcs_df = filter_and_rank_fcs(protein_fcs_counts, min_frequency=min_freq_threshold)
    protein_fcs_df.to_csv(protein_output_file, index=False)

    print(f"[PROTEIN SEQUENCES] Total subsequences (frequency ≥ {min_freq_threshold}): {len(protein_fcs_df)}")
    print(protein_fcs_df.head(10))
else:
    print(f"{protein_output_file} already exists. Skipping protein sequences subsequence extraction.")


Extracting subsequences from drug SMILES...
[DRUG SMILES] Total subsequences (frequency ≥ 5): 4229
  Subsequence  Frequency  Rank
0          cc        997     1
1          O)        640     2
2          CC        635     3
3          C(        538     4
4          [C        516     5
5         [C@        516     6
6          C@        516     7
7         ccc        510     8
8          H]        469     9
9          (C        469    10
Extracting subsequences from protein sequences...
[PROTEIN SEQUENCES] Total subsequences (frequency ≥ 5): 46284
  Subsequence  Frequency  Rank
0          LL       1645     1
1          LK       1247     2
2          LA       1242     3
3          AL       1192     4
4          AA       1181     5
5          GL       1148     6
6          DL       1117     7
7          LV       1109     8
8          EL       1086     9
9          LG       1078    10


# Sequence Encoding

In this section, we encode each sequence (drug SMILES or protein sequence) by mapping every valid subsequence to its corresponding token index using a token dictionary. If a subsequence is missing in the dictionary, a default value of 0 is used.


In [86]:
def encode_sequence(seq, token_dict, min_subseq_len=2):
    """
    Encodes a sequence (drug SMILES or protein sequence) by scanning all consecutive subsequences
    of length >= min_subseq_len and mapping them to their token indices (or 0 if not found).
    """
    return [token_dict.get(seq[i:j], 0)
            for i in range(len(seq))
            for j in range(i + min_subseq_len, len(seq) + 1)]


In [87]:
drug_df = pd.read_csv(DRUGBANK_CSV)             # Must contain columns "Drug id" and "smiles"
protein_df = pd.read_csv(PROTEIN_CSV)           # Must contain columns "pbd id" and "sequence"

# Load token dictionaries for drugs and proteins
drug_dict_df = pd.read_csv(DRUG_FCS_CSV)    # Columns: "Subsequence", "Rank"
protein_dict_df = pd.read_csv(PROTEIN_FCS_CSV)       # Columns: "Subsequence", "Rank"


In [88]:
import os
import pandas as pd

# Only run encoding if either output file doesn't exist.
if not os.path.exists(DRUG_ENCODED_CSV) or not os.path.exists(PROTEIN_ENCODED_CSV):
    # Create Python dictionaries for mapping
    drug_token_dict = dict(zip(drug_dict_df["Subsequence"], drug_dict_df["Rank"]))
    protein_token_dict = dict(zip(protein_dict_df["Subsequence"], protein_dict_df["Rank"]))

    # Encode drug SMILES and protein sequences separately.
    # You can adjust min_subseq_len if needed.
    drug_df["Encoded"] = drug_df["smiles"].apply(lambda x: encode_sequence(x, drug_token_dict, min_subseq_len=2))
    protein_df["Encoded"] = protein_df["sequence"].apply(lambda x: encode_sequence(x, protein_token_dict, min_subseq_len=2))

    # print("Encoded drug SMILES:")
    # print(drug_df[["Drug id", "Encoded"]].head())
    # print("Encoded protein sequences:")
    # print(protein_df[["pdb_id", "Encoded"]].head())

    # Save the intermediate results to CSV for further inspection.
    drug_df.to_csv(DRUG_ENCODED_CSV, index=False)
    protein_df.to_csv(PROTEIN_ENCODED_CSV, index=False)
else:
    print(f"{DRUG_ENCODED_CSV} and {PROTEIN_ENCODED_CSV} already exist. Skipping encoding step.")

drug_encoded.csv and protein_encoded.csv already exist. Skipping encoding step.


In [89]:
# --- Step 1: Combine Encoded Sequences into BERT Input Format ---
# Assume you have already encoded sequences stored in separate CSV files.
# Here, we load them. We assume that the encoding is stored as a JSON-encoded list in a column.
# If they are stored as space-separated strings, adjust the parsing accordingly.

drug_encoded_df = pd.read_csv(DRUG_ENCODED_CSV)       # Columns: "Drug id", "Encoded"
protein_encoded_df = pd.read_csv(PROTEIN_ENCODED_CSV)   # Columns: "pbd id", "Encoded"

In [90]:
print("Encoded Proteins:\n",protein_encoded_df[["pdb_id", "Encoded"]].head())
print("Encoded Drugs:\n",drug_encoded_df[["Drug id","Encoded"]].head())

Encoded Proteins:
   pdb_id                                            Encoded
0   1NSI  [289, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
1   1DJL  [307, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
2   1AB2  [45, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
3   4BSJ  [292, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0...
4   1FLT  [21, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
Encoded Drugs:
    Drug id                                            Encoded
0  DB00131  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
1  DB00140  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
2  DB00148  [78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...
3  DB00159  [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...
4  DB00182  [44, 47, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...


In [91]:
# Convert the encoded column from string to list. Here we assume it's stored as JSON.
drug_encoded_df["Encoded"] = drug_encoded_df["Encoded"].apply(json.loads)
protein_encoded_df["Encoded"] = protein_encoded_df["Encoded"].apply(json.loads)

In [92]:
pdb_drug_map = pd.read_csv(INTERACTIONS_CSV)
if "Unnamed: 0" in pdb_drug_map.columns:
    pdb_drug_map = pdb_drug_map.drop(columns=["Unnamed: 0"])

# Save the cleaned DataFrame back to the CSV file
pdb_drug_map.to_csv(INTERACTIONS_CSV, index=False)

# Print the first few rows to verify
print(pdb_drug_map.head())

   Drug ID PDB ID
0  DB05383   1NSI
1  DB08814   1NSI
2  DB08814   1MDI
3  DB09092   1DJL
4  DB09092   3ERY


In [93]:
# Split the "Drug IDs" string by "; " (or just ";" if that's what's in your file)
#pdb_drug_map["Drug IDs"] = pdb_drug_map["Drug IDs"].str.split("; ")

In [94]:
# Explode the list so each drug gets its own row
#pdb_drug_map = pdb_drug_map.explode("Drug IDs").reset_index(drop=True)

In [95]:
# Load the files
confirmed_interactions = pd.read_csv(INTERACTIONS_CSV)
drug_encodings = pd.read_csv(DRUG_ENCODED_CSV)
protein_encodings = pd.read_csv(PROTEIN_ENCODED_CSV)

In [96]:
def gen_neg_samples(confirmed_interactions):
    """
    Generate negative samples for drug-target interactions
    
    Args:
        confirmed_interactions: DataFrame with columns ['Drug id', 'pdb_id']
    Returns:
        DataFrame with negative samples
    """
    unique_prots = confirmed_interactions["pdb_id"].unique()
    unique_drugs = confirmed_interactions["Drug id"].unique()
    
    # Convert confirmed interactions to set for faster lookup
    confirmed_pairs = set(zip(confirmed_interactions["Drug id"], 
                            confirmed_interactions["pdb_id"]))
    
    # Number of negative samples to generate (2x positive samples)
    n_samples = len(confirmed_interactions) * 2
    neg_samples = []
    
    with tqdm(total=n_samples, desc="Generating negative samples") as pbar:
        while len(neg_samples) < n_samples:
            drug = random.choice(unique_drugs)
            prot = random.choice(unique_prots)
            
            # Check if this pair is not in confirmed interactions
            if (drug, prot) not in confirmed_pairs:
                neg_samples.append([drug, prot])
                pbar.update(1)
    
    return pd.DataFrame(neg_samples, columns=["Drug id", "pdb_id"])

In [97]:
# Rename columns in confirmed_interactions for consistency (if needed)
confirmed_interactions.rename(columns={'Drug ID': 'Drug id', 'PDB ID': 'pdb_id'}, inplace=True)

In [98]:
neg_inter = gen_neg_samples(confirmed_interactions)
print(neg_inter.head())

Generating negative samples: 100%|██████████| 1936/1936 [00:00<00:00, 559896.06it/s]

   Drug id pdb_id
0  DB01144   1I7G
1  DB00710   1FVR
2  DB09088   3QNZ
3  DB03756   1F0J
4  DB14924   1HZE





In [99]:
confirmed_interactions['label'] = 1
neg_inter['label'] = 0

# Combine positive and negative samples
confirmed_interactions = pd.concat([confirmed_interactions, neg_inter], ignore_index=True)
confirmed_interactions = confirmed_interactions.sample(frac=1, random_state=42).reset_index(drop=True)

confirmed_interactions.head()

Unnamed: 0,Drug id,pdb_id,label
0,DB00786,1BQQ,1
1,DB01044,3FOE,1
2,DB00591,1AII,1
3,DB15493,3P1N,0
4,DB00485,4V11,0


In [100]:
# Merge the drug encodings based on Drug id
merged_df = confirmed_interactions.merge(drug_encodings, on='Drug id', how='inner')

In [101]:
# Merge the resulting DataFrame with protein encodings based on pdb_id
merged_df = merged_df.merge(protein_encodings, on='pdb_id', how='inner')

In [102]:
merged_df.columns

Index(['Drug id', 'pdb_id', 'label', 'smiles', 'Encoded_x', 'sequence',
       'Encoded_y'],
      dtype='object')

In [103]:
# Save the final merged file
if not os.path.exists(MERGED_ENCODING_CSV):
    merged_df.to_csv(MERGED_ENCODING_CSV, index=False)
    print(f"File {MERGED_ENCODING_CSV} created.")
else:
    print(f"File {MERGED_ENCODING_CSV} already exists. Skipping save.")

File merged_encodings.csv already exists. Skipping save.


In [104]:
merged_df = pd.read_csv(MERGED_ENCODING_CSV)

In [105]:
# Display the first 5 lines
print(merged_df.head())

   Drug id pdb_id  label                                             smiles  \
0  DB00786   1BQQ      1  CNC(=O)[C@@H](NC(=O)[C@H](CC(C)C)[C@H](O)C(=O)...   
1  DB01364   1BQQ      0                         CN[C@H](C)[C@H](O)c1ccccc1   
2  DB00968   1BQQ      0                   C[C@@](N)(Cc1ccc(O)c(O)c1)C(=O)O   
3  DB06637   1BQQ      0                                          Nc1ccncc1   
4  DB00786   1L9K      0  CNC(=O)[C@@H](NC(=O)[C@H](CC(C)C)[C@H](O)C(=O)...   

                                           Encoded_x  \
0  [78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
1  [78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   
2  [44, 47, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...   
3  [0, 0, 0, 0, 0, 0, 0, 0, 17, 49, 55, 0, 0, 0, ...   
4  [78, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,...   

                                            sequence  \
0  IQGLKWQHNEITFCIQNYTPKVGEYATYEAIRKAFRVWESATPLRF...   
1  IQGLKWQHNEITFCIQNYTPKVGEYATYEAIRKAFRVWESATPLRF...   
2  IQGLKWQHNEITFCIQN

In [106]:
# Define special token IDs
CLS_TOKEN = "[CLS]"
SEP_TOKEN = "[SEP]"

In [107]:
token = "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"
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", use_auth_token=token)



In [108]:
# Initialize DataCollator for masking
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=True,
    mlm_probability=0.15
)

In [109]:
# def create_bert_input(drug_tokens, target_tokens):
#     # Ensure tokens are strings and remove any existing [SEP] or [CLS]
#     drug_tokens = [str(t) for t in drug_tokens if str(t) not in {CLS_TOKEN, SEP_TOKEN}]
#     target_tokens = [str(t) for t in target_tokens if str(t) not in {CLS_TOKEN, SEP_TOKEN}]

#     # Debug prints to inspect the tokens:
#     #print("Drug tokens:", drug_tokens)
#     #print("Target tokens:", target_tokens)
    
#     # Build the sequence: [CLS] drug tokens [SEP] target tokens [SEP]
#     return [CLS_TOKEN] + drug_tokens + [SEP_TOKEN] + target_tokens + [SEP_TOKEN]

def create_bert_input(drug_encoding, protein_encoding):
    """
    Creates a clean BERT input sequence from drug and protein encodings
    Args:
        drug_encoding: encoded drug sequence
        protein_encoding: encoded protein sequence
    Returns:
        list: cleaned sequence with proper BERT format
    """
    # Convert JSON strings to lists if needed
    if isinstance(drug_encoding, str):
        drug_encoding = json.loads(drug_encoding)
    if isinstance(protein_encoding, str):
        protein_encoding = json.loads(protein_encoding)
    
    # Clean and format sequences
    drug_seq = [str(x) for x in drug_encoding if x != 0]  # Remove padding zeros
    protein_seq = [str(x) for x in protein_encoding if x != 0]  # Remove padding zeros
    
    # Create BERT sequence with proper tokens
    sequence = (
        ['[CLS]'] +  # Start token
        drug_seq +   # Drug sequence
        ['[SEP]'] +  # Separator token
        protein_seq +  # Protein sequence
        ['[SEP]']   # End token
    )
    
    return sequence


In [110]:
sequence = create_bert_input(merged_df["Encoded_x"].iloc[0], merged_df["Encoded_y"].iloc[0])
sequence

['[CLS]',
 '78',
 '83',
 '4',
 '22',
 '23',
 '24',
 '15',
 '18',
 '19',
 '11',
 '12',
 '2',
 '54',
 '66',
 '67',
 '5',
 '6',
 '27',
 '31',
 '32',
 '7',
 '28',
 '33',
 '34',
 '29',
 '35',
 '36',
 '13',
 '14',
 '51',
 '9',
 '52',
 '43',
 '70',
 '83',
 '4',
 '22',
 '23',
 '24',
 '15',
 '18',
 '19',
 '11',
 '12',
 '2',
 '54',
 '66',
 '67',
 '5',
 '6',
 '38',
 '39',
 '7',
 '40',
 '41',
 '13',
 '14',
 '51',
 '9',
 '52',
 '43',
 '10',
 '3',
 '4',
 '58',
 '85',
 '10',
 '30',
 '21',
 '25',
 '21',
 '54',
 '66',
 '67',
 '5',
 '6',
 '38',
 '39',
 '7',
 '40',
 '41',
 '13',
 '14',
 '51',
 '9',
 '52',
 '43',
 '46',
 '68',
 '2',
 '77',
 '25',
 '4',
 '22',
 '23',
 '24',
 '15',
 '18',
 '19',
 '11',
 '12',
 '2',
 '63',
 '2',
 '77',
 '25',
 '4',
 '58',
 '85',
 '10',
 '30',
 '21',
 '84',
 '10',
 '30',
 '21',
 '25',
 '[SEP]',
 '160',
 '205',
 '6',
 '2',
 '342',
 '331',
 '294',
 '209',
 '79',
 '113',
 '197',
 '348',
 '328',
 '160',
 '277',
 '201',
 '195',
 '114',
 '225',
 '49',
 '52',
 '76',
 '206',
 '183',


In [111]:
def process_chunk(chunk, tokenizer=None, data_collator=None):
    """
    Process a chunk of drug-target pairs using create_bert_input function
    """
    if tokenizer is None:
        tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    if data_collator is None:
        data_collator = DataCollatorForLanguageModeling(
            tokenizer=tokenizer, 
            mlm=True, 
            mlm_probability=0.15
        )
    
    processed_sequences = []
    
    # Process each row in the chunk
    for idx in range(len(chunk)):
        # Get drug and protein encodings
        drug_encoding = chunk['Encoded_x'].iloc[idx]
        protein_encoding = chunk['Encoded_y'].iloc[idx]
        
        # Create BERT input sequence
        sequence = create_bert_input(drug_encoding, protein_encoding)
        
        # Convert sequence to string format
        sequence_str = ' '.join(sequence)
        
        # Apply masking using data_collator
        features = [{"input_ids": tokenizer.convert_tokens_to_ids(sequence)}]
        masked = data_collator(features)
        
        # Convert masked ids back to tokens and join
        masked_sequence = ' '.join(map(str, masked['input_ids'][0].tolist()))
        processed_sequences.append(masked_sequence)
    
    return pd.DataFrame({"Masked_Input": processed_sequences})


In [112]:
if not os.path.exists('final_bert_inputs_masked.h5'):
    chunksize = 10  # Reduced chunk size
    
    # Calculate maximum string length from first few chunks
    print("Calculating maximum sequence length...")
    max_len = 0
    sample_chunks = pd.read_csv("merged_encodings.csv", chunksize=chunksize, nrows=50)
    
    for chunk in sample_chunks:
        processed = process_chunk(chunk, tokenizer, data_collator)
        chunk_max = processed['Masked_Input'].str.len().max()
        max_len = max(max_len, chunk_max)
    
    # Add safety margin to max_len
    max_len = 20000
    print(f"Maximum sequence length (with buffer): {max_len}")
    
    # Set min_itemsize with the calculated max_len
    min_itemsize = {
        'Masked_Input': max_len,
        'index': 100  # for index column if needed
    }
    
    with pd.HDFStore('final_bert_inputs_masked.h5', mode='w') as store:
        # Get total chunks for progress bar
        total_chunks = sum(1 for _ in pd.read_csv("merged_encodings.csv", chunksize=chunksize))
        
        # Process chunks with progress bar
        for i, chunk in enumerate(tqdm(pd.read_csv("merged_encodings.csv", chunksize=chunksize), 
                                     total=total_chunks, 
                                     desc="Processing chunks")):
            try:
                processed_chunk = process_chunk(chunk, tokenizer, data_collator)
                
                # Store with explicit min_itemsize
                store.append('df', 
                           processed_chunk, 
                           format='table', 
                           data_columns=True,
                           min_itemsize=min_itemsize,
                           index=False)  # Disable index storage
                
                # Memory cleanup
                del processed_chunk
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                
                # Progress update every 10 chunks
                if (i + 1) % 10 == 0:
                    print(f"\nProcessed {i+1}/{total_chunks} chunks")
                    
            except Exception as e:
                print(f"\nError in chunk {i}: {str(e)}")
                print("Continuing with next chunk...")
                continue
            
    print("\nProcessing complete! Data saved to final_bert_inputs_masked.h5")
else:
    print("File already exists. Skipping processing.")

# Verify the saved data
def verify_data():
    with pd.HDFStore('final_bert_inputs_masked.h5', 'r') as store:
        print(f"\nVerification:")
        print(f"Number of stored sequences: {len(store['df'])}")
        print(f"Columns in storage: {store['df'].columns.tolist()}")
        
verify_data()
###
### Took 80 minutes to run the above code
###


File already exists. Skipping processing.

Verification:
Number of stored sequences: 2894
Columns in storage: ['Masked_Input']


# Model Training
Training the BERT model by iterating over the DataLoader for a set number of epochs.
Before training we are optimizing the input texts and labels to pass through Data Loader


In [113]:
%pip install h5py

Note: you may need to restart the kernel to use updated packages.


In [114]:
import h5py

In [115]:
def print_hdf5_structure(obj, indent=0):
    for key in obj.keys():
        item = obj[key]
        print("  " * indent + f"{key}: {type(item)}")
        if isinstance(item, h5py.Group):
            print_hdf5_structure(item, indent+1)

with h5py.File('final_bert_inputs_masked.h5', 'r') as f:
    print("HDF5 file structure:")
    print_hdf5_structure(f)

HDF5 file structure:
df: <class 'h5py._hl.group.Group'>
  table: <class 'h5py._hl.dataset.Dataset'>


In [116]:
with h5py.File('final_bert_inputs_masked.h5', 'r') as f:
    # Access the final table dataset
    dataset = f["df"]["table"]

    # Print the first 10 entries
    print("First 10 entries:")
    print(dataset[:10])
    


First 10 entries:
[(0, b'101 6275 103 103 2570 2603 2484 2321 2324 2539 2340 2260 1016 5139 5764 6163 1019 1020 2676 103 3590 1021 2654 3943 4090 2756 3486 103 2410 2403 4868 1023 103 4724 3963 6640 1018 2570 2603 2484 2321 2324 2539 2340 2260 1016 5139 5764 6163 1019 1020 4229 4464 103 103 4601 2410 2403 4868 1023 4720 4724 2184 1017 1018 5388 103 2184 2382 2538 2423 2538 5139 5764 6163 1019 1020 4229 4464 1021 2871 4601 2410 2403 4868 1023 4720 4724 4805 6273 1016 6255 2423 1018 2570 2603 2484 2321 103 2539 2340 103 1016 103 1016 6255 2423 1018 5388 5594 2184 2382 2538 15114 2184 103 2538 2423 102 8148 16327 1020 1016 100 103 28135 19348 6535 103 103 100 25256 8148 25578 16345 17317 12457 14993 4749 4720 6146 18744 103 9402 18512 22018 2423 18545 13029 16666 4868 7558 103 6255 24622 27234 15170 103 9402 12457 4029 2459 12862 20311 103 2538 15028 25504 103 19936 20713 13029 10550 9800 25586 21679 103 16798 11899 5388 6079 13427 19988 103 25103 15017 2382 9800 103 28489 25143 4700 1407

In [117]:
with h5py.File('final_bert_inputs_masked.h5', 'r') as f:
    dataset = f["df"]["table"]
    num_entries = dataset.shape[0]
    print("Number of entries:", num_entries)

Number of entries: 2894


In [118]:
# --- Step 1: Load Masked Inputs from HDF5 and Pair IDs from CSV ---
masked_df = pd.read_hdf("final_bert_inputs_masked.h5", key="df/table")
pairs_df = pd.read_csv("merged_encodings.csv")  # Should include columns "Drug id" and "pbd id"

In [119]:
# Assign the masked inputs (assumed column name "Masked_Input") from masked_df to pairs_df
pairs_df["Masked_Input"] = masked_df["Masked_Input"]
print(pairs_df["Masked_Input"].head())

0    101 6275 103 103 2570 2603 2484 2321 2324 2539...
1    101 6275 1019 1020 4229 4464 1021 2871 4601 24...
2    101 4008 4700 4466 1019 1020 2676 1021 103 275...
3    101 2459 4749 103 4413 5187 1015 4293 1015 656...
4    101 6275 9687 1018 2570 103 2484 103 2324 2539...
Name: Masked_Input, dtype: object


In [120]:
# --- Step 2: Load the Label Matrix ---
# Ensure that the CSV file now has "Drug id" as header in the first column
label_matrix = pd.read_csv("target_labels.csv")
print(label_matrix.head())

   Drug id  1EVU  1NSI  1DJL  1AB2  1ALS  1CFG  1EG0  1OZ5  4BSJ  ...  1TVB  \
0  DB11300     1     0     0     0     0     1     0     0     0  ...     0   
1  DB11311     1     0     0     0     0     0     0     0     0  ...     0   
2  DB11571     1     0     0     0     0     1     0     0     0  ...     0   
3  DB13151     1     0     0     0     0     1     0     0     0  ...     0   
4  DB05383     0     1     0     0     0     0     0     0     0  ...     0   

   1T5Q  6U6U  7RY7  2MDP  1G2C  4BPU  2X18  2N80  2KR6  
0     0     0     0     0     0     0     0     0     0  
1     0     0     0     0     0     0     0     0     0  
2     0     0     0     0     0     0     0     0     0  
3     0     0     0     0     0     0     0     0     0  
4     0     0     0     0     0     0     0     0     0  

[5 rows x 721 columns]


In [121]:
from torch.utils.data import Dataset

class DtiDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length=128):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        # Convert sequence list back to string if needed
        text = ' '.join(map(str, self.texts[idx])) if isinstance(self.texts[idx], list) else str(self.texts[idx])
        label = int(self.labels[idx])
        
        # Tokenize the text
        encoding = self.tokenizer(
            text,
            add_special_tokens=True,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        # Remove batch dimension added by tokenizer
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.long)
        }



#### Load BERT Model and Optimizer
Initialize the BERT tokenizer and the `BertForSequenceClassification` model for predicting drug-target interactions (using 2 labels for single classification). Set up the AdamW optimizer with a learning rate of 2e-5 to update model weights during training.

# Create Dataset and DataLoader
Extract the masked sequences and map each drug ID to its corresponding binary label (derived from the label matrix). Use these to create a PyTorch DataLoader, which facilitates efficient batch processing during model training.


In [122]:
# --- Step 5: Create Dataset and DataLoader ---
# Extract texts from the pairs DataFrame:
texts = pairs_df["Masked_Input"].tolist()
print(texts[0])
# For each drug id in pairs_df, get its label vector from the label_matrix.
# Here we assume the label_matrix’s index matches the "Drug id" values.

# Convert a multi-label vector into a single label (if that is your intended approach)

labels = pairs_df["label"].tolist()
# print(label_matrix.head())
#labels = pairs_df["Drug id"].apply(lambda drug: label_matrix.loc[drug].values.tolist()).tolist()

# Now create the dataset with the extracted lists
print(labels[0])
dataset = DtiDataset(texts, labels, tokenizer)
train_loader = DataLoader(dataset, batch_size=16, shuffle=True)
# print(dataset[1])


101 6275 103 103 2570 2603 2484 2321 2324 2539 2340 2260 1016 5139 5764 6163 1019 1020 2676 103 3590 1021 2654 3943 4090 2756 3486 103 2410 2403 4868 1023 103 4724 3963 6640 1018 2570 2603 2484 2321 2324 2539 2340 2260 1016 5139 5764 6163 1019 1020 4229 4464 103 103 4601 2410 2403 4868 1023 4720 4724 2184 1017 1018 5388 103 2184 2382 2538 2423 2538 5139 5764 6163 1019 1020 4229 4464 1021 2871 4601 2410 2403 4868 1023 4720 4724 4805 6273 1016 6255 2423 1018 2570 2603 2484 2321 103 2539 2340 103 1016 103 1016 6255 2423 1018 5388 5594 2184 2382 2538 15114 2184 103 2538 2423 102 8148 16327 1020 1016 100 103 28135 19348 6535 103 103 100 25256 8148 25578 16345 17317 12457 14993 4749 4720 6146 18744 103 9402 18512 22018 2423 18545 13029 16666 4868 7558 103 6255 24622 27234 15170 103 9402 12457 4029 2459 12862 20311 103 2538 15028 25504 103 19936 20713 13029 10550 9800 25586 21679 103 16798 11899 5388 6079 13427 19988 103 25103 15017 2382 9800 103 28489 25143 4700 14078 8574 12457 24331 15407 

In [131]:
import torch
import torch.nn as nn
from transformers import BertModel, BertConfig

class CustomBertModel(nn.Module):
    def __init__(self, num_labels=2):
        super().__init__()
        
        # Load BERT configuration
        config = BertConfig.from_pretrained('bert-base-uncased')
        
        # BERT base model
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.num_labels = num_labels
        
        # Paper-specific architecture
        self.batch_norm1 = nn.BatchNorm1d(config.hidden_size)
        
        # First transformation block (768 -> 1024)
        self.transform1 = nn.Sequential(
            nn.Linear(config.hidden_size, 1024),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.BatchNorm1d(1024)
        )
        
        # Second transformation block (1024 -> 512)
        self.transform2 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.BatchNorm1d(512)
        )
        
        # Third transformation block (512 -> 256)
        self.transform3 = nn.Sequential(
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.BatchNorm1d(256)
        )
        
        # Output layer
        self.classifier = nn.Linear(256, num_labels)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        # Get BERT outputs
        outputs = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask
        )
        
        # Get pooled output and apply first batch norm
        pooled_output = outputs[1]
        pooled_output = self.batch_norm1(pooled_output)
        
        # Apply transformation blocks
        x = self.transform1(pooled_output)
        x = self.transform2(x)
        x = self.transform3(x)
        
        # Get logits
        logits = self.classifier(x)
        
        # Calculate loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        return {
            'loss': loss,
            'logits': logits
        }

In [124]:
%pip install ipywidgets

Collecting fqdn (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.6.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets)
  Downloading fqdn-1.5.1-py3-none-any.whl.metadata (1.4 kB)
Collecting isoduration (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.6.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets)
  Downloading isoduration-20.11.0-py3-none-any.whl.metadata (5.7 kB)
Collecting uri-template (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.6.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets)
  Downloading uri_template-1.3.0-py3-none-any.whl.metadata (8.8 kB)
Collecting webcolors>=1.11 (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.6.0->jupyter-server<3,>=2.4.0->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets)
  Downloading webcolors-24.11.1-py3-none-any.whl.metadata (2.2 kB)
Downloading webcolors-24.11.1-py3-none-any.whl (14 kB)
Downloa

In [136]:
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm
import numpy as np
from transformers import BertForSequenceClassification, AdamW
import torch.optim as optim

train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts[:-500], labels[:-500], test_size=0.2, random_state=42
)

test_texts = texts[-500:]
test_labels = labels[-500:]

# -------------------------
# Create Datasets and DataLoader
# -------------------------

# Optimizer with weight decay fix
num_epochs = 5
model = CustomBertModel(num_labels=2)
# optimizer = optim.AdamW(model.parameters(), lr=1e-5)


train_dataset = DtiDataset(train_texts, train_labels, tokenizer, max_length=128)
val_dataset = DtiDataset(val_texts, val_labels, tokenizer, max_length=128)
test_dataset = DtiDataset(test_texts, test_labels, tokenizer, max_length=128)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
print(f"Train Loaded...")
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
print(f"Val Loaded...")
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
print(f"Test Loaded...")

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of validation samples: {len(val_dataset)}")
print(f"Number of testing samples: {len(test_dataset)}")

def evaluate(model, dataloader):
    model.eval()
    correct = total = 0
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Evaluating')
    
    with torch.no_grad():
        for batch in progress_bar:
            outputs = model(**batch)
            loss = outputs['loss']
            total_loss += loss.item()
            logits = outputs['logits']
            preds = logits.argmax(dim=1)
            correct += (preds == batch['labels']).sum().item()
            total += len(batch['labels'])
            
            current_accuracy = correct / total
            progress_bar.set_postfix({
                'loss': f'{total_loss/(progress_bar.n+1):.4f}',
                'accuracy': f'{current_accuracy:.4f}'
            })
    
    return total_loss / len(dataloader), correct / total

# -------------------------
# (Optional) Iterate over a batch from train_loader to inspect the data
# -------------------------
for batch in train_loader:
    print("Batch keys:", batch.keys())
    print("Input IDs shape:", batch["input_ids"].shape)
    print("Attention mask shape:", batch["attention_mask"].shape)
    print("Labels:", batch["labels"])
    break

# -------------------------
# Model Initialization and Training (Example)
# -------------------------
best_val_accuracy = 0

## Training loop with progress bars
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

# Improved optimizer settings
optimizer = optim.AdamW(model.parameters(), 
                       lr=2e-5,  # Reduced learning rate
                       weight_decay=0.01,  # L2 regularization
                       eps=1e-8)

# Add learning rate scheduler
from transformers import get_linear_schedule_with_warmup
num_training_steps = len(train_loader) * num_epochs
warmup_steps = num_training_steps // 10  # 10% warmup

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=num_training_steps
)

# Modified training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}/{num_epochs}')
    
    for batch in progress_bar:
        # Move batch to device
        batch = {k: v.to(device) for k, v in batch.items()}
        
        # Forward pass
        outputs = model(**batch)
        loss = outputs['loss']
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        # Track metrics
        total_loss += loss.item()
        logits = outputs['logits']
        preds = logits.argmax(dim=1)
        correct += (preds == batch['labels']).sum().item()
        total += len(batch['labels'])
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': f'{total_loss/(progress_bar.n+1):.4f}',
            'accuracy': f'{correct/total:.4f}',
            'lr': f'{scheduler.get_last_lr()[0]:.2e}'
        })

    # Evaluate and save
    val_loss, val_accuracy = evaluate(model, val_loader)
    
    # print(f"\nEpoch {epoch + 1}:")
    # print(f"Train Loss: {total_loss/len(train_loader):.4f}")
    # print(f"Train Accuracy: {correct/total:.4f}")
    # print(f"Val Loss: {val_loss:.4f}")
    # print(f"Val Accuracy: {val_accuracy:.4f}")
    
    # Save best model
    if val_accuracy > best_val_accuracy:
        best_val_accuracy = val_accuracy
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_accuracy,
            'val_loss': val_loss
        }, 'best_model.pth')
        print(f"New best model saved! (Validation Accuracy: {val_accuracy:.4f})")
    print(f"\nEpoch {epoch + 1} - Average loss: {train_loss:.4f} - Val loss: {val_loss:.4f} - Val accuracy: {val_accuracy:.4f}")

Train Loaded...
Val Loaded...
Test Loaded...
Number of training samples: 1923
Number of validation samples: 481
Number of testing samples: 500
Batch keys: dict_keys(['input_ids', 'attention_mask', 'labels'])
Input IDs shape: torch.Size([64, 128])
Attention mask shape: torch.Size([64, 128])
Labels: tensor([0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 1,
        1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1,
        1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0])


Epoch 1/5:   0%|          | 0/31 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]

New best model saved! (Validation Accuracy: 0.6133)

Epoch 1 - Average loss: 0.7345 - Val loss: 0.6690 - Val accuracy: 0.6133


Epoch 2/5:   0%|          | 0/31 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]


Epoch 2 - Average loss: 0.7345 - Val loss: 0.6802 - Val accuracy: 0.5967


Epoch 3/5:   0%|          | 0/31 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/8 [00:00<?, ?it/s]


Epoch 3 - Average loss: 0.7345 - Val loss: 0.7467 - Val accuracy: 0.3784


Epoch 4/5:   0%|          | 0/31 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [135]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score

def test(model, dataloader):
    model.eval()
    all_predictions = []
    all_labels = []
    all_scores = []
    total_loss = 0
    progress_bar = tqdm(dataloader, desc='Evaluating')
    
    with torch.no_grad():
        for batch in progress_bar:
            outputs = model(**batch)
            loss = outputs.loss
            total_loss += loss.item()
            
            # Get predictions and probabilities
            logits = outputs.logits
            probs = torch.softmax(logits, dim=1)
            preds = logits.argmax(dim=1)
            
            # Store predictions, scores (probability of positive class), and labels
            all_predictions.extend(preds.cpu().numpy())
            all_scores.extend(probs[:, 1].cpu().numpy())
            all_labels.extend(batch['labels'].cpu().numpy())
            
            # Update progress bar
            current_accuracy = accuracy_score(all_labels, all_predictions)
            progress_bar.set_postfix({
                'loss': f'{total_loss/(progress_bar.n+1):.4f}',
                'accuracy': f'{current_accuracy:.4f}'
            })

    # Calculate metrics
    metrics = {
        'loss': total_loss / len(dataloader),
        'accuracy': accuracy_score(all_labels, all_predictions),
        'precision': precision_score(all_labels, all_predictions, zero_division=0),
        'recall': recall_score(all_labels, all_predictions, zero_division=0),
        'f1': f1_score(all_labels, all_predictions, zero_division=0),
        'auc_roc': roc_auc_score(all_labels, all_scores),
        'auprc': average_precision_score(all_labels, all_scores)
    }
    
    # Print metrics
    print("\nTest Metrics:")
    for metric, value in metrics.items():
        print(f"{metric.replace('_', ' ').title()}: {value:.4f}")
    
    return metrics

# Evaluate model on test set
print("\nEvaluating on test set...")
model.load_state_dict(torch.load('best_model.pth'))
test_metrics = test(model, test_loader)


Evaluating on test set...


  model.load_state_dict(torch.load('best_model.pth'))


RuntimeError: Error(s) in loading state_dict for CustomBertModel:
	Missing key(s) in state_dict: "bert.embeddings.word_embeddings.weight", "bert.embeddings.position_embeddings.weight", "bert.embeddings.token_type_embeddings.weight", "bert.embeddings.LayerNorm.weight", "bert.embeddings.LayerNorm.bias", "bert.encoder.layer.0.attention.self.query.weight", "bert.encoder.layer.0.attention.self.query.bias", "bert.encoder.layer.0.attention.self.key.weight", "bert.encoder.layer.0.attention.self.key.bias", "bert.encoder.layer.0.attention.self.value.weight", "bert.encoder.layer.0.attention.self.value.bias", "bert.encoder.layer.0.attention.output.dense.weight", "bert.encoder.layer.0.attention.output.dense.bias", "bert.encoder.layer.0.attention.output.LayerNorm.weight", "bert.encoder.layer.0.attention.output.LayerNorm.bias", "bert.encoder.layer.0.intermediate.dense.weight", "bert.encoder.layer.0.intermediate.dense.bias", "bert.encoder.layer.0.output.dense.weight", "bert.encoder.layer.0.output.dense.bias", "bert.encoder.layer.0.output.LayerNorm.weight", "bert.encoder.layer.0.output.LayerNorm.bias", "bert.encoder.layer.1.attention.self.query.weight", "bert.encoder.layer.1.attention.self.query.bias", "bert.encoder.layer.1.attention.self.key.weight", "bert.encoder.layer.1.attention.self.key.bias", "bert.encoder.layer.1.attention.self.value.weight", "bert.encoder.layer.1.attention.self.value.bias", "bert.encoder.layer.1.attention.output.dense.weight", "bert.encoder.layer.1.attention.output.dense.bias", "bert.encoder.layer.1.attention.output.LayerNorm.weight", "bert.encoder.layer.1.attention.output.LayerNorm.bias", "bert.encoder.layer.1.intermediate.dense.weight", "bert.encoder.layer.1.intermediate.dense.bias", "bert.encoder.layer.1.output.dense.weight", "bert.encoder.layer.1.output.dense.bias", "bert.encoder.layer.1.output.LayerNorm.weight", "bert.encoder.layer.1.output.LayerNorm.bias", "bert.encoder.layer.2.attention.self.query.weight", "bert.encoder.layer.2.attention.self.query.bias", "bert.encoder.layer.2.attention.self.key.weight", "bert.encoder.layer.2.attention.self.key.bias", "bert.encoder.layer.2.attention.self.value.weight", "bert.encoder.layer.2.attention.self.value.bias", "bert.encoder.layer.2.attention.output.dense.weight", "bert.encoder.layer.2.attention.output.dense.bias", "bert.encoder.layer.2.attention.output.LayerNorm.weight", "bert.encoder.layer.2.attention.output.LayerNorm.bias", "bert.encoder.layer.2.intermediate.dense.weight", "bert.encoder.layer.2.intermediate.dense.bias", "bert.encoder.layer.2.output.dense.weight", "bert.encoder.layer.2.output.dense.bias", "bert.encoder.layer.2.output.LayerNorm.weight", "bert.encoder.layer.2.output.LayerNorm.bias", "bert.encoder.layer.3.attention.self.query.weight", "bert.encoder.layer.3.attention.self.query.bias", "bert.encoder.layer.3.attention.self.key.weight", "bert.encoder.layer.3.attention.self.key.bias", "bert.encoder.layer.3.attention.self.value.weight", "bert.encoder.layer.3.attention.self.value.bias", "bert.encoder.layer.3.attention.output.dense.weight", "bert.encoder.layer.3.attention.output.dense.bias", "bert.encoder.layer.3.attention.output.LayerNorm.weight", "bert.encoder.layer.3.attention.output.LayerNorm.bias", "bert.encoder.layer.3.intermediate.dense.weight", "bert.encoder.layer.3.intermediate.dense.bias", "bert.encoder.layer.3.output.dense.weight", "bert.encoder.layer.3.output.dense.bias", "bert.encoder.layer.3.output.LayerNorm.weight", "bert.encoder.layer.3.output.LayerNorm.bias", "bert.encoder.layer.4.attention.self.query.weight", "bert.encoder.layer.4.attention.self.query.bias", "bert.encoder.layer.4.attention.self.key.weight", "bert.encoder.layer.4.attention.self.key.bias", "bert.encoder.layer.4.attention.self.value.weight", "bert.encoder.layer.4.attention.self.value.bias", "bert.encoder.layer.4.attention.output.dense.weight", "bert.encoder.layer.4.attention.output.dense.bias", "bert.encoder.layer.4.attention.output.LayerNorm.weight", "bert.encoder.layer.4.attention.output.LayerNorm.bias", "bert.encoder.layer.4.intermediate.dense.weight", "bert.encoder.layer.4.intermediate.dense.bias", "bert.encoder.layer.4.output.dense.weight", "bert.encoder.layer.4.output.dense.bias", "bert.encoder.layer.4.output.LayerNorm.weight", "bert.encoder.layer.4.output.LayerNorm.bias", "bert.encoder.layer.5.attention.self.query.weight", "bert.encoder.layer.5.attention.self.query.bias", "bert.encoder.layer.5.attention.self.key.weight", "bert.encoder.layer.5.attention.self.key.bias", "bert.encoder.layer.5.attention.self.value.weight", "bert.encoder.layer.5.attention.self.value.bias", "bert.encoder.layer.5.attention.output.dense.weight", "bert.encoder.layer.5.attention.output.dense.bias", "bert.encoder.layer.5.attention.output.LayerNorm.weight", "bert.encoder.layer.5.attention.output.LayerNorm.bias", "bert.encoder.layer.5.intermediate.dense.weight", "bert.encoder.layer.5.intermediate.dense.bias", "bert.encoder.layer.5.output.dense.weight", "bert.encoder.layer.5.output.dense.bias", "bert.encoder.layer.5.output.LayerNorm.weight", "bert.encoder.layer.5.output.LayerNorm.bias", "bert.encoder.layer.6.attention.self.query.weight", "bert.encoder.layer.6.attention.self.query.bias", "bert.encoder.layer.6.attention.self.key.weight", "bert.encoder.layer.6.attention.self.key.bias", "bert.encoder.layer.6.attention.self.value.weight", "bert.encoder.layer.6.attention.self.value.bias", "bert.encoder.layer.6.attention.output.dense.weight", "bert.encoder.layer.6.attention.output.dense.bias", "bert.encoder.layer.6.attention.output.LayerNorm.weight", "bert.encoder.layer.6.attention.output.LayerNorm.bias", "bert.encoder.layer.6.intermediate.dense.weight", "bert.encoder.layer.6.intermediate.dense.bias", "bert.encoder.layer.6.output.dense.weight", "bert.encoder.layer.6.output.dense.bias", "bert.encoder.layer.6.output.LayerNorm.weight", "bert.encoder.layer.6.output.LayerNorm.bias", "bert.encoder.layer.7.attention.self.query.weight", "bert.encoder.layer.7.attention.self.query.bias", "bert.encoder.layer.7.attention.self.key.weight", "bert.encoder.layer.7.attention.self.key.bias", "bert.encoder.layer.7.attention.self.value.weight", "bert.encoder.layer.7.attention.self.value.bias", "bert.encoder.layer.7.attention.output.dense.weight", "bert.encoder.layer.7.attention.output.dense.bias", "bert.encoder.layer.7.attention.output.LayerNorm.weight", "bert.encoder.layer.7.attention.output.LayerNorm.bias", "bert.encoder.layer.7.intermediate.dense.weight", "bert.encoder.layer.7.intermediate.dense.bias", "bert.encoder.layer.7.output.dense.weight", "bert.encoder.layer.7.output.dense.bias", "bert.encoder.layer.7.output.LayerNorm.weight", "bert.encoder.layer.7.output.LayerNorm.bias", "bert.encoder.layer.8.attention.self.query.weight", "bert.encoder.layer.8.attention.self.query.bias", "bert.encoder.layer.8.attention.self.key.weight", "bert.encoder.layer.8.attention.self.key.bias", "bert.encoder.layer.8.attention.self.value.weight", "bert.encoder.layer.8.attention.self.value.bias", "bert.encoder.layer.8.attention.output.dense.weight", "bert.encoder.layer.8.attention.output.dense.bias", "bert.encoder.layer.8.attention.output.LayerNorm.weight", "bert.encoder.layer.8.attention.output.LayerNorm.bias", "bert.encoder.layer.8.intermediate.dense.weight", "bert.encoder.layer.8.intermediate.dense.bias", "bert.encoder.layer.8.output.dense.weight", "bert.encoder.layer.8.output.dense.bias", "bert.encoder.layer.8.output.LayerNorm.weight", "bert.encoder.layer.8.output.LayerNorm.bias", "bert.encoder.layer.9.attention.self.query.weight", "bert.encoder.layer.9.attention.self.query.bias", "bert.encoder.layer.9.attention.self.key.weight", "bert.encoder.layer.9.attention.self.key.bias", "bert.encoder.layer.9.attention.self.value.weight", "bert.encoder.layer.9.attention.self.value.bias", "bert.encoder.layer.9.attention.output.dense.weight", "bert.encoder.layer.9.attention.output.dense.bias", "bert.encoder.layer.9.attention.output.LayerNorm.weight", "bert.encoder.layer.9.attention.output.LayerNorm.bias", "bert.encoder.layer.9.intermediate.dense.weight", "bert.encoder.layer.9.intermediate.dense.bias", "bert.encoder.layer.9.output.dense.weight", "bert.encoder.layer.9.output.dense.bias", "bert.encoder.layer.9.output.LayerNorm.weight", "bert.encoder.layer.9.output.LayerNorm.bias", "bert.encoder.layer.10.attention.self.query.weight", "bert.encoder.layer.10.attention.self.query.bias", "bert.encoder.layer.10.attention.self.key.weight", "bert.encoder.layer.10.attention.self.key.bias", "bert.encoder.layer.10.attention.self.value.weight", "bert.encoder.layer.10.attention.self.value.bias", "bert.encoder.layer.10.attention.output.dense.weight", "bert.encoder.layer.10.attention.output.dense.bias", "bert.encoder.layer.10.attention.output.LayerNorm.weight", "bert.encoder.layer.10.attention.output.LayerNorm.bias", "bert.encoder.layer.10.intermediate.dense.weight", "bert.encoder.layer.10.intermediate.dense.bias", "bert.encoder.layer.10.output.dense.weight", "bert.encoder.layer.10.output.dense.bias", "bert.encoder.layer.10.output.LayerNorm.weight", "bert.encoder.layer.10.output.LayerNorm.bias", "bert.encoder.layer.11.attention.self.query.weight", "bert.encoder.layer.11.attention.self.query.bias", "bert.encoder.layer.11.attention.self.key.weight", "bert.encoder.layer.11.attention.self.key.bias", "bert.encoder.layer.11.attention.self.value.weight", "bert.encoder.layer.11.attention.self.value.bias", "bert.encoder.layer.11.attention.output.dense.weight", "bert.encoder.layer.11.attention.output.dense.bias", "bert.encoder.layer.11.attention.output.LayerNorm.weight", "bert.encoder.layer.11.attention.output.LayerNorm.bias", "bert.encoder.layer.11.intermediate.dense.weight", "bert.encoder.layer.11.intermediate.dense.bias", "bert.encoder.layer.11.output.dense.weight", "bert.encoder.layer.11.output.dense.bias", "bert.encoder.layer.11.output.LayerNorm.weight", "bert.encoder.layer.11.output.LayerNorm.bias", "bert.pooler.dense.weight", "bert.pooler.dense.bias", "batch_norm1.weight", "batch_norm1.bias", "batch_norm1.running_mean", "batch_norm1.running_var", "transform1.0.weight", "transform1.0.bias", "transform1.3.weight", "transform1.3.bias", "transform1.3.running_mean", "transform1.3.running_var", "transform2.0.weight", "transform2.0.bias", "transform2.3.weight", "transform2.3.bias", "transform2.3.running_mean", "transform2.3.running_var", "transform3.0.weight", "transform3.0.bias", "transform3.3.weight", "transform3.3.bias", "transform3.3.running_mean", "transform3.3.running_var", "classifier.weight", "classifier.bias". 
	Unexpected key(s) in state_dict: "epoch", "model_state_dict", "optimizer_state_dict", "val_acc", "val_loss". 

In [None]:

print(pairs_df.columns.tolist())

['Drug id', 'pdb_id', 'label', 'smiles', 'Encoded_x', 'sequence', 'Encoded_y', 'Masked_Input']


### Model Evalutation
Due to small dataset and overfitting, accuracy is 1.

In [None]:
# from tqdm import tqdm
# model.eval()
# correct = total = 0
# with torch.no_grad():
#     for batch in tqdm(train_loader, desc="Evaluating"):
#         outputs = model(**batch)
#         preds = outputs.logits.argmax(dim=1)
#         correct += (preds == batch['labels']).sum().item()
#         total += len(batch['labels'])
# accuracy = correct / total
# print(f"Accuracy: {accuracy:.4f}")



Evaluating: 100%|██████████| 61/61 [06:02<00:00,  5.94s/it]

Accuracy: 1.0000





# Prediction Function
Takes a drug ID and a PDB ID as inputs, constructs the BERT input sequence, tokenizes it, and uses the trained model to predict whether an interaction exists. The function returns "yes" or "no" based on the model's output.

In [None]:
# def predict_interaction(drug_id, pdb_id, tokenizer=tokenizer, model=model):
#     input_str = f"[CLS] {drug_id} [SEP] {pdb_id} [SEP]"
#     inputs = tokenizer(input_str, return_tensors='pt', truncation=True, max_length=512)
#     outputs = model(**inputs)
#     prediction = outputs.logits.argmax(dim=1).item()
#     return "yes" if prediction == 1 else "no"
# predict_interaction('DB01254','1GQ5')

'yes'