# Using pretext tasks for active learning

In this notebook I want to use a simple pretext task to train an Inception network.
This pre-trained Inception network can later be used as a starting point with a different classification head for training an ECG classifier.
Furthermore, the losses of the model can be used as selection criteria at the beginning of a new active learning cycle.

## Setup
Utilize more of the GPU memory so that I can use bigger batches.

In [None]:
from deepal_for_ecg.util import improve_gpu_capacity

improve_gpu_capacity()

## Pretext task
To train the network in a self-supervised fashion the labels have to come from the data itself.
A pretext task for learning ECG representations suggested by Sarkar and Etemad (2022) is transformation recognition of six different signal transformations applied to the ECG signal.
The six signal transformations are noise addition, scaling, negation, temporal inversion, permutation, and time-warping.

## Data

### Load the data

In [None]:
import numpy as np

from deepal_for_ecg.data.load import PTBXLDataLoader

loader = PTBXLDataLoader(saved_data_base_dir="../data/saved/", load_saved_data=True)
loader.load_data()

signal_data = loader.X_train.astype(np.float32)

In [None]:
# since we do not need the other data we reduce the memory footprint by deleting the loader
del loader

### Generate the transformed data

Use the TransformationRecognitionDataModule to generate the transformed data and split it into training, test and validation datasets.


In [None]:
from deepal_for_ecg.data.tranformation_recognition import TransformationRecognitionDataModule

data_module = TransformationRecognitionDataModule()
# uncomment the following lines to initially generate the data
# data_module.generate_and_split_data(signal_data)
# data_module.prepare_datasets()


### Inspect the data transformation

To have a better understanding of the data transform a random signal and visualize it.

In [None]:
import numpy as np
from deepal_for_ecg.data.augmentation import noise_addition, scaling, negation, temporal_inversion, permutation, time_warping

sample_idx = np.random.randint(signal_data.shape[0])
channel_idx = np.random.randint(signal_data.shape[-1])

selected_signal = np.expand_dims(signal_data[sample_idx], axis=0)

noisy_data = noise_addition(selected_signal)
scaled_data = scaling(selected_signal)
negation_data = negation(selected_signal)
temporal_inversed_data = temporal_inversion(selected_signal)
permuted_data = permutation(selected_signal)
time_warped_data = time_warping(selected_signal)

### Visual inspection

Visualize the transformations of a random channel of a random signal.

In [None]:
from matplotlib import pyplot as plt

fig = plt.figure(figsize=(20, 12))
fig.suptitle(f"Original Signal and Transformed Signals (signal {sample_idx}, channel {channel_idx})")

plt.subplot(3, 2, 1)

signals = [
    noisy_data[0, :, channel_idx], 
    scaled_data[0, :, channel_idx], 
    negation_data[0, :, channel_idx], 
    temporal_inversed_data[0, :, channel_idx], 
    permuted_data[0, :, channel_idx], 
    time_warped_data[0, :, channel_idx]
]
transformation_labels = ["noisy", "scaled", "negated", "temporal_inversed", "permuted", "time_warped"]

for i, (signal, label) in enumerate(zip(signals, transformation_labels)):
    plt.subplot(3, 2, i+1)
    plt.plot(signal_data[sample_idx, :, channel_idx], label='original', color='blue')
    plt.plot(signal, label=label, color='orange', alpha=0.5)
    # plt.title(f'Transformed Signal {i+1}')
    if i % 2 == 0:
        plt.ylabel('Amplitude')
    if i >= 4:
        plt.xlabel('Time Steps')
    plt.legend()
    plt.grid(True)

plt.tight_layout()
plt.show()

## Model

Load the Inception network model with the multi-task classification head.

In [None]:
from deepal_for_ecg.models.classification_heads import simple_classification_head
from deepal_for_ecg.models.inception_network import InceptionNetworkBuilder, InceptionNetworkConfig

config = InceptionNetworkConfig(create_classification_head=simple_classification_head, num_classes=7)
builder = InceptionNetworkBuilder()
model = builder.build_model(config)

## Training

Now I train the Inception network with the pretext task to create a good representation model.

### Improve GPU capacity


### Training loop

inclusive saving the best model according to validation set


In [None]:
# Adam optimizer with 0.0001 learning rate
batch_size = 128
epochs = 100

model.compile(optimizer='adam', loss='CategoricalCrossentropy', metrics=['accuracy'])

In [None]:
from deepal_for_ecg.data.augmentation import random_crop as rc

model.fit(data_module.train_dataset.map(lambda x, y:(rc(x), y)).batch(128), epochs=epochs, batch_size=128, validation_data=data_module.validation_dataset.map(lambda x, y:(rc(x), y)).batch(128))

### Test the model

In [None]:
batched_dataset = data_module.train_dataset.map(lambda x, y:(rc(x), y)).batch(128)

In [None]:
for s, l in batched_dataset.take(1):
    print(s.shape)
    print(l.shape)

## Pretext task loss selection vs. random selection
In this section I want to validate that the first selection of samples can be improved by using the pretext task loss.

### Loading the best pretext task model

### Calculate final losses

Calculate the average losses of all transformations of the best pretext task model.
TODO: Check if really the average was used

### Split unlabeled pools
Similar to Yi et al. (2022) split the unlabeled pool in multiple unlabeled pools to select the data from in each iteration.

### Select initial samples
Select the initial samples from the first unlabeled pool at uniform.

### Train the ECG classification network from scratch with these samples

### Fine-tune the ECG classificator from the pre-trained model with these samples

### Train the ECG classificator from scratch with randomly chosen samples

### Fine-tune the ECG classificator from the pre-trained model with randomly chosen samples

### Compare the results