# Network Graph & Heatmap Visualizations

This notebook creates professional visualizations:
1. **Network Graph**: Shows trial relationships as a graph
2. **Similarity Heatmap**: Matrix view of trial similarities

We'll use a small subset (15-20 trials) for fast generation.

In [None]:
from neo4j import GraphDatabase
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from matplotlib.patches import Rectangle

# Set style
sns.set_style('white')
plt.rcParams['figure.figsize'] = (12, 10)

# Connection details
URI = "neo4j://127.0.0.1:7687"
AUTH = ("neo4j", "12345678")

## Step 1: Get a Sample of Trials

In [None]:
def find_similar_trials_jaccard(driver, trial_id, top_n=10):
    """
    Find similar trials using Jaccard similarity
    """
    query = """
    MATCH (input:SubjectNode {name: $trial_id})
    MATCH (input)-[:RELATIONSHIP]-(inputNeighbor:ObjectNode)
    WITH input, COLLECT(DISTINCT inputNeighbor) AS inputNeighbors
    
    MATCH (other:SubjectNode)
    WHERE other <> input
    
    MATCH (other)-[:RELATIONSHIP]-(otherNeighbor:ObjectNode)
    WITH input, inputNeighbors, other, COLLECT(DISTINCT otherNeighbor) AS otherNeighbors
    
    WITH input, other,
         inputNeighbors,
         otherNeighbors,
         [n IN inputNeighbors WHERE n IN otherNeighbors] AS intersection
    WITH input, other,
         SIZE(intersection) AS intersectionSize,
         SIZE(inputNeighbors) + SIZE(otherNeighbors) - SIZE(intersection) AS unionSize
    
    WITH other.name AS similarTrial,
         CASE WHEN unionSize = 0 THEN 0.0 
              ELSE toFloat(intersectionSize) / toFloat(unionSize) 
         END AS similarity
    
    WHERE similarity > 0
    RETURN similarTrial, similarity
    ORDER BY similarity DESC
    LIMIT $top_n
    """
    
    with driver.session() as session:
        result = session.run(query, trial_id=trial_id, top_n=top_n)
        return [(record["similarTrial"], record["similarity"]) for record in result]

# Get a sample of trials from database
print("Fetching sample trials from database...")

try:
    driver = GraphDatabase.driver(URI, auth=AUTH)
    
    # Get 5 seed trials
    with driver.session() as session:
        result = session.run("""
            MATCH (n:SubjectNode)
            RETURN n.name AS trial_id
            LIMIT 5
        """)
        seed_trials = [record["trial_id"] for record in result]
    
    print(f"Selected {len(seed_trials)} seed trials: {seed_trials}")
    
    # For each seed, get top 3 similar trials
    all_trials = set(seed_trials)
    similarity_edges = []
    
    for trial in seed_trials:
        print(f"\nFinding similar trials for {trial}...")
        similar = find_similar_trials_jaccard(driver, trial, top_n=3)
        
        for sim_trial, score in similar:
            all_trials.add(sim_trial)
            similarity_edges.append((trial, sim_trial, score))
            print(f"  - {sim_trial}: {score:.4f}")
    
    driver.close()
    
    print(f"\nâœ“ Total unique trials collected: {len(all_trials)}")
    print(f"âœ“ Total similarity edges: {len(similarity_edges)}")
    
except Exception as e:
    print(f"Error: {e}")
    raise

## Step 2: Create Similarity Matrix for Heatmap

In [None]:
# Build comprehensive similarity matrix
print("Building similarity matrix...\n")

trials_list = sorted(list(all_trials))
n = len(trials_list)

# Initialize matrix with zeros
similarity_matrix = np.zeros((n, n))

# Create trial index mapping
trial_to_idx = {trial: i for i, trial in enumerate(trials_list)}

# Fill matrix with existing similarities
for trial1, trial2, score in similarity_edges:
    i = trial_to_idx[trial1]
    j = trial_to_idx[trial2]
    similarity_matrix[i, j] = score
    # Note: Jaccard is symmetric, but we only have one direction
    # We'll compute the reverse if needed

# Compute missing similarities (optional - can be slow)
print("Computing additional similarities for complete matrix...")

