# Data Preparation for Fine-Tuning Transformer Model on Azure Databricks

## Overview
This notebook prepares and saves the datasets that will be used for the fine-tuning of a transformer model. The data is loaded, combined, and stored as Delta tables on Azure Databricks for subsequent machine learning workflows.

## Datasets
- **train_data.jsonl**: Training set to be used for model fine-tuning.
- **val_data.jsonl**: Validation set for model evaluation.
- **test_data.jsonl**: Test set for final assessment.

## Author
- Name: Alessandro Armillotta
- Date: 09/10/2025

# Steps
1. Load JSONL datasets from Azure Databricks Volumes.
2. Combine training and test datasets and prepare the labels
3. Save processed DataFrames as Delta tables for downstream fine-tuning tasks.

In [0]:
from pyspark.sql import functions as F
import pandas as pd

### Step 1: Load JSONL datasets from Azure Databricks Volumes.

In [0]:
test_data = spark.read.json("/Volumes/main/fine_tuning_transformer_model/files/test_data.jsonl")
train_data = spark.read.json("/Volumes/main/fine_tuning_transformer_model/files/train_data.jsonl")
val_data = spark.read.json("/Volumes/main/fine_tuning_transformer_model/files/val_data.jsonl")

### 2. Combine training and test datasets and prepare the labels

In [0]:
# combine training and test data
union_df = test_data.unionAll(train_data)
union_df.count()

In [0]:
display(union_df)

In [0]:
# create label as integer
# some transformer models require to have labels as integer

union_val_tmp = union_df.unionAll(val_data)

In [0]:
# get labels and create label id
labels_df = union_val_tmp.select(union_val_tmp.label).groupBy(union_val_tmp.label).count()
labels = labels_df.collect()

id2label = {index: row.label for (index, row) in enumerate(labels)}
label2id = {row.label: index for (index, row) in enumerate(labels)}

In [0]:
# replace labels with ids
@F.pandas_udf('integer')
def replace_labels_with_ids(labels: pd.Series) -> pd.Series:
  return labels.apply(lambda x: label2id[x])

union_df = union_df.select(replace_labels_with_ids(union_df.label).alias('label_id')
                      ,union_df.text
                      ,union_df.label
                      )

val_data = val_data.select(replace_labels_with_ids(val_data.label).alias('label_id')
                      ,val_data.text
                      ,val_data.label
                      )


### Step 3: Save processed DataFrames as Delta tables for downstream fine-tuning tasks.

In [0]:
union_df.write.mode("overwrite").option("mergeSchema", "true").saveAsTable("main.fine_tuning_transformer_model.train_data")

In [0]:
val_data.count()

In [0]:
val_data.write.mode("overwrite").option("mergeSchema", "true").saveAsTable("main.fine_tuning_transformer_model.val_data")

In [0]:
labels_df.write.mode("overwrite").option("mergeSchema", "true").saveAsTable("main.fine_tuning_transformer_model.labels")