Skip to content

Codebase for the Paper "Deep Semi-supervised Learning (SSL) for Time Series Classification (TSC)" to appear at the ICMLA '21

License

Notifications You must be signed in to change notification settings

Goschjann/ssltsc

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

12 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

SSLTSC

Codebase for our Paper Deep Semi-supervised Learning (SSL) for Time Series Classification (TSC) to appear at the ICMLA '21

tldr: performance gains of semi-supervised models translate well from image to time series classification:

Results

General

This framework allows the evaluation of the performance of SSL algorithms initially designed for image classification tasks on time series classification problems and their comparison with a different baseline models.

This pytorch-based codebase allows you to run experiments in a reproducible manner and to track and visualize your single experiments via mlflow. The core of this framework are two sub-packages dl4d for data loading and sampling in a semi-supervised manner and ssltsc which contains different backbone architectures, baseline models and the semi-supervised learning strategies. To control the hyperparameters and general arguments for the model runs, you want to use the config files specifying single experiments in ssltsc/experiments/config_files. Hyperparameter tuning is possible based upon this config file syntax using Hyperband as implemented in optuna.

All models in this repository were developed using image classification datasets (Cifar10, SVHN) as comparison to validate the correctness of the code. This means, you can use it not only for semi-supervised time series classification but also as a starting point for semi-supervised image classification.

The core functionalities of this framework are also tested in a series of unit tests. Run python -m unittest discover -s tests from the parent level of this repository to test those crucial parts of the framework via the unittest framework. CI will be integrated on top of these tests soon.

The following UML diagram gives a detailed overview on the different components of this framework: UML Diagram

Get Started

Install the requirements.txt in a clean python environment via pip install -r requirements.txt. Then install the module ssltsc by running pip install -e . from the parent level of this repository.

Examples

The following are some examples on how to train or tune different algorithms on different datasets using this framework. Datasets are downloaded to the folder data on the fly if they are used the first time. These code-snippets should be run from ssltsc/experiments. Then

To train a mixmatch model with an FCN backbone on the pamap2 Dataset for 1000 update steps storing the results in the mlflow experiment hello_mixmatch_fcn, run:

python run.py --config config_files/mixmatch.yaml --n_steps 1000 --dataset pamap2 --backbone FCN --mlflow_name hello_mixmatch_fcn

To verify the correct implementation of the virtual adversarial training (VAT) model on cifar10 with a wideresnet28 backbone run:

python run.py --config config_files/vat.yaml --dataset cifar10 --backbone wideresnet28

To run a Random Forest baseline based on features extracted via tsfresh from the SITS dataset on 250 labelled samples only, run:

python run_baseline.py --config config_files/randomforest.yaml --dataset sits --num_labels 250

And finally to tune the hyperparameters of the meanteacher model on the crop dataset for 10 hours on 1000 labelled samples, run:

python tune.py --config config_files/meanteacher.yaml --num_labels 1000 --time_budget 36000

Integrated Algorithms and Datasets

Algorithms

All algorithms are stored in ssltsc.models. Currently, the following semi-supervised algorithms are implemented within this framework:

and the following baseline models:

  • Supervised baseline model
  • Random Forest (based on features extracted via tsfresh)
  • Logistic Regression (based on features extracted via tsfresh

Datasets

All integrated datasets can be found at dl4d.datasets. This framework currently contain the following TSC datasets:

as well as these standard image classification datasets to validate the implementation

  • Cifar10
  • SVHN

About

Codebase for the Paper "Deep Semi-supervised Learning (SSL) for Time Series Classification (TSC)" to appear at the ICMLA '21

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages