# Fine-Tuning the GPTJ-6B model using transfer learning

## Overview
This notebook will walk you through how to fine-tune a pre-trained large language model with domain specific knowledge. 

The domain specific dataset that we will be using to fine-tune the the model will be United Kingdom (U.K.) Supreme Court case documents. We will tune the model on roughly 693 legal documents. 

## Dataset info
* <strong>Page count:</strong> ~17,718
* <strong>Word count:</strong> 10,015,333
* <strong>Characters (no spaces):</strong> 49,897,639

The entire dataset is publically available and can be download [here](https://zenodo.org/record/7152317#.ZCSfaoTMI2y)

## Considerations when fine-tuning the model
The notebook has been configured to allow you to only use a subset of the entire dataset to fine-tune the model if you would like. There is a variable named _**doc_count**_ in the _**Data Prep**_ section. You can set this number to whatever you would like and it will only fine-tune the model based on the number of documents you set this variable to. The smaller this value the faster the model will fine-tune.
    
## Training/Tuning Time estimates

Here are the estimated training times based on total number of case documents in the training dataset.

#### All training was ran on 1 - *ml.p3dn.24xlarge* instance

#### <strong>Training dataset document count </strong> 250
Training time: 1 hour 41 minutes

#### <strong>Training document count</strong> 500
Training time: 2 hours 57 minutes

#### <strong>Training document count</strong> 693
Training time: 5 hour 30 minutes


## GPTJ-6B base model

Steps you will go through to test the base model

1. Install needed notebook libraries
3. Configure the notebook to use SageMaker
4. Retrieve base model container
5. Deploy inference endpoint
6. Call inference endpoint to retrieve results from the LLM

## Fine-tuned model

Steps you will go through to test the fine-tuned model

1. Download dataset
2. Prep the dataset
3. Retrieve model container
4. Set hyperparameters for fine-tuning
5. Start training/tuning job
6. Deploy inference endpoint for the fine-tuned model
7. Call inference endpoint for the fine-tuned model
8. Parse endpoint results

### Final Step
* Be sure you delete all models and endpoints to avoid incurring spend
    
### Disclaimer
This notebook demos how you can fine-tune an LLM using transfer learning. Even though this notebook is fine-tuned using actual (U.K.) Supreme Court case documents you should not use this notebook for legal advise.
    
    

## Install Pre Reqs

In [None]:
!pip install --upgrade sagemaker --quiet

## SageMaker SDK configurations

In [1]:
import sagemaker, boto3, json
from sagemaker.session import Session

sagemaker_session = Session()
aws_role = sagemaker_session.get_caller_identity_arn()
aws_region = boto3.Session().region_name
sess = sagemaker.Session()

print(aws_role)



arn:aws:iam::938247108506:role/service-role/AmazonSageMaker-ExecutionRole-20230622T124468


## Deploying interence endpoint for the GPTJ-6 base model

In this section we are deploying the HuggingFace GPTJ-6B base model in order to compare the inference results with the fine-tuned model we will tune later.

The fine-tuned model will be trained on UK Supreme Court case documents.

In [3]:
model_id, model_version = "huggingface-textgeneration1-gpt-j-6b", "*"

In [None]:
from sagemaker import image_uris, model_uris, script_uris
from sagemaker.model import Model
from sagemaker.predictor import Predictor
from sagemaker.utils import name_from_base

endpoint_name = name_from_base(f"jumpstart-example-{model_id}")

inference_instance_type = "ml.g5.12xlarge"

# Retrieve the inference docker container uri.
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)

print(deploy_image_uri)

# Retrieve the model uri.
model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="inference"
)

print(model_uri)

# Create the SageMaker model instance. Note that we need to pass Predictor class when we deploy model through Model class,
# for being able to run inference through the sagemaker API.
model = Model(
    image_uri=deploy_image_uri,
    model_data=model_uri,
    role=aws_role,
    predictor_cls=Predictor,
    name=endpoint_name,
)

# deploy the Model. TODO
base_model_predictor = model.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    endpoint_name=endpoint_name,
)

print(base_model_predictor)

## Inference Helper functions
Creates two helper functions that will be used when we call the inference endpoint

In [None]:
import json
import boto3

def query_endpoint_with_json_payload(encoded_json, endpoint_name):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
    )
    return response


def parse_response_multiple_texts(query_response):
    generated_text = []
    model_predictions = json.loads(query_response["Body"].read())
    return model_predictions[0]

## Call GPTJ-6B inference endpoint
In this section we make a call to the SageMaker inference point that host the base model and have the results returned back from the endpoint.

In [None]:
parameters = {
    "max_length": 500,
    "num_return_sequences": 1,
    "top_k": 250,
    "top_p": 0.8,
    "do_sample": True,
    "temperature": 1,
}

