# **Utilising SMILE Embeddings**

## Import Libraries

In [1]:
import xml.etree.ElementTree as ET
import pandas as pd
from rdkit import Chem
from rdkit.Chem import MolToSmiles
from transformers import RobertaTokenizer, RobertaModel
import torch
import numpy as np
from tqdm import tqdm
import faiss

  from .autonotebook import tqdm as notebook_tqdm


## Parse DrugBank

In [2]:
def parse_drugbank_xml(xml_path):
    ns = {'db': 'http://www.drugbank.ca'}  # DrugBank namespace
    tree = ET.parse(xml_path)
    root = tree.getroot()

    drugs = []

    for drug in root.findall('db:drug', ns):
        name = drug.find('db:name', ns)
        if name is not None:
            name = name.text.strip()

        # Look for SMILES under <calculated-properties>
        smiles = None
        properties = drug.find('db:calculated-properties', ns)
        if properties is not None:
            for prop in properties.findall('db:property', ns):
                kind = prop.find('db:kind', ns)
                if kind is not None and kind.text == 'SMILES':
                    value = prop.find('db:value', ns)
                    if value is not None:
                        smiles = value.text.strip()
                        break

        if name and smiles:
            drugs.append((name, smiles))

    return pd.DataFrame(drugs, columns=['drug_name', 'smiles'])

In [3]:
df_drugs = parse_drugbank_xml("../Dataset/full database.xml")

In [4]:
df_drugs.head()

