### Importing Required Libraries

In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import StratifiedKFold, train_test_split, GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score,
    classification_report, confusion_matrix
)
import os

### Loading Data

In [4]:
drug_interaction=pd.read_csv("drug_id_5693.csv")

cell_feature=pd.read_csv("cell_feature.csv")

drug_feature=pd.read_csv("feature197_300.csv")

### Drug Interaction Dataset Overview

- This dataset contains records of drug pairs tested on specific cell lines.
- Columns:
  - `g_id1`: Identifier for the first drug in the pair.
  - `g_id2`: Identifier for the second drug in the pair.
  - `cell`: The cell line where the drug pair was tested (e.g., `MDAMB468`, `BT549`).
  - `label`: Synergy status of the drug pair on the given cell line.
    - `0` indicates **synergistic** interaction.
    - `1` indicates **non-synergistic** interaction.

In [6]:
drug_interaction.head(3)

Unnamed: 0,g_id1,g_id2,cell,label
0,192,115,MDAMB468,1
1,16,50,BT549,0
2,162,93,BT549,0


In [7]:
drug_interaction.shape  #The `drug_interaction` dataset contains 5693 rows ie 5693 different interaction dataset 


(5693, 4)

### Class Distribution of Synergy Status

In [9]:
drug_interaction.label.value_counts()

label
0    4349
1    1344
Name: count, dtype: int64

### Cell Line Feature Data

- This table contains molecular features associated with each **cell line**.
- The first column identifies the **cell line** (e.g., `MDAMB468`).
- Subsequent columns correspond to gene expression levels 

In [11]:
cell_feature.head(3)

Unnamed: 0,cell,ENSG00000116237,ENSG00000162413,ENSG00000171603,ENSG00000160049,ENSG00000065526,ENSG00000117118,ENSG00000053371,ENSG00000076864,ENSG00000070831,...,ENSG00000156299,ENSG00000142166,ENSG00000159228,ENSG00000159231,ENSG00000183527,ENSG00000182093,ENSG00000182240,ENSG00000157617,ENSG00000160208,ENSG00000141959
0,MDAMB468,32.57,7.69,19.51,11.42,12.93,82.54,19.83,0.07,1.25,...,10.26,15.01,100.27,23.4,7.32,13.72,2.9,5.94,35.27,92.34
1,BT549,40.52,17.08,59.1,32.71,15.08,131.08,26.65,0.0,4.8,...,3.71,6.37,90.65,26.82,2.1,11.53,0.03,1.53,27.28,37.67
2,BT549,40.52,17.08,59.1,32.71,15.08,131.08,26.65,0.0,4.8,...,3.71,6.37,90.65,26.82,2.1,11.53,0.03,1.53,27.28,37.67


### Drug Feature Dataset

- This dataset contains numerical features representing chemical or structural properties of drugs.
- Columns labeled generically (e.g., `Column1.1` to `Column1.300`) correspond to various extracted features or descriptors.
- Each row represents a single drug’s feature vector.

In [13]:
drug_feature.head(3)

Unnamed: 0,Column1.1,Column1.2,Column1.3,Column1.4,Column1.5,Column1.6,Column1.7,Column1.8,Column1.9,Column1.10,...,Column1.291,Column1.292,Column1.293,Column1.294,Column1.295,Column1.296,Column1.297,Column1.298,Column1.299,Column1.300
0,0,0,0,0,0,1,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1,0,1,0,0,1,0,1,1,0,...,0,0,0,1,1,0,1,0,0,0
2,0,0,0,1,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0


### Dataset Summary: Unique Drugs and Cell Lines

- **Total unique drugs:** The count of distinct drugs involved in all drug pairs, combining both drug identifiers (`g_id1` and `g_id2`).
- **Total unique cell lines:** The count of distinct cell lines used in experiments.

In [15]:
# Calculate total unique drugs across both drug columns
total_unique_drugs = len(np.unique(np.concatenate([drug_interaction['g_id1'].unique(), drug_interaction['g_id2'].unique()])))

