<a href="https://colab.research.google.com/github/GavinAbercrombie/medical-safety/blob/main/code/medical_safety_convai.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Risk-graded Safety for Handling Medical Queries in Conversational AI**

Code from the AACL 2022 paper.

Please cite as:

Gavin Abercrombie and Verena Rieser. 2022. Risk-graded Safety for Handling Medical Queries in Conversational AI. In Proceedings of The 2nd Conference of the Asia-Pacific Chapter of the Association for Computational Linguistics. Association for Computational Linguistics.



❗Mount google drive before using this notebook\
❗Set Runtime to GPU

**1. Install and import the necessary libraries** 

In [15]:
! pip install transformers

import csv
from spacy.tokenizer import Tokenizer
from spacy.lang.en import English
from random import shuffle
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, mean_squared_error
from sklearn.metrics import confusion_matrix
from imblearn.metrics import macro_averaged_mean_absolute_error

import numpy as np
import csv
import ast
from collections import Counter
import re
import os

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import f1_score

from keras import backend as K
import tensorflow as tf

# set random seed:
seed_val = 42
import random 
import torch

from keras.preprocessing.sequence import pad_sequences
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, TensorDataset
from torch.nn import BCEWithLogitsLoss
from tqdm import tqdm
from transformers import (
    WEIGHTS_NAME,
    AdamW,
    BertConfig,
    BertForSequenceClassification,
    BertTokenizer,
    get_linear_schedule_with_warmup,
)

**2. Import the data**

In [10]:
medical = list(csv.reader(open('drive/MyDrive/medicheck-expert.csv')))
negative = list(csv.reader(open('drive/MyDrive/medicheck-negative.csv')))

In [11]:
#@title 3. Select settings from the drop-down menus:

input = 'queries' #@param ['queries', 'responses']
labelling = 'ordinal' #@param ['binary', 'ternary','ordinal']

**4. Pre-process the labels**

In [12]:
labels = [0 for _ in range(len(negative))]

negative = [[i[0], j] for i, j in zip(negative, labels)]

if input == 'queries':
  if labelling == 'binary':
      medical_processed = [[m[0], int(m[1])] if m[1] == '0' else [m[0], 1] for m in medical[1:]]
  else:
      medical_processed = []
      for m in medical[1:]:
          if m[1] == '0':
              i = [m[0], [0, 0, 0]]
              medical_processed.append(i)
          elif m[1] == '1':
              i = [m[0], [1, 0, 0]]
              medical_processed.append(i)
          elif m[1] == '2':
              i = [m[0], [1, 1, 0]]
              medical_processed.append(i)
          elif m[1] == '3':
              i = [m[0], [1, 1, 1]]
              medical_processed.append(i)
      negative = [[i[0], [0, 0, 0]] for i in negative]


