# Generating Counterfactuals Directly from LLMs

## Some utils.

### Seeds

In [None]:
import openai
import torch
import torch.nn.functional as F
import numpy as np
import random


def set_seed():
    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)

### GNN related

In [1]:
def load_gnn(model_path):
    """_summary_

    Args:
        model_path (str): path to the model
    """
    model = torch.load(model_path)
    model.eval()
    return model

### Generate counterfactual SMILES

In [2]:


def generate_counterfactual_smiles(smiles, label):
    """
    Generate a counterfactual SMILES string for a given SMILES and label semantics.

    :param smiles: The original SMILES string of the molecule.
    :param label: The desired label semantics for the counterfactual molecule.
    :return: A counterfactual SMILES string that satisfies the desired label semantics.
    """
    prompt = f"Minimally edit {smiles} to be a {label} and output its SMILES representation only."

    response = openai.ChatCompletion.create(
        model="gpt-4.0-turbo",
        messages=[
            {"role": "system", "content": "You are a highly knowledgeable chemistry assistant."},
            {"role": "user", "content": prompt}
        ]
    )

    # Extracting the output from the response
    output = response['choices'][0]['message']['content'].strip()

    return output



### Evaluation Related

In [3]:
def compute_proximity(graphs, cf_graph):
    """_summary_

    Args:
        graphs (dicitonary): keys: adj, node, edge
        cf_graph (disctionary): keys: adj, node, edge
    """
    proximity = 0
    for key in graphs.keys():
        graphs[key] = torch.tensor(graphs[key])
        cf_graph[key] = torch.tensor(cf_graph[key])
        if key == 'edge':
            proximity += torch.sum(graphs[key]- cf_graph[key])/2
        else:
            proximity += torch.sum(graphs[key]- cf_graph[key])
    proximity /= len(graphs)
    return proximity

## Arguments


In [4]:
import argparse
def parse_args():
    parser = argparse.ArgumentParser(description='Arguments for counterfactual genration directly from the LLMs')
    parser.add_argument('--dataset', type=str, help='name of the dataset')
    return parser.parse_args()
args = parse_args()

usage: ipykernel_launcher.py [-h] [--dataset DATASET]
ipykernel_launcher.py: error: unrecognized arguments: --f=/home/xxx/.local/share/jupyter/runtime/kernel-v2-2485384zZsuXQPyBKQY.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


## Main Code Starts

## load dataset

In [None]:
# output the list of SMILES strings and their desired label semantics
from utils.datasets import Dataset
set_seed()
dataset = Dataset(dataset=args.dataset, generate_text=args.generate_text)
_, _, test_data = dataset.get_dataloaders()
test_smiles_list = test_data.smiles

In [None]:
label_semantics = {
    'AIDS':
    'Clintox':
}

## generate counterfactual

In [None]:
# iterate all the SMILES strings 

counterfactual_smiles_list = []
for smiles in test_smiles_list:
    counterfactual_smiles = generate_counterfactual_smiles(smiles, label_semantics)
    counterfactual_smiles_list.append(counterfactual_smiles)

## Save counterfactuals

In [None]:
# save counterfactuals as csv, each row contains one SMILES.
import pandas as pd

df = pd.DataFrame(counterfactual_smiles_list, columns=['counterfactual_smiles'])

df.to_csv('../exp_results/rebuttal/gce_dir_llm_'+{dataset}+'.csv', index=False)

## Evaluate counterfactuals

In [None]:
from model.model_evaluation import evaluate_gce_model

# evaluate the counterfactuals

# (1) Transform the SMILES to the graph
# (2) Filter the graphs that are not chemical feasible
# (2) Load the GNN weights for the certain dataset
# (3) Evaluate the validity and proximity

from utils.data_load import get_graphs_from_smiles
graphs, max_nodes, smiles, graph_labels = get_graphs_from_smiles(counterfactual_smiles_list, dataset)
# the graphs are already filtered by the chemical feasibility
gnn = load_gnn(dataset)
predictions = gnn.predict(graphs)
validity = sum(predictions) / len(test_smiles_list)
proximity = compute_proximity(gnn, graphs, graph_labels)

## Save the results

In [None]:
# save the validity and proximity of the model to the csv file
df = pd.DataFrame({'validity': validity, 'proximity': proximity})
df.to_csv('../exp_results/rebuttal/gce_dir_llm_'+{dataset}+'_evaluation.csv', index=False)