In [None]:
import gc
import os
import ssl
import torch
import pandas as pd
import numpy as np
from tqdm import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from transformers import BertTokenizerFast, BertForSequenceClassification, Trainer, TrainingArguments, MarianMTModel, MarianTokenizer
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score, roc_curve, auc, precision_recall_curve

In [None]:
os.environ['SSL_CERT_FILE'] = ssl.get_default_verify_paths().openssl_cafile
# Read data to DataFrame
df = pd.read_csv('G:/labeled_misdiagnosis.csv')
train_validation = df[['主诉内容', 'Misdiag']]
test_data1 = pd.read_csv('./data/pretrained/train.txt', sep='\t')
test_data2 = pd.read_csv('./data/pretrained/test.txt', sep='\t')
test = pd.concat([test_data1, test_data2], ignore_index=True)
from sklearn.model_selection import train_test_split
train, validation = train_test_split(train_validation, test_size=0.20, random_state=42)
# Rename columns in training set
train = train.rename(columns={'主诉内容': 'text'})
train.reset_index(inplace=True, drop=True)
train['Misdiag'] = train['Misdiag'].astype(int)
# Rename columns in validation set
validation = validation.rename(columns={'主诉内容': 'text'})
validation.reset_index(inplace=True, drop=True)
validation['Misdiag'] = validation['Misdiag'].astype(int)
# Rename columns in test set
test = test.rename(columns={'Complain': 'text'})
test.reset_index(inplace=True, drop=True)
test['Misdiag'] = test['Misdiag'].astype(int)
misdiag_data = train[train['Misdiag'] == 1]

In [None]:
# model used for translate Chinese to English
target_model_name = './data_augumentation/chinese_to_english'
target_tokenizer_name = './data_augumentation/chinese_to_english'
target_model = MarianMTModel.from_pretrained(target_model_name)
target_tokenizer = MarianTokenizer.from_pretrained(target_tokenizer_name)

# model used for translate English to Chinese
output_model_name = './data_augumentation/english_to_chinese'
output_tokenizer_name = './data_augumentation/english_to_chinese'
output_model = MarianMTModel.from_pretrained(output_model_name)
output_tokenizer = MarianTokenizer.from_pretrained(output_tokenizer_name)

In [None]:
def translate(text, model, tokenizer, language="zh"):
    encoded = tokenizer(text, truncation=True, padding='longest', return_tensors="pt")
    # Generate translation using model
    translated = model.generate(**encoded)
    # Convert the generated tokens indices back into text
    translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)
    return translated_text

def back_translate(texts, source_lang="en", target_lang="fr"):
    # Translate from source to target language
    en_texts = translate(texts, target_model, target_tokenizer, language=target_lang)
    # Translate from target language back to source language
    back_translated_texts = translate(en_texts, output_model, output_tokenizer, language=source_lang)
    return back_translated_texts

In [None]:
chinese_texts = misdiag_data['text'].tolist()
# Perform back translation
augumentated_texts = []
for i in tqdm(range(len(chinese_texts)), ncols=100):
    aug_text = back_translate(chinese_texts[i], source_lang="zh", target_lang="en")
    augumentated_texts.append(aug_text)
    gc.collect()

In [None]:
flat_list = [item for sublist in augumentated_texts for item in sublist]

In [None]:
new_data = pd.DataFrame({
    'text': flat_list,
    'Misdiag': [1]*len(flat_list)
})

In [None]:
train = train.append(new_data, ignore_index=True)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        text = row['text']
        label = row['Misdiag']
        inputs = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        inputs = {key: value.squeeze(0) for key, value in inputs.items()}  # Remove the batch dimension
        inputs['labels'] = torch.tensor(label)
        return inputs

In [None]:
model_name = './code/bert_pretrained3'
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
train_dataset = CustomDataset(train, tokenizer)
validation_dataset = CustomDataset(validation, tokenizer)
test_dataset = CustomDataset(test, tokenizer)

