### This tutorial is for how to train ESM on Secondary Structure Prediction task.

#### First lets import the important modules.

In [None]:
from protbench import applications
from protbench import models
from protbench import embedder
from protbench.utils import EmbeddingsDatasetFromDisk
from transformers import TrainingArguments, Trainer

#### Second step is to initialize our models which takes only two steps.
* The first step is to initialize the wrapper object which contains utility functions that abstracts away the differences between pretrained models (e.g. Ankh, ESM, ProtTrans, etc...) so it can help you in using the same script for different models.
* The second step is to initialize the model and the tokenizer.

In [None]:
wrapper = applications.initialize_model_from_checkpoint("esm2", "esm2_650M")
wrapper.initialze_model_from_checkpoint()

#### Third step is to load our dataset which is secondary structure prediction in this tutorial.
* In this tutorial we will extract the embeddings of each sequence from ESM then feed it to a ConvBERT downstream model that's why we set `from_embeddings=True` because our inputs to ConvBERT are going to be the embeddings not the sequences.

In [None]:
ssp_task = applications.SSP3("ssp3_casp12", from_embeddings=True, tokenizer=wrapper.tokenizer)

#### Fourth step is to load our sequences and labels.

In [None]:
train_seqs, y_train = ssp_task.load_train_data()
val_seqs, y_val = ssp_task.load_eval_data()

#### Fifth step is to load our tokenizer using this function.

##### You might ask why not just do `wrapper.tokenizer`? What is the differences between `wrapper.tokenizer` and this function?

##### The answer is that this function returns an objects that wraps `wrapper.tokenizer` and the reason for that is because some models might have non-huggingface tokenizer or some models needs extra preprocessing for the sequences before passing it to the tokenizer example for that is ProtTrans T5 where this model requires replacing "U, Z, O, B" amino acids to "X", so this wrapper tokenizer object is convinent because it does the necessary steps for each model (if there is any). Also in case of extracting embeddings we only need the `input_ids` only so we can do that easily here by setting `return_input_ids_only=True`.

In [None]:
tokenizer = wrapper.load_default_tokenization_function(return_input_ids_only=True)

#### Sixth step is to initialize our embedding extractor, it will take just two steps.

* The first step is to initialize an instance that just stores the paths of the train, validation and test embeddings.
* The second step is to initialize the `ComputeEmbeddingWrapper` instance which expects the following:
    * model: Your pretrained model which is ESM in this tutorial.
    * tokenization_fn: Your tokenization function that should be `callable`
    * forward_options: In case your model takes extra arguments other than the `input_ids`, you can pass it here in a dictionary and it will be passed to the model along with the `input_ids`.
    * post_processing_function: If your model returns multiple outputs then we will need a function that takes these outputs and returns only the embeddings.
    * device: target device.
    * pad_token_id: Padding ID
    * low_memory: If this is `True` then the `ComputeEmbeddingsWrapper` will save the embeddings on the disk, otherwise it will save it to the memory, and if it's `True` then `ComputeEmbeddingsWrapper` expects `save_directories` to have `SaveDirectories` instance as we did below.
    * save_directories: Instance that stores the paths for the embeddings.

In [None]:
save_dirs = embedder.SaveDirectories(parent_dir="./", train_dir="ssp_train_embeddings", validation_dir="ssp_val_embeddings")

embedding_extractor = embedder.ComputeEmbeddingsWrapper(
    model=wrapper.model,
    tokenization_fn=tokenizer,
    forward_options={},
    post_processing_function=wrapper.embeddings_postprocessing_fn,
    device="cpu",
    pad_token_id=tokenizer.tokenizer.pad_token_id,
    low_memory=True,
    save_directories=save_dirs,
)

##### Now lets run it. You will find two directories that are created and you will find inside them the embeddings for each sequence.
##### Each embeddings file will be saved with `.npy` extention and each file will have a numerical name. this is convenient for us when we load them from disk we will load them by index which will be their names.

In [None]:
embedding_extractor(train_seqs=train_seqs, val_seqs=val_seqs)

#### Now lets load our downstream model.

In [None]:
convbert_model = models.ConvBert(
    input_dim=wrapper.embedding_dim,
    nhead=4,
    hidden_dim=wrapper.embedding_dim // 2,
    num_layers=1,
    kernel_size=7,
    dropout=0.1,
    pooling=None,
)

#### This function is just simple wrapper that connects all the models together.

#### Lets explain each argument.
* task: Your target task which is secondary structure in this case.
* embedding_dim: Pretrained model embedding dimension.
* from_embeddings: If the inputs are embeddings or sequences.
* backbone: ESM model that we are using in this tutorial.
* downstream_model: Our ConvBERT downstream model that we are using in this tutorial.
* pooling: Pooling function if we are doing sequence classification/regression.
* embedding_postprocessing_fn: Embedding post processing function if your pretrained model returns many outputs along with the embeddings.

#### Why this function needs the task to be passed to it?
* Each task contain a method that returns it's appropriate head (nn.Module) that will be used during training (try: `ssp_task.load_task_head(wrapper.embedding_dim)`).

#### If we are extracting the embeddings why do we need to pass the pretrained model?
* You can ignore this parameter in case you are loading from embeddings, we are just passing it to show all the expected parameters.

In [None]:
final_model = models.utils.initialize_model(
    task=ssp_task,
    embedding_dim=wrapper.embedding_dim,
    from_embeddings=ssp_task.from_embeddings,
    backbone=wrapper.model,
    downstream_model=convbert_model,
    pooling=None,
    embedding_postprocessing_fn=wrapper.embeddings_postprocessing_fn,
)

#### Lets initialize our dataset, it expects the path to the embeddings and the labels.

##### What is shifting?
* Some models has start of sentence and end of sentence tokens, if you do not want these token embeddings to be included while training you can slice them. In this case ESM uses start and end of sentence tokens so we removed it by setting `shift_left=1` and `shift_right=1`

In [None]:
train_dataset = EmbeddingsDatasetFromDisk("ssp_train_embeddings", y_train, shift_left=1, shift_right=1)
val_dataset = EmbeddingsDatasetFromDisk("ssp_val_embeddings", y_val, shift_left=1, shift_right=1)

##### We are using Huggingface Trainer because its easy to use and fits our pipeline.

In [None]:
args = TrainingArguments(output_dir="ssp_experiment", metric_for_best_model=ssp_task.metric_for_best_model)

trainer = Trainer(
    model=final_model,
    args=args,
    data_collator=ssp_task.collate_fn,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=ssp_task.metrics_fn,
    preprocess_logits_for_metrics=ssp_task.preprocessing_fn,
)

#### We finished! Now go check out `ankh_tutorial.ipynb`, you will find this tutorial has exactly the same steps except the name of the model while loading.