In [None]:
import torch
from torch.utils.data import DataLoader
import numpy as np

import src.training as training
import src.models as models

# Set device for PyTorch
device  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Set seed for reproducibility
seed = 14
np.random.seed(seed)
torch.manual_seed(seed)

### Load training data

- Emulator must be finetuned using a small sample of spectra - O(100) - from the target N(z)

In [None]:
# Load in data for finetuning the MAML model
filepath = '../cl_ee_mcmc_dndz_nsamples=30000.h5'
n_finetune = 100  # Number of training samples for finetuning

train_data, test_data, ScalerY, ScalerX = training.load_train_test_val(
    filepath=filepath, n_train=n_finetune, n_val=None, n_test=None, seed=seed,
    device=device
)
X_train, y_train = train_data[:]

### Next we need to load the model and apply the trained weights

- Parameters below are as configured for the paper, I haven't done extensive optimisation on them, feel free to play around with them

In [None]:
in_size = X_train.shape[1]
out_size = y_train.shape[1]        

# Construct model architecture
model = models.FastWeightCNN(
    input_size=in_size,
    latent_dim=(16,16),
    output_size=out_size,
    dropout_rate=0.2
)

# Initialise a MetaLearner for training
metalearner = training.MetaLearner(
    model=model,
    outer_lr=0.01, # Training params
    inner_lr=0.001,
    loss_fn=torch.nn.MSELoss,
    beta1=0.9, # Adam params
    beta2=0.999,
    epsilon=1e-8,
    seed=seed, # Random seed for reproducibility
    device=device
)

# Load an apply the trained meta-weights
weight_path = 'weights\WEIGHTS_5batch_500samples_20tasks_14seed.pt'
metalearner.model.load_state_dict(
    torch.load(weight_path, map_location=device)
)

### Now we finetune the model

- Task specific weights are stored separately and passed at runtime so as to protect the meta-weights from being overwritten

In [None]:
finetune_epochs = 64

fast_weights, _ = metalearner.finetune(
    X_train,
    y_train,
    adapt_steps=finetune_epochs,
    use_new_adam=True # Start with fresh adam optimizer state for finetuning
)

### Finally, we can test the trained model on the rest of the sample spectra

In [None]:
# Send the test data to a torch dataloader
# Can modify batch size as needed depending on system resources
test_loader = DataLoader(test_data, batch_size=5000, shuffle=False)

# Evaluate the model on the test data
metalearner.model.eval()
y_pred = torch.tensor([]).to(device)
for X_batch, y_batch in test_loader:
    model.eval()
    with torch.no_grad(): # don't compute gradients during inference
        y_pred_batch = metalearner.model(X_batch, params=fast_weights)
        y_pred = torch.cat((y_pred, y_pred_batch), dim=0)

print('Total predictions:', y_pred.shape)
# Inverse transform the data
y_pred = ScalerY.inverse_transform(y_pred)

y_pred_np = y_pred.cpu().numpy()
y_test = test_data[1]  # Get original test data
y_test_np = y_test.cpu().numpy()

# Exponentiate the data
y_pred_np = np.exp(y_pred_np)
y_test_np = np.exp(y_test_np)