**Immune2vec: Embedding B/T Cell Receptor Sequences Using Natural Language Processing**

In NLP, the term “embedding” refers to the representation of symbolic information in text at the word-level, phrase-level, and even sentence-level, in terms of real number vectors.

*Following https://www.frontiersin.org/journals/immunology/articles/10.3389/fimmu.2021.680687/full#f2*

In [70]:
import pandas as pd
import numpy as np
from gensim.models import Word2Vec
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import classification_report, accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV, cross_val_score 
from sklearn.model_selection import StratifiedKFold


• Data set

The first step in creating the model is building an adequate corpus for word2vec training. All scores of over 0, alpha and beta cdr3 sequences between length 12 and 14, species type HomoSapien

In [20]:
file_path = "vdjdb_full.txt"
df = pd.read_csv(file_path, delimiter='\t')
df = df.drop_duplicates()
df = df[(df['vdjdb.score'] > 0)]
df = df[['cdr3.alpha','cdr3.beta','species','antigen.epitope','antigen.gene','vdjdb.score']]
print(df.shape)

  df = pd.read_csv(file_path, delimiter='\t')


(9300, 6)


In [21]:
df['cdr3.alpha.length'] = df['cdr3.alpha'].apply(lambda x: len(x) if pd.notnull(x) and not isinstance(x, float) else 0)
df = df[(df['cdr3.alpha.length'] >= 12) & (df['cdr3.alpha.length'] <= 14)]

df['cdr3.beta.length'] = df['cdr3.beta'].apply(lambda x: len(x) if pd.notnull(x) and not isinstance(x, float) else 0)
df = df[(df['cdr3.beta.length'] >= 12) & (df['cdr3.beta.length'] <= 14)]

df = df[df['species'] == 'HomoSapiens']

df['cdr3combined'] = df['cdr3.alpha'].fillna('') + df['cdr3.beta'].fillna('')

df = df.reset_index(drop=True)

In [22]:
print(df.iloc[0])
print(df.shape)

cdr3.alpha                         CAYRPPGTYKYIF
cdr3.beta                         CASSALASLNEQFF
species                              HomoSapiens
antigen.epitope                         FLKEKGGL
antigen.gene                                 Nef
vdjdb.score                                    2
cdr3.alpha.length                             13
cdr3.beta.length                              14
cdr3combined         CAYRPPGTYKYIFCASSALASLNEQFF
Name: 0, dtype: object
(555, 9)


**Split all sequences to non-overlapping n-grams**

cdr3 sequences are already in amino acid format so dont need to be translated, next step is to process them into a non-overlapping n-gram method 

In [24]:
def split_into_ngrams(sequence, n):
    return [sequence[i:i+n] for i in range(0, len(sequence), n)]

n = 3  #as per paper
df['cdr3.alpha.ngrams'] = df['cdr3.alpha'].apply(lambda x: split_into_ngrams(x, n) if pd.notnull(x) else x)
df['cdr3.beta.ngrams'] = df['cdr3.beta'].apply(lambda x: split_into_ngrams(x, n) if pd.notnull(x) else x)

In [50]:
combined_ngrams = df['cdr3.alpha.ngrams'] + df['cdr3.beta.ngrams']

In [51]:
print(combined_ngrams)

0       [CAY, RPP, GTY, KYI, F, CAS, SAL, ASL, NEQ, FF]
1       [CAG, PTG, GSY, IPT, F, CAS, SVA, LAT, GEQ, YF]
2      [CVV, SAI, TND, YKL, SF, CAS, SLI, EGG, TEA, FF]
3       [CAV, QPG, AGG, FKT, IF, CAS, SLI, EGL, EQY, F]
4        [CAS, QSN, TGN, QFY, F, CAS, SLI, EQQ, PQH, F]
                             ...                       
550    [CAL, SEA, GAN, SKL, TF, CAS, SLL, AGG, DTQ, YF]
551    [CAG, QLD, SGT, YKY, IF, CAS, SPA, GWD, TEA, FF]
552    [CLV, GGD, NQG, GKL, IF, CAS, SQR, QGG, NTI, YF]
553        [CAV, NAR, NAG, NML, TF, CAS, SFD, GET, QYF]
554      [CAV, EGG, SNY, KLT, F, CSV, GAG, GSG, ELF, F]
Length: 555, dtype: object


In [25]:
print(df.iloc[0])

cdr3.alpha                         CAYRPPGTYKYIF
cdr3.beta                         CASSALASLNEQFF
species                              HomoSapiens
antigen.epitope                         FLKEKGGL
antigen.gene                                 Nef
vdjdb.score                                    2
cdr3.alpha.length                             13
cdr3.beta.length                              14
cdr3combined         CAYRPPGTYKYIFCASSALASLNEQFF
cdr3.alpha.ngrams        [CAY, RPP, GTY, KYI, F]
cdr3.beta.ngrams        [CAS, SAL, ASL, NEQ, FF]
Name: 0, dtype: object


The skip-gram model works by predicting the context (surrounding words) given a target word.

In [52]:
from gensim.models import Word2Vec

#combine the sequences
sequences = combined_ngrams.tolist()

model = Word2Vec(sequences, 
                 min_count=2,  #ignore words that appear fewer than twice
                 window=28,  #context window of 28 to fit max length of cdr3
                 vector_size=100,  #embed into 100-dimensional vectors
                 sg=1)  #skip-gram algorithm

model.save("word2vec.model")

In [53]:
print(sequences[0])

['CAY', 'RPP', 'GTY', 'KYI', 'F', 'CAS', 'SAL', 'ASL', 'NEQ', 'FF']


In [57]:
print(len((max(sequences, key=lambda x: len(x)))))

10


