In [None]:
# run_analysis.py
# Standalone Python script to run the DeepMethyGene analysis from the command line.

import os
import json
import ast
import argparse
from typing import List, Dict, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# ==============================================================================
# UTILITY AND MODEL DEFINITIONS
# (Copied directly from iGEM-9.1-SLC7A5.ipynb for a self-contained script)
# ==============================================================================

def _standardize_chr(x: str) -> str:
    """Ensures chromosome names start with 'chr'."""
    x = str(x).strip()
    return x if x.startswith("chr") else f"chr{x}"

def load_promoters_hg19(promoter_file: str) -> pd.DataFrame:
    """Loads and preprocesses the hg19 promoter file."""
    prom = pd.read_csv(promoter_file, sep="\t", dtype={"chrID": str})
    prom = prom.rename(columns={"gene name": "gene", "chrID": "seqnames"})
    prom["gene"] = prom["gene"].astype(str).str.upper().str.strip()
    prom["seqnames"] = prom["seqnames"].astype(str).map(_standardize_chr)
    prom["start"] = pd.to_numeric(prom["start"], errors="coerce").astype("Int64")
    prom = prom.dropna(subset=["start"]).astype({"start": "int64"}).drop_duplicates()
    return prom[["gene", "seqnames", "start"]]

def resolve_input_probes(
    gene: str,
    long_df: pd.DataFrame,
    mapped_csv: Optional[str] = "mapped_filteredgenes_data.csv",
) -> List[str]:
    """Resolves the exact probe order for the model input vector."""
    gene_u = gene.upper()
    if mapped_csv and os.path.exists(mapped_csv):
        try:
            m = pd.read_csv(mapped_csv)
            mg = m[m["gene"].astype(str).str.upper() == gene_u]
            if not mg.empty and "cpg_probe_ids" in mg.columns:
                raw = mg.iloc[0]["cpg_probe_ids"]
                out = ast.literal_eval(str(raw))
                out = [str(v) for v in out if str(v) != "nan"]
                if out:
                    print(f"[Info] Resolved probe order from '{mapped_csv}' (n={len(out)})")
                    return out
        except Exception as e:
            print(f"[Warning] Could not parse mapped CSV: {e}")

    sub = long_df[long_df["gene"].astype(str).str.upper() == gene_u].copy()
    if sub.empty:
        raise ValueError(f"Could not resolve input probes for {gene_u}")
    if {"chr", "start", "probe_id"}.issubset(sub.columns):
        sub = sub.drop_duplicates("probe_id").sort_values(["chr", "start"])
        out = sub["probe_id"].astype(str).tolist()
        print(f"[Info] Resolved probe order from long_df by genomic sort (n={len(out)})")
        return out
    return sub.drop_duplicates("probe_id")["probe_id"].astype(str).tolist()

def pick_sample(df_long: pd.DataFrame, gene: str, sample: Optional[str]) -> str:
    """Picks a sample for analysis, defaulting to the first available one."""
    df_g = df_long[df_long["gene"].str.upper() == gene.upper()]
    if df_g.empty: raise ValueError(f"No rows for gene {gene} in long data.")
    samples = sorted(df_g["sample"].astype(str).unique())
    if not samples: raise ValueError(f"No samples available for gene {gene}.")
    if sample is None: return samples[0]
    sample = str(sample)
    if sample not in samples: raise ValueError(f"Sample {sample} not found for {gene}.")
    return sample

def promoter_mask(
    input_probes: List[str], sub_long: pd.DataFrame, promoters: pd.DataFrame,
    gene: str, window_bp: int
) -> np.ndarray:
    """Creates a boolean mask for probes within a promoter window."""
    gene_u = gene.upper()
    pos = sub_long.drop_duplicates("probe_id").set_index("probe_id")[["chr", "start"]].to_dict("index")
    gp = promoters[promoters["gene"] == gene_u]
    if gp.empty: return np.zeros(len(input_probes), dtype=bool)

    gp_by_chr: Dict[str, List[int]] = {}
    for _, r in gp.iterrows(): gp_by_chr.setdefault(str(r["seqnames"]), []).append(int(r["start"]))

    mask = np.zeros(len(input_probes), dtype=bool)
    for i, pid in enumerate(input_probes):
        info = pos.get(pid)
        if info is None: continue
        c, s = str(info["chr"]), int(info["start"])
        for tss in gp_by_chr.get(c, []):
            if (s >= tss - window_bp) and (s <= tss + window_bp):
                mask[i] = True
                break
    return mask

