In [12]:
from sklearn.model_selection import train_test_split

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.manifold import TSNE
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score, precision_score, recall_score, f1_score
from sklearn.utils.class_weight import compute_class_weight
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
data=pd.read_csv('./vdjdb.txt',sep='\t', header=0)
data[data['species'] == 'MusMusculus']
data=data[data['vdjdb.score'] != 0]
columns_to_drop = ['antigen.species','antigen.gene','reference.id', 'method', 'meta','cdr3fix','web.method','web.method.seq','web.cdr3fix.nc','web.cdr3fix.unmp']
data.drop(columns=columns_to_drop, inplace=True)
data=data.dropna()
data= data.drop_duplicates()
data_beta_Mus = data[data['gene'] == 'TRB'].copy()
# 假设 `data` 是已加载的pandas DataFrame，并且包含了上述的列

class_counts = data_beta_Mus['antigen.epitope'].value_counts()

# Find classes that have only one instance
single_classes = class_counts[class_counts == 1].index

# Remove rows where 'antigen.epitope' belongs to classes with only one instance
data_beta_Mus_filtered = data_beta_Mus[~data_beta_Mus['antigen.epitope'].isin(single_classes)]

# 首先，根据`cdr3_a_aa`、`v_a_gene`和`j_a_gene`来计算距离矩阵
# 计算距离矩阵的具体代码将根据所选算法和数据类型有所不同

y = data_beta_Mus_filtered['antigen.epitope']

In [13]:
column_trans = ColumnTransformer(
    [
        ('one_hot_encoder_vj', OneHotEncoder(), ['v.segm', 'j.segm','mhc.a','mhc.b','mhc.class']),
        ('one_hot_encoder_cdr3', OneHotEncoder(handle_unknown='ignore'), ['cdr3'])
    ],
    remainder='drop'
)

X_encoded = column_trans.fit_transform(data_beta_Mus_filtered)
X_train, X_test, y_train, y_test = train_test_split(X_encoded, y, test_size=0.2, random_state=42)

In [14]:
# 5. 计算类别权重
from sklearn.utils.class_weight import compute_class_weight
class_weights = compute_class_weight('balanced', classes=np.unique(y_train), y=y_train)
weights = dict(zip(np.unique(y_train), class_weights))
#*
# 6. 训练随机森林模型
rf_classifier_weighted = RandomForestClassifier(random_state=30, class_weight=weights)
rf_classifier_weighted.fit(X_train, y_train)

# 7. 进行预测
y_pred_weighted = rf_classifier_weighted.predict(X_test)

# 8. 生成和显示分类报告
classification_report_weighted = classification_report(y_test, y_pred_weighted, zero_division=0)
#print(classification_report_weighted)

# 9. 计算和显示准确率、精确度、召回率和F1分数
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
accuracy = accuracy_score(y_test, y_pred_weighted)
precision = precision_score(y_test, y_pred_weighted, average='macro', zero_division=0)
recall = recall_score(y_test, y_pred_weighted, average='macro', zero_division=0)
f1 = f1_score(y_test, y_pred_weighted, average='macro', zero_division=0)

print(f'Accuracy: {accuracy:.2f}')
print(f'Precision: {precision:.2f}')
print(f'Recall: {recall:.2f}')
print(f'F1 Score: {f1:.2f}')
#onehot编码，随机森林预测

Accuracy: 0.63
Precision: 0.44
Recall: 0.44
F1 Score: 0.42


In [15]:
results_df = pd.DataFrame({'Actual': y_test, 'Predicted': y_pred_weighted})
print(results_df) 
mismatches = results_df[results_df['Actual'] != results_df['Predicted']]
print("Mismatched Predictions:")
print(mismatches)

           Actual   Predicted
75670  HPKVSSEVHI  HPKVSSEVHI
78590   HSKKKCDEL   HSKKKCDEL
82837    TTPESANL    STPESANL
21559   GILGFVFTL   GLCTLVAML
75983    FLKEKGGL    EIYKRWII
...           ...         ...
78158  EAAGIGILTV  ELAGIGILTV
76065  KRWIILGLNK  KRWIILGLNK
83809   RMFPNAPYL  KLVALGINAV
75951    EIYKRWII    FLKEKGGL
79973   FRCPRRFCF   CRVLCCYVL

[1305 rows x 2 columns]
Mismatched Predictions:
           Actual   Predicted
82837    TTPESANL    STPESANL
21559   GILGFVFTL   GLCTLVAML
75983    FLKEKGGL    EIYKRWII
26458  KLYGLDWAEL   GILGFVFTL
89648   YLQPRTFLL   GLCTLVAML
...           ...         ...
25795   NLVPMVATV   GILGFVFTL
78158  EAAGIGILTV  ELAGIGILTV
83809   RMFPNAPYL  KLVALGINAV
75951    EIYKRWII    FLKEKGGL
79973   FRCPRRFCF   CRVLCCYVL

[488 rows x 2 columns]


In [16]:
error_counts = mismatches['Actual'].value_counts()
print("Error Counts by Actual Category:")
print(error_counts)
error_counts.to_csv("error_counts.csv")

Error Counts by Actual Category:
NLVPMVATV          23
ELAGIGILTV         23
TTPESANL           19
YVLDHLIVV          18
YLQPRTFLL          16
                   ..
KLYGLDWAEL          1
GELIGILNAAKVPAD     1
FLGKIWPSHK          1
LPRRSGAAGA          1
FRCPRRFCF           1
Name: Actual, Length: 137, dtype: int64