elif input == 'responses':
  if labelling == 'binary':
      medical_processed = []
      for m in medical[1:]:
          if m[3] == '0':
              i = [m[2], 0]
              medical_processed.append(i)
          elif m[3] == '1':
              i = [m[2], 1]
              medical_processed.append(i)
          elif m[3] == '2':
              i = [m[2], 1]
              medical_processed.append(i)
          elif m[3] == '3':
              i = [m[2], 1]
              medical_processed.append(i) 
          if m[6] == '0':
              i = [m[5], 0]
              medical_processed.append(i)
          elif m[6] == '1':
              i = [m[5], 1]
              medical_processed.append(i)
          elif m[6] == '2':
              i = [m[5], 1]
              medical_processed.append(i)
          elif m[6] == '3':
              i = [m[5], 1]
              medical_processed.append(i) 
          if m[9] == '0':
              i = [m[8], 0]
              medical_processed.append(i)
          elif m[9] == '1':
              i = [m[8], 1]
              medical_processed.append(i)
          elif m[9] == '2':
              i = [m[8], 1]
              medical_processed.append(i)
          elif m[9] == '3':
              i = [m[8], 1]
              medical_processed.append(i) 

  elif labelling == 'ternary':
      medical_processed = []
      for m in medical[1:]:
          if m[3] == 'X':
              i = [m[2], [1, 0, 0]]
              medical_processed.append(i) 
          elif m[3] == '0':
              i = [m[2], [0, 1, 0]]
              medical_processed.append(i)
          elif m[3] == '1':
              i = [m[2], [0, 0, 1]]
              medical_processed.append(i)
          elif m[3] == '2':
              i = [m[2], [0, 0, 1]]
              medical_processed.append(i)
          elif m[3] == '3':
              i = [m[2], [0, 0, 1]]
              medical_processed.append(i) 
          if m[6] == 'X':
              i = [m[6], [1, 0, 0]]
              medical_processed.append(i)            
          elif m[6] == '0':
              i = [m[5], [0, 1, 0]]
              medical_processed.append(i)
          elif m[6] == '1':
              i = [m[5], [0, 0, 1]]
              medical_processed.append(i)
          elif m[6] == '2':
              i = [m[5], [0, 0, 1]]
              medical_processed.append(i)
          elif m[6] == '3':
              i = [m[5], [0, 0, 1]]
              medical_processed.append(i)
          if m[9] == 'X':
              i = [m[9], [1, 0, 0]]
              medical_processed.append(i)              
          elif m[9] == '0':
              i = [m[8], [0, 1, 0]]
              medical_processed.append(i)
          elif m[9] == '1':
              i = [m[8], [0, 0, 1]]
              medical_processed.append(i)
          elif m[9] == '2':
              i = [m[8], [0, 0, 1]]
              medical_processed.append(i)
          elif m[9] == '3':
              i = [m[8], [0, 0, 1]]
              medical_processed.append(i) 

  else:
      medical_processed = []
      for m in medical[1:]:
          if m[3] == '0':
              i = [m[2], [0, 0, 0]]
              medical_processed.append(i)
          elif m[3] == '1':
              i = [m[2], [1, 0, 0]]
              medical_processed.append(i)
          elif m[3] == '2':
              i = [m[2], [1, 1, 0]]
              medical_processed.append(i)
          elif m[3] == '3':
              i = [m[2], [1, 1, 1]]
              medical_processed.append(i) 
          if m[6] == '0':
              i = [m[5], [0, 0, 0]]
              medical_processed.append(i)
          elif m[6] == '1':
              i = [m[5], [1, 0, 0]]
              medical_processed.append(i)
          elif m[6] == '2':
              i = [m[5], [1, 1, 0]]
              medical_processed.append(i)
          elif m[6] == '3':
              i = [m[5], [1, 1, 1]]
              medical_processed.append(i) 
          if m[9] == '0':
              i = [m[8], [0, 0, 0]]
              medical_processed.append(i)
          elif m[9] == '1':
              i = [m[8], [1, 0, 0]]
              medical_processed.append(i)
          elif m[9] == '2':
              i = [m[8], [1, 1, 0]]
              medical_processed.append(i)
          elif m[9] == '3':
              i = [m[8], [1, 1, 1]]
              medical_processed.append(i) 

**5. Mix and shuffle the data and create train, test, split sets**

In [13]:
if input == 'queries':
  all_data = medical_processed + other
else:
  all_data = medical_processed

shuffle(all_data)

len_data = len(all_data)

X_train, y_train  = [x[0] for x in all_data[:int(0.8*len_data)]], [y[1] for y in all_data[:int(0.8*len_data)]]
X_valid, y_valid  = [x[0] for x in all_data[int(0.8*len_data):int(0.9*len_data)]], [y[1] for y in all_data[int(0.8*len_data):int(0.9*len_data)]]
X_test, y_test    = [x[0] for x in all_data[int(0.9*len_data):]], [y[1] for y in all_data[int(0.9*len_data):]]

**6. Fine-tune the model, make predictions, and evaluate**

In [16]:
# functions for processing the results

all_y_trues, all_y_preds = [], []
ordinal_map = {-1: [-1,-1,-1], 0: [0,0,0], 1: [1,0,0], 2: [1,1,0], 3: [1,1,1]}
ordinal_map_reverse = {str(value): key for key, value in ordinal_map.items()}

def ordinal_results(y_test, y_pred):
    y_pred = [[int(x) for x in y] for y in y_pred]
    y_pred = [ordinal_map_reverse[str(y)] if str(y) in ordinal_map_reverse else 1 for y in y_pred]
    y_test = [list(y) for y in y_test]
    y_test = [ordinal_map_reverse[str(y)] for y in y_test]
    return [accuracy_score(y_test, y_pred), precision_score(y_test, y_pred, average='macro'), recall_score(y_test, y_pred, average='macro'), \
            f1_score(y_test, y_pred, average='macro'), f1_score(y_test, y_pred, average='micro'), mean_squared_error(y_test, y_pred), \
            macro_averaged_mean_absolute_error(y_test, y_pred)]



