In [1]:
import torch
import pandas as pd
import json
import os
import re
import nltk
from nltk.corpus import stopwords
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer, AutoModel

In [3]:
# initialize nltk stopwords
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

# set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device being used: {device}")

Device being used: mps


[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/fiatlux/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [4]:
data_path = "../data/processed/shuffled_10_data.csv"
test_path = "../data/processed/test_data.csv"
model_dir = "../src/model/"

# data preprocessing
df = pd.read_csv(data_path)
df_test = pd.read_csv(test_path)

def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\d+', '', text) 
    text = " ".join([word.strip() for word in text.split() if word not in stop_words])
    return text

if 'Text_Cleaned' not in df:
    df['Text_Cleaned'] = df['Text_combined'].apply(clean_text)

df['Terms'] = df['Terms'].fillna('non-autoregulatory')

columns_to_keep = ['batch_number', 'Text_Cleaned', 'Terms']
df_cleaned = df[columns_to_keep]

df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(lambda x: [term.strip() for term in x.split(',')])
df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))

mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(df_cleaned['Terms_List'])
label_columns = mlb.classes_

df_test['Text_Cleaned'] = df_test['Text_combined'].apply(clean_text)

tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(lambda x: [term.strip() for term in x.split(',')])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))


In [5]:
class PubMedBERTClassifier(torch.nn.Module):
    def __init__(self, n_classes):
        super(PubMedBERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [6]:
def load_model_and_thresholds(model_path, thresholds_path):
    model = PubMedBERTClassifier(n_classes=len(label_columns))
    state_dict = torch.load(model_path, map_location="cpu")
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()

    if os.path.exists(thresholds_path):
        with open(thresholds_path, "r") as f:
            thresholds = json.load(f)
    else:
        thresholds = [0.5] * len(label_columns)

    return model, thresholds

In [7]:
def preprocess_text(text, tokenizer, max_length=512):
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    return encoding

In [8]:
def predict(model, df, tokenizer, thresholds):
    predictions = []

    for _, row in df.iterrows():
        text = str(row['Text_Cleaned']).strip()
        encoding = preprocess_text(text, tokenizer)
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        with torch.no_grad():
            outputs = model(input_ids, attention_mask)
            probabilities = torch.sigmoid(outputs).cpu().numpy().flatten()

        predicted_labels = [label_columns[i] for i in range(len(label_columns)) if probabilities[i] >= thresholds[i]]
        if not predicted_labels:
            predicted_labels = ["non-autoregulatory"]

        predictions.append(", ".join(sorted(predicted_labels)))

    return predictions

In [9]:
df_test.head()

Unnamed: 0,AC,PMID,Title,Abstract,Terms,Text_combined,Text_Cleaned
0,Q6TFL4,27798626,Stabilizing mutations of KLHL24 ubiquitin liga...,Skin integrity is essential for protection fro...,autoubiquitination,Stabilizing mutations of KLHL24 ubiquitin liga...,stabilizing mutations klhl ubiquitin ligase ca...
1,Q92673,20015111,GGA autoinhibition revisited.,The cytosolic adaptors GGA1-3 mediate sorting ...,autoinhibition,GGA autoinhibition revisited. The cytosolic ad...,gga autoinhibition revisited cytosolic adaptor...
2,P0A964,2068106,Bacterial chemotaxis signaling complexes: form...,We have demonstrated that a complex of the pro...,autophosphorylation,Bacterial chemotaxis signaling complexes: form...,bacterial chemotaxis signaling complexes forma...
3,Q13557,14722083,Comparative analyses of the three-dimensional ...,Ca(2+)-calmodulin-dependent protein kinase II ...,autophosphorylation,Comparative analyses of the three-dimensional ...,comparative analyses threedimensional structur...
4,P74646,23449916,Biochemical analysis of three putative KaiC cl...,Cyanobacteria have been shown to have a circad...,autophosphorylation,Biochemical analysis of three putative KaiC cl...,biochemical analysis three putative kaic clock...


In [10]:
# initialize results df
results_df = pd.DataFrame()
results_df[['AC', 'PMID', 'Actual Terms']] = df_test[['AC', 'PMID', 'Terms']]

In [11]:
# predict
for i in range(1, 3):
    model_path = os.path.join(model_dir, f"best_model_batch_{i}.pt")
    thresholds_path = os.path.join(model_dir, f"best_thresholds_batch_{i}.json")

    model, thresholds = load_model_and_thresholds(model_path, thresholds_path)
    pred_column_name = f"Pred Batch {i}"
    results_df[pred_column_name] = predict(model, df_test, tokenizer, thresholds)

  state_dict = torch.load(model_path, map_location="cpu")
  state_dict = torch.load(model_path, map_location="cpu")


In [12]:
results_df

Unnamed: 0,AC,PMID,Actual Terms,Pred Batch 1,Pred Batch 2
0,Q6TFL4,27798626,autoubiquitination,autoubiquitination,autoubiquitination
1,Q92673,20015111,autoinhibition,"autoinhibition, autoinhibitory","autoinhibition, autoinhibitory"
2,P0A964,2068106,autophosphorylation,"autokinase, autophosphorylation",autophosphorylation
3,Q13557,14722083,autophosphorylation,autophosphorylation,autophosphorylation
4,P74646,23449916,autophosphorylation,autophosphorylation,autophosphorylation
5,Q9JK25,20519438,autoinhibition,"autoinhibition, autoinhibitory","autoinhibition, autoinhibitory"
6,O00418,22216903,autophosphorylation,autophosphorylation,autophosphorylation
7,Q9R1X4,9856465,autoregulatory,"autoregulation, autoregulatory",autoregulatory
8,Q65652,7871721,autocatalytic,"autocatalysis, autocatalytic","autocatalysis, autocatalytic"
9,O14965,19812038,autophosphorylation,autophosphorylation,autophosphorylation


| Batch | Micro-F1 | Macro-F1 | Weighted-F1 | Samples-F1 | Sample-Precision | Sample-Recall |
|-------|----------|----------|-------------|------------|------------------|---------------|
|   1   |  0.8749  |  0.6532  |    0.9073   |   0.8897   |      0.8766      |     0.9216    |
|   2   |  0.9380  |  0.7721  |    0.9456   |   0.9455   |      0.9406      |     0.9570    |
|   3   |  0.9327  |  0.7529  |    0.9492   |   0.9409   |      0.9336      |     0.9578    |
|   4   |  0.9618  |  0.8444  |    0.9694   |   0.9665   |      0.9631      |     0.9751    |
|   5   |  0.9692  |  0.9292  |    0.9702   |   0.9699   |      0.9707      |     0.9700    |
|   6   |  0.9767  |  0.9363  |    0.9773   |   0.9790   |      0.9788      |     0.9810    |
|   7   |  0.9818  |  0.9781  |    0.9821   |   0.9820   |      0.9826      |     0.9826    |
|   8   |  0.9814  |  0.9073  |    0.9825   |   0.9837   |      0.9836      |     0.9845    |
|   9   |  0.9870  |  0.9695  |    0.9873   |   0.9883   |      0.9886      |     0.9897    |
|  10   |  0.9879  |  0.9558  |    0.9889   |   0.9886   |      0.9883      |     0.9897    |