# Calculate total unique cell lines
total_unique_cells = len(cell_feature['cell'].unique())

# Print the results
print(f"Total unique drugs in dataset: {total_unique_drugs}")
print(f"Total unique cell lines in dataset: {total_unique_cells}")


Total unique drugs in dataset: 197
Total unique cell lines in dataset: 12


In [16]:
cell_features = cell_feature.drop_duplicates(subset='cell').set_index('cell')

In [17]:
cell_features = pd.DataFrame(cell_features, index=cell_features.index, columns=cell_features.columns)
cell_features

Unnamed: 0_level_0,ENSG00000116237,ENSG00000162413,ENSG00000171603,ENSG00000160049,ENSG00000065526,ENSG00000117118,ENSG00000053371,ENSG00000076864,ENSG00000070831,ENSG00000133216,...,ENSG00000156299,ENSG00000142166,ENSG00000159228,ENSG00000159231,ENSG00000183527,ENSG00000182093,ENSG00000182240,ENSG00000157617,ENSG00000160208,ENSG00000141959
cell,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
MDAMB468,32.57,7.69,19.51,11.42,12.93,82.54,19.83,0.07,1.25,0.05,...,10.26,15.01,100.27,23.4,7.32,13.72,2.9,5.94,35.27,92.34
BT549,40.52,17.08,59.1,32.71,15.08,131.08,26.65,0.0,4.8,0.0,...,3.71,6.37,90.65,26.82,2.1,11.53,0.03,1.53,27.28,37.67
ZR751,30.69,4.72,5.34,6.82,7.33,78.93,23.91,0.0,4.27,0.0,...,2.18,6.97,1.24,8.24,0.87,13.65,0.0,0.77,6.8,36.37
KPL1,28.36,5.36,21.52,11.99,15.61,86.26,43.19,2.87,2.69,0.0,...,7.24,4.86,54.44,5.98,3.97,13.16,0.98,0.24,16.99,155.4
HS578T,59.48,13.82,90.56,8.09,16.15,102.79,57.86,0.3,1.88,0.0,...,1.34,4.15,143.63,31.28,2.28,23.1,0.86,4.04,11.78,199.62
HUH7,40.89,12.81,23.74,17.67,10.69,73.27,25.64,0.0,1.07,0.0,...,0.01,6.38,41.33,0.21,2.93,13.88,0.0,0.77,25.3,74.72
MCF7,24.32,4.44,14.09,9.4,9.81,89.6,38.03,0.89,2.38,0.01,...,4.17,7.59,48.21,7.2,9.17,18.07,0.0,0.66,19.66,180.48
A549,28.65,4.2,21.2,10.78,7.25,92.97,17.06,0.0,2.88,0.05,...,0.1,8.71,129.35,4.6,6.91,17.92,0.41,1.71,24.41,12.68
UO31,30.97,13.79,50.58,21.08,9.54,90.3,43.36,0.0,1.04,0.43,...,2.22,6.61,66.56,4.97,1.85,23.63,2.6,2.38,19.86,62.48
HCC1187,38.51,30.22,16.61,8.38,17.67,55.27,13.43,0.33,0.96,0.11,...,7.95,6.18,69.91,23.58,2.43,6.08,0.67,4.27,22.68,116.02


### Preparing Cell Features for Integration with Drug Interaction Data


In [19]:
cell_feature_dict = cell_features.to_dict(orient='index')

In [20]:
cell_feature_dict.keys()

dict_keys(['MDAMB468', 'BT549', 'ZR751', 'KPL1', 'HS578T', 'HUH7', 'MCF7', 'A549', 'UO31', 'HCC1187', 'ACHN', 'X786O'])

### Appending Cell Features to Drug Interaction Dataset
- Merge or join cell line feature data with the drug interaction dataset.
- The merging key is the `cell` identifier common to both datasets.

In [22]:
drug_interaction['cell_features'] = drug_interaction['cell'].map(cell_feature_dict)

