## Import Libraries

In [None]:
#check for gpu
import torch
import pandas as pd
import numpy as np
import xml.etree.ElementTree as ET
import torch
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import base64
import os
from google import genai
from google.genai import types

In [6]:
if torch.backends.mps.is_available():
   mps_device = torch.device("mps")
   x = torch.ones(1, device=mps_device)
   print (x)
else:
   print ("MPS device not found.")

tensor([1.], device='mps:0')


## Using DrugBank Cosine Similarity

In [None]:



# Check for GPU support on Mac
device = "mps" if torch.backends.mps.is_available() else "cpu"

# Load and parse DrugBank XML file
def parse_drugbank(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    
    drugs = []
    for drug in root.findall("{http://www.drugbank.ca}drug"):
        drug_name = drug.find("{http://www.drugbank.ca}name")
        drug_id = drug.find("{http://www.drugbank.ca}drugbank-id")
        indications = drug.find("{http://www.drugbank.ca}indication")

        # Handle missing values safely
        drug_name = drug_name.text.strip() if drug_name is not None else "Unknown"
        drug_id = drug_id.text.strip() if drug_id is not None else "Unknown"
        indications = indications.text.strip() if (indications is not None and indications.text) else "No Indications"

        drugs.append({"Drug": drug_name, "DrugID": drug_id, "Indications": indications})
    
    return pd.DataFrame(drugs)

# Parse DrugBank XML and create a DataFrame
drug_data = parse_drugbank("../Dataset/full database.xml")
drug_data.head()  

# Initialize BioBERT/SciBERT model for text embeddings on GPU
bert_model = SentenceTransformer("all-mpnet-base-v2", device=device)

def get_drug_embedding(drug_name, indications):
    """Generate an embedding for a drug based on its name and indications."""
    text = f"{drug_name}: {indications}"
    return bert_model.encode(text, convert_to_numpy=True)

# Compute embeddings for all drugs
drug_data["Embeddings"] = drug_data.apply(lambda row: get_drug_embedding(row["Drug"], row["Indications"]), axis=1)

# Convert embeddings into a NumPy matrix for fast similarity computation
embeddings_matrix = np.vstack(drug_data["Embeddings"].values)

# Function to find similar drugs with indications
def find_similar_drugs(target_drug, top_n=5):
    if target_drug not in drug_data["Drug"].values:
        return f"Drug '{target_drug}' not found in DrugBank dataset."

    target_embedding = drug_data.loc[drug_data['Drug'] == target_drug, "Embeddings"].values[0]
    
    similarities = cosine_similarity([target_embedding], embeddings_matrix)[0]
    
    # Store results in a DataFrame for sorting
    results_df = drug_data.copy()
    results_df["Similarity"] = similarities
    
    similar_drugs = results_df.sort_values(by="Similarity", ascending=False)[1:top_n+1]  
    
    return similar_drugs[["Drug", "Similarity", "Indications"]]

print(f"Running on device: {device}")  

  from .autonotebook import tqdm as notebook_tqdm


Running on device: mps


In [4]:
similar_drugs = find_similar_drugs("Abarelix")
similar_drugs

Unnamed: 0,Drug,Similarity,Indications
1175,Estramustine,0.658857,For the palliative treatment of patients with ...
5487,Atrasentan,0.653181,Investigated for use/treatment in prostate can...
5829,Degarelix,0.640171,"In Canada and the US, degarelix is indicated f..."
5946,Triptorelin,0.636091,Triptorelin is indicated for the palliative tr...
9938,Apalutamide,0.635937,Apalutamide is indicated for the treatment of ...


**Gemini API**

In [7]:
def query_gemini(prompt):
    client = genai.Client(
        api_key= "AIzaSyAlWzliqQKSKyRW2xBM9op_fToJNyId7R0",
    )

    model = "gemini-2.0-flash"
    contents = [
        types.Content(
            role="user",
            parts=[
                types.Part.from_text(text=prompt),
            ],
        ),
    ]
    generate_content_config = types.GenerateContentConfig(
        temperature=0,
        top_p=0.95,
        top_k=40,
        max_output_tokens=8192,
        response_mime_type="text/plain",
    )

    for chunk in client.models.generate_content_stream(
        model=model,
        contents=contents,
        config=generate_content_config,
    ):
        print(chunk.text, end="")
    return chunk.text

In [10]:
def recommend_best_alternative(target_drug, similar_drugs):
    target_row = drug_data[drug_data['Drug'] == target_drug]
    prompt = f"The drug {target_drug} is used to treat {target_row['Indications'].values[0]}.\n\n"
    prompt += "Here are five alternative drugs with similar properties:\n"

    for i, (drug, indication) in enumerate(zip(similar_drugs['Drug'], similar_drugs['Indications'])):
        prompt += f"{i+1}. {drug} - {indication}\n"

    prompt += "\nBased on their indications, efficacy, and side effect profiles, which drug would you recommend. Only Give name?"

    return query_gemini(prompt)
  
find_similar_drugs("Abarelix")
recommend_best_alternative("Abarelix", find_similar_drugs("Abarelix"))

Degarelix


'arelix\n'