Fine-tuning
-----------
A common paradigm in deep learning is to pretrain a model on relevant data and then *fine-tune* it on smaller
application-focused examples.

We show how this premise can be used from the perspective of a large dataset of many people, and see each
person as a fine-tuning opportunity. This is very-similar to the un-aligned/DA/DG case of fine-tuning from
Kostas and Rudzicz 2020 (https://doi.org/10.1088/1741-2552/abb7a7).

To keep things as simple as possible, we use pretty much the same configuration and, as much as possible, code
as the `Basics` example. Return to that if anything is confusing.

```yaml
Configuratron:
  preload: True

use_gpu: False
test_fraction: 0.5

mmidb:
  name: "Physionet MMIDB"
  toplevel: /path/to/eegmmidb
  tmin: 0
  tlen: 6
  data_max: 0.001
  data_min: -0.001
  events:
    - T1
    - T2
  exclude_sessions:
    - "*R0[!48].edf"  # equivalently "*R0[1235679].edf"
    - "*R1[!2].edf"   # equivalently "*R1[134].edf"
  exclude_people:
    - S088
    - S090
    - S092
    - S100
  train_params:
    epochs: 7
    batch_size: 4
  lr: 0.0001
  fine_lr: 0.00001
  folds: 5
```

Below we will start with some identical code to load our dataset, and prepare a TIDNet model for classification.

In [8]:
from dn3.configuratron import ExperimentConfig
from dn3.trainable.processes import StandardClassification
from dn3.trainable.models import TIDNet

from torch.optim.lr_scheduler import CosineAnnealingLR

# Since we are doing a lot of loading, this is nice to suppress some tedious information
import mne
mne.set_log_level(False)

config_filename = 'my_config.yml'
experiment = ExperimentConfig(config_filename)
ds_config = experiment.datasets['mmidb']

dataset = ds_config.auto_construct_dataset()


Adding additional configuration entries: dict_keys(['train_params', 'folds', 'lr'])
Configuratron found 1 datasets.


Scanning /Volumes/Data/MMI. If there are a lot of files, this may take a while...: 100%|██████████| 4/4 [00:03<00:00,  1.05it/s, extension=.gdf]


Creating dataset of 315 Preloaded Epoched recordings from 105 people.


Loading Physionet MMIDB: 100%|██████████| 105/105 [00:10<00:00,  9.88person/s]

>> Physionet MMIDB | DSID: None | 105 people | 4408 trials | 90 channels | 1536 samples/trial | 256Hz | 0 transforms
Constructed 1 channel maps
Used by 315 recordings:
EEG (original(new)): Fc5.(FC5) Fc3.(FC3) Fc1.(FC1) Fcz.(FCZ) Fc2.(FC2) Fc4.(FC4) Fc6.(FC6) C5..(C5) C3..(C3) C1..(C1) Cz..(CZ) C2..(C2) C4..(C4) C6..(C6) Cp5.(CP5) Cp3.(CP3) Cp1.(CP1) Cpz.(CPZ) Cp2.(CP2) Cp4.(CP4) Cp6.(CP6) Fp1.(FP1) Fpz.(FPZ) Fp2.(FP2) Af7.(AF7) Af3.(AF3) Afz.(AFZ) Af4.(AF4) Af8.(AF8) F7..(F7) F5..(F5) F3..(F3) F1..(F1) Fz..(FZ) F2..(F2) F4..(F4) F6..(F6) F8..(F8) Ft7.(FT7) Ft8.(FT8) T7..(T7) T8..(T8) T9..(T9) T10.(T10) Tp7.(TP7) Tp8.(TP8) P7..(P7) P5..(P5) P3..(P3) P1..(P1) Pz..(PZ) P2..(P2) P4..(P4) P6..(P6) P8..(P8) Po7.(PO7) Po3.(PO3) Poz.(POZ) Po4.(PO4) Po8.(PO8) O1..(O1) Oz..(OZ) O2..(O2) Iz..(IZ) 
EOG (original(new)): 
REF (original(new)): 
EXTRA (original(new)): 
Heuristically Assigned: Fc5.(FC5)  Fc3.(FC3)  Fc1.(FC1)  Fcz.(FCZ)  Fc2.(FC2)  Fc4.(FC4)  Fc6.(FC6)  C5..(C5)  C3..(C3)  C1..(C1)  Cz.




This time, we will also create two functions that exhibit the two different (though not necessarily mutually exclusive)
way one might adjust from one domain to a slightly different one. Freezing and fine-tuning.

In [10]:
def frozen_tuning(training_data, testing_data, model):
    model.freeze_features()
    tune_process = StandardClassification(model, learning_rate=ds_config.rate, cuda=experiment.use_gpu)
    tune_process.fit(training_data, **ds_config.train_params)
    # We unfreeze so that the model can be subsequently trained again
    model.freeze_features(unfreeze=True)
    return tune_process.evaluate(testing_data)['Accuracy']

def fine_tuning(training_data, testing_data, model):
    tune_process = StandardClassification(model, learning_rate=ds_config.fine_lr, cuda=experiment.use_gpu,)
    tune_process.fit(training_data, **ds_config.train_params)
    return tune_process.evaluate(testing_data)['Accuracy']


Now we'll make some helpers to compare the tuned performance for three possible scenarios:

    1. Freeze features with a new classifier
    2. The same as the above, but then fine-tune *all weights* including the new final layer
    3. Just fine-tuning all the general weights from the start

In [11]:
PERFORMANCE_COLUMNS = ['Tuned', 'Frozen', 'Frozen then Tuned']
def tuning_performance_comparison(training_data, testing_data, model):
    just_tune_model = model.clone()
    just_tune_performance = fine_tuning(training_data, testing_data, just_tune_model)

    freeze_performance = frozen_tuning(training_data, testing_data, model)
    freeze_then_tune = fine_tuning(training_data, testing_data, model)

    return dict(zip(PERFORMANCE_COLUMNS, (just_tune_performance, freeze_performance, freeze_then_tune)))

Now everything runs pretty much the same as our basic process, with person-specific performance reporting. Except this
time, we will compare the different tuning techniques instead of just evaluating the model with our test person.

In [12]:
import tqdm
results = []
for training, validation, test in tqdm.tqdm(dataset.lmso(ds_config.folds), total=ds_config.folds,
                                            desc="LMSO", unit='fold'):
    tidnet = TIDNet.from_dataset(dataset)
    general_process = StandardClassification(tidnet, cuda=experiment.use_gpu, learning_rate=ds_config.lr)

    # General training
    tqdm.tqdm.write("General training...")
    general_process.fit(training_dataset=training, validation_dataset=validation, **ds_config.train_params)

    # Tuning
    tqdm.tqdm.write("Fine tuning...")
    for _, _, test_thinker in test.loso():
        # Now split the test_thinker further for training and testing (the middle return value would be validation)
        tune_train, _, tune_test = test_thinker.split(test_frac=experiment.test_fraction, validation_frac=0)

        performance = tuning_performance_comparison(tune_train, tune_test, tidnet.clone())
        best_perf = max(performance.values())
        tqdm.tqdm.write("Evaluated person {}, Best performance: {:.2%}".format(test_thinker.person_id, best_perf))

        summary = {'Person':test_thinker.person_id,
                   "Before Tuning": general_process.evaluate(test_thinker)['Accuracy'],
                   'Best Result': best_perf}
        summary.update(performance)
        results.append(summary)

LMSO:   0%|          | 0/5 [00:00<?, ?fold/s]

Training:   >> Physionet MMIDB | DSID: None | 63 people | 2646 trials | 90 channels | 1536 samples/trial | 256Hz | 0 transforms
Validation: >> Physionet MMIDB | DSID: None | 21 people | 880 trials | 90 channels | 1536 samples/trial | 256Hz | 0 transforms
Test:       >> Physionet MMIDB | DSID: None | 21 people | 882 trials | 90 channels | 1536 samples/trial | 256Hz | 0 transforms


  WeightNorm.apply(module, name, dim)
LMSO:   0%|          | 0/5 [00:00<?, ?fold/s]

Creating TIDNet using: 90 channels x 1536 samples at 256Hz | 2 targets
Apple M-series GPU detected: training and model execution will be performed on MPS.
General training...
Loading data with 0 additional workers


Epoch:  14%|#4        | 1/7 [00:00<?, ?epoch/s]

Iteration:   0%|          | 1/661 [00:00<?, ?batches/s]

LMSO:   0%|          | 0/5 [01:03<?, ?fold/s]

Training: End of Epoch 1 | Accuracy: 63.09% | loss: 0.696 | lr: 3.539e-05 | momentum: 0.917 | epoch: 1.000 |


Predicting:   0%|          | 0/220 [00:00<?, ?it/s]

LMSO:   0%|          | 0/5 [01:10<?, ?fold/s]

Validation: End of Epoch 1 | Accuracy: 78.75% | loss: 0.444 |
Best loss. Retaining checkpoint...


Iteration:   0%|          | 1/661 [00:00<?, ?batches/s]

LMSO:   0%|          | 0/5 [02:13<?, ?fold/s]

Training: End of Epoch 2 | Accuracy: 82.07% | loss: 0.394 | lr: 9.699e-05 | momentum: 0.853 | epoch: 2.000 |


Predicting:   0%|          | 0/220 [00:00<?, ?it/s]

LMSO:   0%|          | 0/5 [02:21<?, ?fold/s]

Validation: End of Epoch 2 | Accuracy: 81.25% | loss: 0.379 |
Best loss. Retaining checkpoint...


Iteration:   0%|          | 1/661 [00:00<?, ?batches/s]

LMSO:   0%|          | 0/5 [03:25<?, ?fold/s]

Training: End of Epoch 3 | Accuracy: 85.14% | loss: 0.342 | lr: 9.021e-05 | momentum: 0.860 | epoch: 3.000 |


Predicting:   0%|          | 0/220 [00:00<?, ?it/s]

LMSO:   0%|          | 0/5 [03:33<?, ?fold/s]

Validation: End of Epoch 3 | Accuracy: 83.52% | loss: 0.353 |
Best loss. Retaining checkpoint...


Iteration:   0%|          | 1/661 [00:00<?, ?batches/s]

LMSO:   0%|          | 0/5 [04:41<?, ?fold/s]

Training: End of Epoch 4 | Accuracy: 86.91% | loss: 0.315 | lr: 6.883e-05 | momentum: 0.881 | epoch: 4.000 |


Predicting:   0%|          | 0/220 [00:00<?, ?it/s]

LMSO:   0%|          | 0/5 [04:48<?, ?fold/s]

Validation: End of Epoch 4 | Accuracy: 83.98% | loss: 0.348 |
Best loss. Retaining checkpoint...


Iteration:   0%|          | 1/661 [00:00<?, ?batches/s]

LMSO:   0%|          | 0/5 [05:55<?, ?fold/s]

Training: End of Epoch 5 | Accuracy: 87.93% | loss: 0.293 | lr: 4.167e-05 | momentum: 0.908 | epoch: 5.000 |


Predicting:   0%|          | 0/220 [00:00<?, ?it/s]

LMSO:   0%|          | 0/5 [06:02<?, ?fold/s]

Validation: End of Epoch 5 | Accuracy: 85.00% | loss: 0.334 |
Best loss. Retaining checkpoint...


Iteration:   0%|          | 1/661 [00:00<?, ?batches/s]

LMSO:   0%|          | 0/5 [07:03<?, ?fold/s]

Training: End of Epoch 6 | Accuracy: 89.71% | loss: 0.253 | lr: 1.707e-05 | momentum: 0.933 | epoch: 6.000 |


Predicting:   0%|          | 0/220 [00:00<?, ?it/s]

LMSO:   0%|          | 0/5 [07:10<?, ?fold/s]

Validation: End of Epoch 6 | Accuracy: 85.23% | loss: 0.334 |
Best loss. Retaining checkpoint...


Iteration:   0%|          | 1/661 [00:00<?, ?batches/s]

LMSO:   0%|          | 0/5 [08:11<?, ?fold/s]

Training: End of Epoch 7 | Accuracy: 90.13% | loss: 0.240 | lr: 2.565e-06 | momentum: 0.947 | epoch: 7.000 |


Predicting:   0%|          | 0/220 [00:00<?, ?it/s]

LMSO:   0%|          | 0/5 [08:18<?, ?fold/s]

Validation: End of Epoch 7 | Accuracy: 85.23% | loss: 0.333 |
Best loss. Retaining checkpoint...
Loading best model...
Fine tuning...
Training:   >> Physionet MMIDB | DSID: None | 19 people | 798 trials | 90 channels | 1536 samples/trial | 256Hz | 0 transforms
Validation: Person S021 - 42 trials | 0 transforms
Test:       Person S001 - 42 trials | 0 transforms





AttributeError: 'DatasetConfig' object has no attribute 'fine_lr'

Let's use a `DataFrame` this time to compare the performances a little more effectively.

In [None]:
from pandas import DataFrame
results = DataFrame(results)
for tune_option in PERFORMANCE_COLUMNS:
    results[tune_option + ' Improvement'] = results[tune_option] - results['Before Tuning']
print(results.describe())