# Visualization of a region as a graph and a 3D model

In [17]:
import os
import pickle
import itertools
from concurrent.futures import ProcessPoolExecutor, as_completed
from tqdm import tqdm
import numpy as np
import pandas as pd
import igraph as ig
import networkx as nx

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from knots_tools import get_loop_pair_infos
import modvis

In [18]:
import pyvista as pv
pv.jupyter_extension_enabled = True
pv.set_jupyter_backend('trame')

## A synthetic structure for a simple schematics

In [None]:
minor_palette = modvis.model_colors_palette("light:#1953c8", pad_l=3)
link_palette = [minor_palette[0], '#bd2828', '#2abb30', minor_palette[0]]
sns.palplot(minor_palette);
sns.palplot(link_palette)

In [None]:
G = nx.MultiDiGraph()
for u, g in enumerate(
    [0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5, 5, 5, 6, 6, 0]
):
    G.add_node(u, segment=g, coord=u * 100_000 + 100_000_000)
n = G.number_of_nodes()
for u, v in zip(range(n - 1), range(1, n)):
    G.add_edge(u, v, is_contact=False, is_strand=True, end_segments=None)
# [1, 2, 3], [4, 5], [6, 7, 8], [9, 10], [11, 12, 13, 14], [15, 16]
for u, v in [
    (2, 7), (2, 9), (1, 11), (1, 16),  # 1
    (4, 10), (5, 11), (4, 16),  # 2
    (8, 11), (7, 15), # 3
    (12, 16),  # 4
]:
    G.add_edge(u, v, is_contact=True, is_strand=False, end_segments=(0, 0))
for u, v in [
    (0, 2), (0, 12), (1, 2), (3, 14), (7, 12), (13, 17)
]:
    G.add_edge(u, v, is_contact=True, is_strand=False, end_segments=None)

plt.figure(figsize=(7.5, 3), dpi=300)
modvis.plot_graph_with_minor(
    G, highlighted_loops={
        3: matplotlib.colors.to_hex(link_palette[1]),
        9: matplotlib.colors.to_hex(link_palette[2])
    },
    segments_palette=minor_palette,
    node_size=100,
    arc_alpha=0.5, arc_linewidth=2.0,
    selected_arc_alpha=1, selected_arc_linewidth=2.0,
    strand_linewidth=3.0,
    draw_strands=True,
    tick_y_coords=(-0.002, -0.004, -0.007),
    tick_linewidth=2.0
)
plt.tight_layout()
# plt.show()
# plt.savefig('./plots/graph_schema.svg')

In [None]:
plt.figure(figsize=(2, 2), dpi=300)

G = nx.complete_graph(6)

edge_styles = {
    (0, 5): {'style': 'dashed'},     # connection between 1 and 6 (dashed line)
    (0, 3): {'edge_color': 'red', 'width': 2}, # connection between 2 and 4 (green line)
    (4, 5): {'edge_color': 'green', 'width': 2},   # connection between 5 and 6 (red line)
}

def hexagonal_layout(n, theta=0, clockwise=True):
    angles = np.linspace(0, 2 * np.pi, n, endpoint=False)
    if clockwise:
        angles = -angles
    angles += theta # Rotate the hexagon
    return {i: np.array([np.cos(angle), np.sin(angle)]) for i, angle in enumerate(angles)}

pos = hexagonal_layout(6, theta=4 * np.pi / 3)
nx.draw_networkx_nodes(G, pos, node_color='white', edgecolors=minor_palette[1:], linewidths=1, node_size=300)
nx.draw_networkx_labels(G, pos, labels={i: str(i+1) for i in range(6)}, font_size=8, font_color='black')
nx.draw_networkx_edges(G, pos, [e for e in G.edges() if e not in edge_styles], edge_color='black')


for (u, v), style in edge_styles.items():
    nx.draw_networkx_edges(G, pos, edgelist=[(u, v)], **style)

plt.tight_layout()
plt.margins(0.1)
plt.axis('off')
plt.axis('equal')
plt.savefig('./plots/K6.svg')
plt.show()

## Select a CCD of interest

In [6]:
example_dataset = 'GM12878'
example_chromosome = 'chr7'
example_ccd_id = 100

In [None]:
DATA_DIR = './data'
GRAPHS_DIR = os.path.join(DATA_DIR, 'graphs')
MODELS_DIR = './models'

model_name = f'{example_dataset}_{example_chromosome}_{example_ccd_id:04d}'
model_working_dir = os.path.join(MODELS_DIR, model_name)
# load the prepared examples insted:
model_working_dir = os.path.join(MODELS_DIR, 'GM12878_chr7_0100_bigloop_0_9')
# model_working_dir = os.path.join(MODELS_DIR, 'GM12878_chr7_0100_small_3_8')
# model_working_dir = os.path.join(MODELS_DIR, 'GM12878_chr7_0100_twisted_5_9')
model_file = os.path.join(model_working_dir, f'{model_name}.pkl')
print(model_working_dir)

