### Summary notes
- Strong logistic regression benchmarks for "Relevant" predictions (>80% accuracy offline). Performs better than Xgboost and BERT [in google colab] (without extensive hyperparameter tuning).
- Could improve further by tidying up input labeling (e.g. remove all vision / general neural network stuff)
- Should also implement full cross validation to ensure consistent performance across sample (random tests suggests does meaningfully differ).
- Adding query and source to title feature both as embeddings and dummy vars only very marginally improves performance (so leave out in first prod version).
- In production, currently using a logistic regression model to classify whether an article is relevant (Y or N) that takes word embeddings of the article title as the input feature. 
- Predicting "Read" harder as <5% of cases and probably inconsistent rationale (basic log reg model just predicts all not read), so more of an outlier detection piece (may need different approach). 
- Difficult deployment / file size constraints using basic (free) AWS Lambda file upload where currently have to upload all python depedencies (huge files for sentence transforers / embeddings). Replaced with an AWS ECR deployment approach (details in README.md) which incurs a small hosting cost (~1USD / month).
- Performance in production of v1 model to 6th Dec 2023, better at identifying truly relevant articles (true positives in sensitivity) vs truly non relevant articles (true negatives in specificity). Preferred weighting as would rather have false positives than false negatives. 
    Accuracy:  0.8700389105058366
    Sensitivity: 0.9118303571428571
    Specificity: 0.8476702508960573

#### Next steps
- Review model performance again on new queries
- Add back in query as a feature (vectors and dummy vars) and conduct a more granular performance review by query
- Extend multi-feature approach to BERT + explore more tuning to improve performance
- Research data outlier techniques to improve "Read" prediction or class weighting

### General setup

In [None]:
#packages
from datetime import datetime
import numpy as np
import pandas as pd
import pickle
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split, cross_val_predict, StratifiedKFold, cross_validate
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, roc_auc_score, recall_score, precision_score
import torch
from transformers import BertTokenizer, BertForSequenceClassification, AdamW
from torch.utils.data import DataLoader, Dataset
from torch.nn import functional as F
import xgboost as xgb

In [None]:
#helper functions (could move)

def convert_outcome_to_bin(x):
    if x == "Y":
        return 1
    else: 
        return 0

def evaluate(target, prediction, prediction_prob):
    # Calculate the accuracy of the model
    accuracy = accuracy_score(target, prediction)

    # Calculate the AUC score
    auc_score = roc_auc_score(target, prediction_prob)

    # Calculate the confusion matrix
    conf_matrix = confusion_matrix(target, prediction)

    # Calculate sensitivity and specificity
    true_negative = conf_matrix[0, 0]
    false_positive = conf_matrix[0, 1]
    false_negative = conf_matrix[1, 0]
    true_positive = conf_matrix[1, 1]

    sensitivity = true_positive / (true_positive + false_negative)
    specificity = true_negative / (true_negative + false_positive)

    print(f"Accuracy: {accuracy:.2f}")
    print(f"AUC: {auc_score:.2f}")
    print(f"Sensitivity: {sensitivity:.2f}")
    print(f"Specificity: {specificity:.2f}")

    # Print the confusion matrix
    print("Confusion Matrix:")
    print(conf_matrix)

#NEED TO ABSTRACT TO TARGET OUTCOME
def flag_pred_error(x):
    target = x["relevant_true"]
    pred = x["relevant_pred"]
    
    if pred == target:
        return 0
    else:
        return 1

In [None]:
#load data
# input_agent_file = "paper_agent_list_260723.csv"
input_agent_file = "paper_agent_list_061223.csv"
agent_list = pd.read_csv(input_agent_file)
agent_list

In [None]:
#tidy abstract so could use as an input feature
agent_list["Abstract_clean"] = agent_list["Abstract"].apply(lambda x: str(x))

In [None]:
#convert target outcomes to binary features

agent_list["relevant_bin"] = agent_list["Relevant?"].apply(convert_outcome_to_bin)
print(agent_list["relevant_bin"].value_counts())

agent_list["read_bin"] = agent_list["Read?"].apply(convert_outcome_to_bin)
agent_list["read_bin"].value_counts()

## Reviewing production models

In [None]:
## raw performance metrics

