In [None]:
import scanpy as sc
import scirpy as ir

import mudata as md
import muon as mu
import awkward as ak

import numpy as np
import math

import torch
import torch.nn as nn
from torch import Tensor
from torch.utils.data import Dataset, DataLoader

from typing import Callable, Union, List, Dict, Tuple
from collections import deque
from scipy.sparse import issparse
import umap
import matplotlib.pyplot as plt
import logging
logging.basicConfig(level=logging.INFO)


def get_data(file_path):
    adata_tcr = ir.io.read_10x_vdj(file_path[0])
    adata_rna = sc.read_10x_h5(file_path[1])
    adata_rna.var_names_make_unique()
    return mu.MuData({"gex": adata_rna, "airr": adata_tcr})


def filter_cell(mdata):
    sc.pp.filter_cells(mdata['gex'], min_genes=200)
    sc.pp.filter_genes(mdata['gex'], min_cells=3)
    #sc.pp.normalize_per_cell(mdata['gex'])
    #sc.pp.log1p(mdata['gex'])
    sc.pp.highly_variable_genes(mdata['gex'], flavor="cell_ranger", n_top_genes=5000)
    gex_hv = mdata['gex'][:, mdata['gex'].var.highly_variable]
    mdata = mu.MuData({'gex': gex_hv, 'airr': mdata['airr']})
    #sc.pp.scale(mdata['gex'], max_value=10)
    #sc.tl.pca(mdata['gex'], svd_solver="arpack")
    #sc.pp.neighbors(mdata['gex'], n_neighbors=10, n_pcs=40)
    #sc.tl.umap(mdata['gex'])
    #sc.tl.leiden(mdata['gex'], resolution=0.5)
    return mdata


def filter_chain(mdata):
    ir.tl.chain_qc(mdata)
    subset = mdata['airr'][mdata['airr'].obs['chain_pairing'] == 'single pair']
    mdata = mu.MuData({"gex": mdata['gex'], "airr": subset})

    if mdata.is_view:
        mdata = mdata.copy()

    filtered_data = [
        [
            chain for chain in chains
            if getattr(chain, 'cdr3_aa', None) is not None and
               getattr(chain, 'v_call', None) is not None and
               getattr(chain, 'j_call', None) is not None and
               getattr(chain, 'productive', True)
        ]
        for chains in mdata['airr'].obsm['airr']
    ]
    assert isinstance(filtered_data, list)
    mdata['airr'].obsm['airr'] = ak.Array(filtered_data)

    valid_indices = [i for i, chains in enumerate(mdata['airr'].obsm['airr']) if len(chains) == 2]
    filtered_airr = mdata['airr'][valid_indices]
    mdata.mod['airr'] = filtered_airr
    return mdata


def align_modality_cells(mdata):
    common_cells = mdata['gex'].obs_names[mdata['gex'].obs_names.isin(mdata['airr'].obs_names)]
    gex_mod = mdata['gex'][common_cells, :]
    airr_mod = mdata['airr'][common_cells, :]
    return md.MuData({'gex': gex_mod, 'airr': airr_mod})


def check_modality_alignment(mdata):
    gex_cells = mdata["gex"].obs_names.values
    airr_cells = mdata["airr"].obs_names.values
    check_results = [
        len(gex_cells) == len(airr_cells),
        (gex_cells == airr_cells).all(),
        np.array_equal(gex_cells, airr_cells)
    ]
    print(f"模态细胞数对比: gex({len(gex_cells)}) vs airr({len(airr_cells)})")
    if not check_results[0]:
        print("⚠️ 错误：两个模态的细胞数量不一致")
        return False
    if not check_results[1]:
        mismatch_idx = np.where(gex_cells != airr_cells)[0]
        print(f"⚠️ 错误：发现{mismatch_idx.size}个位置标识符不匹配")
        print("前5个不匹配样例：")
        for i in mismatch_idx[:5]:
            print(f"位置 {i}: gex='{gex_cells[i]}' vs airr='{airr_cells[i]}'")
        return False
    print("✅ 所有检查通过，模态完全对齐")
    return True
