In [1]:
import torch
import pandas as pd
import json
import os
import re
import nltk
from nltk.corpus import stopwords
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer, AutoModel

In [2]:
# initialize nltk stopwords
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

# set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device being used: {device}")

Device being used: mps


[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/fiatlux/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [3]:
data_path = "../data/processed/train_data.csv"
test_path = "../data/processed/test_data.csv"
model_dir = "../src/model/"

# data preprocessing
df = pd.read_csv(data_path)
df_test = pd.read_csv(test_path)

def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\d+', '', text) 
    text = " ".join([word.strip() for word in text.split() if word not in stop_words])
    return text

if 'Text_Cleaned' not in df:
    df['Text_Cleaned'] = df['Text_combined'].apply(clean_text)

df['Terms'] = df['Terms'].fillna('non-autoregulatory')

columns_to_keep = ['batch_number', 'Text_Cleaned', 'Terms']
df_cleaned = df[columns_to_keep]

df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(lambda x: [term.strip() for term in x.split(',')])
df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))

mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(df_cleaned['Terms_List'])
label_columns = mlb.classes_

df_test['Text_Cleaned'] = df_test['Text_combined'].apply(clean_text)

tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(lambda x: [term.strip() for term in x.split(',')])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))


In [4]:
class PubMedBERTClassifier(torch.nn.Module):
    def __init__(self, n_classes):
        super(PubMedBERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [5]:
def load_model_and_thresholds(model_path, thresholds_path):
    model = PubMedBERTClassifier(n_classes=len(label_columns))
    state_dict = torch.load(model_path, map_location="cpu")
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()

    if os.path.exists(thresholds_path):
        with open(thresholds_path, "r") as f:
            thresholds = json.load(f)
    else:
        thresholds = [0.5] * len(label_columns)

    return model, thresholds

In [6]:
def preprocess_text(text, tokenizer, max_length=512):
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    return encoding

In [7]:
def predict(model, df, tokenizer, thresholds):
    predictions = []

    for _, row in df.iterrows():
        text = str(row['Text_Cleaned']).strip()
        encoding = preprocess_text(text, tokenizer)
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        with torch.no_grad():
            outputs = model(input_ids, attention_mask)
            probabilities = torch.sigmoid(outputs).cpu().numpy().flatten()

        predicted_labels = [label_columns[i] for i in range(len(label_columns)) if probabilities[i] >= thresholds[i]]
        if not predicted_labels:
            predicted_labels = ["non-autoregulatory"]

        predictions.append(", ".join(sorted(predicted_labels)))

    return predictions

In [8]:
df_test.head()

Unnamed: 0,AC,PMID,Title,Abstract,if_contain_keyterm,Terms,Text_combined,Text_Cleaned
0,Q2G2U4,17827301,New insights into the WalK/WalR (YycG/YycF) es...,The highly conserved WalK/WalR (also known as ...,0,autolysis,New insights into the WalK/WalR (YycG/YycF) es...,new insights walkwalr yycgyycf essential signa...
1,P00520,20072125,Targeting Bcr-Abl by combining allosteric with...,In an effort to find new pharmacological modal...,0,autophosphorylation,Targeting Bcr-Abl by combining allosteric with...,targeting bcrabl combining allosteric atpbindi...
2,Q8RXD3,15998807,The AIP2 E3 ligase acts as a novel negative re...,The phytohormone abscisic acid (ABA) mediates ...,0,autoubiquitination,The AIP2 E3 ligase acts as a novel negative re...,aip e ligase acts novel negative regulator aba...
3,B0FLN1,18281398,Isolation and characterization of an autoinduc...,The opportunistic human pathogen Acinetobacter...,0,autoinduction,Isolation and characterization of an autoinduc...,isolation characterization autoinducer synthas...
4,P00519,16543148,Organization of the SH3-SH2 unit in active and...,The tyrosine kinase c-Abl is inactivated by in...,0,autoinhibition,Organization of the SH3-SH2 unit in active and...,organization shsh unit active inactive forms c...


In [9]:
# initialize results df
results_df = pd.DataFrame()
results_df[['AC', 'PMID', 'Actual Terms']] = df_test[['AC', 'PMID', 'Terms']]

In [10]:
# predict
for i in range(1, 5):
    model_path = os.path.join(model_dir, f"best_model_batch_{i}.pt")
    thresholds_path = os.path.join(model_dir, f"best_thresholds_batch_{i}.json")

    model, thresholds = load_model_and_thresholds(model_path, thresholds_path)
    pred_column_name = f"Pred Batch {i}"
    results_df[pred_column_name] = predict(model, df_test, tokenizer, thresholds)

  state_dict = torch.load(model_path, map_location="cpu")


RuntimeError: Error(s) in loading state_dict for PubMedBERTClassifier:
	size mismatch for classifier.weight: copying a param with shape torch.Size([15, 768]) from checkpoint, the shape in current model is torch.Size([11, 768]).
	size mismatch for classifier.bias: copying a param with shape torch.Size([15]) from checkpoint, the shape in current model is torch.Size([11]).

In [19]:
results_df

Unnamed: 0,AC,PMID,Actual Terms,Pred Batch 1,Pred Batch 2,Pred Batch 3,Pred Batch 4
0,Q6TFL4,27798626,autoubiquitination,autoubiquitination,autoubiquitination,autoubiquitination,autoubiquitination
1,Q92673,20015111,autoinhibition,"autoinhibition, autoinhibitory","autoinhibition, autoinhibitory","autoinhibition, autoinhibitory",autoinhibition
2,P0A964,2068106,autophosphorylation,"autokinase, autophosphorylation","autocatalysis, autophosphorylation",autophosphorylation,autophosphorylation
3,Q13557,14722083,autophosphorylation,autophosphorylation,autophosphorylation,autophosphorylation,autophosphorylation
4,P74646,23449916,autophosphorylation,autophosphorylation,autophosphorylation,autophosphorylation,autophosphorylation
5,Q9JK25,20519438,autoinhibition,"autoinhibition, autoinhibitory","autoinhibition, autoinhibitory",autoinhibition,autoinhibition
6,O00418,22216903,autophosphorylation,autophosphorylation,autophosphorylation,autophosphorylation,autophosphorylation
7,Q9R1X4,9856465,autoregulatory,"autoregulation, autoregulatory",autoregulatory,autoregulatory,autoregulatory
8,Q65652,7871721,autocatalytic,"autocatalysis, autocatalytic",autocatalytic,autocatalytic,autocatalytic
9,O14965,19812038,autophosphorylation,autophosphorylation,autophosphorylation,autophosphorylation,autophosphorylation


| Batch | Micro-F1 | Macro-F1 | Weighted-F1 | Samples-F1 | Sample-Precision | Sample-Recall |
|-------|----------|----------|-------------|------------|------------------|---------------|
|   1   |  0.8573  |  0.5921  |    0.9067   |   0.8808   |      0.8608      |     0.9290    |
|   2   |  0.9047  |  0.6748  |    0.9306   |   0.9215   |      0.9105      |     0.9471    |
|   3   |  0.9355  |  0.8356  |    0.9417   |   0.9381   |      0.9342      |     0.9477    |
|   4   |  0.9550  |  0.8430  |    0.9605   |   0.9613   |      0.9581      |     0.9692    |
|   5   |----------|----------|-------------|------------|------------------|---------------|