In [8]:
pip install plotly




In [22]:
import pandas as pd
import networkx as nx
import plotly.graph_objects as go

# Load the CSV file
file_path = '/content/Task1_pubmed_secondary_metabolites_Bacteria.csv'
data = pd.read_csv(file_path)

# Initialize a directed graph
G = nx.DiGraph()

# Define functions to add nodes and edges
def add_nodes_from_column(column_name, node_type):
    unique_items = data[column_name].dropna().unique()
    for item in unique_items:
        G.add_node(item, type=node_type)

def add_edges_from_columns(source_col, target_col, edge_type):
    for index, row in data.iterrows():
        source = row[source_col]
        target = row[target_col]
        if pd.notna(source) and pd.notna(target):
            G.add_edge(source, target, type=edge_type)

# Adding nodes for different entities
add_nodes_from_column('Gene', 'gene')
add_nodes_from_column('Chemical', 'chemical')
add_nodes_from_column('Disease', 'disease')
add_nodes_from_column('Species', 'species')

# Adding edges for different relationships
add_edges_from_columns('Gene', 'Disease', 'association')  # Gene-Disease associations
add_edges_from_columns('Chemical', 'Gene', 'interaction')  # Chemical-Gene interactions
add_edges_from_columns('Species', 'Gene', 'expression')  # Species-Gene expression

# Select top 200 nodes based on degree centrality
deg_centrality = dict(G.degree())
top_nodes = sorted(deg_centrality, key=deg_centrality.get, reverse=True)[:200]

# Create a subgraph with top 200 nodes
H = G.subgraph(top_nodes).copy()

# Compute positions for nodes
pos = nx.spring_layout(H, k=0.3, seed=42)  # Adjust layout parameters

# Create Plotly plot
edge_x = []
edge_y = []
for edge in H.edges():
    x0, y0 = pos[edge[0]]
    x1, y1 = pos[edge[1]]
    edge_x.append(x0)
    edge_x.append(x1)
    edge_y.append(y0)
    edge_y.append(y1)

node_x = [pos[node][0] for node in H.nodes()]
node_y = [pos[node][1] for node in H.nodes()]

# Define color map for different node types
node_color_map = {
    'gene': 'blue',
    'chemical': 'green',
    'disease': 'red',
    'species': 'purple'
}

# Ensure that node types are being retrieved correctly and apply colors
node_colors = [node_color_map.get(H.nodes[n]['type'], 'grey') for n in H.nodes()]
node_sizes = [deg_centrality[node] * 1.5 for node in H.nodes()]  # Smaller node size for better fit

# Define color map for different edge types
edge_color_map = {
    'association': 'rgba(255, 0, 0, 0.3)',  # Lighter red for associations
    'interaction': 'rgba(0, 255, 0, 0.3)',  # Lighter green for interactions
    'expression': 'rgba(0, 0, 255, 0.3)'    # Lighter blue for expressions
}
edge_colors = [edge_color_map.get(H.edges[edge]['type'], 'rgba(0, 0, 0, 0.3)') for edge in H.edges()]

# Create figure
fig = go.Figure()

# Add edges
fig.add_trace(go.Scatter(
    x=edge_x, y=edge_y,
    mode='lines',
    line=dict(width=0.5),  # Thicker edges for better visibility
    marker=dict(color=edge_colors),
    name='Relationships'
))

# Add nodes
fig.add_trace(go.Scatter(
    x=node_x, y=node_y,
    mode='markers+text',
    text=[f'{n}' for n in H.nodes()],
    textposition='top center',
    marker=dict(size=node_sizes, color=node_colors, opacity=0.8),  # Increased opacity
    name='Entities'
))

# Add a legend
legend_labels = {
    'gene': 'Genes (Blue)',
    'chemical': 'Chemicals (Green)',
    'disease': 'Diseases (Red)',
    'species': 'Species (Purple)'
}

# Add nodes for legend
for node_type, color in node_color_map.items():
    fig.add_trace(go.Scatter(
        x=[None], y=[None],  # Empty plot to create legend entry
        mode='markers',
        marker=dict(size=15, color=color),
        name=legend_labels[node_type]
    ))

fig.update_layout(
    title='Biological Knowledge Graph (Top 200 Nodes)',
    showlegend=True,
    xaxis=dict(showgrid=False, zeroline=False),
    yaxis=dict(showgrid=False, zeroline=False),
    paper_bgcolor='white',
    width=800,  # Adjust figure width for better fit
    height=800   # Adjust figure height for better fit
)

# Save plot to HTML file
fig.write_html('biological_knowledge_graph_200_resized.html')
fig.show()
