In [None]:
import os
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_curve, auc, precision_recall_curve
from lightning import pytorch as pl
from chemprop import data, featurizers, models, nn
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Create results directory
results_dir = 'results'
os.makedirs(results_dir, exist_ok=True)

# Load training dataset
input_path = "smiles_10449_train_test.csv"  # Path to dataset
smiles_column = 'SMILES'  # Column containing SMILES strings
target_columns = ['Toxicity']  # Target column for toxicity
df_input = pd.read_csv(input_path)

# Preprocess data
smis = df_input[smiles_column].values
ys = df_input[target_columns].values
all_data = [data.MoleculeDatapoint.from_smi(smi, y) for smi, y in zip(smis, ys)]

# Split data into training, validation, and testing sets
mols = [d.mol for d in all_data]
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = data.split_data_by_indices(all_data, train_indices, val_indices, test_indices)

# Extract molecular features
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_dset = data.MoleculeDataset(train_data[0], featurizer)
val_dset = data.MoleculeDataset(val_data[0], featurizer)
test_dset = data.MoleculeDataset(test_data[0], featurizer)
train_loader = data.build_dataloader(train_dset, num_workers=0)
val_loader = data.build_dataloader(val_dset, num_workers=0, shuffle=False)
test_loader = data.build_dataloader(test_dset, num_workers=0, shuffle=False)

# Define custom model
class MyMPNNModel(models.MPNN):
    def __init__(self, mp, agg, ffn, batch_norm, metric_list):
        super().__init__(mp, agg, ffn, batch_norm, metric_list)

    def configure_optimizers(self):
        optimizer = Adam(self.parameters(), lr=1e-3)
        lr_scheduler = {
            'scheduler': ReduceLROnPlateau(optimizer, factor=0.5, patience=5, verbose=True),
            'monitor': 'val_loss',
            'frequency': 1,
        }
        return {
            'optimizer': optimizer,
            'lr_scheduler': lr_scheduler,
        }

# Initialize model
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()
ffn = nn.BinaryClassificationFFN(n_tasks=1)
batch_norm = False
metric_list = None
mpnn = MyMPNNModel(mp, agg, ffn, batch_norm, metric_list)

# Set up training configuration
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="cpu",
    devices=1,
    max_epochs=20,
    callbacks=[pl.callbacks.ModelCheckpoint(dirpath=results_dir, monitor='val_loss', save_top_k=1)],
)

# Train the model
trainer.fit(mpnn, train_loader, val_loader)

# Save the best model checkpoint
best_model_path = os.path.join(results_dir, 'best_model.ckpt')
trainer.save_checkpoint(best_model_path)

# Evaluate model on test set
with torch.inference_mode():
    test_preds = trainer.predict(mpnn, test_loader)

test_preds = np.concatenate([pred.numpy() for pred in test_preds], axis=0)
y_true = df_input.iloc[test_indices[0]][target_columns[0]].values.flatten()
y_pred = test_preds.flatten()
y_pred_binary = [1 if p >= 0.5 else 0 for p in y_pred]

# Calculate performance metrics
accuracy = accuracy_score(y_true, y_pred_binary)
f1 = f1_score(y_true, y_pred_binary)
precision = precision_score(y_true, y_pred_binary)
recall = recall_score(y_true, y_pred_binary)

print(f"Binary Accuracy: {accuracy:.4f}")
print(f"Binary F1 Score: {f1:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

# Plot ROC curve
fpr, tpr, _ = roc_curve(y_true, y_pred)
roc_auc = auc(fpr, tpr)
plt.figure()
plt.plot(fpr, tpr, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], linestyle='--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.savefig(os.path.join(results_dir, 'roc_curve.png'))

# Load unknown compounds for prediction
unknown_input_path = "Water_pollutants.csv"  # Path to unknown compounds dataset
unknown_df = pd.read_csv(unknown_input_path)
unknown_smis = unknown_df[smiles_column].values

# Reload the best model for prediction
mpnn = MyMPNNModel.load_from_checkpoint(best_model_path)

# Prepare unknown compounds for prediction
unknown_data = [data.MoleculeDatapoint.from_smi(smi, None) for smi in unknown_smis]
unknown_dset = data.MoleculeDataset(unknown_data, featurizer)
unknown_loader = data.build_dataloader(unknown_dset, num_workers=0, shuffle=False)

# Predict unknown compounds
with torch.inference_mode():
    unknown_preds = trainer.predict(mpnn, unknown_loader)

unknown_preds = np.concatenate([pred.numpy() for pred in unknown_preds], axis=0)
unknown_preds_binary = [1 if p >= 0.5 else 0 for p in unknown_preds.flatten()]

# Save predictions for unknown compounds
unknown_output_df = pd.DataFrame({
    'Compound_Name': unknown_df['ID'],  # Assuming an ID column exists
    'SMILES': unknown_df[smiles_column].values,
    'Predicted_Toxicity': unknown_preds_binary
})
unknown_output_df.to_csv(os.path.join(results_dir, 'Water_pollutants_predictions.csv'), index=False)

print("Prediction complete. Results saved to 'results/Water_pollutants_predictions.csv'.")