#filter to only data where predictons were made
agent_list_relevant_pred_df = agent_list[pd.notna(agent_list["Relevant_pred"])]

#quick accuracy calc
total_preds = len(agent_list_relevant_pred_df)
agrees = np.sum(agent_list_relevant_pred_df["Relevant_pred"] == agent_list_relevant_pred_df["Relevant?"])
print("Accuracy: ", agrees / total_preds)

# Create a confusion matrix
conf_matrix = confusion_matrix(agent_list_relevant_pred_df['Relevant?'], agent_list_relevant_pred_df['Relevant_pred'], labels=['Y', 'N'])

# Extract values from the confusion matrix
true_positives = conf_matrix[0, 0]
false_negatives = conf_matrix[0, 1]
true_negatives = conf_matrix[1, 1]
false_positives = conf_matrix[1, 0]

# Calculate sensitivity (true positive rate)
sensitivity = true_positives / (true_positives + false_negatives)

# Calculate specificity (true negative rate)
specificity = true_negatives / (true_negatives + false_positives)

print("Sensitivity:", sensitivity)
print("Specificity:", specificity)

In [None]:
## reviewing disagreements

pred_disagreements = agent_list_relevant_pred_df[agent_list_relevant_pred_df["Relevant_pred"] != agent_list_relevant_pred_df["Relevant?"]]


#visual evaluation if wanted to explore more specific performance details. 
    #could also calculate metrics by different queries and sources
# for idx, row in pred_disagreements.iterrows():
#     print(row["Title"] + " > " + row["Query"])
#     print("\n")
# pred_disagreements[["Title", "Source", "Query", "Relevant_pred", "Relevant?", "Read?"]]

## Training offline models

### Word count matrix as feature

In [None]:
#prep data

input_label = "Title"
# input_label = "Abstract_clean" 
    #marginally more accurate (much better sensitivity e.g. misses less relevant articles), in optimum model should aim to include title, abstract and original query. 
    #Interestingly exactly the same number of errors (43 - suggests labeling inconsistency?)
outcome_label = "relevant_bin"
# outcome_label = "read_bin"
# Split the dataset into features (X) and labels (y)
X = agent_list[input_label].values
y = agent_list[outcome_label].values

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Initialize CountVectorizer to convert text data into numerical features
vectorizer = CountVectorizer(stop_words='english')

# Fit and transform the training data
X_train_vectorized = vectorizer.fit_transform(X_train)

# Transform the testing data using the same vectorizer
X_test_vectorized = vectorizer.transform(X_test)

In [None]:
# Initialize the logistic regression model
lr_model_cm = LogisticRegression()

# Train the model using the vectorized training data
lr_model_cm.fit(X_train_vectorized, y_train)

# Predict on the train set
y_pred_train = lr_model_cm.predict(X_train_vectorized)
y_pred_train_prob = lr_model_cm.predict_proba(X_train_vectorized)[:, 1]  # Probability of class 1 (Relevant)

# Predict on the test set
y_pred = lr_model_cm.predict(X_test_vectorized)
y_pred_prob = lr_model_cm.predict_proba(X_test_vectorized)[:, 1]  # Probability of class 1 (Relevant)

In [None]:
#evaluate model
evaluate(y_test, y_pred, y_pred_prob)

In [None]:
#surface incorrect examples
#NOTE: some of this suggests labeling error

test_examples_dict = {"titles": X_test, "relevant_true": y_test, "relevant_pred":y_pred}
test_examples_df = pd.DataFrame(test_examples_dict)

test_examples_df["pred_error"] = test_examples_df.apply(flag_pred_error, axis=1)
test_examples_df_errors = test_examples_df[test_examples_df["pred_error"] == 1]
test_examples_df_errors

In [None]:
#NOTE: actually don't disagree with many of these, further supports that for best performance will likely need to refine input labels
test_examples_df_errors[test_examples_df_errors["relevant_true"] == 0]["titles"].values

In [None]:
#fit model (cross validation test - whole dataset)

# Encode data
X_vectorized = vectorizer.fit_transform(X)

# Initialize the logistic regression model
lr_model_cv = LogisticRegression()

scoring = ['accuracy', 'roc_auc', 'precision', 'recall']
#good article on why don't select final instance of chosen model using CV, positioned as a evaluation tool only - https://stats.stackexchange.com/questions/52274/how-to-choose-a-predictive-model-after-k-fold-cross-validation
    #also suggests that for final model could train it on the whole dataset (thereby maximising the data use)
