If you're opening this Notebook on colab, you will probably need to install 🤗 Transformers as well as some other libraries. Uncomment the following cell and run it.

In [None]:
#! pip install transformers evaluate datasets requests pandas sklearn

If you're opening this notebook locally, make sure your environment has an install from the last version of those libraries.

To be able to share your model with the community and generate results like the one shown in the picture below via the inference API, there are a few more steps to follow.

First you have to store your authentication token from the Hugging Face website (sign up here if you haven't already!) then execute the following cell and input your username and password:

In [None]:
from huggingface_hub import notebook_login

notebook_login()

Then you need to install Git-LFS. Uncomment the following instructions:

In [None]:
# !apt install git-lfs

# Fine-Tuning Protein Language Models

In this notebook, we're going to do some transfer learning to fine-tune some large, pre-trained protein language models on tasks of interest. If that sentence feels a bit intimidating to you, don't panic - there's also a blog post (link when it's done!) that explains the concepts here in much more detail. We highly recommend starting there if you're a biologist who isn't familiar with the state-of-the-art in language modeling (or if you're a machine learning engineer who isn't quite sure what a protein is!)

The specific model we're going to use is ESM-2, which is the state-of-the-art protein language model at the time of writing (September 2022). The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).

There are several ESM-2 checkpoints with differing model sizes. Larger models will generally have better accuracy, but they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints (at time of writing) are:

| Checkpoint name | Num layers | Num parameters |
|------------------------------|----|----------|
| `esm2_t48_15B_UR50D`         | 48 | 15B     |
| `esm2_t36_3B_UR50D`          | 36 | 3B      | 
| `esm2_t33_650M_UR50D`        | 33 | 650M    | 
| `esm2_t30_150M_UR50D`        | 30 | 150M    | 
| `esm2_t12_35M_UR50D`         | 12 | 35M     | 
| `esm2_t6_8M_UR50D`           | 6  | 8M      | 

Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on **any** single GPU! We will use the `esm2_t30_150M_UR50D` checkpoint for this notebook, which should train on any Colab instance or modern GPU.

In [1]:
model_checkpoint = "Rocketknight1/esm2_t12_35M_UR50D"

# Sequence classification

One of the most common tasks you can perform with a language model is **sequence classification**. In sequence classification, we classify an entire protein into a category, from a list of two or more possibilities. There's no limit on the number of categories you can use, or the specific problem you choose, as long as it's something the model could in theory infer from the raw protein sequence. To keep things simple for this example, though, let's try classifying proteins by their cellular localization - given their sequence, can we predict if they're going to be found in the cytosol (the fluid inside the cell) or embedded in the cell membrane?

## Data preparation

In this section, we're going to gather some training data from open web databases. Our goal is to create a pair of lists: `sequences` and `labels`. `sequences` will be a list of protein sequences, which will just be strings like "MNKL...", where each letter represents a single amino acid in the complete protein. `labels` will be a list of the category for each sequence. The categories will just be integers, with 0 representing the first category, 1 representing the second and so on. In other words, if `sequences[i]` is a protein sequence then `labels[i]` should be its corresponding category. These will form the **training data** we're going to use to teach the model the task we want it to do.

If you're adapting this notebook for your own use, this will probably be the main section you want to change! You can do whatever you want here, as long as you create those two lists by the end of it. If you want to follow along with this example, though, first we'll need to `import requests` and set up our query to UniProt.

In [2]:
import requests

query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Csequence%2Ccc_subcellular_location&format=tsv&query=%28%28organism_id%3A9606%29%20AND%20%28reviewed%3Atrue%29%20AND%20%28length%3A%5B100%20TO%201000%5D%29%29"

This query URL might seem mysterious, but it isn't! To get it, we searched for `(organism_id:9606) AND (reviewed:true) AND (length:[100 TO 1000])` on UniProt to get a list of reasonably-sized human proteins,
then selected 'Download', and set the format to TSV and the columns to `Sequence` and `Subcellular location [CC]`, since those contain the data we care about for this task.

Once that's done, selecting `Generate URL for API` gives you a URL you can pass to Requests. Alternatively, if you're not on Colab you can just download the data through the web interface and open the file locally.

In [3]:
uniprot_request = requests.get(query_url)

To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file. If you downloaded the data as a file you can skip this bit and just pass the filepath directly to `read_csv`.

In [4]:
from io import BytesIO
import pandas

bio = BytesIO(uniprot_request.content)