res_gpt_before_finetune = []

for quota_text in [
    "Tell me about the Matrimonial and Family Proceedings Act 1984",
]:
    payload = {"text_inputs": f"{quota_text}:", **parameters}

    query_response = query_endpoint_with_json_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name=endpoint_name
    )
    generated_texts = parse_response_multiple_texts(query_response)[0]["generated_text"]
    res_gpt_before_finetune.append(generated_texts)
    print(generated_texts)
    print("\n")

### Base model results
The output above is what the base model will return to use before fine-tuning the model. It will only return with data that it knows about when the model was pre-trained. The goal is to make the model give us better results after it has more context based on case law that it will be fine-tuned with.

## Clean-up

Delete the SageMaker endpoint and the attached resources once you no longer endpoint them. The inteference endpoints incur cost if you leave them running.

In [None]:
base_model_predictor.delete_model()
base_model_predictor.delete_endpoint()

# Fine-Tuning the GPTJ-6 base model via transfer learning

## Data Prep

Download the dataset. This may take several minutes

In [None]:
!wget https://zenodo.org/record/7152317/files/dataset.zip

In [None]:
# unzipping compressed datasets

print("unzipping file")

!unzip -q dataset.zip

print("finished unzipping file")

## Creating Dataset

In [4]:
import os

# Replace 'path/to/your/directory' with the actual path to your directory containing the text files
directory_path = 'dataset/UK-Abs/train-data/judgement'
train_file = 'dataset/train.txt'

# Replace 'new_file.txt' with the name of the new file where you want to combine the contents
new_file_path = 'dataset/train.txt'

bucket_name = 'shelside-sagemaker' # change this to your bucket name and be sure it exist in S3
training_folder = r'train' # the training folder in your bucket

# number of documents to include the fine-tuning dataset
doc_count = 694
doc_in_dataset = 0

# Loop through each file and append its content to the new file
with open(new_file_path, 'w') as new_file:
    file_list = os.listdir(directory_path)
    for filename in file_list:
        if doc_in_dataset < doc_count:
            doc_in_dataset+=1
            # Create the full file path by joining the directory path with the filename
            file_path = os.path.join(directory_path, filename)

            # Check if the file is a regular file (not a directory)
            if os.path.isfile(file_path):
                # Open the file in read mode
                with open(file_path, 'r') as file:
                    text_content = file.read()

                # Write the content of each file to the new file
                new_file.write(text_content)
                new_file.write("\n-----------------------------------------------------------------\n")
            
            
            
print(doc_in_dataset)
            
print("Training dataset created")

694
Training dataset created


## Upload training data to S3

In [5]:
# uploads training data to S3 so that model can be fine-tune using the dataset
sagemaker_session.upload_data(train_file,
                              bucket=bucket_name, 
                              key_prefix=training_folder)

print("Training data uploaded to S3")

Training data uploaded to S3


## Setup Model to be tuned

When selecting your instance type below ensure you have the minimal available to run based on your account quota. For some GPU based instances you may need to request an increase in the total number you can run in your account. This is true for spot instance type also which have a separate quota. 

You can request a service increase [here](https://us-east-1.console.aws.amazon.com/servicequotas/home/services)

In [6]:
model_id, model_version = "huggingface-textgeneration1-gpt-j-6b", "*"

from sagemaker import image_uris, model_uris, script_uris, hyperparameters

# ml.g5.24xlarge
# training_instance_type = "ml.g5.12xlarge"

# training_instance_type = "ml.g4dn.12xlarge"

training_instance_type = "ml.p3dn.24xlarge" 

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    model_id=model_id,
    model_version=model_version,
    image_scope="training",
    instance_type=training_instance_type,
)

# Retrieve the training script
train_source_uri = script_uris.retrieve(
    model_id=model_id, model_version=model_version, script_scope="training"
)

print(train_source_uri)
# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=model_id, model_version=model_version, model_scope="training"
)

print(train_model_uri)

s3://jumpstart-cache-prod-us-east-1/source-directory-tarballs/huggingface/transfer_learning/textgeneration1/prepack/v1.2.0/sourcedir.tar.gz
s3://jumpstart-cache-prod-us-east-1/huggingface-training/train-huggingface-textgeneration1-gpt-j-6b.tar.gz


## Configure storage locations

In [7]:
# Sample training data is available in this bucket
data_bucket = f"jumpstart-cache-prod-{aws_region}"
data_prefix = "training-datasets"
print(data_bucket)

bucket_name = "shelside-sagemaker"
training_dataset_s3_path = f"s3://{bucket_name}/train/" 
validation_dataset_s3_path = f"s3://{bucket_name}/validation/"

print(training_dataset_s3_path)
print(validation_dataset_s3_path)

