# **Running TCRLang-paired.**

This simply involves using the TCRLang paired weights. All that needs to be changed is the model_to_use when the model is first initialised!

All functionality should be the same **except for the "align = True"** mode. We are currently working on this issue.

In [1]:
import numpy as np
import torch
import ablang2

# **0. Sequence input and its format**

This takes as input either the individual beta variable domain (TRB), alpha variable domain (TRA), or the paired TCR.

Each record (antibody) needs to be a list with the TRB as the first element and the TRA as the second. If either the TRB or TRA is not known, leave an empty string.

An asterisk (\*) is used for masking. It is recommended to mask residues which you are interested in mutating.

**NB:** It is important that the TRB and TRA sequence is ordered correctly.

In [2]:
# Let's use the famous JM22 TCR to begin with

seq1 = [
    'GGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSYEQYFGPGTRLTVTEDLKN', # TRB sequence
    'QLLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVVTGGEVKKLKRLTFQFGDARKDSSLHITAAQPGDTGLYLCAGAGSQGNLIFGKGTKLSVKP' # TRA sequence
]
seq2 = [
    'GITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSYEQYFGPGTRLTVTEDLKN',
    'PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK'
]
seq3 = [
    'GGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSYEQYFGPGTRLTVTEDLKN',
    '' # The TRA sequence is not known, so an empty string is left instead. 
]
seq4 = [
    '',
    'QLLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVVTGGEVKKLKRLTFQFGDARKDSSLHITAAQPGDTGLYLCAGAGSQGNLIFGKGTKLSVKP'
]
seq5 = [
    'GITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSS*EQYFGPGTRLTVTEDLKN', # (*) is used to mask certain residues
    'QLLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVVTGGEVKKLKRLTFQFGD*RKDSSLHITAAQPGDTGLYLCAG*GSQGNLIFGKGTKLSVKP'
]

all_seqs = [seq1, seq2, seq3, seq4, seq5]
only_both_chains_seqs = [seq1, seq2, seq5]

# **1. How to use TCRLang-paired**

TCRLang-paired can be downloaded and used in its raw form as seen below. For convenience, we have also developed different "modes" which can be used for specific use cases (see Section 2) 

In [3]:
# Download and initialise the model
tcrlang = ablang2.pretrained(model_to_use='tcrlang-paired', # This is all that needs to be changed.
                             random_init=False, 
                             ncpu=1, 
                             device='cpu')

# Tokenize input sequences
seq = f"{seq1[0]}|{seq1[1]}" # TRB first, TRA second, with | used to separated the two sequences 
tokenized_seq = tcrlang.tokenizer([seq], pad=True, w_extra_tkns=False, device="cpu")
        
# Generate rescodings
with torch.no_grad():
    rescoding = tcrlang.AbRep(tokenized_seq).last_hidden_states

# Generate logits/likelihoods
with torch.no_grad():
    likelihoods = tcrlang.AbLang(tokenized_seq)

# **2. Different modes for specific usecases**

ablang2 has already been implemented for a variety of different usecases. The benefit of these modes is that they handle extra tokens such as start, stop and separation tokens.

1. seqcoding: Generates sequence representations for each sequence
2. rescoding: Generates residue representations for each residue in each sequence
3. likelihood: Generates likelihoods for each amino acid at each position in each sequence
4. probability: Generates probabilities for each amino acid at each position in each sequence
5. pseudo_log_likelihood: Returns the pseudo log likelihood for a sequence (based on masking each residue one at a time)
6. confidence: Returns a fast calculation of the log likelihood for a sequence (based on a single pass with no masking)
7. restore: Restores masked residues

### **ablang2 can also align the resulting representations using ANARCI**

This can be done for 'rescoding', 'likelihood', and 'probability'. This is done by setting the argument "align=True".

**NB**: Align can only be used on input with the same format, i.e. either all beta, all alpha, or all both beta and alpha.

### **The align argument can also be used to restore variable missing lengths**

For this, use "align=True" with the 'restore' mode.

In [4]:
tcrlang = ablang2.pretrained()

valid_modes = [
    'seqcoding', 'rescoding', 'likelihood', 'probability',
    'pseudo_log_likelihood', 'confidence', 'restore' 
]

## **seqcoding** 

The seqcodings represents each sequence as a 480 sized embedding. It is derived from averaging across each rescoding embedding for a given sequence, including extra tokens. 

**NB:** Seqcodings can also be derived in other ways like using the sum or averaging across only parts of the input such as the CDRs. For such cases please use and adapt the below rescoding.

In [5]:
tcrlang(all_seqs, mode='seqcoding')

array([[-0.08125024, -0.01384698, -0.15913074, ...,  0.28860582,
        -0.12494163,  0.04056989],
       [-0.13111236,  0.03872783, -0.03324484, ...,  0.28661499,
        -0.10136888,  0.00161366],
       [-0.05338498, -0.06573963, -0.11988864, ...,  0.31983912,
        -0.07272346,  0.06720409],
       [-0.08249289,  0.08119008, -0.26181517, ...,  0.22905781,
        -0.17356422, -0.01206601],
       [-0.0763039 , -0.00637079, -0.16056736, ...,  0.29032119,
        -0.13395997,  0.03960718]])