df = pandas.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
0,A0A087X1C5,MGLEALVPLAMIVAIFLLLVDLMHRHQRWAARYPPGPLPLPGLGNL...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
1,A0A0B4J2F2,MVIMSEFSADPAGQGQGQQKPLRVGFYDIERTLGKGNFAVVKLARH...,
2,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
3,A0A1B0GTW7,MLLLLLLLLLLPPLVLRVAASRCLHDETQKSVSLLRPPFSQLPSKS...,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...
4,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
...,...,...,...
17258,Q9NZ38,MAFPGQSDTKMQWPEVPALPLLSSLCMAMVRKSSALGKEVGRRSEG...,
17259,Q9UF83,MRRPSTASLTRTPSRASPTRMPSRASLKMTPFRASLTKMESTALLR...,
17260,Q9UFV3,MAETYRRSRQHEQLPGQRHMDLLTGYSKLIQSRLKLLLHLGSQPPV...,
17261,X6R8D5,MGRKEHESPSQPHMCGWEDSQKPSVPSHGPKTPSCKGVKAPHSSRP...,


Nice! Now we have some proteins and their subcellular locations. Let's start filtering this down. First, let's ditch the columns without subcellular location information. 

In [5]:
df = df.dropna()  # Drop proteins with missing columns

Now we'll make one dataframe of proteins that contain `cytosol` or `cytoplasm` in their subcellular localization column, and a second that mentions the `membrane` or `cell membrane`. To ensure we don't get overlap, we ensure each dataframe only contains proteins that don't match the other search term.

In [6]:
cytosolic = df['Subcellular location [CC]'].str.contains("Cytosol") | df['Subcellular location [CC]'].str.contains("Cytoplasm")
membrane = df['Subcellular location [CC]'].str.contains("Membrane") | df['Subcellular location [CC]'].str.contains("Cell membrane")

