# 性特異的な遺伝子モジュールを探索する

# セットアップ

In [None]:
# Move up to top directory
import os
from pathlib import Path

print(os.getcwd())

while not Path("LICENSE").exists():
    os.chdir('../')

print(os.getcwd())

In [None]:
from pathlib import Path
from pprint import pprint
from collections import defaultdict, Counter
from itertools import combinations
import csv
import numpy as np
import pandas as pd
import polars as pl
from matplotlib import pyplot as plt
import seaborn as sns
import networkx as nx

P = print
PP = pprint
C = Counter


# 実験

In [None]:
import json
import pandas as pd
df_tsumugi = pd.read_csv("data/TSUMUGI_raw_data.csv.gz",
    converters={"List of shared phenotypes": json.loads})

# 45 seconds

In [None]:
df_tsumugi_filtered = df_tsumugi[
    (df_tsumugi["Jaccard Similarity"] > 0.1) &
    (df_tsumugi["Number of shared phenotype"] > 2)
]

In [None]:
df_tsumugi_filtered

In [None]:
import pandas as pd
import re

# explode して行展開
df_long = df_tsumugi_filtered.explode("List of shared phenotypes").dropna(subset=["List of shared phenotypes"])

# phenotype と sex を抽出する関数
def extract_phenotype_and_sex(entry):
    match = re.match(r"^(.*?) \((.*?)\)$", entry)
    if match:
        phenotype_part, annotation_part = match.groups()
        parts = [p.strip() for p in annotation_part.split(",")]
        for p in parts:
            if p in {"Male", "Female"}:
                return {"Phenotype": phenotype_part.strip(), "Sex": p}
    return None

# .apply()して辞書を返し、Noneを除外
extracted = df_long["List of shared phenotypes"].apply(extract_phenotype_and_sex)
df_extracted = df_long[extracted.notnull()].copy()

# 抽出した辞書を DataFrame 化し、元と連結
df_extracted = pd.concat([df_extracted.reset_index(drop=True), pd.DataFrame(extracted.dropna().tolist()).reset_index(drop=True)], axis=1)

# 最終整形
df_result = df_extracted[["Gene1", "Gene2", "Phenotype", "Sex"]]

# 表示（例）
print(df_result.head())


In [None]:
# def extract_phenotype_and_sex(entry):
#     entry = entry.strip()  # ← 空白や改行を除去
#     match = re.match(r"^(.*?)\s*\((.*?)\)$", entry)
#     if match:
#         phenotype_part, annotation_part = match.groups()
#         parts = [p.strip() for p in annotation_part.split(",")]
#         for p in parts:
#             if p in {"Male", "Female"}:
#                 return {"Phenotype": phenotype_part.strip(), "Sex": p}
#     return None

# entry = "increased total body fat amount (Homo, Female, Early)"
# extract_phenotype_and_sex(entry)

In [None]:
df_result["Sex"].value_counts()

In [None]:
# モジュールを格納する辞書
modules = defaultdict(list)

# Phenotype × Sex specificity ごとにグラフを構築
for (phenotype, sex), group in df_result.groupby(['Phenotype', 'Sex']):
    G = nx.Graph()
    G.add_edges_from(group[['Gene1', 'Gene2']].values)

    # 連結成分（モジュール）を抽出
    for component in nx.connected_components(G):
        if len(component) > 1:  # 単独ノードは除外
            modules[(phenotype, sex)].append(sorted(component))

modules = dict(modules)
# # 結果の表示
# for (phenotype, sex), comps in modules.items():
#     print(f"\n[Phenotype: {phenotype}, Sex: {sex}]")
#     for i, comp in enumerate(comps, 1):
#         print(f"  Module {i}: {comp}")

In [None]:
len(modules)

In [None]:
P(modules)

In [None]:
# 結果の表示
for (phenotype, sex), comps in modules.items():
    print(f"\n[Phenotype: {phenotype}, Sex: {sex}]")
    for i, comp in enumerate(comps, 1):
        print(f"  Module {i}: {comp}")
    break

## ✅ 目的

各表現型における雌雄特異的な遺伝子モジュールの数をプロットする

