### Apply Binary Classification Using Trained Models

This notebooks enables to apply the trained models to new data. The trained models are saved in the folder "trained_models".

##### Set the classification type

In [1]:
CLASSIFICATION = 'PUBLIC_PRIVATE' #'MAIT'

In [2]:
cd ..

/home/romi/projects/cvc


In [3]:
import pandas as pd
import torch

if CLASSIFICATION == 'MAIT':
    CLASSIFICATION_MODEL_PATH = "./classifiers/MAIT_classifier.pth"
    label = "MAIT_cell"
else:
    CLASSIFICATION_MODEL_PATH = "./classifiers/private_public_classifier.pth"
    label = "Public_Private_Label"

##### Load and prepare data

In [4]:
# upload and set data variables
data_dir_to_classify = "./CDR3_data/MAIT_cell_data_embeddings_8_datasets_embeddings.csv"
df_data_dir = "./CDR3_data/MAIT_cell_data_embeddings_8_datasets_embeddings.csv"
output_path = "./CDR3_data/MAIT_cell_data_embeddings_8_datasets_embeddings_pub_priv.csv"
df_for_classification = pd.read_csv(data_dir_to_classify)
data_df = pd.read_csv(df_data_dir)
df_for_classification

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,760,761,762,763,764,765,766,767,Sequences,MAIT_cell
0,-0.149587,-0.197047,-0.606274,0.246161,0.146642,-0.208196,0.247288,-0.079729,0.198015,-0.245093,...,-0.638298,0.011018,-0.319213,0.780746,-0.412318,0.522235,0.514297,0.207887,CASSVAGLLYEQYF,MAIT_cell
1,-0.003837,0.004939,-0.305185,-0.467737,0.554620,0.526504,0.037872,-0.058776,0.078017,-0.378901,...,0.325271,-0.620283,-0.451306,-0.204477,-0.042363,0.488556,0.090419,0.102915,CASSHPPGADLGGQPQHF,MAIT_cell
2,0.027054,0.238266,0.138490,-0.477724,0.421007,0.106373,-0.682291,-0.285417,0.483916,-0.346199,...,0.507948,-0.131854,0.230915,0.139384,-0.132491,0.416187,0.232041,-0.580679,CAWSVPPVQGDRTQHF,MAIT_cell
3,-0.149578,0.100374,-0.122322,-0.478098,0.454484,0.019130,0.048994,0.235329,0.132402,-0.505863,...,0.282709,0.410364,0.008110,-0.087051,-0.256943,0.420582,-0.317372,0.073733,CSARDLDSLTNGYTF,MAIT_cell
4,-0.540540,0.291595,-0.353379,-0.261123,0.426946,0.357447,-0.488219,-0.435053,0.388957,-0.267924,...,0.328469,-0.103942,0.323352,0.735837,0.272375,0.473237,0.894415,-0.243902,CAWSGEPSQAQYF,MAIT_cell
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5011,-0.329775,0.793744,-1.056300,0.181084,0.043187,-0.084797,0.497528,0.468565,0.060517,0.038454,...,-0.269269,-0.119622,-0.516974,0.407369,0.187545,0.344213,0.109828,0.269770,CASTGENNSPLHF,non-MAIT_cell
5012,-0.100326,0.304981,0.190182,-0.059698,0.107292,0.121101,-0.251405,-0.306332,0.125169,-0.237075,...,-0.172579,-0.399773,-0.378466,0.207969,-0.128755,0.727024,0.299680,0.027128,CASSVDWSGPGNTGELFF,non-MAIT_cell
5013,-0.056023,-0.094397,-0.433103,-0.062313,0.069507,-0.040174,0.057963,-0.075291,0.027573,-0.167322,...,-0.157909,0.064306,-0.047353,0.050179,-0.237840,0.765852,-0.496809,0.229118,CSARALAGGTNEQFF,non-MAIT_cell
5014,-0.422617,0.676036,-1.092209,0.034845,0.449366,0.100408,0.365332,0.275192,0.020950,0.030260,...,-0.000810,-0.096166,-0.774795,0.810294,-0.045002,0.009765,0.297485,0.518775,CASSFQGGDQPQHF,non-MAIT_cell


