# CaLM
## The Codon adaptation Language Model

Implementation of the [Codon adaptation Language Model (CaLM)](https://www.biorxiv.org/content/10.1101/2022.12.15.519894v1.abstract), a large protein language model has been trained on coding DNA, instead of amino acids. Embeddings from CaLM can be used to build automatic features for downstream machine learning tasks — like predicting a protein's function, the species it belongs to, or others. This Colab Notebook enables automatic embedding of an arbitrary number of proteins using the GPUs available at Google Colaboratory.  

CaLM is a research tool developed by Carlos Outeiral, at the Oxford Protein Informatics Group.

## Set up

In [1]:
! pip3 install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
! pip3 install einops
! pip3 install rotary_embedding_torch
! pip3 install biopython
! python setup.py install

Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cpu
running install
running bdist_egg
running egg_info
writing CaLM.egg-info/PKG-INFO
writing dependency_links to CaLM.egg-info/dependency_links.txt
writing requirements to CaLM.egg-info/requires.txt
writing top-level names to CaLM.egg-info/top_level.txt
reading manifest file 'CaLM.egg-info/SOURCES.txt'
writing manifest file 'CaLM.egg-info/SOURCES.txt'
installing library code to build/bdist.linux-x86_64/egg
running install_lib
running build_py
copying calm/pretrained.py -> build/lib/calm
creating build/bdist.linux-x86_64/egg
creating build/bdist.linux-x86_64/egg/calm
copying build/lib/calm/alphabet.py -> build/bdist.linux-x86_64/egg/calm
copying build/lib/calm/multihead_attention.py -> build/bdist.linux-x86_64/egg/calm
copying build/lib/calm/pipeline.py -> build/bdist.linux-x86_64/egg/calm
copying build/lib/calm/dataset.py -> build/bdist.linux-x86_64/egg/calm
copying build/lib/calm/__init__.py -> build/bdist.

## Useful functions

In [2]:
import string


def _split_into_codons(seq):
  for i in range(0, len(seq), 3):
    yield seq[i:i + 3]
        
def split_into_codons(seq):
  """Yield successive 3-letter chunks of a string/sequence."""
  return list(_split_into_codons(seq))

def randstr(N=20):
  """Returns a random alphanumerical string."""
  return ''.join(random.choice(string.ascii_uppercase + string.digits)
    for _ in range(N))
  
def is_dna(seq):
  """Checks if a sequence matches the DNA alphabet."""
  return set(seq.upper()).issubset({'A', 'T', 'C', 'G'})

def is_rna(seq):
  """Checks if a sequence matches the RNA alphabet."""
  return set(seq.upper()).issubset({'A', 'U', 'C', 'G'})

def check_valid_codons(seq):
  """Checks if a sequence has a number of symbols
  that is a multiple of three."""
  return len(seq) % 3 == 0

def check_start_codon(seq):
  """Checks that the start codon is correct."""
  if is_dna(seq):
    return seq[:3] == 'ATG'
  elif is_rna(seq):
    return seq[:3] == 'AUG'
  else:
    return False

def check_stop_codon(seq):
  """Checks that the stop codon is correct."""
  if is_dna(seq):
    return seq[-3:] in ['TGA', 'TAG', 'TAA']
  elif is_rna(seq):
    return seq[-3:] in ['UGA', 'UAG', 'UAA']
  else:
    return False

def check_no_interstitial_stop_codons(seq):
  codons = split_into_codons(seq)
  coding_positions = set(codons[:-1])
  if is_dna(seq):
    return {'TGA', 'TAG', 'TAA'} not in coding_positions
  elif is_rna(seq):
    return {'UGA', 'UAG', 'UAA'} not in coding_positions
  else:
    return False

## Embed a sequence

The following code exemplifies how to embed a sequence using CaLM. In order to embed a specific cDNA sequence, just modify the value of the `cDNA_sequence` variable.

In [3]:
cDNA_sequence = 'ATGGTATAGAGGCATTGA'
alphabet = 'DNA'

In [4]:
if cDNA_sequence == '':
  raise ValueError('No cDNA sequence was provided.')
if alphabet == 'DNA':
  if not is_dna(cDNA_sequence):
    raise ValueError('Provided DNA sequence did not pass sanity check.')
elif alphabet == 'RNA':
  if not is_rna(cDNA_sequence):
    raise ValueError('Provided RNA sequence did not pass sanity check.')
if not check_valid_codons(cDNA_sequence):
  raise ValueError('The number of nucleotides is not a multiple of three.')
if not check_start_codon(cDNA_sequence):
  raise ValueError('Invalid start codon.')
if not check_stop_codon(cDNA_sequence):
  raise ValueError('Invalid stop codon.')
if not check_no_interstitial_stop_codons(cDNA_sequence):
  raise ValueError('Provided sequence contains interstitial stop codons.')

In the paper we have used sequence representations, where the representations of all codons in a protein are averaged. This is the default output of CaLM:

In [7]:
from calm import CaLM

model = CaLM()
vector = model.embed_sequence(cDNA_sequence)
print('Embeddings shape: ', vector.shape)
print(vector)

Embeddings shape:  torch.Size([1, 768])
tensor([[-6.3635e-02,  9.2567e-02,  3.2795e-01, -9.2734e-02,  1.4253e-01,
          4.5020e-01,  3.4477e-02,  2.9611e-01, -2.6117e-01, -2.3144e-01,
          8.8645e-02, -9.5536e-03, -1.6554e-01, -5.2163e-03, -1.2918e-01,
          5.2430e-01,  1.1867e-01,  2.6617e-02,  3.1034e-02, -3.2623e-01,
          3.2302e-01,  7.4128e-01,  2.0493e-01,  3.0914e-01,  5.4850e-02,
         -5.8367e-02,  2.8206e-01,  9.0826e-02, -3.7607e-01, -5.7403e-02,
         -2.5932e-01, -3.9007e-02,  7.0167e-02, -4.0534e-02,  2.1532e-02,
          2.3047e-01,  8.3513e-02, -3.9460e-02,  2.4399e-02, -1.3421e-02,
         -2.1616e-02,  1.2546e-03, -5.6222e-02, -1.9381e-01, -5.3109e-02,
          5.6924e-02, -3.8164e-02,  9.5740e-02, -2.2428e-01,  6.0630e-02,
         -4.6556e-01, -1.8238e-01, -3.3394e-02,  1.6678e-02,  1.0147e-01,
          4.5168e-01, -4.5880e-02,  1.6808e-01, -8.8289e-02,  3.2918e-02,
         -4.5759e-02,  1.2530e-01, -1.2781e-01,  1.8653e-01,  1.9146e-01

It is also possible to obtain the raw, unaveraged embeddings:

In [6]:
vector = model.embed_sequence(cDNA_sequence, average=False)
print('Embeddings shape: ', vector.shape)
print(vector)

Embeddings shape:  torch.Size([1, 8, 768])
tensor([[[-0.0320, -0.0483,  0.5161,  ...,  0.5614, -0.1594,  0.4159],
         [ 0.0960,  0.2022,  0.1533,  ...,  0.4696,  0.0534, -0.1295],
         [-0.0489,  0.1945,  0.2206,  ...,  0.4299, -0.3417,  0.0037],
         ...,
         [-0.1136, -0.1073,  0.2894,  ...,  0.4902, -0.0640,  0.0426],
         [-0.2687,  0.1268,  0.3846,  ...,  0.2381, -0.1316,  0.3572],
         [ 0.1666, -0.0149,  0.4737,  ...,  0.2082, -0.1321,  0.4308]]],
       grad_fn=<TransposeBackward0>)
