In [106]:
import os
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import BertTokenizer, BertModel, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
import pandas as pd

In [107]:
hc = [
    'what is the role of a primary care physician',
    'how do vaccines work in our body',
    'how does a healthy diet impact your overall health',
    'what is the best way to maintain a healthy weight',
    'can you explain the different types of health insurance plans',
    'what are the benefits of regular exercise',
    'how does stress affect the body',
    'how often should i get a general health checkup',
    'how important is sleep for health',
    'what are the common symptoms of diabetes',
    'can you explain the health risks of obesity',
    'what is telemedicine',
    'how does mental health affect physical health',
    'what are some ways to prevent heart disease',
    'can you explain the concept of herd immunity',
    'what are the most common mental health disorders',
    'how does smoking affect the lungs',
    'what are the health risks of high cholesterol',
    'what is the difference between type one and type two diabetes',
    'can you explain the health benefits of yoga',
    'what is the role of genetics in health',
    'what are the early signs of alzheimers disease',
    'how can i protect my skin from the sun',
    'how does hypertension affect the body',
    'what are some strategies for managing chronic pain',
    'what is the human microbiome and how does it affect health',
    'what is the role of the world health organization',
    'how can diet affect mental health',
    'what are the early symptoms of parkinsons disease',
    'can you explain the different types of cancer',
    'how important is early detection in cancer',
    'what is the connection between gut health and mood',
    'what is the effect of climate change on health',
    'what are the benefits of breastfeeding',
    'how can a pregnant woman ensure the health of her baby',
    'what are the health risks of alcohol consumption',
    'what are the symptoms of a stroke',
    'how is hiv transmitted',
    'what are some preventive measures against the common cold',
    'how can i improve my dental health',
    'what are the benefits of regular eye checkups',
    'what is osteoporosis and how can it be prevented',
    'what are the risk factors for heart disease',
    'what are the symptoms of a heart attack in women',
    'how can i increase my bone density',
    'what is palliative care',
    'what are the common causes of hair loss',
    'what is the role of antioxidants in health',
    'how does air pollution affect our health',
    'what are some strategies to quit smoking',
    'how does regular physical activity benefit mental health',
    'what is the difference between a dietitian and a nutritionist',
    'how is autism diagnosed',
    'how can you manage the symptoms of adhd',
    'what are the health benefits of drinking water',
    'how does exposure to secondhand smoke affect health',
    'what are the symptoms of lung cancer',
    'what is the role of probiotics in gut health',
    'how does menopause affect a womans health',
    'what are the health benefits of green tea',
    'what is the function of the liver in the human body',
    'what are the early signs of liver disease',
    'what is the impact of high blood sugar levels',
    'how is depression diagnosed and treated',
    'what is the difference between dementia and alzheimers',
    'how can lifestyle changes help manage type two diabetes',
    'what is bipolar disorder',
    'what is the difference between an allergy and an intolerance'
]

nhc = [
    'what is the capital of australia',
    'who wrote the book pride and prejudice',
    'what is the tallest mountain in the world',
    'how deep is the mariana trench',
    'what is the speed of light',
    'who painted the mona lisa',
    'what is quantum physics',
    'who is the current president of the united states',
    'how old is the earth',
    'what is the largest planet in our solar system',
    'who won the oscar for best picture in 2020',
    'what is the longest river in the world',
    'who discovered penicillin',
    'what is the square root of 144',
    'what is the largest ocean in the world',
    'how far is the moon from the earth',
    'who invented the telephone',
    'what is the population of china',
    'how many countries are there in the world',
    'who composed the fifth symphony',
    'what is the tallest building in the world',
    'what is the chemical formula for water',
    'who is the richest person in the world',
    'what is the highest waterfall in the world',
    'who is the current prime minister of the united kingdom',
    'what is the largest country by land area',
    'what is the fastest animal on land',
    'who directed the movie titanic',
    'how old is the universe',
    'what is the hottest planet in our solar system',
    'who won the world cup in soccer in 2018',
    'what is the capital of canada',
    'who wrote the novel 1984',
    'what is the theory of relativity',
    'who is the current secretary general of the united nations',
    'how old is the great wall of china',
    'what is the smallest planet in our solar system',
    'who painted the starry night',
    'what is the boiling point of water at sea level',
    'who is the current chancellor of germany',
    'how long is the amazon river',
    'who discovered gravity',
    'what is the cube root of 27',
    'what is the smallest ocean in the world',
    'how far is the sun from the earth',
    'who invented the light bulb',
    'what is the population of india',
    'how many states are there in the united states',
    'who composed the magic flute',
    'what is the longest bridge in the world',
    'what is the chemical formula for carbon dioxide',
    'who is the current queen of the united kingdom'
]