In [None]:
from plotnine import ggplot, aes, geom_bar, labs, theme, element_text, position_stack, coord_flip, ggsave


In [None]:
# Use full `modules` dictionary input from earlier
# Step 1: Flatten the data to get count of modules for each (phenotype, sex)
phenotype_module_counts = defaultdict(lambda: {'Male': 0, 'Female': 0})

for (phenotype, sex), mod_lists in modules.items():
    phenotype_module_counts[phenotype][sex] += len(mod_lists)

# Step 2: Convert to DataFrame and long format
df = pd.DataFrame.from_dict(phenotype_module_counts, orient='index').fillna(0).reset_index()
df = df.rename(columns={'index': 'Phenotype'})
df_long = df.melt(id_vars='Phenotype', var_name='Sex', value_name='Module Count')

# Step 3: Compute total count and sort
df_long['Total'] = df_long.groupby('Phenotype')['Module Count'].transform('sum')
df_sorted = df_long.sort_values(by='Total', ascending=False)

# Top 10 phenotypes by total module count
top10_phenotypes = df_sorted.groupby('Phenotype')['Total'].max().nlargest(10).index
df_top10 = df_sorted[df_sorted['Phenotype'].isin(top10_phenotypes)]
# Plot using plotnine (ggplot2-style)
plot = (
    ggplot(df_top10, aes(x='reorder(Phenotype, Total)', y='Module Count', fill='Sex')) +
    geom_bar(stat='identity', position=position_stack()) +
    coord_flip() +
    labs(x='Phenotype', y='Number of Gene Modules', title='Gene Modules per Phenotype by Sex') +
    theme(axis_text_y=element_text(ha='right'))
)
ggsave(plot, filename="notebooks/data/number_of_sex_specific_modules.svg")

In [None]:
plot

## ✅ 目的

以下の条件をすべて満たすペアを抽出したい：

* 同じ遺伝子群の一部（≧3遺伝子） を含む2つのモジュール間で、
* 片方がMale、もう片方がFemale
* 表現型（Phenotype）が異なる

つまり：

“同じ遺伝子群が、性別によって異なる表現型に関与している” ことを示唆する遺伝子モジュールペアを抽出したい。

In [None]:
from itertools import combinations
from collections import defaultdict

def find_sex_diff_phenotype_modules(modules, min_shared_genes=3):
    """
    modules: dict of (phenotype, sex) → list of list of genes
    """
    result = []

    # フラット化: (phenotype, sex, gene_set)
    flat_modules = []
    for (phenotype, sex), mod_lists in modules.items():
        for genes in mod_lists:
            flat_modules.append((phenotype, sex, set(genes)))

    # すべてのペアを比較
    for (ph1, sex1, genes1), (ph2, sex2, genes2) in combinations(flat_modules, 2):
        if sex1 == sex2:
            continue  # 両方ともMaleやFemaleならスキップ
        if ph1 == ph2:
            continue  # 表現型が同じならスキップ

        shared = genes1 & genes2
        if len(shared) >= min_shared_genes:
            result.append({
                'shared_genes': sorted(shared),
                'module1': {'phenotype': ph1, 'sex': sex1, 'genes': sorted(genes1)},
                'module2': {'phenotype': ph2, 'sex': sex2, 'genes': sorted(genes2)},
            })

    return result

# 使用例
test_modules = {("hoge", "Male"): [["A", "B", "C", "D"]], ("fuga", "Female"): [["A", "B", "C"]],}

interesting_pairs = find_sex_diff_phenotype_modules(test_modules)

# 表示例
for i, pair in enumerate(interesting_pairs, 1):
    print(f"\n=== Match {i} ===")
    print("Shared genes:", pair['shared_genes'])
    print("→ Module 1:", pair['module1']['phenotype'], "/", pair['module1']['sex'])
    print("   Genes:", pair['module1']['genes'])
    print("→ Module 2:", pair['module2']['phenotype'], "/", pair['module2']['sex'])
    print("   Genes:", pair['module2']['genes'])


In [None]:
interesting_pairs = find_sex_diff_phenotype_modules(modules)

