# Constucting TRIM input from raw RPL8A sequence

In [None]:
import pandas as pd
import os

# ================= Configuration =================

# Input Path
INPUT_FILE = 'RPL8A_raw.xlsx'
OUTPUT_FILE = 'RPL8A_input_5U.csv'
# Sequence
# RPL8A UTR + start 7bp in WT5'UTR
# Ensure length of 140bp, followed by 10bp mutation points
CONST_5UTR_140BP = "CCGACGCAAACAAATTGGAAAAACCAACGCAAAAAAAAAAAGACGCTAAATTGTTTATAAAGGCGAGGAATTTGTATCTATCAATTACTATTCCAGTTGTCAGTTTACATTGCTTACCCTCTATTATCACATCAAAACAA" 
# First 50bp CDS of YFP reporter gene
CONST_CDS_50BP = "ATGGTGAGCAAGGGCGAGGAGCTGTTCACCGGGGTGGTGCCCATCCTGGTC"

# ===========================================

def generate_model_input():
    # Check input file
    if not os.path.exists(INPUT_FILE):
        print(f"Error: {INPUT_FILE} not found.")
        return

    print(f"Loading {INPUT_FILE} ...")
    try:
        df_raw = pd.read_excel(INPUT_FILE)
    except Exception as e:
        print(f"Error: Excel loading failed {e}")
        return

    # Check necessary columns
    required_columns = ['sequence_variant', 'mean_protein_abundance']
    if not all(col in df_raw.columns for col in required_columns):
        print(f"Error: Lacking necessary files in Excel Files. Needing: {required_columns}")
        return

    print(f"Successfully loaded {len(df_raw)} data, constructing model input...")

    df_output = pd.DataFrame()
    df_output['tx_id'] = [f"RPL8A_{i}" for i in df_raw.index]

    # utr5_sequence: 140bp constant + 10bp mutation
    df_output['utr5_sequence'] = df_raw['sequence_variant'].apply(
        lambda x: CONST_5UTR_140BP + str(x).strip()
    )

    # cds_sequence: 50bp constant
    df_output['cds_sequence'] = CONST_CDS_50BP
    df_output['split'] = 'test'
    # SystematicName: 'RPL8A'
    df_output['SystematicName'] = 'RPL8A'
    # Protein_abundance (Label): mean_protein_abundance from raw data
    df_output['Protein_abundance'] = df_raw['mean_protein_abundance']

    sample_utr_len = len(df_output.iloc[0]['utr5_sequence'])
    print(f"5'UTR length generated: {sample_utr_len} bp (Expecting: {len(CONST_5UTR_140BP) + 10} bp)")
    
    df_output.to_csv(OUTPUT_FILE, index=False)
    print(f"Finishing processing! File saved to: {OUTPUT_FILE}")

if __name__ == "__main__":
    generate_model_input()

# Extracting data with top/bottom TE

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os

# ================= Configuration =================
# 1. Input file with prediction
INPUT_FILE = 'RPL8A_input_5U.with_pred.csv' 

# 2. Column Configuration
COL_MEASURED = 'Protein_abundance'
COL_PREDICTED = 'pred_TE_5U'

# 3. Threshold (Top/Bottom 10%)
PERCENTILE_LOW = 0.10
PERCENTILE_HIGH = 0.90

# 4. Output filename
OUTPUT_PREFIX = 'Extracted_Data'
# ===========================================

