# Finding Similar Clinical Trials Using Graph Embeddings (Alternative Method)

This notebook uses **Custom Random Walks + TF-IDF** for graph embeddings - **NO Node2Vec, NO gensim, NO numpy issues!**

## Install Required Libraries (Minimal Dependencies)

In [None]:
!pip install neo4j networkx scikit-learn

In [None]:
from neo4j import GraphDatabase
import networkx as nx
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import random
import numpy as np

# Connection details for LOCAL Neo4j
URI = "neo4j://127.0.0.1:7687"
AUTH = ("neo4j", "12345678")

## Step 1: Fetch Graph from Neo4j

In [None]:
def fetch_graph_from_neo4j(driver):
    """
    Fetch the entire graph from Neo4j and convert to NetworkX
    """
    query = """
    MATCH (n1)-[r:RELATIONSHIP]-(n2)
    RETURN n1.name AS node1, n2.name AS node2,
           labels(n1) AS labels1, labels(n2) AS labels2
    """
    
    G = nx.Graph()
    
    with driver.session() as session:
        result = session.run(query)
        
        print("Loading graph from Neo4j...")
        edge_count = 0
        
        for record in result:
            node1 = record["node1"]
            node2 = record["node2"]
            labels1 = list(record["labels1"])
            labels2 = list(record["labels2"])
            
            # Add nodes with their labels
            G.add_node(node1, label=labels1[0] if labels1 else "Unknown")
            G.add_node(node2, label=labels2[0] if labels2 else "Unknown")
            
            # Add edge
            G.add_edge(node1, node2)
            edge_count += 1
            
            if edge_count % 5000 == 0:
                print(f"Loaded {edge_count} edges...")
        
        print(f"\n✓ Graph loaded successfully!")
        print(f"Total nodes: {G.number_of_nodes()}")
        print(f"Total edges: {G.number_of_edges()}")
    
    return G

## Step 2: Custom Random Walk Implementation

In [None]:
def random_walk(G, start_node, walk_length=30):
    """
    Perform a single random walk starting from start_node
    """
    walk = [start_node]
    
    for _ in range(walk_length - 1):
        current = walk[-1]
        neighbors = list(G.neighbors(current))
        
        if not neighbors:
            break
        
        # Randomly choose next node
        next_node = random.choice(neighbors)
        walk.append(next_node)
    
    return walk


def generate_walks(G, num_walks=100, walk_length=30):
    """
    Generate random walks for all nodes in the graph
    """
    print(f"\nGenerating random walks...")
    print(f"Parameters: num_walks={num_walks}, walk_length={walk_length}")
    
    walks = []
    nodes = list(G.nodes())
    
    for i, node in enumerate(nodes):
        if (i + 1) % 1000 == 0:
            print(f"Generated walks for {i + 1}/{len(nodes)} nodes...")
        
        for _ in range(num_walks):
            walk = random_walk(G, node, walk_length)
            walks.append(walk)
    
    print(f"✓ Generated {len(walks)} random walks!")
    return walks

## Step 3: Create Embeddings Using TF-IDF

In [None]:
def create_node_contexts(walks, G):
    """
    Create context documents for each node from random walks
    """
    print("\nCreating node context documents...")
    
    node_contexts = {node: [] for node in G.nodes()}
    
    # Collect all walks for each node
    for walk in walks:
        for node in walk:
            # Add all other nodes in the walk as context
            context = [n for n in walk if n != node]
            node_contexts[node].extend(context)
    
    # Convert to text documents (space-separated node names)
    node_documents = {}
    for node, context in node_contexts.items():
        if context:
            node_documents[node] = ' '.join(context)
        else:
            node_documents[node] = node  # Isolated nodes
    
    print(f"✓ Created context documents for {len(node_documents)} nodes")
    return node_documents


