In [1]:
from preprocess import create_glob_set, add_support
from model import load_model
from utils import Trainer, Evaluator, Predictor, NCDiscovery
import pandas as pd
import os

In [2]:
params = {
    'name': 'BaseSet',
    'tfname': 'TransferSet',
    'headshape': [2048],
    'dr': 0.5,
    'lr': 0.0001,
    'nsupport': 100,
    'niter': 100,
}

base_name = f"{params['name']}_{params['dr']:.2f}_{'-'.join(map(str, params['headshape']))}"
tf_name = f"{params['tfname']}_{params['dr']:.2f}_{'-'.join(map(str, params['headshape']))}_tf{params['tfname']}"
fs_name = f"FewShot_{params['dr']:.2f}_{'-'.join(map(str, params['headshape']))}"

In [3]:
print("=== 1. 데이터 준비 ===")
base_files = [
    f"db/processed/BaseSet/BaseSet_sim_train.csv",
    f"db/processed/BaseSet/BaseSet_sim_test.csv", 
    f"db/processed/BaseSet/BaseSet_sim_data.csv"
]
if all(os.path.exists(f) for f in base_files):
    print("BaseSet 데이터 존재 - 건너뜀")
else:
    print("BaseSet 데이터 생성 중...")
    create_glob_set('BaseSet', n_pos=7000, n_neg=7000, test_size=0.1)

transfer_files = [
    f"db/processed/TransferSet/TransferSet_sim_train.csv",
    f"db/processed/TransferSet/TransferSet_sim_test.csv",
    f"db/processed/TransferSet/TransferSet_sim_data.csv"
]
if all(os.path.exists(f) for f in transfer_files):
    print("TransferSet 데이터 존재 - 건너뜀")
else:
    print("TransferSet 데이터 생성 중...")
    create_glob_set('TransferSet', n_pos=7000, n_neg=7000, test_size=0.1)

=== 1. 데이터 준비 ===
BaseSet 데이터 존재 - 건너뜀
TransferSet 데이터 존재 - 건너뜀


In [4]:
print("=== 2. 기본 모델 훈련 ===")
model_path = f"model/model_{base_name}.pt"
if os.path.exists(model_path):
    print(f"기존 모델 발견: {model_path}")
    model = load_model(model_path)
    print("기존 모델 로드 완료")
    losses = pd.DataFrame()
else:
    print("새 모델 훈련 시작")
    model, losses, _ = Trainer.train_cycle(params)
    losses.to_csv(f"result/{base_name}_losses.csv")
    print("훈련 완료")

print("=== 3. 기본 모델 테스트 ===")
model = load_model(f"model/model_{base_name}.pt")
y_proba, y_matrix = Trainer.test_cycle(model, params, save_name=base_name, thresh=0.5, seed=777)

=== 2. 기본 모델 훈련 ===
기존 모델 발견: model/model_BaseSet_0.50_2048.pt
기존 모델 로드 완료
=== 3. 기본 모델 테스트 ===
test_file exist


In [None]:
print("=== 4. 전이 학습 === ")
tf_model_path = f"model/model_{tf_name}.pt"

if os.path.exists(tf_model_path):
    print(f"기존 전이학습 모델 발견: {tf_model_path}")
    tmodel = load_model(tf_model_path)
    print("기존 전이학습 모델 로드 완료")
else:
    print("전이 학습 시작")
    model = load_model(f"model/model_{tf_name}.pt")
    tmodel, losses, _ = Trainer.transfer_train_cycle(model, params)
    print("전이 학습 완료")


In [6]:

print("=== 5. 전이 학습 테스트 ===")
tmodel = load_model(f"model/model_{tf_name}.pt")
y_proba, y_matrix = Trainer.tf_test_cycle(tmodel, params, save_name=tf_name, thresh=0.5, seed=777)

=== 5. 전이 학습 테스트 ===
test_file exist


In [7]:
print("=== 6. Few-shot 평가 ===")
tmodel = load_model(f"model/model_{tf_name}.pt")
support_set = {}
test_raw = pd.DataFrame()

support_set, test_raw = add_support('FewshotSet', support_set, test_raw, test_ratio=0.1)
tmodel.support_pos = support_set

y_proba, y_matrix = Trainer.test_cycle(tmodel, params, save_name=fs_name, test_raw=test_raw, thresh=0.5)
print("Few-shot 평가 완료")

=== 6. Few-shot 평가 ===
test_file exist
Few-shot 평가 완료


In [8]:
print("=== 7. 베이스라인 비교 ===")
model = load_model(f"model/model_{base_name}.pt")
test_raw = pd.read_csv(f"db/processed/{params['name']}/{params['name']}_sim_test.csv")
x_cols = [col for col in test_raw.columns if col.startswith('X')]

y_proba = Predictor.predict_baseline(model, test_raw[x_cols].values, n_support=100, iter_size=100, random_seed=777)

y_matrix = y_proba.groupby(axis=1, level=0).max()
y_matrix['LABEL'] = test_raw['class']
if 'SMILES' in test_raw.columns:
    y_matrix['SMILES'] = test_raw['SMILES']

res = Evaluator.evaluation(y_matrix, test_raw['class'], thresh=0.6)
print("베이스라인 결과:")
res

=== 7. 베이스라인 비교 ===
베이스라인 결과:


Unnamed: 0,protein,accuracy,precision,recall,f1,auc,ap,count,tn,fp,fn,tp
0,ADORA2A,0.948416,0.142059,0.801181,0.241328,0.87556,0.115851,508,46642,2458,101,407
1,BRCA1,0.675073,0.028027,0.525522,0.053216,0.60162,0.022973,862,33036,15710,409,453
2,CNR1,0.757055,0.034302,0.830078,0.065881,0.793186,0.030227,512,37131,11965,87,425
3,DRD2,0.709603,0.047986,0.89801,0.091104,0.802255,0.044745,804,34480,14324,82,722
4,HTR1A,0.70152,0.036618,0.883281,0.070321,0.791224,0.033836,634,34241,14733,74,560
5,KCNH2,0.732825,0.024093,0.604128,0.046338,0.669175,0.018808,533,36032,13043,211,322
6,LMNA,0.649653,0.029816,0.347557,0.054921,0.503163,0.029473,1453,31723,16432,948,505
7,OPRM1,0.777637,0.046542,0.94709,0.088724,0.861384,0.044684,567,38040,11001,30,537
8,SLC6A4,0.801201,0.060313,0.910275,0.113129,0.854968,0.056151,691,39117,9800,62,629
9,TARDBP,0.622541,0.023983,0.361772,0.044984,0.495441,0.024359,1219,30442,17947,778,441


In [9]:
print("=== 8. 예측 진행 === ")
fto_csv_path = "db/raw/FewshotSet/FTO.csv"

fooddb_compounds_df = pd.read_csv("db/saved_data/fooddb.csv")
model = load_model(f"model/model_{tf_name}.pt")
print(f"FoodDB 화합물 수: {len(fooddb_compounds_df)}")

candidates = NCDiscovery.screen_compounds(
    model=model,
    target_csv_path=fto_csv_path,
    query_df=fooddb_compounds_df,
    threshold=0.7,
    top_k=50
)

print(f"발견된 후보 화합물 수: {len(candidates)}")
candidates.to_csv("fto_inhibitor_candidates.csv", index=False)

=== 8. 예측 진행 === 
FoodDB 화합물 수: 70413
발견된 후보 화합물 수: 33
