# Fine-tuning Passage Embeddings with GenQ

## Part 2: Fine-tune and Deploy the Embedding Model

In Part 1, we used a generative model to create synthetic queries for each of our document passages. In this notebook, we will use a [Hugging Face Estimator](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-estimator) to finetune our generic embedding model.

## 2.1. Create the custom `train.py` script

Our training script will use the [Sentence-Transformers](https://www.sbert.net/) library, which integrates well with Hugging Face and simplifies the use of embedding models. We can ensure that the `sentence-transformers` package is installed in the Hugging Face DLC by adding a `requirements.txt` to our source directory.

In this case, we will finetune [`distilbert-base-uncased`](https://huggingface.co/distilbert-base-uncased), which is a lightweight version of the BERT base model. To use the model for sentence embeddings instead of the default masked language modeling task, we simply need to add a mean pooling layer to the output.

The other key component of the `train.py` script below is the [`MultipleNegativesRankingLoss`](https://www.sbert.net/docs/package_reference/losses.html#multiplenegativesrankingloss) loss function. Each iteration, we give the model a batch of (query, passage) pairs and have it attempt to make the appropriate associations. Larger batch sizes force the model to be more discriminative. Note that we also have to use the [`NoDuplicatesDataLoader`](https://www.sbert.net/docs/package_reference/datasets.html#noduplicatesdataloader) with this loss function to ensure that we don't end up with multiple queries from the same passage in a batch.

In [1]:
%%writefile scripts/requirements.txt

sentence-transformers

Overwriting scripts/requirements.txt


In [2]:
%%writefile scripts/train.py

from sentence_transformers import InputExample, datasets, models, SentenceTransformer, losses
import boto3
import logging
import sys
import argparse
import os


if __name__ == '__main__':

    parser = argparse.ArgumentParser()

    # hyperparameters sent by the client are passed as command-line arguments to the script.
    parser.add_argument('--epochs', type=int, default=1)
    parser.add_argument('--train_batch_size', type=int, default=12)
    parser.add_argument('--model_name', type=str, default='distilbert-base-uncased')

    # Data, model, and output directories
    parser.add_argument('--bucket', type=str)
    parser.add_argument('--model_dir', type=str, default=os.environ['SM_MODEL_DIR'])
    parser.add_argument('--training_dir', type=str, default=os.environ['SM_HP_TRAINING_DIR'])

    args, _ = parser.parse_known_args()

    # Set up logging
    logger = logging.getLogger(__name__)

    logging.basicConfig(
        level=logging.getLevelName('INFO'),
        handlers=[logging.StreamHandler(sys.stdout)],
        format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
    )

    logger.info('loading dataset from s3')

    # load datasets
    s3_client = boto3.client('s3')

    obj_keys = [obj['Key'] for obj in s3_client.list_objects_v2(Bucket=args.bucket, Prefix=args.training_dir)['Contents']]
    pairs = []
    for key in obj_keys:
        obj = s3_client.get_object(Bucket=args.bucket, Key=key)['Body'].read().decode('utf-8')
        lines = obj.split('\n')
        for line in lines:
            if '\t' not in line:
                continue
            else:
                q, p = line.split('\t')
                pairs.append(InputExample(
                    texts=[q, p]
                ))

    logger.info(f'done. {len(pairs)} pairs loaded.')

    batch_size = args.train_batch_size

    loader = datasets.NoDuplicatesDataLoader(
        pairs, batch_size=batch_size
    )

    logger.info(f'loading model: {args.model_name}')

    base = models.Transformer(args.model_name)
    pooler = models.Pooling(
        base.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True
    )

    model = SentenceTransformer(modules=[base, pooler])

    epochs = args.epochs
    warmup_steps = int(len(loader) * epochs * 0.1)

    loss = losses.MultipleNegativesRankingLoss(model)

    model.fit(
        train_objectives=[(loader, loss)],
        epochs=epochs,
        warmup_steps=warmup_steps,
        output_path=f's3://{args.bucket}/{args.model_dir}',
        show_progress_bar=True
    )

    model.save(args.model_dir)

Overwriting scripts/train.py


**Initialize and launch the training job**

Next, we use a [`HuggingFace`](https://sagemaker.readthedocs.io/en/stable/frameworks/huggingface/sagemaker.huggingface.html#hugging-face-estimator) Estimator to configure our fine-tuning job with our custom entrypoint and hyperparameters.

In [5]:
from sagemaker.huggingface import HuggingFace
import sagemaker
from datetime import datetime

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
bucket = sess.default_bucket()

prefix = 'pubmed-finetuning'

hyperparameters = {
    'epochs': 1,
    'train_batch_size': 24,
    'model_name': 'distilbert-base-uncased',
    'bucket': bucket,
    'training_dir': f'{prefix}/data/training'
}

huggingface_estimator = HuggingFace(
        entry_point='train.py',
        source_dir='./scripts',
        instance_type='ml.p3.2xlarge',
        instance_count=1,
        role=role,
        transformers_version='4.26',
        pytorch_version='1.13',
        py_version='py39',
        hyperparameters=hyperparameters
)

training_job_name = f"distilbert-finetuned-pubmed-{datetime.utcnow().isoformat().replace(':', '-').replace('.', '-')}"

huggingface_estimator.fit({
        'train': f's3://{bucket}/{prefix}/data/training'
    },
    job_name=training_job_name
)

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.
INFO:sagemaker:Creating training-job with name: distilbert-finetuned-pubmed-2023-05-31T13-01-30-427842


2023-05-31 13:01:33 Starting - Starting the training job...
2023-05-31 13:01:58 Starting - Preparing the instances for training.........
2023-05-31 13:03:34 Downloading - Downloading input data
2023-05-31 13:03:34 Training - Downloading the training image........................
2023-05-31 13:07:31 Training - Training image download completed. Training in progress...[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2023-05-31 13:07:51,187 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2023-05-31 13:07:51,205 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2023-05-31 13:07:51,218 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2023-05-31 13:07:51,221 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2023-05-31 13:07:5

**Download model artifacts for finetuned and baseline models**

The output of a Hugging Face training job is a `model.tar.gz` artifact, which contains the model weights and configuration. We will copy it to our notebook and decompress it for local evaluation.

In [6]:
!mkdir distilbert-finetuned-pubmed
!aws s3 cp {huggingface_estimator.model_data} . && tar -xf model.tar.gz -C distilbert-finetuned-pubmed
!ls distilbert-finetuned-pubmed

mkdir: cannot create directory ‘distilbert-finetuned-pubmed’: File exists
download: s3://sagemaker-us-east-1-352937523354/distilbert-finetuned-pubmed-2023-05-31T13-01-30-427842/output/model.tar.gz to ./model.tar.gz
1_Pooling			   sentence_bert_config.json
config.json			   special_tokens_map.json
config_sentence_transformers.json  tokenizer_config.json
modules.json			   tokenizer.json
pytorch_model.bin		   vocab.txt
README.md


For comparison, we will also download the base model directly from Hugging Face. Don't forget to add the pooling layer to use it for sentence embeddings.

In [7]:
from sentence_transformers import models, SentenceTransformer

bert = models.Transformer('distilbert-base-uncased')
pooler = models.Pooling(
    bert.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

model = SentenceTransformer(modules=[bert, pooler])
model.save('distilbert-embedder')

Downloading (…)lve/main/config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_projector.weight', 'vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

INFO:sentence_transformers.SentenceTransformer:Use pytorch device: cuda
INFO:sentence_transformers.SentenceTransformer:Save model to distilbert-embedder


## Bonus: Deploy Finetuned Model to Sagemaker Endpoint

You can also deploy the fine-tuned model to a SageMaker Endpoint for simplified, scalable inferencing. In this case, you will need to override the default `inference.py` that comes in the Hugging Face DLC to use the [`SentenceTransformer`](https://www.sbert.net/docs/package_reference/SentenceTransformer.html) class.

In [13]:
%%writefile scripts/inference.py

from sentence_transformers import SentenceTransformer


def model_fn(model_dir):
    model = SentenceTransformer(model_dir)
    return model, model.tokenizer


def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer

    # Tokenize sentences
    sentences = data.pop('inputs', data)
    
    sentence_embeddings = model.encode(sentences)

    # return dictionary, which will be json serializable
    return {'vectors': sentence_embeddings.tolist()}


Overwriting scripts/inference.py


In [14]:
from sagemaker.huggingface import HuggingFaceModel

finetuned_model = HuggingFaceModel(
    model_data=f's3://{bucket}/{training_job_name}/output/model.tar.gz',
    source_dir='./scripts',
    entry_point='inference.py',
    transformers_version='4.26',
    pytorch_version='1.13',
    py_version='py39',
    role=role
)

finetuned_predictor = finetuned_model.deploy(
   initial_instance_count=1,
   instance_type='ml.g4dn.xlarge'
)

# finetuned_predictor.delete_endpoint()

with open('.endpoint_name', 'w') as f:
    f.write(finetuned_predictor.endpoint_name)

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole
INFO:sagemaker:Creating model with name: huggingface-pytorch-inference-2023-05-31-13-20-43-206
INFO:sagemaker:Creating endpoint-config with name huggingface-pytorch-inference-2023-05-31-13-20-43-976
INFO:sagemaker:Creating endpoint with name huggingface-pytorch-inference-2023-05-31-13-20-43-976


----------!

In [15]:
finetuned_predictor.predict({
    'inputs': ['this is a test sentence', 'this is another sentence']
})

{'vectors': [[0.2859458327293396,
   -0.35708174109458923,
   -0.12256012856960297,
   0.037739675492048264,
   -0.02512454055249691,
   -0.029434809461236,
   1.1277031898498535,
   -0.24892261624336243,
   0.09099698066711426,
   0.25677916407585144,
   -0.5184693336486816,
   -0.6755648851394653,
   -0.028212156146764755,
   -0.32723119854927063,
   0.006148063577711582,
   -0.5086283683776855,
   0.14669542014598846,
   -0.5131253004074097,
   1.1572073698043823,
   -0.3071233332157135,
   1.1385725736618042,
   0.36879268288612366,
   -0.4401373267173767,
   -0.27648404240608215,
   0.7058352828025818,
   -0.4707321524620056,
   -0.03573893383145332,
   0.1498100310564041,
   -0.889588475227356,
   -0.17182762920856476,
   -0.2719089090824127,
   0.3588113486766815,
   -0.5920833945274353,
   0.1139451414346695,
   -0.290635347366333,
   -0.11189478635787964,
   0.24905870854854584,
   -0.2317727953195572,
   -0.021314473822712898,
   -0.03505370765924454,
   -0.031328968703746796

In [16]:
finetuned_predictor.delete_endpoint()

INFO:sagemaker:Deleting endpoint configuration with name: huggingface-pytorch-inference-2023-05-31-13-20-43-976
INFO:sagemaker:Deleting endpoint with name: huggingface-pytorch-inference-2023-05-31-13-20-43-976