In [23]:
drug_interaction

Unnamed: 0,g_id1,g_id2,cell,label,cell_features
0,192,115,MDAMB468,1,"{'ENSG00000116237': 32.57, 'ENSG00000162413': ..."
1,16,50,BT549,0,"{'ENSG00000116237': 40.52, 'ENSG00000162413': ..."
2,162,93,BT549,0,"{'ENSG00000116237': 40.52, 'ENSG00000162413': ..."
3,55,10,BT549,0,"{'ENSG00000116237': 40.52, 'ENSG00000162413': ..."
4,24,20,BT549,0,"{'ENSG00000116237': 40.52, 'ENSG00000162413': ..."
...,...,...,...,...,...
5688,194,45,BT549,0,"{'ENSG00000116237': 40.52, 'ENSG00000162413': ..."
5689,44,91,BT549,0,"{'ENSG00000116237': 40.52, 'ENSG00000162413': ..."
5690,33,86,BT549,0,"{'ENSG00000116237': 40.52, 'ENSG00000162413': ..."
5691,69,68,HS578T,0,"{'ENSG00000116237': 59.48, 'ENSG00000162413': ..."


### Expanding and Setting Cell Features in Drug Interaction Dataset

- The `cell_features` column in the `drug_interaction` dataframe contains nested data (e.g., a list or dictionary).
- `apply(pd.Series)` is used to expand each element of `cell_features` into separate columns.
- The original `cell_features` column is dropped.

In [25]:
cell_features_expanded = drug_interaction['cell_features'].apply(pd.Series)

drug_interaction = pd.concat([drug_interaction.drop(columns=['cell_features']), cell_features_expanded], axis=1)
drug_interaction

Unnamed: 0,g_id1,g_id2,cell,label,ENSG00000116237,ENSG00000162413,ENSG00000171603,ENSG00000160049,ENSG00000065526,ENSG00000117118,...,ENSG00000156299,ENSG00000142166,ENSG00000159228,ENSG00000159231,ENSG00000183527,ENSG00000182093,ENSG00000182240,ENSG00000157617,ENSG00000160208,ENSG00000141959
0,192,115,MDAMB468,1,32.57,7.69,19.51,11.42,12.93,82.54,...,10.26,15.01,100.27,23.40,7.32,13.72,2.90,5.94,35.27,92.34
1,16,50,BT549,0,40.52,17.08,59.10,32.71,15.08,131.08,...,3.71,6.37,90.65,26.82,2.10,11.53,0.03,1.53,27.28,37.67
2,162,93,BT549,0,40.52,17.08,59.10,32.71,15.08,131.08,...,3.71,6.37,90.65,26.82,2.10,11.53,0.03,1.53,27.28,37.67
3,55,10,BT549,0,40.52,17.08,59.10,32.71,15.08,131.08,...,3.71,6.37,90.65,26.82,2.10,11.53,0.03,1.53,27.28,37.67
4,24,20,BT549,0,40.52,17.08,59.10,32.71,15.08,131.08,...,3.71,6.37,90.65,26.82,2.10,11.53,0.03,1.53,27.28,37.67
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5688,194,45,BT549,0,40.52,17.08,59.10,32.71,15.08,131.08,...,3.71,6.37,90.65,26.82,2.10,11.53,0.03,1.53,27.28,37.67
5689,44,91,BT549,0,40.52,17.08,59.10,32.71,15.08,131.08,...,3.71,6.37,90.65,26.82,2.10,11.53,0.03,1.53,27.28,37.67
5690,33,86,BT549,0,40.52,17.08,59.10,32.71,15.08,131.08,...,3.71,6.37,90.65,26.82,2.10,11.53,0.03,1.53,27.28,37.67
5691,69,68,HS578T,0,59.48,13.82,90.56,8.09,16.15,102.79,...,1.34,4.15,143.63,31.28,2.28,23.10,0.86,4.04,11.78,199.62


In [26]:
drug_interaction.shape

(5693, 958)

