In [1]:
import torch
import lightning as L
from dataset import AEDataModule
from model import AutoEncoder, VariationalAutoEncoder
from callbacks import SaveBest, SaveEveryNEpochs, BetaWarmUp
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.loggers import TensorBoardLogger
from views import GAIN_PIDS as gain_list

import numpy as np

import seaborn as sns
import matplotlib.pyplot as plt

# Define important variables

In [2]:
data_dir = "../data"
file_name = "20251028_all_data.npz"
input_dim = 360
geometry = [64, 32, 32]
beta = 1e-7
gamma = 1e-4
batch_size = 32
max_epochs = 100
normalization_strategy = 'minmax'  # Options: 'minmax', 'zscore', 'robust', 'none'
learning_rate = 3e-4
nprofiles = None
activation = torch.nn.ReLU()
model_kind = 'AE'
device = 'cpu'

# Training function

In [3]:
def train_autoencoder(data_dir, file_name, input_dim=360, geometry=[64, 32, 16, 8], 
					  beta=1, gamma=1, batch_size=32, max_epochs=100, 
					  normalization_strategy='minmax', learning_rate=3e-4, 
					  activation=torch.nn.ReLU, nprofiles=None, model_kind='AE'):
	# Initialize the data module
	data_module = AEDataModule(data_dir, file_name, batch_size, normalization_strategy, nprofiles=nprofiles)
	data_module.prepare_data()
	data_module.setup()
	data_module.exclude_pids(gain_list)

	# Initialize the model
	if model_kind == 'AE':
		model = AutoEncoder(input_dim=input_dim, geometry=geometry, learning_rate=learning_rate, activation=activation)
		logger_name = f"AE{input_dim}"
	elif model_kind == 'VAE':
		model = VariationalAutoEncoder(input_dim=input_dim, geometry=geometry, beta=beta, gamma=gamma, learning_rate=learning_rate, activation=activation)
		logger_name = f"VAE{input_dim}"

	# Initialize a logger
	logger = TensorBoardLogger("W7-X_QXT", name=logger_name)

	# Initialize the trainer
	trainer = L.Trainer(
		logger=logger,
		max_epochs=max_epochs,
		accelerator=device,
		callbacks=[
			SaveBest(monitor="val/loss", logger=logger),
			SaveEveryNEpochs(10, logger=logger),
			EarlyStopping(monitor="val/loss", patience=10, mode="min"),
			# BetaWarmUp(start_epoch=50, initial_beta=0, final_beta=0.1, warmup_epochs=100),
			],
		devices=1,)


	# Train the model
	trainer.fit(model, data_module)

	# Evaluate the model on the validation set
	if data_module.val_data is not None:
		trainer.validate(model, datamodule=data_module)
	# Evaluate the model on the test set
	if data_module.test_data is not None:
		trainer.test(model, datamodule=data_module)

	return model, data_module

In [None]:
model = train_autoencoder(data_dir, file_name, input_dim, geometry, beta, gamma, 
                         batch_size, max_epochs, normalization_strategy, learning_rate,
                         activation, nprofiles, model_kind)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


The file db_excluded_pid.npz already exists in path ../data, loading it...
Data loaded successfully.


/home/IPP-HGW/orluca/.venv/ptl/lib/python3.11/site-packages/lightning/pytorch/trainer/setup.py:177: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.

  | Name       | Type       | Params | Mode 
--------------------------------------------------
0 | activation | ReLU       | 0      | train
1 | encoder    | Sequential | 26.2 K | train
2 | decoder    | Sequential | 26.6 K | train
--------------------------------------------------
52.8 K    Trainable params
0         Non-trainable params
52.8 K    Total params
0.211     Total estimated model params size (MB)
9         Modules in train mode
0         Modules in eval mode


Data is already normalized, returning original data


Sanity Checking: |                                                      | 0/? [00:00<?, ?it/s]

/home/IPP-HGW/orluca/.venv/ptl/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
/home/IPP-HGW/orluca/.venv/ptl/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:433: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |                                                             | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

Validation: |                                                           | 0/? [00:00<?, ?it/s]

# Evaluate the model 

In [None]:
model.eval() # Put the model in evaluation mode so not to update the weights

# get the data_module
data_module = AEDataModule(data_dir, file_name, batch_size, normalization_strategy, nprofiles=nprofiles)
data_module.prepare_data()
data_module.setup()
data_module.get_pids(gain_list)


with torch.no_grad():
    for batch in data_module.test_dataloader():
        x = batch['profile'].to(device)
        pid = batch['pid']
        time = batch['time']
        z = model.encoder(x)
        y = model(x)


for i, (x_i, y_i) in enumerate(zip(x, y)):
    plt.figure(figsize=(10, 6))
    plt.plot(x_i, ls='-.', marker='o', color='blue', label='meas.')
    plt.plot(y_i, ls='--', marker='>', color='orange', label='reco.')
    plt.xlabel("Diode number")
    plt.ylabel("Brilliance [a.u.]")
    plt.title(f"PID {pid[i]} @ t = {time[i]:.3f} s")
    plt.show()
    plt.close()

        

