# 

In [1]:
import requests
import time
import matplotlib.pyplot as plt
import subprocess
import tempfile
import shutil
import os

import ipywidgets as widgets
from IPython.display import display, clear_output

# ======== Настройки ======== 
ENSEMBL_REST = "https://rest.ensembl.org"
HEADERS = {"Content-Type": "application/json"}
AVAILABLE_GENES = [
    "CGA", "FSHB", "FSHR", "GNRH1", "GNRHR", "FOXL2", "NR5A1", "PITX1", "LHX3",
    "INHBA", "INHBB", "ACVR1B", "SMAD2", "SMAD3", "SMAD4", "INHA", "FST", "GNAS",
    "ADCY3", "PRKAR2B", "CREB1", "CYP19A1", "AMH", "ESR1", "ESR2", "LHCGR", "SOX9",
    "CLDN11", "INSL3", "TGFBR1", "TGFBR2", "BMP2", "BMP4"
]
SPECIES = {"human": "homo_sapiens", "mouse": "mus_musculus"}
UPSTREAM = 1000
PWM_PATH = r"C:\Users\ПК\Desktop\Project\data\motifs\pwm.meme"

#  ======== Функции ========
def get_gene_coordinates(gene_name, species):
    url = f"{ENSEMBL_REST}/lookup/symbol/{species}/{gene_name}?expand=1"
    res = requests.get(url, headers=HEADERS)
    if not res.ok:
        return None
    data = res.json()
    return {
        "seq_region_name": data["seq_region_name"],
        "start": data["start"],
        "strand": data["strand"]
    }

def get_promoter_sequence(gene_name, species):
    coords = get_gene_coordinates(gene_name, species)
    if not coords:
        return None
    if coords["strand"] == 1:
        region = f"{coords['seq_region_name']}:{coords['start'] - UPSTREAM}..{coords['start'] - 1}:1"
    else:
        region = f"{coords['seq_region_name']}:{coords['start']}..{coords['start'] + UPSTREAM - 1}:-1"
    url = f"{ENSEMBL_REST}/sequence/region/{species}/{region}"
    res = requests.get(url, headers={"Content-Type": "text/plain"})
    if not res.ok:
        return None
    return res.text

def fimo_predict(sequence, pwm_path="pwm.meme"):
    with tempfile.TemporaryDirectory() as tmpdir:
        fasta_path = os.path.join(tmpdir, "input.fasta")
        pwm_target = os.path.join(tmpdir, "pwm.meme")
        output_dir = os.path.join(tmpdir, "fimo_out")

        with open(fasta_path, "w") as f:
            f.write(">query_sequence\n")
            f.write(sequence.replace("\n", "") + "\n")

        shutil.copyfile(pwm_path, pwm_target)

        try:
            subprocess.run([
                "docker", "run", "--rm",
                "-v", f"{tmpdir}:/data",
                "memesuite/memesuite",
                "fimo",
                "--thresh", "1",
                "--oc", "/data/fimo_out",
                "/data/pwm.meme",
                "/data/input.fasta"
            ], check=True)

            fimo_result = os.path.join(output_dir, "fimo.tsv")
            if not os.path.exists(fimo_result):
                return {}

            tf_scores = {}
            with open(fimo_result) as f:
                for line in f:
                    if line.startswith("#") or "sequence_name" in line:
                        continue
                    parts = line.strip().split("\t")
                    if len(parts) >= 9:
                        motif_id = parts[0]
                        score = float(parts[6])
                        tf_scores[motif_id] = max(tf_scores.get(motif_id, 0), score)
            return tf_scores

        except subprocess.CalledProcessError:
            return {}

def compare_predictions(human_preds, mouse_preds):
    all_tfs = set(human_preds.keys()) | set(mouse_preds.keys())
    diff = {tf: abs(human_preds.get(tf, 0) - mouse_preds.get(tf, 0)) for tf in all_tfs}
    return dict(sorted(diff.items(), key=lambda x: x[1], reverse=True))

def parse_pwm_names(pwm_path):
    id_to_name = {}
    with open(pwm_path, "r") as f:
        for line in f:
            if line.startswith("MOTIF"):
                parts = line.strip().split()
                if len(parts) >= 3:
                    motif_id = parts[1]
                    tf_name = parts[2]
                    id_to_name[motif_id] = tf_name
    return id_to_name