### Mapping Drug Features to Drug Interaction Dataset and Replacing IDs with Corresponding Features

In [28]:
drug_feature_dict = drug_feature.to_dict(orient='index')

drug_interaction['g_id1_features'] = drug_interaction['g_id1'].map(drug_feature_dict)
drug_interaction['g_id2_features'] = drug_interaction['g_id2'].map(drug_feature_dict)


In [29]:
drug_interaction.head()

Unnamed: 0,g_id1,g_id2,cell,label,ENSG00000116237,ENSG00000162413,ENSG00000171603,ENSG00000160049,ENSG00000065526,ENSG00000117118,...,ENSG00000159228,ENSG00000159231,ENSG00000183527,ENSG00000182093,ENSG00000182240,ENSG00000157617,ENSG00000160208,ENSG00000141959,g_id1_features,g_id2_features
0,192,115,MDAMB468,1,32.57,7.69,19.51,11.42,12.93,82.54,...,100.27,23.4,7.32,13.72,2.9,5.94,35.27,92.34,"{'Column1.1': 0, 'Column1.2': 0, 'Column1.3': ...","{'Column1.1': 1, 'Column1.2': 1, 'Column1.3': ..."
1,16,50,BT549,0,40.52,17.08,59.1,32.71,15.08,131.08,...,90.65,26.82,2.1,11.53,0.03,1.53,27.28,37.67,"{'Column1.1': 0, 'Column1.2': 0, 'Column1.3': ...","{'Column1.1': 0, 'Column1.2': 0, 'Column1.3': ..."
2,162,93,BT549,0,40.52,17.08,59.1,32.71,15.08,131.08,...,90.65,26.82,2.1,11.53,0.03,1.53,27.28,37.67,"{'Column1.1': 0, 'Column1.2': 0, 'Column1.3': ...","{'Column1.1': 0, 'Column1.2': 1, 'Column1.3': ..."
3,55,10,BT549,0,40.52,17.08,59.1,32.71,15.08,131.08,...,90.65,26.82,2.1,11.53,0.03,1.53,27.28,37.67,"{'Column1.1': 1, 'Column1.2': 0, 'Column1.3': ...","{'Column1.1': 0, 'Column1.2': 0, 'Column1.3': ..."
4,24,20,BT549,0,40.52,17.08,59.1,32.71,15.08,131.08,...,90.65,26.82,2.1,11.53,0.03,1.53,27.28,37.67,"{'Column1.1': 0, 'Column1.2': 0, 'Column1.3': ...","{'Column1.1': 0, 'Column1.2': 0, 'Column1.3': ..."


### Expanding and Setting Cell Features in Drug Interaction Dataset


In [31]:
g_id1_features_expanded = drug_interaction['g_id1_features'].apply(pd.Series).add_prefix('g1_')

g_id2_features_expanded = drug_interaction['g_id2_features'].apply(pd.Series).add_prefix('g2_')

drug_interaction = pd.concat([drug_interaction.drop(columns=['g_id1_features', 'g_id2_features']), 
                       g_id1_features_expanded, 
                       g_id2_features_expanded], axis=1)


In [32]:
drug_interaction.head()

Unnamed: 0,g_id1,g_id2,cell,label,ENSG00000116237,ENSG00000162413,ENSG00000171603,ENSG00000160049,ENSG00000065526,ENSG00000117118,...,g2_Column1.291,g2_Column1.292,g2_Column1.293,g2_Column1.294,g2_Column1.295,g2_Column1.296,g2_Column1.297,g2_Column1.298,g2_Column1.299,g2_Column1.300
0,192,115,MDAMB468,1,32.57,7.69,19.51,11.42,12.93,82.54,...,1,1,0,0,0,0,0,1,0,0
1,16,50,BT549,0,40.52,17.08,59.1,32.71,15.08,131.08,...,0,0,0,0,0,0,0,0,0,0
2,162,93,BT549,0,40.52,17.08,59.1,32.71,15.08,131.08,...,1,1,0,0,0,1,0,1,0,0
3,55,10,BT549,0,40.52,17.08,59.1,32.71,15.08,131.08,...,0,0,0,0,0,0,1,0,1,0
4,24,20,BT549,0,40.52,17.08,59.1,32.71,15.08,131.08,...,0,0,0,1,1,0,1,0,0,0


