In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import torch
import ablang2

In [2]:
seq1 = [
    'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS',
    'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
]
seq2 = [
    'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT',
    'PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
]
seq3 = [
    'EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS',
    ''
]
seq4 = [
    '',
    'DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
]
seq5 = [
    'EVQ***SGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCAR**PGHGAAFMDVWGTGTTVTVSS',
    'DIQLTQSPLSLPVTLGQPASISCRSS*SLEASDTNIYLSWFQQRPGQSPRRLIYKI*NRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
]

all_seqs = [seq1, seq2, seq3, seq4, seq5]
only_both_chains_seqs = [seq1, seq2, seq5]

# **1. How to use AbLang2**

AbLang2 can be downloaded and used in its raw form as seen below. For convenience, we have also developed different "modes" which can be used for specific use cases (see Section 2) 

In [21]:
# Download and initialise the model
ablang = ablang2.pretrained(model_to_use='ablang2-paired', random_init=False, ncpu=1, device='cpu')

# Tokenize input sequences
seq = f"{seq1[0]}|{seq1[1]}"
tokenized_seq = ablang.tokenizer([seq], pad=True, w_extra_tkns=False, device="cpu")
        
# Generate rescodings
with torch.no_grad():
    rescoding = ablang.AbRep(tokenized_seq).last_hidden_states

# Generate logits/likelihoods
with torch.no_grad():
    likelihoods = ablang.AbLang(tokenized_seq)

# **2. Different modes for specific usecases**

AbLang2 has already been implemented for a variety of different usecases. The benefit of these modes is that they handle extra tokens such as start, stop and separation tokens.

1. seqcoding: Generates sequence representations for each sequence
2. rescoding: Generates residue representations for each residue in each sequence
3. likelihood: Generates likelihoods for each amino acid at each position in each sequence
4. probability: Generates probabilities for each amino acid at each position in each sequence
5. pseudo_log_likelihood: Returns the pseudo log likelihood for a sequence (based on masking each residue one at a time)
6. confidence: Returns a fast calculation of the log likelihood for a sequence (based on a single pass with no masking)
7. restore: Restores masked residues

### **AbLang2 can also align the resulting representations using ANARCI**

This can be done for 'rescoding', 'likelihood', and 'probability'. This is done by setting the argument "align=True".

**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light.

### **The align argument can also be used to restore variable missing lengths**

For this, use "align=True" with the 'restore' mode.

In [4]:
ablang = ablang2.pretrained()

valid_modes = [
    'seqcoding', 'rescoding', 'likelihood', 'probability',
    'pseudo_log_likelihood', 'confidence', 'restore' 
]

## **seqcoding** 

The seqcodings represents each sequence as a 480 sized embedding. It is derived from averaging across each rescoding embedding for a given sequence, excluding any extra tokens. 

In [5]:
ablang(all_seqs, mode='seqcoding')

array([[ 0.74964871,  1.1840472 ,  1.00861522, ...,  1.15228194,
         0.85474925,  0.86772948],
       [-0.2517527 ,  0.21268045,  0.0752421 , ...,  0.18923415,
        -0.15258079, -0.08107855],
       [-0.27464662,  0.16865194,  0.08387676, ...,  0.18583655,
        -0.14459687, -0.16536156],
       [-0.19567645,  0.16862158, -0.04988689, ...,  0.10978557,
        -0.14783345, -0.10016351],
       [ 0.70623552,  1.17424116,  1.05676274, ...,  1.15829591,
         0.83368663,  0.84597218]])

## **rescoding / likelihood / probability**

The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.

In [6]:
ablang(all_seqs, mode='rescoding')