try:
    driver = GraphDatabase.driver(URI, auth=AUTH)
    
    computed = 0
    for i, trial1 in enumerate(trials_list):
        for j, trial2 in enumerate(trials_list):
            if i != j and similarity_matrix[i, j] == 0:
                # Check if reverse exists
                if similarity_matrix[j, i] > 0:
                    similarity_matrix[i, j] = similarity_matrix[j, i]
                else:
                    # Compute similarity
                    similar = find_similar_trials_jaccard(driver, trial1, top_n=n)
                    similar_dict = dict(similar)
                    if trial2 in similar_dict:
                        similarity_matrix[i, j] = similar_dict[trial2]
                        computed += 1
    
    driver.close()
    print(f"âœ“ Computed {computed} additional similarities")
    
except Exception as e:
    print(f"Warning: {e}")
    print("Continuing with partial matrix...")

# Set diagonal to 1 (trial is 100% similar to itself)
np.fill_diagonal(similarity_matrix, 1.0)

print(f"\nâœ“ Similarity matrix shape: {similarity_matrix.shape}")
print(f"âœ“ Non-zero entries: {np.count_nonzero(similarity_matrix)} / {n*n}")

## Step 3: Create Similarity Heatmap

In [None]:
# Create heatmap
plt.figure(figsize=(14, 12))

# Create custom labels (shorter for readability)
labels = [f"{trial[-8:]}" for trial in trials_list]  # Last 8 chars

# Create heatmap
sns.heatmap(similarity_matrix, 
            annot=True, 
            fmt='.2f', 
            cmap='YlOrRd',
            xticklabels=labels,
            yticklabels=labels,
            linewidths=0.5,
            linecolor='gray',
            cbar_kws={'label': 'Jaccard Similarity'},
            vmin=0, 
            vmax=1,
            square=True)

plt.title('Clinical Trial Similarity Heatmap (Jaccard Index)', 
          fontsize=16, fontweight='bold', pad=20)
plt.xlabel('Trial ID (last 8 digits)', fontsize=12, labelpad=10)
plt.ylabel('Trial ID (last 8 digits)', fontsize=12, labelpad=10)

# Rotate labels for better readability
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)