In [33]:
drug_interaction.drop(columns=['g_id1','g_id2','cell'],inplace=True)
drug_interaction.head()

Unnamed: 0,label,ENSG00000116237,ENSG00000162413,ENSG00000171603,ENSG00000160049,ENSG00000065526,ENSG00000117118,ENSG00000053371,ENSG00000076864,ENSG00000070831,...,g2_Column1.291,g2_Column1.292,g2_Column1.293,g2_Column1.294,g2_Column1.295,g2_Column1.296,g2_Column1.297,g2_Column1.298,g2_Column1.299,g2_Column1.300
0,1,32.57,7.69,19.51,11.42,12.93,82.54,19.83,0.07,1.25,...,1,1,0,0,0,0,0,1,0,0
1,0,40.52,17.08,59.1,32.71,15.08,131.08,26.65,0.0,4.8,...,0,0,0,0,0,0,0,0,0,0
2,0,40.52,17.08,59.1,32.71,15.08,131.08,26.65,0.0,4.8,...,1,1,0,0,0,1,0,1,0,0
3,0,40.52,17.08,59.1,32.71,15.08,131.08,26.65,0.0,4.8,...,0,0,0,0,0,0,1,0,1,0
4,0,40.52,17.08,59.1,32.71,15.08,131.08,26.65,0.0,4.8,...,0,0,0,1,1,0,1,0,0,0


### Splitting Features and Target Variable

- `X` contains all columns except the `label`, i.e., it includes drug features and cell features used for prediction.
- `y` stores the `label` column, which represents the synergy class — 0 for synergistic and 1 for non-synergistic interactions.


In [35]:
X = drug_interaction.drop(columns=['label'])
y = drug_interaction['label']

### Splitting Dataset into Training and Testing Sets

- Split the feature matrix `X` and target vector `y` into training and testing subsets.


In [37]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

### Training with Logistic Regression Model and Evaluating via Stratified K-Fold Cross-Validation

In [39]:
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

accuracy_list = []
precision_list = []
recall_list = []
f1_list = []


### Logistic Regression Training and Evaluation Using Stratified K-Fold Cross-Validation

This code performs 5-fold stratified cross-validation to evaluate the Logistic Regression model with balanced class weights.

- The dataset is split into 5 folds maintaining the original class distribution in each fold.
- For each fold:
  - Train on 4 folds and test on the remaining fold.
  - Calculate performance metrics: Accuracy, Weighted Precision, Weighted Recall, and Weighted F1-Score.
  - Display the confusion matrix with a heatmap visualization.
  - Print a detailed classification report to analyze per-class performance.
- This approach helps provide a robust estimate of model performance while addressing potential class imbalance.

### Training and Evaluating Logistic Regression with Optimized Hyperparameters


In [44]:
output_dir = 'img_(drug_synergy_prediction without Feature extraction(without quantile and min max scaling)'
fold = 1
for train_index, test_index in skf.split(X, y):
    X_train, X_test = X.iloc[train_index], X.iloc[test_index]
    y_train, y_test = y.iloc[train_index], y.iloc[test_index]

    model = LogisticRegression(C=0.01, class_weight={0: 1, 1: 3}, max_iter=10000)
    model.fit(X_train, y_train)
    y_pred = model.predict(X_test)

    acc = accuracy_score(y_test, y_pred)
    prec = precision_score(y_test, y_pred, average='weighted')
    rec = recall_score(y_test, y_pred, average='weighted')
    f1 = f1_score(y_test, y_pred, average='weighted')

    accuracy_list.append(acc)
    precision_list.append(prec)
    recall_list.append(rec)
    f1_list.append(f1)

    print(f"\nFold {fold}:")
    print(f"  Accuracy           = {acc:.4f}")
    print(f"  Weighted Precision = {prec:.4f}")
    print(f"  Weighted Recall    = {rec:.4f}")
    print(f"  Weighted F1-score  = {f1:.4f}")

    cm = confusion_matrix(y_test, y_pred)
    print("  Confusion Matrix:")
    print(cm)

    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False)
    plt.title(f'Confusion Matrix - Fold {fold}')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
      

    
    plt.savefig(os.path.join(output_dir, f'confusion_matrix_fold_{fold}_after_tuning.png'))    
    plt.close()

    print("\nDetailed Classification Report:")
    print(classification_report(y_test, y_pred, digits=4))

    fold += 1


