# Important things to know

* # You will need to run this file on Kaggle and use GPU-T4 X2 for Acceleration
* # You will need a Huggingface Account : Signup here : https://huggingface.co
* # You will need a WandDB account as well : Signup here : https://wandb.ai/

In [1]:
!pip install transformers
!pip install evaluate
!pip install datasets
!pip install requests
!pip install pandas

Collecting evaluate
  Obtaining dependency information for evaluate from https://files.pythonhosted.org/packages/70/63/7644a1eb7b0297e585a6adec98ed9e575309bb973c33b394dae66bc35c69/evaluate-0.4.1-py3-none-any.whl.metadata
  Downloading evaluate-0.4.1-py3-none-any.whl.metadata (9.4 kB)
Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: evaluate
Successfully installed evaluate-0.4.1


In [2]:
import torch

# Set the device to the first GPU
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


Add your token from HuggingFace you can find it under Profile >> Access Tokens

In [3]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

Then you need to install Git-LFS.

In [4]:
!apt install git-lfs

Reading package lists... Done
Building dependency tree... Done
Reading state information... Done
git-lfs is already the newest version (3.0.2-1ubuntu0.2).
0 upgraded, 0 newly installed, 0 to remove and 46 not upgraded.


In [5]:
from transformers.utils import send_example_telemetry

send_example_telemetry("protein_language_modeling_notebook", framework="pytorch")

# Fine-Tuning Protein Language Models

In this notebook, we're going to do some transfer learning to fine-tune some large, pre-trained protein language models on tasks of interest. 

The specific model we're going to use is ESM-2, which is the state-of-the-art protein language model. The citation for this model is [Lin et al, 2022](https://www.biorxiv.org/content/10.1101/2022.07.20.500902v1).

There are several ESM-2 checkpoints with differing model sizes. Larger models will generally have better accuracy, but they require more GPU memory and will take much longer to train. The available ESM-2 checkpoints (at time of writing) are:

| Checkpoint name | Num layers | Num parameters |
|------------------------------|----|----------|
| `esm2_t48_15B_UR50D`         | 48 | 15B     |
| `esm2_t36_3B_UR50D`          | 36 | 3B      |
| `esm2_t33_650M_UR50D`        | 33 | 650M    |
| `esm2_t30_150M_UR50D`        | 30 | 150M    |
| `esm2_t12_35M_UR50D`         | 12 | 35M     |
| `esm2_t6_8M_UR50D`           | 6  | 8M      |

Note that the larger checkpoints may be very difficult to train without a large cloud GPU like an A100 or H100, and the largest 15B parameter checkpoint will probably be impossible to train on **any** single GPU! Also, note that memory usage for attention during training will scale as `O(batch_size * num_layers * seq_len^2)`, so larger models on long sequences will use quite a lot of memory! We will use the `esm2_t12_35M_UR50D` checkpoint for this notebook, which should train on any Colab instance or modern GPU.

In [7]:
model_checkpoint = "facebook/esm2_t12_35M_UR50D"

# Sequence classification / Protein SubCellular Localisation

One of the most common tasks you can perform with a language model is **sequence classification**. In sequence classification, we classify an entire protein into a category, from a list of two or more possibilities. There's no limit on the number of categories you can use, or the specific problem you choose, as long as it's something the model could in theory infer from the raw protein sequence.
For our Project we will be classifying proteins by their cellular localization - given their sequence, can we predict if they're going to be found in the Chloroplast or not?

## Data preparation

In this section, we're going to gather some training data from UniProt. Our goal is to create a pair of lists: `sequences` and `labels`. `sequences` will be a list of protein sequences, which will just be strings like "MNKL...", where each letter represents a single amino acid in the complete protein. `labels` will be a list of the category for each sequence. The categories will just be integers, with 1 representing the first category ie Protein is in **Chloroplast**, 0 representing the second ie **Not present in Chloroplast**. 