def ternary_results(y_test, y_pred):
    preds = [list(y).index(max(y)) for y in y_pred]
    tests = [list(y).index(max(y)) for y in y_test] # if sum(y) != -3 else -1 for y in y_test]
    y_pred, y_test = [], []
    for i, j in zip(preds, tests):
        y_pred.append(i)
        y_test.append(j)
    return [accuracy_score(y_test, y_pred), precision_score(y_test, y_pred, average='macro'), recall_score(y_test, y_pred, average='macro'), \
            f1_score(y_test, y_pred, average='macro'), f1_score(y_test, y_pred, average='micro'), mean_squared_error(y_test, y_pred), \
            macro_averaged_mean_absolute_error(y_test, y_pred)]

## code below adapted from https://osf.io/re4gd/, 
## which, in turn, uses code from https://mccormickml.com/2019/07/22/BERT-fine-tuning/

train_labels, valid_labels, test_labels = y_train, y_valid, y_test

eval_task = labelling

def load_train_test_data(tokenizer):
    ''' Function to load training and test data '''
    train_input_ids = []
    valid_input_ids = []
    test_input_ids = []

    for sent in X_train:
        encoded_sent = tokenizer.encode(
            sent,  # Sentence to encode.
            add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
            max_length=512,)
        train_input_ids.append(encoded_sent)

    for sent in X_valid:
        encoded_sent = tokenizer.encode(
            sent,  # Sentence to encode.
            add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
            max_length=512,)

        valid_input_ids.append(encoded_sent)

    for sent in X_test:
        encoded_sent = tokenizer.encode(
            sent,  # Sentence to encode.
            add_special_tokens=True,  # Add '[CLS]' and '[SEP]'
            max_length=512,)

        test_input_ids.append(encoded_sent)

    train_input_ids = pad_sequences(train_input_ids, maxlen=100, dtype="long",
                          value=tokenizer.pad_token_id, truncating="pre", padding="pre")
    
    valid_input_ids = pad_sequences(valid_input_ids, maxlen=100, dtype="long",
                          value=tokenizer.pad_token_id, truncating="pre", padding="pre")

    test_input_ids = pad_sequences(test_input_ids, maxlen=100, dtype="long",
                                    value=tokenizer.pad_token_id, truncating="pre", padding="pre")
    train_attention_masks = []
    valid_attention_masks = []
    test_attention_masks = []

    for sent in train_input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]

        train_attention_masks.append(att_mask)

    for sent in valid_input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]

        valid_attention_masks.append(att_mask)

    for sent in test_input_ids:
        att_mask = [int(token_id > 0) for token_id in sent]

        test_attention_masks.append(att_mask)

    return train_input_ids, train_labels, train_attention_masks, valid_input_ids, valid_labels, valid_attention_masks, test_input_ids, test_labels, test_attention_masks

def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

config_class, model_class, tokenizer_class = (BertConfig, BertForSequenceClassification, BertTokenizer)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if eval_task == 'binary':
    n_labels = 2 # The number of output labels is 2 for binary classification.
else: 
    n_labels = len(y_train[0])

model_name = "bert-base-uncased"

tokenizer = tokenizer_class.from_pretrained(model_name)

model = BertForSequenceClassification.from_pretrained(
     model_name, 
     num_labels = n_labels, 
     output_attentions = False,
     output_hidden_states = False,
)

model.to(device)

train_inputs, train_labels, train_masks, valid_inputs, valid_labels, valid_masks, test_inputs, test_labels, test_masks = load_train_test_data(tokenizer)

train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(valid_inputs)
test_inputs = torch.tensor(test_inputs)

train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(valid_labels)
test_labels = torch.tensor(test_labels)

train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(valid_masks)
test_masks = torch.tensor(test_masks)

batch_size = 32

train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

test_data = TensorDataset(test_inputs, test_masks, test_labels)
test_sampler = SequentialSampler(test_data)
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=batch_size)

optimizer = AdamW(model.parameters(),
                  lr=1e-5,
                  eps=1e-8)

epochs = 3

total_steps = len(train_dataloader) * epochs