# 表示例
for i, pair in enumerate(interesting_pairs, 1):  # 最初の5件だけ表示
    if pair['module1']['phenotype'] == "abnormal behavior":
        continue
    if pair['module2']['phenotype'] == "abnormal behavior":
        continue
    print(f"\n=== Match {i} ===")
    print("Shared genes:", pair['shared_genes'])
    print("→ Module 1:", pair['module1']['phenotype'], "/", pair['module1']['sex'])
    print("   Genes:", pair['module1']['genes'])
    print("→ Module 2:", pair['module2']['phenotype'], "/", pair['module2']['sex'])
    print("   Genes:", pair['module2']['genes'])

In [None]:
len(interesting_pairs)

In [None]:
interesting_pairs

In [None]:
df = pl.read_csv("data/TSUMUGI_raw_data.csv.gz")

df = df.with_columns([
    pl.col("List of shared phenotypes").str.json_decode().alias("List of shared phenotypes")
  ])

In [None]:
df

In [None]:
RELEASE = 22.1

In [None]:
path_data = Path("data", "impc", f"statistical-results-ALL-{RELEASE}.csv")
data = pd.read_csv(path_data)
# 30 seconds

In [None]:
columns = ["marker_symbol", "mp_term_name", "p_value", "sex_effect_p_value", "female_ko_effect_p_value", "male_ko_effect_p_value", "zygosity", "effect_size",]


data = data[columns]

# Filter by p_value < 0.0001
threshold = 0.0001
filter_pvalue = data["p_value"] < threshold
filter_female_ko_pvalue = data["female_ko_effect_p_value"] < threshold
filter_male_ko_pvalue = data["male_ko_effect_p_value"] < threshold

data_filtered = data[filter_pvalue | filter_male_ko_pvalue | filter_male_ko_pvalue]

# Filter by mp_term_id and mp_term_name are not NaN
data_filtered = data_filtered.dropna(subset=["mp_term_name"])

# Filter by effect_size is not NaN
data_filtered = data_filtered.dropna(subset=["effect_size"])

data_filtered

In [None]:
data_annotated = data_filtered.copy()

threshold = 0.0001

# 条件リスト
conditions = [
    (data_annotated["sex_effect_p_value"] < threshold) & (data_annotated["female_ko_effect_p_value"] < threshold) & (data_annotated["male_ko_effect_p_value"] > threshold),
    (data_annotated["sex_effect_p_value"] < threshold) & (data_annotated["male_ko_effect_p_value"] < threshold) & (data_annotated["female_ko_effect_p_value"] > threshold)
]

# 条件に対応する値
choices = ["female", "male"]

# np.selectで列を設定
data_annotated["sex"] = np.select(conditions, choices, default=None)
data_annotated = data_annotated.reset_index(drop=True)

# 結果を確認
print(RELEASE)
print(data_annotated["sex"].value_counts())

In [None]:
data_annotated

In [None]:
columns_sex = ["marker_symbol", "mp_term_name", "sex"]
data_sex = data_annotated[columns_sex]
P(len(data_sex))

In [None]:
data_sex_filtered = data_sex[data_sex["sex"].notna()]
P(len(data_sex_filtered))

In [None]:
data_sex_filtered

In [None]:
from itertools import combinations
import networkx as nx

df = data_sex_filtered.copy()

# モジュールを記録
modules = defaultdict(list)

# phenotypes per sex
for (phenotype, sex), group in df.groupby(['mp_term_name', 'sex']):
    genes = group['marker_symbol'].unique()

    if len(genes) < 2:
        continue  # モジュールにならない単一遺伝子はスキップ

    G = nx.Graph()
    G.add_nodes_from(genes)
    G.add_edges_from(combinations(genes, 2))  # 完全グラフ：同じ表現型に属する遺伝子を接続

    for comp in nx.connected_components(G):
        modules[(phenotype, sex)].append(sorted(comp))

# 結果表示
for (phenotype, sex), comps in modules.items():
    print(f"\n[Phenotype: {phenotype}, Sex: {sex}]")
    for i, comp in enumerate(comps, 1):
        print(f"  Module {i}: {comp}")