In [58]:
#model = Word2Vec.load("word2vec.model")

#convert sequence to vector
def sequence_to_vec(ngrams, model):
    #convert n-gram to a vector and compute the weighted average
    vectors = [model.wv[ngram] for ngram in ngrams if ngram in model.wv]
    if not vectors:
        return np.zeros(model.vector_size)
    weights = [1/len(vectors)] * len(vectors)  # Equal weights for all vectors
    return np.average(vectors, axis=0, weights=weights)

vectors = [sequence_to_vec(seq, model) for seq in sequences]

df_vectors = pd.DataFrame({
    'sequence': sequences,
    'vector': vectors
})

In [59]:
print(vectors[0])
print(df_vectors.iloc[0])
print(df_vectors.shape)

[-0.0596909   0.05990946  0.00394271  0.10615421  0.05276418 -0.17461887
  0.15643714  0.20113581 -0.06299914 -0.06218642 -0.02066001 -0.19236816
 -0.01941576  0.12300529 -0.12521967  0.02035039  0.13009885 -0.13047655
 -0.08348431 -0.34395748  0.02176093  0.02647692  0.07282719 -0.10732127
  0.08304158  0.04530009 -0.12766941  0.09297457 -0.19410612  0.11232468
  0.04355018  0.12572232  0.16832486 -0.05851794 -0.16557175  0.17660795
  0.02422735 -0.11960185 -0.09953406 -0.240232   -0.07601622 -0.0340294
  0.05615477 -0.11082582  0.12403981  0.07402097 -0.11344352 -0.16191149
  0.11868708  0.16883141  0.03865691 -0.03383769 -0.11027782  0.03236344
 -0.0980185   0.15631883  0.01923436  0.01115263 -0.24942973  0.16013008
  0.01325665  0.01061799  0.06939732 -0.05010538 -0.22545882  0.00169031
  0.0251941   0.03644635 -0.04589688  0.20283756 -0.01166326 -0.00256912
  0.09452346 -0.00419985  0.07762411 -0.02682507  0.08999039 -0.05872625
 -0.10756232  0.03124608 -0.02999053  0.01442881 -0.

In [63]:
X = np.stack(df_vectors['vector'].values)  
y = df['antigen.epitope'].values  

# Split the data into a training set and a test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=111)

# Train a SVM classifier
clf = SVC()
clf.fit(X_train, y_train)

# Predict the labels for the test set
y_pred = clf.predict(X_test)

# Print a classification report
print(classification_report(y_test, y_pred))

# Print the accuracy
print(clf.score(X_test, y_test))

                 precision    recall  f1-score   support

      ALGIGILTV       0.00      0.00      0.00         1
      ALYGFVPVL       0.00      0.00      0.00         1
      AVGVGKSAL       0.00      0.00      0.00         1
      CINGVCWTV       0.00      0.00      0.00         1
      CLGGLLTMV       0.00      0.00      0.00         1
      CVNGSCFTV       0.00      0.00      0.00         1
     EAAGIGILTV       0.00      0.00      0.00         1
     EFFWDANDIY       0.00      0.00      0.00         1
     ELAGIGILTV       0.00      0.00      0.00         2
     EMLFSHGLVK       0.00      0.00      0.00         1
    EPLPQGQLTAY       0.00      0.00      0.00         1
      FGDHPGHSY       0.00      0.00      0.00         1
       FLKEKGGL       0.00      0.00      0.00         1
     GADGVGKSAL       0.00      0.00      0.00         1
GELIGILNAAKVPAD       0.00      0.00      0.00         1
      GILGFVFTL       0.63      0.87      0.73        30
      GLCTLVAML       0.00    

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [71]:
param_grid = {
    'n_estimators': [5, 10, 25, 100, 200],
    'max_depth': [None,2,5,10, 15, 20],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'max_features': ['sqrt', 'log2']
}


rf = RandomForestClassifier()
stratified_kfold = StratifiedKFold(n_splits=5)

#stratified k-fold cross-validation 
#aims to ensure that each fold is representative of all strata of the data 
grid_search = GridSearchCV(estimator=rf, param_grid=param_grid, cv=stratified_kfold)
grid_search.fit(X_train, y_train)


best_params = grid_search.best_params_
best_score = grid_search.best_score_


rf_best = RandomForestClassifier(**best_params)
rf_best.fit(X_train, y_train)


y_pred = rf_best.predict(X_test)
classification_rep = classification_report(y_test, y_pred)
accuracy = accuracy_score(y_test, y_pred)

print("Best Parameters:", best_params)
print("Best Score:", best_score)
print("Classification Report:\n", classification_rep)
print("Accuracy:", accuracy)




Best Parameters: {'max_depth': None, 'max_features': 'log2', 'min_samples_leaf': 1, 'min_samples_split': 5, 'n_estimators': 100}
Best Score: 0.5181307456588355
Classification Report:
                  precision    recall  f1-score   support

      ALGIGILTV       0.00      0.00      0.00         1
      ALYGFVPVL       0.00      0.00      0.00         1
      AVGVGKSAL       0.00      0.00      0.00         1
      CINGVCWTV       0.00      0.00      0.00         1
      CLGGLLTMV       0.00      0.00      0.00         1
      CVNGSCFTV       0.00      0.00      0.00         1
     EAAGIGILTV       0.00      0.00      0.00         1
     EFFWDANDIY       0.00      0.00      0.00         1
     ELAGIGILTV       0.00      0.00      0.00         2
     EMLFSHGLVK       0.00      0.00      0.00         1
    EPLPQGQLTAY       0.50      1.00      0.67         1
      FGDHPGHSY       0.00      0.00      0.00         1
       FLKEKGGL       0.00      0.00      0.00         1
     GADGVGKSAL  

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
