In [1]:
import cheetah
import matplotlib.pyplot as plt
import numpy as np
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import wandb

### Generate dataset

This time we create a dataset in the form of a PyTorch `Dataset` object. This is always good practice to keep your code tidy, but more importantly it will help you when your data inevitably scales and, for example, no longer fits into memory. Similar classes also exist in TensorFlow and Keras, they are not exclusive to PyTorch.

The data generation is otherwise the same as in the TensorFlow example:
 - Inputs are beam parameters of the incoming beam, k1 and the quadrupole's length.
 - Outputs are the beam parameters of the outgoing beam.
 
 We start by defining what our dataset is supposed to look like and how it is generated.

In [2]:
class QuadrupoleDataset(Dataset):
    
    def __init__(self, n=5000, normalize=False, x_scaler=None, y_scaler=None):
        self.n = n
        self.normalize = normalize
        self.x_scaler = x_scaler
        self.y_scaler = y_scaler
        
        self.x, self.y = self.generate_samples(self.n)
                
        if normalize:
            if not self.x_scaler:
                self.x_scaler = StandardScaler().fit(self.x)
            self.x = self.x_scaler.transform(self.x)
            
            if not self.y_scaler:
                self.y_scaler = StandardScaler().fit(self.y)
            self.y = self.y_scaler.transform(self.y)
        
        self.x = self.x.astype(np.float32)
        self.y = self.y.astype(np.float32)
                
        
    def __len__(self):
        return len(self.y)
    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    
    def generate_samples(self, n):
        parameter_keys = ["energy", "mu_x", "mu_xp", "mu_y", "mu_yp", "sigma_p", "sigma_s",
                          "sigma_x", "sigma_xp", "sigma_y", "sigma_yp"]
        parameters = [{
            "mu_x": np.random.uniform(-1e-3, 1e-3),
            "mu_y": np.random.uniform(-1e-3, 1e-3),
            "mu_xp": np.random.uniform(-1e-4, 1e-4),
            "mu_yp": np.random.uniform(-1e-4, 1e-4),
            "sigma_x": np.random.uniform(1e-5, 5e-4),
            "sigma_y": np.random.uniform(1e-5, 5e-4),
            "sigma_xp": np.random.uniform(1e-6, 5e-5),
            "sigma_yp": np.random.uniform(1e-6, 5e-5),
            "sigma_s": np.random.uniform(1e-6, 5e-5),
            "sigma_p": np.random.uniform(1e-4, 1e-3),
            "energy": np.random.uniform(80e6, 160e6)
        } for _ in range(n)]
        
        beams = [cheetah.ParameterBeam.from_parameters(**p) for p in parameters]

        x1 = np.array([[b.parameters[k] for k in parameter_keys] for b in beams])
        x0 = np.array([[np.random.uniform(0.1, 0.3), np.random.uniform(-15.0, 15.0)] for _ in range(n)])
        x = np.hstack([x0, x1])
        
        y = []
        for incoming, setting in zip(beams, x0):
            quadrupole = cheetah.Quadrupole(length=setting[0], k1=setting[1])
            outgoing = quadrupole(incoming)
            y.append(outgoing)
        y = np.array([[b.parameters[k] for k in parameter_keys] for b in y])
                
        return x, y

### Defining the Model

Below we define our neural network model. This is where you will really start to see PyTorch-Lighting in action. We start in vanilla PyTorch, though, by first defining the actual model in `MLPRegressor` and then the "ML System" in `QuadrupoleSurrogate`. This is a best practice the reasons for which are more apparent when training more complex architectures, such as Autoencoders and GANs. If you want to know more about this, I recommend reading the *Style Guide* page in the PyTorch-Lightning documentation.

**Note:** In this simple example where the model is just a simple MLP, it would be much quicker to use a PyTorch `Sequential` model rather that define the `MLP Regressor` class. We are going the slightly longer route because I thought it a good opportunity to illustrate how easy PyTorch makes it to play with the lower-level details of a model, which helps when you want to implement highly customised models.

In [3]:
class MLPRegressor(nn.Module):
    
    def __init__(self, width=32):
        super().__init__()
        self.layer_1 = nn.Linear(13, width)
        self.layer_2 = nn.Linear(width, width)
        self.output_layer = nn.Linear(width, 11)
    
    def forward(self, x):
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.output_layer(x)

        return x

Now we define the "ML System" using PyTorch-Lighting. Were it not for this awesome package, you would have to write the entire trainig loop yourself. With it you only need to define what one training step looks like in `training_step`.

In [4]:
class QuadrupoleSurrogate(pl.LightningModule):
    
    def __init__(self, width=32):
        super().__init__()
        self.regressor = MLPRegressor(width=width)
    
    def forward(self, x):
        y = self.regressor(x)        
        return y
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.regressor(x)
        loss = F.mse_loss(y_hat, y)
        
        self.log("train_loss", loss)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.regressor(x)
        loss = F.mse_loss(y_hat, y)
        
        self.log("val_loss", loss)
        
        return loss
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

### Train

Now we can go and train.

