# Handling Local Data
To load datasets that are stored either on your laptop or on a remote server, we can still use the `load_dataset()` function. This time, we just need to specify the type of loading script in the `load_dataset()` function, along with a `data_files=''` argument that specifies the path to one or more files.

!["load_dataset()"](data/chapter_5/load_dataset.png "load_dataset()")

### Loading a local dataset

| Data format | Loading script | Example |
|-------------|----------------|---------|
| CSV & TSV |`csv`|`load_dataset("csv", data_files="my_file.csv")`|
| Text files |`text`|`load_dataset("text", data_files="my_file.txt")`|
| JSON & JSON Lines |`json`|`load_dataset("json", data_files="my_file.json")`|
| Pickled DataFrames |`pandas`|`load_dataset("pandas", data_files="my_dataframe.pkl")`|

For this example, let's use the [SQuAD-it](https://github.com/crux82/squad-it/) dataset, which is a large-scale **json** dataset for question answering in Italian. It's hosted on GitHub, let's first download it in our `data/chapter_5` dir using `wget` and then decompress these compressed files `SQuAD_it-train.json.gz`, `SQuAD_it-test.json.gz` using `gzip`:

In [None]:
!cd data/chapter_5 && wget https://github.com/crux82/squad-it/raw/master/SQuAD_it-train.json.gz
!cd data/chapter_5 && wget https://github.com/crux82/squad-it/raw/master/SQuAD_it-test.json.gz

!cd data/chapter_5 && gzip -dkv SQuAD_it-*.json.gz

Now that we have our data in the `JSON` format, we can simply use the `load_dataset()` function, we just need to know if we’re dealing with **ordinary JSON** (*similar to a nested dictionary*) or **JSON Lines** (*line-separated JSON*). Like many question answering datasets, **SQuAD-it** uses the *nested format*, with all the text stored in a **data field**. This means we can load the dataset by specifying the `field='data'` argument:

In [None]:
from datasets import load_dataset

squad_it_dataset = load_dataset("json", data_files="data/chapter_5/SQuAD_it-train.json", field="data")

squad_it_dataset

As we can see, by default, loading local files creates a `DatasetDict` object with only a **train** split. But, what we really want is to include both the **train** and **test** splits in a single `DatasetDict` object so we can apply `Dataset.map()` functions across both splits at once. To do this, we can provide a dictionary to the 
```python
data_files={"train":"path to the training data", "test":"path to the testing data"}
```
argument that maps each split name to a file associated with that split:

In [None]:
data_files = {
    "train":"data/chapter_5/SQuAD_it-train.json",
    "test":"data/chapter_5/SQuAD_it-test.json"
}
squad_it_dataset = load_dataset("json", data_files=data_files, field="data")
squad_it_dataset

The loading scripts in Datasets actually support automatic decompression of the input files, so we could have skipped the use of gzip by pointing the `data_files` argument directly to the compressed files:
```python
data_files = {
    "train": "data/chapter_5/SQuAD_it-train.json.gz", 
    "test": "data/chapter_5/SQuAD_it-test.json.gz"
}
squad_it_dataset = load_dataset("json", data_files=data_files, field="data")
```
This can be useful if you don’t want to manually decompress many `GZIP` files. The automatic decompression also applies to other common formats like `ZIP` and `TAR`, so you just need to point `data_files` to the compressed files.

