In [None]:
%pip install selene-sdk logomaker dash

In [None]:
%pip install torch_fftconv

This cell imports all the required Python libraries and modules for the analysis and visualization. It includes libraries for numerical operations (`numpy`, `pandas`, `torch`), plotting (`matplotlib`, `plotly`, `logomaker`), genomic data handling (`selene_sdk`), and building the interactive Dash application (`dash`). It also sets up the device for PyTorch to use (GPU if available, otherwise CPU).

In [None]:
import os
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import selene_sdk
import logomaker
import matplotlib.pyplot as plt
import base64
from io import BytesIO
from dash import Dash, html, dcc, Input, Output, no_update
import tempfile
import os

In [None]:
from torch_fftconv import FFTConv1d

# simple model architecture
class SimpleNetModified_DA_SSE(nn.Module):
    def __init__(self, input_channels=4):
        super().__init__()
        self.conv = nn.Conv1d(input_channels, 40, kernel_size=51, padding=25)
        self.activation = nn.Softplus()
        self.deconv = FFTConv1d(40, 2, kernel_size=601, padding=300)

    def forward(self, x):
        y = self.activation(self.conv(x))
        y_pred = torch.sigmoid(self.deconv(y))  # independent sigmoid per channel
        return y_pred[:, :, 500:-500]

This cell defines two helper functions:

1.  `add_exon_rectangle`: This function adds a shaded rectangle to a Plotly figure to visually represent the exon region within the sequence.
2.  `plot_motif_logo_to_file`: This function takes a position weight matrix (PWM) for a motif, generates a sequence logo using `logomaker`, and saves it as a PNG image to a temporary file.

In [None]:
# ------------------------------
# Function to add exon rectangle to subplot
# ------------------------------
def add_exon_rectangle(fig, acceptor_idx, donor_idx, row, col, y_min, y_max, color="rgba(0, 255, 0, 0.3)"):
    """
    Add exon region as a rectangle to specified subplot

    Parameters:
    - fig: plotly figure
    - acceptor_idx: x-coordinate of acceptor site
    - donor_idx: x-coordinate of donor site
    - row: subplot row
    - col: subplot column
    - y_min: minimum y value for rectangle height
    - y_max: maximum y value for rectangle height
    - color: rectangle color with transparency
    """
    fig.add_shape(
        type="rect",
        x0=acceptor_idx,
        y0=y_min,
        x1=donor_idx,
        y1=y_max,
        fillcolor=color,
        line=dict(width=0),
        row=row, col=col
    )

# ------------------------------
# Function to plot motif logo and save to temporary file
# ------------------------------
def plot_motif_logo_to_file(motifpwm, title=None, filename=None):
    """Plot motif logo and save to file, return filename"""
    # Create figure with specific styling
    fig, ax = plt.subplots(figsize=(3.5, 1.2))  # Slightly smaller figure

    motifpwm = pd.DataFrame(motifpwm, columns=['A','C','G','T'])
    crp_logo = logomaker.Logo(motifpwm,
                              shade_below=.5,
                              fade_below=.5,
                              font_name='Arial',
                              ax=ax,
                              color_scheme='classic')

    # Remove spines
    crp_logo.style_spines(visible=False)
    crp_logo.style_spines(spines=['left', 'bottom'], visible=True, linewidth=0.5)

    # Style ticks - smaller font size and adjustments
    crp_logo.style_xticks(rotation=90, fmt='%d', anchor=0)

    # Customize tick parameters
    ax.tick_params(axis='x', labelsize=5, pad=1)      # Smaller x-axis labels
    ax.tick_params(axis='y', labelsize=5, pad=1)      # Smaller y-axis labels

    # Thinner spine lines
    for spine in ax.spines.values():
        spine.set_linewidth(0.5)

    if title is not None:
        ax.set_title(title, fontsize=7, pad=3)  # Smaller title

    # Remove y-axis label
    ax.set_ylabel("")

    # Adjust layout to reduce padding
    plt.tight_layout(pad=0.5)

    # Save to file with higher DPI
    if filename is None:
        fd, filename = tempfile.mkstemp(suffix='.png')
        os.close(fd)

    plt.savefig(filename, format='png', dpi=150, bbox_inches='tight',
                facecolor='white', edgecolor='none')
    plt.close(fig)

    return filename

