## How to Train and Predict with AlignAIR v2.0 Models in a Jupyter Environment

This notebook demonstrates how to train and predict with AlignAIR's unified v2.0 architecture in a Jupyter environment. We will cover the following steps:

1. **Training the Model**: We will train a SingleChainAlignAIR or MultiChainAlignAIR model using a sample dataset.
2. **Saving the Model**: After training, we will save the model weights for future use.
3. **Loading Pretrained Model Weights**: We will load the saved model weights.
4. **Using the Loaded Model**: Finally, we will use the loaded model to make predictions on new sequences.

### What's New in v2.0
- **Unified Architecture**: `SingleChainAlignAIR` and `MultiChainAlignAIR` replace chain-specific models
- **Dynamic GenAIRR Integration**: Built-in dataconfigs for major receptor types
- **Multi-Chain Support**: Native support for mixed receptor analysis
- **Streamlined API**: Simplified training and prediction workflows



In [None]:
import os
import pandas as pd
import tensorflow as tf
import numpy as np

# Import v2.0 unified architecture components
from AlignAIR.Models.SingleChainAlignAIR import SingleChainAlignAIR
from AlignAIR.Models.MultiChainAlignAIR import MultiChainAlignAIR
from AlignAIR.Data.SingleChainDataset import SingleChainDataset
from AlignAIR.Data.MultiChainDataset import MultiChainDataset
from AlignAIR.Data.MultiDataConfigContainer import MultiDataConfigContainer
from AlignAIR.Trainers import Trainer

# GenAIRR integration for dynamic data configuration
from GenAIRR.data import (
    builtin_heavy_chain_data_config,
    builtin_kappa_chain_data_config,
    builtin_lambda_chain_data_config
)

# Legacy imports for backward compatibility
from AlignAIR.Metadata import RandomDataConfigGenerator
from AlignAIR.PostProcessing.HeuristicMatching import HeuristicReferenceMatcher

# Training The Model

In this section, we will train a SingleChainAlignAIR or MultiChainAlignAIR model using the unified v2.0 architecture. The training process involves the following steps:

1. **Dataset Preparation**: Load and prepare the training dataset using the new unified dataset classes.
2. **Model Selection**: Choose between SingleChainAlignAIR (optimized for single receptor type) or MultiChainAlignAIR (supports multiple receptor types).
3. **Training**: Train the model using the prepared dataset.
4. **Saving the Model**: Save the trained model weights for future use.

### Model Selection Guide
- **SingleChainAlignAIR**: Use when training on a single receptor type (e.g., only IGH sequences)
- **MultiChainAlignAIR**: Use when training on mixed receptor types (e.g., IGH + IGK + IGL)


### Dataset Requirements
Before loading the training dataset, ensure it contains the following columns:

- **sequence**: The nucleotide sequence.
- **v_sequence_start**: Start position of the V gene segment.
- **v_sequence_end**: End position of the V gene segment.
- **d_sequence_start**: Start position of the D gene segment.
- **d_sequence_end**: End position of the D gene segment.
- **j_sequence_start**: Start position of the J gene segment.
- **j_sequence_end**: End position of the J gene segment.
- **v_call**: V gene call.
- **d_call**: D gene call.
- **j_call**: J gene call.
- **mutation_rate**: Mutation rate in the sequence.
- **indels**: Insertions and deletions in the sequence.
- **productive**: Whether the sequence is productive or not.

In [1]:
# there are the required columns your training dataset
['sequence', 'v_sequence_start', 'v_sequence_end', 'd_sequence_start',
                                      'd_sequence_end', 'j_sequence_start', 'j_sequence_end', 'v_call',
                                      'd_call', 'j_call', 'mutation_rate', 'indels', 'productive']

['sequence',
 'v_sequence_start',
 'v_sequence_end',
 'd_sequence_start',
 'd_sequence_end',
 'j_sequence_start',
 'j_sequence_end',
 'v_call',
 'd_call',
 'j_call',
 'mutation_rate',
 'indels',
 'productive']

## Loading the Training Dataset

### Option 1: Single-Chain Training (SingleChainDataset)

To train on a single receptor type, use the `SingleChainDataset` class:

1. **Specify the Dataset Path**: Ensure you have the correct path to your dataset file (TSV, CSV, or FASTA format).
2. **Create Data Configuration**: Use built-in GenAIRR dataconfigs or load custom ones.
3. **Instantiate SingleChainDataset**: Create an instance with the dataset path and data configuration.

