In [1]:
import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import TensorDataset, DataLoader, random_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score

In [None]:
df = pd.read_csv("intent_classification_dataset_v1.csv")

In [3]:
label_encoder = LabelEncoder()
df["intent_label"] = label_encoder.fit_transform(df["intent"])

In [4]:
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
def tokenize_data(texts):
    return tokenizer(texts, padding=True, truncation=True, return_tensors="pt")

In [6]:
tokens = tokenize_data(list(df["query"]))
labels = torch.tensor(df["intent_label"].values)

In [7]:
dataset = TensorDataset(tokens["input_ids"], tokens["attention_mask"], labels)
train_size = int(0.8 * len(dataset))
train_dataset, val_dataset = random_split(dataset, [train_size, len(dataset) - train_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
optimizer = AdamW(model.parameters(), lr=5e-5)



In [9]:
EPOCHS = 3
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in train_loader:
        input_ids, attention_mask, labels = [x.to(device) for x in batch]
        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}, Loss: {total_loss / len(train_loader)}")

Epoch 1, Loss: 0.6995056395977736
Epoch 2, Loss: 0.024732803255319596
Epoch 3, Loss: 0.008703706609085203


In [10]:
model.eval()
preds, true_labels = [], []
with torch.no_grad():
    for batch in val_loader:
        input_ids, attention_mask, labels = [x.to(device) for x in batch]
        outputs = model(input_ids, attention_mask=attention_mask)
        predictions = torch.argmax(outputs.logits, dim=1)
        preds.extend(predictions.cpu().numpy())
        true_labels.extend(labels.cpu().numpy())

accuracy = accuracy_score(true_labels, preds)
print(f"Validation Accuracy: {accuracy * 100:.2f}%")

Validation Accuracy: 100.00%


In [15]:
def predict_intent(query):
    model.eval()
    tokens = tokenizer(query, return_tensors="pt", padding=True, truncation=True).to(device)
    with torch.no_grad():
        output = model(tokens["input_ids"], attention_mask=tokens["attention_mask"])
        pred_label = torch.argmax(output.logits, dim=1).cpu().item()
    return label_encoder.inverse_transform([pred_label])[0]

print(predict_intent("Schedule a blood test for next week."))  # book_test
print(predict_intent("Cancel my MRI appointment for tomorrow."))  # cancel_appointment
print(predict_intent("Analyze my blood report and give insights."))  # analyze_report
print(predict_intent("I need to upload my X-ray report."))  # upload_record
print(predict_intent("Retrieve my last cholesterol test report."))  # retrieve_record
print(predict_intent("Book an appointment for a full body checkup."))  # book_test
print(predict_intent("I want to cancel my dental checkup appointment."))  # cancel_appointment
print(predict_intent("Give me a detailed analysis of my health report."))  # analyze_report
print(predict_intent("Where can I upload my latest medical documents?"))  # upload_record
print(predict_intent("Fetch my recent blood test results."))  # retrieve_record


book_test
cancel_appointment
analyze_report
upload_record
retrieve_record
book_test
cancel_appointment
analyze_report
upload_record
retrieve_record