[array([[-0.5768882 ,  0.38245377, -0.21792021, ...,  0.01250281,
         -0.08844489, -0.32367533],
        [-0.14759329,  0.39639032, -0.38226995, ..., -0.10119925,
         -0.41469547, -0.00319326],
        [-0.15368716,  0.16587661, -0.30081886, ...,  0.02159324,
         -0.2850579 , -0.12827396],
        ...,
        [-0.1435836 ,  0.31243888, -0.30157977, ..., -0.13289277,
         -0.45353436, -0.07878845],
        [ 0.17538942,  0.24394313,  0.20141156, ...,  0.14587337,
         -0.38479012,  0.07409145],
        [-0.23031712, -0.354873  ,  0.19606796, ..., -0.12833637,
          0.3110731 , -0.3265107 ]], dtype=float32),
 array([[-0.50541353,  0.38347134, -0.10992067, ..., -0.05231511,
         -0.13636601, -0.34830102],
        [-0.06784626,  0.69349885, -0.4212396 , ..., -0.24805343,
         -0.39583787, -0.10972748],
        [-0.07713673,  0.31808612, -0.24827132, ...,  0.05780765,
         -0.24981117, -0.23789679],
        ...,
        [ 0.19134362,  0.21744648,  0.2

## **Align rescoding/likelihood/probability output**

For the 'rescoding', 'likelihood', and 'probability' modes, the output can also be aligned using the argument "align=True".

This is done using the antibody numbering tool ANARCI, and requires manually installing **Pandas** and **[ANARCI](https://github.com/oxpig/ANARCI)**.

**NB**: Align can only be used on input with the same format, i.e. either all heavy, all light, or all both heavy and light.

In [7]:
results = ablang(only_both_chains_seqs, mode='likelihood', align=True)

print(results.number_alignment)
print(results.aligned_seqs)
print(results.aligned_embeds)

['<' '1 ' '2 ' '3 ' '4 ' '5 ' '6 ' '7 ' '8 ' '9 ' '11 ' '12 ' '13 ' '14 '
 '15 ' '16 ' '17 ' '18 ' '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 '
 '27 ' '28 ' '29 ' '30 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 ' '41 ' '42 '
 '43 ' '44 ' '45 ' '46 ' '47 ' '48 ' '49 ' '50 ' '51 ' '52 ' '53 ' '54 '
 '55 ' '56 ' '57 ' '58 ' '59 ' '62 ' '63 ' '64 ' '65 ' '66 ' '67 ' '68 '
 '69 ' '70 ' '71 ' '72 ' '74 ' '75 ' '76 ' '77 ' '78 ' '79 ' '80 ' '81 '
 '82 ' '83 ' '84 ' '85 ' '86 ' '87 ' '88 ' '89 ' '90 ' '91 ' '92 ' '93 '
 '94 ' '95 ' '96 ' '97 ' '98 ' '99 ' '100 ' '101 ' '102 ' '103 ' '104 '
 '105 ' '106 ' '107 ' '108 ' '109 ' '110 ' '111 ' '112A' '112 ' '113 '
 '114 ' '115 ' '116 ' '117 ' '118 ' '119 ' '120 ' '121 ' '122 ' '123 '
 '124 ' '125 ' '126 ' '127 ' '128 ' '>' '|' '<' '1 ' '2 ' '3 ' '4 ' '5 '
 '6 ' '7 ' '8 ' '9 ' '10 ' '11 ' '12 ' '13 ' '14 ' '15 ' '16 ' '17 ' '18 '
 '19 ' '20 ' '21 ' '22 ' '23 ' '24 ' '25 ' '26 ' '27 ' '28 ' '29 ' '30 '
 '31 ' '32 ' '34 ' '35 ' '36 ' '37 ' '38 ' '39 ' '40 

array([[[  9.31621552,  -3.42184424,  -3.59398293, ..., -14.73707485,
          -6.8935895 ,  -0.23662716],
        [ -3.54718328,  -5.8486681 ,  -4.02423763, ..., -12.9396677 ,
          -9.56145287,  -4.48474121],
        [-11.94997597,  -2.2455442 ,  -5.69481659, ..., -15.1963892 ,
         -17.97455025, -12.56952667],
        ...,
        [ -8.94505119,  -0.42261413,  -4.95588017, ..., -16.66817665,
         -15.22247696, -10.37267685],
        [-11.65150261,  -5.44477367,  -2.95585799, ..., -16.25555801,
          -9.75158882, -11.75897026],
        [  1.79469967,  -1.95846725,  -3.59784651, ..., -14.95585823,
          -7.47080421,  -0.95226705]],

       [[  8.55518723,  -3.83663583,  -2.33596039, ..., -13.87456799,
          -8.14840603,  -0.42472461],
        [ -4.4070158 ,  -5.53201628,  -3.69397473, ..., -12.97877884,
          -9.86258984,  -4.95414734],
        [-11.95642948,  -3.86210847,  -5.80935097, ..., -14.89213085,
         -16.94556236, -11.36959457],
        ...,


## **pseudo_log_likelihood / confidence**

The pseudo_log_likelihood and confidence represents two methods for calculating an uncertainty for the input sequence.

In [8]:
results = ablang(all_seqs, mode='pseudo_log_likelihood')
np.exp(-results) # convert to pseudo perplexity

array([1.995193 , 2.017602 , 2.1375413, 1.8546418, 2.0021744],
      dtype=float32)

In [9]:
results = ablang(all_seqs, mode='confidence')
np.exp(-results)

array([1.2699332, 1.1272193, 1.3212233, 1.2203734, 1.1848254],
      dtype=float32)

## **restore**

This mode can be used to restore masked residues, and fragmented regions with "align=True". 

In [10]:
restored = ablang(only_both_chains_seqs, mode='restore')
restored

array(['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',
       '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTT>|<PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',
       '<EVQLVQSGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDPPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>'],
      dtype='<U238')

In [11]:
restored = ablang(only_both_chains_seqs, mode='restore', align = True) ## This doesn't work yet
restored

array(['<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',
       '<EVQLLESGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDVPGHGAAFMDVWGTGTTVTVSS>|<DVVMTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',
       '<QVQLVQSGGEVKKPGASVKVSCRASGYTFRNYGLTWVRQAPGQGLEWMGWISAYNGNTNYAQKFQGRVTLTTDTSTSTAYMELRSLRSDDTAVYFCARDPPGHGAAFMDVWGTGTTVTVSS>|<DIQLTQSPLSLPVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>'],
      dtype='<U238')