def extract_consistent_extremes():
    # Load data
    if not os.path.exists(INPUT_FILE):
        print(f"Error: {INPUT_FILE} not found.")
        return

    df = pd.read_csv(INPUT_FILE)
    print(f"{len(df)} successfully loaded")

    df = df.dropna(subset=[COL_MEASURED, COL_PREDICTED])
    # Calculate Quantiles
    meas_low_thresh = df[COL_MEASURED].quantile(PERCENTILE_LOW)
    meas_high_thresh = df[COL_MEASURED].quantile(PERCENTILE_HIGH)
    
    # Threshold of prediction
    pred_low_thresh = df[COL_PREDICTED].quantile(PERCENTILE_LOW)
    pred_high_thresh = df[COL_PREDICTED].quantile(PERCENTILE_HIGH)

    # Data Extraction
    # Condition A: High TE data
    df_high_consistent = df[
        (df[COL_MEASURED] >= meas_high_thresh) & 
        (df[COL_PREDICTED] >= pred_high_thresh)
    ].copy()
    # Condition B: Low TE data
    df_low_consistent = df[
        (df[COL_MEASURED] <= meas_low_thresh) & 
        (df[COL_PREDICTED] <= pred_low_thresh)
    ].copy()

    # Save file
    file_high = f"{OUTPUT_PREFIX}_High_Consistent.csv"
    file_low = f"{OUTPUT_PREFIX}_Low_Consistent.csv"
    df_high_consistent.to_csv(file_high, index=False)
    df_low_consistent.to_csv(file_low, index=False)
    print(f"Extraction finished!")

if __name__ == "__main__":
    extract_consistent_extremes()

# Using top/bottom 10% data for heatmap

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import os

# ================= Configuration =================
# Including WT(RPL8A_WT) raw genome-wide data
FULL_DATA_FILE = 'RPL8A_input_5U.with_pred.csv' 

# Targeted files to process
TARGET_FILES = {
    'Extracted_Data_High_Consistent.csv': 'Mutation Impact: High TE Group (Top 10%)',
    'Extracted_Data_Low_Consistent.csv':  'Mutation Impact: Low TE Group (Bottom 10%)'
}

# Column Configuration
COL_ID = 'tx_id'
COL_SEQ = 'utr5_sequence'
COL_TE = 'pred_TE_5U'
WT_ID = 'RPL8A_WT'
AUTO_CROP = True
# ===========================================

def get_wt_info(full_file):
    """Extracting WT sequences and TE"""
    if not os.path.exists(full_file):
        raise FileNotFoundError(f"Sequence files not found: {full_file}")

    df = pd.read_csv(full_file)
    wt_row = df[df[COL_ID] == WT_ID]

    if len(wt_row) == 0:
        raise ValueError(f"Sequence of {WT_ID} not found in {full_file}")

    wt_seq = str(wt_row.iloc[0][COL_SEQ]).strip()
    wt_te = float(wt_row.iloc[0][COL_TE])
    return wt_seq, wt_te

def plot_heatmap(file_path, title, wt_seq, wt_te):
    if not os.path.exists(file_path):
        print(f"Jumped: File {file_path} not found")
        return

    df = pd.read_csv(file_path)
    seq_len = len(wt_seq)
    bases = ['A', 'C', 'G', 'T']
    data_dict = {pos: {b: [] for b in bases} for pos in range(seq_len)}
    mutation_count = 0

    # Calculating Delta TE
    for _, row in df.iterrows():
        curr_seq = str(row[COL_SEQ]).strip()
        curr_te = float(row[COL_TE])
        
        if len(curr_seq) != seq_len:
            continue
            
        delta_te = curr_te - wt_te
        
        for pos in range(seq_len):
            if curr_seq[pos] != wt_seq[pos]:
                base = curr_seq[pos]
                if base in bases:
                    data_dict[pos][base].append(delta_te)
                    mutation_count += 1

    heatmap_df = pd.DataFrame(index=bases, columns=range(1, seq_len + 1))
    for pos in range(seq_len):
        for base in bases:
            vals = data_dict[pos][base]
            heatmap_df.at[base, pos + 1] = np.mean(vals) if vals else np.nan
    
    heatmap_df = heatmap_df.astype(float)

    if AUTO_CROP:
        heatmap_df = heatmap_df.dropna(axis=1, how='all')
        print(f"  Cropped to {heatmap_df.shape[1]}bp active area.")

    curr_min = np.nanmin(heatmap_df.values)
    curr_max = np.nanmax(heatmap_df.values)
    print(f"  Dynamic range: [{curr_min:.3f}, +{curr_max:.3f}]")
    # ==========================================================

    if "High" in title:
        # Reflection: Min(White) -> Max(Red)
        use_cmap = "Reds"    
    elif "Low" in title:
        # Reflection: Min(Dark Blue) -> Max(White)
        use_cmap = "Blues_r" 
    else:
        use_cmap = "vlag"
    # ==========================================================

    # Plot
    plt.figure(figsize=(12, 3))
    sns.set_theme(style="white")
    g = sns.heatmap(heatmap_df, 
                    cmap=use_cmap, 
                    vmin=curr_min, 
                    vmax=curr_max, 
                    annot=False, 
                    fmt=".2f",
                    linewidths=1, 
                    linecolor='white',
                    cbar_kws={'label': 'Mean $\Delta$TE'},
                    square=True)

    plt.xlabel("Mutation Position", fontsize=16,fontweight = "bold")
    plt.ylabel("Nucleotide", fontsize=16,fontweight = "bold")
    plt.yticks(rotation=0) 
    
    safe_name = title.split(":")[0].replace(" ", "_") + "_" + file_path.split("_")[2] + ".png"
    plt.tight_layout()
    plt.savefig(safe_name, dpi=300, bbox_inches='tight')
    print(f"  Figure saved to {safe_name}")
    print("-" * 30)

