# Frequency Analysis
This notebook consolidates loading, annotation, filtering, and visualization of mutation data.

## Section 1: Imports

In [None]:
# 1. Imports & Dependencies
import pandas as pd
import numpy as np
import re
import os
import glob
from pathlib import Path
from Bio import SeqIO, Seq
from fpdf import FPDF
import plotly.graph_objects as go
import plotly.express as px
from tqdm import tqdm


## Section 2: Configurations

In [None]:
# 2. Paths
metadata_path           = Path(".../.../.../chronic_metadata.csv")
base_dir                = Path(".../.../.../mutation_analysis")
problematic_path        = Path(".../.../.../Recurrent_mutations.tsv")
manual_changes_path     = Path(".../.../.../mutation_manual.csv")
suspicious_mutations_path = Path(".../.../.../suspected_muts_with_soft_filtering_trim_40.csv")
patients_tp_path        = Path(".../.../.../patients_timepoints_to_analyze.csv")
consenss_fasta_path     = Path(".../.../.../wuhan_ref.fasta")
plot_folder             = Path(".../.../.../Frequency_Graphs")
freqs_folder            = Path(".../.../.../Frequency_Tables")

# 3. ORF Coordinates & Protein Map
ORF_COORDS = {
    'ORF1a': (266, 13467), 'ORF1b': (13468, 21556), 'S': (21563, 25384),
    'ORF3a': (25393, 26220), 'E': (26245, 26472),  'M': (26523, 27191),
    'ORF6': (27202, 27387), 'ORF7a': (27394, 27759), 'ORF7b': (27756, 27887),
    'ORF8': (27894, 28259),  'N': (28274, 29533),  'ORF10': (29558, 29674)
}
protein_abbreviations = {
    'envelope protein': 'E', 'membrane glycoprotein': 'M',
    'nucleocapsid phosphoprotein': 'N', 'ORF10 protein': 'ORF10',
    'orf1ab polyprotein': 'ORF1ab', 'ORF3a protein': 'ORF3a',
    'ORF6 protein': 'ORF6', 'ORF7a protein': 'ORF7a',
    'ORF8 protein': 'ORF8', 'surface glycoprotein': 'S'
}

# 4. Timepoint Adjustments & Patient Groups
days_to_add = {"N1":18,"N2":29,"N3":60,"N4":0,"N7":41,"N8":4,"P3":30,"P4":26,"P5":5}

# 5. Plot Colors & Styles
PREDEFINED_COLORS = [
    "#E69F00","#56B4E9","#009E73","#F0E442","#0072B2",
    "#D55E00","#CC79A7","#000000","#CC6677","#882255",
    "#44AA99","#117733","#332288","#AA4499","#88CCEE"
]
FONT_SIZE_MAIN       = 45
FONT_SIZE_AXIS_TITLE = 40
FONT_SIZE_TICK       = 30
FONT_SIZE_LEGEND     = 30
MARKER_SIZE          = 10
LINE_WIDTH           = 4
OPACITY              = 0.5
LOW_FREQ_COLOR       = "#A9A9A9"
PLOT_HEIGHT          = 900
PLOT_WIDTH           = 1200
GRID_COLOR           = 'lightgrey'
TEXT_COLOR           = 'black'


## Section 4: Load & Merge Mutation Files

In [None]:
# 6. Load and Merge
all_merged_files = glob.glob(str(base_dir / "*/*/*merged.csv"))
df_list = []
for file_path in all_merged_files:
    parts   = Path(file_path).parts
    patient = parts[-3]
    tp      = int(parts[-2])
    df      = pd.read_csv(file_path)
    df.insert(0, "patient",     patient)
    df.insert(1, "timepoint",   tp)
    df_list.append(df)

combined_mutations_df = pd.concat(df_list, ignore_index=True) if df_list else pd.DataFrame()


## Section 5: Apply Manual Corrections

