In [None]:
from tensorflow.keras.layers import Input, Dense, Lambda, LSTM, RepeatVector
from tensorflow.keras.models import Model
from tensorflow.keras.losses import mse
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt

# VAE model
input_dim = train_sequences.shape[2]
latent_dim = 50

inputs = Input(shape=(time_steps, input_dim))
encoded = LSTM(latent_dim, activation='relu')(inputs)

z_mean = Dense(latent_dim)(encoded)
z_log_var = Dense(latent_dim)(encoded)

def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

decoder_h = LSTM(latent_dim, activation='relu', return_sequences=True)
decoder_mean = LSTM(input_dim, activation='sigmoid', return_sequences=True)

h_decoded = RepeatVector(time_steps)(z)
x_decoded_h = decoder_h(h_decoded)
x_decoded_mean = decoder_mean(x_decoded_h)

vae = Model(inputs, x_decoded_mean)

# Loss function
reconstruction_loss = mse(K.flatten(inputs), K.flatten(x_decoded_mean))
reconstruction_loss *= input_dim * time_steps
kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
kl_loss = K.sum(kl_loss, axis=-1)
kl_loss *= -0.5
vae_loss = K.mean(reconstruction_loss + kl_loss)
vae.add_loss(vae_loss)
vae.compile(optimizer='adam')

# Train the VAE
vae.fit(train_sequences, epochs=50, batch_size=32, validation_split=0.2)

# Generate new simulations
decoder_input = Input(shape=(latent_dim,))
_h_decoded = decoder_h(RepeatVector(time_steps)(decoder_input))
_x_decoded_mean = decoder_mean(_h_decoded)
generator = Model(decoder_input, _x_decoded_mean)

def generate_simulation(generator, num_samples=5):
    latent_samples = np.random.normal(size=(num_samples, latent_dim))
    simulations = generator.predict(latent_samples)
    return simulations

# Generate and plot new wildfire simulations
vae_simulations = generate_simulation(generator, num_samples=5)

for i, sim in enumerate(vae_simulations):
    plt.figure(figsize=(14, 5))
    plt.plot(sim, label=f'Simulation {i+1}')
    plt.legend()
    plt.title('Generated Wildfire Simulation with VAE')
    plt.xlabel('Time Steps')
    plt.ylabel('Normalized Feature Values')
    plt.show()
