In [None]:
import math
import numpy as np
import numpy
from PIL import Image, ImageDraw
from PIL import ImagePath
import pandas as pd
import os
from os import path
from tqdm import tqdm
import json
import cv2
import matplotlib.pyplot as plt
import urllib

In [None]:
import tensorflow as tf
from tensorflow import keras
import matplotlib.pyplot as plt
import imgaug.augmenters as iaa
from skimage.measure import label, regionprops
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from multiprocessing import Pool, cpu_count
import community  # Louvain algorithm package

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

In [None]:
os.listdir('data')

# Step 1: Data Preprocessing


## Load the annotated cell images and masks

In [None]:
mask_dir = 'data/bwmask'

In [None]:
def get_file_names(root_dir):
    image = []
    mask = []
    label= []
    for i in sorted(os.listdir(root_dir)):        
              for frame in sorted(os.listdir(path.join (root_dir,i))):
                  if i == 'bwmask':
                    mask.append(path.join (root_dir,i,frame))
                  elif i == 'label':
                    label.append(path.join(root_dir,i,frame))
                  else:
                    image.append(path.join(root_dir,i,frame))

    data_df = pd.DataFrame(data =(zip(image,mask,label)),columns = ['image','bwmask','label'])
    return data_df

In [None]:
data_df = get_file_names('data/')
data_df.tail(10)

### Display the images per row

In [None]:
def strip_path(path):
  """Strips the path and returns only the file name.

  Args:
    path: The path to the file.

  Returns:
    The file name.
  """

  filename = os.path.basename(path)
  return filename


def display_images(row):
    """
    Displays all image formats in a row.

    Args:
        row: The row number of the image.

    """
    image_path = row['image']
    bwmask_path = row['bwmask']
    label_path = row['label']

    fig, axs = plt.subplots(1, 3, figsize=(10, 4))
    
    img1 = Image.open(image_path)
    imarray1 = numpy.array(img1)
    axs[0].imshow(img1)
    axs[0].set_title(strip_path(image_path))
    axs[0].axis('off')
    axs[0].text(0.5, -0.15, f'Shape: {imarray1.shape}', ha='center', transform=axs[0].transAxes)

    img2 = Image.open(bwmask_path)
    imarray2 = numpy.array(img2)
    axs[1].imshow(img2)
    axs[1].set_title(strip_path(bwmask_path))
    axs[1].axis('off')
    axs[1].text(0.5, -0.15, f'Shape: {imarray2.shape}', ha='center', transform=axs[1].transAxes)

    img3 = Image.open(label_path)
    imarray3 = numpy.array(img3)
    axs[2].imshow(img3)
    axs[2].set_title(strip_path(label_path))
    axs[2].axis('off')
    axs[2].text(0.5, -0.15, f'Shape: {imarray3.shape}', ha='center', transform=axs[2].transAxes)


    plt.tight_layout()

    plt.show()

# Display images from the first row of the DataFrame
display_images(data_df.iloc[0])


In [None]:
def tile(filename, dir_in, dir_out, d, image_or_mask):
    lst = []
    name, ext = os.path.splitext(filename)
    img = Image.open(os.path.join(dir_in, filename))
    w, h = img.size
    grid = product(range(0, h - h % d, d), range(0, w - w % d, d))
    for i, j in grid:
        df = {
            f'{image_or_mask}_name': filename,
            f'{image_or_mask}_path': os.path.join(dir_in, filename),
        }
        box = (j, i, j + d, i + d)
        out = os.path.join(dir_out, f'{name}_{i}_{j}{ext}')
        df[f'sliced_{image_or_mask}_path'] = out
        img.crop(box).save(out)
        lst.append(df)
    return lst


In [None]:

def process_images(data_df, dir_in, dir_out, d):
    data_dict = []
    for index, row in tqdm(data_df.iterrows()):
        if not os.path.isdir(os.path.join(dir_out, "image")):
            os.makedirs(os.path.join(dir_out, "image"))
        if not os.path.isdir(os.path.join(dir_out, "bwmask")):
            os.makedirs(os.path.join(dir_out, "bwmask"))

        image_data = tile(os.path.basename(row['image']), dir_in, os.path.join(dir_out, "image"), d, 'image')
        mask_data = tile(os.path.basename(row['bwmask']), dir_in, os.path.join(dir_out, "bwmask"), d, 'bwmask')

        merged_data = []
        for i in range(len(mask_data)):
            mask_data[i].update(image_data[i])
            merged_data.extend(mask_data)

        data_dict.extend(merged_data)

    processed_df = pd.DataFrame(data_dict)
    return processed_df