# --- PyTorch Model Definition ---
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv1 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size=3, padding=1)
        self.act = nn.LeakyReLU(0.01)
    def forward(self, x):
        res = x; x = self.act(self.conv1(x)); x = self.conv2(x); x = x + res; return self.act(x)

class AdaptiveRegressionCNN(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        out1 = min(64, max(1, input_size // 10))
        out2 = min(32, max(1, input_size // 20))
        self.conv1 = nn.Conv1d(1, out1, kernel_size=3, padding=1)
        self.res1 = ResidualBlock(out1)
        self.conv2 = nn.Conv1d(out1, out2, kernel_size=3, padding=1)
        self.res2 = ResidualBlock(out2)
        self.act = nn.LeakyReLU(0.01)
        with torch.no_grad():
            h = self.res2(self.act(self.conv2(self.res1(self.act(self.conv1(torch.randn(1, 1, input_size)))))))
            self._to_linear = h.view(1, -1).shape[1]
        self.fc1 = nn.Linear(self._to_linear, 512)
        self.fc2 = nn.Linear(512, 1)
    def forward(self, x):
        x = self.act(self.conv1(x)); x = self.res1(x); x = self.act(self.conv2(x)); x = self.res2(x)
        x = x.view(x.size(0), -1); x = self.act(self.fc1(x)); return self.fc2(x).squeeze(-1)

def load_gene_model_cnn_only(gene: str, models_dir: str, device: str = "cpu"):
    """Loads the pre-trained CNN model for a specific gene."""
    gene_u = gene.upper()
    meta_path = os.path.join(models_dir, f"{gene_u}.json")
    if not os.path.exists(meta_path): raise FileNotFoundError(f"Metadata for {gene_u} not found")
    with open(meta_path, "r") as f: meta = json.load(f)
    if "input_size" not in meta: raise ValueError(f"meta missing 'input_size'")

    input_size = int(meta["input_size"])
    model_path = os.path.join(models_dir, f"{gene_u}.pt")
    if not os.path.exists(model_path): raise FileNotFoundError(f"Weights for {gene_u} not found")

    model = AdaptiveRegressionCNN(input_size=input_size)
    state_dict = torch.load(model_path, map_location=device)
    model.load_state_dict(state_dict); model.to(device).eval()
    print(f"[Info] Loaded model '{gene_u}.pt' with input size {input_size}")
    return model, input_size

def vector_from_long(
    df_long: pd.DataFrame, gene: str, sample: str, input_probes: List[str], fill_missing: float = 0.0
) -> Tuple[np.ndarray, pd.DataFrame]:
    """Builds the input vector for a given gene and sample."""
    gene_u = gene.upper()
    sub = df_long[(df_long["gene"].str.upper() == gene_u) & (df_long["sample"] == sample)]
    if sub.empty: raise ValueError(f"No rows for {gene_u} / sample {sample}")
    sub = sub.drop_duplicates("probe_id").copy()
    m_by_probe = dict(zip(sub["probe_id"].astype(str), sub["m_value"]))
    x = np.array([m_by_probe.get(pid, fill_missing) for pid in input_probes], dtype=np.float32)
    return x, sub

def _predict_tensor(model: nn.Module, x_arr_1d: np.ndarray, device: str = "cpu") -> float:
    """Performs a prediction on a single 1D numpy array."""
    x = torch.from_numpy(x_arr_1d).float().unsqueeze(0).unsqueeze(1).to(device)
    with torch.no_grad(): y = model(x).cpu().numpy().reshape(-1)[0]
    return float(y)

# ==============================================================================
# MAIN ANALYSIS FUNCTION
# ==============================================================================

def perform_analysis(args):
    """Main function to orchestrate the data loading, analysis, and plotting."""
    # --- Setup ---
    target_gene = args.gene.upper()
    device = "cuda" if torch.cuda.is_available() and not args.cpu else "cpu"
    print(f"Using device: {device}")
    os.makedirs(args.output_dir, exist_ok=True)

    # --- Load data ---
    print("Loading and preprocessing data files...")
    df_long = pd.read_csv(args.long_csv)
    df_long["gene"] = df_long["gene"].astype(str).str.upper().str.strip()
    df_long["sample"] = df_long["sample"].astype(str)
    df_long["chr"] = df_long["chr"].astype(str).map(_standardize_chr)
    df_long["start"] = pd.to_numeric(df_long["start"], errors="coerce").astype("Int64")
    df_long = df_long.dropna(subset=["start"]).astype({"start": "int64"})

    promoters = load_promoters_hg19(args.promoter_file)

    # --- Resolve inputs for the target gene ---
    input_probes = resolve_input_probes(target_gene, df_long, args.mapped_csv)
    sample_id = pick_sample(df_long, target_gene, sample=None)
    print(f"Using auto-selected sample: {sample_id}")

    x_base, sub = vector_from_long(df_long, target_gene, sample_id, input_probes, fill_missing=0.0)

    # --- Load Model ---
    model, input_size = load_gene_model_cnn_only(target_gene, models_dir=args.weights_dir, device=device)

    # Align vector length to model's expected input size
    if len(input_probes) != input_size:
        print(f"[Warning] Probe list length ({len(input_probes)}) differs from model input size ({input_size}). Will truncate/pad.")
        if len(input_probes) > input_size:
            x_base = x_base[:input_size]
            input_probes = input_probes[:input_size]
        else:
            pad_n = input_size - len(input_probes)
            x_base = np.concatenate([x_base, np.full(pad_n, 0.0, dtype=np.float32)])
            input_probes += ["__PAD__"] * pad_n

    # --- Task 1: 4-scenario bar plot ---
    print("\n--- Running Task 1: 4-Scenario Prediction ---")
    pmask_task1 = promoter_mask(input_probes, sub, promoters, target_gene, window_bp=2000)
    preds = {}
    preds["Baseline"] = _predict_tensor(model, x_base.copy(), device)
    x_hyper, x_hypo = x_base.copy(), x_base.copy()
    x_hyper[pmask_task1], x_hypo[pmask_task1] = 10.0, -10.0
    preds["Promoter\nHypermethylated"] = _predict_tensor(model, x_hyper, device)
    preds["Promoter\nHypomethylated"] = _predict_tensor(model, x_hypo, device)

    plt.figure(figsize=(8, 6))
    plt.bar(preds.keys(), preds.values(), color=['#3b82f6', '#ef4444', '#22c55e'])
    plt.ylabel("Predicted Expression")
    plt.title(f"Task 1: {target_gene} Expression Changes (Sample: {sample_id})")
    plt.xticks(rotation=15)
    plt.grid(axis='y', linestyle='--', alpha=0.7)

    task1_plot_path = os.path.join(args.output_dir, f"{target_gene}_task1_scenarios.png")
    plt.savefig(task1_plot_path, bbox_inches='tight')
    plt.close()
    print(f"Saved plot: {task1_plot_path}")

    # --- Task 2: Sweep window hypomethylation ---
    print("\n--- Running Task 2: Hypomethylation Window Sweep ---")
    rows = []
    for W in range(1000, 10001, 1000):
        pmask = promoter_mask(input_probes, sub, promoters, target_gene, window_bp=W)
        x_edit = x_base.copy(); x_edit[pmask] = -10.0
        rows.append({"window_bp": W, "y_pred": _predict_tensor(model, x_edit, device)})
    df2 = pd.DataFrame(rows)

    plt.figure(figsize=(8, 6))
    plt.plot(df2["window_bp"], df2["y_pred"], marker="o")
    plt.xlabel("Promoter Window Size (±bp)")
    plt.ylabel("Predicted Expression")
    plt.title(f"Task 2: {target_gene} Expression vs. Hypo-methylated Window")
    plt.grid(True, linestyle=':')

    task2_plot_path = os.path.join(args.output_dir, f"{target_gene}_task2_window_sweep.png")
    plt.savefig(task2_plot_path, bbox_inches='tight')
    plt.close()
    print(f"Saved plot: {task2_plot_path}")

    task2_csv_path = os.path.join(args.output_dir, f"{target_gene}_task2_window_sweep.csv")
    df2.to_csv(task2_csv_path, index=False)
    print(f"Saved data: {task2_csv_path}")

    # --- Task 3: Sweep promoter methylation level ---
    print("\n--- Running Task 3: Promoter Methylation Level Sweep ---")
    pmask3 = promoter_mask(input_probes, sub, promoters, target_gene, window_bp=2000)
    levels, preds3 = [], []
    for level in np.arange(10.0, -10.5, -0.5):
        x_edit = x_base.copy()
        x_edit[pmask3] = level
        levels.append(level)
        preds3.append(_predict_tensor(model, x_edit, device))
    df3 = pd.DataFrame({"promoter_level": levels, "y_pred": preds3})

    plt.figure(figsize=(8, 6))
    plt.plot(df3["promoter_level"], df3["y_pred"], marker="o")
    plt.xlabel("Promoter CpG M-value (within ±2000bp)")
    plt.ylabel("Predicted Expression")
    plt.title(f"Task 3: {target_gene} Expression vs. Promoter Methylation Level")
    plt.grid(True, linestyle=':')
    plt.gca().invert_xaxis() # Show hypermethylation on the left

    task3_plot_path = os.path.join(args.output_dir, f"{target_gene}_task3_level_sweep.png")
    plt.savefig(task3_plot_path, bbox_inches='tight')
    plt.close()
    print(f"Saved plot: {task3_plot_path}")

    task3_csv_path = os.path.join(args.output_dir, f"{target_gene}_task3_level_sweep.csv")
    df3.to_csv(task3_csv_path, index=False)
    print(f"Saved data: {task3_csv_path}")

    print(f"\nAnalysis for {target_gene} complete. Outputs are in '{args.output_dir}'.")


# ==============================================================================
# COMMAND LINE INTERFACE
# ==============================================================================
if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Run DeepMethyGene analysis for a specific gene.",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )

    # --- Required Arguments ---
    parser.add_argument(
        "--gene",
        type=str,
        required=True,
        help="The target gene symbol to analyze (e.g., SLC7A5)."
    )

    # --- File Path Arguments ---
    parser.add_argument(
        "--long-csv",
        type=str,
        default="data/m_arrays_for_edit.csv",
        help="Path to the long-format methylation data CSV."
    )
    parser.add_argument(
        "--mapped-csv",
        type=str,
        default="data/mapped_filteredgenes_data.csv",
        help="Path to the mapped genes data with probe order."
    )
    parser.add_argument(
        "--promoter-file",
        type=str,
        default="data/hg19_promoter.txt",
        help="Path to the hg19 promoter annotation file."
    )
    parser.add_argument(
        "--weights-dir",
        type=str,
        default="Gene Wise Model Weights",
        help="Directory containing the pre-trained model weights (.pt and .json files)."
    )
    parser.add_argument(
        "--output-dir",
        type=str,
        default="predictions_out",
        help="Directory where output plots and CSVs will be saved."
    )

    # --- Computation Arguments ---
    parser.add_argument(
        "--cpu",
        action="store_true",
        help="Force the use of CPU even if a GPU is available."
    )

    args = parser.parse_args()

    try:
        perform_analysis(args)
    except (FileNotFoundError, ValueError) as e:
        print(f"\nERROR: {e}")
        print("Please check your file paths and ensure the gene exists in the data.")