Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#76 Add Sliding Window to DAGMM #77

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 22 additions & 26 deletions src/algorithms/dagmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,6 @@ def to_var(x, volatile=False):
return Variable(x, volatile=volatile)


class CustomDataLoader(object):
"""Wrap the given features so they can be put into a torch DataLoader"""

def __init__(self, X):
self.X = X

def __len__(self):
return self.X.shape[0]

def __getitem__(self, index):
return np.float32(self.X[index])


class DAGMM_Module(nn.Module):
"""Residual Block."""

Expand Down Expand Up @@ -171,13 +158,14 @@ def loss_function(self, x, x_hat, z, gamma, lambda_energy, lambda_cov_diag):

class DAGMM(Algorithm):
def __init__(self, num_epochs=5, lambda_energy=0.1, lambda_cov_diag=0.005, lr=1e-4, batch_size=700, gmm_k=3,
normal_percentile=80):
normal_percentile=80, sequence_length=5):
super().__init__(__name__, "DAGMM")
self.num_epochs = num_epochs
self.lambda_energy = lambda_energy
self.lambda_cov_diag = lambda_cov_diag
self.lr = lr
self.batch_size = batch_size
self.sequence_length = sequence_length
self.gmm_k = gmm_k # Number of Gaussian mixtures
self.normal_percentile = normal_percentile # Up to which percentile data should be considered normal
self.dagmm, self.optimizer, self.train_energy, self._threshold = None, None, None, None
Expand All @@ -201,15 +189,18 @@ def fit(self, X: pd.DataFrame, _):
"""Learn the mixture probability, mean and covariance for each component k.
Store the computed energy based on the training data and the aforementioned parameters."""
X = X.dropna()
data_loader = DataLoader(dataset=CustomDataLoader(X.values), batch_size=self.batch_size, shuffle=False)
self.dagmm = DAGMM_Module(n_features=X.shape[1], n_gmm=self.gmm_k)
data = X.values
# Each point is a flattened window and thus has as many features as sequence_length * features
multi_points = [data[i:i + self.sequence_length].flatten() for i in range(len(data) - self.sequence_length + 1)]
data_loader = DataLoader(dataset=multi_points, batch_size=self.batch_size, shuffle=True, drop_last=True)
self.dagmm = DAGMM_Module(n_features=self.sequence_length * X.shape[1], n_gmm=self.gmm_k)
self.optimizer = torch.optim.Adam(self.dagmm.parameters(), lr=self.lr)
self.dagmm.eval()

for _ in range(self.num_epochs):
for input_data in data_loader:
input_data = to_var(input_data)
self.dagmm_step(input_data)
self.dagmm_step(input_data.float())

n = 0
mu_sum = 0
Expand All @@ -218,7 +209,7 @@ def fit(self, X: pd.DataFrame, _):

for input_data in data_loader:
input_data = to_var(input_data)
_, _, z, gamma = self.dagmm(input_data)
_, _, z, gamma = self.dagmm(input_data.float())
phi, mu, cov = self.dagmm.compute_gmm_params(z, gamma)

batch_gamma_sum = torch.sum(gamma, dim=0)
Expand All @@ -236,7 +227,7 @@ def fit(self, X: pd.DataFrame, _):
train_energy = []
for input_data in data_loader:
input_data = to_var(input_data)
_, _, z, _ = self.dagmm(input_data)
_, _, z, _ = self.dagmm(input_data.float())
sample_energy, _ = self.dagmm.compute_energy(z, phi=train_phi, mu=train_mu, cov=train_cov,
size_average=False)
train_energy.append(sample_energy.data.cpu().numpy())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to take the mean of sample energy here as well.
I implemented it like this in the other branch:

train_length = len(data_loader)*self.batch_size + self.sequence_length - 1
train_energy = np.full((self.sequence_length, train_length), np.nan)
for i1, ts_batch in enumerate(data_loader):
    _, _, z, _ = self.dagmm(to_var(ts_batch).float())
    sample_energies, _ = self.dagmm.compute_energy(z, phi=train_phi, mu=train_mu, cov=train_cov,
                                                                                          size_average=False)
    for i2, sample_energy in enumerate(sample_energies):
        index = i1 * self.batch_size + i2
        window_elements = list(range(index, index + self.sequence_length, 1))
        train_energy[index % self.sequence_length, window_elements] = sample_energy.data.cpu().numpy()
self.train_energy = np.nanmean(train_energy, axis=0)

One issue here is that we have batches larger than 1, so we can't do it as we do in predict. In addition we're dropping the last batch, so the energy array isn't quite equal in length to data.

Expand All @@ -246,17 +237,22 @@ def fit(self, X: pd.DataFrame, _):
def predict(self, X: pd.DataFrame):
"""Using the learned mixture probability, mean and covariance for each component k, compute the energy on the
given data."""
self.dagmm.eval()
X = X.dropna()
test_energy = []
data_loader = DataLoader(dataset=CustomDataLoader(X.values), batch_size=self.batch_size, shuffle=False)
for input_data in data_loader:
input_data = to_var(input_data)
_, _, z, _ = self.dagmm(input_data)
data = X.values
multi_points = [data[i:i + self.sequence_length].flatten() for i in range(len(data) - self.sequence_length + 1)]
data_loader = DataLoader(dataset=multi_points, batch_size=1, shuffle=False)
test_energy = np.full((self.sequence_length, len(data)), np.nan)

for idx, long_point in enumerate(data_loader):
_, _, z, _ = self.dagmm(to_var(long_point).float())
sample_energy, _ = self.dagmm.compute_energy(z, size_average=False)
test_energy.append(sample_energy.data.cpu().numpy())
window_elements = np.arange(idx, idx + self.sequence_length, 1)
test_energy[idx % self.sequence_length, window_elements] = sample_energy.data.cpu().numpy()

test_energy = np.concatenate(test_energy, axis=0)
test_energy = np.nanmean(test_energy, axis=0)
combined_energy = np.concatenate([self.train_energy, test_energy], axis=0)

self._threshold = np.percentile(combined_energy, self.normal_percentile)
if np.isnan(self._threshold):
raise Exception("Threshold is NaN")
Expand Down
2 changes: 1 addition & 1 deletion src/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def evaluate(self):
score = det.predict(X_test)
self.results[(ds.name, det.name)] = score
except Exception as e:
self.logger.error(f"An exception occured while training {det.name} on {ds}: {e}")
self.logger.error(f"An exception occurred while training {det.name} on {ds}: {e}")
self.logger.error(traceback.format_exc())
self.results[(ds.name, det.name)] = np.zeros_like(y_test)

Expand Down
3 changes: 2 additions & 1 deletion tests/test_DAGMM.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@

class DAGMMTestCase(unittest.TestCase):
def test_kdd_cup(self):
evaluator = Evaluator([KDDCup()], [DAGMM(num_epochs=10)])
evaluator = Evaluator([KDDCup()], [DAGMM(num_epochs=10, sequence_length=1)])
df_evaluation = pd.DataFrame(
columns=["dataset", "algorithm", "accuracy", "precision", "recall", "F1-score", "F0.1-score"])
for _ in range(5):
evaluator.evaluate()
df = evaluator.benchmarks()
df_evaluation = df_evaluation.append(df)
print(df_evaluation.to_string())
assert (df_evaluation == 0).sum().sum() == 0 # No zeroes in the DataFrame
assert df_evaluation['F1-score'].std() > 0 # Not always the same value
# Values reported in the paper -1% each
Expand Down