def main():
    try:
        wt_seq, wt_te = get_wt_info(FULL_DATA_FILE)
        print(f"WT TE: {wt_te:.4f}")
        print("-" * 30)
    except Exception as e:
        print(f"Error: {e}")
        return

    for fpath, ftitle in TARGET_FILES.items():
        plot_heatmap(fpath, ftitle, wt_seq, wt_te)

if __name__ == "__main__":
    main()

# Generating TRIM input with double mutation

In [None]:
import pandas as pd

# WT
wt_utr = "CCGACGCAAACAAATTGGAAAAACCAACGCAAAAAAAAAAAGACGCTAAATTGTTTATAAAGGCGAGGAATTTGTATCTATCAATTACTATTCCAGTTGTCAGTTTACATTGCTTACCCTCTATTATCACATCAAAACAACTAATTCGAA"
wt_cds = "ATGGTGAGCAAGGGCGAGGAGCTGTTCACCGGGGTGGTGCCCATCCTGGTC"
wt_name = "RPL8A"
wt_split = "test"

print(f"Full length of 5'UTR: {len(wt_utr)} bp")

# Target range from 140 to 150 bp (Bio Coordination，start from 1)
Target_Start_Bio = 141
Target_End_Bio = 150

# Transform to Python index (0-based)
# 140 -> index 139
idx_start = Target_Start_Bio - 1 
# 150 -> index 150
idx_end = Target_End_Bio 
if idx_end > len(wt_utr):
    idx_end = len(wt_utr)


print(f"Mutation Window (Python Index): {idx_start} - {idx_end}")
print(f"Raw Sequence in mutation window: {wt_utr[idx_start:idx_end]}")
print("-" * 30)


# Generating mutation data
bases = ['A', 'C', 'G', 'T']
generated_rows = []
seq_list_template = list(wt_utr)
window_indices = range(idx_start, idx_end)

# Adding WT data
generated_rows.append({
    'tx_id': 'RPL8A_WT',
    'utr5_sequence': wt_utr,
    'cds_sequence': wt_cds,
    'split': wt_split,
    'SystematicName': wt_name
})

count = 0

# Position loop
for i in window_indices:
    for j in range(i + 1, idx_end):
        for b1 in bases:
            for b2 in bases:
                current_seq_list = seq_list_template.copy()                
                orig_b1 = current_seq_list[i]
                orig_b2 = current_seq_list[j]                
                current_seq_list[i] = b1
                current_seq_list[j] = b2
                mutated_utr = "".join(current_seq_list)
                
                # Constructing tx_id
                mutation_info = f"m{i+1}{b1}_m{j+1}{b2}"
                new_tx_id = f"{wt_name}_{mutation_info}"
                
                generated_rows.append({
                    'tx_id': new_tx_id,
                    'utr5_sequence': mutated_utr,
                    'cds_sequence': wt_cds,
                    'split': wt_split,
                    'SystematicName': wt_name
                })
                count += 1

df_input = pd.DataFrame(generated_rows)