This cell loads the human genome sequence using `selene_sdk` and initializes the splicing model (`SimpleNetModified_DA_SSE`). It then loads the pre-trained weights for the model. Finally, it prepares an example exon input from a provided TSV file, extracts the genomic sequence around the exon, and converts it into a tensor format suitable for the model.

In [None]:
# ------------------------------
# Load genome and model
# ------------------------------
# Download and unzip the genome file
!wget https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_49/GRCh38.primary_assembly.genome.fa.gz
!gunzip GRCh38.primary_assembly.genome.fa.gz

genome = selene_sdk.sequences.Genome(
    input_path='GRCh38.primary_assembly.genome.fa'
)



--2025-11-03 22:20:37--  https://ftp.ebi.ac.uk/pub/databases/gencode/Gencode_human/release_49/GRCh38.primary_assembly.genome.fa.gz
Resolving ftp.ebi.ac.uk (ftp.ebi.ac.uk)... 193.62.193.165
Connecting to ftp.ebi.ac.uk (ftp.ebi.ac.uk)|193.62.193.165|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 845635028 (806M) [application/x-gzip]
Saving to: ‘GRCh38.primary_assembly.genome.fa.gz.1’


2025-11-03 22:21:12 (23.8 MB/s) - ‘GRCh38.primary_assembly.genome.fa.gz.1’ saved [845635028/845635028]

gzip: GRCh38.primary_assembly.genome.fa already exists; do you wish to overwrite (y or n)? n
	not overwritten


