In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForMaskedLM
from transformers import Trainer, TrainingArguments
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings("ignore")

In [2]:
# Load dataset
df = pd.read_csv("medical_words.csv")

In [3]:
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')
model = BertForMaskedLM.from_pretrained('microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext')

Some weights of the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [4]:
# Tokenize medical words
tokenized_medical_words = [tokenizer.encode(word, add_special_tokens=False) for word in df['Medical Word']]

In [5]:
# Fine-tuning the model
class MedicalWordsDataset(Dataset):
    def __init__(self, tokenized_texts, tokenizer):
        self.tokenized_texts = tokenized_texts
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.tokenized_texts)

    def __getitem__(self, idx):
        inputs = self.tokenizer.prepare_for_model(
            self.tokenized_texts[idx],
            return_tensors="pt",
            padding="max_length",
            truncation=True,
            max_length=512,
        )
        return {
            "input_ids": inputs["input_ids"].squeeze(),
            "attention_mask": inputs["attention_mask"].squeeze(),
            "labels": inputs["input_ids"].squeeze(),
        }

In [6]:
train_dataset = MedicalWordsDataset(tokenized_medical_words, tokenizer)

In [7]:
training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=5,
    per_device_train_batch_size=8,
    save_steps=10_000,
    save_total_limit=2,
    logging_dir="./logs",
)

In [8]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)

trainer.train()

  0%|          | 0/905 [00:00<?, ?it/s]

: 

In [None]:
# Function to predict medical words
def predict_medical_words(prompt, tokenizer, model):
    masked_index = prompt.index("[MASK]")
    prompt = prompt.replace("[MASK]", tokenizer.mask_token)
    tokenized_prompt = tokenizer.encode(prompt, return_tensors="pt")
    mask_token_index = torch.where(tokenized_prompt == tokenizer.mask_token_id)[1]

    model.eval()
    with torch.no_grad():
        result = model(tokenized_prompt)
    masked_token_logits = result.logits[0, mask_token_index, :]
    masked_token_probs = torch.softmax(masked_token_logits, dim=1)
    top_5_tokens = torch.topk(masked_token_probs, 5, dim=1)

    predicted_words = []
    for token in top_5_tokens.indices[0]:
        predicted_words.append(tokenizer.decode([token.item()]))

    medical_predicted_words = [
        word for word in predicted_words if word in df["Medical Word"].values
    ]

    return medical_predicted_words

In [None]:
# Example usage
prompt = "The patient has [MASK]."
medical_words = predict_medical_words(prompt, tokenizer, model)
print("Medical Words:", medical_words)