#### 前置作業 Preliminary work
##### 匯入相關模組與設定 Import related modules and settings

In [None]:
import time
import os
import numpy as np
import pandas as pd
from scipy.sparse.linalg import svds
from pybiomart import Server
import pyarrow as pa
import plotly.express as px

print(f" [i] Start analyzing, total 7 steps...")
print(f" [i] Setting up...")

# 強制 Qt 在 Linux 環境下使用 X11
os.environ["QT_QPA_PLATFORM"] = "xcb"

# 防止 NumPy 佔用過多記憶體
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"

# 固定隨機種子確保可重現性
np.random.seed(42)
os.environ["PYTHONHASHSEED"] = "42"

#### 連線到ensembl資料庫擷取基因資料，分批次讀取gct檔案 
##### Connect to the ensembl database to retrieve gene data and read gct files in batches


In [None]:
def load_hgnc_protein_coding():
    print(" [i] Querying HGNC protein coding genes via pybiomart...")
    # 建立連線與資料集
    server = Server(host='http://www.ensembl.org')
    dataset = server.marts['ENSEMBL_MART_ENSEMBL'].datasets['hsapiens_gene_ensembl']

    # 取得 Ensembl ID 與 HGNC Symbol
    genes_df = dataset.query(attributes=['ensembl_gene_id', 'hgnc_symbol', 'gene_biotype'])

    # 去除 HGNC Symbol 為空的資料
    genes_df = genes_df.dropna(subset=['HGNC symbol'])

    # 找出 FAU 與 UBA52 的列
    extra_genes = genes_df[genes_df['HGNC symbol'].isin(['FAU', 'UBA52'])]

    # 篩選 Ribosomal Protein（HGNC symbol 開頭為 RPL 或 RPS）
    genes_df = genes_df[genes_df['HGNC symbol'].str.startswith(('RPL', 'RPS'), na=False)]
    genes_df = genes_df[genes_df['Gene type'] == 'protein_coding']

    # 將 FAU 與 UBA52 加入並去除重複列
    genes_df = pd.concat([genes_df, extra_genes]).drop_duplicates()
    
    # 移除重複的 Ensembl ID，保留第一筆
    genes_df = genes_df.drop_duplicates(subset=['HGNC symbol'], keep='first')
    
    unwanted_genes = ["RPL17-C18orf32", "RPL36A-HNRNPH2", 
                      "RPS6KA1", "RPS6KA5", "RPS19BP1", "RPS6KL1", "RPS6KA4", "RPS6KA3", "RPS6KB2", "RPS6KA6", "RPS6KC1", "RPS10-NUDT3", "RPS6KA2", "RPS6KB1"]
    genes_df = genes_df[~genes_df['HGNC symbol'].isin(unwanted_genes)]
    print(genes_df)
    
    # 回傳對應字典
    return dict(zip(genes_df['Gene stable ID'], genes_df['HGNC symbol']))

ensembl_to_symbol = load_hgnc_protein_coding()
print(f" [i] Ribosomal Protein gene count: {len(ensembl_to_symbol)}")

def load_gct(filepath):
    """
    逐行解析 GCT 檔案，確保數據格式一致，避免記憶體使用過高。
    """
    print(" [1] Detecting column names...")
    with open(filepath, "r") as f:
        lines = f.readlines()
    header_line = lines[2]  # `.gct` 的標題通常在第 3 行（index 2）
    column_names = header_line.strip().split("\t")
    gene_id_column = column_names[0]
    print(f" [2] Detected Gene_ID column: {gene_id_column}")

    print(" [3] Streaming GCT data...", end = "\n")
    print(" [i] This may take a few minutes, please wait...", end = "\n")
    filtered_data = []
    
    def process_chunk(chunk):
        chunk = chunk.rename(columns={gene_id_column: "Gene_ID"})
        chunk["Gene_ID"] = chunk["Gene_ID"].astype(str)
        chunk.set_index("Gene_ID", inplace=True)

        # 🔹 確保基因 ID 的格式與 ensembl_to_symbol 一致（去掉 .X 後綴）
        chunk.index = chunk.index.str.split('.').str[0]
        
        # 🔹 過濾我們要的 Ensembl ID
        if chunk.index is None or chunk.empty:
            return pd.DataFrame()

        return chunk[chunk.index.isin(ensembl_to_symbol)]

    chunk_size = 1000
    with pd.read_csv(filepath, sep="\t", skiprows=2, chunksize=chunk_size) as reader:
        for chunk in reader:
            processed_chunk = process_chunk(chunk)
            if not processed_chunk.empty:
                filtered_data.append(pa.Table.from_pandas(processed_chunk))

    print(" [4] Merging processed data...")
    final_table = pa.concat_tables(filtered_data).to_pandas()
    if "Description" in final_table.columns:
        final_table.drop(columns=["Description"], inplace=True)
    
    # 轉置以符合橫向基因矩陣的格式
    final_table = final_table.T
    
    # print(final_table.shape)
    return final_table