outputs = cross_validate(lr_model_cv, X_vectorized, y, cv=7, scoring=scoring, return_train_score=True)
outputs

In [None]:
#fit model on all data

lr_model_wc = LogisticRegression().fit(X_vectorized, y)

#save model
today_date = datetime.today().strftime('%Y_%m_%d')
lr_model_wc_filename = 'pa_lr_model_wc_all_data_' + today_date + ".pkl"
pickle.dump(lr_model_wc, open(lr_model_wc_filename, 'wb'))

#test load (comment out)
lr_model_wc_loaded = pickle.load(open(lr_model_wc_filename, 'rb'))
loaded_result = lr_model_wc_loaded.score(X_vectorized, y)
print(loaded_result)

lr_model_wc_loaded.predict(X_vectorized[2].reshape(1,-1))

### Word embeddings as feature 

In [None]:
#prep data

input_label = "Title"
# input_label = "Abstract_clean" 
    #Takes significantly longer to embed - something to factor into application (would need to add an order of magnitude more accuracy to be worth computing on the fly)
    #Actually performs worse on this subset
outcome_label = "relevant_bin"
# Split the dataset into features (X) and labels (y)
X = agent_list[input_label].values
y = agent_list[outcome_label].values

# Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

#generate sentence embeddings
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# encode train 
X_train_vectorized = encoder.encode(X_train)

# Transform the testing data using the same vectorizer
X_test_vectorized = encoder.encode(X_test)

#### logisitic regression

In [None]:
#fit model (basic train test split to review cases / errors)

# Initialize the logistic regression model
lr_model_we = LogisticRegression()

# Train the model using the vectorized training data
lr_model_we.fit(X_train_vectorized, y_train)

# Predict on the train set
y_pred_train = lr_model_we.predict(X_train_vectorized)
y_pred_train_prob = lr_model_we.predict_proba(X_train_vectorized)[:, 1]  # Probability of class 1 (Relevant)

# Predict on the test set
y_pred = lr_model_we.predict(X_test_vectorized)
y_pred_prob = lr_model_we.predict_proba(X_test_vectorized)[:, 1]  # Probability of class 1 (Relevant)

#evaluate
evaluate(y_test, y_pred, y_pred_prob)

In [None]:
#save model
today_date = datetime.today().strftime('%Y_%m_%d')
lr_model_we_filename = 'pa_lr_model_we_tt_data_' + today_date + ".pkl"
pickle.dump(lr_model_we, open(lr_model_we_filename, 'wb'))

In [None]:
#test load (comment out)
lr_model_we_loaded = pickle.load(open(lr_model_we_filename, 'rb'))
loaded_result = lr_model_we_loaded.score(X_test_vectorized, y_test)
print(loaded_result)

lr_model_we_loaded.predict(X_test_vectorized[0].reshape(1,-1))

In [None]:
#inspect errors

test_examples_dict = {"titles": X_test, "relevant_true": y_test, "relevant_pred":y_pred}
test_examples_df = pd.DataFrame(test_examples_dict)

test_examples_df["pred_error"] = test_examples_df.apply(flag_pred_error, axis=1)
test_examples_df_errors = test_examples_df[test_examples_df["pred_error"] == 1]
test_examples_df_errors

In [None]:
#fit model (cross validation test - whole dataset)

# Encode data
X_vectorized = encoder.encode(X)

# Initialize the logistic regression model
lr_model_cv = LogisticRegression()

scoring = ['accuracy', 'roc_auc', 'precision', 'recall']
#good article on why don't select final instance of chosen model using CV, positioned as a evaluation tool only - https://stats.stackexchange.com/questions/52274/how-to-choose-a-predictive-model-after-k-fold-cross-validation
    #also suggests that for final model could train it on the whole dataset (thereby maximising the data use)
outputs = cross_validate(lr_model_cv, X_vectorized, y, cv=7, scoring=scoring, return_train_score=True, return_estimator=True)
outputs

In [None]:
#save model fit on entire dataset (refer to CV for estimated performance)

lr_model_we_all = LogisticRegression()
lr_model_we_all.fit(X_vectorized, y)

