In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import ablang2

# **0. Sequence input and its format**

AbLang2 takes as input either the individual heavy variable domain (VH), light variable domain (VL), or the full variable domain (Fv).

Each record (antibody) needs to be a list with the VH as the first element and the VL as the second. If either the VH or VL is not known, leave an empty string.

An asterisk (\*) is used for masking. It is recommended to mask residues which you are interested in mutating.

**NB:** It is important that the VH and VL sequence is ordered correctly.

In [None]:
seq1 = [
    'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS', # VH sequence
    'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK' # VL sequence
]
seq2 = [
    'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT',
    'PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
]
seq3 = [
    'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS',
    '' # The VL sequence is not known, so an empty string is left instead. 
]
seq4 = [
    '',
    'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
]
seq5 = [
    'EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS', # (*) is used to mask certain residues
    'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
]

all_seqs = [seq1, seq2, seq3, seq4, seq5]
only_both_chains_seqs = [seq1, seq2, seq5]

# **1. How to use AbLang2**

AbLang2 can be downloaded and used in its raw form as seen below. For convenience, we have also developed different "modes" which can be used for specific use cases (see Section 2) 

In [None]:
# Download and initialise the model
ablang = ablang2.pretrained(model_to_use='ablang2-paired', random_init=False, ncpu=1, device='cpu')

# Tokenize input sequences
seq = f"{seq1[0]}|{seq1[1]}" # VH first, VL second, with | used to separated the two sequences 
tokenized_seq = ablang.tokenizer([seq], pad=True, w_extra_tkns=False, device="cpu")
        
# Generate rescodings
with torch.no_grad():
    rescoding = ablang.AbRep(tokenized_seq).last_hidden_states

# Generate logits/likelihoods
with torch.no_grad():
    likelihoods = ablang.AbLang(tokenized_seq)

# **2. Different modes for specific usecases**

AbLang2 has already been implemented for a variety of different usecases. The benefit of these modes is that they handle extra tokens such as start, stop and separation tokens.

1. seqcoding: Generates sequence representations for each sequence
2. rescoding: Generates residue representations for each residue in each sequence
3. likelihood: Generates likelihoods for each amino acid at each position in each sequence
4. probability: Generates probabilities for each amino acid at each position in each sequence
5. pseudo_log_likelihood: Returns the pseudo log likelihood for a sequence (based on masking each residue one at a time)
6. confidence: Returns a fast calculation of the log likelihood for a sequence (based on a single pass with no masking)
7. restore: Restores masked residues

### **AbLang2 can also align the resulting representations using ANARCI**

This can be done for 'rescoding', 'likelihood', and 'probability'. This is done by setting the argument "align=True".

**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light.

### **The align argument can also be used to restore variable missing lengths**

For this, use "align=True" with the 'restore' mode.

In [None]:
ablang = ablang2.pretrained()

valid_modes = [
    'seqcoding', 'rescoding', 'likelihood', 'probability',
    'pseudo_log_likelihood', 'confidence', 'restore' 
]

## **seqcoding** 

The seqcodings represents each sequence as a 480 sized embedding. It is derived from averaging across each rescoding embedding for a given sequence, including extra tokens. 

**NB:** Seqcodings can also be derived in other ways like using the sum or averaging across only parts of the input such as the CDRs. For such cases please use and adapt the below rescoding.

In [None]:
ablang(all_seqs, mode='seqcoding')

## **rescoding / likelihood / probability**

The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.

**NB:** The output includes extra tokens (start, stop and separation tokens) in the format "<VH_seq>|<VL_seq>". The length of the output is therefore 5 longer than the VH and VL.

**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the "stepwise_masking" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass.

In [None]:
ablang(all_seqs, mode='rescoding', stepwise_masking = False)

## **Align rescoding/likelihood/probability output**

For the 'rescoding', 'likelihood', and 'probability' modes, the output can also be aligned using the argument "align=True".

This is done using the antibody numbering tool ANARCI, and requires manually installing **Pandas** and **[ANARCI](https://github.com/oxpig/ANARCI)**.

**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light.

In [None]:
results = ablang(only_both_chains_seqs, mode='likelihood', align=True)

print(results.number_alignment)
print(results.aligned_seqs)
print(results.aligned_embeds)

## **Pseudo log likelihood and Confidence scores**

The pseudo log likelihood and confidence represents two methods for calculating the uncertainty for the input sequence.

- pseudo_log_likelihood: For each position, the pseudo log likelihood is calculated when predicting the masked residue. The final score is an average across the whole input. This is similar to the approach taken in the ESM-2 paper for calculating pseudo perplexity [(Lin et al., 2023)](https://doi.org/10.1126/science.ade2574).

- confidence: For each position, the log likelihood is calculated without masking the residue. The final score is an average across the whole input. 

**NB:** The **confidence is fast** to compute, requiring only a single forward pass per input. **Pseudo log likelihood is slow** to calculate, requiring L forward passes per input, where L is the length of the input.

**NB:** It is recommended to use **pseudo log likelihood for final results** and **confidence for exploratory work**.

In [None]:
results = ablang(all_seqs, mode='pseudo_log_likelihood')
np.exp(-results) # convert to pseudo perplexity

In [None]:
results = ablang(all_seqs, mode='confidence')
np.exp(-results)

## **restore**

This mode can be used to restore masked residues, and fragmented regions with "align=True". 

In [None]:
restored = ablang(only_both_chains_seqs, mode='restore')
restored

In [None]:
restored = ablang(only_both_chains_seqs, mode='restore', align = True)
restored