# 🧬 ProtCNet: Protein Contact Network Analyzer

**ProtCNet** is an interactive bioinformatics tool designed to analyze and visualize **inter-chain contact networks** in macromolecular structures from the **Protein Data Bank (PDB)** or from predictions, such as AlphaFold models.

Given a PDB file (uploaded or retrieved by ID), this notebook:

- Parses the 3D structure to extract atomic coordinates by chain or atom type.
- Computes contacts between chains or residues using a distance cutoff (default: 5 Å).
- Visualizes molecular interactions through:
  - An **interactive network graph** (`ipysigma`)
  - A **heatmap** of contact frequencies (`seaborn`)
  - A **residue-level network** (if enabled)

Additional options allow you to:
- Filter by **hydrophobic** or **electrostatic** residues
- Choose between **Cα**, **Cβ**, or **all atoms**
- Track **intra-chain** and **inter-chain** contacts

---

### 📦 Dependencies

- `BioPython`, `NetworkX`, `Seaborn`, `Matplotlib`, `ipysigma`, `ipywidgets`, `scipy`

---

**Use ProtCNet to explore macromolecular complexes, map interaction networks, identify candidate interfaces, and support structural biology research and discovery.**

Developed by Victor Klein-Sousa (@vkleinsousa) under the MIT License.

Github: https://github.com/VKleinSousa/ProtCNet

In [None]:
# Install dependencies
!pip install biopython
!pip install ipysigma

In [None]:
#import dependencies
import os
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from Bio import PDB
from Bio.PDB import PDBParser, MMCIFParser
from ipysigma import Sigma
from scipy.spatial import KDTree
from IPython.display import display
import ipywidgets as widgets
import time
import csv



In [None]:
# @title
# Inputs
from IPython.display import display, Markdown
import ipywidgets as widgets

# --- 🧪 Analysis Parameters ---

# File input
pdb_id_widget = widgets.Text(value='5IV5', description='PDB ID:')
upload_widget = widgets.FileUpload(accept='.pdb,.cif', multiple=False)

# Contact computation
cutoff_widget = widgets.FloatText(value=5.0, description='Cutoff (Å):')
contact_mode_widget = widgets.Dropdown(
    options=[('Inter-chain', 'inter'), ('Intra-chain', 'intra'), ('All contacts', 'all')],
    value='inter',
    description='Contact Type:'
)
sequence_distance_widget = widgets.IntText(
    value=15,
    description='Seq. Distance:',
    tooltip='Minimum residue separation for intra-chain contacts'
)

# Atom selection and filtering
atom_selector = widgets.Dropdown(
    options=[('All atoms', 'all'), ('Cα only', 'ca'), ('Cβ only', 'cb')],
    value='ca',
    description='Atoms:'
)
interaction_filter_widget = widgets.Dropdown(
    options=[
        ('All contacts', 'all'),
        ('Electrostatic only', 'electrostatic'),
        ('Hydrophobic only', 'hydrophobic')
    ],
    value='all',
    description='Interaction:'
)
track_residues_widget = widgets.Checkbox(
    value=True,
    description="Track contacting residues"
)

residue_level_net_widget = widgets.Checkbox(
    value=True,
    description="Plot Network at residue level"
)
residue_level_cutoff_widget = widgets.FloatText(value=10.0, description='Minimal number of contacts.')
# --- 🎨 Visualization Style Settings ---

network_file_widget = widgets.Text(value='network_visualization.html', description='Network File:')
edge_style_widget = widgets.Dropdown(
    options=[
        ('Rectangle', 'rectangle'),
        ('Line', 'line'),
        ('Curve', 'curve')
    ],
    value='curve',
    description='Edge Style:'
)

# --- 📌 Display Layout ---

display(Markdown("## 🧪 Contact Analysis Parameters"))
display(pdb_id_widget, upload_widget, cutoff_widget, contact_mode_widget,
        atom_selector, interaction_filter_widget, track_residues_widget,residue_level_net_widget)
