In [1]:
# import af_analysis
# from af_analysis import analysis
# from af_analysis.data import Data

# def calculate_all_metrics_per_query(query_group_data, **kwargs):
#     """
#     단일 query 그룹에 대해 모든 메트릭을 순차적으로 계산
    
#     Parameters
#     ----------
#     query_group_data : tuple
#         (query_name, query_df) 형태의 데이터
#     kwargs : dict
#         각종 메트릭 계산 옵션들
        
#     Returns
#     -------
#     result_df : pd.DataFrame
#         모든 메트릭이 계산된 query 데이터
#     """
#     query_name, query_df = query_group_data
    
#     try:
#         # # Data 객체 생성
#         temp_data = Data(data_dict={
#             'pdb': query_df['pdb'].tolist(),
#             'query': query_df['query'].tolist(),
#             'data_file': query_df['data_file'].tolist()
#             # ... 기타 필요한 컬럼들
#         })
        
#         # # 모든 메트릭 계산 (순차 실행)
#         # temp_data = (temp_data
#         #             .extract_chain_columns(verbose=False)
#         #             .analyze_chains(verbose=False))
        
#         # # Analysis 함수들
#         # from af_analysis import analysis
#         # analysis.add_interface_metrics(temp_data, verbose=False)
#         # analysis.pdockq(temp_data, verbose=False)
#         # analysis.pdockq2(temp_data, verbose=False)
        
#         # # Data 메서드들
#         # temp_data = (temp_data
#         #             .add_pitm_pis(cutoff=8.0, verbose=False)
#         #             .add_chain_rmsd(align_chain='A', rmsd_chain='H')
#         #             .add_rmsd_scale())
        
#         # Rosetta 메트릭 (multiprocessing 비활성화)
#         temp_data = (temp_data.add_rosetta_metrics(n_jobs=1, verbose=False))
        
#         # DockQ (옵션)
#         if kwargs.get('calculate_dockq', False):
#             temp_data.prep_dockq(native_dir=kwargs.get('native_dir'), verbose=False)
#             analysis.calculate_dockQ(temp_data, 
#                                    rec_chains='A', lig_chains='H',
#                                    native_rec_chains='A', native_lig_chains='H',
#                                    verbose=False)
        
#         return query_name, temp_data.df
        
#     except Exception as e:
#         print(f"Error processing query {query_name}: {str(e)}")
#         return query_name, None


# def calculate_all_metrics_parallel(data_obj, n_jobs=None, **kwargs):
#     """
#     모든 query에 대해 병렬로 메트릭 계산
    
#     Parameters
#     ----------
#     data_obj : Data
#         전체 데이터 객체
#     n_jobs : int, optional
#         사용할 CPU 코어 수. None이면 자동 설정
#     **kwargs : dict
#         메트릭 계산 옵션들
        
#     Returns
#     -------
#     updated_data : Data
#         모든 메트릭이 계산된 데이터 객체
#     """
#     import multiprocessing as mp
#     from functools import partial
#     import pandas as pd
#     from tqdm.auto import tqdm
    
#     # Query별로 데이터 그룹화
#     query_groups = list(data_obj.df.groupby('query'))
    
#     if n_jobs is None:
#         n_jobs = min(mp.cpu_count(), len(query_groups))
    
#     print(f"Processing {len(query_groups)} queries using {n_jobs} CPU cores...")
    
#     # 병렬 처리
#     worker_func = partial(calculate_all_metrics_per_query, **kwargs)
    
#     if n_jobs > 1:
#         ctx = mp.get_context("fork")
#         with ctx.Pool(n_jobs) as pool:
#             results = list(tqdm(
#                 pool.imap(worker_func, query_groups),
#                 total=len(query_groups),
#                 desc="Processing queries"
#             ))
#     else:
#         results = [worker_func(group) for group in tqdm(query_groups, desc="Processing queries")]
    
#     # 결과 통합
#     successful_results = []
#     failed_queries = []
    
#     for query_name, result_df in results:
#         if result_df is not None:
#             successful_results.append(result_df)
#         else:
#             failed_queries.append(query_name)
    
#     if failed_queries:
#         print(f"Failed to process {len(failed_queries)} queries: {failed_queries}")
    
#     # DataFrame 통합
#     if successful_results:
#         combined_df = pd.concat(successful_results, ignore_index=True)
#         data_obj.df = combined_df
#         print(f"Successfully processed {len(successful_results)} queries")
    
#     return data_obj


# # 사용 예시
# def main():
#     # 데이터 로드
#     input_data=Data(directory='/home/cseomoon/project/ABAG/2025_H_L_A/20250504_seeds_10/negative/af3_results/results1/6x97_7o9w')
#     # my_data.df
#     # data = Data(csv="your_data.csv")
    
#     # 병렬 메트릭 계산
#     my_data = calculate_all_metrics_parallel(
#         input_data, 
#         n_jobs=48,  # 8 CPU 코어 사용
#         calculate_dockq=False,
#     )
    
#     # 결과 저장
#     my_data.export_file("results_with_all_metrics_only_rosetta.csv")

# main()

In [8]:
import time
import logging
from contextlib import contextmanager
import af_analysis
from af_analysis import analysis
from af_analysis.data import Data

# 로깅 레벨 설정
logging.getLogger('pdb_numpy').setLevel(logging.WARNING)  # INFO 레벨의 로그 숨김
logging.getLogger('pdb_numpy.coor').setLevel(logging.WARNING)
logging.getLogger('pdb_numpy.analysis').setLevel(logging.WARNING)



