In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import mean_squared_error, r2_score
import joblib


df=pd.read_csv('../data/final_data_20250408.csv')


# 2. 불필요한 컬럼 제거
# 예시로 'ID'와 'Timestamp' 컬럼을 제거합니다.
cols_to_drop = ['pdb', 'query','seed','sample','data_file', 
                'chain_iptm', 'chain_pair_iptm','chain_pair_pae_min', 'chain_ptm',
                'format','model_path', 'native_path',
                'Fnat','Fnonnat','rRMS','iRMS','LRMS']
df = df.drop(columns=cols_to_drop)

# 3. 학습 데이터(X)와 레이블(y) 분리
# 여기서는 'target' 컬럼이 레이블이라고 가정합니다.
X = df.drop(columns=['DockQ'])
y = df['DockQ']

# 결측치 확인 및 처리
print("결측치 개수:\n", df.isnull().sum())
df = df.dropna()  # 또는 적절한 값으로 대체

# 데이터 스케일링 (Random Forest는 덜 필요하지만 특성 중요도 해석에 도움)
scaler = StandardScaler()
numeric_features = X.select_dtypes(include=['float64', 'int64']).columns
X[numeric_features] = scaler.fit_transform(X[numeric_features])

# DockQ가 연속값이면 회귀 문제, 범주형이면 분류 문제로 접근
# 회귀 문제로 가정하고 진행

# 데이터 분할 (학습:검증:테스트 = 70:15:15)
X_train, X_temp, y_train, y_temp = train_test_split(X, y, test_size=0.3, random_state=42)
X_val, X_test, y_val, y_test = train_test_split(X_temp, y_temp, test_size=0.5, random_state=42)

# 기본 모델 학습 및 평가
rf = RandomForestRegressor(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)

# 검증 세트 성능 평가
y_val_pred = rf.predict(X_val)
val_mse = mean_squared_error(y_val, y_val_pred)
val_r2 = r2_score(y_val, y_val_pred)
print(f"검증 세트 MSE: {val_mse:.4f}, R²: {val_r2:.4f}")

# 하이퍼파라미터 최적화
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 10, 20, 30],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4]
}

kf = KFold(n_splits=5, shuffle=True, random_state=42)
grid_search = GridSearchCV(
    RandomForestRegressor(random_state=42),
    param_grid=param_grid,
    cv=kf,
    scoring='neg_mean_squared_error',
    n_jobs=-1
)
grid_search.fit(X_train, y_train)

print(f"최적 하이퍼파라미터: {grid_search.best_params_}")

# 최적 모델로 테스트 세트 평가
best_rf = grid_search.best_estimator_
y_test_pred = best_rf.predict(X_test)
test_mse = mean_squared_error(y_test, y_test_pred)
test_r2 = r2_score(y_test, y_test_pred)
print(f"테스트 세트 MSE: {test_mse:.4f}, R²: {test_r2:.4f}")

# 특성 중요도 시각화
feature_importances = best_rf.feature_importances_
sorted_idx = feature_importances.argsort()[::-1]

plt.figure(figsize=(10, 6))
plt.bar(range(X.shape[1]), feature_importances[sorted_idx])
plt.xticks(range(X.shape[1]), X.columns[sorted_idx], rotation=90)
plt.title('Random Forest 특성 중요도')
plt.tight_layout()
plt.show()

# 상위 10개 특성 출력
top_features = [(X.columns[i], feature_importances[i]) for i in sorted_idx[:10]]
print("상위 10개 특성:")
for feature, importance in top_features:
    print(f"  {feature}: {importance:.4f}")

# 5-fold 교차 검증
cv = KFold(n_splits=5, shuffle=True, random_state=42)
cv_scores = cross_val_score(rf_model, X_scaled, y, cv=cv, scoring='neg_mean_squared_error')

print(f"교차 검증 MSE: {-cv_scores.mean():.4f} (±{cv_scores.std():.4f})")

# 특성 간 상관관계 확인
plt.figure(figsize=(12, 10))
correlation = X.corr()
sns.heatmap(correlation, annot=True, cmap='coolwarm', linewidths=0.5)
plt.title("특성 간 상관관계")
plt.tight_layout()
plt.show()

# 높은 상관관계를 가진 특성들 확인 (0.9 이상)
high_corr = (correlation.abs() > 0.9) & (correlation != 1.0)
high_corr_features = []
for col in high_corr.columns:
    high_corr_pairs = high_corr[col][high_corr[col]].index.tolist()
    if high_corr_pairs:
        high_corr_features.append((col, high_corr_pairs))

if high_corr_features:
    print("높은 상관관계를 가진 특성들:")
    for feature, corr_features in high_corr_features:
        print(f"  {feature}: {corr_features}")

# 최적 모델 저장
joblib.dump(best_rf, 'best_rf_model.pkl')
joblib.dump(scaler, 'scaler.pkl')