display(Markdown("If Contact Type != Inter-chain:"))
display(sequence_distance_widget)
display(Markdown("If Plot Network at residue level:"))
display(residue_level_cutoff_widget)
display(Markdown("## 🎨 Network Visualization Settings"))
display(edge_style_widget, network_file_widget)


In [None]:
# @title
# Core Functions
display(Markdown("## 💻 Core Functions"))
def parse_structure(file_path, atom_mode="ca", residue_filter="all", with_residues=False):
    """
    Parse a PDB or CIF file and extract atomic coordinates grouped by chains.

    Returns:
        - chain_atoms: {chain_id: [atom_coords]}
        - atom_residue_map (optional): {chain_id: [(res_id, resname, resseq)]}
          where res_id = "A:101", resname = "GLY", resseq = 101
    """
    if file_path.endswith(".pdb"):
        parser = PDB.PDBParser(QUIET=True)
    elif file_path.endswith(".cif"):
        parser = PDB.MMCIFParser(QUIET=True)
    else:
        raise ValueError("Unsupported file format. Only .pdb or .cif allowed.")

    structure = parser.get_structure("structure", file_path)

    chain_atoms = {}
    atom_residue_map = {}

    hydrophobic = {'ALA', 'VAL', 'LEU', 'ILE', 'MET', 'PHE', 'TRP', 'PRO'}
    charged = {'ASP', 'GLU', 'ARG', 'LYS', 'HIS'}

    for model in structure:
        for chain in model:
            chain_id = chain.id
            chain_atoms.setdefault(chain_id, [])
            if with_residues:
                atom_residue_map.setdefault(chain_id, [])

            for residue in chain:
                resname = residue.get_resname().strip()
                resseq = residue.id[1]  # sequence number
                res_id = f"{chain_id}:{resseq}"

                # Residue filtering
                if residue_filter == "hydrophobic" and resname not in hydrophobic:
                    continue
                if residue_filter == "electrostatic" and resname not in charged:
                    continue

                for atom in residue:
                    atom_name = atom.get_id()
                    if atom_mode == "ca" and atom_name != "CA":
                        continue
                    elif atom_mode == "cb" and atom_name != "CB":
                        continue

                    chain_atoms[chain_id].append(atom.coord)
                    if with_residues:
                        atom_residue_map[chain_id].append((res_id, resname, resseq))

    if with_residues:
        return chain_atoms, atom_residue_map
    else:
        return chain_atoms


def compute_contacts_with_residues(
    chain_atoms_dict,
    atom_residue_map,
    cutoff=5.0,
    mode="inter",
    sequence_distance=5,
    progress=None,
    status_label=None
):
    """
    Compute contacts with optional inter/intra/all filtering and residue-level annotation.

    Parameters:
        - chain_atoms_dict: {chain_id: [np.array([x,y,z]), ...]}
        - atom_residue_map: {chain_id: [(residue_id_str, resname, seqpos), ...]}
        - cutoff: distance threshold
        - mode: 'inter', 'intra', or 'all'
        - sequence_distance: minimum residue index separation for intra-chain
    """
    from scipy.spatial import KDTree

    contact_map = {}
    chain_ids = list(chain_atoms_dict.keys())
    total_pairs = (
        len(chain_ids) * (len(chain_ids) - 1) // 2 if mode == "inter"
        else sum(1 for _ in chain_ids) if mode == "intra"
        else len(chain_ids) * (len(chain_ids) + 1) // 2
    )
    progress_counter = 0

    for i in range(len(chain_ids)):
        for j in range(i, len(chain_ids)):
            c1, c2 = chain_ids[i], chain_ids[j]
            is_same_chain = c1 == c2

            # Filter mode
            if mode == "inter" and is_same_chain:
                continue
            if mode == "intra" and not is_same_chain:
                continue

            coords1 = np.array(chain_atoms_dict[c1])
            coords2 = np.array(chain_atoms_dict[c2])
            res1 = atom_residue_map[c1]
            res2 = atom_residue_map[c2]

            if coords1.size == 0 or coords2.size == 0:
                continue

            tree = KDTree(coords2)
            residue_contacts = set()

            for idx1, atom1 in enumerate(coords1):
                close_indices = tree.query_ball_point(atom1, r=cutoff)
                for idx2 in close_indices:
                    res1_id, _, seq1 = res1[idx1]
                    res2_id, _, seq2 = res2[idx2]

                    # For intra-chain, skip close-in-sequence contacts
                    if is_same_chain and abs(seq1 - seq2) < sequence_distance:
                        continue

                    residue_contacts.add((res1_id, res2_id))

            if residue_contacts:
                contact_map[(c1, c2)] = {
                    "count": len(residue_contacts),
                    "residue_pairs": sorted(residue_contacts)
                }

            progress_counter += 1
            if progress is not None:
                progress.value = 40 + int((progress_counter / total_pairs) * 40)
            if status_label is not None:
                status_label.value = f"Computing contacts... ({progress_counter}/{total_pairs})"

    return contact_map