In [8]:
def load_graph(path):
    with open(os.path.join(GRAPHS_DIR, path), 'rb') as f:
        graph = pickle.load(f)
    return graph

example_graph = load_graph(f'ccd_graph_{example_dataset}_{example_chromosome}_{example_ccd_id:04d}.pkl')
drawable_graph, example_minor = modvis.make_drawable_graph(example_graph)

In [None]:
if os.path.exists(model_file):
    coords, restraints, bead_groups = modvis.load_model(model_file)
    print(f'Loaded model from {model_file}')
    CREATE_MODEL = False
else:
    print(f'Model not loaded, path "{model_file}" does not exist.')
    CREATE_MODEL = True

## ...Or create a new model here

In [None]:
# Then run modeling
def run_modeling_on_graph(graph, model_working_dir=None, bead_spacing_bp=500, padding_bp=0, overwrite=False):    
    bead_coords, restraints, bead_groups = modvis.beads_and_restraints_from_graph(
        graph, bead_spacing_bp=bead_spacing_bp, padding_bp=padding_bp
    )
    spring_model_runner = modvis.SpringModelAPI(
        config_file_path='./spring_model_config.ini',
        modeling_command='./run_sm_docker_CPU.sh',
        model_working_dir=model_working_dir
    )
    return spring_model_runner.run_modeling(bead_coords, restraints, bead_groups, overwrite=overwrite) + (restraints, bead_groups)

# CREATE_MODEL = True
if CREATE_MODEL:
    try:
        coords, init_str_points, restraints, bead_groups = run_modeling_on_graph(
            drawable_graph, bead_spacing_bp=500, padding_bp=20_000,
            model_working_dir=model_working_dir,  # Set this to None to use temporary dir for models files (faster, but no saved models)
            overwrite=True  # Change this to True to force the model to be re-run
        )
        print(f'Created model, saving in {model_working_dir}...')
        modvis.save_model(model_file, coords, restraints, bead_groups)
    except FileExistsError as e:
        print(f'Model already exists in {model_working_dir}:', e)
else:
    print('Model already exists, skipping modeling step')

## Gaussian linking integeral

Find a "nice" link to visualize: preferably low linking number and not internally overlapping

In [None]:
infos = get_loop_pair_infos(coords, restraints)
infos.query('abs_linking_number > 0').head(10)

## Select a loop pair for visualization

In [12]:
example_loop_pair = (0, 9)

## Visualize the model in 3D

In [None]:
def groups_loop_pair(restraints, total_beads, loop_A, loop_B):
    g = [
        (restraints[loop_A, 0], restraints[loop_A, 1], 1),
        (restraints[loop_B, 0], restraints[loop_B, 1], 2)
    ]
    if g[0][0] > 0:
        g = [(0, g[0][0] - 1, 0)] + g
    if g[-1][1] < total_beads - 1:
        g = g + [(g[-1][1] + 1, total_beads - 1, 0)]
    return g

norm_coords = modvis.scale_coords(coords)
plotter = pv.Plotter(shape=(1, 2))
_strand_radius = 0.015
_interaction_radius = 0.015
_interaction_color = 'orange'
plotter.subplot(0, 0)
modvis.plot_model(
    norm_coords, restraints, bead_groups=bead_groups,
    strand_radius=_strand_radius, interaction_radius=_interaction_radius,
    interaction_color=_interaction_color,
    palette=minor_palette,
    plotter=plotter
)
plotter.subplot(0, 1)
modvis.plot_model(
    norm_coords, restraints,
    bead_groups=groups_loop_pair(restraints, len(norm_coords), *example_loop_pair),
    interaction_color=_interaction_color,
    selected_groups=[1, 2], nonselected_opacity=0.1,
    strand_radius=_strand_radius, interaction_radius=_interaction_radius,
    restraints_only_within_groups=True,
    continous=False,
    palette=link_palette,
    plotter=plotter
)
plotter.link_views()
plotter.show()

## Visualize structure as a graph

With a given minor colored and edes selected

In [14]:
drawable_graph, example_minor = modvis.make_drawable_graph(example_graph)

In [None]:
plt.figure(figsize=(6, 3), dpi=300)
modvis.plot_graph_with_minor(
    drawable_graph, highlighted_loops={
        example_loop_pair[0]: matplotlib.colors.to_hex(link_palette[1]),
        example_loop_pair[1]: matplotlib.colors.to_hex(link_palette[2])
    },
    segments_palette=minor_palette,
    visible_segments=list(range(1, 6 + 1)),
    node_size=1, arc_linewidth=1.0,
    tick_y_coords=(-0.001, -0.002, -0.004),
    tick_fontsize=8
)
# plt.title(f'|V| = {G.number_of_nodes()}, |E| = {G.number_of_edges()}', fontsize=14)
plt.tight_layout()
# plt.show()
plt.savefig(f'./plots/graph_{model_name}_{example_loop_pair[0]}_{example_loop_pair[1]}.svg');