plt.tight_layout()
plt.savefig('similarity_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

print("âœ“ Heatmap saved as 'similarity_heatmap.png'")

## Step 4: Create Network Graph

In [None]:
# Create network graph
print("Building network graph...\n")

G = nx.Graph()

# Add all nodes
for trial in all_trials:
    G.add_node(trial)

# Add edges (only above threshold to reduce clutter)
threshold = 0.2  # Only show edges with similarity > 0.2

for trial1, trial2, score in similarity_edges:
    if score > threshold:
        G.add_edge(trial1, trial2, weight=score)

print(f"Network Statistics:")
print(f"  - Nodes: {G.number_of_nodes()}")
print(f"  - Edges: {G.number_of_edges()}")
print(f"  - Average degree: {sum(dict(G.degree()).values()) / G.number_of_nodes():.2f}")

# Calculate node sizes based on degree (number of connections)
degrees = dict(G.degree())
node_sizes = [degrees[node] * 300 + 200 for node in G.nodes()]

# Calculate edge widths based on similarity
edges = G.edges()
weights = [G[u][v]['weight'] for u, v in edges]
edge_widths = [w * 4 for w in weights]

# Use different colors for seed trials vs discovered trials
node_colors = ['#FF6B6B' if node in seed_trials else '#4ECDC4' for node in G.nodes()]

# Create plot
plt.figure(figsize=(16, 14))

# Use spring layout for better visualization
pos = nx.spring_layout(G, k=2, iterations=50, seed=42)

# Draw edges first (so they appear behind nodes)
nx.draw_networkx_edges(G, pos, 
                       width=edge_widths,
                       alpha=0.6,
                       edge_color='gray')

# Draw nodes
nx.draw_networkx_nodes(G, pos,
                       node_size=node_sizes,
                       node_color=node_colors,
                       alpha=0.9,
                       edgecolors='black',
                       linewidths=2)

# Draw labels with shorter names
labels_dict = {node: node[-8:] for node in G.nodes()}  # Last 8 chars
nx.draw_networkx_labels(G, pos, 
                        labels=labels_dict,
                        font_size=9,
                        font_weight='bold')

plt.title('Clinical Trial Similarity Network\n(Node size = # of connections, Edge width = similarity strength)', 
          fontsize=16, fontweight='bold', pad=20)

# Add legend
from matplotlib.lines import Line2D
legend_elements = [
    Line2D([0], [0], marker='o', color='w', label='Seed Trial',
           markerfacecolor='#FF6B6B', markersize=12, markeredgecolor='black', markeredgewidth=2),
    Line2D([0], [0], marker='o', color='w', label='Similar Trial',
           markerfacecolor='#4ECDC4', markersize=12, markeredgecolor='black', markeredgewidth=2)
]
plt.legend(handles=legend_elements, loc='upper right', fontsize=11)

plt.axis('off')
plt.tight_layout()
plt.savefig('network_graph.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.show()

print("\nâœ“ Network graph saved as 'network_graph.png'")

## Step 5: Create Enhanced Network with Edge Labels (Optional)

In [None]:
# Alternative: Smaller network with similarity scores on edges
# Select top 8 most connected trials for clarity

top_trials = sorted(degrees.items(), key=lambda x: x[1], reverse=True)[:8]
top_trial_ids = [trial for trial, degree in top_trials]

print(f"Creating focused network with top {len(top_trial_ids)} trials...\n")
print("Most connected trials:")
for trial, degree in top_trials:
    print(f"  - {trial}: {degree} connections")

# Create subgraph
G_small = G.subgraph(top_trial_ids).copy()

plt.figure(figsize=(14, 12))

# Layout
pos_small = nx.spring_layout(G_small, k=3, iterations=50, seed=42)

# Draw edges
edges_small = G_small.edges()
weights_small = [G_small[u][v]['weight'] for u, v in edges_small]
edge_widths_small = [w * 5 for w in weights_small]

nx.draw_networkx_edges(G_small, pos_small, 
                       width=edge_widths_small,
                       alpha=0.5,
                       edge_color='gray')

# Draw nodes
degrees_small = dict(G_small.degree())
node_sizes_small = [degrees_small[node] * 500 + 300 for node in G_small.nodes()]
node_colors_small = ['#FF6B6B' if node in seed_trials else '#4ECDC4' for node in G_small.nodes()]

nx.draw_networkx_nodes(G_small, pos_small,
                       node_size=node_sizes_small,
                       node_color=node_colors_small,
                       alpha=0.9,
                       edgecolors='black',
                       linewidths=2.5)

# Draw labels
nx.draw_networkx_labels(G_small, pos_small,
                        font_size=10,
                        font_weight='bold')

# Draw edge labels (similarity scores)
edge_labels = {(u, v): f"{G_small[u][v]['weight']:.2f}" for u, v in edges_small}
nx.draw_networkx_edge_labels(G_small, pos_small, 
                             edge_labels,
                             font_size=8,
                             bbox=dict(boxstyle='round,pad=0.3', 
                                      facecolor='yellow', 
                                      alpha=0.7,
                                      edgecolor='none'))

plt.title('Focused Network: Top 8 Most Connected Trials\n(Edge labels show Jaccard similarity)', 
          fontsize=16, fontweight='bold', pad=20)

# Legend
legend_elements = [
    Line2D([0], [0], marker='o', color='w', label='Seed Trial',
           markerfacecolor='#FF6B6B', markersize=12, markeredgecolor='black', markeredgewidth=2),
    Line2D([0], [0], marker='o', color='w', label='Similar Trial',
           markerfacecolor='#4ECDC4', markersize=12, markeredgecolor='black', markeredgewidth=2)
]
plt.legend(handles=legend_elements, loc='upper right', fontsize=12)

plt.axis('off')
plt.tight_layout()
plt.savefig('network_graph_focused.png', dpi=300, bbox_inches='tight', facecolor='white')
plt.show()

print("\nâœ“ Focused network saved as 'network_graph_focused.png'")

## Step 6: Summary

In [None]:
print("\n" + "="*70)
print("VISUALIZATION SUMMARY")
print("="*70)

print(f"\nðŸ“Š Generated Visualizations:")
print(f"   1. similarity_heatmap.png - Matrix view of {len(trials_list)}x{len(trials_list)} similarities")
print(f"   2. network_graph.png - Full network with {G.number_of_nodes()} nodes")
print(f"   3. network_graph_focused.png - Focused view of top 8 trials")

print(f"\nðŸŽ¨ Key Features:")
print(f"   - Heatmap uses color intensity to show similarity strength")
print(f"   - Network node size represents number of connections")
print(f"   - Network edge width represents similarity strength")
print(f"   - Red nodes = seed trials, Teal nodes = discovered similar trials")

print(f"\nðŸ“ˆ Network Statistics:")
print(f"   - Total trials analyzed: {len(all_trials)}")
print(f"   - Similarity threshold: {threshold} (edges below this are hidden)")
print(f"   - Average connections per trial: {sum(degrees.values()) / len(degrees):.2f}")
print(f"   - Most connected trial: {max(degrees, key=degrees.get)} ({max(degrees.values())} connections)")

print("\n" + "="*70)
print("âœ“ All visualizations generated successfully!")
print("="*70)