#save model
today_date = datetime.today().strftime('%Y_%m_%d')
lr_model_we_all_filename = 'pa_lr_model_we_' + today_date + ".pkl"
pickle.dump(lr_model_we_all, open(lr_model_we_all_filename, 'wb'))

In [None]:
#test load (comment out)
lr_model_we_all_loaded = pickle.load(open(lr_model_we_all_filename, 'rb'))
loaded_result = lr_model_we_all_loaded.score(X_vectorized, y)
print(loaded_result)

lr_model_we_all_loaded.predict(X_vectorized[2].reshape(1,-1))

#### xgboost

In [None]:
xgb_model_we = xgb.XGBRegressor(objective="binary:logistic", random_state=42)

xgb_model_we.fit(X_train_vectorized, y_train)

# Predict on the train set
y_pred_train_prob = xgb_model_we.predict(X_train_vectorized)
y_pred_train = [ 1 if p >= 0.5 else 0 for p in y_pred_train_prob ]

# Predict on the test set
y_pred_prob = xgb_model_we.predict(X_test_vectorized)
y_pred = [ 1 if p >= 0.5 else 0 for p in y_pred_prob ]

In [None]:
#save model
today_date = datetime.today().strftime('%Y_%m_%d')
xgb_model_we_filename = 'pa_xgb_model_we_' + today_date + ".pkl"
pickle.dump(xgb_model_we, open(xgb_model_we_filename, 'wb'))

#test load (comment out)

# xgb_model_we_loaded = pickle.load(open(xgb_model_we_filename, 'rb'))
# loaded_result = xgb_model_we_loaded.score(X_test_vectorized, y_test)
# print(loaded_result)

In [None]:
#evaluate
evaluate(y_test, y_pred, y_pred_prob)

In [None]:
#inspect errors

test_examples_dict = {"titles": X_test, "relevant_true": y_test, "relevant_pred":y_pred}
test_examples_df = pd.DataFrame(test_examples_dict)

test_examples_df["pred_error"] = test_examples_df.apply(flag_pred_error, axis=1)
test_examples_df_errors = test_examples_df[test_examples_df["pred_error"] == 1]
test_examples_df_errors

### Add more features

In [None]:
#add queries and sources - as word embeddings

input_labels = ["Title", "Query", "Source"]
outcome_label = "relevant_bin"
# Split the dataset into features (X) and labels (y)
X = agent_list[input_labels].values
y = agent_list[outcome_label].values

# # Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train_titles = [sample[0] for sample in X_train]
X_train_queries = [sample[1] for sample in X_train]
X_train_sources = [sample[2] for sample in X_train]

X_test_titles = [sample[0] for sample in X_test]
X_test_queries = [sample[1] for sample in X_test]
X_test_sources = [sample[2] for sample in X_test]

# generate sentence embeddings
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# encode train 
X_train_titles_vectorized = encoder.encode(X_train_titles)
X_train_queries_vectorized = encoder.encode(X_train_queries)
X_train_sources_vectorized = encoder.encode(X_train_sources)
X_train_vectorized = np.concatenate((X_train_titles_vectorized, X_train_queries_vectorized, X_train_sources_vectorized), axis=1)


# Transform the testing data using the same vectorizer
X_test_titles_vectorized = encoder.encode(X_test_titles)
X_test_queries_vectorized = encoder.encode(X_test_queries)
X_test_sources_vectorized = encoder.encode(X_test_sources)
X_test_vectorized = np.concatenate((X_test_titles_vectorized, X_test_queries_vectorized, X_test_sources_vectorized), axis=1)

In [None]:
#fit model

# Initialize the logistic regression model
lr_model_we_multi = LogisticRegression()

# Train the model using the vectorized training data
lr_model_we_multi.fit(X_train_vectorized, y_train)

# Predict on the train set
y_pred_train = lr_model_we_multi.predict(X_train_vectorized)
y_pred_train_prob = lr_model_we_multi.predict_proba(X_train_vectorized)[:, 1]  # Probability of class 1 (Relevant)

# Predict on the test set
y_pred = lr_model_we_multi.predict(X_test_vectorized)
y_pred_prob = lr_model_we_multi.predict_proba(X_test_vectorized)[:, 1]  # Probability of class 1 (Relevant)