In [None]:
net = SimpleNetModified_DA_SSE()
net.load_state_dict(torch.load('model.rep7.pth', map_location=torch.device('cpu')), strict=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.eval()



SimpleNetModified_DA_SSE(
  (conv): Conv1d(4, 40, kernel_size=(51,), stride=(1,), padding=(25,))
  (activation): Softplus(beta=1.0, threshold=20.0)
  (deconv): FFTConv1d(40, 2, kernel_size=(601,), stride=(1,), padding=(300,))
)

This cell performs the forward pass of the input sequence through the loaded splicing model. It calculates the model's predictions (`pred`) and the activations of the convolutional layer (`postact_motif`). It then extracts and normalizes the convolution weights, generates motif logos for each motif using the `plot_motif_logo_to_file` function, and computes the effect of each motif on the predicted donor and acceptor splice sites by convolving the motif activations with the deconvolutional weights. Finally, it sums the motif effects to get the total effect for both donor and acceptor sites and trims the effects and predictions to the region where predictions are made.

In [None]:
NUM_MOTIFS = 40

In [None]:
# ------------------------------
# Extract and normalize convolution weights
# ------------------------------
conv_weight = net.conv.weight.cpu().detach().numpy()
conv_weight_norm = conv_weight - conv_weight.mean(axis=1, keepdims=True)

# ------------------------------
# Generate motif logos and save to files
# ------------------------------
motif_logo_files = []
for i in range(NUM_MOTIFS):
    motif_pwm = conv_weight_norm[i].T
    logo_file = plot_motif_logo_to_file(motif_pwm, title=f"Motif {i}")
    motif_logo_files.append(logo_file)

### **Use chr7	127589083	127589163	+ as an example**

---



Change the code below to the coordinate of interest

In [None]:
# ------------------------------
# Example exon input
# ------------------------------
idx = 0
# chr1	234433410	234433571

# Extract genomic coordinates
chrom = 'chr1'
start = 234433410
end = 234433571
strand = '-'
start = start - 1   # fix to 0-based index
exon_length = end - start
# have a preview of the exon
print('Len Exon: ' + str(exon_length) +'  Exon preview: ',genome.get_sequence_from_coords(chrom, start-10, start, strand) + '|' + genome.get_sequence_from_coords(chrom, start, end, strand) + '|' + genome.get_sequence_from_coords(chrom, end, end+10, strand))


Len Exon: 162  Exon preview:  GTAAGCCTCT|GTTTCAGATCTGGATCGTTGCCATTTATACCTGATGGTGTTAACTGAGCTTATAAATCTGCATTTGAAGGTTGGGTGGAAAAGGGGTAACCCTATCTGGAGAGTTATTTCTCTTTTGAAAAATGCATCCATTCAGCATCTTCAAGAGATGGACAGTGGACAG|CTTGTCTTAG


In [None]:
# Extract sequence
flank = 550
seq = genome.get_encoding_from_coords(chrom, start-flank, end+flank, strand)

seq_tensor = torch.FloatTensor(seq)[None,:,:].transpose(1,2).to(device)
seq_len = seq_tensor.shape[2]

In [None]:
# ------------------------------
# Forward pass
# ------------------------------
trim = 500
with torch.no_grad():
    pred = net(seq_tensor)
    postact_motif = net.activation(net.conv(seq_tensor))

seq_len = postact_motif.shape[2]
pred_len = pred.shape[-1]

# X coordinates
x_full = np.arange(seq_len)
x_pred = np.arange(trim, seq_len - trim)

In [None]:
# ------------------------------
# Compute motif effects for both outputs
# ------------------------------
dweight = net.deconv.weight.cpu().detach().numpy()
effects_motifs_donor = []
effects_motifs_acceptor = []

for i in range(NUM_MOTIFS):
    activation = postact_motif[0, i, :].cpu().numpy()
    effect_donor = np.convolve(activation, dweight[0, i, ::-1], mode='same')
    effects_motifs_donor.append(effect_donor)
    effect_acceptor = np.convolve(activation, dweight[1, i, ::-1], mode='same')
    effects_motifs_acceptor.append(effect_acceptor)

effects_motifs_donor = np.array(effects_motifs_donor)
effects_motifs_acceptor = np.array(effects_motifs_acceptor)

# Trim to predicted region
effects_motifs_donor = effects_motifs_donor[:, trim:-trim]
effects_motifs_acceptor = effects_motifs_acceptor[:, trim:-trim]
effects_sum_donor = effects_motifs_donor.sum(axis=0)
effects_sum_acceptor = effects_motifs_acceptor.sum(axis=0)

In [None]:
# ------------------------------
# Determine true splice site positions
# ------------------------------
acceptor_idx = flank
donor_idx = flank + exon_length - 1

# ------------------------------
# Calculate y-ranges for rectangle positioning
# ------------------------------
motif_activations = postact_motif[0, :, :].cpu().numpy()
y_min_track1, y_max_track1 = motif_activations.min(), motif_activations.max()
rect_y_min_track1 = y_min_track1 - 0.1 * (y_max_track1 - y_min_track1)
rect_y_max_track1 = y_min_track1

# Track 2: Motif effects
y_min_track2_acceptor, y_max_track2_acceptor = effects_motifs_acceptor.min(), effects_motifs_acceptor.max()
y_min_track2_donor, y_max_track2_donor = effects_motifs_donor.min(), effects_motifs_donor.max()
rect_y_min_track2_acceptor = y_min_track2_acceptor - 0.1 * (y_max_track2_acceptor - y_min_track2_acceptor)
rect_y_min_track2_donor = y_min_track2_donor - 0.1 * (y_max_track2_donor - y_min_track2_donor)
rect_y_max_track2_acceptor = y_min_track2_acceptor
rect_y_max_track2_donor = y_min_track2_donor

# Track 3: Sum of effects
y_min_track3_acceptor, y_max_track3_acceptor = effects_sum_acceptor.min(), effects_sum_acceptor.max()
y_min_track3_donor, y_max_track3_donor = effects_sum_donor.min(), effects_sum_donor.max()
rect_y_min_track3_acceptor = y_min_track3_acceptor - 0.1 * (y_max_track3_acceptor - y_min_track3_acceptor)
rect_y_min_track3_donor = y_min_track3_donor - 0.1 * (y_max_track3_donor - y_min_track3_donor)
rect_y_max_track3_acceptor = y_min_track3_acceptor
rect_y_max_track3_donor = y_min_track3_donor

# Track 4: Predicted signals
pred_acceptor = pred[0, 1, :].cpu().numpy()
pred_donor = pred[0, 0, :].cpu().numpy()
y_min_track4_acceptor, y_max_track4_acceptor = pred_acceptor.min(), pred_acceptor.max()
y_min_track4_donor, y_max_track4_donor = pred_donor.min(), pred_donor.max()
rect_y_min_track4_acceptor = y_min_track4_acceptor - 0.1 * (y_max_track4_acceptor - y_min_track4_acceptor)
rect_y_min_track4_donor = y_min_track4_donor - 0.1 * (y_max_track4_donor - y_min_track4_donor)
rect_y_max_track4_acceptor = y_min_track4_acceptor
rect_y_max_track4_donor = y_min_track4_donor

This cell determines the true splice site positions based on the exon information. It then calculates the appropriate y-axis ranges for positioning the exon rectangle annotation on each subplot based on the data being plotted in that subplot (motif activations, motif effects, sum of effects, and predicted signals). Finally, it creates the main Plotly figure with multiple subplots to display the analysis results. It sets up the layout and titles for each subplot.

In [None]:
# ------------------------------
# Create the main figure (without hover functionality)
# ------------------------------
fig = make_subplots(
    rows=4, cols=2,
    column_widths=[0.5, 0.5],
    row_heights=[0.25,0.25,0.25,0.25],
    specs=[
        [{"colspan": 2}, None],
        [{}, {}],
        [{}, {}],
        [{}, {}]
    ],
    subplot_titles=(
        "Motif Activations",
        "Motif Effects (Acceptor)", "Motif Effects (Donor)",
        "Sum Motif Effects (Acceptor)", "Sum Motif Effects (Donor)",
        "Predicted Signals with True Sites (Acceptor)",
        "Predicted Signals with True Sites (Donor)"
    )
)

# Colors
colors_motif = [f"hsl({i*360/NUM_MOTIFS},70%,50%)" for i in range(NUM_MOTIFS)]
color_donor = "blue"
color_acceptor = "orange"
color_exon = "rgba(0, 255, 0, 0.8)"

# ------------------------------
# Track 1: Motif activations (shared)
# ------------------------------
for i in range(NUM_MOTIFS):
    color = colors_motif[i]
    group = f"motif_{i}"

    fig.add_trace(
        go.Scatter(
            x=x_pred,
            y=postact_motif[0, i, :].cpu().numpy(),
            name=f"Motif {i}",
            line=dict(color=color, width=1),
            legendgroup=group,
            showlegend=True,
            customdata=[i] * len(x_pred),  # Store motif index
            hovertemplate=(
                f"<b>Motif {i}</b><br>" +
                "Position: %{x}<br>" +
                "Activation: %{y:.3f}<br>" +
                "<extra></extra>"
            )
        ),
        row=1, col=1
    )

# Add exon rectangle to Track 1
add_exon_rectangle(
    fig, acceptor_idx, donor_idx, 1, 1,
    rect_y_min_track1, rect_y_max_track1, color_exon
)

# ------------------------------
# Track 2: Motif effects (separate)
# ------------------------------
for i in range(NUM_MOTIFS):
    color = colors_motif[i]
    group = f"motif_{i}"

    # Acceptor effect
    fig.add_trace(
        go.Scatter(
            x=x_pred,
            y=effects_motifs_acceptor[i],
            line=dict(color=color, width=1),
            name=f"Motif {i} - Acceptor Effect",
            legendgroup=group,
            showlegend=False,
            customdata=[i] * len(x_pred),
            hovertemplate=(
                f"<b>Motif {i} - Acceptor Effect</b><br>" +
                "Position: %{x}<br>" +
                "Effect: %{y:.3f}<br>" +
                "<extra></extra>"
            )
        ),
        row=2, col=1
    )

    # Donor effect
    fig.add_trace(
        go.Scatter(
            x=x_pred,
            y=effects_motifs_donor[i],
            line=dict(color=color, width=1),
            name=f"Motif {i} - Donor Effect",
            legendgroup=group,
            showlegend=False,
            customdata=[i] * len(x_pred),
            hovertemplate=(
                f"<b>Motif {i} - Donor Effect</b><br>" +
                "Position: %{x}<br>" +
                "Effect: %{y:.3f}<br>" +
                "<extra></extra>"
            )
        ),
        row=2, col=2
    )

# Add exon rectangles to Track 2
add_exon_rectangle(fig, acceptor_idx, donor_idx, 2, 1,
                   rect_y_min_track2_acceptor, rect_y_max_track2_acceptor, color_exon)
add_exon_rectangle(fig, acceptor_idx, donor_idx, 2, 2,
                   rect_y_min_track2_donor, rect_y_max_track2_donor, color_exon)

# ------------------------------
# Track 3: Sum of effects
# ------------------------------
fig.add_trace(
    go.Scatter(
        x=x_pred,
        y=effects_sum_acceptor,
        line=dict(color=color_acceptor, width=2),
        name="Sum Acceptor Effects"
    ),
    row=3, col=1
)

fig.add_trace(
    go.Scatter(
        x=x_pred,
        y=effects_sum_donor,
        line=dict(color=color_donor, width=2),
        name="Sum Donor Effects"
    ),
    row=3, col=2
)

# Add exon rectangles to Track 3
add_exon_rectangle(fig, acceptor_idx, donor_idx, 3, 1,
                   rect_y_min_track3_acceptor, rect_y_max_track3_acceptor, color_exon)
add_exon_rectangle(fig, acceptor_idx, donor_idx, 3, 2,
                   rect_y_min_track3_donor, rect_y_max_track3_donor, color_exon)

# ------------------------------
# Track 4: Predicted signals + true sites
# ------------------------------
fig.add_trace(
    go.Scatter(
        x=x_pred,
        y=pred[0, 1, :].cpu().numpy(),
        line=dict(color=color_acceptor, width=2),
        name="Predicted Acceptor"
    ),
    row=4, col=1
)

fig.add_trace(
    go.Scatter(
        x=x_pred,
        y=pred[0, 0, :].cpu().numpy(),
        line=dict(color=color_donor, width=2),
        name="Predicted Donor"
    ),
    row=4, col=2
)

# Add exon rectangles to Track 4
add_exon_rectangle(fig, acceptor_idx, donor_idx, 4, 1,
                   rect_y_min_track4_acceptor, rect_y_max_track4_acceptor, color_exon)
add_exon_rectangle(fig, acceptor_idx, donor_idx, 4, 2,
                   rect_y_min_track4_donor, rect_y_max_track4_donor, color_exon)

# Turn off default hover for motif traces
fig.update_traces(
    hoverinfo="none",
    selector=dict(name=[f"Motif {i}" for i in range(NUM_MOTIFS)] +
                   [f"Motif {i} - Acceptor Effect" for i in range(NUM_MOTIFS)] +
                   [f"Motif {i} - Donor Effect" for i in range(NUM_MOTIFS)])
)

# Synchronize all x-axes
for i in range(1, 5):  # Rows 1-4
    for j in range(1, 3):  # Columns 1-2
        if i == 1 and j == 2:  # Skip the non-existent subplot in row 1
            continue
        fig.update_xaxes(
            matches='x',  # This synchronizes the x-axis
            uirevision='constant',  # Keep UI state during updates
            row=i, col=j
        )

# Then update layout with MUCH larger height
fig.update_layout(
    height=1200,  # Significantly increased from 1200
    width=1600,   # Optional: increase width too if needed
    title_text="Splice Site Analysis: Donor vs Acceptor",
    showlegend=True,
    legend=dict(
        title="Motifs",
        orientation="v",
        uirevision='constant',  # Important for maintaining zoom state
        x=1.05,
        y=1,
        bgcolor="rgba(255,255,255,0.8)",
        bordercolor="rgba(0,0,0,0.2)",
        borderwidth=1
    )
)

This cell sets up the interactive Dash application. It defines the layout with the Plotly graph and a tooltip element. The callback function `display_hover` is triggered when the user hovers over the graph. If the hovered element is a motif trace (identified by the `customdata`), it retrieves the motif index, generates an HTML div containing the motif logo image (served from the `assets` directory), the position, and the value at the hovered point, and displays this content in the tooltip. The cell also copies the generated motif logo images to an `assets` directory so Dash can serve them and saves a static version of the figure as an HTML file. Finally, it runs the Dash application.

In [None]:
# ------------------------------
# Create Dash app for interactive hover
# ------------------------------
app = Dash(__name__)

app.layout = html.Div([
    dcc.Graph(id="splice-analysis-graph", figure=fig, clear_on_unhover=True),
    dcc.Tooltip(id="graph-tooltip"),
])

@app.callback(
    Output("graph-tooltip", "show"),
    Output("graph-tooltip", "bbox"),
    Output("graph-tooltip", "children"),
    Input("splice-analysis-graph", "hoverData"),
)
def display_hover(hoverData):
    if hoverData is None:
        return False, no_update, no_update

    # Get hovered point data
    pt = hoverData["points"][0]
    bbox = pt["bbox"]

    # Check if this is a motif trace (has customdata with motif index)
    if 'customdata' not in pt or pt['customdata'] is None:
        return False, no_update, no_update

    motif_index = pt['customdata']

    # Verify motif index is valid
    if motif_index >= len(motif_logo_files):
        return False, no_update, no_update

    # Get the actual trace name more reliably
    curve_number = pt['curveNumber']
    trace_name = fig['data'][curve_number]['name'] if 'data' in fig and curve_number < len(fig['data']) else ''

    hover_text = f"Motif {motif_index}"

    # Better detection of trace type
    if trace_name and 'Acceptor Effect' in trace_name:
        hover_text += " - Acceptor Effect"
    elif trace_name and 'Donor Effect' in trace_name:
        hover_text += " - Donor Effect"
    elif trace_name and ('Motif' in trace_name and 'Effect' not in trace_name):
        hover_text += " - Activation"
    else:
        hover_text += " - Activation"  # Default fallback

    # Get position and value
    x_pos = pt.get('x', 'N/A')
    y_val = pt.get('y', 'N/A')

    # Create hover content with motif logo
    children = [
        html.Div([
            html.H3(f"{hover_text}", style={"color": "darkblue", "margin-bottom": "5px"}),
            html.Img(
                src=f"/assets/motif_{motif_index}.png",
                style={"width": "400px", "height": "150px", "border": "1px solid #ccc"}
            ),
            html.P(f"Position: {x_pos}", style={"margin": "2px 0"}),
            html.P(f"Value: {y_val:.3f}", style={"margin": "2px 0"}),
        ], style={
            'width': '420px',
            'white-space': 'normal',
            'background': 'white',
            'border': '1px solid #ccc',
            'padding': '10px',
            'border-radius': '5px'
        })
    ]

    return True, bbox, children
# Copy motif logo files to assets directory for Dash to serve
if not os.path.exists('assets'):
    os.makedirs('assets')

for i, logo_file in enumerate(motif_logo_files):
    import shutil
    shutil.copy2(logo_file, f'assets/motif_{i}.png')

# Run on all interfaces
app.run(debug=True,  port=8050)

<IPython.core.display.Javascript object>