In [None]:
!pip install torch==1.11.0  # version recommended by source

In [None]:
!pip install git+https://github.com/gretelai/gretel-synthetics.git

In [None]:
import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.dates as md
from sklearn.preprocessing import MinMaxScaler
from pickle import dump, load
from skimage.util.shape import view_as_windows

import torch

from gretel_synthetics.timeseries_dgan.dgan import DGAN
from gretel_synthetics.timeseries_dgan.config import DGANConfig, OutputType

In [None]:
# https://www.kaggle.com/code/xiaxiaxu/predictmachinefailureinadvance/data
sensor = pd.read_csv("sensor.csv")

# data columns
COLS = ["sensor_25", "sensor_11", "sensor_36", "sensor_34", "machine_status"]

# only keeping cols w high var in pca analysis + machine status
data = sensor[["timestamp"] + COLS]
data

In [None]:
# lets convert our datatypes to the correct ones

print(data.dtypes, "\n")
# i have no idea why i have to run this line twice for it to work.
data["timestamp"] = pd.to_datetime(data["timestamp"])

In [None]:
# okay, what can our machine status be?

# data = data.convert_dtypes()
print(data.dtypes, "\n")
# i have no idea why i have to run this line twice for it to work.
data["timestamp"] = pd.to_datetime(data["timestamp"])

data.machine_status.unique()

In [None]:
# lets drop the NaNs

data.dropna(axis=0, inplace=True)

print(len(data[data.isna().any(axis=1)]))

In [None]:
# lets try only using data centered around 2 failures.

data_around_failures = data.iloc[16000:26080]
print(len(data_around_failures))
data_around_failures

In [None]:
# # let's now scale our data between 0-1

# scaler = MinMaxScaler((0,1))
sensor_cols = ["sensor_25", "sensor_11", "sensor_36", "sensor_34"]

# # scaling our data, then saving our scaler object for future use.
# data_around_failures[sensor_cols] = scaler.fit_transform(data_around_failures[sensor_cols])

# dump(scaler, open('dGAN_scaler_windowing.pkl', 'wb'))

In [None]:
# Plot the 4 scaled sensors

for c in COLS:
    if c == "machine_status":
        continue
    plt.plot(data_around_failures["timestamp"], data_around_failures[c], label=c)

plt.xticks(rotation=90)
plt.legend()
plt.ylabel("Sensor Value")
plt.xlabel("Date")
plt.show()

In [None]:
# dropping timestamp col
features = data_around_failures.drop(columns=["timestamp", "machine_status"]).to_numpy()
print(features.shape)

In [None]:
window_shape = (240, 4)
windowed_data = view_as_windows(features, window_shape, step=10)
windowed_data = np.squeeze(windowed_data)
print(windowed_data.shape)
print(windowed_data[0].shape)
windowed_data.shape

In [None]:
# Show a few of the 4-hour training samples
# note x-axis isnt accurate in these plots.
# xaxis_4hr = data_around_failures["timestamp"][0:240]


def plot_hours(f, ind):
    # get the 4hr time period of these points - might be wrong
    # but i think the logic is right.
    xaxis_4hr = data_around_failures["timestamp"][ind * 10 : ind * 10 + 240]

    for i, c in enumerate(sensor_cols):
        plt.plot(xaxis_4hr, f[:, i], label=c)
    ax = plt.gca()
    ax.xaxis.set_major_locator(md.HourLocator(byhour=range(2, 24, 3)))
    ax.xaxis.set_major_formatter(md.DateFormatter("%H:%M"))
    plt.legend(prop={"size": 7})


figure = plt.figure(figsize=(10, 10))
for i in range(1, 10):
    sample_idx = torch.randint(len(windowed_data), size=(1,)).item()
    point = windowed_data[sample_idx]
    figure.add_subplot(3, 3, i)

    plot_hours(point, sample_idx)
plt.show()

In [None]:
# Recommended to train with a GPU
torch.cuda.is_available()

In [None]:
# set up DGAN config.

config = DGANConfig(
    max_sequence_len=windowed_data.shape[1],
    sample_len=20,  # trying a larger sample_len
    batch_size=min(1000, windowed_data.shape[0]),
    apply_feature_scaling=True,
    apply_example_scaling=False,
    use_attribute_discriminator=False,
    generator_learning_rate=1e-4,
    discriminator_learning_rate=1e-4,
    epochs=10000,
)

model = DGAN(config)

In [None]:
model.train_numpy(
    windowed_data,
    feature_types=[OutputType.CONTINUOUS] * windowed_data.shape[2],
)
# this took 7 minutes on Small + GPU

In [None]:
# Generate synthetic data - this ran near instantly
_, synthetic_features = model.generate_numpy(1000)

In [None]:
# plot windowed data - the x-axis in these plots is meaningless here

figure = plt.figure(figsize=(10, 10))
for i in range(1, 10):
    sample_idx = torch.randint(len(synthetic_features), size=(1,)).item()
    point = synthetic_features[sample_idx]
    figure.add_subplot(3, 3, i)

    plot_hours(point, sample_idx)
plt.show()

In [None]:
# Compare (non-temporal) correlations between the 4 sensors
synthetic_df = pd.DataFrame(
    synthetic_features.reshape(-1, synthetic_features.shape[2]), columns=sensor_cols
)

print("Correlation in real data:")
print(data_around_failures.drop(columns=["timestamp", "machine_status"]).corr())
print()
print("Correlation in synthetic data:")
print(synthetic_df.corr())

# Correlations between sensor variables are even more similar

In [None]:
# Compare distribution of sensor_34 values
plt.hist(
    [windowed_data[:, :, 3].flatten(), synthetic_features[:, :, 3].flatten()],
    label=["real", "synthetic"],
    bins=25,
    density=True,
)
plt.legend()
plt.xlabel("Sensor 34 Values")
plt.ylabel("Density")
plt.show()

In [None]:
# Compare distribution of sensor_25 values
plt.hist(
    [windowed_data[:, :, 0].flatten(), synthetic_features[:, :, 0].flatten()],
    label=["real", "synthetic"],
    bins=25,
    density=True,
)
plt.legend()
plt.xlabel("Sensor 25 Values")
plt.ylabel("Density")
plt.show()

In [None]:
# Compare distribution of sensor_11 values
plt.hist(
    [windowed_data[:, :, 1].flatten(), synthetic_features[:, :, 1].flatten()],
    label=["real", "synthetic"],
    bins=25,
    density=True,
)
plt.legend()
plt.xlabel("Sensor 11 Values")
plt.ylabel("Density")
plt.show()

In [None]:
# Compare distribution of sensor_36 values
plt.hist(
    [windowed_data[:, :, 2].flatten(), synthetic_features[:, :, 2].flatten()],
    label=["real", "synthetic"],
    bins=25,
    density=True,
)
plt.legend()
plt.xlabel("Sensor 36 Values")
plt.ylabel("Density")
plt.show()

In [None]:
# saving model for future use
model.save("../models/dgan_model_1.pt")

# X = model.load("dgan_model_0.pt")

# X