def compute_contacts(chain_atoms, cutoff=5.0, progress=None, status_label=None):
    """Compute inter-chain contacts using KDTree for spatial efficiency."""
    contact_map = {}
    chain_ids = list(chain_atoms.keys())
    total_pairs = (len(chain_ids) * (len(chain_ids) - 1)) // 2
    progress_counter = 0

    for i in range(len(chain_ids)):
        for j in range(i + 1, len(chain_ids)):
            chain1, chain2 = chain_ids[i], chain_ids[j]
            coords1 = np.array(chain_atoms[chain1])
            coords2 = np.array(chain_atoms[chain2])

            # Skip empty chains
            if coords1.size == 0 or coords2.size == 0:
                continue
            if coords1.ndim != 2 or coords2.ndim != 2 or coords1.shape[1] != 3 or coords2.shape[1] != 3:
                continue

            # Build KDTree and query contacts
            tree = KDTree(coords2)
            num_contacts = sum(len(tree.query_ball_point(pt, r=cutoff)) > 0 for pt in coords1)

            if num_contacts > 0:
                contact_map[(chain1, chain2)] = num_contacts

            # Progress bar update
            progress_counter += 1
            if progress is not None:
                progress.value = 40 + int((progress_counter / total_pairs) * 40)
            if status_label is not None:
                status_label.value = f"Computing contacts... ({progress_counter}/{total_pairs})"

    return contact_map