output_bucket = sess.default_bucket()
output_prefix = "training"

s3_output_location = f"s3://{bucket_name}/{output_prefix}/output"

print(s3_output_location)


jumpstart-cache-prod-us-east-1
s3://shelside-sagemaker/train/
s3://shelside-sagemaker/validation/
s3://shelside-sagemaker/training/output


## Spot Training configuration
If **use_spot_instances** is set to **True** below training will use spot instances

In [9]:
from sagemaker.utils import name_from_base
training_job_name = name_from_base(f"ssides-hugging-face-{model_id}-transfer-learning")

# We will use spot for training
use_spot_instances = False
max_run = 36000 # in seconds
max_wait = 7200 if use_spot_instances else None # in seconds

checkpoint_s3_uri = None

if use_spot_instances:
    checkpoint_s3_uri = f's3://{bucket_name}/{output_prefix}/checkpoints/{training_job_name}'
    
print (f'Checkpoint uri: {checkpoint_s3_uri}')

Checkpoint uri: None


## Train with Automatic Model Tuning (HPO)
This section configures Automatic Model Tuning if you change from **use_amt = False** to **use_amt = True**. By default we set it to false for this example.

In [10]:
from sagemaker import hyperparameters

# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(model_id=model_id, model_version=model_version)

# [Optional] Override default hyperparameters with custom values
hyperparameters["epoch"] = "3"
hyperparameters["per_device_train_batch_size"] = "4"
hyperparameters["instruction_tuned"] = False

print(hyperparameters)