output_file = "RPL8A_double_mutations_input.csv"
df_input.to_csv(output_file, index=False)

print(f"{len(df_input)} Data generated。")
print(f"File saved: {output_file}")

# Plot Lollipop Figure

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import re

# ================= Configuration =================
FILE_PATH = 'RPL8A_double_mutations_input.with_pred.csv'
SCORE_COL = 'pred_TE_5U'
WT_BASELINE = 0.1
TARGET_POS = 149
TARGET_ALT = 'C'
sns.set(style="whitegrid", font_scale=1.1)

# ================= Data loading =================
df = pd.read_csv(FILE_PATH)

# Parsing ID
pattern = re.compile(r'_m(\d+)([ACGT])_m(\d+)([ACGT])')
def parse_id(tx_id):
    match = pattern.search(tx_id)
    return (int(match.group(1)), match.group(2), int(match.group(3)), match.group(4)) if match else (None, None, None, None)

df['Pos1'], df['Alt1'], df['Pos2'], df['Alt2'] = zip(*df['tx_id'].map(parse_id))
df = df.dropna(subset=['Pos1'])

# Ensuring position column as integar
df['Pos1'] = df['Pos1'].astype(int)
df['Pos2'] = df['Pos2'].astype(int)

if WT_BASELINE is None:
    WT_BASELINE = df[SCORE_COL].mean()

# Lollipop Plot for 149C
cond1 = (df['Pos1'] == TARGET_POS) & (df['Alt1'] == TARGET_ALT)
cond2 = (df['Pos2'] == TARGET_POS) & (df['Alt2'] == TARGET_ALT)

df_plot = df[cond1 | cond2].copy()

# Ensure mutation partner
def get_partner(row):
    if row['Pos1'] == TARGET_POS:
        return row['Pos2'], row['Alt2']
    else:
        return row['Pos1'], row['Alt1']

df_plot['Partner_Pos'], df_plot['Partner_Base'] = zip(*df_plot.apply(get_partner, axis=1))

plt.figure(figsize=(7, 3))
colors = {'A': '#d62728', 'C': '#1f77b4', 'G': "#fbff0e", 'T': '#2ca02c'}

plt.vlines(x=df_plot['Partner_Pos'], ymin=WT_BASELINE, ymax=df_plot[SCORE_COL], 
           color='grey', alpha=0.4, linewidth=1)
sns.scatterplot(
    data=df_plot, x='Partner_Pos', y=SCORE_COL, hue='Partner_Base', 
    palette=colors, s=150, zorder=5, edgecolor='k', alpha=0.7
)
plt.axhline(y=WT_BASELINE, color='black', linestyle='--', alpha=0.6, label='WT')
plt.xlabel(f'Position of Partner Mutation (with {TARGET_POS}{TARGET_ALT})', fontsize=16)
plt.ylabel('Predicted TE', fontsize=16)
min_pos = df_plot['Partner_Pos'].min()
max_pos = df_plot['Partner_Pos'].max()
if pd.notna(min_pos) and pd.notna(max_pos):
    plt.xticks(range(int(min_pos)-1, int(max_pos)+2)) 