In [None]:
# 7. Apply Manual Changes
manual_df = pd.read_csv(manual_changes_path).rename(columns={'patient_id':'patient'})
combined_mutations_df['manually_changed'] = combined_mutations_df.get('manually_changed', False)

for _, row in manual_df.iterrows():
    mask = (
        (combined_mutations_df['patient']   == row['patient']) &
        (combined_mutations_df['timepoint']== row['timepoint']) &
        (combined_mutations_df['mutation']  == row['mutation'])
    )
    fixed = row.get('fixed_freq')
    if pd.notna(fixed):
        if mask.any():
            combined_mutations_df.loc[mask, ['final_freq','manually_changed']] = [fixed, True]
        else:
            new = {**{c:None for c in combined_mutations_df.columns}}
            new.update(patient=row['patient'],
                       timepoint=row['timepoint'],
                       mutation=row['mutation'],
                       final_freq=fixed,
                       manually_changed=True)
            combined_mutations_df = pd.concat(
                [combined_mutations_df, pd.DataFrame([new])],
                ignore_index=True
            )

# Drop one known artefact
combined_mutations_df = combined_mutations_df[
    combined_mutations_df['mutation'] != 'T22204+GAGCCAGAA'
]


## Section 6: Filter Relevant Patients & Timepoints

In [None]:
# 8. Load whitelist and filter
patients_tp_df = pd.read_csv(patients_tp_path)
patients_tp_df['timepoints'] = (
    patients_tp_df['timepoints']
    .astype(str)
    .str.strip('[]')
    .str.split(',')
    .apply(lambda L: [int(x) for x in L if x.strip().isdigit()])
)
pt_dict = dict(zip(patients_tp_df['patient'], patients_tp_df['timepoints']))

def is_relevant(r):
    return r['patient'] in pt_dict and r['timepoint'] in pt_dict[r['patient']]

filtered_mutations_df = combined_mutations_df.loc[
    combined_mutations_df.apply(is_relevant, axis=1)
].copy()

# Remove mutations never >0 in any timepoint
present = (
    filtered_mutations_df
    .query("final_freq > 0")
    .groupby('patient')['mutation']
    .unique()
    .explode()
    .reset_index()
)
filtered_mutations_df = filtered_mutations_df.merge(
    present, on=['patient','mutation'], how='inner'
)


## Section 7: Load Reference & Define Annotation Functions


In [None]:
def load_wuhan_ref(path_to_fasta):
    record = SeqIO.read(path_to_fasta, "fasta")
    return str(record.seq).upper()

# --- Deletion annotation usinע custom logic ---
def process_deletion(mutation: str, consensus_sequence: str, protein_dict: dict) -> str:
    if "-" not in mutation:
        return mutation
    
    deletion_part = mutation.split("-")[1]
    num_deleted_nucleotides = len(deletion_part)
    if num_deleted_nucleotides % 3 != 0:
        return mutation

    match = re.search(r'\d+', mutation)
    if not match:
        return mutation
    start_pos = int(match.group())
    protein = None
    for prot, (start, end) in protein_dict.items():
        if start <= start_pos <= end:
            protein = prot
            break
    if not protein:
        return mutation

    deleted_amino_acids = num_deleted_nucleotides // 3
    aa_start_pos = (start_pos - protein_dict[protein][0]) // 3 + 2
    if deleted_amino_acids == 1:
        aa_range = f"{aa_start_pos}"
    else:
        aa_range = f"{aa_start_pos}-{aa_start_pos + deleted_amino_acids}"

    return f"{protein}:Δ{aa_range}"