def plot_and_save_heatmap(contact_map, output_file):
    """Plot and save a heatmap of contacts."""
    chains = sorted(set([c for pair in contact_map for c in pair]))
    matrix = np.zeros((len(chains), len(chains)))
    idx = {c: i for i, c in enumerate(chains)}

    for (c1, c2), val in contact_map.items():
      i, j = idx[c1], idx[c2]
      count = val["count"] if isinstance(val, dict) else val
      matrix[i, j] = matrix[j, i] = float(count)


    plt.figure(figsize=(8, 6))
    sns.heatmap(matrix, xticklabels=chains, yticklabels=chains, cmap="Blues", square=True)
    plt.xlabel("Chain ID")
    plt.ylabel("Chain ID")
    plt.title("Contact Heatmap")
    plt.savefig(output_file, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✅ Heatmap saved as {output_file}")


def get_chain_palette(n):
    """Return a distinct color palette with n colors."""
    import seaborn as sns
    return sns.color_palette("hls", n).as_hex()

def plot_and_save_interactive_network(contact_map, output_file, edge_style):
    """Create and save an interactive network visualization using ipysigma."""
    G = nx.Graph()

    # Extract all unique chain IDs
    chains = sorted(set([c for pair in contact_map for c in pair]))
    chain_palette = get_chain_palette(len(chains))
    palette = {chain: chain_palette[i] for i, chain in enumerate(chains)}

    # Assign chain ID as the categorical color label (not actual hex color)
    for (chain1, chain2), contact_info in contact_map.items():
        weight = contact_info["count"] if isinstance(contact_info, dict) else int(contact_info)

        G.add_node(chain1, color=chain1, node_size=weight)
        G.add_node(chain2, color=chain2, node_size=weight)

        if G.has_edge(chain1, chain2):
            G[chain1][chain2]['weight'] += weight
        else:
            G.add_edge(chain1, chain2, weight=weight)


    for _, data in G.nodes(data=True):
        data["node_size"] = int(data["node_size"])
    for _, _, data in G.edges(data=True):
        data["weight"] = int(data["weight"])

    Sigma.write_html(
        G,
        output_file,
        fullscreen=True,
        node_color="color",
        node_metrics=["louvain"],
        node_size="node_size",
        node_size_range=(3, 20),
        max_categorical_colors=30,
        edge_size="weight",
        edge_size_range=(5, 15),
        default_edge_type=edge_style,
        node_border_color_from="node",
        default_node_label_size=16,
        node_color_palette=palette
    )
    print(f"✅ Interactive network saved as {output_file}")




def plot_networks_with_selfloops(contact_map, output_html, edge_style, output_tiff, dpi=300):
    """
    Generates two network visualizations:
    1. An interactive HTML file without self-loops using ipysigma.
    2. A high-resolution TIFF image with self-loops using NetworkX and Matplotlib.

    Parameters:
    - contact_map: dict
        Dictionary where keys are (chain1, chain2) and values are either a count or dict with "count".
    - output_html: str
        Path to the interactive HTML file.
    - output_tiff: str
        Path to the static .tiff image file.
    - dpi: int
        Resolution for the static image (default 300).
    """
    # --- Build graph ---
    G = nx.Graph()

    # Extract unique chain IDs
    chains = sorted(set([c for pair in contact_map for c in pair]))

    # Use a colorblind-friendly palette
    chain_palette = get_chain_palette(len(chains))
    palette = {chain: chain_palette[i] for i, chain in enumerate(chains)}

    # Add nodes and edges
    for (chain1, chain2), contact_info in contact_map.items():
        weight = contact_info["count"] if isinstance(contact_info, dict) else int(contact_info)
        G.add_node(chain1, color=chain1, node_size=weight)
        G.add_node(chain2, color=chain2, node_size=weight)

        if G.has_edge(chain1, chain2):
            G[chain1][chain2]['weight'] += weight
        else:
            G.add_edge(chain1, chain2, weight=weight)

    # Normalize attributes
    for _, data in G.nodes(data=True):
        data["node_size"] = int(data["node_size"])
    for _, _, data in G.edges(data=True):
        data["weight"] = int(data["weight"])

    # --- Interactive HTML without self-loops ---
    G_no_selfloops = G.copy()
    self_loops = list(nx.selfloop_edges(G_no_selfloops))
    G_no_selfloops.remove_edges_from(self_loops)

    Sigma.write_html(
        G,
        output_html,
        fullscreen=True,
        node_color="color",  # category name
        node_metrics=["louvain"],
        node_size="node_size",
        node_size_range=(3, 20),
        max_categorical_colors=30,
        edge_size="weight",
        edge_size_range=(5, 15),
        default_edge_type=edge_style,
        node_border_color_from="node",
        default_node_label_size=16,
        node_color_palette=palette  # category -> color
    )
    print(f"✅ Interactive network saved as {output_html}")

    # --- Static .tiff with self-loops ---
    pos = nx.spring_layout(G, seed=42)  # Reproducible layout

    plt.figure(figsize=(10, 10), dpi=dpi)

    # Use same node color palette
    node_colors = [palette.get(node, "#999999") for node in G.nodes()]
    nx.draw_networkx_nodes(G, pos, node_size=300, node_color=node_colors)
    nx.draw_networkx_labels(G, pos, font_size=10)

    # Separate self-loops
    self_loop_edges = list(nx.selfloop_edges(G))
    other_edges = [edge for edge in G.edges() if edge not in self_loop_edges]

    # Draw inter-chain edges
    nx.draw_networkx_edges(G, pos, edgelist=other_edges, width=1.5, edge_color='gray')

    # Draw self-loops as red circles
    for node in G.nodes():
        if G.has_edge(node, node):
            loop_weight = G[node][node]['weight']
            loop = plt.Circle(pos[node], 0.05, color='red', fill=False, linewidth=1.5)
            plt.gca().add_patch(loop)
            # Optional: annotate loop count
            # plt.text(pos[node][0], pos[node][1] + 0.07, f"{loop_weight}", fontsize=8, ha='center')

    plt.axis('off')
    plt.tight_layout()
    plt.savefig(output_tiff, format='tiff', dpi=dpi)
    plt.close()
    print(f"✅ High-resolution network image saved as {output_tiff}")


def plot_residue_level_network(contact_map, output_file="network_residue_level.html", min_contacts=10):
    """
    Generates and saves a residue-level interactive contact network using ipysigma.

    Args:
        contact_map: dict from compute_contacts_with_residues
        output_file: HTML output path
        min_contacts: minimum number of edges for a residue to be styled normally
    """
    import networkx as nx
    import matplotlib.pyplot as plt
    from ipysigma import Sigma

    G = nx.Graph()

    # Step 1: Extract all chains
    all_chains = set()
    for (chain1, chain2), data in contact_map.items():
        for res1, res2 in data.get("residue_pairs", []):
            all_chains.add(res1.split(":")[0])
            all_chains.add(res2.split(":")[0])
    sorted_chains = sorted(all_chains)

    # Step 2: Build color palette
    chain_palette = get_chain_palette(len(sorted_chains))
    palette = {chain: chain_palette[i] for i, chain in enumerate(sorted_chains)}

    # Step 3: Build graph (nodes first, to track all)
    all_residues = set()
    for (_, _), data in contact_map.items():
        for res1, res2 in data.get("residue_pairs", []):
            all_residues.add(res1)
            all_residues.add(res2)

    for residue in all_residues:
        chain = residue.split(":")[0]
        G.add_node(residue, label=residue, color=chain, node_size=5)

    # Step 4: Add edges
    for (_, _), data in contact_map.items():
        for res1, res2 in data.get("residue_pairs", []):
            if G.has_edge(res1, res2):
                G[res1][res2]["weight"] += 1
            else:
                G.add_edge(res1, res2, weight=1)

    # Step 5: Update node styles based on degree
    for node in G.nodes():
        deg = G.degree(node)
        if deg < min_contacts:
            G.nodes[node]["color"] = "#cccccc"  # light gray
            G.nodes[node]["node_size"] = 2      # smaller
        else:
            G.nodes[node]["node_size"] = 5      # default size

    # Convert attributes to int
    for _, data in G.nodes(data=True):
        data["node_size"] = int(data["node_size"])
    for _, _, data in G.edges(data=True):
        data["weight"] = int(data["weight"])

    # Step 6: Save interactive graph
    Sigma.write_html(
        G,
        output_file,
        fullscreen=True,
        node_color="color",
        node_size="node_size",
        node_size_range=(2, 8),
        edge_size="weight",
        edge_size_range=(1, 5),
        default_edge_type="curve",
        node_border_color_from="node",
        default_node_label_size=12,
        node_color_palette=palette  # for valid color categories
    )

    print(f"✅ Residue-level network saved as {output_file}")



def analyze_structure(file_path, cutoff=cutoff_widget.value):
    """Pipeline: Parse file, compute contacts, generate visualizations."""
    if not os.path.exists(file_path):
        print(f"❌ File not found: {file_path}")
        return

    print(f"🔍 Analyzing {file_path} with cutoff {cutoff} Å ...")
    try:
        chain_atoms = parse_structure(file_path, atom_mode=atom_selector.value)
        contact_map = compute_contacts(chain_atoms, cutoff)

        if not contact_map:
            print("⚠️ No inter-chain contacts found.")
            return

        plot_and_save_interactive_network(contact_map, NETWORK_HTML_FILE)
        plot_and_save_heatmap(contact_map, HEATMAP_FILE)
    except Exception as e:
        print(f"❌ Error during analysis: {e}")


def export_contacts_table(contact_map, output_file, structure_path):
    """
    Export residue-residue contacts to TSV including residue names and sequence numbers.

    Args:
        contact_map: dict with 'residue_pairs' as (res1_id, res2_id)
        output_file: output path to save the TSV
        structure_path: path to the PDB or CIF file to extract residue names
    """


    # Load structure
    if structure_path.endswith(".pdb"):
        parser = PDBParser(QUIET=True)
    elif structure_path.endswith(".cif"):
        parser = MMCIFParser(QUIET=True)
    else:
        raise ValueError("Unsupported file format.")

    structure = parser.get_structure("structure", structure_path)

    # Create lookup: "A:1002" -> ("ALA", 1002)
    residue_lookup = {}
    for model in structure:
        for chain in model:
            for residue in chain:
                chain_id = chain.id
                resseq = residue.id[1]
                resname = residue.get_resname()
                key = f"{chain_id}:{resseq}"
                residue_lookup[key] = (resname, resseq)

    # Write TSV
    with open(output_file, "w", newline='') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerow(["chain1", "res1", "resname1", "chain2", "res2", "resname2"])

        for (chain1, chain2), data in contact_map.items():
            if "residue_pairs" not in data:
                continue
            for res1, res2 in data["residue_pairs"]:
                resname1, resseq1 = residue_lookup.get(res1, ("UNK", -1))
                resname2, resseq2 = residue_lookup.get(res2, ("UNK", -1))
                chain_id1, _ = res1.split(":")
                chain_id2, _ = res2.split(":")
                writer.writerow([chain_id1, resseq1, resname1, chain_id2, resseq2, resname2])

    print(f"✅ Contact table saved to: {output_file}")




def download_pdb(pdb_id, output_path_base):
    """
    Downloads a PDB structure file in .pdb format. If not available, tries .cif format.

    Args:
        pdb_id (str): The 4-character PDB ID (e.g., "1TUP").
        output_path_base (str): Base path without extension.

    Returns:
        str or None: Path to the downloaded file, or None if both downloads failed.
    """
    import requests

    pdb_id = pdb_id.strip().upper()

    # Try .pdb format first
    pdb_url = f"https://files.rcsb.org/download/{pdb_id}.pdb"
    try:
        response = requests.get(pdb_url)
        response.raise_for_status()
        output_path = f"{output_path_base}.pdb"
        with open(output_path, 'w') as f:
            f.write(response.text)
        print(f"✅ Successfully downloaded {pdb_id}.pdb")
        return output_path
    except requests.exceptions.RequestException:
        print(f"⚠️  Failed to download {pdb_id}.pdb. Trying .cif format...")

    # Try .cif format as fallback
    cif_url = f"https://files.rcsb.org/download/{pdb_id}.cif"
    try:
        response = requests.get(cif_url)
        response.raise_for_status()
        output_path = f"{output_path_base}.cif"
        with open(output_path, 'w') as f:
            f.write(response.text)
        print(f"✅ Successfully downloaded {pdb_id}.cif")
        return output_path
    except requests.exceptions.RequestException:
        print(f"❌ Failed to download both .pdb and .cif for {pdb_id}")
        return None



In [None]:
# @title
display(Markdown("## 📝 Main"))
def run_analysis():
    start_time = time.time()
    progress_bar.value = 0
    progress_label.value = "Starting analysis..."

    cutoff = cutoff_widget.value
    mode = contact_mode_widget.value  # e.g. 'inter', 'intra', 'all'
    sequence_distance = sequence_distance_widget.value
    residue_filter = interaction_filter_widget.value

    file_path = None
    if upload_widget.value:
        for name, file in upload_widget.value.items():
            file_path = f"/tmp/{name}"
            with open(file_path, 'wb') as f:
                f.write(file['content'])
            break
    elif pdb_id_widget.value:
        file_path = f"/tmp/{pdb_id_widget.value.upper()}"
        file_path = download_pdb(pdb_id_widget.value, file_path)
        if file_path is None:
            progress_label.value = "❌ File download failed."
            return
    else:
        progress_label.value = "❌ No input file provided."
        return

    progress_bar.value = 20
    progress_label.value = "Parsing structure..."

    try:
        if track_residues_widget.value:
            chain_atoms, residue_map = parse_structure(
                file_path,
                atom_mode=atom_selector.value,
                residue_filter=residue_filter,
                with_residues=True
            )
        else:
            chain_atoms = parse_structure(
                file_path,
                atom_mode=atom_selector.value,
                residue_filter=residue_filter,
                with_residues=False
            )
    except Exception as e:
        progress_label.value = f"❌ Parsing failed: {e}"
        return

    progress_bar.value = 40
    progress_label.value = "Starting contact calculations..."

    try:
        if track_residues_widget.value:
            contact_map = compute_contacts_with_residues(
                chain_atoms,
                residue_map,
                cutoff=cutoff,
                mode=mode,
                sequence_distance=sequence_distance,
                progress=progress_bar,
                status_label=progress_label
            )
        else:
            contact_map = compute_contacts(
                chain_atoms,
                cutoff=cutoff,
                progress=progress_bar,
                status_label=progress_label
            )
    except Exception as e:
        progress_label.value = f"❌ Contact computation failed: {e}"
        return

    if not contact_map:
        progress_label.value = "⚠️ No contacts found."
        progress_bar.value = 100
        return

    progress_label.value = "Generating interactive network..."
    try:
        if contact_mode_widget.value == 'inter':
            plot_and_save_interactive_network(contact_map, network_file_widget.value, edge_style_widget.value)
        else:
            plot_networks_with_selfloops(
                contact_map,
                network_file_widget.value,
                edge_style_widget.value,
                network_file_widget.value.replace('.html', '.tiff')
            )
        progress_bar.value = 90
    except Exception as e:
        progress_label.value = f"❌ Network plot failed: {e}"
        return

    progress_label.value = "Generating heatmap..."
    try:
        plot_and_save_heatmap(contact_map, 'heatmap.tiff')
        progress_bar.value = 95
    except Exception as e:
        progress_label.value = f"❌ Heatmap plot failed: {e}"
        return

    try:
        if track_residues_widget.value:
            progress_label.value = "Generating contacts table..."
            export_contacts_table(
                contact_map,
                network_file_widget.value.replace(".html", "_contacts.tsv"),
                file_path
            )
            if residue_level_net_widget.value:
              progress_label.value = "Generating residue level network"
              plot_residue_level_network(
                  contact_map,
                  output_file="network_residue_level.html", min_contacts=residue_level_cutoff_widget.value
              )
            progress_bar.value = 100
    except Exception as e:
        print(f"⚠️ Failed to export contact table or residue network: {e}")

    end_time = time.time()
    elapsed = end_time - start_time
    progress_label.value = f"✅ Analysis complete! Runtime: {elapsed:.2f} seconds"


In [None]:
# @title

# Add button to trigger analysis
run_button = widgets.Button(description="Run Analysis")
run_button.on_click(lambda x: run_analysis())
display(run_button)

#Add progress bar
progress_bar = widgets.IntProgress(value=0, min=0, max=100, description='Progress:', bar_style='info')
progress_label = widgets.Label(value="Ready")
display(progress_bar, progress_label)


In [None]:
from google.colab import files

# Download the interactive network and heatmap files
files.download(network_file_widget.value)
files.download('heatmap.tiff')
files.download(network_file_widget.value.replace('.html', '_contacts.tsv'))
if contact_mode_widget.value != 'inter':
  files.download(network_file_widget.value.replace('.html', '.tiff'))
if residue_level_net_widget.value:
  files.download('network_residue_level.html')

Thanks for using it - I hope it worked fine haha.