In [None]:
# processed_df = process_images(data_df, dir_in, dir_out, d=256)
# Processed_df

## Extract protein localization and interaction information from the masks


In [None]:
from skimage.measure import label, regionprops
import matplotlib.pyplot as plt

def extract_protein_information(mask_dir, batch_size=5, min_area_threshold=50):
    mask_files = os.listdir(mask_dir)
    num_masks = len(mask_files)

    # Process masks sequentially
    for i in range(0, num_masks, batch_size):
        batch_mask_files = mask_files[i:i + batch_size]

        protein_info = []  # Clear protein_info after each batch

        for mask_file in batch_mask_files:
            mask_path = os.path.join(mask_dir, mask_file)
            image_name, _ = os.path.splitext(mask_file)

            # Use context manager for Image.open
            with Image.open(mask_path) as img:
                mask = np.array(img)

                # Perform connected component labeling
                labeled_mask = label(mask)

                # Calculate protein localization information
                for region in regionprops(labeled_mask):
                    if region.area >= min_area_threshold:  # Set a minimum area threshold to filter out noise
                        y, x = region.centroid
                        protein_info.append({
                            'image_name': image_name,
                            'protein_id': region.label,
                            'area': region.area,
                            'centroid_x': x,
                            'centroid_y': y,
                            'bbox': region.bbox
                        })

        # Identify interacting regions
        if len(protein_info) > 1:
            interacting_pairs = set()
            for i in range(len(protein_info)):
                for j in range(i + 1, len(protein_info)):
                    bbox_i = protein_info[i]['bbox']
                    bbox_j = protein_info[j]['bbox']
                    if do_boxes_intersect(bbox_i, bbox_j):
                        interacting_pairs.add((protein_info[i]['protein_id'], protein_info[j]['protein_id']))

            # Add interacting pairs to protein_info
            for pair in interacting_pairs:
                protein_info.append({
                    'image_name': image_name,
                    'protein_id': f'interaction_{pair[0]}_{pair[1]}',
                    'area': -1,  # Indicate that it's an interaction region
                    'centroid_x': -1,
                    'centroid_y': -1,
                    'bbox': (-1, -1, -1, -1)
                })

        # Use yield to return each batch's DataFrame
        yield pd.DataFrame(protein_info)
        
def do_boxes_intersect(bbox1, bbox2):
    # Check if two bounding boxes intersect
    y1_min, x1_min, y1_max, x1_max = bbox1
    y2_min, x2_min, y2_max, x2_max = bbox2
    return not (x1_max < x2_min or x1_min > x2_max or y1_max < y2_min or y1_min > y2_max)



In [None]:
for batch_data in extract_protein_information(mask_dir, batch_size=64, min_area_threshold=50):
    print(batch_data.head())

In [None]:
batch_data.head()

# Step 2: Graph Network Generation

In [None]:
import networkx as nx

def create_cell_network_graph(protein_data):
    G = nx.Graph()

    # Add nodes to the graph
    for _, row in protein_data.iterrows():
        G.add_node(str(row['protein_id']), area=row['area'], centroid=(row['centroid_x'], row['centroid_y']))

    # Add edges for interacting pairs
    for _, row in protein_data.iterrows():
        if str(row['protein_id']).startswith('interaction_'):
            _, node1, node2 = str(row['protein_id']).split('_')
            G.add_edge(node1, node2)

    return G


### Create a graph data structure to represent the cell network

In [None]:
protein_data = pd.concat(list(extract_protein_information(mask_dir, batch_size=64, min_area_threshold=50)))
cell_network_graph = create_cell_network_graph(protein_data)
cell_network_graph

In [None]:
# Visualization of the cell_network_graph
pos = nx.get_node_attributes(cell_network_graph, 'centroid')  # Get positions of nodes