def generate_tfidf_embeddings(node_documents):
    """
    Generate TF-IDF embeddings from node context documents
    """
    print("\nGenerating TF-IDF embeddings...")
    
    nodes = list(node_documents.keys())
    documents = [node_documents[node] for node in nodes]
    
    # Create TF-IDF vectorizer
    vectorizer = TfidfVectorizer(max_features=1000)
    tfidf_matrix = vectorizer.fit_transform(documents)
    
    # Create node to embedding mapping
    embeddings = {}
    for i, node in enumerate(nodes):
        embeddings[node] = tfidf_matrix[i]
    
    print(f"✓ Generated embeddings with {tfidf_matrix.shape[1]} dimensions")
    return embeddings

## Step 4: Find Similar Trials

In [None]:
def get_subject_nodes(G):
    """
    Get all SubjectNode (clinical trial) nodes
    """
    subject_nodes = [node for node, data in G.nodes(data=True) 
                     if data.get('label') == 'SubjectNode']
    return subject_nodes


def find_similar_trials_embedding(embeddings, G, trial_id, top_n=10):
    """
    Find similar trials using cosine similarity on embeddings
    """
    # Check if trial exists
    if trial_id not in embeddings:
        return []
    
    # Get all subject nodes
    subject_nodes = get_subject_nodes(G)
    
    # Get embedding for the query trial
    query_embedding = embeddings[trial_id]
    
    # Calculate similarities
    similarities = []
    for node in subject_nodes:
        if node == trial_id or node not in embeddings:
            continue
        
        node_embedding = embeddings[node]
        similarity = cosine_similarity(query_embedding, node_embedding)[0][0]
        similarities.append((node, float(similarity)))
    
    # Sort by similarity (descending) and return top N
    similarities.sort(key=lambda x: x[1], reverse=True)
    return similarities[:top_n]


def check_trial_exists(G, trial_id):
    """
    Check if a trial ID exists in the graph
    """
    return trial_id in G.nodes() and G.nodes[trial_id].get('label') == 'SubjectNode'

## Step 5: Load Graph and Generate Embeddings (Run Once)

In [None]:
# Load graph from Neo4j
print("Step 1: Loading graph from Neo4j...")
driver = GraphDatabase.driver(URI, auth=AUTH)
G = fetch_graph_from_neo4j(driver)
driver.close()

# Generate random walks
print("\nStep 2: Generating random walks...")
walks = generate_walks(G, num_walks=50, walk_length=30)

# Create node contexts
print("\nStep 3: Creating node contexts...")
node_documents = create_node_contexts(walks, G)

# Generate embeddings
print("\nStep 4: Generating embeddings...")
embeddings = generate_tfidf_embeddings(node_documents)

print("\n" + "="*70)
print("✓ READY! You can now search for similar trials.")
print("="*70)

## Test with Single Trial ID

In [None]:
# Single trial example
trial_id = "NCT00752622"  # Change this to any trial ID you want

if not check_trial_exists(G, trial_id):
    print(f"Trial {trial_id} not found in database!")
else:
    print(f"Finding similar trials to {trial_id} using Graph Embeddings...\n")
    
    # Find similar trials
    similar_trials = find_similar_trials_embedding(embeddings, G, trial_id, top_n=10)
    
    if similar_trials:
        print(f"Top 10 Similar Trials to {trial_id}:")
        print("=" * 60)
        for i, (similar_trial, similarity) in enumerate(similar_trials, 1):
            print(f"{i:2d}. {similar_trial}: {similarity:.4f}")
    else:
        print("No similar trials found.")

## Test with Multiple Trial IDs

In [None]:
# Multiple trials example
trial_ids = ["NCT00385736", "NCT00386607", "NCT03518073"]

for trial_id in trial_ids:
    print(f"\n{'='*70}")
    print(f"Trial ID: {trial_id}")
    print(f"{'='*70}")
    
    # Check if trial exists
    if not check_trial_exists(G, trial_id):
        print(f"Trial {trial_id} not found in database!")
        continue
    
    # Find similar trials
    similar_trials = find_similar_trials_embedding(embeddings, G, trial_id, top_n=10)
    
    if similar_trials:
        print(f"\nTop 10 Similar Trials:")
        for i, (similar_trial, similarity) in enumerate(similar_trials, 1):
            print(f"{i:2d}. {similar_trial}: Similarity = {similarity:.4f}")
    else:
        print("No similar trials found.")