def plot_tf_differences(differences, id_to_name=None):
    tfs = list(differences.keys())[:10]
    values = [differences[tf] for tf in tfs]
    labels = [id_to_name.get(tf, tf) for tf in tfs] if id_to_name else tfs
    plt.figure(figsize=(10, 5))
    plt.bar(labels, values, color='purple')
    plt.title("TF-различия между человеком и мышью")
    plt.ylabel("|human - mouse| score")
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

# ======== Интерфейс ========
use_custom_gene = widgets.Checkbox(
    value=False,
    description="Ввести вручную",
    indent=False
)

gene_dropdown = widgets.Dropdown(
    options=AVAILABLE_GENES,
    description="Ген:"
)

gene_text = widgets.Text(
    placeholder="Введите название гена",
    description="Ген:",
    disabled=False,
    layout=widgets.Layout(width='250px')
)
gene_text.layout.display = 'none'  

def toggle_gene_input(change):
    if change['new']:
        gene_dropdown.layout.display = 'none'
        gene_text.layout.display = ''
    else:
        gene_dropdown.layout.display = ''
        gene_text.layout.display = 'none'

use_custom_gene.observe(toggle_gene_input, names='value')

run_button = widgets.Button(
    description="Запустить анализ",
    style={'button_color': '#A9A9A9'},
    layout=widgets.Layout(width='200px')
)
run_button.layout = widgets.Layout(margin='10px 0 20px 0', width='200px')

spacer = widgets.HTML(value="<div style='height:15px'></div>")

output = widgets.Output()

# ======== Функция запуска анализа ========
def on_run_clicked(b):
    output.clear_output()
    with output:
        gene = gene_text.value.strip() if use_custom_gene.value else gene_dropdown.value
        if not gene:
            print("Пожалуйста, введите название гена.")
            return

        print(f"Анализ гена {gene}...")
        id_to_name = parse_pwm_names(PWM_PATH)

        print("Загрузка последовательностей...")
        seq_human = get_promoter_sequence(gene, SPECIES["human"])
        time.sleep(0.2)
        seq_mouse = get_promoter_sequence(gene, SPECIES["mouse"])
        time.sleep(0.2)

        if not seq_human or not seq_mouse:
            print("Ошибка получения промоторных последовательностей.")
            return

        print("Предсказание связывания TF...")
        preds_human = fimo_predict(seq_human, pwm_path=PWM_PATH)
        preds_mouse = fimo_predict(seq_mouse, pwm_path=PWM_PATH)

        print("Сравнение паттернов...")
        differences = compare_predictions(preds_human, preds_mouse)

        print("Топ-3 TF с наибольшими различиями:")
        top_diffs = list(differences.items())[:10]
        for tf, diff in top_diffs[:3]:
            name = id_to_name.get(tf, tf)
            print(f"  {name} ({tf}): |Δ| = {diff:.2f}")

        plot_tf_differences(differences, id_to_name)

        print("\nСравнение скоров для топ-10 TF:")
        print("{:<20} {:>10} {:>10} {:>10}".format("TF", "Human", "Mouse", "|Δ|"))
        print("-" * 52)
        for tf, diff in top_diffs:
            name = id_to_name.get(tf, tf)
            h_score = preds_human.get(tf, 0)
            m_score = preds_mouse.get(tf, 0)
            print("{:<20} {:>10.2f} {:>10.2f} {:>10.2f}".format(name, h_score, m_score, diff))

run_button.on_click(on_run_clicked)

display(use_custom_gene, gene_dropdown, gene_text, spacer, run_button, output)


Checkbox(value=False, description='Ввести вручную', indent=False)

Dropdown(description='Ген:', options=('CGA', 'FSHB', 'FSHR', 'GNRH1', 'GNRHR', 'FOXL2', 'NR5A1', 'PITX1', 'LHX…

Text(value='', description='Ген:', layout=Layout(display='none', width='250px'), placeholder='Введите название…

HTML(value="<div style='height:15px'></div>")

Button(description='Запустить анализ', layout=Layout(margin='10px 0 20px 0', width='200px'), style=ButtonStyle…

Output()