Skip to content

Commit

Permalink
Cuda support (#103)
Browse files Browse the repository at this point in the history
* adds second training set to train params of gaussian dist via spliting

* adds cuda support for all algorithms

* Run LSTMAD, Donut and DAGMM with NNAutoencoder on CUDA (#105)

* Ignore device placement on ReEBM as well

* Adapt LSTMED

* Sort detector execution by framework

* Closes #89 #91 #92 #93 #94 #95 #96
  • Loading branch information
xasetl authored and WGierke committed Jun 28, 2018
1 parent 8bab6b9 commit d5cced5
Show file tree
Hide file tree
Showing 13 changed files with 274 additions and 208 deletions.
42 changes: 22 additions & 20 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import numpy as np
import pandas as pd

from experiments import run_pollution_experiment, run_missing_experiment, run_extremes_experiment, \
run_multivariate_experiment
from src.algorithms import DAGMM, Donut, RecurrentEBM, LSTMAD, LSTMED, LSTMAutoEncoder
from src.datasets import AirQuality, KDDCup, SyntheticDataGenerator
from src.evaluation.evaluator import Evaluator
from experiments import run_pollution_experiment, run_missing_experiment, run_extremes_experiment, \
run_multivariate_experiment

RUNS = 2

Expand All @@ -17,24 +17,34 @@ def main():
run_experiments()


def run_pipeline():
def get_detectors():
if os.environ.get("CIRCLECI", False):
datasets = [SyntheticDataGenerator.extreme_1()]
detectors = [RecurrentEBM(num_epochs=2), LSTMAD(num_epochs=5), Donut(num_epochs=5), DAGMM(num_epochs=2),
LSTMED(num_epochs=2), DAGMM(num_epochs=2, autoencoder_type=LSTMAutoEncoder)]
return [RecurrentEBM(num_epochs=2), Donut(num_epochs=5), LSTMAD(num_epochs=5), DAGMM(num_epochs=2),
LSTMED(num_epochs=2), DAGMM(num_epochs=2, autoencoder_type=LSTMAutoEncoder)]
else:
datasets = [
return [RecurrentEBM(num_epochs=15), Donut(), LSTMAD(), LSTMED(num_epochs=40),
DAGMM(sequence_length=1), DAGMM(sequence_length=15),
DAGMM(sequence_length=1, autoencoder_type=LSTMAutoEncoder),
DAGMM(sequence_length=15, autoencoder_type=LSTMAutoEncoder)]


def get_pipeline_datasets():
if os.environ.get("CIRCLECI", False):
return [SyntheticDataGenerator.extreme_1()]
else:
return [
SyntheticDataGenerator.extreme_1(),
SyntheticDataGenerator.variance_1(),
SyntheticDataGenerator.shift_1(),
SyntheticDataGenerator.trend_1(),
SyntheticDataGenerator.combined_1(),
SyntheticDataGenerator.combined_4(),
]
detectors = [RecurrentEBM(num_epochs=15), LSTMAD(), Donut(), LSTMED(num_epochs=40),
DAGMM(sequence_length=1), DAGMM(sequence_length=15),
DAGMM(sequence_length=1, autoencoder_type=LSTMAutoEncoder),
DAGMM(sequence_length=15, autoencoder_type=LSTMAutoEncoder)]


def run_pipeline():
datasets = get_pipeline_datasets()
detectors = get_detectors()

evaluator = Evaluator(datasets, detectors)
# perform multiple pipeline runs for more significant end results
Expand Down Expand Up @@ -84,19 +94,11 @@ def evaluate_on_real_world_data_sets():

def run_experiments(outlier_type='extreme_1', output_dir=None, steps=5):
output_dir = output_dir or os.path.join('reports/experiments', outlier_type)
detectors = get_detectors()
if os.environ.get("CIRCLECI", False):
detectors = [RecurrentEBM(num_epochs=2), LSTMAD(num_epochs=5), Donut(num_epochs=5),
LSTMED(num_epochs=2), DAGMM(num_epochs=2),
DAGMM(num_epochs=2, autoencoder_type=LSTMAutoEncoder)]
run_extremes_experiment(detectors, outlier_type, output_dir=os.path.join(output_dir, 'extremes'),
steps=1)
else:
detectors = [RecurrentEBM(num_epochs=15), LSTMAD(), Donut(), LSTMED(num_epochs=40),
DAGMM(sequence_length=1),
DAGMM(sequence_length=15),
DAGMM(sequence_length=1, autoencoder_type=LSTMAutoEncoder),
DAGMM(sequence_length=15, autoencoder_type=LSTMAutoEncoder)]

announce_experiment('Pollution')
run_pollution_experiment(detectors, outlier_type, output_dir=os.path.join(output_dir, 'pollution'),
steps=steps)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pandas
tqdm
scipy>=0.14.0
scikit-learn>=0.19.1
tensorflow
flake8
matplotlib
progressbar2
Expand Down
2 changes: 0 additions & 2 deletions src/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .algorithm import Algorithm
from .autoencoder import NNAutoEncoder, LSTMAutoEncoder
from .dagmm import DAGMM
from .donut import Donut
Expand All @@ -8,7 +7,6 @@
from .rnn_ebm import RecurrentEBM

__all__ = [
'Algorithm',
'NNAutoEncoder',
'LSTMAutoEncoder',
'DAGMM',
Expand Down
6 changes: 5 additions & 1 deletion src/algorithms/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@


class Algorithm(metaclass=abc.ABCMeta):
def __init__(self, module_name, name):
def __init__(self, module_name, name, framework):
self.logger = logging.getLogger(module_name)
self.name = name
self.framework = framework

def __str__(self) -> str:
return self.name
Expand Down Expand Up @@ -39,3 +40,6 @@ def threshold(self, score):
:param score
:return threshold:
"""

class Frameworks:
PyTorch, Tensorflow = range(2)
31 changes: 19 additions & 12 deletions src/algorithms/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import torch
import torch.nn as nn

from .cuda_utils import GPUWrapper


class AutoEncoder(nn.Module):
"""AutoEncoder class, forward needs to return (decoded, encoded)"""


class NNAutoEncoder(AutoEncoder):

def __init__(self, n_features=118, sequence_length=1, hidden_size=1):
super().__init__()

class NNAutoEncoder(AutoEncoder, GPUWrapper):
def __init__(self, n_features=118, sequence_length=1, hidden_size=1, gpu=0):
AutoEncoder.__init__(self)
GPUWrapper.__init__(self, gpu)
n_features = n_features * sequence_length

layers = []
Expand All @@ -23,6 +24,7 @@ def __init__(self, n_features=118, sequence_length=1, hidden_size=1):
layers += [nn.Linear(10, hidden_size)]

self._encoder = nn.Sequential(*layers)
self.to_device(self._encoder)

layers = []
layers += [nn.Linear(hidden_size, 10)]
Expand All @@ -34,6 +36,7 @@ def __init__(self, n_features=118, sequence_length=1, hidden_size=1):
layers += [nn.Linear(60, n_features)]

self._decoder = nn.Sequential(*layers)
self.to_device(self._decoder)

def forward(self, x):
x = x.view(x.shape[0], -1)
Expand All @@ -44,12 +47,13 @@ def forward(self, x):
return dec, enc


class LSTMAutoEncoder(AutoEncoder):
class LSTMAutoEncoder(AutoEncoder, GPUWrapper):
"""Autoencoder with Recurrent module. Inspired by LSTM-EncDec"""

def __init__(self, n_features: int, sequence_length: int, hidden_size: int = 1, n_layers: tuple = (3, 3),
use_bias: tuple = (True, True), dropout: tuple = (0.3, 0.3)):
super().__init__()
use_bias: tuple = (True, True), dropout: tuple = (0.3, 0.3), gpu: int = 0):
AutoEncoder.__init__(self)
GPUWrapper.__init__(self, gpu)

self.n_features = n_features
self.sequence_length = sequence_length
Expand All @@ -61,13 +65,16 @@ def __init__(self, n_features: int, sequence_length: int, hidden_size: int = 1,

self.encoder = nn.LSTM(self.n_features, self.hidden_size, batch_first=True,
num_layers=self.n_layers[0], bias=self.use_bias[0], dropout=self.dropout[0])
self.to_device(self.encoder)
self.decoder = nn.LSTM(self.n_features, self.hidden_size, batch_first=True,
num_layers=self.n_layers[1], bias=self.use_bias[1], dropout=self.dropout[1])
self.to_device(self.decoder)
self.hidden2output = nn.Linear(self.hidden_size, self.n_features)
self.to_device(self.hidden2output)

def _init_hidden(self, batch_size):
return (torch.zeros(self.n_layers[0], batch_size, self.hidden_size),
torch.zeros(self.n_layers[0], batch_size, self.hidden_size))
return (self.to_var(torch.zeros(self.n_layers[0], batch_size, self.hidden_size)),
self.to_var(torch.zeros(self.n_layers[0], batch_size, self.hidden_size)))

def forward(self, ts_batch):
batch_size = ts_batch.shape[0]
Expand All @@ -77,13 +84,13 @@ def forward(self, ts_batch):
_, enc_hidden = self.encoder(ts_batch.float(), enc_hidden) # .float() here or .double() for the model

# 2. Use hidden state as initialization for our Decoder-LSTM
dec_hidden = (enc_hidden[0], torch.zeros(self.n_layers[1], batch_size, self.hidden_size))
dec_hidden = (enc_hidden[0], self.to_var(torch.zeros(self.n_layers[1], batch_size, self.hidden_size)))

# 3. Also, use this hidden state to get the first output aka the last point of the reconstructed timeseries
# 4. Reconstruct timeseries backwards
# * Use true data for training decoder
# * Use hidden2output for prediction
output = torch.zeros(ts_batch.size())
output = self.to_var(torch.zeros(ts_batch.size()))
for i in reversed(range(ts_batch.shape[1])):
output[:, i, :] = self.hidden2output(dec_hidden[0][0, :])

Expand Down
28 changes: 28 additions & 0 deletions src/algorithms/cuda_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import tensorflow as tf
import torch
from tensorflow.python.client import device_lib
from torch.autograd import Variable


class GPUWrapper:
def __init__(self, gpu):
self.gpu = gpu

@property
def tf_device(self):
local_device_protos = device_lib.list_local_devices()
gpus = [x.name for x in local_device_protos if x.device_type == 'GPU']
return tf.device(gpus[self.gpu] if gpus else '/cpu:0')

@property
def torch_device(self):
return torch.device(f'cuda:{self.gpu}' if torch.cuda.is_available() else 'cpu')

def to_var(self, x, **kwargs):
"""PyTorch only: send Var to proper device."""
x = x.to(self.torch_device)
return Variable(x, **kwargs)

def to_device(self, model):
"""PyTorch only: send Model to proper device."""
model.to(self.torch_device)
Loading

0 comments on commit d5cced5

Please sign in to comment.