# Transfer learning using ART

Hi! In this tutorial, we will walk you through the process of using ART to perform transfer learning. We will use the [Yelp Reviews](https://huggingface.co/datasets/yelp_review_full) dataset and the `bert-base-cased` model from HuggingFace. We will train a classifier to predict the sentiment of a review (positive or negative) and then we will use ART to perform transfer learning to attack the classifier. Most of the code will follow [HF's tutorial](https://huggingface.co/docs/transformers/training) with some modifications to make it work with ART.

Just to remind you - the main goal of ART is to follow [Karpathys' recipe of training neural networks](https://karpathy.github.io/2019/04/25/recipe/).

We'll do everything in a script, your task will be to fill the `run.py` accordingly with our instructions from this tutorial.


In [None]:
!pip install art nltk wordcloud

## Data Analysis

Firstly we need to download the data and do some analysis on it. We'll use the `datasets` library from HuggingFace to do this, and we'll wrap the model to Lightning's `DataModule` to make it easier to use with PyTorch Lightning. We prepared the dataset for you in `dataset.py`, check it out there - it's nothing more than just a simple Lightning's Datamodule. The main function in `run.py` is just ready to download the data and show you a sample from it:

In [None]:
!python run.py

<details>

<summary>Correct main()</summary>

```py

data = YelpReviews()
print(data.dataset["train"][100])

```
</details>

Now we can become one with the data. We want to know some statistics, that will be helpful throughout the whole project. We prepared for you a data analysis step in `steps.py`. 
Now it's your turn! Fill the `...` places in `steps.py` to get the data statistics. You can use `dataset.py` to get the data.

<details>

<summary>Hint for filling TextDataAnalisys</summary>

* For `number_of_classes` you can count the unique targets in the dataset
* For class_names you can get names from targets as well
* For class counts just sum all examples for each class form targets
* For number_of_unique_words you need to calculate the number of unique words in the dataset.
* You can make use of `Counter` from `collections` library and `np.unique` from `numpy` library

* Notice the `log_params` function - it's an important function that will log specigied by you parameters to the logger. You can use it to log the results of your analysis.
</details>

<details>

<summary>Correct TextDataAnalisys</summary>

```python
    def do(self, previous_states):
        targets = []
        texts = []

        # Loop through batches in the YelpReviews datamodule train dataloader
        for batch in self.datamodule.train_dataloader():
            # Assuming 'labels' contains the review scores
            targets.extend(batch['label'])
            # Assuming 'text' contains the review text
            texts.extend(batch['text'])

        # Calculate the number of unique classes (review scores) in the targets
        number_of_classes = len(np.unique(targets))

        # Now tell me what the scores are
        class_names = [str(i) for i in sorted(np.unique(targets))]

        # Create a dictionary of class names and their counts
        targets_ints = [int(i) for i in targets]
        class_counts = Counter(targets_ints)

        # count number of unique words
        unique_words = set()
        for text in texts:
            unique_words.update(text.split())
        number_of_unique_words = len(unique_words)

        # Create a word cloud
        wordcloud = WordCloud().generate(' '.join(texts))
        fig = plt.figure(figsize=(12, 12))
        plt.imshow(wordcloud, interpolation='bilinear')
        plt.axis("off")
        MatplotLibSaver().save(
            fig, self.get_step_id(), self.name, "wordcloud"
        )

        self.results.update(
            {
                "number_of_classes": number_of_classes,
                "class_names": class_names,
                "number_of_reviews_in_each_class": class_counts,
            }
        )
```
</details>

As you've done it, modify the main() function as follows:
* read the data
* start the ART project
* add our data analysis step with checking, whether the result exists
* run all the steps (for now we have just one)

<details>

<summary>Hint for filling main()</summary>

* To read the data just initialize the `YelpReviews` class from `dataset.py` - as we did in the previous step.
* To start the ART project you need to initialize the ArtProject class from art.project. To do it you just need a name of your project (let it be "yelpreviews") and the data, you've already read
* To add the step you need to call `add_step` method from ArtProject class. You need to pass the step class (in this case it's `TextDataAnalysis` from `steps.py` you've just filled) and a list of checks to perform. In our case we want to check whether the result from our step exist, so we use `CheckResultExists` class from `art.checks` for each of ["number_of_classes", "class_names", "number_of_reviews_in_each_class", "number_of_unique_words"]
* To run the project you need to call `run_all` method from the already initialized project. It will run all the steps you've added to the project.

</details>

<details>

<summary>Correct main()</summary>

```py
def main():
    data = YelpReviews()
    project = ArtProject("yelpreviews", data)
    project.add_step(
        TextDataAnalysis(),
        [
            CheckResultExists("number_of_classes"),
            CheckResultExists("class_names"),
            CheckResultExists("number_of_reviews_in_each_class"),
            CheckResultExists("number_of_unique_words")
        ])
    project.run_all()
```
</details>

In [None]:
!python run.py

If you can see the output below, and the wordcloud.png in checkpoints folder we're good to go!
```
Steps status:
data_analysis_Data analysis: Completed. Results:
        number_of_classes: 5
        class_names: ['0', '1', '2', '3', '4']
        number_of_reviews_in_each_class: Counter({1: 240, 2: 208, 4: 189, 0: 189, 3: 174})```


**Extra tasks**
* Try to write your check to check whether the wordcloud exists
* Try to calculate more statistics that you find useful, save them in the results, and add checks in the `run.py` to verify whether they exist!
* Try to log the results in log_params function in `steps.py`
* Try to check whether the number of unique words is greater than 500


## Preparation of metrics in our project

As we have the data, we ca work on our models, which will solve the sentiment analysis problem! We start with a simple baseline. But before that, we need to define metrics that we'll use throughout the entire experiment:

* Calculate the number of classes - you can write it by yourself or use the `number_of_classes` from the results of the previous step
* Define metrics - we'll use Accuracy, Precision, Recall, and the CrossEntropyLoss. Initialize each of them in a list `METRICS`
* pass that list to the project - `project.register_metrics(METRICS)`

<details>

<summary>Correct main()</summary>

```py
def main():
    data = YelpReviews()
    project = ArtProject("yelpreviews", data)

    project.add_step(TextDataAnalysis(), [
                     CheckResultExists("number_of_classes"),
                     CheckResultExists("class_names"),
                     CheckResultExists("number_of_reviews_in_each_class"),])
    NUM_CLASSES = 5
    METRICS = [
        Accuracy(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Precision(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Recall(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        nn.CrossEntropyLoss()
    ]
    project.register_metrics(METRICS)

    project.run_all()
```
</details>

In [None]:
!python run.py

At this stage you should see, that the first step was skipped, because we already have executed it.

**But why do we need metrics defined for the project?**

Take a look at the `MetricsCalculator` from the ART. It takes care of calculating metrics for each step in the project. It's a very useful class, as it allows us to calculate metrics for each step in the project, and then we can use them to compare different models. It's also very useful when we want to compare different models on the same dataset. We can just add the metrics to the project, and then we can compare them in the end.


## Baselines

In every project, we have to start from the baselines! We prepared one baseline for you in `models/simple_baseline.py`. 

The baseline, as every other model used in ART, has to inherit from the `ArtModule` which is a wrapper for PyTorch Lightning's `LightningModule`. The `ArtModule` has a few useful methods, that we'll use in our project. The most important is the integration with the previously mentioned `MetricsCalculator`, but we'll come to that later when we develop the first deep learning model. For now, we use `ml_parse_data` which parses data specifically for the non-deep-learning training (we don't use PyTorch there), and the `baseline_train` method, which "trains" the model. In our case, it's just calculating probabilities for each class and returning them. We'll use it to compare it with our deep learning models. Take attention to the `ml_parse_data` return format - it's a dictionary `{INPUT: X, TARGET: y}`

Add the baseline to the project and run it:

* Create a baseline callable object - do not initialize it!
* Add the `EvaluateBaseline` step to the project by checking whether scores for each metric exist
* Run the project


<details>

<summary>Hints for evaluating the baseline</summary>

* To register the step call the `add_step` function with `step=EvaluateBaseline` from `art.steps` and `checks=[CheckResultExists(metric) for metric in METRICS]`
</details>

<details>

<summary>Correct main()</summary>

```py
def main():
    data = YelpReviews()
    project = ArtProject("yelpreviews", data)

    project.add_step(TextDataAnalysis(), [
                     CheckResultExists("number_of_classes"),
                     CheckResultExists("class_names"),
                     CheckResultExists("number_of_reviews_in_each_class"),])
    NUM_CLASSES = 5
    METRICS = [
        Accuracy(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Precision(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Recall(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        nn.CrossEntropyLoss()
    ]
    project.register_metrics(METRICS)

    baseline = HeuristicBaseline
    project.add_step(
        step=EvaluateBaseline(baseline),
        checks=[CheckScoreExists(metric=METRICS[i])
                for i in range(len(METRICS))],
    )

    project.run_all()
```
</details>

In [None]:
!python run.py

The correct output should look like this:
```
Steps status:

data_analysis_Data analysis: Skipped. Results:
        number_of_classes: 5
        class_names: ['0', '1', '2', '3', '4']
        number_of_reviews_in_each_class: {'4': 189, '1': 240, '3': 174, '0': 189, '2': 208}

        
HeuristicBaseline_2_Evaluate Baseline: Completed. Results:
        MulticlassAccuracy-HeuristicBaseline-validate-Evaluate Baseline: 0.30702152848243713
        MulticlassPrecision-HeuristicBaseline-validate-Evaluate Baseline: 0.3316725790500641
        MulticlassRecall-HeuristicBaseline-validate-Evaluate Baseline: 0.30702152848243713
```

**Extra tasks**
* Try to write your own baseline in `models/baseline2.py` and evaluate it in the project

## Training the proper model

As you might already know from the choice of tokenizer, we chose the bert-tiny for this problem. This dataset is hard, so we'll be able to obtain ~45% accuracy on the test set.

We prepared the model for you in `models/bert.py`. It's a simple model, that uses the `prajjwal1/bert-tiny` model from HuggingFace. We use the `AutoModelForSequenceClassification` model, which is a model that takes a sequence of tokens and returns the logits for each class. We use the `AutoTokenizer` to tokenize the text, and then we use the `BertForSequenceClassification` to get the logits. We use the `AutoModelForSequenceClassification` with the `prajjwal1/bert-tiny` model because it's already trained on the sentiment analysis task, so we can use it as a starting point for our model. We could train the last layer of the model, which is a linear layer, and we'll be able to get some good results. But we'll also try to perform fine-tuning of the whole model, to see if we can get better results.

Notice a few things:
* We use the `ArtModule` as a wrapper for the `LightningModule` - it's a very useful class, as it allows us to use the `MetricsCalculator` to calculate metrics for each step in the project
* Notice the `compute_loss()` - it only takes calculated loss from the `MetricsCalculator` which is passed inside the data dictionary. Pure ART's magic!
* Pay attention to the format of returning predictions and data, as previously done in the baselines

Before we train the final model we'll perform some experiments:
* Check loss on initialization - add `CheckLossOnInit` to the project
* Overfitting one batch with an unfrozen backbone - add `OverfitOneBatch` to the project
* Overfitting the entire dataset with an unfrozen backbone - add `OverfitEntireDataset` to the project


Then, if our steps succeed we can perform training on the entire dataset - first with a frozen backbone, then with an unfrozen backbone and reduced learning rate - just add `TransferLearning` to the project

<details>

<summary>Hints for checking loss on initialization</summary>

* Register the model without initializing it - `model = YelpReviewsModel`
* Try to calculate the expected loss of cross entropy for the number of classes you have in the dataset
* Use logarithm to log the expected loss
* Use `CheckLossOnInit` from `art.checks` to check whether the loss is correct
* Initialize the `CheckLossOnInit` with `model=model` and one check in the list
* Use CheckScoreCloseTo - the cross entropy loss defined in METRICS should be close to the expected loss
</details>

<details>

<summary><b>CheckLossOnInit in main()<b></summary>

```py
    data = YelpReviews()
    project = ArtProject("yelpreviews", data)

    project.add_step(
        TextDataAnalysis(),
        [
            CheckResultExists("number_of_classes"),
            CheckResultExists("class_names"),
            CheckResultExists("number_of_reviews_in_each_class"),
            CheckResultExists("number_of_unique_words")
        ])
    NUM_CLASSES = 5
    METRICS = [
        Accuracy(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Precision(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Recall(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        nn.CrossEntropyLoss()
    ]
    project.register_metrics(METRICS)

    baseline = HeuristicBaseline
    project.add_step(
        step=EvaluateBaseline(baseline),
        checks=[CheckScoreExists(metric=METRICS[i])
                for i in range(len(METRICS))],
    )

    model = YelpReviewsModel

    EXPECTED_LOSS = - math.log(1 / NUM_CLASSES)
    EXPECTED_LOSS
    project.add_step(
        CheckLossOnInit(model),
        [CheckScoreCloseTo(metric=METRICS[3],
                           value=EXPECTED_LOSS, rel_tol=0.1)]
    )

    project.run_all()
```

</details>

<details>

<summary>Hints for overfitting one batch</summary>

* Use `OverfitOneBatch` step, specify the number of epochs to 40-50
* Use `CheckScoreLessThan` to check whether the loss less them some value you specify
* The value could be e.g. `0.05`, but everyone can define the overfitting differently
</details>

<details>

<summary><b>OverfitOneBatch main()<b></summary>

```py
    data = YelpReviews()
    project = ArtProject("yelpreviews", data)

    project.add_step(
        TextDataAnalysis(),
        [
            CheckResultExists("number_of_classes"),
            CheckResultExists("class_names"),
            CheckResultExists("number_of_reviews_in_each_class"),
            CheckResultExists("number_of_unique_words")
        ])
    NUM_CLASSES = 5
    METRICS = [
        Accuracy(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Precision(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Recall(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        nn.CrossEntropyLoss()
    ]
    project.register_metrics(METRICS)

    baseline = HeuristicBaseline
    project.add_step(
        step=EvaluateBaseline(baseline),
        checks=[CheckScoreExists(metric=METRICS[i])
                for i in range(len(METRICS))],
    )

    model = YelpReviewsModel

    EXPECTED_LOSS = - math.log(1 / NUM_CLASSES)
    EXPECTED_LOSS
    project.add_step(
        CheckLossOnInit(model),
        [CheckScoreCloseTo(metric=METRICS[3],
                           value=EXPECTED_LOSS, rel_tol=0.1)]
    )

    project.add_step(
        step=OverfitOneBatch(model, number_of_steps=40),
        checks=[CheckScoreLessThan(metric=METRICS[3], value=0.05)],
    )

    project.run_all()
```

</details>

<details>

<summary>Hints for overfitting the entire trainset</summary>

* Use `Overfit` step, specify the number of epochs to 40-50
* Use `CheckScoreLessThan` to check whether the loss less them some value you specify
* The value could be e.g. `0.05`, but everyone can define the overfitting differently
</details>

<details>

<summary><b>Overfit main()<b></summary>

```py
    data = YelpReviews()
    project = ArtProject("yelpreviews", data)

    project.add_step(
        TextDataAnalysis(),
        [
            CheckResultExists("number_of_classes"),
            CheckResultExists("class_names"),
            CheckResultExists("number_of_reviews_in_each_class"),
            CheckResultExists("number_of_unique_words")
        ])
    NUM_CLASSES = 5
    METRICS = [
        Accuracy(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Precision(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Recall(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        nn.CrossEntropyLoss()
    ]
    project.register_metrics(METRICS)

    baseline = HeuristicBaseline
    project.add_step(
        step=EvaluateBaseline(baseline),
        checks=[CheckScoreExists(metric=METRICS[i])
                for i in range(len(METRICS))],
    )

    model = YelpReviewsModel

    EXPECTED_LOSS = - math.log(1 / NUM_CLASSES)
    EXPECTED_LOSS
    project.add_step(
        CheckLossOnInit(model),
        [CheckScoreCloseTo(metric=METRICS[3],
                           value=EXPECTED_LOSS, rel_tol=0.1)]
    )

    project.add_step(
        step=OverfitOneBatch(model, number_of_steps=40),
        checks=[CheckScoreLessThan(metric=METRICS[3], value=0.05)],
    )

    project.add_step(
        step=Overfit(model, max_epochs=50),
        checks=[CheckScoreLessThan(metric=METRICS[3], value=0.1)],
    )

    project.run_all()
```

</details>

<details>

<summary>Hints for performing full transfer learning</summary>

* To register the step call the `add_step` function with `step=EvaluateBaseline` from `art.steps` and `checks=[CheckResultExists(metric) for metric in METRICS]`
</details>

<details>

<summary><b>Final main()<b></summary>

```py
    data = YelpReviews()
    project = ArtProject("yelpreviews", data)

    project.add_step(
        TextDataAnalysis(),
        [
            CheckResultExists("number_of_classes"),
            CheckResultExists("class_names"),
            CheckResultExists("number_of_reviews_in_each_class"),
            CheckResultExists("number_of_unique_words")
        ])
    NUM_CLASSES = 5
    METRICS = [
        Accuracy(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Precision(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        Recall(num_classes=NUM_CLASSES, average='macro', task='multiclass'),
        nn.CrossEntropyLoss()
    ]
    project.register_metrics(METRICS)

    baseline = HeuristicBaseline
    project.add_step(
        step=EvaluateBaseline(baseline),
        checks=[CheckScoreExists(metric=METRICS[i])
                for i in range(len(METRICS))],
    )

    model = YelpReviewsModel

    EXPECTED_LOSS = - math.log(1 / NUM_CLASSES)
    EXPECTED_LOSS
    project.add_step(
        CheckLossOnInit(model),
        [CheckScoreCloseTo(metric=METRICS[3],
                           value=EXPECTED_LOSS, rel_tol=0.1)]
    )

    project.add_step(
        step=OverfitOneBatch(model, number_of_steps=40),
        checks=[CheckScoreLessThan(metric=METRICS[3], value=0.05)],
    )

    project.add_step(
        step=Overfit(model, max_epochs=50),
        checks=[CheckScoreLessThan(metric=METRICS[3], value=0.1)],
    )

    early_stopping = EarlyStopping('CrossEntropyLoss-validate', patience=6)
    project.add_step(TransferLearning(model,
                                      freezed_trainer_kwargs={"max_epochs": 3,
                                                              "check_val_every_n_epoch": 2,
                                                              "callbacks": [early_stopping]},
                                      unfreezed_trainer_kwargs={"max_epochs": 3,
                                                                "check_val_every_n_epoch": 2,
                                                                "callbacks": [early_stopping]},
                                      freeze_names=["bert"],
                                      logger=logger
                                      ),
                     [CheckScoreGreaterThan(metric=METRICS[0], value=0.80)])

    project.run_all()
```

</details>

In [None]:
!python run.py

## Conclusions

Congratulations!! You've just performed transfer learning on the Yelp Reviews dataset! You can check the results in the checkpoints folder. You should see, that the model is able to achieve ~45% accuracy on the test set. It's not a lot, but it's a good result for this dataset.

**Extra tasks**
* Add logger, currently we support Neptune and Wandb. You can initialize the loggers and pass it to every step by `logger` argument.
* Experiment with other Checks
* Write your own Step (and Check if needed) to perform further research

<details>
<summary> Loggers usage </summary>
```py
from art.loggers import NeptuneLoggerAdapter, WandbLoggerAdapter
logger = NeptuneLoggerAdapter(
        project="your_project_name_in_Neptune")
logger = WandbLoggerAdapter(
        project="your_project_name_in_Neptune")
```
</details>