Fold 1:
  Accuracy           = 0.6804
  Weighted Precision = 0.7382
  Weighted Recall    = 0.6804
  Weighted F1-score  = 0.6996
  Confusion Matrix:
[[617 253]
 [111 158]]

Detailed Classification Report:
              precision    recall  f1-score   support

           0     0.8475    0.7092    0.7722       870
           1     0.3844    0.5874    0.4647       269

    accuracy                         0.6804      1139
   macro avg     0.6160    0.6483    0.6185      1139
weighted avg     0.7382    0.6804    0.6996      1139


Fold 2:
  Accuracy           = 0.7138
  Weighted Precision = 0.7622
  Weighted Recall    = 0.7138
  Weighted F1-score  = 0.7298
  Confusion Matrix:
[[644 226]
 [100 169]]

Detailed Classification Report:
              precision    recall  f1-score   support

           0     0.8656    0.7402    0.7980       870
           1     0.4278    0.6283    0.5090       269

    accuracy                         0.7138      1139
   macro avg     0.6467    0.6842    0.6535  

### Calculating and Printing Average Performance Metrics Across All 5 Folds

In [46]:
print("\n=== Average Across 5 Folds ===")
print(f"wt.Avg Accuracy       = {np.mean(accuracy_list):.4f}")
print(f"wt.Avg Macro Precision= {np.mean(precision_list):.4f}")
print(f"wt.Avg Macro Recall   = {np.mean(recall_list):.4f}")
print(f"wt.Avg Macro F1-score = {np.mean(f1_list):.4f}")


=== Average Across 5 Folds ===
wt.Avg Accuracy       = 0.6900
wt.Avg Macro Precision= 0.7442
wt.Avg Macro Recall   = 0.6900
wt.Avg Macro F1-score = 0.7079


### Calculation of ROC AUC and Precision-Recall AUC (PR AUC)

In [63]:
model = LogisticRegression(C=0.01, class_weight={0: 1, 1: 3}, max_iter=10000)
model.fit(X_train, y_train)

In [64]:
y_pred = model.predict(X_test)
y_probs = model.predict_proba(X_test)[:, 1]

In [65]:
from sklearn.metrics import roc_auc_score

roc_auc = roc_auc_score(y_test, y_probs)
print("ROC-AUC Score:", roc_auc)


ROC-AUC Score: 0.6751200892091268


In [66]:
from sklearn.metrics import roc_curve
import matplotlib.pyplot as plt

fpr, tpr, thresholds = roc_curve(y_test, y_probs)

plt.figure(figsize=(8,6))
plt.plot(fpr, tpr, label=f'Logistic Regression (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--', label='Random Classifier')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, f'rocauc.png'))    
plt.close()



In [67]:
from sklearn.metrics import average_precision_score

pr_auc = average_precision_score(y_test, y_probs)
print("PR-AUC Score:", pr_auc)


PR-AUC Score: 0.39252006547183216


In [68]:
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt

precision, recall, thresholds = precision_recall_curve(y_test, y_probs)

plt.figure(figsize=(8,6))
plt.plot(recall, precision, label=f'Logistic Regression (PR-AUC = {pr_auc:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend()
plt.grid(True)
plt.savefig(os.path.join(output_dir, f'prauc.png'))    
plt.close()
