In [None]:
import numpy as np
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt

from src.inference import *
from src.misc import *
from src.models.autoencoders import *
from src.plots import *
from src.train import *

In [None]:
# Parameters
D = 8
window_size = 40
total_samples= 10000
train_samples= 9000

# 1. Data
X = generate_synthetic_data(T=total_samples, D=D)
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

X_train = sliding_windows(X_scaled[train_samples:], window_size)
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)


In [None]:
print(X_train_tensor.shape) # (batch, seq_len, features)

In [None]:
# 2. Model (choose one)
# model = LSTMAutoencoder(input_dim=D, hidden_dim=4)
# model = CNNAutoencoder(input_dim=D)
model = FlattenedAutoencoder(d=D, t=window_size,n1=25)



In [None]:
batch_size = 16
torchinfo.summary(model, input_size=(batch_size, 40,8))

In [None]:
# 3. Train
train_autoencoder(model, X_train_tensor, epochs=50)



In [None]:
# 4. Inference
scores = online_inference(model, X, scaler, window_size)


In [None]:
# 5. Plot
plt.figure(figsize=(20,12))
plt.plot(scores, label='Anomaly Score')
plt.axhline(np.nanmean(scores) + 3*np.nanstd(scores), color='r', linestyle='--', label='Threshold')
plt.legend()
plt.grid()
plt.title('Online Anomaly Detection Score')
plt.show()


In [None]:
# After calling model on a window:
window_size = 40
window_start = 293
x_window = X[window_start:window_start + window_size]
x_window_scaled = scaler.transform(x_window)
x_input = torch.tensor(x_window_scaled, dtype=torch.float32).unsqueeze(0)

with torch.no_grad():
    x_output = model(x_input).squeeze(0).numpy()

# Inverse scale if needed
x_original = scaler.inverse_transform(x_window_scaled)
x_reconstructed = scaler.inverse_transform(x_output)

plot_reconstruction(x_original, x_reconstructed, window_start=window_start)


In [None]:
import plotly.graph_objs as go

fig = go.Figure()
for i in range(D):
    fig.add_trace(go.Scatter(y=X[:, i], mode='lines', name=f'Signal {i+1}'))

fig.update_layout(
    title='Each Signal in X Timeseries',
    xaxis_title='Time',
    yaxis_title='Value',
    legend_title='Signals',
    height=600,
    width=1000
)
fig.show()