#evaluate
evaluate(y_test, y_pred, y_pred_prob)

In [None]:
#fit model (cross validation test - whole dataset)

# Encode data
X_titles = [sample[0] for sample in X]
X_queries = [sample[1] for sample in X]
X_sources = [sample[2] for sample in X]

X_titles_vectorized = encoder.encode(X_titles)
X_queries_vectorized = encoder.encode(X_queries)
X_sources_vectorized = encoder.encode(X_sources)
X_vectorized = np.concatenate((X_titles_vectorized, X_queries_vectorized, X_sources_vectorized), axis=1)

# Initialize the logistic regression model
lr_model_cv = LogisticRegression()

scoring = ['accuracy', 'roc_auc', 'precision', 'recall']
#good article on why don't select final instance of chosen model using CV, positioned as a evaluation tool only - https://stats.stackexchange.com/questions/52274/how-to-choose-a-predictive-model-after-k-fold-cross-validation
    #also suggests that for final model could train it on the whole dataset (thereby maximising the data use)
outputs = cross_validate(lr_model_cv, X_vectorized, y, cv=7, scoring=scoring, return_train_score=True)
outputs

In [None]:
lr_model_all = LogisticRegression().fit(X_vectorized, y)
quick_accuracy_check = lr_model_all.score(X_vectorized, y)
print(quick_accuracy_check)

In [None]:
#add queries - as dummy vars (setup dummy vars)

def create_query_dummy_vars(query, query_col_name):
    if query == query_col_name:
        return 1
    else:
        return 0

dummy_query_vals = ["large language model evaluation", "menopause symptoms", "ChatGPT for healthcare", "menopause prediction", "menopause genetics"]
    
for query in dummy_query_vals:
    agent_list[query] = agent_list["Query"].apply(create_query_dummy_vars, args=(query, ))

In [None]:
#add queries - as dummy vars (create vectors)

input_labels = ["Title"] + dummy_query_vals
outcome_label = "relevant_bin"
# Split the dataset into features (X) and labels (y)
X = agent_list[input_labels].values
y = agent_list[outcome_label].values

# # Split the data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

X_train_titles = [sample[0] for sample in X_train]
X_train_queries = [list(sample[1:]) for sample in X_train]

X_test_titles = [sample[0] for sample in X_test]
X_test_queries = [list(sample[1:]) for sample in X_test]

# generate sentence embeddings
encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

# encode train 
X_train_titles_vectorized = encoder.encode(X_train_titles)
# X_train_queries_vectorized = encoder.encode(X_train_queries)
X_train_vectorized = np.concatenate((X_train_titles_vectorized, X_train_queries), axis=1)


# Transform the testing data using the same vectorizer
X_test_titles_vectorized = encoder.encode(X_test_titles)
# X_test_queries_vectorized = encoder.encode(X_test_queries)
X_test_vectorized = np.concatenate((X_test_titles_vectorized, X_test_queries), axis=1)

In [None]:
#fit model

# Initialize the logistic regression model
lr_model_we_multi = LogisticRegression()

# Train the model using the vectorized training data
lr_model_we_multi.fit(X_train_vectorized, y_train)

# Predict on the train set
y_pred_train = lr_model_we_multi.predict(X_train_vectorized)
y_pred_train_prob = lr_model_we_multi.predict_proba(X_train_vectorized)[:, 1]  # Probability of class 1 (Relevant)

# Predict on the test set
y_pred = lr_model_we_multi.predict(X_test_vectorized)
y_pred_prob = lr_model_we_multi.predict_proba(X_test_vectorized)[:, 1]  # Probability of class 1 (Relevant)

#evaluate
evaluate(y_test, y_pred, y_pred_prob)

### BERT / transformer based fine-tuning (in google colab)

