In [1]:
!pwd

/root/ensemble_commit/10


In [1]:

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from xgboost import XGBClassifier
from sklearn.metrics import classification_report, accuracy_score

df = pd.read_json(r'../datasets/patch_db.json', encoding='utf_8_sig')
df.dropna(inplace=True)
label2id={'non-security':0,'security':1}
df = df.replace({"category": label2id})
df

train, test = train_test_split(df, test_size=0.3, random_state=42)
test, val = train_test_split(test, test_size=0.5, random_state=42)
train,_ = train_test_split(train, train_size=0.1, random_state=42)
test, val = train_test_split(test, train_size=0.1, random_state=42)
val, _ = train_test_split(val, train_size=0.1, random_state=42)

train.reset_index(inplace=True)
test.reset_index(inplace=True)
val.reset_index(inplace=True)

In [3]:
df

Unnamed: 0,CVE_ID,CWE_ID,category,commit_id,commit_message,diff_code,owner,repo,source
0,,,non-security,540958e2f5a87b81aa5f55ce40b3e2869754f97d,commit 540958e2f5a87b81aa5f55ce40b3e2869754f97...,diff --git a/drivers/staging/comedi/drivers/cb...,stoth68000,media-tree,wild
1,,,non-security,64d240b721b21e266ffde645ec965c3b6d1c551f,commit 64d240b721b21e266ffde645ec965c3b6d1c551...,diff --git a/drivers/target/target_core_file.c...,stoth68000,media-tree,wild
2,,,non-security,f181dd278274f50e689ebd13237010a90b430164,commit f181dd278274f50e689ebd13237010a90b43016...,diff --git a/include/paths.h b/include/paths.h...,openbsd,src,wild
3,,,non-security,0abdc3723b5d33dde698ab941325edec2819c128,commit 0abdc3723b5d33dde698ab941325edec2819c12...,diff --git a/gnu/usr.bin/binutils/ld/lexsup.c ...,openbsd,src,wild
4,,,non-security,d7930d7f820e5dd6b07b823f155aeb943b525e16,commit d7930d7f820e5dd6b07b823f155aeb943b525e1...,diff --git a/src/expat_erl.c b/src/expat_erl.c...,esl,MongooseIM,wild
...,...,...,...,...,...,...,...,...,...
35810,CVE-2013-0217,399,security,7d5145d8eb2b9791533ffe4dc003b129b9696c48,From 7d5145d8eb2b9791533ffe4dc003b129b9696c48 ...,diff --git a/drivers/net/xen-netback/netback.c...,torvalds,linux,cve
35811,CVE-2018-18311,119,security,34716e2a6ee2af96078d62b065b7785c001194be,From 34716e2a6ee2af96078d62b065b7785c001194be ...,diff --git a/util.c b/util.c\nindex 7282dd9cfe...,Perl,perl5,cve
35812,CVE-2019-12984,476,security,385097a3675749cbc9e97c085c0e5dfe4269ca51,From 385097a3675749cbc9e97c085c0e5dfe4269ca51 ...,diff --git a/net/nfc/netlink.c b/net/nfc/netli...,torvalds,linux,cve
35813,CVE-2013-0865,119,security,f3d16706060ab6ae6dc78f15359fab3fd87c9495,From f3d16706060ab6ae6dc78f15359fab3fd87c9495 ...,diff --git a/libavcodec/vqavideo.c b/libavcode...,,,cve


In [5]:
train['category'].value_counts()

category
non-security    1644
security         863
Name: count, dtype: int64

In [2]:
len(train), len(test), len(val)

(2507, 537, 483)

In [2]:
test['diff_code']

0      diff --git a/net/ipv4/netfilter/ipt_recent.c b...
1      diff --git a/drivers/char/ipmi/ipmi_msghandler...
2      diff --git a/drivers/char/tty_ioctl.c b/driver...
3      diff --git a/src/rgw/rgw_admin.cc b/src/rgw/rg...
4      diff --git a/libavcodec/mjpegdec.c b/libavcode...
                             ...                        
532    diff --git a/drivers/media/video/tuner-xc2028....
533    diff --git a/src/os/bluestore/BlockDevice.cc b...
534    diff --git a/src/core/manager.c b/src/core/man...
535    diff --git a/ext/wddx/tests/bug72750.phpt b/ex...
536    diff --git a/src/modules/extra/m_ssl_gnutls.cp...
Name: diff_code, Length: 537, dtype: object

In [3]:
# 假设你有一个DataFrame df，其中包含'commit_message'和'label'列
# df = pd.DataFrame({'commit_message': [...], 'label': [...]})

# 1. 数据预处理和特征提取
vectorizer = TfidfVectorizer()
X_train = vectorizer.fit_transform(train['diff_code'])
y_train = train['category']

# 3. 训练XGBoost模型
clf = XGBClassifier()
clf.fit(X_train, y_train)

# 4. 评估模型
y_pred = clf.predict(vectorizer.transform(test['diff_code']))
print("Accuracy:", accuracy_score(test['category'], y_pred))
print("Classification Report:\n", classification_report(test['category'], y_pred,digits=4))


Accuracy: 0.7411545623836127
Classification Report:
               precision    recall  f1-score   support

           0     0.7770    0.8757    0.8234       370
           1     0.6167    0.4431    0.5157       167

    accuracy                         0.7412       537
   macro avg     0.6968    0.6594    0.6695       537
weighted avg     0.7271    0.7412    0.7277       537



In [5]:
from joblib import load,dump

In [6]:
dump(clf, 'XGBoost_patchDB.joblib')

['XGBoost_patchDB.joblib']