We start by **loading data**. All we do is instatiate two of the datasets define above. Keep in mind that data should be normalised and that the normalisation should be based only on the training data!
`DataLoaders` are a further abstraction provided by PyTorch. They take care of shuffling and minibatching the data for you during training.

Then we **setup the training** by instantiating our model, a logger to W&B and the PyTorch-Lightning `Trainer` that handles the training loop for us and we are good to go!

PyTorch-Lightning will take care of saving our model every so often, when it improved or when we interupt the training. It's like magic! 🪄

In [22]:
training_data = QuadrupoleDataset(n=4000, normalize=True)
validation_data = QuadrupoleDataset(n=2000, normalize=True, x_scaler=training_data.x_scaler, y_scaler=training_data.y_scaler)
train_loader = DataLoader(training_data, batch_size=64, shuffle=True)
val_loader = DataLoader(validation_data, batch_size=64, shuffle=True)

surrogate = QuadrupoleSurrogate(width=32)

wandb_logger = WandbLogger(project="quadrupole-surrogate-pytorch", entity="msk-ipc")
trainer = pl.Trainer(logger=wandb_logger)

trainer.fit(surrogate, train_loader, val_loader)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmsk-ipc[0m (use `wandb login --relogin` to force relogin)


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name      | Type         | Params
-------------------------------------------
0 | regressor | MLPRegressor | 1.9 K 
-------------------------------------------
1.9 K     Trainable params
0         Non-trainable params
1.9 K     Total params
0.007     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]


Aborted!


### Try the Model

Now we can load a checkpoint of our trained model and try it. Luckily, PyTorch-Lightning makes this very easy.

In [12]:
loaded_surrogate = QuadrupoleSurrogate.load_from_checkpoint("quadrupole-surrogate-pytorch/1vc56z4w/checkpoints/epoch=271-step=17136.ckpt")
loaded_surrogate.eval()

QuadrupoleSurrogate(
  (regressor): MLPRegressor(
    (layer_1): Linear(in_features=13, out_features=32, bias=True)
    (layer_2): Linear(in_features=32, out_features=32, bias=True)
    (output_layer): Linear(in_features=32, out_features=11, bias=True)
  )
)

Next we create a new random sample input just to see if we can get a prediction that makes sense.

In [13]:
parameter_keys = ["energy", "mu_x", "mu_xp", "mu_y", "mu_yp", "sigma_p", "sigma_s",
                          "sigma_x", "sigma_xp", "sigma_y", "sigma_yp"]
parameters = {
    "mu_x": np.random.uniform(-1e-3, 1e-3),
    "mu_y": np.random.uniform(-1e-3, 1e-3),
    "mu_xp": np.random.uniform(-1e-4, 1e-4),
    "mu_yp": np.random.uniform(-1e-4, 1e-4),
    "sigma_x": np.random.uniform(1e-5, 5e-4),
    "sigma_y": np.random.uniform(1e-5, 5e-4),
    "sigma_xp": np.random.uniform(1e-6, 5e-5),
    "sigma_yp": np.random.uniform(1e-6, 5e-5),
    "sigma_s": np.random.uniform(1e-6, 5e-5),
    "sigma_p": np.random.uniform(1e-4, 1e-3),
    "energy": np.random.uniform(80e6, 160e6)
}
l = 0.2
k1 = 13.1

Here is the ground truth outgoing beam computed by Cheetah.

In [14]:
incoming = cheetah.ParameterBeam.from_parameters(**parameters)
outgoing = cheetah.Quadrupole(length=l, k1=k1)(incoming)
outgoing

ParameterBeam(mu_x=-0.000476, mu_xp=0.001560, mu_y=-0.000880, mu_yp=-0.002034, sigma_x=0.000218, sigma_xp=0.000697, sigma_y=0.000050, sigma_yp=0.000125, sigma_s=0.000011, sigma_p=0.000926, energy=123855842.779)

Now we do the prediction with our model. Don't forget to normalise the input and unnormalise the output! The results below should look pretty close to the true values. Keep in mind that this model has not been tuned yet.

**Note:** The code below could be a low nicer, but that's for another time.

In [15]:
X_try = np.array([[l, k1] + [parameters[k] for k in parameter_keys]])
X_try_scaled = training_data.x_scaler.transform(X_try)
X_try_scaled = torch.tensor(X_try_scaled, dtype=torch.float)

with torch.no_grad():
    y_try_scaled = loaded_surrogate(X_try_scaled)

y_try = training_data.y_scaler.inverse_transform(y_try_scaled)
outgoing = y_try.squeeze()

for k, v in zip(parameter_keys, outgoing):
    print(f"{k} = {v}")

energy = 124255287.58704847
mu_x = -0.0004363412847541228
mu_xp = 0.0016406806637992383
mu_y = -0.0009006608697546891
mu_yp = -0.0019867829847205914
sigma_p = 0.0009292942036982102
sigma_s = 1.1579115883921944e-05
sigma_x = 0.00021396877212374012
sigma_xp = 0.0006996545417383908
sigma_y = 5.687342307038288e-05
sigma_yp = 0.0001057605214651619