In other words, if `sequences[i]` is a protein sequence then `labels[i]` should be its corresponding category. These will form the **training data** we're going to use to teach the model the task we want it to do.


In [8]:
import requests

query_url ="https://rest.uniprot.org/uniprotkb/stream?compressed=true&fields=accession%2Creviewed%2Cid%2Ccc_subcellular_location%2Csequence&format=tsv&query=%28A.thaliana%29+AND+%28model_organism%3A3702%29+AND+%28reviewed%3Atrue%29"

This query URL might seem mysterious, but it isn't! To get it, we searched for `(A.thaliana) AND (reviewed:true) AND (length:[0 TO 512])` on UniProt to get a list of reasonably-sized Plant proteins,
then selected 'Download', and set the format to TSV and the columns to `Sequence` and `Subcellular location [CC]`, since those contain the data we care about for this task.

Once that's done, selecting `Generate URL for API` gives you a URL you can pass to Requests.

In [9]:
uniprot_request = requests.get(query_url)

To get this data into Pandas, we use a `BytesIO` object, which Pandas will treat like a file.

In [10]:
from io import BytesIO
import pandas

bio = BytesIO(uniprot_request.content)

df = pandas.read_csv(bio, compression='gzip', sep='\t')
df

Unnamed: 0,Entry,Reviewed,Entry Name,Subcellular location [CC],Sequence
0,A0A0A7EPL0,reviewed,PIAL1_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.,MVIPATSRFGFRAEFNTKEFQASCISLANEIDAAIGRNEVPGNIQE...
1,A0A178VEK7,reviewed,DUO1_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000255|PRO...,MRKMEAKKEEIKKGPWKAEEDEVLINHVKRYGPRDWSSIRSKGLLQ...
2,A0A178WF56,reviewed,CSTM3_ARATH,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MAQYHQQHEMKQTMAETQYVTAPPPMGYPVMMKDSPQTVQPPHEGQ...
3,A0A1I9LMX5,reviewed,PCEP9_ARATH,SUBCELLULAR LOCATION: [C-terminally encoded pe...,MKLLSITLTSIVISMVFYQTPITTEARSLRKTNDQDHFKAGFTDDF...
4,A0A1I9LN01,reviewed,LAF3_ARATH,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...,MTGWYEFPVMIGFVSAAVFLLISVAYLPLLNDLYWSTLKSLTPPAG...
...,...,...,...,...,...
16296,Q9ZVR0,reviewed,PP2B6_ARATH,,MGQKLGVDSRQKIRQVLGSSSKVQKHDVESIGGGGGEIVPGHSPFD...
16297,Q9ZVR1,reviewed,PP2B5_ARATH,,MGQKHGVDTRGKGAEFCGCWEILTEFINGSSASFDDLPDDCLAIIS...
16298,Q9ZVR3,reviewed,PP2B4_ARATH,,MNTQILSQKTRYSAYIVYKTIYRFHGFKHIGVGFIGHGTPKAKRWE...
16299,Q9ZW38,reviewed,FBK36_ARATH,,MASISETSDDGSNGGDPNQKPEEPHKNPQEGKEEENQNEKPKEDDH...


Nice! Now we have some proteins and their subcellular locations. Let's start filtering this down. First, let's remove the columns without subcellular location information.

In [11]:
df = df.dropna()  # Drop proteins with missing columns
df

Unnamed: 0,Entry,Reviewed,Entry Name,Subcellular location [CC],Sequence
0,A0A0A7EPL0,reviewed,PIAL1_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000305}.,MVIPATSRFGFRAEFNTKEFQASCISLANEIDAAIGRNEVPGNIQE...
1,A0A178VEK7,reviewed,DUO1_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000255|PRO...,MRKMEAKKEEIKKGPWKAEEDEVLINHVKRYGPRDWSSIRSKGLLQ...
2,A0A178WF56,reviewed,CSTM3_ARATH,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MAQYHQQHEMKQTMAETQYVTAPPPMGYPVMMKDSPQTVQPPHEGQ...
3,A0A1I9LMX5,reviewed,PCEP9_ARATH,SUBCELLULAR LOCATION: [C-terminally encoded pe...,MKLLSITLTSIVISMVFYQTPITTEARSLRKTNDQDHFKAGFTDDF...
4,A0A1I9LN01,reviewed,LAF3_ARATH,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...,MTGWYEFPVMIGFVSAAVFLLISVAYLPLLNDLYWSTLKSLTPPAG...
...,...,...,...,...,...
16242,Q9SYZ7,reviewed,U496A_ARATH,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...,MGNQTSKKSQETSAKSVHYTTELRSYAAACKADTELQSFDTCLQAR...
16259,Q9XIK5,reviewed,Y1045_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000250}.,MAQNKNLNLELSLSQYVEDDPWVLKKKLSDSDLYYSAQLYLPKQEM...
16273,Q9ZU26,reviewed,Y1197_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000250}.,MAQELDLELGLAPYDPWVLKKNLTESDLNNGFIILPKQDFEKIIRQ...
16275,Q9ZU96,reviewed,Y2168_ARATH,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...,MEMKQMKFLTHQAFFSSVRSGDLSQLQQLVDNLTGDELIDESSPCS...


