# Resuming Training

When using Trainer locally all the model checkpoints are saved locally, however we can design a solution to sync checkpoints to external storages.
This notebook shows a way to upload checkpoints to either an s3 bucket or a GCS bucket, but can be extended for more storage options.
It also include an example code to retrieve checkpoints from those storage in order to resume training.


1) [Preparation](#preparation)

2) [Checkpoint Callbacks](#checkpoint-callbacks)

3) [Util checkpoint downloader from storages](#checkpoint-downloader)

4) [Model training example](#model-training)


## Preparation

Let's first install the required packages and import the necessary libraries

In [None]:
!pip install -r requirements.txt --quiet

In [None]:
# Import the necessary libraries
from datasets import load_dataset
import evaluate
from transformers import (AutoTokenizer, 
                          DataCollatorWithPadding,
                          TrainingArguments,
                          Trainer,
                          TrainerCallback,
                          AutoModelForSequenceClassification)

from transformers.trainer_callback import TrainerControl, TrainerState

import numpy as np
import os
from tqdm import tqdm

# Import AWS and GCP libraries
import boto3
from google.cloud import storage

Next, we can define some variables used throughout the notebook. You can change those to suit your needs. 

Note that you need to create buckets beforehand (check the links in the README.md)

In [None]:
# Create S3 and GCS clients
s3_client = boto3.client('s3')
gcs_client = storage.Client(project="hf-notebooks")

# Set the training name
training_name = "trainer-demo-checkpointing"

# S3 bucket name (To update)
s3_bucket_name = "hf-demo-s3-checkpointing-sagemaker"

# GCS bucket name (To update)
gcs_bucket_name = "hf-demo-gcs-checkpointing"

## Checkpoint callbacks

Let's define our first callback. 

It will save the model checkpoints to an S3 bucket everytime we save those locally (more info on the [trainer callback and the on save method](https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/callback#transformers.TrainerCallback.on_save)).
We can define how frequent we want those to be uploaded in the training arguments as we will see next. Typically, we can save those every 100 steps like this : 

```python
TrainingArguments(...,
                  save_strategy = "steps",
                  save_steps = 100,
                  ...)
```

We can retrieve the step from the [TrainerState](https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/callback#transformers.TrainerState) state argument.
Model checkpoints are first saved locally with the name {training_name}/checkpoint-{step}, and we can simply upload them to S3. Or we could save those every 5 times for example to reduce costs : 

```python
if (state.global_step%state.save_step % 5 == 0) : 
    # Code to save
```

In [None]:
class SaveCheckpointsToS3Callback(TrainerCallback):
    '''
    This class is a callback that saves the model checkpoints to S3
    '''
    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        '''
        Function called when the model is saved. It uploads the model checkpoint to S3
        '''

        # Get checkpoint folder
        model_checkpoint = "{}/checkpoint-{}".format(training_name, state.global_step)
        
        # Upload all the checkpoint files to the S3
        for filename in os.listdir(model_checkpoint):
            filename_path = "/".join([model_checkpoint,filename])
            s3_client.upload_file(filename_path, s3_bucket_name, filename_path)

The second callback is quite similar, except that we now upload to a GCP bucket.  

In [None]:
class SaveCheckpointsToGCSCallback(TrainerCallback):
    '''
    This class is a callback that saves the model checkpoints to GCS
    '''
    def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
        '''
        Function called when the model is saved. It uploads the model checkpoint to S3
        '''

        # Need to install gcloud if not already installed (https://cloud.google.com/sdk/docs/install)
        # Need to "gcloud auth application-default login" before running this
        # Get checkpoint folder
        model_checkpoint = "{}/checkpoint-{}".format(training_name, state.global_step)
        
        bucket = gcs_client.get_bucket(gcs_bucket_name)
        
        
        for filename in os.listdir(model_checkpoint):
            filename_path = "/".join([model_checkpoint,filename])
            
            blob = bucket.blob(filename_path)
            blob.upload_from_filename(filename_path)

## Checkpoint downloader 

This Python class, CloudCheckpointLoader, is designed to download model checkpoints from either Amazon S3 or Google Cloud Storage (GCS).

The ```download_checkpoints``` method is responsible for downloading the checkpoints from the specified bucket. It first checks the bucket_type and then calls the appropriate method to download the checkpoints. 

The ```__get_last_checkpoint_from_s3``` and ```__get_last_checkpoint_from_gcs``` methods are helper methods used to retrieve the last checkpoint from S3 and GCS respectively. You can change those to retrieve the desired checkpoints.

In [None]:
class CloudCheckpointLoader():
    '''
    This class is used to download the model checkpoints from S3 or GCS.
    '''

    def __init__(self, bucket_type="s3"):
        '''
        Initializes a new instance of the CloudCheckpointLoader class.

        Parameters:
            bucket_type (str): The type of bucket to download checkpoints from. Valid values are "s3" or "gcs".
        '''
        # Set the bucket type (s3 or gcs)
        self.bucket_type = bucket_type

        # Check if the bucket type is valid
        assert self.bucket_type in ["s3", "gcs"], "Invalid bucket type. Please choose either s3 or gcs"
    

    def download_checkpoints(self):
        '''
        Downloads the checkpoints from the specified bucket.

        Returns:
            str: The path to the downloaded checkpoints.
        '''
        
        if self.bucket_type == "s3":

            print("Downloading checkpoints from S3...")
            checkpoint_dir, checkpoint_files = self.__get_last_checkpoint_from_s3()
            path_to_checkpoint = "/".join([training_name,checkpoint_dir])
            os.makedirs(path_to_checkpoint, exist_ok=True)

            for file in tqdm(checkpoint_files):
                s3_client.download_file(s3_bucket_name, file, file)


        elif self.bucket_type == "gcs":

            bucket = gcs_client.get_bucket(gcs_bucket_name)

            print("Downloading checkpoints from GCS...")
            checkpoint_dir, checkpoint_files = self.__get_last_checkpoint_from_gcs(bucket)
            path_to_checkpoint = "/".join([training_name,checkpoint_dir])
            os.makedirs(path_to_checkpoint, exist_ok=True)

            for file in tqdm(checkpoint_files):
                blob = bucket.blob(file)
                blob.download_to_filename(file)

        else:

            raise ValueError("Invalid bucket type. Please choose either s3 or gcs")


    def __get_last_checkpoint_from_s3(self):
        '''
        Retrieves the last checkpoint from S3.

        Returns:
            tuple: A tuple containing the checkpoint directory and a list of checkpoint files.
        '''
        # List all the objects in the bucket
        response = s3_client.list_objects_v2(Bucket="hf-demo-s3-checkpointing-sagemaker")
        
        # Sort the objects by the last modified date
        sorted_content = sorted(response["Contents"], 
                                key=lambda obj: int(obj['LastModified'].strftime('%s')))
        
        # Get the keys of the sorted objects
        sorted_keys =  [obj['Key'] for obj in sorted_content]

        # Return all files from the last checkpoint
        checkpoint_dir = sorted_keys[-1].split("/")[1]
        checkpoint_files = [key for key in sorted_keys if checkpoint_dir in key]

        # Return the last checkpoint files
        return checkpoint_dir, checkpoint_files


    def __get_last_checkpoint_from_gcs(self, bucket):
        '''
        Retrieves the last checkpoint from GCS.

        Parameters:
            bucket (object): The GCS bucket object.

        Returns:
            tuple: A tuple containing the checkpoint directory and a list of checkpoint files.
        '''
        # List all the objects in the bucket
        blobs = bucket.list_blobs()
        
        # Sort the objects by the last modified date
        sorted_content = sorted(list(blobs), 
                                key=lambda obj: int(obj.time_created.strftime('%s')))
        
        # Get the keys of the sorted objects
        sorted_keys =  [obj.name for obj in sorted_content]

        # Return all files from the last checkpoint
        checkpoint_dir = sorted_keys[-1].split("/")[1]
        checkpoint_files = [key for key in sorted_keys if checkpoint_dir in key]

        # Return the last checkpoint files
        return checkpoint_dir, checkpoint_files

## Model training 

The example training code below is from the excellent Hugging Face [tutorial](https://huggingface.co/learn/nlp-course/chapter3/3). It serves as an easy guide to understand the code to add our callbacks.

We start by preparing the dataset.

In [None]:
# Prepare the dataset downloading from the Hub, tokenizing and preparing the data collator
raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


def tokenize_function(example):
    return tokenizer(example["sentence1"], example["sentence2"], truncation=True)
                

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

We then include our callbacks in the Trainer class, and we can as well start a training by specifying a model checkpoint from the S3 or the GCP bucket 

In [None]:
# Define the training arguments
# In particular, the save strategy is set to steps and the save steps is set to 100
training_args = TrainingArguments(training_name, 
                                  evaluation_strategy="epoch", 
                                  save_strategy = "steps", 
                                  save_steps = 100)


# Load the model from the hub or a custom checkpoint
load_from_checkpoint = False
if load_from_checkpoint:

    # Download the model from S3
    cloud_checkpoint_loader = CloudCheckpointLoader(bucket_type="s3")
    custom_checkpoint = cloud_checkpoint_loader.download_checkpoints()
    model = AutoModelForSequenceClassification.from_pretrained(custom_checkpoint, 
                                                               num_labels=2, 
                                                               local_files_only=True)
else:

    model = AutoModelForSequenceClassification.from_pretrained(checkpoint, 
                                                               num_labels=2)




# Instantiate the checkpoint callbacks
save_checkpoints_to_s3 = SaveCheckpointsToS3Callback()
save_checkpoints_to_gcs = SaveCheckpointsToGCSCallback()

# Add the checkpoint callbacks to the trainer
callbacks = [save_checkpoints_to_s3, save_checkpoints_to_gcs]


# Define the compute_metrics function for the evaluation
def compute_metrics(eval_preds):
    metric = evaluate.load("glue", "mrpc")
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)




trainer = Trainer(
    model,
    training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks = callbacks
    )


In [None]:
# Start the training
trainer.train()