In [7]:
cytosolic_df = df[cytosolic & ~membrane]
cytosolic_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
15,A0MZ66,MNSSDEEKQLQLITSLKEQAIGEYEDLRAENQKTKEKCDKIRQERD...,SUBCELLULAR LOCATION: Perikaryon {ECO:0000250|...
25,A1E959,MKIIILLGFLGATLSAPLIPQRLMSASNSNELLLNLNNGQLLPLQL...,SUBCELLULAR LOCATION: Secreted {ECO:0000250|Un...
30,A1L4K1,MEEESGEELGLDRSTPKDFHFYHMDLYDSEDRLHLFPEENTRMRKV...,SUBCELLULAR LOCATION: Nucleus {ECO:0000250|Uni...
31,A1X283,MPPRRSIVEVKVLDVQKRRVPNKHYVYIIRVTWSSGSTEAIYRRYS...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000250}....
32,A1XBS5,MMRRTLENRNAQTKQLQTAVSNVEKHFGELCQIFAAYVRKTARLRD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...
...,...,...,...
16862,Q96M43,MVVSADPLSSERAEMNILEINQELRSQLAESNQQFRDLKEKFLITQ...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000305}.
16900,Q9BYD9,MNHCQLPVVIDNGSGMIKAGVAGCREPQFIYPNIIGRAKGQSRAAQ...,"SUBCELLULAR LOCATION: Cytoplasm, cytoskeleton ..."
16943,Q9NPB0,MEQRLAEFRAARKRAGLAAQPPAASQGAQTPGEKAEAAATLKAAPG...,SUBCELLULAR LOCATION: Cytoplasmic vesicle memb...
16955,Q9NUJ7,MGGQVSASNSFSRLHCRNANEDWMSALCPRLWDVPLHHLSIPGSHD...,SUBCELLULAR LOCATION: Cytoplasm {ECO:0000269|P...


In [8]:
membrane_df = df[membrane & ~cytosolic]
membrane_df

Unnamed: 0,Entry,Sequence,Subcellular location [CC]
2,A0A0K2S4Q6,MTQRAGAAMLPSALLLLCVPGCLTVSGPSTVMGAVGESLSVQCRYE...,SUBCELLULAR LOCATION: [Isoform 1]: Membrane {E...
3,A0A1B0GTW7,MLLLLLLLLLLPPLVLRVAASRCLHDETQKSVSLLRPPFSQLPSKS...,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...
4,A0A5B9,DLKNVFPPKVAVFEPSEAEISHTQKATLVCLATGFYPDHVELSWWV...,SUBCELLULAR LOCATION: Cell membrane {ECO:00003...
5,A0AV02,MTQMSQVQELFHEAAQQDALAQPQPWWKTQLFMWEPVLFGTWDGVF...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
10,A0FGR8,MTANRDAALSSHRHPGCAQRPRTPTFASSSQRRSAFGFDDGNFPGL...,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...
...,...,...,...
17131,Q6UWF5,MQIQNNLFFCCYTVMSAIFKWLLLYSLPALCFLLGTQESESFHSKA...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
17203,Q8N8V8,MLLKVRRASLKPPATPHQGAFRAGNVIGQLIYLLTWSLFTAWLRPP...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
17242,Q96N68,MQGQGALKESHIHLPTEQPEASLVLQGQLAESSALGPKGALRPQAQ...,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...
17248,Q9H0A3,MMNNTDFLMLNNPWNKLCLVSMDFCFPLDFVSNLFWIFASKFIIVT...,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...


We're almost done! Now, let's make a list of sequences from each df and generate the associated labels. We'll use `0` as the label for cytosolic proteins and `1` as the label for membrane proteins.

In [9]:
cytosolic_sequences = cytosolic_df["Sequence"].tolist()
cytosolic_labels = [0 for protein in cytosolic_sequences]

In [10]:
membrane_sequences = membrane_df["Sequence"].tolist()
membrane_labels = [1 for protein in membrane_sequences]

Now we can concatenate these lists together to get the `sequences` and `labels` lists that will form our final training data. Don't worry - they'll get shuffled during training!

In [11]:
sequences = cytosolic_sequences + membrane_sequences
labels = cytosolic_labels + membrane_labels

# Quick check to make sure we got it right
len(sequences) == len(labels)

True

Phew!

## Splitting the data

Since the data we're loading isn't prepared for us as a machine learning dataset, we'll have to split the data into train and test sets ourselves! We can use sklearn's function for that:

In [12]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

## Tokenizing the data

All inputs to neural nets must be numerical. The process of converting strings into numerical indices suitable for a neural net is called **tokenization**. For natural language this can be quite complex, as usually the network's vocabulary will not contain every possible word, which means the tokenizer must handle splitting rarer words into pieces, as well as all the complexities of capitalization and unicode characters and so on.

With proteins, however, things are very easy. In protein language models, each amino acid is converted to a single token. Every model on `transformers` comes with an associated `tokenizer` that handles tokenization for it, and protein language models are no different. Let's get our tokenizer!

In [13]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading:   0%|          | 0.00/40.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/125 [00:00<?, ?B/s]

Let's try a single sequence to see what the outputs from our tokenizer look like:

In [14]:
tokenizer(train_sequences[0])

{'input_ids': [0, 20, 4, 10, 18, 12, 16, 15, 18, 8, 16, 5, 8, 8, 15, 12, 4, 15, 19, 8, 18, 14, 7, 6, 4, 10, 11, 8, 10, 11, 13, 12, 4, 8, 4, 15, 20, 8, 4, 16, 16, 17, 18, 8, 14, 23, 14, 10, 14, 22, 4, 8, 8, 8, 18, 14, 5, 19, 20, 8, 15, 11, 16, 23, 19, 21, 11, 8, 14, 23, 8, 18, 15, 15, 16, 16, 15, 16, 5, 4, 4, 5, 10, 14, 8, 8, 11, 12, 11, 19, 4, 11, 13, 8, 14, 15, 14, 5, 4, 23, 7, 11, 4, 5, 6, 4, 12, 14, 18, 7, 5, 14, 14, 4, 7, 20, 4, 20, 11, 15, 11, 19, 12, 14, 12, 4, 5, 18, 11, 16, 20, 5, 19, 6, 5, 8, 18, 4, 8, 18, 4, 6, 6, 12, 10, 22, 6, 18, 5, 4, 14, 9, 6, 8, 14, 5, 15, 14, 13, 19, 4, 17, 4, 5, 8, 8, 5, 5, 14, 4, 18, 18, 8, 22, 18, 5, 18, 4, 12, 8, 9, 10, 4, 8, 9, 5, 12, 7, 11, 7, 12, 20, 6, 20, 6, 7, 5, 18, 21, 4, 9, 4, 18, 4, 4, 14, 21, 19, 14, 17, 22, 18, 15, 5, 4, 10, 12, 7, 7, 11, 4, 4, 5, 11, 18, 8, 18, 12, 12, 11, 4, 7, 7, 15, 8, 8, 18, 14, 9, 15, 6, 21, 15, 10, 14, 6, 16, 7, 2], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

This looks good! We can see that our sequence has been converted into `input_ids`, which is the tokenized sequence, and an `attention_mask`. The attention mask handles the case when we have sequences of variable length - in those cases, the shorter sequences are padded with blank "padding" tokens, and the attention mask is padded with 0s to indicate that those tokens should be ignored by the model.

So now, let's tokenize our whole dataset. Note that we don't need to do anything with the labels, as they're already in the format we need.

In [15]:
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)

## Dataset creation

Now we want to turn this data into a dataset that PyTorch can load samples from. We can use the HuggingFace `Dataset` class for this, although if you prefer you can also use `torch.utils.data.Dataset`, at the cost of some more boilerplate code.

In [16]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 5652
})

This looks good, but we're missing our labels! Let's add those on as an extra column to the datasets.

In [17]:
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
train_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 5652
})