In [108]:
def load_data():
    text = hc + nhc
    labels = [1] * len(hc) + [0] * len(nhc)
    return text, labels

texts, labels = load_data()

In [109]:
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
            self.texts = texts
            self.labels = labels
            self.tokenizer = tokenizer
            self.max_length = max_length
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        return {'input_ids': encoding['input_ids'].flatten(), 'attention_mask': encoding['attention_mask'].flatten(), 'label': torch.tensor(label)}

In [110]:
class BERTClassifier(nn.Module):
    def __init__(self, bert_model_name, num_classes):
        super(BERTClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(self.bert.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
            outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output
            x = self.dropout(pooled_output)
            logits = self.fc(x)
            return logits

In [111]:
def train(model: BERTClassifier, data_loader, optimizer, scheduler, device):
    model.train()
    for batch in data_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)
        outputs = model(input_ids, attention_mask)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()

In [112]:
def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    actual_labels = []
    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            outputs = model(input_ids=input_ids, attention_mask=attention_mask)
            _, preds = torch.max(outputs, dim=1)
            predictions.extend(preds.cpu().tolist())
            actual_labels.extend(labels.cpu().tolist())
    return accuracy_score(actual_labels, predictions), classification_report(actual_labels, predictions)

In [113]:
def predict(text, model, tokenizer, device, max_length=128):
    model.eval()
    encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True)
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, preds = torch.max(outputs, dim=1)
        return "medical" if preds.item() == 1 else "non-medical"

In [114]:
# Set up parameters
bert_model_name = 'bert-base-uncased'
num_classes = 2
max_length = 128
batch_size = 16
num_epochs = 4
learning_rate = 2e-5

In [115]:
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.2, random_state=42)

In [116]:
tokenizer = BertTokenizer.from_pretrained(bert_model_name)
train_dataset = TextClassificationDataset(train_texts, train_labels, tokenizer, max_length)
val_dataset = TextClassificationDataset(val_texts, val_labels, tokenizer, max_length)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)

In [117]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BERTClassifier(bert_model_name, num_classes).to(device)

In [118]:
optimizer = AdamW(model.parameters(), lr=learning_rate)
total_steps = len(train_dataloader) * num_epochs
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)



In [119]:
for epoch in range(num_epochs):
        print(f"Epoch {epoch + 1}/{num_epochs}")
        train(model, train_dataloader, optimizer, scheduler, device)
        accuracy, report = evaluate(model, val_dataloader, device)
        print(f"Validation Accuracy: {accuracy:.4f}")
        print(report)

Epoch 1/4
Validation Accuracy: 0.9583
              precision    recall  f1-score   support

           0       1.00      0.88      0.93         8
           1       0.94      1.00      0.97        16

    accuracy                           0.96        24
   macro avg       0.97      0.94      0.95        24
weighted avg       0.96      0.96      0.96        24

Epoch 2/4
Validation Accuracy: 0.9583
              precision    recall  f1-score   support

           0       0.89      1.00      0.94         8
           1       1.00      0.94      0.97        16

    accuracy                           0.96        24
   macro avg       0.94      0.97      0.95        24
weighted avg       0.96      0.96      0.96        24

Epoch 3/4
Validation Accuracy: 0.9583
              precision    recall  f1-score   support

           0       0.89      1.00      0.94         8
           1       1.00      0.94      0.97        16

    accuracy                           0.96        24
   macro avg  

In [120]:
torch.save(model.state_dict(), "bert_classifier.pth")

In [121]:
test_text = "Why is my blood pressure high?"
classification = predict(test_text, model, tokenizer, device)
print(test_text)
print(f"Predicted: {classification}")

Why is my blood pressure high?
Predicted: medical