#### 讀取組織對應元資料
##### Read metadata of corresponding tissues

In [None]:
def load_metadata_labels(metadata_path, df, target_tissues=None):
    print(f" [i] Reading tissue metadata...")
    metadata_df = pd.read_csv(metadata_path, sep="\t", low_memory=False, usecols=["SAMPID", "SMTSD"])
    metadata_df = metadata_df[metadata_df["SAMPID"].str.startswith("GTEX-")]
    metadata_df.set_index("SAMPID", inplace=True)
    metadata_filtered = metadata_df.loc[df]
    metadata_filtered["GTEX_ID"] = metadata_filtered.index.str.split("-").str[:2].str.join("-")
    gtex_tissue_dict = metadata_filtered.set_index("GTEX_ID")["SMTSD"].to_dict()

    sample_ids = df.str.split("-").str[:2].str.join("-")
    tissue_labels = sample_ids.map(gtex_tissue_dict).dropna()

    if target_tissues is not None:
        tissue_labels = [label if label in target_tissues else "Other" for label in tissue_labels]
    

    return tissue_labels, metadata_filtered

#### 將元資料中多餘的資料截除，只留下有匹配的樣本資訊
##### Remove redundant data from the metadata, leaving only matching sample information

In [None]:
def get_matched_samples(df_gct, df_metadata, tissue_column="SMTSD"):
    gct_sample_ids = df_gct.columns.astype(str)
    metadata_sample_ids = df_metadata.index.astype(str)

    # 找出交集
    common_samples = list(set(gct_sample_ids) & set(metadata_sample_ids))

    print(f" [i] GCT 樣本數：{len(gct_sample_ids)}")
    print(f" [i] Metadata 樣本數：{len(metadata_sample_ids)}")
    print(f" [i] GCT 與 Metadata 交集樣本數：{len(common_samples)}")

    # 過濾 GCT 以及 Metadata
    filtered_df = df_gct[common_samples]
    filtered_metadata = df_metadata.loc[common_samples].copy()

    print(f" [i] 可用的組織種類數（{tissue_column}）：{filtered_metadata[tissue_column].nunique()}")
    print(f" [i] 可用組織列表：\n{filtered_metadata[tissue_column].value_counts().sort_values(ascending=False)}")

    return filtered_df, filtered_metadata


#### 將gct檔案樣本資訊與元資料對齊
##### Align gct file sample information with metadata

In [None]:
def align_metadata_and_gct(filtered_df, filtered_metadata):
    filtered_df = filtered_df
    filtered_metadata = filtered_metadata
    
    # 對齊：只保留 metadata 中 index 有出現在 gct_df.columns 裡的樣本
    metadata_matched = filtered_metadata.loc[filtered_metadata.index.intersection(filtered_df.columns)]

    # 再反過來，從 gct_df 裡面挑出這些樣本（欄位）
    gct_matched = filtered_df.loc[:, metadata_matched.index]

    # 最後確認兩者是否一一對齊（順序一致）
    print(" [!] 是否完全對齊？", all(metadata_matched.index == gct_matched.columns))
    print(" [!] 樣本數：", len(metadata_matched))
    
    return gct_matched, metadata_matched

#### 使用sklearn預處理數據
##### Preprocessing data using sklearn

