In [1]:
using Pkg
Pkg.activate("../../FastAI.jl/")
Pkg.instantiate()

[32m[1m  Activating[22m[39m project at `~/Desktop/dev/FastAI.jl`


In [2]:
using Flux

In [2]:
using FastAI, FastTimeSeries, Flux

┌ Info: Precompiling FastAI [5d0beca9-ade8-49ae-ad0b-a3cf890e669f]
└ @ Base loading.jl:1423
┌ Info: Precompiling FastTimeSeries [5337c758-7610-4451-a331-8357b11df7c6]
└ @ Base loading.jl:1423


# TimeSeries Classification

In [3]:
data, blocks = load(datarecipes()["ecg5000"]);

`getobs` gets us a sample from the TimeSeriesDataset. It returns a tuple with the input time series and the correspodning label.

In [4]:
input, class = sample = getobs(data, 25)

(Float32[-0.28834122 -2.2725453 … 1.722784 1.2959242], "1")

Now we create a learning task for time-series classification. This means using the time-series to predict labels. We will use the `TimeSeriesRow` block as input and `Label` block as the target.

In [5]:
task = SupervisedTask(
    blocks,
    (
        OneHot(),
        setup(TSPreprocessing, blocks[1], data[1].table)
    )
)

SupervisedTask(TimeSeriesRow -> Label{SubString{String}})

The encodings passed in transform samples into formats suitable as inputs and outputs for a model

Let's check that samples from the created data container conform to the blocks of the learning task:

In [6]:
checkblock(task.blocks.sample, sample)

true

To get an overview of the learning task created, and as a sanity test, we can use describetask. This shows us what encodings will be applied to which blocks, and how the predicted ŷ values are decoded.

In [7]:
describetask(task)

**`SupervisedTask` summary**

Learning task for the supervised task with input `TimeSeriesRow` and target `Label{SubString{String}}`. Compatible with `model`s that take in `TimeSeriesRow` and output `OneHotLabel{SubString{String}}`.

Encoding a sample (`encodesample(task, context, sample)`) is done through the following encodings:

|          Encoding |              Name |      `blocks.input` |                      `blocks.target` |
| -----------------:| -----------------:| -------------------:| ------------------------------------:|
|                   | `(input, target)` |     `TimeSeriesRow` |           `Label{SubString{String}}` |
|          `OneHot` |                   |                     | **`OneHotLabel{SubString{String}}`** |
| `TSPreprocessing` |          `(x, y)` | **`TimeSeriesRow`** |                                      |


In [8]:
encoded_sample = encodesample(task, Training(), sample)

(Float32[-0.28937635 -2.2807038 … 1.7289687 1.3005764], Bool[1, 0, 0, 0, 0])

### Visualization Tools for TimeSeries

In [9]:
sample = getobs(data, 1)

(Float32[-0.11252183 -2.8272038 … 0.92528623 0.19313742], "1")

In [None]:
showsample(task, sample)

In [None]:
showblock(blocks[1], sample[1])

### Training

We will use a StackedLSTM as a backbone model, and a Dense layer at the front for classification. `taskmodel` knows how to do this by looking at the datablocks used. 

In [13]:
backbone = FastTimeSeries.Models.StackedLSTM(1, 16, 10, 2);

In [15]:
model = FastAI.taskmodel(task, backbone);

We can `tasklossfn` to get a loss function suitable for our task.

In [16]:
lossfn = tasklossfn(task)

logitcrossentropy (generic function with 1 method)

Next we create a pair of training and validation data loaders. They take care of batching and loading the data in parallel in the background.

In [17]:
traindl, validdl = taskdataloaders(data, task, 16);

We will use an `Adam` optimzer for this task.

In [18]:
optimizer = ADAM(0.002)

ADAM(0.002, (0.9, 0.999), 1.0e-8, IdDict{Any, Any}())

We create callbacks to get the accuracy during the training

In [19]:
callbacks = [ToGPU(), Metrics(accuracy)];

With the addition of an optimizer and a loss function, we can now create a Learner and start training. 

In [20]:
learner = Learner(model, lossfn; data=(traindl, validdl), optimizer=optimizer, callbacks = callbacks);

In [21]:
fitonecycle!(learner, 10, 0.002)

┌ Info: The GPU function is being called but the GPU is not accessible. 
│ Defaulting back to the CPU. (No action is required if you want to run on the CPU).
└ @ Flux /Users/saksham/.julia/packages/Flux/js6mP/src/functor.jl:192
[32mEpoch 1 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:40[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   1.0 │ 0.95453 │  0.65725 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 1 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   1.0 │ 0.36429 │   0.9082 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 2 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   2.0 │ 0.30034 │   0.9205 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 2 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   2.0 │ 0.28543 │  0.91211 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 3 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m   Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼────────┼──────────┤
│ TrainingPhase │   3.0 │ 0.2677 │  0.92825 │
└───────────────┴───────┴────────┴──────────┘


[32mEpoch 3 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   3.0 │ 0.26776 │  0.91895 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 4 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   4.0 │ 0.23461 │   0.9355 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 4 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   4.0 │ 0.27086 │  0.92285 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 5 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   5.0 │ 0.22571 │   0.9375 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 5 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   5.0 │ 0.24774 │  0.93457 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 6 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   6.0 │ 0.21649 │   0.9385 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 6 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   6.0 │ 0.24026 │  0.93359 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 7 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   7.0 │ 0.21095 │  0.93825 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 7 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   7.0 │ 0.23704 │  0.93262 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 8 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   8.0 │ 0.20555 │  0.93975 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 8 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   8.0 │ 0.24263 │  0.93359 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 9 TrainingPhase(): 100%|██████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │   9.0 │ 0.20291 │  0.94075 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 9 ValidationPhase(): 100%|████████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │   9.0 │ 0.23519 │  0.93457 │
└─────────────────┴───────┴─────────┴──────────┘


[32mEpoch 10 TrainingPhase(): 100%|█████████████████████████| Time: 0:00:03[39m


┌───────────────┬───────┬─────────┬──────────┐
│[1m         Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├───────────────┼───────┼─────────┼──────────┤
│ TrainingPhase │  10.0 │ 0.19846 │    0.942 │
└───────────────┴───────┴─────────┴──────────┘


[32mEpoch 10 ValidationPhase(): 100%|███████████████████████| Time: 0:00:00[39m


┌─────────────────┬───────┬─────────┬──────────┐
│[1m           Phase [0m│[1m Epoch [0m│[1m    Loss [0m│[1m Accuracy [0m│
├─────────────────┼───────┼─────────┼──────────┤
│ ValidationPhase │  10.0 │ 0.23493 │  0.93457 │
└─────────────────┴───────┴─────────┴──────────┘


We can save the model for later inference using `savetaskmodel`:

In [23]:
savetaskmodel("tsclassification.jld2", task, learner.model; force = true)