node_sizes = [data['area'] for node, data in cell_network_graph.nodes(data=True)]
nx.draw_networkx_nodes(cell_network_graph, pos, node_size=node_sizes, node_color='skyblue', alpha=0.7)

nx.draw_networkx_edges(cell_network_graph, pos, edge_color='gray', alpha=0.5)
labels = {node: str(node) for node in cell_network_graph.nodes()}
nx.draw_networkx_labels(cell_network_graph, pos, labels, font_size=10, font_color='black')

# Set plot properties
plt.title("Cell Network Graph Visualization")
plt.axis('off')
plt.show()

# Network Analysis using Louvain Algorithm

In [None]:
from community import community_louvain
# import community as community_louvain  # Louvain algorithm package

partition = community_louvain.best_partition(cell_network_graph)

# Use NetworkX library to implement the Louvain algorithm for community detection
# Partition the cell network into communities (protein complexes)
# Analyze the communities to identify important nodes and pathways
# Analyze the communities
communities = {}
for node, community_id in partition.items():
    if community_id not in communities:
        communities[community_id] = []
    communities[community_id].append(node)

# Print the communities
print("Identified Communities (Protein Complexes):")
for community_id, nodes in communities.items():
    print(f"Community {community_id}: {nodes}")


# Network Visualization


#### Use NetworkX or other libraries like Matplotlib or Plotly to visualize the cell network
#### Visualize protein complexes as separate subgraphs or colors for better understanding


In [None]:

# Filter nodes based on degree centrality
degree_centrality = nx.degree_centrality(cell_network_graph)
important_nodes = {node for node, centrality in degree_centrality.items() if centrality > 0.05}  # Adjust the threshold as needed

# Visualization of the cell_network_graph with protein complexes as subgraphs
pos = nx.get_node_attributes(cell_network_graph, 'centroid')  # Get positions of nodes

# Create separate node lists for important and non-important nodes
important_nodes_list = [node for node in cell_network_graph.nodes() if node in important_nodes]
non_important_nodes_list = [node for node in cell_network_graph.nodes() if node not in important_nodes]

# Draw important nodes with their area as node size and color based on communities
important_node_sizes = [data['area'] for node, data in cell_network_graph.nodes(data=True) if node in important_nodes]
important_node_colors = [partition[node] for node in important_nodes_list]
nx.draw_networkx_nodes(cell_network_graph, pos, nodelist=important_nodes_list, node_size=important_node_sizes, node_color=important_node_colors, cmap='viridis', alpha=0.7)

# Draw non-important nodes with a fixed size and color (if desired)
nx.draw_networkx_nodes(cell_network_graph, pos, nodelist=non_important_nodes_list, node_size=30, node_color='gray', alpha=0.7)

# Draw edges with edge bundling
bundled_edges = nx.bundled_edges(cell_network_graph)
nx.draw_networkx_edges(cell_network_graph, pos, edgelist=bundled_edges, edge_color='gray', alpha=0.5)

# Draw protein complexes as subgraphs with distinct colors
colors = plt.cm.get_cmap('tab20', num_communities)
for i, protein_complex in enumerate(protein_complexes):
    nx.draw(protein_complex, pos, node_size=30, node_color=[colors(i)], alpha=0.7, edge_color='gray', linewidths=0.5)

# Label only important nodes
important_node_labels = {node: str(node) for node in important_nodes}
nx.draw_networkx_labels(cell_network_graph, pos, labels=important_node_labels, font_size=10, font_color='black')

# Set plot properties
plt.title("Cell Network Graph with Protein Complexes (Filtered)")
plt.axis('off')
plt.show()



In [None]:
from matplotlib.collections import LineCollection

# Get the number of protein complexes (communities)
num_communities = max(partition.values()) + 1

# Separate protein complexes into subgraphs based on community IDs
protein_complexes = [cell_network_graph.subgraph([node for node in cell_network_graph.nodes if partition[node] == community_id])
                        for community_id in range(num_communities)]


In [None]:

# Filter nodes based on degree centrality
degree_centrality = nx.degree_centrality(cell_network_graph)
important_nodes = {node for node, centrality in degree_centrality.items() if centrality > 0.05}  # Adjust the threshold as needed