{'epoch': '3', 'learning_rate': '6e-06', 'per_device_train_batch_size': '4', 'per_device_eval_batch_size': '8', 'warmup_ratio': '0.1', 'instruction_tuned': False, 'train_from_scratch': 'False', 'fp16': 'True', 'bf16': 'False', 'evaluation_strategy': 'steps', 'eval_steps': '20', 'gradient_accumulation_steps': '2', 'logging_steps': '10', 'weight_decay': '0.2', 'load_best_model_at_end': 'True', 'max_train_samples': '-1', 'max_val_samples': '-1', 'seed': '10', 'max_input_length': '-1', 'validation_split_ratio': '0.2', 'train_data_split_seed': '0', 'preprocessing_num_workers': 'None', 'max_steps': '-1', 'gradient_checkpointing': 'True', 'early_stopping_patience': '3', 'early_stopping_threshold': '0.0', 'adam_beta1': '0.9', 'adam_beta2': '0.999', 'adam_epsilon': '1e-08', 'max_grad_norm': '1.0', 'label_smoothing_factor': '0', 'logging_first_step': 'False', 'logging_nan_inf_filter': 'True', 'save_strategy': 'steps', 'save_steps': '500', 'save_total_limit': '1', 'dataloader_drop_last': 'False',

## Set hyperparameters
This section configures any hyperparameter you would like to configure before starting the training job. 

In [11]:
from sagemaker.tuner import ContinuousParameter

# Use AMT for tuning and selecting the best model
use_amt = False

# Define objective metric, based on which the best model will be selected.
amt_metric_definitions = {
    "metrics": [{"Name": "eval:loss", "Regex": "'eval_loss': ([0-9]+\.[0-9]+)"}],
    "type": "Minimize",
}

# You can select from the hyperparameters supported by the model, and configure ranges of values to be searched for training the optimal model.(https://docs.aws.amazon.com/sagemaker/latest/dg/automatic-model-tuning-define-ranges.html)
hyperparameter_ranges = {
    "learning_rate": ContinuousParameter(0.00001, 0.0001, scaling_type="Logarithmic")
}

# Increase the total number of training jobs run by AMT, for increased accuracy (and training time).
max_jobs = 6
# Change parallel training jobs run by AMT to reduce total training time, constrained by your account limits.
# if max_jobs=max_parallel_jobs then Bayesian search turns to Random.
max_parallel_jobs = 2

## Start Training
Here we start our SageMaker training job to tune the model. Depending on how much data is being used, the size of your training instance and the number of instances used for training will dictate how long it will take to train/tune your new model.

If your training job fails because you surpassed your qouta for that instance type you can request an increase in your quota for that instance type [here](https://us-east-1.console.aws.amazon.com/servicequotas/home/services/sagemaker/quotas). You can request an instance quota increase for regular training instances and spot instances.



In [None]:
from sagemaker.estimator import Estimator
from sagemaker.tuner import HyperparameterTuner

metric_definitions = [
    {"Name": "train:loss", "Regex": "'loss': ([0-9]+\.[0-9]+)"},
    {"Name": "eval:loss", "Regex": "'eval_loss': ([0-9]+\.[0-9]+)"},
    {"Name": "eval:runtime", "Regex": "'eval_runtime': ([0-9]+\.[0-9]+)"},
    {"Name": "eval:samples_per_second", "Regex": "'eval_samples_per_second': ([0-9]+\.[0-9]+)"},
    {"Name": "eval:eval_steps_per_second", "Regex": "'eval_steps_per_second': ([0-9]+\.[0-9]+)"},
]

# Create SageMaker Estimator instance
tg_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",
    instance_count=1,
    instance_type=training_instance_type,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
    metric_definitions=metric_definitions,
    checkpoint_s3_uri=checkpoint_s3_uri,
    use_spot_instances=use_spot_instances,
    max_run=max_run,
    max_wait=max_wait,
)

if use_amt:
    hp_tuner = HyperparameterTuner(
        tg_estimator,
        amt_metric_definitions["metrics"][0]["Name"],
        hyperparameter_ranges,
        amt_metric_definitions["metrics"],
        max_jobs=max_jobs,
        max_parallel_jobs=max_parallel_jobs,
        objective_type=amt_metric_definitions["type"],
        base_tuning_job_name=training_job_name,
        instruction_tuned=False
    )

    # Launch a SageMaker Tuning job to search for the best hyperparameters
    hp_tuner.fit({"train": training_dataset_s3_path })
else:
    # Launch a SageMaker Training job by passing s3 path of the training data
    tg_estimator.fit(
        {"train": training_dataset_s3_path}, logs=True
    )

INFO:sagemaker:Creating training-job with name: ssides-hugging-face-huggingface-textgen-2023-08-04-16-25-01-097


2023-08-04 16:25:01 Starting - Starting the training job...
2023-08-04 16:25:18 Starting - Preparing the instances for training............
2023-08-04 16:27:04 Downloading - Downloading input data..................................................................
2023-08-04 16:38:32 Training - Training image download completed. Training in progress..[34mbash: cannot set terminal process group (-1): Inappropriate ioctl for device[0m
[34mbash: no job control in this shell[0m
[34m2023-08-04 16:38:33,780 sagemaker-training-toolkit INFO     Imported framework sagemaker_pytorch_container.training[0m
[34m2023-08-04 16:38:33,837 sagemaker-training-toolkit INFO     No Neurons detected (normal if no neurons installed)[0m
[34m2023-08-04 16:38:33,845 sagemaker_pytorch_container.training INFO     Block until all host DNS lookups succeed.[0m
[34m2023-08-04 16:38:33,847 sagemaker_pytorch_container.training INFO     Invoking user training script.[0m
[34m2023-08-04 16:38:36,448 sagemaker-tr

## Review Training metrics

In [None]:
from sagemaker import TrainingJobAnalytics

if use_amt:
    training_job_name = hp_tuner.best_training_job()
else:
    training_job_name = tg_estimator.latest_training_job.job_name

df = TrainingJobAnalytics(training_job_name=training_job_name).dataframe()
df.head(10)

## Deploy & run Inference on the fine-tuned model

In [None]:
inference_instance_type = "ml.g5.12xlarge"

# Retrieve the docker container uri for inference
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,
    image_scope="inference",
    model_id=model_id,
    model_version=model_version,
    instance_type=inference_instance_type,
)

endpoint_name_after_finetune = name_from_base(f"ssides-hugging-face-{model_id}-")

# Deploy to SageMaker endpoint
finetuned_predictor = (hp_tuner if use_amt else tg_estimator).deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    image_uri=deploy_image_uri,
    endpoint_name=endpoint_name_after_finetune,
)

## Inference Helper functions
Creates two helper functions that will be used when we call the inference endpoint

In [None]:
import json
import boto3


def query_endpoint_with_payload(encoded_json, endpoint_name):
    client = boto3.client("runtime.sagemaker")
    response = client.invoke_endpoint(
        EndpointName=endpoint_name, ContentType="application/json", Body=encoded_json
    )
    return response


def parse_response_multiple_texts(query_response):
    generated_text = []
    model_predictions = json.loads(query_response["Body"].read())
    return model_predictions[0]

In [None]:
parameters = {
    "max_length": 500,
    "num_return_sequences": 1,
    "top_k": 250,
    "top_p": 0.8,
    "do_sample": True,
    "temperature": 1,
}

res_gpt_finetune = []
    
for quota_text in [
    "Tell me about the Matrimonial and Family Proceedings Act 1984",
]:
    payload = {"text_inputs": f"{quota_text}:", **parameters}

    query_response = query_endpoint_with_payload(
        json.dumps(payload).encode("utf-8"), endpoint_name_after_finetune
    )
    generated_texts = parse_response_multiple_texts(query_response)[0]["generated_text"]
    res_gpt_finetune.append(generated_texts)
    print(generated_texts)
    print("\n")

In [None]:
# Delete the SageMaker endpoint and the attached resources
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()