## Interactive Input

In [None]:
# Interactive version - get input from user
trial_input = input("Enter trial ID(s) separated by commas: ")
trial_ids = [t.strip() for t in trial_input.split(',') if t.strip()]

for trial_id in trial_ids:
    print(f"\n{'='*70}")
    print(f"Trial ID: {trial_id}")
    print(f"{'='*70}")
    
    # Check if trial exists
    if not check_trial_exists(G, trial_id):
        print(f"❌ Trial {trial_id} not found in database!")
        continue
    
    # Find similar trials
    similar_trials = find_similar_trials_embedding(embeddings, G, trial_id, top_n=10)
    
    if similar_trials:
        print(f"\n✓ Top 10 Similar Trials (Graph Embeddings):")
        for i, (similar_trial, similarity) in enumerate(similar_trials, 1):
            print(f"{i:2d}. {similar_trial}: Similarity = {similarity:.4f}")
    else:
        print("No similar trials found.")

## Compare with Jaccard Similarity (Optional)

In [None]:
def find_similar_trials_jaccard(driver, trial_id, top_n=10):
    """
    Find similar trials using Jaccard similarity (for comparison)
    """
    query = """
    MATCH (input:SubjectNode {name: $trial_id})
    MATCH (input)-[:RELATIONSHIP]-(inputNeighbor:ObjectNode)
    WITH input, COLLECT(DISTINCT inputNeighbor) AS inputNeighbors
    
    MATCH (other:SubjectNode)
    WHERE other <> input
    
    MATCH (other)-[:RELATIONSHIP]-(otherNeighbor:ObjectNode)
    WITH input, inputNeighbors, other, COLLECT(DISTINCT otherNeighbors) AS otherNeighbors
    
    WITH input, other,
         inputNeighbors,
         otherNeighbors,
         [n IN inputNeighbors WHERE n IN otherNeighbors] AS intersection
    WITH input, other,
         SIZE(intersection) AS intersectionSize,
         SIZE(inputNeighbors) + SIZE(otherNeighbors) - SIZE(intersection) AS unionSize
    
    WITH other.name AS similarTrial,
         CASE WHEN unionSize = 0 THEN 0.0 
              ELSE toFloat(intersectionSize) / toFloat(unionSize) 
         END AS similarity
    
    WHERE similarity > 0
    RETURN similarTrial, similarity
    ORDER BY similarity DESC
    LIMIT $top_n
    """
    
    with driver.session() as session:
        result = session.run(query, trial_id=trial_id, top_n=top_n)
        return [(record["similarTrial"], record["similarity"]) for record in result]


# Compare both methods
trial_id = "NCT00385736"

print(f"Comparing Jaccard vs Graph Embeddings for {trial_id}\n")

# Jaccard
driver = GraphDatabase.driver(URI, auth=AUTH)
jaccard_results = find_similar_trials_jaccard(driver, trial_id, top_n=10)
driver.close()

# Embeddings
embedding_results = find_similar_trials_embedding(embeddings, G, trial_id, top_n=10)

# Display side by side
print("="*70)
print(f"{'JACCARD SIMILARITY':<35} | {'GRAPH EMBEDDINGS':<35}")
print("="*70)

for i in range(10):
    jaccard_str = f"{jaccard_results[i][0]}: {jaccard_results[i][1]:.4f}" if i < len(jaccard_results) else "-"
    embed_str = f"{embedding_results[i][0]}: {embedding_results[i][1]:.4f}" if i < len(embedding_results) else "-"
    print(f"{i+1:2d}. {jaccard_str:<32} | {embed_str:<32}")