Fine-tuning
-----------
More and more state-of-the-art deep neural-network like classifiers perform a procedure of pretraining
using immense general purpose datasets, then *fine-tuning* on smaller application-focused examples.

We show how this premise can be used from the perspective of a large dataset of many people, and see each
person as a fine-tuning opportunity. This is very-similar to the un-aligned/DA/DG case of fine-tuning from
the Kostas and Rudzicz 2020 (under review) paper.

To keep things as simple as possible, we use pretty much the same configuration and, as much as possible, code
as the `Basics` example. Return to that if anything is confusing.

```yaml
DN3:
  datasets:
    - mmidb

training_configuration:
  use_gpu: False
  folds: 5
  epochs: 10
  batch_size: 16
  fine_tuning:
    test_fraction: 0.5
    epochs: 2
    rate: 1e-5

mmidb:
  name: "Physionet MMIDB"
  toplevel: /path/to/the/toplevel/folder
  tmin: 0
  tlen: 6
  events:
    - T1
    - T2
  exclude_sessions:
    - "*R0[!6].edf"    # equivalently "*R0[12345789].edf"
    - "*R1[!04].edf"   # equivalently "*R1[123].edf"
  exclude_people:
    - S088
    - S090
    - S092
    - S100
```

Below we will start with some identical code to load our dataset, and prepare a TIDNet model for classification.

In [1]:
from dn3.configuratron import ExperimentConfig
from dn3.trainable.processes import StandardClassification
from dn3.trainable.models import TIDNet

# Since we are doing a lot of loading, this is nice to suppress some tedious information
import mne
mne.set_log_level(False)

config_filename = 'my_config.yml'
experiment = ExperimentConfig(config_filename)
dataset = experiment.datasets['mmidb']

dataset = dataset.auto_construct_dataset()


Scanning ../tests/test_dataset. If there are a lot of files, this may take a while...: 100%|██████████| 4/4 [00:00<00:00, 91.52it/s, extension=.gdf]
Loading Physionet MMIDB: 100%|██████████| 105/105 [00:23<00:00,  4.48person/s]


Found 1 datasets.
Creating dataset of 420 Epoched recordings from 105 people.


This time, we will also create two functions that exhibit the two different (though not necessarily mutually exclusive)
way one might adjust from one domain to a slightly different one. Freezing and fine-tuning.

In [2]:
def frozen_tuning(training_data, testing_data, model):
    model.freeze_features()
    tune_process = StandardClassification(model, learning_rate=experiment.training_configuration.fine_tuning.rate)
    tune_process.fit(training_data, epochs=experiment.training_configuration.fine_tuning.epochs,
                     batch_size=experiment.training_configuration.batch_size)
    # We unfreeze so that the model can be subsequently trained again
    model.freeze_features(unfreeze=True)
    return tune_process.evaluate(testing_data)

def fine_tuning(training_data, testing_data, model):
    tune_process = StandardClassification(model, learning_rate=experiment.training_configuration.fine_tuning.rate)
    tune_process.fit(training_data, epochs=experiment.training_configuration.fine_tuning.epochs,
                     batch_size=experiment.training_configuration.batch_size)
    return tune_process.evaluate(testing_data)
    

Now we'll make some helpers to compare the tuned performance for three possible scenarios:

    1. Freeze features with a new classifier
    2. The same as the above, but then fine-tune *all weights* including the new final layer
    3. Just fine-tuning all the general weights from the start

In [3]:
def tuning_performance_comparison(training_data, testing_data, model):
    just_tune_model = model.clone()
    just_tune_performance = fine_tuning(training_data, testing_data, just_tune_model)

    freeze_performance = frozen_tuning(training_data, testing_data, model)
    freeze_then_tune = fine_tuning(training_data, testing_data, model)

    return dict(tuned=just_tune_performance, frozen=freeze_performance, freeze_then_tune=freeze_then_tune)

Now everything runs pretty much the same as our basic process, with person-specific performance reporting. Except this
time, we will compare the different tuning techniques instead of just evaluating the model with our test person.

In [4]:
results = list()
for training, validation, test in dataset.lmso(experiment.training_configuration.folds):
    tidnet = TIDNet.from_dataset(dataset, targets=2)
    process = StandardClassification(tidnet, cuda=experiment.training_configuration.use_gpu)

    # General training
    process.fit(training_dataset=training, validation_dataset=validation,
                epochs=experiment.training_configuration.epochs,
                batch_size=experiment.training_configuration.batch_size)

    # Tuning
    for _, _, test_thinker in test.loso():
        # First split the test_thinker further for training and testing (the middle return value would be validation)
        tune_train, _, tune_test = test_thinker.split(
            test_frac=experiment.training_configuration.fine_tuning.test_fraction, validation_frac=0)

        results.append(tuning_performance_comparison(tune_train, tune_test, tidnet))

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]
Iteration:   0%|          | 0/166 [00:00<?, ?it/s][A
Iteration:   0%|          | 0/166 [00:23<?, ?it/s, Accuracy=0.5, loss=0.978][A
Iteration:   1%|          | 1/166 [00:23<1:05:41, 23.89s/it, Accuracy=0.5, loss=0.978][A
Iteration:   1%|          | 1/166 [00:24<1:05:41, 23.89s/it, Accuracy=0.5, loss=1.18] [A
Iteration:   1%|          | 2/166 [00:24<45:58, 16.82s/it, Accuracy=0.5, loss=1.18]  [A
Iteration:   1%|          | 2/166 [00:24<45:58, 16.82s/it, Accuracy=0.688, loss=0.677][A
Iteration:   2%|▏         | 3/166 [00:24<32:14, 11.87s/it, Accuracy=0.688, loss=0.677][A
Iteration:   2%|▏         | 3/166 [00:24<32:14, 11.87s/it, Accuracy=0.625, loss=0.71] [A
Iteration:   2%|▏         | 4/166 [00:24<22:44,  8.42s/it, Accuracy=0.625, loss=0.71][A
Iteration:   2%|▏         | 4/166 [00:25<22:44,  8.42s/it, Accuracy=0.625, loss=1.15][A
Iteration:   3%|▎         | 5/166 [00:25<16:04,  5.99s/it, Accuracy=0.625, loss=1.15][A
Iteration:   3

KeyboardInterrupt: 

Let's use a `DataFrame` this time to compare the performances a little more elegantly.

In [None]:
from pandas import DataFrame
results = DataFrame(results)

print(results)
print(results.describe())