### Option 2: Multi-Chain Training (MultiChainDataset)

To train on multiple receptor types, use the `MultiChainDataset` class with `MultiDataConfigContainer`:

1. **Create MultiDataConfigContainer**: Combine multiple GenAIRR dataconfigs for different chain types.
2. **Instantiate MultiChainDataset**: The dataset will automatically handle mixed receptor types.


In [None]:
# Example 1: Single-Chain Training Dataset (e.g., Heavy Chain only)
dataset_path = '/path/to/your/dataset.csv'  # replace with your path, can be tsv, csv or fasta
dataconfig_instance = builtin_heavy_chain_data_config()  # or builtin_kappa_chain_data_config(), builtin_lambda_chain_data_config()

# For single-chain training
train_dataset = SingleChainDataset(
    data_path=dataset_path,
    dataconfig=dataconfig_instance,
    use_streaming=True,
    max_sequence_length=576
)

# Example 2: Multi-Chain Training Dataset (mixed receptor types)
# Create a multi-dataconfig container for multiple chain types
multi_dataconfig = MultiDataConfigContainer({
    'IGH': builtin_heavy_chain_data_config(),
    'IGK': builtin_kappa_chain_data_config(), 
    'IGL': builtin_lambda_chain_data_config()
})

# For multi-chain training (when dataset contains mixed receptor types)
multi_train_dataset = MultiChainDataset(
    data_path=dataset_path,
    multi_dataconfig=multi_dataconfig,
    use_streaming=True,
    max_sequence_length=576
)

# Use train_dataset for SingleChainAlignAIR or multi_train_dataset for MultiChainAlignAIR

## Setting Up the Trainer

In this section, we will set up the `Trainer` class to train our unified AlignAIR v2.0 model. Follow these steps:

1. **Choose Model Architecture**: Select between `SingleChainAlignAIR` or `MultiChainAlignAIR` based on your dataset.

2. **Initialize Trainer**: Create an instance of the `Trainer` class with the following parameters:
   - `model`: Either `SingleChainAlignAIR` (for single receptor type) or `MultiChainAlignAIR` (for multi-chain analysis)
   - `dataset`: The corresponding dataset object (`SingleChainDataset` or `MultiChainDataset`)
   - `epochs`: Number of epochs (e.g., 1)
   - `steps_per_epoch`: Number of steps per epoch (e.g., 512)
   - `verbose`: Verbosity level (e.g., 1 for detailed logging)
   - `classification_metric`: List of AUC metrics (only used for logging)
   - `regression_metric`: Binary cross-entropy loss (only used for logging)
   - `optimizers_params`: Dictionary with optimizer parameters (e.g., gradient clipping)