In [None]:
# drop index column - if necessary
# df_for_classification.drop(columns=["Unnamed: 0"], inplace=True)

In [5]:
# remove Sequences column and MAIT_cell column in df_for_classification
df_for_classification.drop(columns=["Sequences", "MAIT_cell"], inplace=True)
df_for_classification

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
0,-0.149587,-0.197047,-0.606274,0.246161,0.146642,-0.208196,0.247288,-0.079729,0.198015,-0.245093,...,0.221753,0.178292,-0.638298,0.011018,-0.319213,0.780746,-0.412318,0.522235,0.514297,0.207887
1,-0.003837,0.004939,-0.305185,-0.467737,0.554620,0.526504,0.037872,-0.058776,0.078017,-0.378901,...,-0.394406,-0.488040,0.325271,-0.620283,-0.451306,-0.204477,-0.042363,0.488556,0.090419,0.102915
2,0.027054,0.238266,0.138490,-0.477724,0.421007,0.106373,-0.682291,-0.285417,0.483916,-0.346199,...,-0.040932,-0.713956,0.507948,-0.131854,0.230915,0.139384,-0.132491,0.416187,0.232041,-0.580679
3,-0.149578,0.100374,-0.122322,-0.478098,0.454484,0.019130,0.048994,0.235329,0.132402,-0.505863,...,-0.385887,-0.625642,0.282709,0.410364,0.008110,-0.087051,-0.256943,0.420582,-0.317372,0.073733
4,-0.540540,0.291595,-0.353379,-0.261123,0.426946,0.357447,-0.488219,-0.435053,0.388957,-0.267924,...,0.051595,-0.938021,0.328469,-0.103942,0.323352,0.735837,0.272375,0.473237,0.894415,-0.243902
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5011,-0.329775,0.793744,-1.056300,0.181084,0.043187,-0.084797,0.497528,0.468565,0.060517,0.038454,...,0.523287,-0.355889,-0.269269,-0.119622,-0.516974,0.407369,0.187545,0.344213,0.109828,0.269770
5012,-0.100326,0.304981,0.190182,-0.059698,0.107292,0.121101,-0.251405,-0.306332,0.125169,-0.237075,...,0.001577,-0.410296,-0.172579,-0.399773,-0.378466,0.207969,-0.128755,0.727024,0.299680,0.027128
5013,-0.056023,-0.094397,-0.433103,-0.062313,0.069507,-0.040174,0.057963,-0.075291,0.027573,-0.167322,...,-0.025749,-0.564122,-0.157909,0.064306,-0.047353,0.050179,-0.237840,0.765852,-0.496809,0.229118
5014,-0.422617,0.676036,-1.092209,0.034845,0.449366,0.100408,0.365332,0.275192,0.020950,0.030260,...,0.410969,0.098768,-0.000810,-0.096166,-0.774795,0.810294,-0.045002,0.009765,0.297485,0.518775


In [6]:
# Convert the embeddings to PyTorch tensors and stack them
embeddings_tensors = [torch.tensor(embedding) for embedding in df_for_classification.values]
stack_tensor = torch.stack(embeddings_tensors, dim=0)
stack_tensor = stack_tensor.float()

##### Load the trained model

In [7]:
# defining the network
from torch import nn
from torch.nn import functional as F