plt.tick_params(axis='x', bottom=True, length=5, width=1, color='black')
plt.legend(title='Partner Base', loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
filename = f'Lollipop_Partner_Screening_{TARGET_POS}{TARGET_ALT}.png'
plt.savefig(filename, dpi=300)

print(f"Lollipop Figure saved as {filename}")

# Double mutation scan and plot

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import re

# ================= Configuration =================
FILE_PATH = 'RPL8A_double_mutations_input.with_pred.csv'
SCORE_COL = 'pred_TE_5U'
WT_BASELINE = 0.1

sns.set(style="white", font_scale=1.1)

# ================= Data extracting and analysis =================
df = pd.read_csv(FILE_PATH)

pattern = re.compile(r'_m(\d+)([ACGT])_m(\d+)([ACGT])')
def parse_id(tx_id):
    match = pattern.search(tx_id)
    return (int(match.group(1)), match.group(2), int(match.group(3)), match.group(4)) if match else (None, None, None, None)

df['Pos1'], df['Alt1'], df['Pos2'], df['Alt2'] = zip(*df['tx_id'].map(parse_id))
df = df.dropna(subset=['Pos1'])
df['Pos1'] = df['Pos1'].astype(int)
df['Pos2'] = df['Pos2'].astype(int)

all_positions = range(141, 151)
heatmap_matrix = pd.DataFrame(index=all_positions, columns=all_positions, dtype=float)

# Searching function
def get_combined_score(p_a, p_b):
    # Searching for Pos1=p_a, Pos2=p_b lines
    cond1 = (df['Pos1'] == p_a) & (df['Pos2'] == p_b)
    # Searching for Pos1=p_b, Pos2=p_a lines
    cond2 = (df['Pos1'] == p_b) & (df['Pos2'] == p_a)
    
    matches = df[cond1 | cond2]
    if not matches.empty:
        return matches[SCORE_COL].max() # return the max value
    else:
        return np.nan

for y_pos in all_positions:
    for x_pos in all_positions:
        if y_pos >= x_pos:
            score = get_combined_score(y_pos, x_pos)
            heatmap_matrix.at[y_pos, x_pos] = score
        else:
            heatmap_matrix.at[y_pos, x_pos] = np.nan

# ================= Plot =================
plt.figure(figsize=(8, 6))
mask = np.triu(np.ones_like(heatmap_matrix, dtype=bool), k=1)
ax = sns.heatmap(
    heatmap_matrix, 
    mask=mask,
    cmap='coolwarm', 
    annot=True, 
    fmt=".2f",
    linewidths=1,
    linecolor='white',
    cbar_kws={'label': 'Max Predicted TE', 'shrink': 0.8, 'pad': 0.0}
)
ax.invert_yaxis()
ax.xaxis.tick_top()
ax.xaxis.set_label_position('top')
plt.yticks(rotation=0)
plt.ylabel('Mutation Position 1 (Y)', fontsize=16, labelpad=10)
plt.xlabel('Mutation Position 2 (X)', fontsize=16, labelpad=10)
ax.tick_params(left=False, top=False)

plt.tight_layout()
plt.savefig('Global_Interaction_Heatmap.png', dpi=300)
print("Fig saved.")

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.lines import Line2D

# ================= Configuration =================
INPUT_FILE = 'RPL8A_input_5U.with_pred.csv'
WT_SEQUENCE = "CCGACGCAAACAAATTGGAAAAACCAACGCAAAAAAAAAAAGACGCTAAATTGTTTATAAAGGCGAGGAATTTGTATCTATCAATTACTATTCCAGTTGTCAGTTTACATTGCTTACCCTCTATTATCACATCAAAACAACTAATTCGAA"
# Extra mutations allowed
MAX_EXTRA_MUTATIONS = 5

COLOR_BG = '#b0b0b0'
COLOR_SINGLE = '#1f77b4'
COLOR_DOUBLE = '#d62728'

# ================= Data Preprocessing =================
def count_extra_mutations(seq, wt_seq, ignore_indices):
    mismatches = 0
    length = min(len(seq), len(wt_seq))
    for i in range(length):
        if i in ignore_indices:
            continue
        if seq[i] != wt_seq[i]:
            mismatches += 1
    return mismatches

def classify_sequence(seq, wt_seq, max_extra_muts):
    if len(seq) < 150: return 'Background'
    idx_148 = 147
    idx_149 = 148
    base_148 = seq[idx_148]
    base_149 = seq[idx_149]
    wt_148 = wt_seq[idx_148]
    
    extra_muts = count_extra_mutations(seq, wt_seq, ignore_indices={idx_148, idx_149})
    if extra_muts > max_extra_muts:
        return 'Background'

    if base_148 == 'A' and base_149 == 'G':
        return 'Double'
    elif base_149 == 'C' and base_148 == wt_148:
        return 'Single'
    else:
        return 'Background'

def run_filtered_scatter_plot():
    df = pd.read_csv(INPUT_FILE)
    df['Protein_abundance'] = pd.to_numeric(df['Protein_abundance'], errors='coerce')
    df = df.dropna(subset=['Protein_abundance', 'utr5_sequence'])

    print(f"Classifying... (Only extra mutation <= {MAX_EXTRA_MUTATIONS} sequences reserved)")
    df['Type'] = df['utr5_sequence'].apply(lambda x: classify_sequence(x, WT_SEQUENCE, MAX_EXTRA_MUTATIONS))
    df['X_Axis'] = '' 
    df_bg = df[df['Type'] == 'Background']
    df_single = df[df['Type'] == 'Single']
    df_double = df[df['Type'] == 'Double']
    
    print(f"Single Mutation number (Single): {len(df_single)} 个")
    print(f"Double Mutation number: {len(df_double)} 个")

    sns.set(style="ticks", font_scale=1.2)
    plt.figure(figsize=(5, 6)) 

    # --- Layer 1: Background Points ---
    sns.stripplot(
        x='X_Axis', y='Protein_abundance', data=df_bg,
        color=COLOR_BG, size=4, jitter=0.25, alpha=0.3, zorder=1
    )

    # --- Layer 2: Single Mutation ---
    if not df_single.empty:
        sns.stripplot(
            x='X_Axis', y='Protein_abundance', data=df_single,
            color=COLOR_SINGLE, size=8, jitter=0.25, 
            edgecolor='white', linewidth=1, zorder=10
        )

    # --- Layer 3: Double Mutation ---
    if not df_double.empty:
        sns.stripplot(
            x='X_Axis', y='Protein_abundance', data=df_double,
            color=COLOR_DOUBLE, size=9, jitter=0.25, marker='o',
            edgecolor='white', linewidth=1, zorder=20
        )

    LINE_XMIN = -0.4
    LINE_XMAX = 0.4
    TEXT_X_POS = 0.42
    
    # 1. BG AVG
    bg_mean = df_bg['Protein_abundance'].mean()
    plt.hlines(y=bg_mean, xmin=LINE_XMIN, xmax=LINE_XMAX, colors='gray', linestyles='--', alpha=0.6, linewidth=1.5, zorder=5)
    plt.text(x=TEXT_X_POS, y=bg_mean, s='BG AVG', 
             color='gray', va='center', ha='left', fontsize=10)

    # 2. 149-C AVG
    if not df_single.empty:
        single_mean = df_single['Protein_abundance'].mean()
        plt.hlines(y=single_mean, xmin=LINE_XMIN, xmax=LINE_XMAX, colors=COLOR_SINGLE, linestyles='--', alpha=0.8, linewidth=1.5, zorder=15)
        plt.text(x=TEXT_X_POS, y=single_mean-0.1, s='149-T AVG', 
                 color=COLOR_SINGLE, va='center', ha='left', fontsize=10, fontweight='bold')

    # 3. Double AVG
    if not df_double.empty:
        double_mean = df_double['Protein_abundance'].mean()
        plt.hlines(y=double_mean, xmin=LINE_XMIN, xmax=LINE_XMAX, colors=COLOR_DOUBLE, linestyles='--', alpha=0.8, linewidth=1.5, zorder=25)
        plt.text(x=TEXT_X_POS, y=double_mean+0.05, s="Double AVG", 
                 color=COLOR_DOUBLE, va='center', ha='left', fontsize=10, fontweight='bold')

    # ==========================================================
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor=COLOR_SINGLE, label='149-T', markersize=10),
        Line2D([0], [0], marker='o', color='w', markerfacecolor=COLOR_DOUBLE, label='148-A, 149-G', markersize=10)
    ]
    plt.legend(handles=legend_elements, loc='upper center', frameon=False, bbox_to_anchor=(0.3, 1.1), ncol=2)
    plt.ylabel('Protein Abundance', fontsize=14)
    plt.xlabel('') 
    plt.xlim(-0.5, 1.3)
    sns.despine(bottom=True)
    plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False)
    plt.tight_layout()
    plt.savefig('Scatter_Mutation.png', dpi=300)
    print("Figure Saved: Fig_Scatter_Filtered.png")

if __name__ == "__main__":
    run_filtered_scatter_plot()