# Finetune Passage Embeddings with GenQ

## Part 2: Finetune and Deploy the Embedding Model

In [None]:
%%writefile scripts/requirements.txt

sentence-transformers

In [None]:
%%writefile scripts/train.py

from sentence_transformers import InputExample, datasets, models, SentenceTransformer, losses
import boto3
import logging
import sys
import argparse
import os


if __name__ == "__main__":

    parser = argparse.ArgumentParser()

    # hyperparameters sent by the client are passed as command-line arguments to the script.
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--train_batch_size", type=int, default=12)
    parser.add_argument("--model_name", type=str, default='distilbert-base-uncased')

    # Data, model, and output directories
    parser.add_argument("--bucket", type=str)
    parser.add_argument("--model_dir", type=str, default=os.environ["SM_MODEL_DIR"])
    parser.add_argument("--training_dir", type=str, default=os.environ["SM_HP_TRAINING_DIR"])

    args, _ = parser.parse_known_args()

    # Set up logging
    logger = logging.getLogger(__name__)

    logging.basicConfig(
        level=logging.getLevelName("INFO"),
        handlers=[logging.StreamHandler(sys.stdout)],
        format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
    )

    logger.info("loading dataset from s3")

    # load datasets
    s3_client = boto3.client('s3')

    obj_keys = [obj['Key'] for obj in s3_client.list_objects_v2(Bucket=args.bucket, Prefix=args.training_dir)['Contents']]
    pairs = []
    for key in obj_keys:
        obj = s3_client.get_object(Bucket=args.bucket, Key=key)['Body'].read().decode('utf-8')
        lines = obj.split('\n')
        for line in lines:
            if '\t' not in line:
                continue
            else:
                q, p = line.split('\t')
                pairs.append(InputExample(
                    texts=[q, p]
                ))

    logger.info(f"done. {len(pairs)} pairs loaded.")

    batch_size = args.train_batch_size

    loader = datasets.NoDuplicatesDataLoader(
        pairs, batch_size=batch_size
    )

    logger.info(f"loading model: {args.model_name}")

    base = models.Transformer(args.model_name)
    pooler = models.Pooling(
        base.get_word_embedding_dimension(),
        pooling_mode_mean_tokens=True
    )

    model = SentenceTransformer(modules=[base, pooler])

    epochs = args.epochs
    warmup_steps = int(len(loader) * epochs * 0.1)

    loss = losses.MultipleNegativesRankingLoss(model)

    model.fit(
        train_objectives=[(loader, loss)],
        epochs=epochs,
        warmup_steps=warmup_steps,
        output_path=f's3://{args.bucket}/{args.model_dir}',
        show_progress_bar=True
    )

    model.save(args.model_dir)

In [None]:
from sagemaker.huggingface import HuggingFace
import sagemaker
from datetime import datetime

role = sagemaker.get_execution_role()
sess = sagemaker.Session()
bucket = sess.default_bucket()

prefix = "pubmed-finetuning"

# hyperparameters which are passed to the training job
hyperparameters={
    'epochs': 1,
    'train_batch_size': 24,
    'model_name': 'distilbert-base-uncased',
    'bucket': bucket,
    'training_dir': f'{prefix}/data/training'
}

# create the Estimator
huggingface_estimator = HuggingFace(
        entry_point='train.py',
        source_dir='./scripts',
        instance_type='ml.p3.2xlarge',
        instance_count=1,
        role=role,
        transformers_version='4.26',
        pytorch_version='1.13',
        py_version='py39',
        hyperparameters=hyperparameters
)

training_job_name = f"distilbert-finetuned-pubmed-{datetime.utcnow().isoformat().replace(':', '-').replace('.', '-')}"

huggingface_estimator.fit({
        'train': f's3://{bucket}/{prefix}/data/training'
    },
    job_name=training_job_name
)

In [None]:
!mkdir distilbert-finetuned-pubmed
!aws s3 cp {huggingface_estimator.model_data} . && tar -xf model.tar.gz -C distilbert-finetuned-pubmed
!ls distilbert-finetuned-pubmed

In [None]:
from sentence_transformers import models, SentenceTransformer

bert = models.Transformer('distilbert-base-uncased')
pooler = models.Pooling(
    bert.get_word_embedding_dimension(),
    pooling_mode_mean_tokens=True
)

model = SentenceTransformer(modules=[bert, pooler])
model.save('distilbert-embedder')

## Bonus: Deploy Finetuned Model to Sagemaker Endpoint

In [None]:
%%writefile scripts/inference.py

from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F

# Helper: Mean Pooling - Take attention mask into account for correct averaging
def mean_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


def model_fn(model_dir):
  # Load model from HuggingFace Hub
  tokenizer = AutoTokenizer.from_pretrained(model_dir)
  model = AutoModel.from_pretrained(model_dir)
  return model, tokenizer

def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer

    # Tokenize sentences
    sentences = data.pop("inputs", data)
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')

    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input)

    # Perform pooling
    sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask'])

    # Normalize embeddings
    sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)

    # return dictonary, which will be json serializable
    return {"vectors": sentence_embeddings.tolist()}


In [None]:
from sagemaker.huggingface import HuggingFaceModel

finetuned_model = HuggingFaceModel(
    model_data=f's3://{bucket}/{training_job_name}/output/model.tar.gz',
    source_dir='./scripts',
    entry_point='inference.py',
    transformers_version='4.26',
    pytorch_version='1.13',
    py_version='py39',
    role=role
)

finetuned_predictor = finetuned_model.deploy(
   initial_instance_count=1,
   instance_type="ml.g4dn.xlarge"
)

# finetuned_predictor.delete_endpoint()

with open('.endpoint_name', 'w') as f:
    f.write(finetuned_predictor.endpoint_name)