## **rescoding / likelihood / probability**

The rescodings represents each residue as a 480 sized embedding. The likelihoods represents each residue as the predicted logits for each character in the vocabulary. The probabilities represents the normalised likelihoods.

**NB:** The output includes extra tokens (start, stop and separation tokens) in the format "<TRB_seq>|<TRA_seq>". The length of the output is therefore 5 longer than the TRB and TRA.

**NB:** By default the representations are derived using a single forward pass. To prevent the predicted likelihood and probability to be affected by the input residue at each position, setting the "stepwise_masking" argument to True can be used. This will run a forward pass for each position with the residue at that position masked. This is much slower than running a single forward pass.

In [6]:
tcrlang(all_seqs, mode='rescoding', stepwise_masking = False)

[array([[-0.51347226, -0.28576213,  0.10309762, ...,  0.5203557 ,
          0.0117211 , -0.08264501],
        [-0.19843523,  0.1543197 , -0.5031867 , ..., -0.2870497 ,
         -0.21901898,  0.36414275],
        [-0.16083395,  0.20545605, -0.21582198, ..., -0.00274766,
          0.066175  ,  0.27734903],
        ...,
        [ 0.19438688,  0.2340737 ,  0.0407359 , ...,  0.03316897,
          0.18425798,  0.14009582],
        [-0.09048033, -0.41594166,  0.3686235 , ...,  0.05291507,
         -0.13554473, -0.09198374],
        [-0.04215955, -0.4688292 , -0.04049325, ...,  0.05855337,
          0.08600137,  0.13561374]], dtype=float32),
 array([[-0.33639285,  0.06262851, -0.09385429, ...,  0.29438573,
          0.09021386, -0.03847658],
        [-0.13315135,  0.18713355,  0.07811087, ...,  0.5782139 ,
         -0.22035252,  0.03181488],
        [ 0.3239961 , -0.01685584, -0.5550718 , ...,  0.36060256,
          0.42027324,  0.03702496],
        ...,
        [-0.14616522,  0.15133138, -0.2

## **Pseudo log likelihood and Confidence scores**

The pseudo log likelihood and confidence represents two methods for calculating the uncertainty for the input sequence.

- pseudo_log_likelihood: For each position, the pseudo log likelihood is calculated when predicting the masked residue. The final score is an average across the whole input. This is similar to the approach taken in the ESM-2 paper for calculating pseudo perplexity [(Lin et al., 2023)](https://doi.org/10.1126/science.ade2574).

- confidence: For each position, the log likelihood is calculated without masking the residue. The final score is an average across the whole input. 

**NB:** The **confidence is fast** to compute, requiring only a single forward pass per input. **Pseudo log likelihood is slow** to calculate, requiring L forward passes per input, where L is the length of the input.

**NB:** It is recommended to use **pseudo log likelihood for final results** and **confidence for exploratory work**.

In [7]:
results = tcrlang(all_seqs, mode='pseudo_log_likelihood')
np.exp(-results) # convert to pseudo perplexity

array([20.41889  ,  6.8523793, 21.07971  , 20.479687 , 18.358845 ],
      dtype=float32)

In [8]:
results = tcrlang(all_seqs, mode='confidence')
np.exp(-results)

array([2.2753055, 1.528476 , 2.1577282, 2.5318768, 2.156193 ],
      dtype=float32)

## **restore**

This mode can be used to restore masked residues. 

In [9]:
restored = tcrlang(only_both_chains_seqs, mode='restore')
restored

array(['<GGITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSYEQYFGPGTRLTVTEDLKN>|<QLLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVVTGGEVKKLKRLTFQFGDARKDSSLHITAAQPGDTGLYLCAGAGSQGNLIFGKGTKLSVKP>',
       '<GITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSYEQYFGPGTRLTVTEDLKN>|<PVTLGQPASISCRSSQSLEASDTNIYLSWFQQRPGQSPRRLIYKISNRDSGVPDRFSGSGSGTHFTLRISRVEADDVAVYYCMQGTHWPPAFGQGTKVDIK>',
       '<GITQSPKYLFRKEGQNVTLSCEQNLNHDAMYWYRQDPGQGLRLIYYSQIVNDFQKGDIAEGYSVSREKKESFPLTVTSAQKNPTAFYLCASSIRSSREQYFGPGTRLTVTEDLKN>|<QLLEQSPQFLSIQEGENLTVYCNSSSVFSSLQWYRQEPGEGPVLLVTVVTGGEVKKLKRLTFQFGDSRKDSSLHITAAQPGDTGLYLCAGRGSQGNLIFGKGTKLSVKP>'],
      dtype='<U230')