Unnamed: 0,drug_name,smiles
0,Bivalirudin,CC[C@H](C)[C@H](NC(=O)[C@H](CCC(O)=O)NC(=O)[C@...
1,Leuprolide,CCNC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCNC(N)=N)NC(=...
2,Goserelin,CC(C)C[C@H](NC(=O)[C@@H](COC(C)(C)C)NC(=O)[C@H...
3,Gramicidin D,CC(C)C[C@@H](NC(=O)CNC(=O)[C@@H](NC=O)C(C)C)C(...
4,Desmopressin,NC(=O)CC[C@@H]1NC(=O)[C@H](CC2=CC=CC=C2)NC(=O)...


## Load Tokenizer + Model

In [6]:
tokenizer = RobertaTokenizer.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model = RobertaModel.from_pretrained("seyonec/ChemBERTa-zinc-base-v1")
model.eval()

def canonicalize_smiles(smiles):
    mol = Chem.MolFromSmiles(smiles)
    return MolToSmiles(mol, canonical=True) if mol else None

def get_embedding(smiles):
    inputs = tokenizer(smiles, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()

In [7]:
embeddings = []
valid_names = []

for idx, row in tqdm(df_drugs.iterrows(), total=len(df_drugs)):
    canon_smiles = canonicalize_smiles(row['smiles'])
    if canon_smiles:
        try:
            emb = get_embedding(canon_smiles)
            embeddings.append(emb)
            valid_names.append(row['drug_name'])
        except Exception as e:
            continue  # skip failed ones

 14%|█▍        | 1687/11925 [00:29<03:07, 54.64it/s][12:05:16] Explicit valence for atom # 13 Cl, 5, is greater than permitted
 15%|█▍        | 1758/11925 [00:30<02:35, 65.37it/s][12:05:17] SMILES Parse Error: syntax error while parsing: OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O)(O)O)C1=CC=C(C=C1)S([O-])([O-])[O-])\C1=CC=C(C=C1)S(O)(O)[O-]
[12:05:17] SMILES Parse Error: check for mistakes around position 84:
[12:05:17] C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O
[12:05:17] ~~~~~~~~~~~~~~~~~~~~^
[12:05:17] SMILES Parse Error: Failed parsing SMILES 'OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O)(O)O)C1=CC=C(C=C1)S([O-])([O-])[O-])\C1=CC=C(C=C1)S(O)(O)[O-]' for input: 'OS(O)(O)C1=CC=C(C=C1)C-1=C2\C=CC(=N2)\C(=C2/N\C(\C=C2)=C(/C2=N/C(/C=C2)=C(\C2=CC=C\-1N2)C1=CC=C(C=C1)S(O)(O)O)C1=CC=C(C=C1)S([O-])([O-])[O-])\C1=CC=C(C=C1)S(O)(O)[O-]'
 20%|█▉        | 2382/11925 [00:40<02

In [8]:
embedding_matrix = np.vstack(embeddings)

## Faiss

In [9]:
dimension = embedding_matrix.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embedding_matrix)

In [10]:
topk = 4  # include original drug, remove later

results = {}

for i in range(len(embedding_matrix)):
    _, indices = index.search(embedding_matrix[i].reshape(1, -1), topk)
    similar_names = [valid_names[idx] for idx in indices[0] if idx != i]  # exclude self
    results[valid_names[i]] = similar_names[:3]

## Final Output

In [11]:
alt_df = pd.DataFrame([
    {'drug': drug, 'alternative_1': alts[0], 'alternative_2': alts[1], 'alternative_3': alts[2]}
    for drug, alts in results.items() if len(alts) >= 3
])

In [12]:
alt_df 

Unnamed: 0,drug,alternative_1,alternative_2,alternative_3
0,Bivalirudin,Semaglutide,Avexitide,PP-F11N lutetium Lu-177
1,Leuprolide,Buserelin,Deslorelin,Nerofe
2,Goserelin,Nafarelin,Triptorelin,Ganirelix
3,Gramicidin D,Nerofe,Echinomycin,Reltecimod
4,Desmopressin,Lypressin,Selepressin,Ozarelix
...,...,...,...,...
11908,Alogabat,4-(6-CYCLOHEXYLMETHOXY-9H-PURIN-2-YLAMINO)--BE...,Mizolastine,N-cyclopropyl-4-methyl-3-{2-[(2-morpholin-4-yl...
11909,Ropsacitinib,Regadenoson,Vistusertib,Golidocitinib
11910,taletrectinib,RU90395,Cadazolid,Carotegrast methyl
11911,Tolebrutinib,Tirabrutinib,Ibrutinib,Edralbrutinib


In [13]:
## Save alternatives to CSV
alt_df.to_csv("drug_alternatives.csv", index=False)
print("Alternative drugs saved to drug_alternatives.csv")

Alternative drugs saved to drug_alternatives.csv


# **Using BioBert Embeddings**

## Import Libraries

In [59]:
import pandas as pd
import numpy as np
import xml.etree.ElementTree as ET
import networkx as nx
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
from tqdm import tqdm
from collections import defaultdict
import re
import base64
import os
from google import genai
from google.genai import types

## Parse DrugBank XML file

In [60]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

Using device: mps


In [None]:
def parse_drugbank(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    drugs = []
    for drug in root.findall("{http://www.drugbank.ca}drug"):
        drug_name = drug.find("{http://www.drugbank.ca}name")
        drug_id = drug.find("{http://www.drugbank.ca}drugbank-id")
        indications = drug.find("{http://www.drugbank.ca}indication")

        # Handle missing values safely
        drug_name = drug_name.text.strip() if drug_name is not None else "Unknown"
        drug_id = drug_id.text.strip() if drug_id is not None else "Unknown"
        indications = indications.text.strip() if (indications is not None and indications.text) else "No Indications"

        drugs.append({"Drug": drug_name, "DrugID": drug_id, "Indications": indications})
    
    return pd.DataFrame(drugs)

In [62]:
# Parse DrugBank XML and create a DataFrame
drug_data = parse_drugbank("../Dataset/full database.xml")
print(drug_data.head())  

                  Drug   DrugID  \
0            Lepirudin  DB00001   
1            Cetuximab  DB00002   
2         Dornase alfa  DB00003   
3  Denileukin diftitox  DB00004   
4           Etanercept  DB00005   

                                         Indications  
0  Lepirudin is indicated for anticoagulation in ...  
1  Cetuximab indicated for the treatment of local...  
2  Used as adjunct therapy in the treatment of cy...  
3         For treatment of cutaneous T-cell lymphoma  
4  Etanercept is indicated for the treatment of m...  


In [63]:
# Initialize BioBERT/SciBERT model for text embeddings on GPU
bert_model = SentenceTransformer("all-mpnet-base-v2", device=device)

def get_drug_embedding(drug_name, indications):
    """Generate an embedding for a drug based on its name and indications."""
    text = f"{drug_name}: {indications}"
    return bert_model.encode(text, convert_to_numpy=True)

# Compute embeddings for all drugs
drug_data["Embeddings"] = drug_data.apply(lambda row: get_drug_embedding(row["Drug"], row["Indications"]), axis=1)

# Convert embeddings into a NumPy matrix for fast similarity computation
embeddings_matrix = np.vstack(drug_data["Embeddings"].values)

In [None]:
def get_top_n_similar_drugs_in_rows(top_n=3):
    rows = []

    for idx, row in drug_data.iterrows():
        drug_name = row["Drug"]
        target_embedding = row["Embeddings"]

        similarities = cosine_similarity([target_embedding], embeddings_matrix)[0]
        
        temp_df = drug_data.copy()
        temp_df["Similarity"] = similarities

        # Remove the original drug from candidates
        temp_df = temp_df[temp_df["Drug"] != drug_name]

        # Get top-N most similar drugs
        top_similars = temp_df.sort_values(by="Similarity", ascending=False).head(top_n).reset_index(drop=True)

        # Build a single row with the original drug and its top-N alternatives
        row_data = {
            "Original Drug": drug_name,
            "Original Indications": row["Indications"]
        }

        for i in range(top_n):
            row_data[f"Alternative {i+1}"] = top_similars.loc[i, "Drug"]
            row_data[f"Similarity {i+1}"] = round(top_similars.loc[i, "Similarity"], 4)
            row_data[f"Indications {i+1}"] = top_similars.loc[i, "Indications"]

        rows.append(row_data)

    return pd.DataFrame(rows)

In [67]:
structured_alternatives_df = get_top_n_similar_drugs_in_rows(top_n=3)
structured_alternatives_df.head()

Unnamed: 0,Original Drug,Original Indications,Alternative 1,Similarity 1,Indications 1,Alternative 2,Similarity 2,Indications 2,Alternative 3,Similarity 3,Indications 3
0,Lepirudin,Lepirudin is indicated for anticoagulation in ...,Dalteparin,0.6823,Dalteparin is used as a prophylaxis for deep-v...,Heparin,0.6634,Unfractionated heparin is indicated for prophy...,Bivalirudin,0.6611,For treatment of heparin-induced thrombocytope...
1,Cetuximab,Cetuximab indicated for the treatment of local...,Cemiplimab,0.7153,Cemiplimab is indicated to treat:\r\n\r\n- **L...,Ramucirumab,0.7102,Ramucirumab is indicated for the treatment of ...,Cediranib,0.6955,"For the treatment of liver cancer, advanced no..."
2,Dornase alfa,Used as adjunct therapy in the treatment of cy...,Denufosol,0.5917,For use as an inhaled treatment for cystic fib...,Cystic fibrosis transmembrane conductance regu...,0.5874,No Indications,ALTU-135,0.5492,Investigated for use/treatment in cystic fibro...
3,Denileukin diftitox,For treatment of cutaneous T-cell lymphoma,Vorinostat,0.6477,For the treatment of cutaneous manifestations ...,Romidepsin,0.6033,Romidepsin is indicated for the treatment of c...,Brentuximab vedotin,0.5806,Brentuximab vedotin is indicated in adult pati...
4,Etanercept,Etanercept is indicated for the treatment of m...,Etarfolatide,0.6439,No Indications,Abatacept,0.6404,Abatacept is indicated in adult patients for t...,Alefacept,0.6347,"As an immunosuppressive drug, Alefacept can be..."


# **Validation Framework**

### Import Libraries

In [68]:
import pandas as pd
import requests
import time
import re
import json
from bs4 import BeautifulSoup
import google.generativeai as genai
from typing import List, Dict, Tuple, Optional
import os
from tqdm import tqdm

### Validation Framework

In [71]:
class PatentValidationFramework:
    def __init__(self, df: pd.DataFrame, api_key: str, query_col: str = "query", alternatives_cols: List[str] = None):
        """
        Initialize the framework with a DataFrame containing queries and alternatives.
        
        Args:
            df: DataFrame with drug queries and alternatives
            api_key: Google AI API key
            query_col: Column name containing the query compound
            alternatives_cols: List of column names containing alternative compounds
        """
        self.df = df
        self.query_col = query_col
        self.alternatives_cols = alternatives_cols if alternatives_cols else [col for col in df.columns if col != query_col]
        
        # Configure Gemini API
        genai.configure(api_key=api_key)
        
        # Select models
        self.descriptor_model = genai.GenerativeModel('gemini-2.0-flash-lite')
        self.similarity_model = genai.GenerativeModel('gemini-2.0-flash-lite')
        
        # Results storage
        self.patent_data = {}
        self.functional_descriptors = {}
        self.similarity_results = {}
        self.query_results = {}
        self.alternative_results = {}
        
        # API rate limiting
        self.pubchem_delay = 0.5  # seconds between PubChem API calls
        self.scholar_delay = 2.0  # seconds between Google Scholar requests
        self.gemini_delay = 2.5   # seconds between Gemini API calls
        
    def get_pubchem_cid(self, compound_name: str) -> List[str]:
        """
        Convert compound name to PubChem CID.
        
        Args:
            compound_name: Name of the compound
            
        Returns:
            List of PubChem CIDs
        """
        try:
            url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{compound_name}/cids/JSON"
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            
            if "IdentifierList" in data and "CID" in data["IdentifierList"]:
                return [str(cid) for cid in data["IdentifierList"]["CID"]]
            return []
        except Exception as e:
            print(f"Error retrieving CID for {compound_name}: {e}")
            return []
        finally:
            time.sleep(self.pubchem_delay)
    
    def get_patent_ids(self, cid: str, max_patents: int = 10) -> List[str]:
        """
        Get patent IDs associated with a PubChem CID.
        
        Args:
            cid: PubChem Compound ID
            max_patents: Maximum number of patents to retrieve
            
        Returns:
            List of patent IDs
        """
        try:
            url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/{cid}/xrefs/PatentID/JSON"
            response = requests.get(url)
            response.raise_for_status()
            data = response.json()
            
            if "InformationList" in data and "Information" in data["InformationList"]:
                info = data["InformationList"]["Information"][0]
                if "PatentID" in info:
                    return info["PatentID"][:max_patents]
            return []
        except Exception as e:
            print(f"Error retrieving patent IDs for CID {cid}: {e}")
            return []
        finally:
            time.sleep(self.pubchem_delay)
    
    def scrape_patent_info(self, patent_id: str) -> Dict[str, str]:
        """
        Scrape patent information from Google Patents.
        
        Args:
            patent_id: Patent identifier
            
        Returns:
            Dictionary with patent title, abstract, and description
        """
        try:
            patent_id = patent_id.replace("-", "")
            
            url = f"https://patents.google.com/patent/{patent_id}"
            headers = {
                "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
            }
            response = requests.get(url, headers=headers)
            response.raise_for_status()
            
            soup = BeautifulSoup(response.text, 'html.parser')
            
            # Extract title, abstract, and description
            # Actual implementation would need to match Google Patents HTML structure
            title_elem = soup.find("span", {"itemprop": "title"})
            title = title_elem.text.strip() if title_elem else ""
            
            abstract_elem = soup.find("div", {"class": "abstract"})
            abstract = abstract_elem.text.strip() if abstract_elem else ""
            
            description_elem = soup.find("div", {"class": "description"})
            description = description_elem.text.strip() if description_elem else ""
            
            return {
                "title": title,
                "abstract": abstract,
                "description": description[:5000]  # Limit description length
            }
        except Exception as e:
            print(f"Error scraping patent info for {patent_id}: {e}")
            return {"title": "", "abstract": "", "description": ""}
        finally:
            time.sleep(self.scholar_delay)
    
    def generate_functional_descriptors(self, patent_info: Dict[str, str], compound_name: str) -> List[str]:
        """
        Generate functional descriptors using Gemini.
        
        Args:
            patent_info: Dictionary containing patent title, abstract, and description
            compound_name: Name of the compound
            
        Returns:
            List of functional descriptors
        """
        prompt = f"""
        You are a pharmaceutical expert analyzing patent information for the compound {compound_name}.
        
        Patent Title: {patent_info['title']}
        Patent Abstract: {patent_info['abstract']}
        Patent Description: {patent_info['description'][:2000]}...
        
        Based solely on the patent information above, provide 1-3 brief functional descriptors 
        (1-3 words each) for the compound {compound_name}. Focus on its therapeutic function, 
        mechanism of action, or treatment target. Be concise and specific.
        
        Format your response as a comma-separated list without explanations or additional text.
        """
        
        try:
            response = self.descriptor_model.generate_content(prompt)
            descriptors_text = response.text.strip()
            
            # Clean up response to extract just the comma-separated list
            cleaned_text = re.sub(r'^[\s\S]*?([\w\s-]+(?:,\s*[\w\s-]+)*)[\s\S]*$', r'\1', descriptors_text)
            descriptors = [d.strip() for d in cleaned_text.split(',')]
            
            # Clean up descriptors
            cleaned_descriptors = []
            for d in descriptors:
                # Remove any unwanted characters and enforce length limits
                d = re.sub(r'[^\w\s-]', '', d)
                if len(d.split()) <= 3 and d not in cleaned_descriptors:
                    cleaned_descriptors.append(d)
            
            time.sleep(self.gemini_delay)
            return cleaned_descriptors
        except Exception as e:
            print(f"Error generating descriptors for {compound_name}: {e}")
            return []
    
    def determine_functional_similarity(self, query_descriptors: List[str], 
                                       alt_descriptors: List[str],
                                       query_name: str,
                                       alt_name: str) -> Dict:
        """
        Determine functional similarity using Gemini.
        
        Args:
            query_descriptors: List of descriptors for the query compound
            alt_descriptors: List of descriptors for the alternative compound
            query_name: Name of the query compound
            alt_name: Name of the alternative compound
            
        Returns:
            Dictionary with similarity assessment
        """
        prompt = f"""
        You are analyzing the functional similarity between two pharmaceutical compounds:
        
        Query compound: {query_name}
        Functional descriptors: {', '.join(query_descriptors)}
        
        Alternative compound: {alt_name}
        Functional descriptors: {', '.join(alt_descriptors)}
        
        Based on these functional descriptors, determine whether these compounds have similar functionality.
        Consider mechanism of action, therapeutic targets, and clinical applications.
        
        Provide your assessment as a JSON with the following structure:
        {{
            "is_similar": true/false,
            "similarity_score": [0-100],
            "explanation": "Brief explanation of your reasoning",
            "shared_functions": ["list", "of", "shared", "functions"]
        }}
        
        Only provide the JSON output, nothing else.
        """
        
        try:
            response = self.similarity_model.generate_content(prompt)
            similarity_assessment = response.text.strip()
            
            # Extract the JSON part if there's any text around it
            json_match = re.search(r'{.*}', similarity_assessment, re.DOTALL)
            if json_match:
                similarity_assessment = json_match.group(0)
            
            # Convert to Python dictionary
            try:
                result = json.loads(similarity_assessment)
                time.sleep(self.gemini_delay)
                return result
            except json.JSONDecodeError:
                print(f"Error parsing JSON from similarity assessment: {similarity_assessment}")
                return {
                    "is_similar": False,
                    "similarity_score": 0,
                    "explanation": "Error processing response",
                    "shared_functions": []
                }
                
        except Exception as e:
            print(f"Error determining similarity between {query_name} and {alt_name}: {e}")
            return {
                "is_similar": False,
                "similarity_score": 0,
                "explanation": f"Error: {str(e)}",
                "shared_functions": []
            }
    
    def process_compound(self, compound_name: str) -> Dict:
        """
        Process a single compound through the entire pipeline.
        
        Args:
            compound_name: Name of the compound
            
        Returns:
            Dictionary with processing results
        """
        results = {"name": compound_name, "cids": [], "patents": [], "descriptors": []}
        
        # Step 1: Get PubChem CIDs
        cids = self.get_pubchem_cid(compound_name)
        results["cids"] = cids
        
        if not cids:
            print(f"No CIDs found for {compound_name}")
            return results
        
        # Step 2: Get patent IDs (up to 10 per CID)
        all_patent_ids = []
        for cid in cids[:3]:  # Limit to first 3 CIDs to avoid excessive API calls
            patent_ids = self.get_patent_ids(cid)
            all_patent_ids.extend(patent_ids)
        
        # Deduplicate and limit to 10 total
        unique_patent_ids = list(set(all_patent_ids))[:10]
        results["patents"] = unique_patent_ids
        
        if not unique_patent_ids:
            print(f"No patents found for {compound_name}")
            return results
        
        # Step 3: Scrape patent info and generate descriptors
        all_descriptors = []
        for patent_id in unique_patent_ids[:3]:  # Limit to first 3 patents
            patent_info = self.scrape_patent_info(patent_id)
            if any(patent_info.values()):  # If we got any useful info
                descriptors = self.generate_functional_descriptors(patent_info, compound_name)
                all_descriptors.extend(descriptors)
        
        # Deduplicate descriptors
        unique_descriptors = list(set(all_descriptors))
        results["descriptors"] = unique_descriptors
        
        return results
    
    def run_pipeline(self, sample_size: Optional[int] = None) -> Dict:
        """
        Run the full validation pipeline on all compounds in the DataFrame.
        
        Args:
            sample_size: Optional number of rows to process (for testing)
            
        Returns:
            Dictionary with validation results
        """
        # Process the dataframe
        df_to_process = self.df.head(sample_size) if sample_size else self.df
        
        # Process query compounds
        print("Processing query compounds...")
        query_results = {}
        for idx, row in tqdm(df_to_process.iterrows(), total=len(df_to_process)):
            query_name = row[self.query_col]
            if query_name not in query_results:
                query_results[query_name] = self.process_compound(query_name)
        
        self.query_results = query_results
        
        # Process alternative compounds
        print("Processing alternative compounds...")
        alt_results = {}
        for idx, row in tqdm(df_to_process.iterrows(), total=len(df_to_process)):
            for alt_col in self.alternatives_cols:
                alt_name = row[alt_col]
                if pd.notna(alt_name) and alt_name not in alt_results:
                    alt_results[alt_name] = self.process_compound(alt_name)
        
        self.alternative_results = alt_results
        
        # Determine functional similarity
        print("Determining functional similarity...")
        similarity_results = []
        for idx, row in tqdm(df_to_process.iterrows(), total=len(df_to_process)):
            query_name = row[self.query_col]
            query_descriptors = query_results.get(query_name, {}).get("descriptors", [])
            
            for alt_col in self.alternatives_cols:
                alt_name = row[alt_col]
                if pd.notna(alt_name):
                    alt_descriptors = alt_results.get(alt_name, {}).get("descriptors", [])
                    
                    if query_descriptors and alt_descriptors:
                        similarity = self.determine_functional_similarity(
                            query_descriptors, alt_descriptors, query_name, alt_name
                        )
                        
                        similarity_results.append({
                            "query": query_name,
                            "alternative": alt_name,
                            "is_similar": similarity.get("is_similar", False),
                            "similarity_score": similarity.get("similarity_score", 0),
                            "explanation": similarity.get("explanation", ""),
                            "shared_functions": similarity.get("shared_functions", [])
                        })
        
        self.similarity_results = similarity_results
        
        # Create a summary DataFrame
        summary_df = pd.DataFrame(similarity_results)
        
        # Return all results
        return {
            "query_results": query_results,
            "alternative_results": alt_results,
            "similarity_results": similarity_results,
            "summary_df": summary_df
        }
    
    def save_results(self, output_dir: str = "validation_results"):
        """Save all results to CSV files."""
        os.makedirs(output_dir, exist_ok=True)
        
        # Convert nested dictionaries to DataFrames
        query_df = pd.DataFrame([
            {"name": name, "cids": ",".join(data["cids"]), 
             "patents": ",".join(data["patents"]), 
             "descriptors": ",".join(data["descriptors"])}
            for name, data in self.query_results.items()
        ])
        
        alt_df = pd.DataFrame([
            {"name": name, "cids": ",".join(data["cids"]), 
             "patents": ",".join(data["patents"]), 
             "descriptors": ",".join(data["descriptors"])}
            for name, data in self.alternative_results.items()
        ])
        
        sim_df = pd.DataFrame(self.similarity_results)
        
        # Save to CSV
        query_df.to_csv(f"{output_dir}/query_compounds.csv", index=False)
        alt_df.to_csv(f"{output_dir}/alternative_compounds.csv", index=False)
        sim_df.to_csv(f"{output_dir}/similarity_results.csv", index=False)
        
        print(f"Results saved to {output_dir}/")

def main():
    GOOGLE_API_KEY = "AIzaSyAdfhL-mt4l_Yt2Dz5eaWRNfbZcPQxzn6Q"
    
    df = alt_df.copy()
    
    # Initialize validation framework
    framework = PatentValidationFramework(df, api_key=GOOGLE_API_KEY, query_col="drug")
    
    # Run validation pipeline (with a small sample for testing)
    results = framework.run_pipeline(sample_size=5)
    
    # Print summary
    summary_df = results["summary_df"]
    print(f"Total compounds analyzed: {len(summary_df)}")
    print(f"Functionally similar compounds: {summary_df['is_similar'].sum()}")
    print(f"Average similarity score: {summary_df['similarity_score'].mean():.2f}")
    
    # Save results
    framework.save_results()

if __name__ == "__main__":
    main()

Processing query compounds...


 40%|████      | 2/5 [01:17<02:01, 40.64s/it]

Error retrieving patent IDs for CID 5311128: 504 Server Error: PUGREST.Timeout for url: https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/5311128/xrefs/PatentID/JSON


 60%|██████    | 3/5 [01:48<01:13, 36.55s/it]

No patents found for Goserelin


100%|██████████| 5/5 [03:01<00:00, 36.33s/it]


Processing alternative compounds...


  0%|          | 0/5 [00:00<?, ?it/s]

Error retrieving patent IDs for CID 91976991: 404 Client Error: PUGREST.NotFound for url: https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/cid/91976991/xrefs/PatentID/JSON


100%|██████████| 5/5 [06:55<00:00, 83.02s/it]


Determining functional similarity...


100%|██████████| 5/5 [00:41<00:00,  8.26s/it]

Total compounds analyzed: 11
Functionally similar compounds: 1
Average similarity score: 12.27
Results saved to validation_results/