# 모델 불러오기 및 사용 예시
# loaded_model = joblib.load('best_rf_model.pkl')
# loaded_scaler = joblib.load('scaler.pkl')
# scaled_data = loaded_scaler.transform(new_data[numeric_features])
# predictions = loaded_model.predict(scaled_data)

결측치 개수:
 fraction_disordered    0
has_clash              0
ipTM                   0
pTM                    0
ranking_confidence     0
                      ..
chain_pair_pae_AL      0
avg_pair_pae           0
avg_model_plddt        0
chain_pair_pae_HL      0
avg_internal_pae       0
Length: 61, dtype: int64
검증 세트 MSE: 0.0019, R²: 0.9778


KeyboardInterrupt: 

In [5]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
import pandas as pd
import numpy as np

pd.set_option('display.max_columns', None)
pd.set_option('display.max_colwidth', None)

df=pd.read_csv('../data/final_data_20250408.csv')


# 2. 불필요한 컬럼 제거
# 예시로 'ID'와 'Timestamp' 컬럼을 제거합니다.
cols_to_drop = ['pdb', 'query','seed','sample','data_file', 
                'chain_iptm', 'chain_pair_iptm','chain_pair_pae_min', 'chain_ptm',
                'format','model_path', 'native_path',
                'Fnat','Fnonnat','rRMS','iRMS','LRMS']
df = df.drop(columns=cols_to_drop)

# 3. 학습 데이터(X)와 레이블(y) 분리
# 여기서는 'target' 컬럼이 레이블이라고 가정합니다.
X = df.drop(columns=['DockQ'])
y = df['DockQ']

# # 4. 학습 데이터와 테스트 데이터로 분할
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# # 5. Random Forest 모델 생성 및 학습
# model = RandomForestClassifier(random_state=42)
# model.fit(X_train, y_train)

# # 6. 모델 예측 및 평가
# y_pred = model.predict(X_test)
# accuracy = accuracy_score(y_test, y_pred)

# print("모델 정확도:", accuracy)

In [6]:
X