In [12]:
df = df[df['Sequence'].apply(lambda x: len(x) < 512)]  #drop sequences with length larger than 512 [Because of Data limit issues]
df

Unnamed: 0,Entry,Reviewed,Entry Name,Subcellular location [CC],Sequence
1,A0A178VEK7,reviewed,DUO1_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000255|PRO...,MRKMEAKKEEIKKGPWKAEEDEVLINHVKRYGPRDWSSIRSKGLLQ...
2,A0A178WF56,reviewed,CSTM3_ARATH,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MAQYHQQHEMKQTMAETQYVTAPPPMGYPVMMKDSPQTVQPPHEGQ...
3,A0A1I9LMX5,reviewed,PCEP9_ARATH,SUBCELLULAR LOCATION: [C-terminally encoded pe...,MKLLSITLTSIVISMVFYQTPITTEARSLRKTNDQDHFKAGFTDDF...
5,A0A1P8AQ95,reviewed,STMP4_ARATH,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MTKNMTKKKMGLMSPNIAAFVLPMLLVLFTISSQVEVVESTGRKLS...
9,A0JQ18,reviewed,SOP14_ARATH,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MAAKTSNLVALLLSLFLLLLSISSQVGLGEAKRNLRNNLRLDCVSH...
...,...,...,...,...,...
16241,Q9SYL8,reviewed,Y1786_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000250}.,MAEEQREISHENNVSLGSAETAIPLTNVSISPTKKEEQKTVYLVLF...
16242,Q9SYZ7,reviewed,U496A_ARATH,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...,MGNQTSKKSQETSAKSVHYTTELRSYAAACKADTELQSFDTCLQAR...
16259,Q9XIK5,reviewed,Y1045_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000250}.,MAQNKNLNLELSLSQYVEDDPWVLKKKLSDSDLYYSAQLYLPKQEM...
16273,Q9ZU26,reviewed,Y1197_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000250}.,MAQELDLELGLAPYDPWVLKKNLTESDLNNGFIILPKQDFEKIIRQ...


In [13]:
# saving the dataframe
df.to_csv('UneditedDataUniprotProteins.csv')

Now we'll make one dataframe of proteins that contain `Chloroplast` in their subcellular localization column, and a second that does not mention `Chloroplast`. 

In [14]:
chloroplastic = df['Subcellular location [CC]'].str.contains("chloroplast")


In [15]:
chloroplastic_df = df[chloroplastic]
chloroplastic_df

