In [95]:
import torch
import pandas as pd
import json
import os
import re
import nltk
from nltk.corpus import stopwords
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer, AutoModel

In [96]:
# initialize nltk stopwords
nltk.download('stopwords')
stop_words = set(stopwords.words('english'))

# set device
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device being used: {device}")

Device being used: mps


[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/fiatlux/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [97]:
data_path = "../data/processed/shuffled_10_data.csv"
test_path = "../data/processed/test_data.csv"
model_path = "../src/model/best_model_batch_1.pt"
thresholds_path = "../src/model/best_thresholds_batch_1.json"

In [114]:
# data preprocessing
df = pd.read_csv(data_path)
df_test = pd.read_csv(test_path)
# df_test = df_test[5:] # for testing

def clean_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    text = re.sub(r'\d+', '', text) 
    text = " ".join([word.strip() for word in text.split() if word not in stop_words])
    return text

if 'Text_Cleaned' not in df:
    df['Text_Cleaned'] = df['Text_combined'].apply(clean_text)

df['Terms'] = df['Terms'].fillna('non-autoregulatory')

columns_to_keep = ['batch_number', 'Text_Cleaned', 'Terms']
df_cleaned = df[columns_to_keep]

df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(lambda x: [term.strip() for term in x.split(',')])
df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))

mlb = MultiLabelBinarizer()
labels = mlb.fit_transform(df_cleaned['Terms_List'])
label_columns = mlb.classes_

labels_df = pd.DataFrame(labels, columns=label_columns)

df_cleaned = pd.concat([df_cleaned, labels_df], axis=1)

df_test['Text_Cleaned'] = df_test['Text_combined'].apply(clean_text)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms'].apply(lambda x: [term.strip() for term in x.split(',')])
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  df_cleaned['Terms_List'] = df_cleaned['Terms_List'].apply(lambda x: list(set(x)))


In [115]:
class PubMedBERTClassifier(torch.nn.Module):
    def __init__(self, n_classes):
        super(PubMedBERTClassifier, self).__init__()
        self.bert = AutoModel.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(self.bert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        return logits

In [116]:
def load_model_and_thresholds():
    # Load model to CPU first
    model = PubMedBERTClassifier(n_classes=len(label_columns))
    state_dict = torch.load(model_path, map_location="cpu")  # Always load to CPU first
    model.load_state_dict(state_dict)
    
    # Move to MPS
    model = model.to(device)
    model.eval()

    # Load thresholds
    if os.path.exists(thresholds_path):
        with open(thresholds_path, "r") as f:
            thresholds = json.load(f)
        print(f"Thresholds loaded from {thresholds_path}")
    else:
        print(f"Thresholds file not found. Using default 0.5 for all labels.")
        thresholds = [0.5] * len(label_columns)

    return model, thresholds

In [117]:
tokenizer = AutoTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')

def preprocess_text(text, tokenizer, max_length=512):
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=max_length,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    return encoding

In [118]:
def predict(model, df, tokenizer, thresholds):
    predictions = []
    accuracies = []

    for idx, row in df.iterrows():
        text = str(row['Text_Cleaned']).strip()
        actual_labels = str(row['Terms']).split(",") if pd.notna(row['Terms']) else ["non-autoregulatory"]

        # Ensure no empty strings in actual_labels
        actual_labels = [label.strip() for label in actual_labels if label.strip()]

        if not text:
            predictions.append(["non-autoregulatory"])
            accuracies.append(1.0 if "non-autoregulatory" in actual_labels else 0.0)
            continue

        encoding = preprocess_text(text, tokenizer)
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        try:
            with torch.no_grad():
                outputs = model(input_ids, attention_mask)
                probabilities = torch.sigmoid(outputs).cpu().numpy().flatten()

            if len(probabilities) != len(label_columns):
                print(f"Warning: Probabilities length {len(probabilities)} does not match label columns length {len(label_columns)}")
                predictions.append(["error"])
                accuracies.append(0.0)
                continue

            # apply prediction logic based on thresholds
            predicted_labels = [
                label_columns[i] for i in range(len(label_columns))
                if probabilities[i] >= thresholds[i]
            ]

            # if no predicted labels, mark as "non-autoregulatory"
            if not predicted_labels:
                predicted_labels = ["non-autoregulatory"]

        except Exception as e:
            print(f"Error during prediction for index {idx}: {e}")
            predicted_labels = ["error"]

        # calculate accuracy
        actual_set = set(actual_labels)
        predicted_set = set(predicted_labels)

        # Intersection over Union (IoU) for multi-label accuracy
        intersection = len(actual_set & predicted_set)
        union = len(actual_set | predicted_set)
        accuracy = intersection / union if union != 0 else 1.0

        predictions.append(sorted(predicted_set))
        accuracies.append(round(accuracy, 2))

    return predictions, accuracies

In [119]:
# load model and thresholds
model, thresholds = load_model_and_thresholds()

  state_dict = torch.load(model_path, map_location="cpu")  # Always load to CPU first


Thresholds loaded from ../src/model/best_thresholds_batch_1.json


In [120]:
# prdict
df_test['Predicted_Terms'], df_test['Accuracy'] = predict(model, df_test, tokenizer, thresholds)

In [121]:
df_test[['Terms', 'Predicted_Terms', 'Accuracy']]

Unnamed: 0,Terms,Predicted_Terms,Accuracy
0,autoubiquitination,[autoubiquitination],1.0
1,autoinhibition,"[autoinhibition, autoinhibitory]",0.5
2,autophosphorylation,[autophosphorylation],1.0
3,autophosphorylation,[autophosphorylation],1.0
4,autophosphorylation,[autophosphorylation],1.0
5,autoinhibition,"[autoinhibition, autoinhibitory]",0.5
6,autophosphorylation,[autophosphorylation],1.0
7,autoregulatory,[autoregulatory],1.0
8,autocatalytic,[autocatalytic],1.0
9,autophosphorylation,[autophosphorylation],1.0
