In [12]:
!pip install datasets nltk transformers torch PyTorch

clear_output()

In [1]:
from IPython.display import clear_output
import pandas as pd
import nltk
from nltk.tokenize import RegexpTokenizer
import warnings
from transformers import AutoTokenizer
from datasets import load_dataset
from transformers import BertTokenizer, BertModel
import torch

# Suppress all warnings in cells printouts for clear output
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


# Load dataset

In [2]:
dataset = load_dataset("medical_questions_pairs")

for split in dataset:
    print(split)



train


# Parse and clean dataset


In [3]:
df = pd.DataFrame(dataset['train'])
df.head()

Unnamed: 0,dr_id,question_1,question_2,label
0,1,After how many hour from drinking an antibioti...,I have a party tonight and I took my last dose...,1
1,1,After how many hour from drinking an antibioti...,I vomited this morning and I am not sure if it...,0
2,1,Am I over weight (192.9) for my age (39)?,I am a 39 y/o male currently weighing about 19...,1
3,1,Am I over weight (192.9) for my age (39)?,What diet is good for losing weight? Keto or v...,0
4,1,Aspirin allergy - is it worth getting a bracelet?,How much Aspirin can I take for my headache wi...,0


In [4]:
question_1 = 'question_1'
question_2 = 'question_2' # column names as vars for convenienct

In [5]:
df.drop('dr_id', axis=1, inplace=True)  # keep only question pairs and labels. Label 1 means match

In [6]:
df.shape

(3048, 3)

Properties of a good set:
1. No paired duplicates (no rows with same q1 and q2 values as a pair). Ensure removal of inverse duplication:
   q1,q2 and q2,q1
2. No NaN
3. Many to many relation is possible

In [7]:
# Check for NaN and remove
nan_cols = df.isnull().any()
print(nan_cols)

question_1    False
question_2    False
label         False
dtype: bool


In [8]:
# Remove paired duplicates 
df['isometric_pair'] = df.apply(lambda x: tuple(sorted([x[question_1], x[question_2]])), axis=1)

# Remove duplicates based on the normalized pairs
df = df.drop_duplicates(subset=['isometric_pair'])

# Drop the auxiliary column
df = df.drop(columns=['isometric_pair'])

In [9]:
df.shape  # see if something was removed

(3048, 3)

Outcome: clean dataset without pairs of duplicates

# Embedding: transform each question into vector

    1. Tokenize question to words
    2. Embed words to vectors with BERT
    3. Make an average vector to represent the sentence (or weighted average: like TF-IDF scores, giving more importance to certain words.)    
    4. return question as a vector
    5. add vectorized question to table


In [22]:
def tokenize(question):
    tokenizer = RegexpTokenizer(r'\w+')
    tokens = tokenizer.tokenize(question) # split
    tokens = [token.lower() for token in tokens] # decapitalize
     
    #tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")  # native BERT format preserves punctuation as well, not needed
    #tokens = tokenizer.tokenize(question)
    
    return tokens


In [29]:
def embed(tokens):

    bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    
    # Convert custom tokens into a string for BERT special format
    custom_text = ' '.join(tokens)
    
    # Tokenize this text using BERT's tokenizer
    bert_inputs = bert_tokenizer(custom_text, return_tensors="pt", padding=True, truncation=True)
    
    # Pass the tokenized inputs through the BERT model
    bert_model = BertModel.from_pretrained("bert-base-uncased")
    bert_outputs = bert_model(**bert_inputs)
    
    # Extracting embeddings
    embeddings = bert_outputs.last_hidden_state
    embedded_question = embeddings.mean(dim=1)[0]  # single vector representing the entire question.
    return embedded_question.detach().numpy()
    

In [30]:
text = "Hello, world! How are you doing today?"
tokens = tokenize(text)
embedded_question = embed(tokens)
print(embedded_question)

tensor([ 1.2027e-01, -4.8629e-01,  2.5507e-01, -2.2412e-01,  3.7272e-01,
        -4.1045e-01,  2.9779e-01,  8.1002e-01, -2.2958e-01, -4.4690e-01,
         2.2236e-01, -4.5837e-01,  1.4484e-02,  3.6392e-01, -4.3390e-01,
         4.2487e-01,  1.8755e-01,  3.8753e-01,  5.4501e-02,  5.4168e-01,
        -1.7867e-01, -2.4879e-01,  2.1380e-02,  3.6549e-01,  3.7245e-01,
        -2.8870e-01, -1.0254e-01, -4.9800e-02, -1.2759e-01, -4.6746e-01,
        -1.4162e-01,  1.6437e-01, -3.9243e-01,  5.7495e-02,  1.8970e-01,
        -1.9391e-01, -4.2002e-02,  1.2921e-01, -5.8789e-01,  3.4880e-01,
        -8.3487e-01, -6.1177e-01,  3.2349e-01,  3.6605e-01, -4.9621e-01,
        -1.2292e+00,  2.7390e-01, -3.4765e-02, -2.4528e-01,  2.0516e-01,
        -2.4259e-01,  3.6626e-01, -8.0943e-01, -2.7043e-02, -3.5219e-01,
         6.2473e-01, -1.8816e-01, -6.5396e-01, -4.7406e-01,  7.0624e-02,
         1.6397e-01, -3.8148e-01, -1.1909e-01, -4.4703e-01,  3.4184e-01,
         6.7240e-01, -8.7383e-02,  5.6854e-01, -6.1