class Net(nn.Module):
    def __init__(self, input_shape):
        super(Net,self).__init__()
        self.fc1 = nn.Linear(input_shape,128)
        self.fc2 = nn.Linear(128,32)
        self.fc3 = nn.Linear(32,1)

    def forward(self,x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

In [8]:
# Load your saved model
PATH = CLASSIFICATION_MODEL_PATH
checkpoint = torch.load(PATH)
# create model
model = Net(df_for_classification.shape[1])
# create optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)

In [9]:
# Load the model parameters
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# Set the model to evaluation mode
model.eval()

Net(
  (fc1): Linear(in_features=768, out_features=128, bias=True)
  (fc2): Linear(in_features=128, out_features=32, bias=True)
  (fc3): Linear(in_features=32, out_features=1, bias=True)
)

##### Make predictions

In [10]:
# Use the model to make predictions on the input data
with torch.no_grad():
    predictions = model.forward(stack_tensor)

binary_preds = torch.where(predictions > 0.5, torch.ones_like(predictions), torch.zeros_like(predictions))
pred_list = [int(x) for sublist in binary_preds.tolist() for x in sublist]
pred_list

[1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 1,
 1,


In [11]:
# add predictions to data_df
data_df[label] = pred_list
data_df

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,761,762,763,764,765,766,767,Sequences,MAIT_cell,Public_Private_Label
0,-0.149587,-0.197047,-0.606274,0.246161,0.146642,-0.208196,0.247288,-0.079729,0.198015,-0.245093,...,0.011018,-0.319213,0.780746,-0.412318,0.522235,0.514297,0.207887,CASSVAGLLYEQYF,MAIT_cell,1
1,-0.003837,0.004939,-0.305185,-0.467737,0.554620,0.526504,0.037872,-0.058776,0.078017,-0.378901,...,-0.620283,-0.451306,-0.204477,-0.042363,0.488556,0.090419,0.102915,CASSHPPGADLGGQPQHF,MAIT_cell,0
2,0.027054,0.238266,0.138490,-0.477724,0.421007,0.106373,-0.682291,-0.285417,0.483916,-0.346199,...,-0.131854,0.230915,0.139384,-0.132491,0.416187,0.232041,-0.580679,CAWSVPPVQGDRTQHF,MAIT_cell,0
3,-0.149578,0.100374,-0.122322,-0.478098,0.454484,0.019130,0.048994,0.235329,0.132402,-0.505863,...,0.410364,0.008110,-0.087051,-0.256943,0.420582,-0.317372,0.073733,CSARDLDSLTNGYTF,MAIT_cell,0
4,-0.540540,0.291595,-0.353379,-0.261123,0.426946,0.357447,-0.488219,-0.435053,0.388957,-0.267924,...,-0.103942,0.323352,0.735837,0.272375,0.473237,0.894415,-0.243902,CAWSGEPSQAQYF,MAIT_cell,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5011,-0.329775,0.793744,-1.056300,0.181084,0.043187,-0.084797,0.497528,0.468565,0.060517,0.038454,...,-0.119622,-0.516974,0.407369,0.187545,0.344213,0.109828,0.269770,CASTGENNSPLHF,non-MAIT_cell,1
5012,-0.100326,0.304981,0.190182,-0.059698,0.107292,0.121101,-0.251405,-0.306332,0.125169,-0.237075,...,-0.399773,-0.378466,0.207969,-0.128755,0.727024,0.299680,0.027128,CASSVDWSGPGNTGELFF,non-MAIT_cell,0
5013,-0.056023,-0.094397,-0.433103,-0.062313,0.069507,-0.040174,0.057963,-0.075291,0.027573,-0.167322,...,0.064306,-0.047353,0.050179,-0.237840,0.765852,-0.496809,0.229118,CSARALAGGTNEQFF,non-MAIT_cell,1
5014,-0.422617,0.676036,-1.092209,0.034845,0.449366,0.100408,0.365332,0.275192,0.020950,0.030260,...,-0.096166,-0.774795,0.810294,-0.045002,0.009765,0.297485,0.518775,CASSFQGGDQPQHF,non-MAIT_cell,1


In [12]:
# save labeled data
data_df.to_csv(output_path, index=False)