In [None]:
import os
import ssl
import torch
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from transformers import BertTokenizerFast, BertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, average_precision_score

In [None]:
os.environ['SSL_CERT_FILE'] = ssl.get_default_verify_paths().openssl_cafile
# Read data to DataFrame
df = pd.read_csv('G:/labeled_misdiagnosis.csv')
train_validation = df[['主诉内容', 'Misdiag']]
test_data1 = pd.read_csv('./code/data/pretrained/train.txt', sep='\t')
test_data2 = pd.read_csv('./code/data/pretrained/test.txt', sep='\t')
test = pd.concat([test_data1, test_data2], ignore_index=True)
from sklearn.model_selection import train_test_split
train, validation = train_test_split(train_validation, test_size=0.20, random_state=42)
# Rename columns in training set
train = train.rename(columns={'主诉内容': 'text'})
train.reset_index(inplace=True, drop=True)
train['Misdiag'] = train['Misdiag'].astype(int)
# Rename columns in validation set
validation = validation.rename(columns={'主诉内容': 'text'})
validation.reset_index(inplace=True, drop=True)
validation['Misdiag'] = validation['Misdiag'].astype(int)
# Rename columns in test set
test = test.rename(columns={'Complain': 'text'})
test.reset_index(inplace=True, drop=True)
test['Misdiag'] = test['Misdiag'].astype(int)

In [None]:
class CustomDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=128):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        text = row['text']
        label = row['Misdiag']
        inputs = self.tokenizer(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True)
        inputs = {key: value.squeeze(0) for key, value in inputs.items()}  # Remove the batch dimension
        inputs['labels'] = torch.tensor(label)
        return inputs

In [None]:
model_name = './code/bert_pretrained'
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)

In [None]:
train_dataset = CustomDataset(train, tokenizer)
validation_dataset = CustomDataset(validation, tokenizer)
test_dataset = CustomDataset(test, tokenizer)

In [None]:
def collate_fn(batch):
    input_keys = batch[0].keys()
    outputs = {key: torch.stack([item[key] for item in batch]) for key in input_keys}
    return outputs
training_args = TrainingArguments(
    output_dir='./code/bert_pretrained1/results',
    num_train_epochs=5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    warmup_steps=500,
    weight_decay=0.01,
    learning_rate=1e-4,  # Here is where you define the learning rate
    evaluation_strategy="steps",
    logging_dir='./code/bert_pretrained1/logs',
)
def compute_metrics(eval_pred):
    predictions, true_labels = eval_pred.predictions, eval_pred.label_ids

    # convert logits to probabilities
    probabilities = torch.nn.functional.softmax(torch.from_numpy(predictions), dim=-1).numpy()

    accuracy = accuracy_score(true_labels, np.argmax(predictions, axis=-1))
    precision = precision_score(true_labels, np.argmax(predictions, axis=-1), average='weighted', zero_division=1)
    recall = recall_score(true_labels, np.argmax(predictions, axis=-1), average='weighted')
    f1 = f1_score(true_labels, np.argmax(predictions, axis=-1), average='weighted')

    auroc = roc_auc_score(true_labels, probabilities[:, 1])
    auprc = average_precision_score(true_labels, probabilities[:, 1], average='weighted')

    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'auroc': auroc,
        'auprc': auprc,
    }
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)
trainer.train()
test_results = trainer.predict(test_dataset)
test_metrics = compute_metrics(test_results)
print("Test Set Metrics:")
for key, value in test_metrics.items():
    print(f"{key}: {value}")