@contextmanager
def time_tracker(description):
    """시간 측정 컨텍스트 매니저"""
    start = time.time()
    print(f"🔄 Starting: {description}")
    yield
    elapsed = time.time() - start
    print(f"✅ Completed: {description} ({elapsed:.2f}s)")

def calculate_all_metrics_per_query_optimized(query_group_data, **kwargs):
    """최적화된 단일 query 메트릭 계산"""
    query_name, query_df = query_group_data
    
    print(f"\n🚀 Processing query: {query_name} ({len(query_df)} models)")
    total_start = time.time()
    
    try:
        # Data 객체 생성
        with time_tracker("Data object creation"):
            # 모든 필요한 컬럼을 포함하여 생성
            data_dict = {}
            for col in query_df.columns:
                data_dict[col] = query_df[col].tolist()
            
            temp_data = Data(data_dict=data_dict)
        
        # 1. 기본 AF3 metrics
        with time_tracker("Basic AF3 metrics"):
            temp_data = (temp_data
                        .extract_chain_columns(verbose=False)
                        .analyze_chains(verbose=False)
                        .add_h3_l3_plddt(verbose=False))  # 여기에 추가
                
        # 2. Interface metrics
        with time_tracker("Interface metrics"):
            analysis.add_interface_metrics(temp_data, verbose=False)
        
        # 3. PPI metrics (가장 빠른 것들 먼저)
        with time_tracker("pDockQ calculations"):
            analysis.pdockq(temp_data, verbose=False)
            analysis.pdockq2(temp_data, verbose=False)
        
        with time_tracker("LIS matrix"):
            analysis.LIS_matrix(temp_data, verbose=False)
        
        # 4. piTM/pIS (중간 시간 소요)
        with time_tracker("piTM/pIS calculation"):
            temp_data.add_pitm_pis(cutoff=8.0, verbose=False)
        
        # 5. RMSD metrics
        with time_tracker("RMSD calculations"):
            temp_data = (temp_data
                        .add_chain_rmsd(align_chain='A', rmsd_chain='H')
                        .add_rmsd_scale())
        
        #6. Rosetta metrics (가장 시간 소모적 - 마지막에)
        # with time_tracker("Rosetta binding energy"):
        #     temp_data.calculate_binding_energy()
        
        with time_tracker("Rosetta interface metrics"):
            temp_data.add_rosetta_metrics(n_jobs=1, verbose=False)
        
        # 7. DockQ (옵션)
        if kwargs.get('calculate_dockq', False):
            with time_tracker("DockQ calculation"):
                temp_data.prep_dockq(native_dir=kwargs.get('native_dir'), verbose=False)
                analysis.calculate_dockq(temp_data, 
                                       rec_chains='A', lig_chains='H',
                                       native_rec_chains='A', native_lig_chains='H',
                                       verbose=False)
        
        total_time = time.time() - total_start
        print(f"✨ Query {query_name} completed in {total_time:.2f}s")
        
        return query_name, temp_data.df
        
    except Exception as e:
        total_time = time.time() - total_start
        print(f"❌ Error processing query {query_name} after {total_time:.2f}s: {str(e)}")
        import traceback
        print(traceback.format_exc())
        return query_name, None

In [9]:
import pandas as pd
import os

base_path='/home/cseomoon/project/ABAG/DB/ABAG_structure/AF3/native'
query='6x97'

target=os.path.join(base_path,query)
my_data=Data(directory=target)
df=my_data.df

# 예시: CSV로부터 읽어온 DataFrame. 'query' 컬럼으로 그룹화한다고 가정
for query_name, group_df in df.groupby("query"):
    print(query_name)
    # calculate_dockq 옵션과 native_dir 전달 예시
    name, result_df = calculate_all_metrics_per_query_optimized(
        (query_name, group_df),
        calculate_dockq=True,
        native_dir="/home/cseomoon/project/ABAG/DB/ABAG_structure/original_pdb"
    )
    
    if result_df is not None:
        # 결과를 파일로 저장하거나 후처리
        result_df.to_csv(f"metrics_{name}.csv", index=False)
    else:
        print(f"[Warning] {name} 처리 실패")

INFO:root:Reading /home/cseomoon/project/ABAG/DB/ABAG_structure/AF3/native/6x97


6x97

🚀 Processing query: 6x97 (50 models)
🔄 Starting: Data object creation
✅ Completed: Data object creation (0.04s)
🔄 Starting: Basic AF3 metrics
✅ Completed: Basic AF3 metrics (41.19s)
🔄 Starting: Interface metrics


  analysis.add_interface_metrics(temp_data, verbose=False)


✅ Completed: Interface metrics (80.95s)
🔄 Starting: pDockQ calculations
✅ Completed: pDockQ calculations (15.95s)
🔄 Starting: LIS matrix
✅ Completed: LIS matrix (4.82s)
🔄 Starting: piTM/pIS calculation


INFO:af_analysis.data:처리 중: 6x97


✅ Completed: piTM/pIS calculation (10.94s)
🔄 Starting: RMSD calculations


INFO:af_analysis.data:6x97: 평균 RMSD = 28.63 Å
INFO:af_analysis.data:스케일링된 RMSD 관련 열들이 추가되었습니다: scaled_rmsd_ratio, scaled_model_RMSD, scaled_query_RMSD


✅ Completed: RMSD calculations (87.07s)
🔄 Starting: Rosetta interface metrics
✅ Completed: Rosetta interface metrics (435.80s)
🔄 Starting: DockQ calculation
✅ Completed: DockQ calculation (118.73s)
✨ Query 6x97 completed in 795.50s
