In [None]:
# # Plot signals and their image representations
# fig, axes = plt.subplots(5, 5, figsize=(20, 20))
# fig.suptitle('Synthetic Signals and Their Image Representations')

# for i in range(5):
#     # Plot signal
#     axes[0,i].plot(signals[i].numpy())
#     axes[0,i].set_title(f'Signal {i+1}')
#     axes[0,i].set_xlabel('Time')
#     axes[0,i].set_ylabel('Amplitude')
    
#     # Get all transformation images
#     img = images[i].numpy()
    
#     # Plot GASF
#     im = axes[1,i].imshow(img[0], cmap='viridis')
#     axes[1,i].set_title(f'MTF {i+1}')
#     axes[1,i].axis('off')
    
#     # Plot GADF
#     im = axes[2,i].imshow(img[1], cmap='viridis')
#     axes[2,i].set_title(f'GAF {i+1}')
#     axes[2,i].axis('off')
    
#     # Plot MTF
#     im = axes[3,i].imshow(img[2], cmap='viridis')
#     axes[3,i].set_title(f'RP {i+1}')
#     axes[3,i].axis('off')
    
#     # Plot Recurrence Plot
#     im = axes[4,i].imshow(img[3], cmap='viridis')
#     axes[4,i].set_title(f'SWT {i+1}')
#     axes[4,i].axis('off')
    
    
# plt.tight_layout()
# plt.show()

In [None]:
import torch 
import os 
from model import *

config = {
    "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'), 
    "output_dir": './outputs/',
    "run_name": 'run_probabilistic_2d',
    "n_samples": 300000,
    "batch_size": 32,
    "n_features": 4,
    "n_timesteps": 600,
    "image_size": 64,
    "transform_method": 'all',
    "seed": 42,
    "latent_dim": 32,
    "n_components": 4,
    "learning_rate": 1e-3,
    "num_epochs": 5,
    "beta": 0.1
}

# os.makedirs(output_dir, exist_ok=True)


# Create output directory
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
output_dir = os.path.join(config["output_dir"], config["run_name"])
os.makedirs(output_dir, exist_ok=True)

# Setup logging
log_file = setup_logging(output_dir)
# logging.info(f'Starting training with arguments: {args}')

# Set device
logging.info(f'Using device: {config["device"]}')

# Generate synthetic data
logging.info('Generating synthetic data...')

# Split data
train_size = int(0.8 * config["n_samples"])
val_size = config["n_samples"] - train_size

# Create dataloaders
train_loader = get_dataloader(
    n_samples=train_size,
    batch_size=config["batch_size"],
    n_features=config["n_features"],
    n_timesteps=config["n_timesteps"],
    image_size=config["image_size"],
    transform_method=config["transform_method"],
    seed=config["seed"],
    num_workers=12,
    shuffle=True
)

val_loader = get_dataloader(
    n_samples=val_size,
    batch_size=config["batch_size"],
    n_features=config["n_features"],
    n_timesteps=config["n_timesteps"],
    image_size=config["image_size"],
    transform_method=config["transform_method"],
    seed=config["seed"] + 1,
    num_workers=4,
    shuffle=False
)

logging.info(f'Generated {train_size} training samples and {val_size} validation samples')


model = VAE(
    latent_dim=config["latent_dim"],
    input_channels=4,
    image_size=config["image_size"],
    signal_length=config["n_timesteps"], 
    n_components=config["n_components"]
)

# Initialize model
# model = VAE1D(
#     latent_dim=config["latent_dim"],
#     # input_channels=4,
#     # image_size=image_size,
#     signal_length=config["n_timesteps"], 
#     n_components=config["n_components"]
# )
optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])

# Train model
train_model(
    model=model,
    train_loader=train_loader, 
    val_loader=val_loader, 
    optimizer=optimizer, 
    num_epochs=config["num_epochs"], 
    device=config["device"], 
    log_dir=output_dir, 
    beta=config["beta"] 
)

logging.info('Training completed')


In [None]:
import os 
# Import required libraries
import numpy as np
import matplotlib.pyplot as plt
from synthetic_gen import get_dataloader
# Ignore pyts warnings about quantiles
import warnings
warnings.filterwarnings('ignore', category=UserWarning, module='pyts')


# Create a dataloader with a small batch size to visualize examples
dataloader = get_dataloader(
    n_samples=10, 
    batch_size=5,
    n_timesteps=config["n_timesteps"],
    image_size=config["image_size"],
    transform_method=config["transform_method"]
)

# Get a batch of data
batch = next(iter(dataloader))
signals = batch['signal']
images = batch['image']

In [None]:
import torch 
import os 
from model import *

# Initialize model
model = VAE(
    latent_dim=config["latent_dim"],
    input_channels=4,
    image_size=config["image_size"],
    signal_length=config["n_timesteps"], 
    n_components=config["n_components"]
)

model = VAE1D(
    latent_dim=config["latent_dim"],
    signal_length=config["n_timesteps"], 
    n_components=config["n_components"]
)

output_dir = os.path.join(config['output_dir'], config['run_name'])

checkpoint_path = f"{output_dir}/best_model.pth"
logging.info(f"Loading model from {checkpoint_path}")

model.load_state_dict(torch.load(checkpoint_path, map_location=config["device"]))
model.to(config["device"])
model.eval()
logging.info("Model loaded successfully")

In [None]:
signals = model.predict_signals(batch, return_mean=False)

if len(signals.shape) == 3:
    # remove the second dimension
    signals = signals[:, 0, :]

In [None]:
# Plot the 5 output signals
fig, axs = plt.subplots(1, len(signals), figsize=(20, 4))
for i in range(len(signals)):
    axs[i].plot(batch['signal'][i].detach().numpy(), label='Target', alpha=0.8)
    axs[i].plot(signals[i].detach().numpy(), label='Predicted', alpha=0.8)
    axs[i].legend()
plt.tight_layout()
plt.show()

In [None]:
# mixture_weights, locations, scales, dofs, mu, log_var, z = model(batch['image'].to(device))

# mixture = model.create_mixture_distribution(mixture_weights, locations, scales, dofs)
        
# probs = mixture.log_prob(batch['signal'].to(device))

# probs.shape

In [None]:
# mixture

In [None]:
# batch['signal'].shape