In [None]:
def collate_fn(batch):
    input_keys = batch[0].keys()
    outputs = {key: torch.stack([item[key] for item in batch]) for key in input_keys}
    return outputs
training_args = TrainingArguments(
    output_dir='./code/bert_pretrained/results',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    learning_rate=1e-4,  # Here is where you define the learning rate
    evaluation_strategy="steps",
    logging_dir='./code/bert_pretrained/logs',
)
def compute_metrics(eval_pred):
    predictions, true_labels = eval_pred.predictions, eval_pred.label_ids

    # convert logits to probabilities
    probabilities = torch.nn.functional.softmax(torch.from_numpy(predictions), dim=-1).numpy()

    accuracy = accuracy_score(true_labels, np.argmax(predictions, axis=-1))
    precision = precision_score(true_labels, np.argmax(predictions, axis=-1), average='weighted', zero_division=1)
    recall = recall_score(true_labels, np.argmax(predictions, axis=-1), average='weighted')
    f1 = f1_score(true_labels, np.argmax(predictions, axis=-1), average='weighted')

    auroc = roc_auc_score(true_labels, probabilities[:, 1])
    auprc = average_precision_score(true_labels, probabilities[:, 1], average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auroc': auroc,
        'auprc': auprc,
    }
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()
test_results = trainer.predict(test_dataset)
test_metrics = compute_metrics(test_results)
print("Test Set Metrics:")
for key, value in test_metrics.items():
    print(f"{key}: {value}")

In [None]:
# Load the trained model
model = BertForSequenceClassification.from_pretrained('./code/bert_pretrained3/results/checkpoint-7500')
# Predict probabilities
predictions = trainer.predict(test_dataset)
# Convert logits to probabilities
probabilities = torch.nn.functional.softmax(torch.from_numpy(predictions.predictions), dim=-1).numpy()
import matplotlib.pyplot as plt
plt.hist(probabilities[:, 1], bins=10)
plt.show()

In [None]:
# Apply threshold
threshold = 0.5
predicted_labels = (probabilities[:, 1] >= threshold).astype(int)

# Evaluate
accuracy = accuracy_score(test['Misdiag'], predicted_labels)
precision = precision_score(test['Misdiag'], predicted_labels, average='weighted', zero_division=1)
recall = recall_score(test['Misdiag'], predicted_labels, average='weighted')
f1 = f1_score(test['Misdiag'], predicted_labels, average='weighted')
auroc = roc_auc_score(test['Misdiag'], probabilities[:, 1])
auprc = average_precision_score(test['Misdiag'], probabilities[:, 1], average='weighted')

print(f"Accuracy: {accuracy}")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1}")
print(f"AUROC: {auroc}")
print(f"AUPRC: {auprc}")

In [None]:
pd.DataFrame(predicted_labels).to_csv('G:/bert_data_augumentation_free_text.csv')

In [None]:
precision, recall, _ = precision_recall_curve(test['Misdiag'], probabilities[:, 1])
pr_auc = auc(recall, precision)
pr_auc = auc(recall, precision)
fpr, tpr, _ = roc_curve(test['Misdiag'], probabilities[:, 1])
roc_auc = auc(fpr, tpr)
plt.figure()
lw = 2  # Line width

# ROC curve
plt.subplot(1, 2, 1)
plt.plot(fpr, tpr, color='darkorange', lw=lw, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')  # Random classifier
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")

# Precision-Recall curve
plt.subplot(1, 2, 2)
plt.plot(recall, precision, color='blue', lw=lw, label='PR curve (area = %0.2f)' % pr_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")

plt.tight_layout()
plt.show()

In [None]:
pd.DataFrame(train).to_csv('G:/train.csv')
pd.DataFrame(validation).to_csv('G:/validation.csv')
pd.DataFrame(test).to_csv('G:/test.csv')