> The `data_files` argument is also quite flexible and can be either *a single file path*, *a list of file paths*, or *a dictionary* that maps split names to file paths. You can also *glob files* that match a *specified pattern* according to the rules used by the `Unix shell` (e.g., you can glob all the `JSON` files in a directory as a single split by setting `data_files="*.json"`). See the [Datasets documentation](https://huggingface.co/docs/datasets/loading#local-and-remote-files) for more details.

### Loading a remote dataset

Fortunately, loading *remote files* is just as simple as loading *local* ones!
<br />
Instead of providing a path to *local files*, we point the `data_files` argument to **one or more URLs** where the *remote files* are stored.

In [None]:
url =  "https://github.com/crux82/squad-it/raw/master/"

data_files = {
    "train": url + "SQuAD_it-train.json.gz",
    "test": url + "SQuAD_it-test.json.gz",
}

squad_it_dataset = load_dataset("json", data_files=data_files, field="data")
squad_it_dataset

# Data Manipulation

The `DatasetDict` object comes with a lot of functionalities to manipulate the original dataset.
<br />
For this example, we’ll use the [Drug Review Dataset](https://archive.ics.uci.edu/ml/datasets/Drug+Review+Dataset+%28Drugs.com%29) that’s hosted on the [UC Irvine Machine Learning Repository](https://archive.ics.uci.edu/ml/index.php), which contains patient reviews on various drugs, along with the condition being treated and a 10-star rating of the patient’s satisfaction.

In [None]:
!cd data/chapter_5/ && wget "https://archive.ics.uci.edu/ml/machine-learning-databases/00462/drugsCom_raw.zip"
!cd data/chapter_5/ && unzip drugsCom_raw.zip

As we can see, this the data is in the `TSV` format which is a variant of `CSV` that uses tabs instead of commas as the separator. So, when loading these files using `load_dataset()`, we use the specify `csv` as the *loading script* and most importantly the `delimiter=\t` argument:

In [None]:
from datasets import load_dataset

data_files = {
    "train" : "data/chapter_5/drugsComTrain_raw.tsv",
    "test" : "data/chapter_5/drugsComTest_raw.tsv"
}

drug_dataset = load_dataset("csv", data_files=data_files, delimiter="\t")

Now that we have the `DatasetDict` object, we can create a random sample to get a quick feel for the type of data you’re working with and to do so we simply have to chain the `Dataset.shuffle()` and `Dataset.select()` function to first randomly shuffle the data  (we can also pass the `seed` argument to later use the same shuffle) and select/see the first *n* data elements:

In [None]:
drug_sample = drug_dataset["train"].shuffle(seed=42).select(range(1000))

drug_sample[:3]

From above we can see before passing this data to the model or even for tokenisation we need to perform few pre-processing steps:
  + The `Unnamed: 0` column needs to be renamed to `patient_id`.
  + The `condition` column includes a mix of *uppercase* and *lowercase* labels.
  + The `reviews` are of varying length and contain a mix of Python line separators `(\r\n)` as well as HTML character codes like `&\#039;`.

So, we can use the in-built functions like the, `rename_column()` - to rename the column name, `map()` and `filter()` - to map all the `condition` column values to lowercase, and also filter out the special characters.

In [None]:
import html

# rename the column name
drug_dataset = drug_dataset.rename_column(
    original_column_name="Unnamed: 0",
    new_column_name="patient_id"
)

# map conditon column values to lowercase
def lowercase_condition(data):
    return {"condition": [row.lower() for row in data["condition"]]}
    # return {"condition": data["condition"].lower()} # if not using batched=True in the map() function
    

# let's first remove all the rows with null values, otherwise the above
# function will throw an error
drug_dataset = drug_dataset.filter(lambda x: x["condition"] is not None)

# map lowercasse
drug_dataset = drug_dataset.map(lowercase_condition, batched=True)


# unescape all the HTML special characters in our corpus
drug_dataset =  drug_dataset.map(
    lambda x: {"review": [html.unescape(row) for row in x["review"]]},
    batched=True
)


drug_dataset["train"][:2]

>In Python, `lambda` functions are small functions that you can define without explicitly naming them. They take the general form `lambda <arguments> : <expression>`,
where `lambda` is one of Python’s special keywords, `<arguments>` is a list/set of *comma-separated values* that define the *inputs* to the function, and `<expression>` represents the operations you wish to execute. For example, we can define a simple lambda function that squares a number as follows: `lambda x : x * x`
To apply this function to an input, we need to wrap it and the input in parentheses:
`(lambda x: x * x)(3) -> 9`

### From Datasets to DataFrames and back

We can use the the `set_format()` function of the `DatasetDict` object to convert it into a different dataframe such as *Pandas*, *NumPy*, *PyTorch*, *TensorFlow*, and *JAX*. To convert it back to the `DatasetDict` object, we simply need to call the `reset_format()` function

In [None]:
drug_dataset.set_format("pandas")

drug_dataset["train"][:3]

In [None]:
drug_dataset.reset_format()

drug_dataset["train"][:3]

### Creating a validation set
The `DatasetDict` object also provides a `Dataset.train_test_split()` function that is based on the famous functionality from `scikit-learn` which can be used to further split the data into a train-validation-test format.


In [None]:
# 80-20 percent train-validation split on the training dataset
drug_dataset_clean = drug_dataset["train"].train_test_split(train_size=0.8, seed=41)

# name the 20% split data as the validation
drug_dataset_clean["validation"] = drug_dataset_clean.pop("test")

# Add the orignal test dataset
drug_dataset_clean["test"] = drug_dataset["test"]

drug_dataset_clean

### Saving a dataset
To save a dataset to disk:

| Data format | Function |
|-------------|----------|
|*Arrow*|`Dataset.save_to_disk()`|
|*CSV*|`Dataset.to_csv()`|
|*JSON*|`Dataset.to_json()`|

For example, let’s save our cleaned dataset in the Arrow format:

In [None]:
drug_dataset_clean.save_to_disk("data/chapter_5/drug-reviews")

!ls data/chapter_5/drug-reviews/*

Once the dataset is saved, we can load it by using the load_from_disk() function as follows:

In [None]:
from datasets import load_from_disk

drug_dataset_reloaded = load_from_disk("data/chapter_5/drug-reviews")
drug_dataset_reloaded

For the **CSV** and **JSON** formats, we have to store each split as a separate file. One way to do this is by iterating over the keys and values in the `DatasetDict` object. This saves each split in JSON Lines format, where each row in the dataset is stored as a single line of JSON.

In [None]:
for split, dataset in drug_dataset_clean.items():
    dataset.to_json(f"data/chapter_5/drug-reviews-{split}.jsonl")

And to load the data we can simply use the `load_dataset()` function:

In [None]:
data_files = {
    "train": "data/chapter_5/drug-reviews-train.jsonl",
    "validation": "data/chapter_5/drug-reviews-validation.jsonl",
    "test": "data/chapter_5/drug-reviews-test.jsonl",
}
drug_dataset_reloaded = load_dataset("json", data_files=data_files)

drug_dataset_clean = drug_dataset_reloaded

drug_dataset_clean

## Example

Let's train a classifier that can predict the patient condition based on the drug review.

1. Downloading and preparing the data (no need to run the cell below, if `drug_dataset_reloaded` object is still there or simply run the above cell if the right files are there):

In [None]:
from datasets import load_dataset
import html

# download the data
!cd data/chapter_5/ && curl -O "https://archive.ics.uci.edu/ml/machine-learning-databases/00462/drugsCom_raw.zip"
!cd data/chapter_5/ && unzip -o drugsCom_raw.zip


# load the data
data_files = {
    "train" : "data/chapter_5/drugsComTrain_raw.tsv",
    "test" : "data/chapter_5/drugsComTest_raw.tsv"
}

drug_dataset = load_dataset(
    "csv",
    data_files=data_files,
    delimiter='\t'
)


# filter out the data
## rename the column
drug_dataset = drug_dataset.rename_column(
    original_column_name="Unnamed: 0",
    new_column_name="patient_id"
)

## map condition column values to lowercase
def lowercase_condition(data):
    return {"condition": [row.lower() for row in data["condition"]]}

## filter out the null rows
drug_dataset = drug_dataset.filter(lambda x: x["condition"] is not None)

## map lowercase
drug_dataset = drug_dataset.map(lowercase_condition, batched=True)


# unescape all the special characters
drug_dataset = drug_dataset.map(
    lambda x: {"review": [html.unescape(review_row) for review_row in x["review"]]},
    batched=True
)

drug_dataset = drug_dataset.map(
    lambda x: {"condition": [html.unescape(condition_row) for condition_row in x["condition"]]},
    batched=True
)


# Creating validation set
drug_dataset_clean = drug_dataset["train"].train_test_split(
    train_size=0.8,
    seed=41
)

drug_dataset_clean["validation"] = drug_dataset_clean.pop("test")


drug_dataset_clean["test"] = drug_dataset["test"]

drug_dataset_clean

As we can have quite a lot of data `200k+` in total. Let's just for the sake of making the training process faster, only take `5%` randomly shuffled sample of each split for the training and evaluating the model.
> Note if you would like to train the model on the whole data, simply skip the cell below.

In [None]:
# to only take 5% of the data per split 
pct = 0.05

drug_dataset_clean["train"] = drug_dataset_clean["train"].shuffle(seed=42).select(range(int(pct*len(drug_dataset_clean["train"]))))
drug_dataset_clean["validation"] = drug_dataset_clean["validation"].shuffle(seed=42).select(range(int(pct*len(drug_dataset_clean["validation"]))))
drug_dataset_clean["test"] = drug_dataset_clean["test"].shuffle(seed=42).select(range(int(pct*len(drug_dataset_clean["test"]))))

drug_dataset_clean

Since there are many values in the `condition` column, let's first have a look at all of them and only select the first 5 conditions that occurs the most as the *labels* for this *classification task* and mark all of the *others* as `other`. We can use `Counter()` method from the `collections` class to get the distribution over the `condition` column:

In [None]:
from collections import Counter

# simple case – one label per example
train_counts = Counter(drug_dataset_clean["train"]["condition"])
print(f"Train dataset condition distribution:\n\t{train_counts}")

We can see that, `birth control`, `depression`, `pain`, `anxiety` and `acne`, are the top 5 conidtions that occurs the most in our dataset. 
In Machine Learning, there is a concept called **stratification**, which states that a model will only perform well when the data distribution between the **train**, **validation** and **test** follows the same pattern, i.e., stratified. So always make sure the distribution stratified when splitting the data or choosing `labels`. Let's also make sure that is the case here, meaning for the `validation` and `test` dataset we would also like to see that  `birth control`, `depression`, `pain`, `anxiety` and `acne` are the top 5 conidtions.

> Note: Normally, you first choose the labels, refine the data and then only split the data into train-validation-test split using *stratification* over the refined labels. Here we did the otherway around.

In [None]:
validation_counts = Counter(drug_dataset_clean["validation"]["condition"])
print(f"Validation dataset condition distribution:\n\t{validation_counts}")

test_counts = Counter(drug_dataset_clean["test"]["condition"])
print(f"Test dataset condition distribution:\n\t{test_counts}") 

Now that we know the data will be well stratified when we will only take the first 5 `condition` value for the classification task. Let's create a new columns `labels` where we transfer these values only if the row's `condition` column value corresponds to one of these 5 conditions, otherwise we will put `other` there.

In [None]:
allowed_conditions = ['birth control', 'depression', 'pain', 'anxiety', 'acne']

drug_dataset_clean = drug_dataset_clean.map(
    lambda x: {"labels": [c if c in allowed_conditions else 'other' for c in x['condition']]},
    batched=True
)

drug_dataset_clean

Now let's use the in-built `unique()` method to see how many values are there in the `labels` columns:

In [None]:
drug_dataset_clean["train"].unique('labels')

In [None]:
train_label_counts = Counter(drug_dataset_clean["train"]["labels"])
print(f"train dataset labels distribution:\n\t{train_label_counts}") 

Now, since this task is a *Multi-label classification* task, therefore we need to convert the text values in the `labels` columns, `other`, `birth control`, `depression`, `pain`, `anxiety` and `acne` into discreet numerical values to represent them as **labels** for the model. Luckily, the `DatasetDict` object has `class_encode_column()` function to handle this task for us in-place:

In [None]:
# encode the labels to the right form
drug_dataset_clean = drug_dataset_clean.class_encode_column("labels")

In [None]:
print(drug_dataset_clean["train"].features["labels"])

label_names = drug_dataset_clean["train"].features["labels"].names

print(drug_dataset_clean["train"].unique("labels"))

Now, whenever we are dealing with customer reviews, it is a good practice to check the number of words in each review. A review might be just a single word like “Great!” or a full-blown essay with thousands of words, and depending on the use case you’ll need to handle these extremes differently. In our case, some reviews containing just a single word, which, although it may be okay for **sentiment analysis**, would not be informative when preddicting the condition.
<br />
So, to compute the number of words in each review, we’ll use a rough heuristic based on splitting each text by whitespace and use the `filter()` function to remove reviews that contain fewer than **30 words**:

In [None]:
print("Before review filtering:")
print(drug_dataset_clean)

# returns a new column with row's review corresponding length
def compute_review_length(data):
    return {"review_length": [len(row.split()) for row in data["review"]]}

# map the review_length column
drug_dataset_clean  = drug_dataset_clean.map(
    compute_review_length,
    batched=True
)

# filter out rows that has review_length length less than and qual to 30
drug_dataset_clean = drug_dataset_clean.filter(
    lambda x: x["review_length"] > 30
)

print("After review filtering:")
drug_dataset_clean 

Now that we have cleaned our data, let's first gather the tokeniser and the model, tokenise the data, and refine it all for once. When initialising the model, we also have to specify the `num_labels=6` arguments because we are training the model for a multi-class classification task and there are `6` labels in total:

In [None]:
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from pprint import pprint

checkpoint = "bert-base-uncased"
tokeniser = AutoTokenizer.from_pretrained(checkpoint)
# initialise the model and also specify the number of labels
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=6)
data_collator = DataCollatorWithPadding(tokenizer=tokeniser)

def tokenisation_function(data):
    return tokeniser(data['review'], truncation=True)

tokenised_datasets = drug_dataset_clean.map(
    tokenisation_function,
    batched=True
)


pprint(tokenised_datasets)


tokenised_datasets = tokenised_datasets.remove_columns(
    column_names=['patient_id', 'drugName', 'condition', 'review', 'rating', 'date', 'usefulCount', 'review_length']
)

tokenised_datasets.set_format("torch")

pprint(tokenised_datasets)

seperate and divide the data into batched using dataloader:

In [None]:
train_batch_size = 32
eval_batch_size = min(128, len(tokenised_datasets["validation"]))
test_batch_size = min(128, len(tokenised_datasets["test"]))

train_dataloader = DataLoader(
    dataset=tokenised_datasets["train"],
    batch_size=train_batch_size,
    shuffle=True,
    collate_fn=data_collator
)

eval_dataloader = DataLoader(
    dataset=tokenised_datasets["validation"],
    batch_size=eval_batch_size,
    collate_fn=data_collator
)

test_dataloader = DataLoader(
    dataset=tokenised_datasets["test"],
    batch_size=test_batch_size,
    collate_fn=data_collator
)


print(f"So there are, {len(train_dataloader)} batches of size {train_batch_size} in the training dataset,\n\t{len(eval_dataloader)} batches of size {eval_batch_size} in the evaluation dataset, and\n\t {len(test_dataloader)} batches of size {test_batch_size} in the test dataset")

initialise the *accelerator*, *optimisor* and *learning rate scheduler* object:

In [None]:
from accelerate import Accelerator
from torch.optim import AdamW
from transformers import get_scheduler

optimiser = AdamW(
    params=model.parameters(),
    lr=2e-5
)

accelerator = Accelerator()
train_dl, eval_dl, test_dl, model, optimiser = accelerator.prepare(
    train_dataloader,
    eval_dataloader,
    test_dataloader,
    model,
    optimiser
)

num_epochs = 5
num_training_steps = num_epochs * len(train_dl)

num_warmup_steps = .01 * num_training_steps

lr_schedular = get_scheduler(
    name="linear",
    optimizer=optimiser,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps
)

print(f"Total training steps {num_training_steps}")

Now this time since there is no pre-evaluation metric present, therefore we have to define which metrics to use when evaluating our model.
<br />
For a *Classification* task,  the best metrics to evalute a model are **Accuracy**, **Precision**, **Recall** and **F1 Score**. The latter three metrics are detrived from a **Confusion Matrix**, which is basically a `N X N matrix`, where `N` is the *number of classes or categories* that are to be predicted. The values inside the *confusion matrix* represents one of these 4 values:
+ **True Positives (TP)** : It is the case where we predicted Yes and the real output was also Yes.
+ **True Negatives (TN)**: It is the case where we predicted No and the real output was also No.
+ **False Positives (FP)**: It is the case where we predicted Yes but it was actually No.
+ **False Negatives (FN)**: It is the case where we predicted No but it was actually Yes. 

For example, suppose there is a problem which is a binary classification with labels as `Yes` or `No`. So, here `N = 2`, therfore we will get a `2 X 2` *confusion matrix*. Now let's say we tested our model with 165 samples and the results using *confusion matrix* looks like this:

|              |Predicted No|Predited Yes|
|--------------|------------|------------|
|**Actual No** |50|10|
|**Actual Yes**|5|100|

Therefore, out of the 165 predictions, `100` predictions were **TP** (bottom right), `50` were **TN** (top left), `10` were **FP** (top right), and `5` were *FN* (bottom left).


Now, how these values are useful because we can use them to calculate **Precision**, **Recall** and **F1 Score**:
+ **Precision**: It measures how many of the positive predictions made by the model are actually correct. It's useful when the cost of false positives is high such as in medical diagnoses where predicting a disease when it’s not present can have serious consequences. Therefore, *Precision* helps ensure that when the model predicts a positive outcome, it’s likely to be correct.
$$
\text{Precision} = \frac{TP}{TP+FP}
$$
+ **Recall**: *Recall* or *Sensitivity measures* how many of the actual positive cases were correctly identified by the model. It is important when missing a positive case (*false negative*) is more costly than false positives (like disease detection).
$$
\text{Recall} = \frac{TP}{TP+FN}
$$
+ **F1 Score**: The *F1 Score* is the *harmonic mean* of *precision* and *recall*. It is useful when we need a balance between *precision* and *recall*, as it combines both into a single number. A *high F1 score* means the model performs well on both metrics, i.e., the model is performing well. Its range is `[0,1]`:
$$
\text{F1 Score}=2\times\frac{Precision+Recall}{Precision×Recall} 
$$
Now, when you have multiple classes, you still often want a single precision/recall/F1 number—but how you combine per-class scores depends on whether you care more about rare classes, common classes, or every example equally. Here’s what each averaging strategy does:

+ **Weighted**: Compute each class’s score, then average them but weight by how many true examples each class has - so common labels count more.

+ **Micro**: Pool all true/false positives and negatives across every example, then compute one overall score - every prediction is equal (large classes dominate).

+ **Macro**: Compute each class’s score and then take the simple average—every class counts the same, no matter how many examples it has.

> NOTE: **Lower recall** and **higher precision** gives us **great accuracy** but then it misses a large number of instances and that's why **accuracy** alone is not a good metric when evaluating a model and using **Recall**, **Precision** and **F1 score** if possible is a good practice.

Luckily, the `evaluate` lib provides `combine()` method, where you can specify which metrics to use for the evaluation, and also when calling the `compute()` we can pass the `average` argument to specify which averaging strategy to use:

In [None]:
from evaluate import combine
import torch
from livelossplot import PlotLosses
from tqdm.notebook import tqdm

def perform_evaluation():
    """
    Perform evaluation on the validation set
    """
    # Set model to evaluation mode
    model.eval()

    eval_epoch_loss = []

    # initialising evaluation
    eval_metric = combine(
        evaluations=[
            "accuracy",
            "precision",
            "recall",
            "f1"
        ]
    )

    for batch in eval_dl:
        # Disable gradient computation for evaluation (saves memory and computation)
        with torch.no_grad():
            outputs = model(**batch)
            # Store loss inside no_grad for memory efficiency
            eval_epoch_loss.append(outputs.loss.item())

            # Get predictions for metrics (logits already created without gradients)
            logits = outputs.logits
            refs = batch["labels"]
            preds = torch.argmax(logits, dim=-1)

            # Add batch to evaluation metric
            eval_metric.add_batch(
                predictions=accelerator.gather(preds),
                references=accelerator.gather(refs)
            )
    
    eval_avg_loss = sum(eval_epoch_loss) / len(eval_epoch_loss)
    eval_pred_stats = eval_metric.compute(
        kwargs={
            "average":"weighted"
        }
    )

    return eval_avg_loss, eval_pred_stats

For after we will have the model trained we would like to evaluate its performance on the test data because testing on untouched data gives a true measure of how our model will perform on new examples and prevents us from overfitting by tuning to the same data we used to train it. 
<br />
So, let's write the evaluation function on the test data, and this time we can also ask for the *confusion_matrix* from the `evalute.compute()` function along with other metrics to further evalute the model on the test data:

> Note: when evaluating the model on the test data we don't need to look at the loss value

In [None]:
def test_evaluation():
    model.eval()

    test_metric = combine(
        evaluations=[
            "accuracy",
            "precision",
            "recall",
            "f1",
            "confusion_metrix"
        ]
    )

    for batch in test_dl:
        with torch.no_grsd():
            outputs = model(**batch)

            logits = outputs.logits
            refs = batch["labels"]
            preds = torch.argmax(logits, dim=-1)

            test_metric.add_batch(
                predictions=accelerator.gather(preds),
                references=accelerator.gather(refs)
            )
    
    return test_metric.compute(
        kwargs={
            "average":"weighted",
            "labels": list(range(len(label_names)))
        }
    )

In [None]:
import matplotlib.pyplot as plt

def plot_cm(cm):
    plt.figure(figsize=(8,6))
    plt.imshow(cm, interpolation="nearest", cmap=plt.cm.Blues)
    plt.xticks(range(len(label_names)), label_names, rotation=45)
    plt.yticks(range(len(label_names)), label_names)
    plt.colorbar()
    plt.xlabel("Predicted label")
    plt.ylabel("True label")
    plt.title("Confusion Matrix")
    plt.tight_layout()
    plt.show()

the main training function:

In [None]:
import pandas as pd

progress_bar = tqdm(range(num_training_steps))

def training_function():
    # initialise the plotter for the learning curve
    plotter = PlotLosses(mode='notebook')

    for epoch in range(num_epochs):
        # ensure model is in training mode
        model.train()
        train_epoch_loss = []

        # metrics for training data
        train_metric = combine(
            evaluations=[
                "accuracy",
                "precision",
                "recall",
                "f1"
            ]   
        )

        for batch in train_dl:
            # Forward Pass (keep gradient attached)
            outputs = model(**batch)
            loss = outputs.loss

            # Backward Pass (while gradients are still attached)
            accelerator.backward(loss)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimiser.step()
            lr_schedular.step()
            optimiser.zero_grad()

            # metric computation
            with torch.no_grad():
                # detach loss for metric computation 
                train_epoch_loss.append(loss.detach().item())

                # detach logits for metric computation
                logits = outputs.logits.detach()
                # no need to detach labels (they don't have gradients)
                refs = batch['labels']
                preds = torch.argmax(logits, dim=-1)

                # add batch to the train matric
                train_metric.add_batch(
                    predictions=accelerator.gather(preds),
                    references=accelerator.gather(refs)
                )
            
            progress_bar.update(1)

        # compute training metrics
        tain_avg_loss = sum(train_epoch_loss)/len(train_epoch_loss)
        train_pred_stats = train_metric.compute(
            kwargs={
                "average":"weighted"
            }
        )

        # evaluation phase
        eval_avg_loss, eval_pred_stats = perform_evaluation()

        plotter.update({
            'loss': tain_avg_loss,
            'val_loss': eval_avg_loss,
            'acc': train_pred_stats['accuracy'],
            'val_acc': eval_pred_stats['accuracy'],
            'precision': train_pred_stats['precision'],
            'val_precision': eval_pred_stats['precision'],
            'recall': train_pred_stats['recall'],
            'val_recall': eval_pred_stats['recall'],
            'f1': train_pred_stats['f1'],
            'val_f1': eval_pred_stats['f1'],
        })
        plotter.send() 

    # perform evaluation on the test dataset
    test_pred_stats = test_evaluation()
    plotter.update({
        'test_acc': test_pred_stats['accuracy'],
        'test_precision': test_pred_stats['precision'],
        'test_recall': test_pred_stats['recall'],
        'test_f1': test_pred_stats['f1'],
    })
    plotter.send()
    # plot the confusion matrix on the test data
    plot_cm(test_pred_stats["confusion_matrix"])

Let's launch the training loop:

In [None]:
from accelerate import notebook_launcher

notebook_launcher(training_function, num_processes=1)