In [None]:
# Option 1: Single-Chain Model Training
trainer_single = Trainer(
    model=SingleChainAlignAIR,
    dataset=train_dataset,  # SingleChainDataset instance
    epochs=1,
    steps_per_epoch=max(1, train_dataset.data_length // 10),
    verbose=1,
    classification_metric=[tf.keras.metrics.AUC(), tf.keras.metrics.AUC(), tf.keras.metrics.AUC()],
    regression_metric=tf.keras.losses.binary_crossentropy,
    optimizers_params={"clipnorm": 1},
)

# Option 2: Multi-Chain Model Training
trainer_multi = Trainer(
    model=MultiChainAlignAIR,
    dataset=multi_train_dataset,  # MultiChainDataset instance  
    epochs=1,
    steps_per_epoch=max(1, multi_train_dataset.data_length // 10),
    verbose=1,
    classification_metric=[tf.keras.metrics.AUC(), tf.keras.metrics.AUC(), tf.keras.metrics.AUC()],
    regression_metric=tf.keras.losses.binary_crossentropy,
    optimizers_params={"clipnorm": 1},
)

# Choose the appropriate trainer based on your use case
trainer = trainer_single  # or trainer_multi for multi-chain analysis

tanh


In [16]:
# Train the model
trainer.train()



### Saving Your Trained Model Weights for Future Use

In [None]:
trainer.model.save_weights(f'your/path/model_name')

## Configuring the Model for Inference

In this section, we will set up the `Trainer` class to configure our unified AlignAIR v2.0 model for inference. Follow these steps:

1. **Ensure Consistency**: Make sure the dataset object has the same `DataConfig` or `MultiDataConfigContainer` as during training to ensure consistency. Use a small dataset sample (e.g., 10 samples) for configuration.

2. **Initialize Trainer**: Create an instance of the `Trainer` class matching your training setup:
   - For single-chain: Use `SingleChainAlignAIR` with `SingleChainDataset`
   - For multi-chain: Use `MultiChainAlignAIR` with `MultiChainDataset`

3. **Build the Model**: Use the `build` method to define the model architecture with the input shape (e.g., tokenized sequence of shape (576, 1)).

4. **Load Model Weights**: Load the pre-trained model weights from the specified checkpoint path (`MODEL_CHECKPOINT`).

This setup prepares the `Trainer` for configuring the model for inference with the pre-trained weights.

In [None]:
# Configure trainer for inference - ensure same setup as training
# Option 1: Single-Chain Model Inference
trainer_inference = Trainer(
    model=SingleChainAlignAIR,  # Match your training model
    dataset=train_dataset,  # Same dataconfig as training, small sample for configuration
    epochs=1,
    steps_per_epoch=512,
    verbose=1,
    classification_metric=[tf.keras.metrics.AUC(), tf.keras.metrics.AUC(), tf.keras.metrics.AUC()],
    regression_metric=tf.keras.losses.binary_crossentropy,
    optimizers_params={"clipnorm": 1},
)

# Option 2: Multi-Chain Model Inference (if you trained with MultiChainAlignAIR)
# trainer_inference = Trainer(
#     model=MultiChainAlignAIR,
#     dataset=multi_train_dataset,  # Same multi_dataconfig as training
#     epochs=1,
#     steps_per_epoch=512,
#     verbose=1,
#     classification_metric=[tf.keras.metrics.AUC(), tf.keras.metrics.AUC(), tf.keras.metrics.AUC()],
#     regression_metric=tf.keras.losses.binary_crossentropy,
#     optimizers_params={"clipnorm": 1},
# )

# Build the model so trained weights can be mounted
trainer_inference.model.build({'tokenized_sequence': (576, 1)})
MODEL_CHECKPOINT = 'your/path/model_name'
trainer_inference.model.load_weights(MODEL_CHECKPOINT)

## Running the Prediction Pipeline

In this section, we will set up and run a prediction pipeline using the `AlignAIR` library. This pipeline processes input data, makes predictions, and performs various post-processing steps to generate final results. Follow these steps:

1. **Import Necessary Modules**: Import the required modules and classes from the `AlignAIR` library for preprocessing, model loading, batch processing, and post-processing tasks.

2. **Create Logger**: Create a logger named `PipelineLogger` to log the process. Logging helps in tracking progress and debugging issues.

3. **Instantiate PredictObject**: Create an instance of the `PredictObject` class with the necessary arguments and the logger. This object will hold all the predicted information and processed results throughout the pipeline.

4. **Define Pipeline Steps**: Define the pipeline as a list of steps, each represented by an instance of a specific class from the `AlignAIR` library. These steps include:
   - Loading configuration
   - Extracting file names
   - Counting samples
   - Loading models
   - Processing and predicting batches
   - Cleaning up raw predictions
   - Correcting segmentations
   - Applying thresholds to distill assignments
   - Aligning predicted segments with germline sequences
   - Translating alleles to IMGT format
   - Finalizing post-processing and saving results as a CSV file

5. **Execute Pipeline**: Run the pipeline by executing each step sequentially. The `execute` method of each step processes the `predict_object` and updates it with the results of that step. This ensures that the data flows through all necessary stages to produce the final output.

By following these steps, you will be able to set up and run the prediction pipeline to generate the desired results.

In [None]:
import argparse
"""
Here we load all the parameters needed for using the complete AlignAIR v2.0 suite, including the post-processing and pre-processing steps. 
This is usually done via CLI, thus we imitate the parameters one would pass in the command line and load all of them into an argparse namespace.

Note: In v2.0, --chain-type has been replaced with --genairr-dataconfig for dynamic GenAIRR integration.
"""

args = argparse.Namespace(
    mode=None,
    config_file='',  # this is for the YAML file mode and is not relevant here
    model_checkpoint=r'C:\Users\tomas\Desktop\AlignAIRR\tests\AlignAIRR_S5F_OGRDB_V8_S5F_576_Balanced_V2',  # checkpoint of trained model weights
    save_path='/Users/tomas/Downloads/',  # path for the saved results
    genairr_dataconfig='HUMAN_IGH_OGRDB',  # NEW v2.0: GenAIRR dataconfig (replaces chain_type)
    sequences=r'C:\Users\tomas\Desktop\AlignAIRR\tests\sample_HeavyChain_dataset.csv',  # target sequences (csv/tsv/FASTA)
    max_input_size=576,  # max input size, must match the trained model
    batch_size=8,  # maximum number of samples per batch processed by the model
    v_allele_threshold=0.1,  # threshold for v allele call likelihood consideration
    d_allele_threshold=0.1,  # threshold for d allele call likelihood consideration
    j_allele_threshold=0.1,  # threshold for j allele call likelihood consideration
    v_cap=3,  # maximum number of v allele calls based on likelihood and threshold
    d_cap=3,  # maximum number of d allele calls based on likelihood and threshold
    j_cap=3,  # maximum number of j allele calls based on likelihood and threshold
    translate_to_asc=False,  # translate ASC's to IMGT allele names if ASC were derived
    fix_orientation=True,  # check and orient reversed sequences properly
    custom_orientation_pipeline_path=None  # path to custom orientation pipeline if needed
)

# Available GenAIRR dataconfigs in v2.0:
# - 'HUMAN_IGH_OGRDB' (Heavy chain)
# - 'HUMAN_IGK_OGRDB' (Kappa light chain)  
# - 'HUMAN_IGL_OGRDB' (Lambda light chain)
# - 'HUMAN_TCRB_IMGT' (TCR Beta chain)
# - Custom dataconfig path for your own species/references

In [29]:
from AlignAIR.PostProcessing.Steps.allele_threshold_step import MaxLikelihoodPercentageThresholdApplicationStep, \
    ConfidenceMethodThresholdApplicationStep
from AlignAIR.PostProcessing.Steps.clean_up_steps import CleanAndArrangeStep
from AlignAIR.PostProcessing.Steps.finalization_and_packaging_steps import FinalizationStep
from AlignAIR.PostProcessing.Steps.germline_alignment_steps import AlleleAlignmentStep
from AlignAIR.PostProcessing.Steps.segmentation_correction_steps import SegmentCorrectionStep
from AlignAIR.PostProcessing.Steps.translate_to_imgt_step import TranslationStep
from AlignAIR.PredictObject.PredictObject import PredictObject
from AlignAIR.Preprocessing.Steps.batch_processing_steps import BatchProcessingStep
from AlignAIR.Preprocessing.Steps.dataconfig_steps import ConfigLoadStep
from AlignAIR.Preprocessing.Steps.file_steps import FileNameExtractionStep, FileSampleCounterStep
from AlignAIR.Preprocessing.Steps.model_loading_steps import ModelLoadingStep
import logging
from AlignAIR.Step.Step import Step

# create a logger to log the process
logger = logging.getLogger('PipelineLogger')
Step.set_logger(logger)

# set up t he predict objecet, here all the predicted information and processed results will be saved
predict_object = PredictObject(args, logger=logger)

# define the steps in the prediction pipeline
steps = [
    ConfigLoadStep("Load Config"),
    FileNameExtractionStep('Get File Name'),
    FileSampleCounterStep('Count Samples in File'),
    ModelLoadingStep('Load Models'),
    BatchProcessingStep("Process and Predict Batches"),
    CleanAndArrangeStep("Clean Up Raw Prediction"),
    SegmentCorrectionStep("Correct Segmentations"),
    MaxLikelihoodPercentageThresholdApplicationStep("Apply Max Likelihood Threshold to Distill Assignments"),
    AlleleAlignmentStep("Align Predicted Segments with Germline"),
    TranslationStep("Translate ASC's to IMGT Alleles"),
    FinalizationStep("Finalize Post Processing and Save Csv")
]

#run the pipeline
for step in steps:
    predict_object = step.execute(predict_object)

tanh


Processing V Likelihoods:   0%|          | 0/1000 [00:00<?, ?it/s]

Processing J Likelihoods:   0%|          | 0/1000 [00:00<?, ?it/s]

Processing D Likelihoods:   0%|          | 0/1000 [00:00<?, ?it/s]

Matching V Germlines:   0%|          | 0/1000 [00:00<?, ?it/s]

Matching J Germlines:   0%|          | 0/1000 [00:00<?, ?it/s]

Matching D Germlines:   0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
# raw prediction made by the model before any processing can be found here:
predict_object.results['predictions']