# Classification - Fine-tuning FLAN-T5 XL

The following sections detail the fine-tuning process of the FLAN-T5 XL model using SageMaker JumpStart. This process will improve the model's ability to accurately classify complex and ambiguous queries in our dataset.

First, we import the necessary libraries and set up the SageMaker inference instance type.

In [None]:
import json
import utils

inference_instance_type = "ml.g5.2xlarge"

model_id, model_version = "huggingface-text2text-flan-t5-xl", "2.0.0"
base_endpoint_name = model_id

Next, we initialize a base predictor using the `utils` module. This predictor will be utilized to evaluate the model's performance before and after fine-tuning.


In [None]:
base_predictor = utils.get_predictor(
    endpoint_name=base_endpoint_name,
    model_id=model_id,
    model_version=model_version,
    inference_instance_type=inference_instance_type,
)

To validate our deployed model, we conduct a preliminary test using a straightforward prompt: "What is the capital of France?" This helps ensure the model's basic functionality.

In [None]:
prompt = "What is the capital of France?\nResponse:"
response = utils.flant5(base_predictor, user=prompt, max_tokens=2)
print(utils.parse_output(response))

## Data Preparation

Preparing for fine-tuning requires organizing several files, including the dataset and template files. The dataset is structured to align with the required input format for fine-tuning. For example, each record in our training dataset adheres to the following structure:

```json
{"query": "customer query", "response": "main-intent:sub-intent"}
```

In [None]:
intent_dataset_file = "data/intent_dataset.jsonl"
intent_dataset_train_file = "data/intent_dataset_train.jsonl"
intent_dataset_test_file = "data/intent_dataset_test.jsonl"
ft_template_file = "data/template.json"

The following creates a template file which will be used by the jumpstart framework to fine-tune the model. The template has two fields, `prompt` and `completion`. These fields are used to pass labeled data to the model for the fine-tuning process. 

In [None]:
template = {
    "prompt": utils.FT_PROMPT,
    "completion": "{response}",
}

with open(ft_template_file, "w") as f:
    json.dump(template, f)

The training data is uploaded to an S3 bucket, setting the stage for the actual fine-tuning process.

In [None]:
train_data_location = utils.upload_train_and_template_to_s3(
    bucket_prefix="intent_dataset_flant5",
    train_path=intent_dataset_train_file,
    template_path=ft_template_file,
)

## Fine-tuning the model
We configure the JumpStartEstimator, specifying our chosen model and other parameters like instance type and hyperparameters (in this example we use 5 epochs for the training). This estimator will drive the fine-tuning process.

We configure the `JumpStartEstimator`, specifying our chosen model and other parameters such as `instance_type` and hyperparameters. This estimator will guide the fine-tuning process.

In [None]:
from sagemaker.jumpstart.estimator import JumpStartEstimator

estimator = JumpStartEstimator(
    model_id=model_id,
    disable_output_compression=True,
    instance_type="ml.g5.24xlarge",
    role=utils.get_role_arn(),
)

estimator.set_hyperparameters(
    instruction_tuned="True", epochs="5", max_input_length="1024"
)

estimator.fit({"training": train_data_location})

If you experience a disconnection during training, you can rejoin the ongoing training job by using the code below:

```py
estimator = JumpStartEstimator.attach(
    training_job_name="job-name",
    model_id=model_id,
)
estimator.logs()
```

You can locate the `training_job_name` in the AWS console or by using [awscli](https://awscli.amazonaws.com/v2/documentation/api/latest/reference/sagemaker/list-training-jobs.html).


When fine-tuning is completed, we deploy the model to an endpoint. 

In [None]:
finetuned_endpoint_name = "flan-t5-xl-ft-infoext"
finetuned_model_name = finetuned_endpoint_name

If you have already deployed the endpoint, you can run the following code instead of redeploying it:

```python
finetuned_predictor = utils.get_predictor(
    endpoint_name=finetuned_endpoint_name,
)
```

In [None]:
# Deploying the finetuned model to an endpoint
finetuned_predictor = estimator.deploy(
    endpoint_name=finetuned_endpoint_name,
    model_name=finetuned_model_name,
)

Now, let's test the fine-tuned model against its base model with ambiguous queries which we saw in the previous section.

In [None]:
ambiguous_queries = [
    {
        "query": "I want to change my coverage plan. But I'm not seeing where to do this on the online site. Could you please show me how?",
        "main_intent": "techincal_support",
        "sub_intent": "portal_navigation",
    },
    {
        "query": "I'm unhappy with the current benefits of my plan and I'm considering canceling unless there are better alternatives. What can you offer?",
        "main_intent": "customer_retention",
        "sub_intent": "free_product_upgrade",
    },
]
for query in ambiguous_queries:
    question = query["query"]
    print("query:", question, "\n")
    print(
        "expected intent:  ", f"{query['main_intent']}:{query['sub_intent']}"
    )

    prompt = utils.FT_PROMPT.format(query=question)
    response = utils.flant5(base_predictor, user=prompt, max_tokens=13)
    print("base model:  ", utils.parse_output(response))

    response = utils.flant5(finetuned_predictor, user=prompt, max_tokens=13)
    print("finetuned model:  ", utils.parse_output(response))
    print("-" * 80)

As you can see, the fine-tuned model can accurately classify ambiguous queries.

Finally, we evaluate the fine-tuned model's performance using the test dataset to benchmark its overall accuracy, and for each specific intent.


In [None]:
test_dataset = utils.load_dataset(intent_dataset_test_file)

res = utils.evaluate_model(
    predictor=finetuned_predictor,
    llm=utils.flant5,
    dataset=test_dataset,
    prompt_template=utils.FT_PROMPT,
    response_formatter=utils.flant5_output_intent_formatter,
)

utils.print_eval_result(res, test_dataset)

> To evaluate the base model, we can pass `base_predictor` to the `evaluate` function.

In this notebook, we have enhanced our model's performance through the fine-tuning process. By optimizing a smaller model (i.e. FlanT5-XL) for complex classification tasks, we have attained better accuracy compared to the in-context learning approach utilized with much larger models.