Looks good! We're ready for training.

## Model loading

Next, we want to load our model. Make sure to use exactly the same model as you used when loading the tokenizer, or your model might not understand the tokenization scheme you're using!

In [18]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

num_labels = max(train_labels + test_labels) + 1  # Add 1 since 0 can be a label
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Downloading:   0%|          | 0.00/646 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of the model checkpoint at Rocketknight1/esm2_t12_35M_UR50D were not used when initializing EsmForSequenceClassification: ['lm_head.layer_norm.bias', 'lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias']
- This IS expected if you are initializing EsmForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing EsmForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at Rocketknight1/esm2_t12_35M_UR50D and are newly initialized: ['classifier.out_proj.weight', 'classifier.dense.bias', 'classifier.out_proj.bias', 'classifier.dens

These warnings are telling us that the model is discarding some weights that it used for language modelling (the `lm_head`) and adding some weights for sequence classification (the `classifier`). This is exactly what we expect when we want to fine-tune a language model on a sequence classification task!

Next, we initialize our `TrainingArguments`. These control the various training hyperparameters, and will be passed to our `Trainer`.

In [19]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8

args = TrainingArguments(
    f"{model_name}-finetuned-localization",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

Next, we define the metric we will use to evaluate our models and write a `compute_metrics` function. We can load this from the `evaluate` library.

In [20]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

And at last we're ready to initialize our `Trainer`:

In [21]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

Cloning https://huggingface.co/Rocketknight1/esm2_t12_35M_UR50D-finetuned-localization into local empty directory.


You might wonder why we pass along the `tokenizer` when we already preprocessed our data. This is because we will use it one last time to make all the samples we gather the same length by applying padding, which requires knowing the model's preferences regarding padding (to the left or right? with which token?). The `tokenizer` has a pad method that will do all of this right for us, and the `Trainer` will use it. You can customize this part by defining and passing your own `data_collator` which will receive samples like the dictionaries seen above and will need to return a dictionary of tensors.

We can now finetune our model by just calling the `train` method:

In [22]:
trainer.train()

***** Running training *****
  Num examples = 5652
  Num Epochs = 3
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 2121
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ········


[34m[1mwandb[0m: W&B syncing is set to `offline` in this directory.  Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.


Epoch,Training Loss,Validation Loss,Accuracy
1,0.2586,0.198971,0.940552
2,0.1769,0.176735,0.943737
3,0.1246,0.18033,0.947452


***** Running Evaluation *****
  Num examples = 1884
  Batch size = 8
Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-707
Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-707/config.json
Model weights saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-707/pytorch_model.bin
tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-707/tokenizer_config.json
Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-707/special_tokens_map.json
tokenizer config file saved in esm2_t12_35M_UR50D-finetuned-localization/tokenizer_config.json
Special tokens file saved in esm2_t12_35M_UR50D-finetuned-localization/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 1884
  Batch size = 8
Saving model checkpoint to esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1414
Configuration saved in esm2_t12_35M_UR50D-finetuned-localization/checkpoint-1414/config.jso

TrainOutput(global_step=2121, training_loss=0.1768098774522838, metrics={'train_runtime': 678.7337, 'train_samples_per_second': 24.982, 'train_steps_per_second': 3.125, 'total_flos': 2687138428873008.0, 'train_loss': 0.1768098774522838, 'epoch': 3.0})

Nice! After three epochs we have a model accuracy of 94-95%. Note that we didn't do a lot of work to filter the training data or tune hyperparameters for this experiment, and also that we used one of the smallest ESM-2 models. With a larger starting model and more effort to ensure that the training data categories were cleanly separable, accuracy could almost certainly go a lot higher!