Unnamed: 0,fraction_disordered,has_clash,ipTM,pTM,ranking_confidence,pdockq_AH,pdockq_AL,pdockq_HL,pdockq,pdockq2_AH,pdockq2_AL,pdockq2_HL,pdockq2,mpdockq_AH,mpdockq_AL,mpdockq_HL,mpdockq,LIS_AH,LIS_AL,LIS_HL,avg_LIS,contacts_AH,interface_plddt_AH,interface_pae_AH,contacts_AL,interface_plddt_AL,interface_pae_AL,contacts_HL,interface_plddt_HL,interface_pae_HL,total_contacts,avg_interface_plddt,avg_interface_pae,model_avg_RMSD,query_avg_RMSD,model_rmsd_scale,iptm_A,iptm_H,iptm_L,ptm_A,ptm_H,ptm_L,chain_pair_iptm_AH,chain_pair_iptm_AL,chain_pair_iptm_HL,chain_pair_pae_min_AH,chain_pair_pae_min_AL,chain_pair_pae_min_HL,chain_plddt_L,chain_plddt_A,chain_pae_A,chain_plddt_H,chain_pair_pae_AH,chain_pae_H,chain_pae_L,chain_pair_pae_AL,avg_pair_pae,avg_model_plddt,chain_pair_pae_HL,avg_internal_pae
0,0.00,0.0,0.91,0.93,0.91,0.342220,0.077264,0.516534,0.455998,0.767638,0.902289,0.815250,0.694185,0.262,0.262,0.262001,0.262000,0.623716,0.631027,0.744084,0.666276,2873.0,91.748091,3.567034,1069.0,93.798798,2.458242,3915.0,96.140033,2.542372,7857.0,93.895641,2.855883,2.075217,2.101325,0.506251,0.85,0.90,0.87,0.82,0.91,0.91,0.88,0.82,0.92,0.880,1.055,0.895,94.209845,90.717220,2.734350,94.231186,5.019568,2.469001,2.403128,5.114726,4.538768,93.052750,3.482009,2.535493
1,0.00,0.0,0.91,0.93,0.91,0.331281,0.069762,0.527177,0.442246,0.755151,0.894533,0.885665,0.739335,0.262,0.262,0.262001,0.262000,0.622075,0.626972,0.746965,0.665337,2773.0,92.729248,3.512828,1103.0,93.796718,2.546296,3905.0,96.453397,2.088824,7781.0,94.326454,2.715983,1.957483,2.101325,0.535395,0.85,0.90,0.87,0.82,0.92,0.91,0.88,0.82,0.92,0.875,1.080,0.885,94.247836,90.933359,2.737291,94.720042,4.885757,2.293457,2.432452,5.195561,4.485196,93.300412,3.374271,2.487733
2,0.00,0.0,0.91,0.93,0.91,0.335378,0.064479,0.528147,0.441553,0.835029,0.904203,0.888190,0.761060,0.262,0.262,0.262001,0.262000,0.626143,0.630800,0.753669,0.670204,2745.0,92.975432,3.425616,1092.0,93.480522,2.424333,3912.0,96.475105,2.061224,7749.0,94.310353,2.637058,2.216276,2.101325,0.473395,0.85,0.90,0.87,0.82,0.92,0.91,0.88,0.83,0.92,0.875,1.050,0.880,94.487360,90.730714,2.740211,94.831354,4.873263,2.243352,2.336347,5.119666,4.420462,93.349809,3.268457,2.439970
3,0.00,0.0,0.91,0.93,0.91,0.344821,0.070719,0.532266,0.453196,0.769181,0.893866,0.883787,0.761893,0.262,0.262,0.262001,0.262000,0.608838,0.619515,0.744474,0.657609,2884.0,91.220964,3.647286,1036.0,94.041255,2.546474,3875.0,96.353924,2.112004,7795.0,93.872048,2.768588,1.732211,2.101325,0.595401,0.84,0.90,0.87,0.81,0.91,0.91,0.87,0.82,0.92,0.890,1.090,0.890,94.234102,90.429131,2.804522,94.422267,5.058343,2.354819,2.405847,5.253095,4.574639,93.028500,3.412478,2.521729
4,0.00,0.0,0.90,0.92,0.90,0.354432,0.075200,0.545661,0.461509,0.718113,0.873801,0.831357,0.686014,0.262,0.262,0.262001,0.262000,0.593848,0.601764,0.731953,0.642521,2937.0,90.863139,3.636828,1218.0,92.478649,2.760969,3991.0,96.026966,2.280572,8146.0,93.122918,2.892790,1.688816,2.101325,0.607563,0.83,0.89,0.86,0.81,0.91,0.90,0.87,0.80,0.91,0.895,1.160,0.920,93.894233,90.203494,2.849254,94.221490,5.270345,2.445614,2.508474,5.498757,4.777462,92.773072,3.563284,2.601114
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3645,0.00,0.0,0.84,0.74,0.82,0.232218,0.070009,0.555710,0.599316,0.561245,0.624653,0.814576,0.618818,0.262,0.262,0.262001,0.262002,0.289527,0.285897,0.625503,0.400309,2719.0,87.269937,5.841875,954.0,89.932940,4.209048,4938.0,93.207609,2.761589,8611.0,90.136829,4.270837,28.091750,43.520733,0.705894,0.84,0.85,0.84,0.77,0.86,0.84,0.84,0.83,0.85,2.635,2.555,1.305,90.857188,85.486119,12.073501,90.432776,17.676980,3.139595,3.401286,18.382169,13.648970,88.925361,4.887762,6.204794
3646,0.00,0.0,0.84,0.72,0.81,0.248864,0.070228,0.557309,0.600589,0.554938,0.611623,0.812720,0.605055,0.262,0.262,0.262001,0.262002,0.288403,0.288247,0.623829,0.400160,2717.0,87.278721,5.304783,931.0,90.191917,4.250952,4930.0,93.243791,2.778444,8578.0,90.238143,4.111393,28.121604,43.520733,0.705452,0.83,0.84,0.84,0.76,0.85,0.84,0.84,0.83,0.85,2.605,2.695,1.330,91.019183,85.245260,12.405681,90.483141,18.143902,3.176724,3.406830,18.707529,13.921269,88.915861,4.912376,6.329745
3647,0.00,0.0,0.84,0.74,0.82,0.241769,0.069642,0.545479,0.591900,0.569695,0.616918,0.812042,0.622400,0.262,0.262,0.262001,0.262002,0.289766,0.288060,0.625000,0.400942,2779.0,87.267436,5.845484,994.0,89.867173,4.114359,4965.0,93.040888,2.766102,8738.0,90.058499,4.241982,27.989834,43.520733,0.707400,0.84,0.84,0.84,0.78,0.85,0.84,0.84,0.83,0.85,2.525,2.640,1.320,90.792427,85.401828,12.051977,90.342624,17.581776,3.161998,3.407699,18.237571,13.571983,88.845627,4.896603,6.207225
3648,0.01,0.0,0.84,0.74,0.82,0.228089,0.071001,0.546140,0.593414,0.580567,0.636103,0.820262,0.631419,0.262,0.262,0.262001,0.262002,0.292963,0.292603,0.626221,0.403929,2635.0,87.829805,5.163556,954.0,90.394853,4.123590,4894.0,93.352547,2.759766,8483.0,90.525735,4.015637,28.251209,43.520733,0.703538,0.84,0.85,0.84,0.78,0.86,0.84,0.84,0.84,0.85,2.585,2.635,1.300,91.144831,85.785195,11.876426,90.732411,17.332290,3.133910,3.389435,17.963919,13.393566,89.220812,4.884489,6.133257


In [7]:
y

0       0.704191
1       0.695233
2       0.706899
3       0.737216
4       0.723623
          ...   
3645    0.451226
3646    0.423402
3647    0.478954
3648    0.445480
3649    0.447185
Name: DockQ, Length: 3650, dtype: float64