Unnamed: 0,Entry,Reviewed,Entry Name,Subcellular location [CC],Sequence
12,A1A6H3,reviewed,RBSK_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast str...",MMKGISSVSQSINYNPYIEFNRPQLQISTVNPNPAQSRFSRPRSLR...
13,A1A6M1,reviewed,PTAC5_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast str...",MASSSLPLSLPFPLRSLTSTTRSLPFQCSPLFFSIPSSIVCFSTQN...
17,A2RVM0,reviewed,TIC32_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast inn...",MWFFGSKGASGFSSRSTAEEVTHGVDGTGLTAIVTGASSGIGVETA...
70,B9DFK5,reviewed,RETIC_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast mem...",MAGCAMNLQFSSVVKVRNEISSFGICNRDFVFRDLAKAMKVPVLRI...
74,B9DFZ0,reviewed,NTH2_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast str...",MILTGAASTFPIVARVLNAMNRRMYAATTLSSAKSISAESLNLRSD...
...,...,...,...,...,...
15147,Q9SSR1,reviewed,Y1259_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast {EC...",MAILIPASFGRLTITSRAQVRVRVSASANQRTIRRDSVDWVKETSS...
15221,Q9SW33,reviewed,TL1Y_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast thy...",MSLVASLQLILPPRPRSTKLLCSLQSPKQEQELSSTSPPISLLPKL...
15419,Q9ZW12,reviewed,TRNH5_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast {EC...",MVLDMASHLYTNPPQNLHFISSSSSLKPHLCLSFKRINPKHKSSSS...
16220,Q9STN5,reviewed,Y4833_ARATH,"SUBCELLULAR LOCATION: Plastid, chloroplast {EC...",MERSASVGVNDGRFGGNQFYSPSFSSSSSSSSMRHVNYSCGSCGYE...


In [16]:
chloroplastic_non_df = df[~chloroplastic]
chloroplastic_non_df

Unnamed: 0,Entry,Reviewed,Entry Name,Subcellular location [CC],Sequence
1,A0A178VEK7,reviewed,DUO1_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000255|PRO...,MRKMEAKKEEIKKGPWKAEEDEVLINHVKRYGPRDWSSIRSKGLLQ...
2,A0A178WF56,reviewed,CSTM3_ARATH,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MAQYHQQHEMKQTMAETQYVTAPPPMGYPVMMKDSPQTVQPPHEGQ...
3,A0A1I9LMX5,reviewed,PCEP9_ARATH,SUBCELLULAR LOCATION: [C-terminally encoded pe...,MKLLSITLTSIVISMVFYQTPITTEARSLRKTNDQDHFKAGFTDDF...
5,A0A1P8AQ95,reviewed,STMP4_ARATH,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MTKNMTKKKMGLMSPNIAAFVLPMLLVLFTISSQVEVVESTGRKLS...
9,A0JQ18,reviewed,SOP14_ARATH,SUBCELLULAR LOCATION: Cell membrane {ECO:00002...,MAAKTSNLVALLLSLFLLLLSISSQVGLGEAKRNLRNNLRLDCVSH...
...,...,...,...,...,...
16237,Q9SYB0,reviewed,TAUE2_ARATH,SUBCELLULAR LOCATION: Membrane {ECO:0000255}; ...,MRNNFVPIILSFIIFLTPSIAEQEPSILSPVDQLLNKTSSYLDFST...
16241,Q9SYL8,reviewed,Y1786_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000250}.,MAEEQREISHENNVSLGSAETAIPLTNVSISPTKKEEQKTVYLVLF...
16242,Q9SYZ7,reviewed,U496A_ARATH,SUBCELLULAR LOCATION: Membrane {ECO:0000305}; ...,MGNQTSKKSQETSAKSVHYTTELRSYAAACKADTELQSFDTCLQAR...
16259,Q9XIK5,reviewed,Y1045_ARATH,SUBCELLULAR LOCATION: Nucleus {ECO:0000250}.,MAQNKNLNLELSLSQYVEDDPWVLKKKLSDSDLYYSAQLYLPKQEM...


We're almost done! Now, let's make a list of sequences from each df and generate the associated labels. We'll use `1` as the label for Chloroplast proteins and `0` as the label for the rest of the remaining proteins.

In [17]:
chloroplastic_sequences = chloroplastic_df["Sequence"].tolist()
chloroplastic_labels = [1 for protein in chloroplastic_sequences]

In [18]:
chloroplastic_non_df_sequences = chloroplastic_non_df["Sequence"].tolist()
chloroplastic_non_df_labels = [0 for protein in chloroplastic_non_df_sequences]

Now we can concatenate these lists together to get the `sequences` and `labels` lists that will form our final training data. 

In [19]:
sequences = chloroplastic_sequences + chloroplastic_non_df_sequences
labels = chloroplastic_labels + chloroplastic_non_df_labels

# Quick check to make sure we got it right
len(sequences) == len(labels)

True

### If all is correct till here then we can say that our data is loaded correctly

## Splitting the data

Since the data we're loading isn't prepared for us as a machine learning dataset, we'll have to split the data into train and test sets. We can use sklearn's function for that:

In [20]:
from sklearn.model_selection import train_test_split

train_sequences, test_sequences, train_labels, test_labels = train_test_split(sequences, labels, test_size=0.25, shuffle=True)

## Tokenizing the data

All inputs to neural nets must be numerical. The process of converting strings into numerical indices suitable for a neural net is called **tokenization**. For natural language this can be quite complex, as usually the network's vocabulary will not contain every possible word, which means the tokenizer must handle splitting rarer words into pieces, as well as all the complexities of capitalization and unicode characters and so on.

With proteins, however, things are very easy. In protein language models, each amino acid is converted to a single token. Every model on `transformers` comes with an associated `tokenizer` that handles tokenization for it, and protein language models are no different.

In [21]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

Downloading tokenizer_config.json:   0%|          | 0.00/95.0 [00:00<?, ?B/s]

Downloading vocab.txt:   0%|          | 0.00/93.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Let's try a single sequence to see what the outputs from our tokenizer look like:

In [22]:
tokenizer(train_sequences[0])

{'input_ids': [0, 20, 8, 18, 17, 15, 7, 14, 17, 12, 14, 6, 5, 14, 5, 4, 8, 5, 4, 4, 15, 7, 8, 7, 12, 6, 6, 4, 6, 7, 19, 5, 4, 11, 17, 8, 4, 19, 17, 7, 13, 6, 6, 21, 10, 5, 7, 20, 18, 17, 10, 4, 11, 6, 12, 15, 9, 15, 7, 19, 14, 9, 6, 11, 21, 18, 20, 7, 14, 22, 18, 9, 10, 14, 12, 12, 19, 13, 7, 10, 5, 10, 14, 19, 4, 7, 9, 8, 11, 11, 6, 8, 21, 13, 4, 16, 20, 7, 15, 12, 6, 4, 10, 7, 4, 11, 10, 14, 20, 6, 13, 10, 4, 14, 16, 12, 19, 10, 11, 4, 6, 9, 17, 19, 8, 9, 10, 7, 4, 14, 8, 12, 12, 21, 9, 11, 4, 15, 5, 7, 7, 5, 16, 19, 17, 5, 8, 16, 4, 12, 11, 16, 10, 9, 5, 7, 8, 10, 9, 12, 10, 15, 12, 4, 11, 9, 10, 5, 8, 17, 18, 13, 12, 5, 4, 13, 13, 7, 8, 12, 11, 11, 4, 11, 18, 6, 15, 9, 18, 11, 5, 5, 12, 9, 5, 15, 16, 7, 5, 5, 16, 9, 5, 9, 10, 5, 15, 18, 12, 7, 9, 15, 5, 9, 16, 13, 10, 10, 8, 5, 7, 12, 10, 5, 16, 6, 9, 5, 15, 8, 5, 16, 4, 12, 6, 16, 5, 12, 5, 17, 17, 16, 5, 18, 12, 11, 4, 10, 15, 12, 9, 5, 5, 10, 9, 12, 5, 16, 11, 12, 5, 16, 8, 5, 17, 15, 7, 19, 4, 8, 8, 17, 13, 4, 4, 4, 17, 4, 16, 

This looks good! We can see that our sequence has been converted into `input_ids`, which is the tokenized sequence, and an `attention_mask`. The attention mask handles the case when we have sequences of variable length - in those cases, the shorter sequences are padded with blank "padding" tokens, and the attention mask is padded with 0s to indicate that those tokens should be ignored by the model.

So now, let's tokenize our whole dataset. Note that we don't need to do anything with the labels, as they're already in the format we need.

In [23]:
train_tokenized = tokenizer(train_sequences)
test_tokenized = tokenizer(test_sequences)
print("Tokenisation Successfull")

Tokenisation Successfull


If indexing error pops up , reduce the length of the sequences to feed in

## Dataset creation

Now we want to turn this data into a dataset that PyTorch can load samples from. We can use the HuggingFace `Dataset` class for this

In [24]:
from datasets import Dataset
train_dataset = Dataset.from_dict(train_tokenized)
test_dataset = Dataset.from_dict(test_tokenized)

train_dataset

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 6231
})

Labels are yet to be added to our dataset , so we will add them now!

In [25]:
train_dataset = train_dataset.add_column("labels", train_labels)
test_dataset = test_dataset.add_column("labels", test_labels)
train_dataset
test_dataset

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 2077
})

### If all correct till here we can move forward with finetuning the model

## Model loading

If everything is done correctly this should run without any problems

In [26]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

num_labels = max(train_labels + test_labels) + 1  # Add 1 since 0 can be a label
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)

