In [56]:
import torch
import pandas as pd
from nltk.tokenize import word_tokenize
from transformers import BertTokenizer, BertModel, BertForSequenceClassification, AdamW
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

In [2]:
# Load pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [63]:
# Sample dataset of user prompts
user_prompts = [
    "I have a fever and headache.",
    "My throat is sore and I have a cough.",
    "I am diabetic and need insulin.",
    "I took some paracetamol for my headache.",
]

In [64]:
# medical vocabulary
df = pd.read_csv("medical_words.csv")
medical_vocab = df["Medical Word"].values.tolist()
user_prompts = medical_vocab
print("Loaded", len(medical_vocab), "medical words")

Loaded 1439 medical words


In [65]:
# Label each user prompt as 1 (indicating medical-related words are present)
labels = [1] * len(user_prompts)

In [66]:
# Tokenize the user prompts
input_ids = []
attention_masks = []

for prompt in user_prompts:
    encoded_dict = tokenizer.encode_plus(
        prompt,  # Sentence to encode
        add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
        max_length=64,  # Pad & truncate all sentences.
        pad_to_max_length=True,
        truncation=True,
        return_attention_mask=True,  # Construct attn. masks.
        return_tensors="pt",  # Return pytorch tensors.
    )
    input_ids.append(encoded_dict["input_ids"])
    attention_masks.append(encoded_dict["attention_mask"])

In [67]:
# Convert the lists into tensors
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
labels = torch.tensor(labels)

In [68]:
# Create the DataLoader for our dataset
dataset = TensorDataset(input_ids, attention_masks, labels)
batch_size = 4
dataloader = DataLoader(dataset, sampler=RandomSampler(dataset), batch_size=batch_size)

In [69]:
# Load pre-trained BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased",
    num_labels=2,  # 2 classes: medical-related words present or not
    output_attentions=False,
    output_hidden_states=False,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [70]:
# Run the model on the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [71]:
# Define the optimizer and learning rate scheduler
optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

In [72]:
# Fine-tune BERT
epochs = 3
for epoch in range(epochs):
    model.train()
    total_loss = 0

    for batch in tqdm(dataloader, desc="Epoch {}".format(epoch + 1)):
        # Unpack the inputs from our dataloader
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        # Clear any previously calculated gradients
        model.zero_grad()

        # Perform a forward pass
        outputs = model(
            b_input_ids,
            token_type_ids=None,
            attention_mask=b_input_mask,
            labels=b_labels,
        )

        # Accumulate the training loss over all of the batches
        loss = outputs.loss
        total_loss += loss.item()

        # Perform a backward pass to calculate the gradients
        loss.backward()

        # Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Update parameters and take a step using the computed gradient
        optimizer.step()

    # Calculate the average loss over the training data
    avg_train_loss = total_loss / len(dataloader)

Epoch 1: 100%|██████████| 360/360 [13:57<00:00,  2.33s/it]
Epoch 2: 100%|██████████| 360/360 [13:50<00:00,  2.31s/it]
Epoch 3: 100%|██████████| 360/360 [13:36<00:00,  2.27s/it]


In [73]:
# Evaluate the fine-tuned model
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [78]:
# Sample user prompt for testing
new_prompt = "shoes cars bikes sandals."

In [79]:
# Tokenize the user prompt
inputs = tokenizer(new_prompt, return_tensors="pt")
inputs.to(device)

{'input_ids': tensor([[  101,  6007,  3765, 18105, 24617,  1012,   102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1]])}

In [80]:
# Make a prediction
with torch.no_grad():
    outputs = model(**inputs)

In [83]:
# Get the predicted class
predicted_class = torch.argmax(outputs.logits).item()
print(outputs.logits)

if predicted_class == 1:
    print("Medical-related words detected in the user prompt.")
else:
    print("No medical-related words detected in the user prompt.")

tensor([[-5.1778,  5.5801]])
Medical-related words detected in the user prompt.


In [62]:
# Load pre-trained BERT model (for feature extraction)
model = BertModel.from_pretrained("bert-base-uncased")
model.eval()

# Sample user prompt for testing
user_prompt = "I have a fever and headache."

# Tokenize the user prompt
tokens = tokenizer.tokenize(user_prompt)

# Convert tokens to IDs
token_ids = tokenizer.convert_tokens_to_ids(tokens)

# Convert token IDs to tensor
input_ids = torch.tensor([token_ids])

# Get the BERT model's hidden states (features)
with torch.no_grad():
    outputs = model(input_ids)
    hidden_states = outputs.last_hidden_state

# Retrieve the hidden states for medical-related tokens
medical_words = []
for i, token in enumerate(tokens):
    if token.startswith("##"):
        # Merge subword tokens
        medical_words[-1] += token[2:]
    else:
        # Check if token is in the medical vocabulary
        if token.lower() in medical_vocab:
            medical_words.append(token)

# Print the detected medical-related words
print("Medical-related words detected in the user prompt:", medical_words)

Medical-related words detected in the user prompt: []