# Visualization of the cell_network_graph with protein complexes as subgraphs
pos = nx.get_node_attributes(cell_network_graph, 'centroid')  # Get positions of nodes

# Create separate node lists for important and non-important nodes
important_nodes_list = [node for node in cell_network_graph.nodes() if node in important_nodes]
non_important_nodes_list = [node for node in cell_network_graph.nodes() if node not in important_nodes]

# Draw important nodes with their area as node size and color based on communities
important_node_sizes = [data['area'] for node, data in cell_network_graph.nodes(data=True) if node in important_nodes]
important_node_colors = [partition[node] for node in important_nodes_list]
nx.draw_networkx_nodes(cell_network_graph, pos, nodelist=important_nodes_list, node_size=important_node_sizes, node_color=important_node_colors, cmap='viridis', alpha=0.7)

# Draw non-important nodes with a fixed size and color (if desired)
nx.draw_networkx_nodes(cell_network_graph, pos, nodelist=non_important_nodes_list, node_size=30, node_color='gray', alpha=0.7)

# Draw edges using LineCollection for edge bundling effect
edges = cell_network_graph.edges()
edge_colors = [partition[u] for u, v in edges]
edge_positions = np.array([(pos[u], pos[v]) for u, v in edges], dtype='f')
lc = LineCollection(edge_positions, colors='gray', linewidths=0.5, alpha=0.5)
plt.gca().add_collection(lc)

# Draw protein complexes as subgraphs with distinct colors
colors = plt.cm.get_cmap('tab20', num_communities)
for i, protein_complex in enumerate(protein_complexes):
    nx.draw(protein_complex, pos, node_size=30, node_color=[colors(i)], alpha=0.7, edge_color='gray', linewidths=0.5)

# Label only important nodes
important_node_labels = {node: str(node) for node in important_nodes}
nx.draw_networkx_labels(cell_network_graph, pos, labels=important_node_labels, font_size=10, font_color='black')

# Set plot properties
plt.title("Cell Network Graph with Protein Complexes (Filtered)")
plt.axis('off')
plt.show()

In [None]:
from multiprocessing import Pool, cpu_count
import community  # Louvain algorithm package

def filter_nodes_by_degree(graph, threshold):
    degree_centrality = nx.degree_centrality(graph)
    important_nodes = {node for node, centrality in degree_centrality.items() if centrality > threshold}
    return important_nodes

def create_subgraph_for_community(graph, community_id):
    return graph.subgraph([node for node in graph.nodes if partition[node] == community_id])

def limit_cells_to_display(graph, num_cells_to_display):
    return sorted(graph.nodes(), key=lambda x: nx.degree(graph, x), reverse=True)[:num_cells_to_display]

def edge_bundling_edges(graph, cells_to_display):
    return [(u, v) for u, v in graph.edges() if u in cells_to_display and v in cells_to_display]

def get_node_sizes(graph, cells_to_display, important_nodes):
    node_sizes = []
    for node in cells_to_display:
        if node in important_nodes:
            node_data = graph.nodes[node]
            node_sizes.append(node_data['area'])
        else:
            node_sizes.append(30)  # A fixed size for non-important nodes
    return node_sizes

def draw_cell_network_with_protein_complexes(graph, partition, protein_complexes, num_cells_to_display=50, node_degree_threshold=0.05):
    important_nodes = filter_nodes_by_degree(graph, node_degree_threshold)
    cells_to_display = limit_cells_to_display(graph, num_cells_to_display)
    pos = nx.get_node_attributes(graph, 'centroid')
    
    # Get node sizes based on importance
    node_sizes = get_node_sizes(graph, cells_to_display, important_nodes)
    
    # Draw important nodes with their area as node size and color based on communities
    important_node_colors = [partition[node] for node in cells_to_display if node in important_nodes]
    nx.draw_networkx_nodes(graph, pos, nodelist=cells_to_display, node_size=node_sizes, node_color=important_node_colors, cmap='viridis', alpha=0.7)

    # Draw non-important nodes with a fixed size and color (if desired)
    non_important_nodes_list = [node for node in cells_to_display if node not in important_nodes]
    nx.draw_networkx_nodes(graph, pos, nodelist=non_important_nodes_list, node_size=30, node_color='gray', alpha=0.7)

    # Draw edges using LineCollection for edge bundling effect
    edges = edge_bundling_edges(graph, cells_to_display)
    edge_colors = [partition[u] for u, v in edges]
    edge_positions = np.array([(pos[u], pos[v]) for u, v in edges], dtype='f')
    lc = LineCollection(edge_positions, colors='gray', linewidths=0.5, alpha=0.5)
    plt.gca().add_collection(lc)

    # Draw protein complexes as subgraphs with distinct colors
    num_communities = max(partition.values()) + 1
    colors = plt.cm.get_cmap('tab20', num_communities)
    for i, protein_complex in enumerate(protein_complexes):
        nx.draw(protein_complex, pos, node_size=30, node_color=[colors(i)], alpha=0.7, edge_color='gray', linewidths=0.5)

    # Label only important nodes
    important_node_labels = {node: str(node) for node in cells_to_display if node in important_nodes}
    nx.draw_networkx_labels(graph, pos, labels=important_node_labels, font_size=10, font_color='black')

    # Set plot properties
    plt.title("Cell Network Graph with Protein Complexes (Limited Cells)")
    plt.axis('off')
    plt.show()


In [None]:
draw_cell_network_with_protein_complexes(cell_network_graph, partition, protein_complexes)

In [None]:

def get_node_sizes(graph, cells_to_display, important_nodes):
    node_sizes = []
    for node in cells_to_display:
        if node in important_nodes:
            node_data = graph.nodes[node]
            node_sizes.append(node_data['area'])
        else:
            node_sizes.append(30)  # A fixed size for non-important nodes
    return node_sizes


In [None]:

def draw_cell_network_with_protein_complexes1(graph, partition, protein_complexes, num_cells_to_display=50, node_degree_threshold=0.05):
    important_nodes = filter_nodes_by_degree(graph, node_degree_threshold)
    cells_to_display = limit_cells_to_display(graph, num_cells_to_display)
    pos = nx.get_node_attributes(graph, 'centroid')
    
    # Get node sizes based on importance
    node_sizes = get_node_sizes(graph, cells_to_display, important_nodes)
    
    # Create subplots for main graph and protein complexes
    fig = make_subplots(rows=1, cols=2, subplot_titles=["Cell Network Graph", "Protein Complexes"])
    
    # Draw important nodes with their area as node size and color based on communities
    important_node_colors = [partition[node] for node in cells_to_display if node in important_nodes]
    node_trace = go.Scatter(x=[], y=[], mode='markers', text=[], marker=dict(size=node_sizes, color=important_node_colors, colorscale='viridis', colorbar=dict(title='Community')))
    for node in cells_to_display:
        x, y = pos[node]
        node_trace['x'] += tuple([x])
        node_trace['y'] += tuple([y])
        node_trace['text'] += tuple([node])
    fig.add_trace(node_trace, row=1, col=1)

    # Draw edges
    edge_trace = go.Scatter(x=[], y=[], line=dict(width=0.5, color='gray'), hoverinfo='none', mode='lines')
    for edge in edge_bundling_edges(graph, cells_to_display):
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_trace['x'] += tuple([x0, x1, None])
        edge_trace['y'] += tuple([y0, y1, None])
    fig.add_trace(edge_trace, row=1, col=1)

    # Draw protein complexes as subgraphs with distinct colors
    colors = plt.cm.get_cmap('tab20', len(protein_complexes))
    for i, protein_complex in enumerate(protein_complexes):
        complex_x, complex_y = zip(*[pos[node] for node in protein_complex.nodes()])
        complex_trace = go.Scatter(x=complex_x, y=complex_y, mode='markers', marker=dict(size=10, color=[colors(i)]))
        fig.add_trace(complex_trace, row=1, col=2)

    # Set axis properties
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=1)
    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=2)
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False, row=1, col=2)

    # Set layout properties
    fig.update_layout(title_text="Cell Network Graph with Protein Complexes (Limited Cells)", showlegend=False)
    fig.show()

# Call function
draw_cell_network_with_protein_complexes1(cell_network_graph, partition, protein_complexes)