In [None]:
def preprocess_data(df):
    print(" [5] Pre-processing data...")
    from sklearn.preprocessing import StandardScaler
    scaler = StandardScaler()
    df = df.select_dtypes(include=[np.number])  # 只保留數值欄位
    data_scaled = scaler.fit_transform(df)
    return data_scaled, data_scaled.shape

#### 使用t-SNE方法將資料降為二維
##### Use t-SNE to reduce the data to two dimensions

In [None]:
# t-SNE 降維
def perform_tsne_np(data, perplexity = 30, learning_rate = 200, 
                    n_components=2, max_iter=2000, init = "pca", random_state=42):
    print(" [6] Performing t-SNE_NP...")
    from sklearn.manifold import TSNE
    tsne = TSNE(n_components=n_components, perplexity=perplexity, 
                learning_rate=learning_rate, max_iter=max_iter, 
                init = init, random_state = random_state)
    
    return tsne.fit_transform(data)

#### 使用plotly繪製t-SNE散點圖，製作互動圖、存檔
##### Use plotly to draw t-SNE scatter plot, make interactive graph, and archive it

In [None]:
# 繪製 t-SNE 圖形
def plot_tsne(tsne_results, tissue_labels, output_path):
    print(" [7] Drawing t-SNE plot...")
    import seaborn as sns
    tissues = tissue_labels.unique()
    palette = sns.color_palette("hls", len(tissues))
    custom_order = (sorted(set(tissues)))
    color_map = {
        tissue: f"rgb({int(r*255)}, {int(g*255)}, {int(b*255)})"
        for tissue, (r, g, b) in zip(custom_order, palette)
    }
    fig = px.scatter(x = tsne_results[:, 0], y = tsne_results[:, 1], 
                     color = tissue_labels, color_discrete_map = color_map,
                     hover_name = tissue_labels, 
               title = "t-SNE visualization of GTEx Data", template='plotly_white', 
               width = 2000, height = 1600, 
               labels = {
                    'tSNE-1': 't-SNE dimension 1',
                    'tSNE-2': 't-SNE dimension 2'
               }
            )
    fig.update_layout(
        legend_title_text="Tissues",
        legend=dict(
            x=1,
            y=1,
            xanchor='left',
            yanchor='top'
        ),
        margin=dict(t=50, b=50, l=50, r=50)  # 避免圖例或標題被裁切
    )
    fig.show() # 如果要顯示圖片則取消註解
    fig.write_image(output_path)
    print(f" [i] t-SNE figure is saved as: {output_path}")

#### 主執行流程
##### Main process

In [None]:
if __name__ == "__main__":
    start = time.time()
    gct_file = r"/home/terry_0714/gct_data/GTEx_Analysis_2022-06-06_v10_RNASeQCv2.4.2_gene_tpm_non_lcm.gct"
    metadata_path = r"/home/terry_0714/gct_data/GTEx_Analysis_v10_Annotations_SampleAttributesDS.tsv"
    output_dir = r"/home/terry_0714/tsne_plot"
    os.makedirs(output_dir, exist_ok = True)
    # 讀取 GCT 檔案
    gct_df = load_gct(gct_file)
    _, metadata_df = load_metadata_labels(metadata_path, gct_df.index)
    filtered_df, filtered_metadata = get_matched_samples(gct_df.T, metadata_df)
    
    # 資料對齊
    gct_matched, metadata_matched= align_metadata_and_gct(filtered_df, filtered_metadata)
    input = gct_matched.T.values
    label = metadata_matched["SMTSD"]
    
    # 資料預處理
    processed_data, shape = preprocess_data(gct_matched.T)
    
    # t-SNE 降維
    tsne_results = perform_tsne_np(processed_data)
            
    # 作圖
    output_image = os.path.join(output_dir, f"tsne_plot_p_30_lr_200.png")
    plot_tsne(tsne_results, label, output_image)
    
    print(" [i] t-SNE analysis is completed.")

    end = time.time()
    elapsed_time = int(end - start)
    hours = elapsed_time // 3600
    minutes = elapsed_time % 3600 // 60
    seconds = elapsed_time % 60

    print(f" [i] Time spent: {hours} hours, {minutes} minutes and {seconds} seconds.")