Downloading config.json:   0%|          | 0.00/778 [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/136M [00:00<?, ?B/s]

Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t12_35M_UR50D and are newly initialized: ['classifier.dense.weight', 'classifier.dense.bias', 'classifier.out_proj.weight', 'classifier.out_proj.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


These warnings are telling us that the model is discarding some weights that it used for language modelling (the `lm_head`) and adding some weights for sequence classification (the `classifier`). This is exactly what we expect when we want to fine-tune a language model on a sequence classification task.
So we will be ignoring them.

Next, we initialize our `TrainingArguments`. These control the various training hyperparameters, and will be passed to our `Trainer`.

In [28]:
model_name = model_checkpoint.split("/")[-1]
batch_size = 8

args = TrainingArguments(
    f"{model_name}-finetuned-localization",
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

In [29]:
!pip install evaluate



In [30]:
from evaluate import load
import numpy as np

metric = load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

And at last we're ready to initialize our `Trainer`:

In [31]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

Why do we pass along the `tokenizer` when we already preprocessed our data? This is because we will use it one last time to make all the samples we gather the same length by applying padding, which requires knowing the model's preferences regarding padding (to the left or right? with which token?). The `tokenizer` has a pad method that will do all of this right for us, and the `Trainer` will use it. You can customize this part by defining and passing your own `data_collator` which will receive samples like the dictionaries seen above and will need to return a dictionary of tensors.

We can now finetune our model by just calling the `train` method:

In [32]:
trainer.train()

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

  ········································


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.146773,0.955705
2,0.153800,0.135787,0.958113
3,0.073300,0.132532,0.961001




TrainOutput(global_step=1170, training_loss=0.10612111784454085, metrics={'train_runtime': 755.2856, 'train_samples_per_second': 24.75, 'train_steps_per_second': 1.549, 'total_flos': 1828684459715508.0, 'train_loss': 0.10612111784454085, 'epoch': 3.0})

Output :

`TrainOutput(global_step=1170, training_loss=0.11126196201031024, metrics={'train_runtime': 745.437, 'train_samples_per_second': 25.077, 'train_steps_per_second': 1.57, 'total_flos': 1827808076607552.0, 'train_loss': 0.11126196201031024, 'epoch': 3.0})`

After three epochs we have a model accuracy of ~96%. With a larger starting model and more effort to ensure that the training data categories were cleanly separable, accuracy could almost certainly go a lot higher!