# --- Core mutation annotation logic ---
def annotate_mutation(mutation, ref_seq, protein_dict):
    match_sub = re.match(r'^([ACGT])(\d+)([ACGT])$', mutation)
    match_ins = re.match(r'^([ACGT])(\d+)\+([ACGT]+)$', mutation)
    match_del = re.match(r'^([ACGT])(\d+)-([ACGT]*)$', mutation)

    def is_valid_position(pos):
        return pos is not None and 1 <= pos <= len(ref_seq)

    if match_sub:
        ref, pos_str, alt = match_sub.groups()
        pos = int(pos_str)
        if not is_valid_position(pos):
            return None, mutation, 'out_of_bounds'
        if ref_seq[pos - 1].upper() != ref:
            return None, mutation, 'mismatch'
        for orf, (start, end) in protein_dict.items():
            if start <= pos <= end:
                rel_pos = pos - start
                codon_start = start + (rel_pos // 3) * 3
                codon_seq = ref_seq[codon_start - 1: codon_start + 2]
                if len(codon_seq) != 3:
                    return orf, mutation, 'syn'
                codon_list = list(codon_seq)
                codon_pos = pos - codon_start
                if codon_pos < 0 or codon_pos > 2:
                    return orf, mutation, 'syn'
                codon_list[codon_pos] = alt
                aa_orig = Seq.Seq(codon_seq).translate()
                aa_new = Seq.Seq(''.join(codon_list)).translate()
                if aa_orig == aa_new:
                    return orf, mutation, 'syn'
                else:
                    aa_pos = (rel_pos // 3) + 1
                    return orf, f"{aa_orig}{aa_pos}{aa_new}", 'non-syn'
        return None, mutation, 'syn'

    elif match_ins:
        ref, pos_str, ins_seq = match_ins.groups()
        pos = int(pos_str)
        if not is_valid_position(pos):
            return None, mutation, 'out_of_bounds'
        for orf, (start, end) in protein_dict.items():
            if start <= pos <= end:
                return orf, mutation, 'indel'
        return None, mutation, 'indel'

    elif match_del:
        ref, pos_str, del_seq = match_del.groups()
        pos = int(pos_str)
        if not is_valid_position(pos):
            return None, mutation, 'out_of_bounds'
        formatted = process_deletion(mutation, ref_seq, protein_dict)
        if ":" in formatted:
            return formatted.split(":")[0], formatted, 'indel'
        else:
            return None, formatted, 'indel'

    return None, mutation, 'unknown'

def annotate_dataframe(df, ref_seq, protein_dict):
    # Filter out mutations that start or end with 'N'
    n_df = df[df['mutation'].str.startswith('N') & df['mutation'].str.endswith('N')].copy()
    df = df[~df['mutation'].str.startswith('N') & ~df['mutation'].str.endswith('N')].copy()

    def apply_annotation(row):
        orf, mut_aa, mut_type = annotate_mutation(row['mutation'], ref_seq, protein_dict)

        # Fallback to protein_y if ORF is missing
        if orf is None and pd.notna(row.get('protein_y')):
            orf = row['protein_y']

        # Build final manual_mut_AA
        if mut_type == 'non-syn' and orf is not None:
            manual_mut_AA = f"{orf}:{mut_aa}"
        else:
            manual_mut_AA = mut_aa

        # Try to extract position and get consensus base
        match = re.search(r'(\d+)', row['mutation'])
        if match:
            pos = int(match.group())
            if 1 <= pos <= len(ref_seq):
                con_base = ref_seq[pos - 1]
            else:
                con_base = 'N/A'
        else:
            con_base = 'N/A'

        return pd.Series([orf, manual_mut_AA, mut_type, con_base])

    df[['manual_prot', 'manual_mut_AA', 'mut_type', 'con_base']] = df.apply(apply_annotation, axis=1)
    return df, n_df

## Section 8: Annotate & Prepare Final Table

In [None]:
# 12. Annotate
annotated_df, n_df = annotate_dataframe(filtered_mutations_df, ref_seq, ORF_COORDS)

# 13. Simplify & reorder
if 'manual_mut_AA' in annotated_df: 
    annotated_df = annotated_df.rename(columns={'manual_mut_AA':'manual_mut'})
if 'manual_prot' in annotated_df:
    annotated_df = annotated_df.rename(columns={'manual_prot':'protein'})

final_mutations_df = (
    annotated_df
    [['patient','timepoint','POS','mutation','manual_mut',
      'protein','mut_type','final_freq','con_base']]
    .reset_index(drop=True)
)

# 14. Adjust to absolute dates
final_mutations_df['timepoint'] = final_mutations_df.apply(
    lambda r: r['timepoint'] + days_to_add.get(r['patient'],0), axis=1
)
patients_tp_df['adjusted_timepoints'] = (
    patients_tp_df.apply(
        lambda r: [tp + days_to_add.get(r['patient'],0) for tp in r['timepoints']],
        axis=1
    )
)


## Section 9: Export Timepoint Tables


In [None]:
def build_timepoints_csvs(final_mutations_df, patients_tp_df, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    # If any POS values are missing, try extracting them from the 'mutation' column
    if final_mutations_df['POS'].isnull().any():
        def extract_pos(mut):
            match = re.search(r'\d+', str(mut))
            return int(match.group()) if match else None

        final_mutations_df['POS'] = final_mutations_df['POS'].fillna(
            final_mutations_df['mutation'].apply(extract_pos)
        )
    # Iterate over each patient
    for _, row in patients_tp_df.iterrows():
        patient = row['patient']
        timepoints = row['adjusted_timepoints']
        # Filter mutations for that patient
        patient_df = final_mutations_df[final_mutations_df['patient'] == patient]
        # Get unique mutations for that patient (including POS!)
        unique_mutations = patient_df[['POS','mutation', 'manual_mut', 'protein','mut_type']].drop_duplicates()
        # Build a structure for the final dataframe
        result_df = unique_mutations.copy()
        # Initialize columns for each timepoint
        for tp in timepoints:
            result_df[str(tp)] = 0.0  # default to 0 if not found
        # Fill in the frequencies with new logic
        for idx, mut_row in result_df.iterrows():
            mut = mut_row['mutation']
            for tp in timepoints:
                freq_row = patient_df[
                    (patient_df['mutation'] == mut) &
                    (patient_df['timepoint'] == tp)
                ]
                if not freq_row.empty:
                    freq = freq_row['final_freq'].values[0]
                    result_df.at[idx, str(tp)] = np.nan if freq == -1 else freq
        # Sort by POS (ascending)
        result_df = result_df.sort_values(by='POS', ascending=True)
        # Save to CSV
        output_path = os.path.join(output_dir, f"{patient}_mutation_table.xlsx")
        result_df.to_excel(output_path, index=False)

    print(f"Mutation timepoint tables saved to {output_dir}")


In [None]:
build_timepoints_csvs(final_mutations_df, patients_tp_df, output_dir=freqs_folder)

## Section 10: Plotting Functions & Main Loop

In [None]:
 #=== Fill missing POS ===
if final_mutations_df['POS'].isnull().any():
    def extract_pos(mut):
        match = re.search(r'\d+', str(mut))
        return int(match.group()) if match else None
    final_mutations_df['POS'] = final_mutations_df['POS'].fillna(
        final_mutations_df['mutation'].apply(extract_pos)
    )

df = final_mutations_df.copy()
df['final_after_manual'] = df['final_freq']

# === Manual patient groups ===
grouped_patients = {
    "21-37 Days": ["N4", "P4", "N8", "N7"],
    "68-105 Days": ["N2", "N3", "P3", "P5"],
    "241 Days": ["N1"]
}

def remove_negative_frequencies(df):
    return df[df['final_after_manual'] >= 0]

def remove_zero_na_mutations(df):
    return df.groupby('manual_mut').filter(
        lambda x: (x['final_after_manual'] != 0).any() and x['final_after_manual'].notna().any()
    )

def clean_and_fill_missing_timepoints(patient_df, patient):
    # Include all original timepoints from the patient (even from other mutations)
    all_timepoints = sorted(patient_df['timepoint'].unique())
    all_mutations = patient_df['manual_mut'].unique()

    # Use this to preserve original data
    filled_rows = []
    final_rows = []

    for mut in all_mutations:
        mut_df = patient_df[patient_df['manual_mut'] == mut].copy()
        existing_tps_all = set(mut_df['timepoint'])
        valid_mut_df = mut_df[mut_df['final_after_manual'] != -1].copy()
        valid_tps = set(valid_mut_df['timepoint'])

        missing_tps = [tp for tp in all_timepoints if tp not in existing_tps_all]

        # Add zeros for timepoints completely missing (and never had a -1)
        for tp in missing_tps:
            filled_rows.append({
                'patient': patient,
                'timepoint': tp,
                'final_after_manual': 0,
                'manual_mut': mut,
                'mut_type': mut_df['mut_type'].iloc[0],
                'POS': mut_df['POS'].iloc[0]
            })

        # Add valid original rows
        final_rows.append(valid_mut_df)

    if filled_rows:
        final_rows.append(pd.DataFrame(filled_rows))

    return pd.concat(final_rows, ignore_index=True)

def save_legend_only_svg(color_dict, plot_folder, patient):
    legend_fig = go.Figure()
    for mut, color in color_dict.items():
        legend_fig.add_trace(go.Scatter(
            x=[None], y=[None],
            mode='lines+markers',
            name=mut,
            showlegend=True,
            line=dict(color=color, width=LINE_WIDTH),
            marker=dict(color=color, size=MARKER_SIZE),
            opacity=1
        ))

    legend_fig.update_layout(
        showlegend=True,
        legend=dict(
            orientation='h',
            x=0.5,
            y=0.5,
            xanchor='center',
            yanchor='middle',
            font=dict(size=FONT_SIZE_LEGEND, color=TEXT_COLOR)
        ),
        xaxis=dict(visible=False),
        yaxis=dict(visible=False),
        margin=dict(l=0, r=0, t=0, b=0),
        height=200,
        width=1000,
        paper_bgcolor='white',
        plot_bgcolor='white'
    )

    os.makedirs(plot_folder, exist_ok=True)
    legend_path = os.path.join(plot_folder, f"legend_{patient}.svg")
    legend_fig.write_image(legend_path)

def plot_patient(patient_data, patient):
    fig = go.Figure()

    mutations = sorted(
        patient_data['manual_mut'].unique(),
        key=lambda x: patient_data[patient_data['manual_mut'] == x]['POS'].iloc[0]
    )

    high_freq_mutations = [
        m for m in mutations
        if patient_data[patient_data['manual_mut'] == m]['final_after_manual'].max() > 0.5
    ]

    color_dict = {
        mut: PREDEFINED_COLORS[i % len(PREDEFINED_COLORS)]
        for i, mut in enumerate(high_freq_mutations)
    }

    # Step 1: Precompute jitter per timepoint
    jitter_lookup = {}  # {timepoint: {rounded_freq: list of mutations}}

    for tp in sorted(patient_data['timepoint'].unique()):
        tp_df = patient_data[
            (patient_data['timepoint'] == tp) &
            (patient_data['final_after_manual'] > 0.5)
        ]
        grouped = tp_df.groupby('manual_mut')
        freq_bins = {}

        for mut, group in grouped:
            f = group['final_after_manual'].iloc[0]
            rounded = round(f, 2)
            if rounded not in freq_bins:
                freq_bins[rounded] = []
            freq_bins[rounded].append((mut, f))

        jitter_lookup[tp] = {}
        for rounded_val, entries in freq_bins.items():
            n = len(entries)
            jitter_range = 0.06 * rounded_val  # ±6% of freq
            spacing = jitter_range / max(n - 1, 1) if n > 1 else 0
            start = -jitter_range / 2
            for i, (mut, f) in enumerate(sorted(entries)):
                jitter_val = start + i * spacing
                if tp not in jitter_lookup:
                    jitter_lookup[tp] = {}
                if mut not in jitter_lookup[tp]:
                    jitter_lookup[tp][mut] = jitter_val

    # Step 2: Plot mutation trends with jittered frequencies
    for mut in mutations:
        mut_df = patient_data[
            (patient_data['manual_mut'] == mut) &
            (patient_data['final_after_manual'] != -1)
        ].sort_values('timepoint').copy()

        if mut_df.empty:
            continue

        max_freq = mut_df['final_after_manual'].max()
        mut_type = mut_df['mut_type'].iloc[0].lower()
        line_style = 'dash' if mut_type in ['syn', 'mismatch'] else 'solid'
        marker_symbol = 'triangle-up' if mut_type == 'indel' else 'circle'
        show_in_legend = max_freq > 0.5
        name_in_legend = mut if show_in_legend else ''

        # Apply jitter if needed
        jittered_freqs = []
        for _, row in mut_df.iterrows():
            f = row['final_after_manual']
            tp = row['timepoint']
            jitter_val = jitter_lookup.get(tp, {}).get(mut, 0)
            jittered_freqs.append(f + jitter_val)

        mut_df['plot_freq'] = jittered_freqs
        color = color_dict.get(mut, LOW_FREQ_COLOR)

        fig.add_trace(go.Scatter(
            x=mut_df['timepoint'],
            y=mut_df['plot_freq'],
            mode='lines+markers',
            name=name_in_legend,
            showlegend=show_in_legend,
            line=dict(color=color, dash=line_style, width=LINE_WIDTH),
            marker=dict(color=color, size=MARKER_SIZE, symbol=marker_symbol),
            opacity=OPACITY if not show_in_legend else 1
        ))

    tick_vals = sorted(patient_data['timepoint'].unique())
    fig.update_layout(
        title=f"Patient {patient}",
        title_x=0.5,
        title_y=0.99,
        font=dict(size=FONT_SIZE_MAIN, color=TEXT_COLOR),
        xaxis=dict(
            title="Days",
            titlefont=dict(size=FONT_SIZE_AXIS_TITLE, color=TEXT_COLOR),
            tickfont=dict(size=FONT_SIZE_TICK, color=TEXT_COLOR),
            tickvals=tick_vals,
            ticktext=[str(t) for t in tick_vals],
            showgrid=True,
            gridcolor=GRID_COLOR
        ),
        yaxis=dict(
            title="Frequency",
            titlefont=dict(size=FONT_SIZE_AXIS_TITLE, color=TEXT_COLOR),
            tickfont=dict(size=FONT_SIZE_TICK, color=TEXT_COLOR),
            range=[-0.05, 1.1],  # show slightly outside 0-1
            showgrid=True,
            gridcolor=GRID_COLOR
        ),
        legend=dict(
            orientation='h',
            y=-0.2,
            x=0.5,
            xanchor='center',
            font=dict(size=FONT_SIZE_LEGEND, color=TEXT_COLOR)
        ),
        margin=dict(l=40, r=40, t=120, b=100),
        height=PLOT_HEIGHT,
        width=PLOT_WIDTH,
        paper_bgcolor='white',
        plot_bgcolor='white'
    )

    os.makedirs(plot_folder, exist_ok=True)
    png_path = os.path.join(plot_folder, f"png_plot_{patient}.png")
    svg_path = os.path.join(plot_folder, f"svg_plot_{patient}.svg")
    fig.write_image(png_path, scale=3)
    fig.write_image(svg_path)

    save_legend_only_svg(color_dict, plot_folder, patient)
    return png_path, svg_path

In [None]:
# === execution ===
all_patients = [(g, p) for g, plist in grouped_patients.items() for p in plist]
for group_name, patient in tqdm(all_patients, desc="Generating plots and legends"):
    pdata = df[df['patient'] == patient].copy()
    pdata = clean_and_fill_missing_timepoints(pdata, patient)
    pdata = remove_negative_frequencies(pdata)
    pdata = remove_zero_na_mutations(pdata)
    if not pdata.empty:
        plot_patient(pdata, patient)

print("✅ All patient plots and legends saved (PNG + SVG).")