## How to Train and Predict with an AlignAIR Model in a Jupyter Enviorment
Introduction

This notebook demonstrates how to train and predict with an AlignAIR model in a Jupyter environment. We will cover the following steps:

1. **Training the Model**: We will train a HeavyChain AlignAIR model using a sample dataset.
2. **Saving the Model**: After training, we will save the model weights for future use.
3. **Loading Pretrained Model Weights**: We will load the saved model weights.
4. **Using the Loaded Model**: Finally, we will use the loaded model to make predictions on new sequences.



In [3]:
import os
from AlignAIR.Metadata import RandomDataConfigGenerator
from AlignAIR.Models.LightChain import LightChainAlignAIRR
import pandas as pd
from GenAIRR.data import builtin_heavy_chain_data_config,builtin_kappa_chain_data_config,builtin_lambda_chain_data_config
from AlignAIR.Data import HeavyChainDataset, LightChainDataset
from AlignAIR.Models.HeavyChain import HeavyChainAlignAIRR
from AlignAIR.Trainers import Trainer
import tensorflow as tf
import numpy as np
from AlignAIR.PostProcessing.HeuristicMatching import HeuristicReferenceMatcher
import os

# Training The Model

In this section, we will train a HeavyChain AlignAIR model using a sample dataset. The training process involves the following steps:

1. **Dataset Preparation**: Load and prepare the training dataset.
2. **Model Initialization**: Define and initialize the model.
3. **Training**: Train the model using the prepared dataset.
4. **Saving the Model**: Save the trained model weights for future use.

We will start by going over an example of how a HeavyChain AlignAIR is trained. The LightChain training process is essentially the same with slight differences.


### Dataset Requirements
Before loading the training dataset, ensure it contains the following columns:

- **sequence**: The nucleotide sequence.
- **v_sequence_start**: Start position of the V gene segment.
- **v_sequence_end**: End position of the V gene segment.
- **d_sequence_start**: Start position of the D gene segment.
- **d_sequence_end**: End position of the D gene segment.
- **j_sequence_start**: Start position of the J gene segment.
- **j_sequence_end**: End position of the J gene segment.
- **v_call**: V gene call.
- **d_call**: D gene call.
- **j_call**: J gene call.
- **mutation_rate**: Mutation rate in the sequence.
- **indels**: Insertions and deletions in the sequence.
- **productive**: Whether the sequence is productive or not.

In [1]:
# there are the required columns your training dataset
['sequence', 'v_sequence_start', 'v_sequence_end', 'd_sequence_start',
                                      'd_sequence_end', 'j_sequence_start', 'j_sequence_end', 'v_call',
                                      'd_call', 'j_call', 'mutation_rate', 'indels', 'productive']

['sequence',
 'v_sequence_start',
 'v_sequence_end',
 'd_sequence_start',
 'd_sequence_end',
 'j_sequence_start',
 'j_sequence_end',
 'v_call',
 'd_call',
 'j_call',
 'mutation_rate',
 'indels',
 'productive']

## Loading the Training Dataset

To begin, we need to load our training dataset into the `HeavyChainDataset` class. Follow these steps:

1. **Specify the Dataset Path**: Ensure you have the correct path to your dataset file (TSV, CSV, or FASTA format). Replace `heavy_chain_dataset_path` with the actual file path.

2. **Create Data Configuration**: Use the `builtin_heavy_chain_data_config()` function to load the builtin BCR HeavyChain configuration instance. This ensures that the alleles in your dataset match the reference alleles.
(You can always load the pickled custom DataConfig instance you create using GenAIR for your own data/species/reference)

3. **Instantiate HeavyChainDataset**: Create an instance of the `HeavyChainDataset` class with the dataset path and data configuration. Set `batch_read_file` to `True` for efficient handling of large datasets and define the `max_sequence_length` (e.g., 576).

This setup prepares your dataset for further analysis or model training.

In [None]:
# Load Your Training Dataset into a HeavyChainDataset Instance
heavy_chain_dataset_path = '/path/to/your/dataset.csv' # replace with your path, can be tsv,csv or fasta
dataconfig_insatnce = builtin_heavy_chain_data_config() # make sure the dataconfig you are using matches your dataset (the alleles in your dataset should share the same reference for the V,D and J alleles as the dataconfig object)
train_dataset = HeavyChainDataset(data_path=heavy_chain_dataset_path,
                                          dataconfig=dataconfig_insatnce ,batch_read_file=True,
                                          max_sequence_length=576)


## Setting Up the Trainer

In this section, we will set up the `Trainer` class to train our `HeavyChainAlignAIRR` model. Follow these steps:

1. **Ensure Consistency**: Make sure the `train_dataset` object has the same `DataConfig` as during training to ensure consistency.