In [None]:
#Some reference resources (also explored this in mimic-diag-prediction)
    #Huggingface
        #Intro - https://huggingface.co/docs/transformers/tasks/sequence_classification
        #Fine-tuning example (but not custom dataset or easily digestable) - https://huggingface.co/docs/transformers/training + https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/training.ipynb
        #Text classification example (again not a custom dataset, but could be useful as reference) - https://github.com/huggingface/notebooks/blob/main/examples/text_classification.ipynb
    #Other new examples (generally too high level or not fully complete e.g. loading a custom dataset)
        #https://www.thepythoncode.com/article/finetuning-bert-using-huggingface-transformers-python?utm_content=cmp-true
        #https://lajavaness.medium.com/regression-with-text-input-using-bert-and-transformers-71c155034b13
        #https://saturncloud.io/blog/bert-text-classification-using-pytorch-a-guide-for-data-scientists/
    #Legacy examples (from notes, none of these worked comprehensively)
        #https://colab.research.google.com/drive/1PHv-IRLPCtv7oTcIGbsgZHqrB5LPvB7S
        #https://colab.research.google.com/github/prateekjoshi565/Fine-Tuning-BERT/blob/master/Fine_Tuning_BERT_for_Spam_Classification.ipynb
        #reminder that this might not be feasible on an CPU - may need to switch to Google Colab (when run tests it is much, much quicker)

#### ChatGPT output for below prompt (v2)

Provide a complete example in Python of fine-tuning a BERT model on a custom dataset to predict a binary outcome. For example, given a text title, predict whether the title is relevant (Yes or No). 

##### full based off original (code in google colab runs quickly on full dataset)
https://colab.research.google.com/drive/1feZ5mjrQyHe5tOyTWbmzNxqgFHENjdwV

In [None]:
# Assuming you have your custom dataset in two lists: titles and labels (0 or 1)
#NOTE: shortened until proved works as expect
input_label = "Title"
outcome_label = "relevant_bin"
titles = agent_list[input_label].values[0:100]
labels = agent_list[outcome_label].values[0:100]

# Split the data into training and testing sets
train_titles, test_titles, train_labels, test_labels = train_test_split(titles, labels, test_size=0.2, random_state=42)

# Create a custom Dataset class for loading the data
class CustomDataset(Dataset):
    def __init__(self, titles, labels, tokenizer, max_length):
        self.titles = titles
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.titles)

    def __getitem__(self, idx):
        title = self.titles[idx]
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            title,
            add_special_tokens=True,
            max_length=self.max_length,
            return_tensors='pt',
            padding='max_length',
            truncation=True
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Set the maximum sequence length for BERT input
max_length = 128

# Load the pre-trained BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Create the custom train and test datasets
train_dataset = CustomDataset(train_titles, train_labels, tokenizer, max_length)
test_dataset = CustomDataset(test_titles, test_labels, tokenizer, max_length)

# Create data loaders for training and testing
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Load the pre-trained BERT model for sequence classification
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# Set up the optimizer and the device (assuming GPU is available)
optimizer = AdamW(model.parameters(), lr=2e-5)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: ", device)

# Move the model to the appropriate device
model.to(device)

# Training loop
num_epochs = 5

for epoch in range(num_epochs):
    print("Epoch #: ", epoch)
    model.train()
    total_loss = 0

    for idx, batch in enumerate(train_loader):
        print("Train batch #: ", idx)
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()

        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()

    avg_train_loss = total_loss / len(train_loader)

    # Evaluation on the test set
    model.eval()
    num_correct = 0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for idx, batch in enumerate(test_loader):
            print("Test batch #: ", idx)
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)
            
            outputs = model(input_ids, attention_mask=attention_mask)
            print("Test outputs: ", outputs)
            logits = outputs.logits
            preds = F.softmax(logits, dim=1).argmax(dim=1)
            print("Test labels: ", labels)
            print("Test preds: ", preds)

            num_correct += torch.sum(preds == labels).item()
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    accuracy = num_correct / len(test_dataset)

    print(f'Epoch {epoch+1}/{num_epochs}, Avg. training loss: {avg_train_loss:.4f}, Test accuracy: {accuracy:.4f}')
    
    # Calculate AUC and confusion matrix
    auc_score = roc_auc_score(all_labels, all_preds)
    conf_matrix = confusion_matrix(all_labels, all_preds)

    print(f'AUC: {auc_score:.4f}')
    print('Confusion Matrix:')
    print(conf_matrix)
    
    # Calculate sensitivity and specificity
    tn, fp, fn, tp = conf_matrix.ravel()

    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)

    print(f'Sensitivity (True Positive Rate): {sensitivity:.4f}')
    print(f'Specificity (True Negative Rate): {specificity:.4f}')

# Save the fine-tuned model
model.save_pretrained('fine_tuned_bert_model')