scheduler = get_linear_schedule_with_warmup(optimizer,
                                            num_warmup_steps=0,  # Default value in run_glue.py
                                            num_training_steps=total_steps)

for epoch_i in tqdm(range(0, epochs), desc="Training"):

    gold_labels = []
    predicted_labels = []

    total_loss = 0

    model.train()

    for step, batch in tqdm(enumerate(train_dataloader), desc="Batch"):

        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        if eval_task == 'binary':
            b_labels = batch[2].to(device)
        else:
            b_labels = batch[2].float().to(device)

        model.zero_grad()

        outputs = model(b_input_ids,
                        token_type_ids=None,
                        attention_mask=b_input_mask,
                        labels=b_labels)
        
        if eval_task == 'binary':
            loss = outputs[0]
        else: # masked loss function to account for missing labels (-1s)
            loss_func = BCEWithLogitsLoss(reduction='none') 
            logits = outputs['logits']
            loss = loss_func(logits,b_labels.type_as(logits)) #convert labels to float for calculation
            mask = torch.cuda.FloatTensor([[1 if l != -1 else 0 for l in y] for y in b_labels])
            loss = torch.mean(loss*mask)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()

        scheduler.step()

    model.eval()

    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0

    for batch in validation_dataloader:
        batch = tuple(t.to(device) for t in batch)

        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():
            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask)

        if eval_task == 'binary':
            logits = outputs[0]
        else:
            logits = outputs['logits']
      
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        tmp_eval_accuracy = flat_accuracy(logits, label_ids)

        eval_accuracy += tmp_eval_accuracy

        nb_eval_steps += 1

        pred_flat = np.argmax(logits, axis=1).flatten()
        labels_flat = label_ids.flatten()
        
        gold_labels.extend(labels_flat)
        predicted_labels.extend(pred_flat) 

def evaluate(model, test_loader):
    y_pred = []
    y_true = []

    model.eval()

    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0

    # Evaluate data for one epoch
    for batch in test_loader:

        batch = tuple(t.to(device) for t in batch)

        b_input_ids, b_input_mask, b_labels = batch

        with torch.no_grad():
            outputs = model(b_input_ids,
                            token_type_ids=None,
                            attention_mask=b_input_mask)
           
        logits = outputs[0]

        if eval_task == 'binary':
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()
            tmp_eval_accuracy = flat_accuracy(logits, label_ids)
            eval_accuracy += tmp_eval_accuracy
            nb_eval_steps += 1
            pred_flat = np.argmax(logits, axis=1).flatten()
            labels_flat = label_ids.flatten()
            y_true.extend(labels_flat)
            y_pred.extend(pred_flat) 
        else:
            b_logit_pred = outputs[0]
            pred_label = (torch.sigmoid(b_logit_pred)>0.5).float()
            pred_label = pred_label.detach().cpu().numpy()
            logits = logits.detach().cpu().numpy()
            label_ids = b_labels.to('cpu').numpy()
            pred_list = [list(l) for l in pred_label]
            labels_list = [list(l) for l in label_ids]
            y_true.extend(labels_list)      
            y_pred.extend(pred_list)

    if eval_task == 'binary':
        results = [accuracy_score(y_test, y_pred), precision_score(y_test, y_pred), recall_score(y_test, y_pred), \
        f1_score(y_test, y_pred, average='macro'), f1_score(y_test, y_pred, average='micro'), macro_averaged_mean_absolute_error(y_test, y_pred)]
    elif eval_task == 'ternary':
        results = ternary_results(y_true, y_pred)
    else:
        results = ordinal_results(y_true,y_pred) 
        all_y_trues.extend(y_true)
        all_y_preds.extend(y_pred)

    return results

print('\nResults score for', eval_task)
metrics = ['Accuracy','Precision', 'Recall', 'F1 macro', 'F1 micro', 'MSE', 'Macro MSE']
results = evaluate(model, test_dataloader)
for i in range(7):
    print(metrics[i], results[i])

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at


Results score for ordinal
Accuracy Precision Recall F1 macro F1 micro MSE Macro MSE
[0.8801369863013698, 0.4410639695874595, 0.4725081699346405, 0.4539586054455943, 0.8801369863013698, 0.16095890410958905, 0.7774918300653595]


  _warn_prf(average, modifier, msg_start, len(result))