2. **Initialize Trainer**: Create an instance of the `Trainer` class with the following parameters:
   - `model`: The `HeavyChainAlignAIRR` model. (can be replaced with LightChainAlignAIRR or any future version of the AlignAIR)
   - `dataset`: The `train_dataset` object such as HeavyChainDataset or LightChain Dataset for example.
   - `epochs`: Number of epochs (e.g., 1).
   - `steps_per_epoch`: Number of steps per epoch (e.g., 512).
   - `verbose`: Verbosity level (e.g., 1 for detailed logging).
   - `classification_metric`: List of AUC metrics. (only used for logging)
   - `regression_metric`: Binary cross-entropy loss. (only used for logging)
   - `optimizers_params`: Dictionary with optimizer parameters (e.g., gradient clipping).



In [None]:
# define a Trainer instance which will handle the initialization and training process of the model
trainer = Trainer(
    model=HeavyChainAlignAIRR,
    dataset=train_dataset,
    epochs=1,
    steps_per_epoch=max(1, train_dataset.data_length // 10),
    verbose=1,
    classification_metric=[tf.keras.metrics.AUC(), tf.keras.metrics.AUC(), tf.keras.metrics.AUC()],
    regression_metric=tf.keras.losses.binary_crossentropy,
    optimizers_params={"clipnorm": 1},
)


tanh


In [16]:
# Train the model
trainer.train()



### Saving Your Trained Model Weights for Future Use

In [None]:
trainer.model.save_weights(f'your/path/model_name')

## Configuring the Trainer for Inference

In this section, we will set up the `Trainer` class to configure our `HeavyChainAlignAIRR` model for inference. Follow these steps:

1. **Ensure Consistency**: Make sure the `train_dataset` object has the same `DataConfig` as during training to ensure consistency. Use a small dataset sample (e.g., 10 samples) for configuration.

2. **Initialize Trainer**: Create an instance of the `Trainer` class with the following parameters:
   - `model`: The `HeavyChainAlignAIRR` model.
   - `dataset`: The `train_dataset` object.
   - `epochs`: Number of epochs (e.g., 1).
   - `steps_per_epoch`: Number of steps per epoch (e.g., 512).
   - `verbose`: Verbosity level (e.g., 1 for detailed logging).
   - `classification_metric`: List of AUC metrics.
   - `regression_metric`: Binary cross-entropy loss.
   - `optimizers_params`: Dictionary with optimizer parameters (e.g., gradient clipping).

3. **Build the Model**: Use the `build` method to define the model architecture with the input shape (e.g., tokenized sequence of shape (576, 1)).

4. **Load Model Weights**: Load the pre-trained model weights from the specified checkpoint path (`MODEL_CHECKPOINT`).

This setup prepares the `Trainer` for configuring the model for inference with the pre-trained weights.

In [None]:
# make sure the trainer is defiend the same way it was defined in the training in terms of the train_dataset provided to it (the train_dataset object must have the same DataConfig asscotiated to it)
# it will not be used for inference only to define the model, my suggestion is to use a the same dataconfig object as in training and a small dataset sample of 10 samples just for configuring the model
trainer = Trainer(
    model=HeavyChainAlignAIRR,
    dataset=train_dataset,
    epochs=1,
    steps_per_epoch=512,
    verbose=1,
    classification_metric=[tf.keras.metrics.AUC(), tf.keras.metrics.AUC(), tf.keras.metrics.AUC()],
    regression_metric=tf.keras.losses.binary_crossentropy,
    optimizers_params={"clipnorm": 1},
)

# build the model so trained the weights can be mounted 
trainer.model.build({'tokenized_sequence': (576, 1)})
MODEL_CHECKPOINT = f'your/path/model_name'
trainer.model.load_weights(MODEL_CHECKPOINT)

## Running the Prediction Pipeline

In this section, we will set up and run a prediction pipeline using the `AlignAIR` library. This pipeline processes input data, makes predictions, and performs various post-processing steps to generate final results. Follow these steps:

1. **Import Necessary Modules**: Import the required modules and classes from the `AlignAIR` library for preprocessing, model loading, batch processing, and post-processing tasks.

2. **Create Logger**: Create a logger named `PipelineLogger` to log the process. Logging helps in tracking progress and debugging issues.

3. **Instantiate PredictObject**: Create an instance of the `PredictObject` class with the necessary arguments and the logger. This object will hold all the predicted information and processed results throughout the pipeline.

4. **Define Pipeline Steps**: Define the pipeline as a list of steps, each represented by an instance of a specific class from the `AlignAIR` library. These steps include:
   - Loading configuration
   - Extracting file names
   - Counting samples
   - Loading models
   - Processing and predicting batches
   - Cleaning up raw predictions
   - Correcting segmentations
   - Applying thresholds to distill assignments
   - Aligning predicted segments with germline sequences
   - Translating alleles to IMGT format
   - Finalizing post-processing and saving results as a CSV file

5. **Execute Pipeline**: Run the pipeline by executing each step sequentially. The `execute` method of each step processes the `predict_object` and updates it with the results of that step. This ensures that the data flows through all necessary stages to produce the final output.

By following these steps, you will be able to set up and run the prediction pipeline to generate the desired results.

In [None]:
import argparse
"""
Here we load all the parameters needed for using the complete AlignAIR suite, including the post-processing and pre-processing steps. 
This is usually done via Docker or CLI, thus we imitate the parameters one would pass in the command line and load all of them into an argparse namespace.
"""

args = argparse.Namespace(
    mode=None,
    config_file='',# this is for the YAML file mode and is not relevant here,
    model_checkpoint=r'C:\Users\tomas\Desktop\AlignAIRR\tests\AlignAIRR_S5F_OGRDB_V8_S5F_576_Balanced_V2', # check point of trained model weights
    save_path='/Users/tomas/Downloads/', # path for the saved results
    chain_type='heavy', # type of chain i.e heavy/light
    sequences=r'C:\Users\tomas\Desktop\AlignAIRR\tests\sample_HeavyChain_dataset.csv', # the target sequences, can be csv/tsv/FASTA file, csv and tsv must have a column called "sequecne"
    lambda_data_config='D', # if custom lambda dataconfig is required else leave as "D"
    kappa_data_config='D', # if custom kappa dataconfig is required else leave as "D"
    heavy_data_config='D', # if custom heavy chain dataconfig is required else leave as "D"
    max_input_size=576, # max input size, has to match the max_size of the trained model
    batch_size=8, # the maximum number of samples per batch processed by the model
    v_allele_threshold=0.1, # the threshold for v allele call likelihood consideration
    d_allele_threshold=0.1, # the threshold for d allele call likelihood consideration
    j_allele_threshold=0.1, # the threshold for j allele call likelihood consideration
    v_cap=3, # the maximum number of v allele calls the model will select based on the likelihood predicted and the threshold
    d_cap=3, # the maximum number of d allele calls the model will select based on the likelihood predicted and the threshold
    j_cap=3, # the maximum number of j allele calls the model will select based on the likelihood predicted and the threshold
    translate_to_asc=False, # in case ASC were derived for the DataConfig, this will transalte the ASC's to IMGT allele names
    fix_orientation=True, # this flag controls wheter the preprocessing should check if there are reversed sequences and orient them properly
    custom_orientation_pipeline_path=None # in case you have a custom model you will need to create a custom orientation pipeline, and specify the path here
)

In [29]:
from AlignAIR.PostProcessing.Steps.allele_threshold_step import MaxLikelihoodPercentageThresholdApplicationStep, \
    ConfidenceMethodThresholdApplicationStep
from AlignAIR.PostProcessing.Steps.clean_up_steps import CleanAndArrangeStep
from AlignAIR.PostProcessing.Steps.finalization_and_packaging_steps import FinalizationStep
from AlignAIR.PostProcessing.Steps.germline_alignment_steps import AlleleAlignmentStep
from AlignAIR.PostProcessing.Steps.segmentation_correction_steps import SegmentCorrectionStep
from AlignAIR.PostProcessing.Steps.translate_to_imgt_step import TranslationStep
from AlignAIR.PredictObject.PredictObject import PredictObject
from AlignAIR.Preprocessing.Steps.batch_processing_steps import BatchProcessingStep
from AlignAIR.Preprocessing.Steps.dataconfig_steps import ConfigLoadStep
from AlignAIR.Preprocessing.Steps.file_steps import FileNameExtractionStep, FileSampleCounterStep
from AlignAIR.Preprocessing.Steps.model_loading_steps import ModelLoadingStep
import logging

# create a logger to log the process
logger = logging.getLogger('PipelineLogger')
# set up t he predict objecet, here all the predicted information and processed results will be saved
predict_object = PredictObject(args, logger=logger)

# define the steps in the prediction pipeline
steps = [
    ConfigLoadStep("Load Config", logger),
    FileNameExtractionStep('Get File Name', logger),
    FileSampleCounterStep('Count Samples in File', logger),
    ModelLoadingStep('Load Models', logger),
    BatchProcessingStep("Process and Predict Batches", logger),
    CleanAndArrangeStep("Clean Up Raw Prediction", logger),
    SegmentCorrectionStep("Correct Segmentations", logger),
    MaxLikelihoodPercentageThresholdApplicationStep("Apply Max Likelihood Threshold to Distill Assignments", logger),
    AlleleAlignmentStep("Align Predicted Segments with Germline", logger),
    TranslationStep("Translate ASC's to IMGT Alleles", logger),
    FinalizationStep("Finalize Post Processing and Save Csv", logger)
]

#run the pipeline
for step in steps:
    predict_object = step.execute(predict_object)

tanh


Processing V Likelihoods:   0%|          | 0/1000 [00:00<?, ?it/s]

Processing J Likelihoods:   0%|          | 0/1000 [00:00<?, ?it/s]

Processing D Likelihoods:   0%|          | 0/1000 [00:00<?, ?it/s]

Matching V Germlines:   0%|          | 0/1000 [00:00<?, ?it/s]

Matching J Germlines:   0%|          | 0/1000 [00:00<?, ?it/s]

Matching D Germlines:   0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
# raw prediction made by the model before any processing can be found here:
predict_object.results['predictions']