## Example script to run a learnable wavelet scattering network on a subset of the CAMELs data

Simple example script. **Make sure you have run the `wget` command in the README to download the example data file**

We take a small sample of the CAMELs dataset (1k $M_\mathrm{tot}$ maps), and train a "SN" (2 layers of wavelet convolutions, 8 wavelet filters each, and pass the output to a CNN).

In [2]:
#!pip install wandb -qqq
import wandb
wandb.login()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmichaelhealy[0m ([33mwaveletcapstone[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
import time, sys, os
import numpy as np
import torch
import torch.backends.cudnn as cudnn

import matplotlib.pyplot as plt

## Learnable scattering modules
from learnable_wavelets.models import models_factory, sn_hybrid_models, camels_models
from learnable_wavelets.datasets import camels_dataset 

## MaxViT
import timm

from itertools import product
from mpl_toolkits import mplot3d
%matplotlib inline

In [4]:
## training parameters
batch_size  = 32
num_workers = 10    #number of workers to load data
lr_sn       = 0.01  ## Learning rate for scattering layers
lr          = 0.005 ## Learning rate for neural weights
wd          = 0.18   ## weight decay
epochs      = 80

In [5]:
# use GPUs if available
if torch.cuda.is_available():
    print("CUDA Available")
    device = torch.device('cuda')
    use_cuda=True
else:
    print('CUDA Not Available')
    device = torch.device('cpu')
    use_cuda=False
cudnn.benchmark = True      #May train faster but cost more memory

# architecture parameters
beta1 = 0.5
beta2 = 0.999

fparams    = "/scratch/cp3759/camels_data/params_IllustrisTNG.txt"
#fparams    = "../test_data/params_IllustrisTNG.txt" ## Simulation parameter file
fmaps      = ["/scratch/cp3759/camels_data/Maps_Mtot_IllustrisTNG_LH_z=0.00.npy"]
#fmaps      = ["../test_data/maps_Mtot_1k.npy"]   ## Simulated maps, must be a list as we can take multiple maps

fmaps_norm      = [None]
splits          = 15      ## Number of maps taken from each sim (range between 1 and 15, must match the dataset size)
                         ## i.e. splits=1 for 1k maps, splits=15 for 15k maps, or any integer value in between
seed            = 123    ## seed for the test/valid/train split
monopole        = True   ## Keep the monopole of the maps (True) or remove it (False)
rot_flip_in_mem = False  ## Whether rotations and flipings are kept in memory (faster but takes more memory if true)
smoothing       = 0      ## Smooth the maps with a Gaussian filter? 0 for no
arch            = "maxvit"   ## Which model architecture to use
features        = 4    ## Number of variables to train the model on. This can be 2, 4, 6 or 12, depending on whether
                         ## you want to a) train on both cosmological and IGM parameters and b) also ask the
                         ## network to estimate uncertanties on these parameters

CUDA Available


In [6]:
channels=len(fmaps)
## Set up indices to use for the loss function
if features==2:
    g=[0,1]
if features==4:
    g=[0,1]
    h=[2,3]
elif features==6:
    g=[0,1,2,3,4,5]
elif features==12:
    g=[0,1,2,3,4,5]
    h=[6,7,8,9,10,11]

In [7]:
config = {"learning rate": lr,
            "scattering learning rate": lr_sn,
            "wd": wd,
            "channels": channels,
            "epochs": epochs,
            "batch size": batch_size,
            "network": arch,
            "features": features,
            "splits":splits}

wandb.init(project="Capstone Maxvit Base 15k With Transform", entity="michaelhealy",config=config)

[34m[1mwandb[0m: Currently logged in as: [33mmichaelhealy[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
## Generate torch datasets
print('\nPreparing training set')
train_loader = camels_dataset.create_dataset_multifield('train', seed, fmaps, fparams, batch_size, splits, fmaps_norm,
                                         num_workers=num_workers, rot_flip_in_mem=rot_flip_in_mem, verbose=True)

# get validation set
print('\nPreparing validation set')
valid_loader = camels_dataset.create_dataset_multifield('valid', seed, fmaps, fparams, batch_size, splits, fmaps_norm,
                                         num_workers=num_workers, rot_flip_in_mem=rot_flip_in_mem,  verbose=True)    

# get test set
print('\nPreparing test set')
test_loader = camels_dataset.create_dataset_multifield('test', seed, fmaps, fparams, batch_size, splits, fmaps_norm,
                                        num_workers=num_workers, rot_flip_in_mem=rot_flip_in_mem,  verbose=True)


Preparing training set
Found 1 channels
Reading data...
4.598e+09 < F(all|orig) < 3.186e+15
9.663 < F(all|resc)  < 15.503
-2.931 < F(all|norm) < 8.946





Preparing validation set
Found 1 channels
Reading data...
4.598e+09 < F(all|orig) < 3.186e+15
9.663 < F(all|resc)  < 15.503
-2.931 < F(all|norm) < 8.946

Preparing test set
Found 1 channels
Reading data...
4.598e+09 < F(all|orig) < 3.186e+15
9.663 < F(all|resc)  < 15.503
-2.931 < F(all|norm) < 8.946


In [9]:
num_train_maps=len(train_loader.dataset.x)
wandb.config.update({"no. training maps": num_train_maps,
                        "fields": fmaps})

In [10]:
#Create model
model_size = 'Base'
data_size = '15k'
model = timm.models.maxvit_base_224(in_chans=channels,num_classes=features)
#model = timm.models.maxvit_xlarge_224(in_chans=channels,num_classes=features)
#model = timm.models.maxxvit_rmlp_tiny_rw_256(in_chans=channels,num_classes=features)
wandb.config.update({"learnable_parameters":sum(p.numel() for p in model.parameters())})
model.to(device=device) ## Put model on the appropriate device
print(f"MaxViT {model_size} model created")

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


MaxViT Base model created


Below counts trainable parameters

Regular Model Loop Below

In [11]:
wandb.watch(model, log_freq=1)

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd, betas=(beta1, beta2))
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.3, patience=10)

In [44]:
## Train and valid loops
start = time.time()
train_losses = []
val_losses = []
for epoch in range(epochs): #edited to slow down
    log_dic={}
    # train
    train_loss1, train_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
    train_loss, points = 0.0, 0
    model.train()
    i=0
    for x, y in train_loader:
        x    = x[:,:,16:-16,16:-16]
        bs   = x.shape[0]        #batch size
        x    = x.to(device)       #maps
        y    = y.to(device)[:,g]  #parameters
        i+=1
        #print(f'Training:    Epoch={epoch}    Iteration={i}')
        p    = model(x)           #NN output
        y_NN = p[:,g]             #posterior mean
        loss1 = torch.mean((y_NN - y)**2,                axis=0)
        if features==4 or features==12:
            e_NN = p[:,h]         #posterior std
            loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
            loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
            train_loss2 += loss2*bs
        else:
            loss = torch.mean(torch.log(loss1))
        train_loss1 += loss1*bs
        points      += bs
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    train_loss = torch.log(train_loss1/points) 
    if features==4 or features==12:
        train_loss+=torch.log(train_loss2/points)
    train_loss = torch.mean(train_loss).item()
    
    i=0
    
    # do validation: cosmo alone & all params
    valid_loss1, valid_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
    valid_loss, points = 0.0, 0
    model.eval()
    for x, y in valid_loader:
        with torch.no_grad():
            x    = x[:,:,16:-16,16:-16]
            x = torch.cat((x,x.flip(3)))
            y = torch.cat((y,y))
            bs    = x.shape[0]         #batch size
            x     = x.to(device)       #maps
            y     = y.to(device)[:,g]  #parameters
            i+=1
            #print(f'Validation:    Epoch={epoch}    Iteration={i}')
            p     = model(x)           #NN output
            y_NN  = p[:,g]             #posterior mean
            loss1 = torch.mean((y_NN - y)**2,                axis=0)
            if features==4 or features==12:    
                e_NN  = p[:,h]         #posterior std
                loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
                valid_loss2 += loss2*bs
            valid_loss1 += loss1*bs
            points     += bs


    valid_loss = torch.log(valid_loss1/points) 
    if features==4 or features==12:
        valid_loss+=torch.log(valid_loss2/points)
    valid_loss = torch.mean(valid_loss).item()

    scheduler.step(valid_loss)
    log_dic["training_loss"]=train_loss
    log_dic["valid_loss"]=valid_loss
    wandb.log(log_dic)

    # verbose
    print('%03d %.3e %.3e '%(epoch, train_loss, valid_loss), end='')
    print("")
    train_losses.append(train_loss)
    val_losses.append(valid_loss)

stop = time.time()
print('Time take (h):', "{:.4f}".format((stop-start)/3600.0))

## Model performance metrics on test set
num_maps=test_loader.dataset.size
## Now loop over test set and print accuracy
# define the arrays containing the value of the parameters
params_true = np.zeros((num_maps,len(g)), dtype=np.float32)
params_NN   = np.zeros((num_maps,len(g)), dtype=np.float32)
errors_NN   = np.zeros((num_maps,len(g)), dtype=np.float32)



OutOfMemoryError: CUDA out of memory. Tried to allocate 56.00 MiB (GPU 0; 15.78 GiB total capacity; 14.35 GiB already allocated; 10.19 MiB free; 14.62 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [33]:
x.flip(3)

tensor([[[[-0.3374, -0.3721, -0.4578,  ...,  0.4679,  0.4256,  0.5475],
          [-0.3277, -0.3488, -0.4178,  ...,  0.1402, -0.0104,  0.0495],
          [-0.2833, -0.3978, -0.4694,  ..., -0.0963, -0.1806, -0.2189],
          ...,
          [ 0.2521,  0.3420,  0.4284,  ...,  0.3993,  0.5832,  0.7669],
          [ 0.3780,  0.4490,  0.5075,  ...,  1.1881,  1.4582,  1.8287],
          [ 0.5650,  0.6055,  0.6545,  ...,  2.7351,  3.3737,  2.9904]]],


        [[[ 0.3346,  0.4174,  0.4182,  ...,  0.2565,  0.1676,  0.7016],
          [ 0.1852,  0.3971,  0.5053,  ..., -0.0156,  0.0219,  0.2959],
          [ 0.1229,  0.3446,  0.4948,  ..., -0.0509,  0.0646,  0.0090],
          ...,
          [ 0.5662,  0.5961,  0.6921,  ...,  0.6172,  0.6550,  0.5550],
          [ 0.7911,  0.9035,  0.9196,  ...,  0.6272,  0.7788,  0.6418],
          [ 0.8559,  0.9488,  0.9067,  ...,  0.6228,  0.8599,  0.8876]]],


        [[[-0.5225, -0.6725, -0.7858,  ...,  0.8485,  1.4605,  0.9476],
          [-0.5509, -0.685

In [41]:
torch.cat((y,y))

tensor([[0.9955, 0.9125],
        [0.8825, 0.9225],
        [0.3335, 0.6895],
        [0.9495, 0.8235],
        [0.2495, 0.2235],
        [0.9625, 0.1065],
        [0.4485, 0.8425],
        [0.8185, 0.8905],
        [0.4975, 0.8775],
        [0.3215, 0.0145],
        [0.8675, 0.0385],
        [0.7075, 0.0985],
        [0.7325, 0.6325],
        [0.1755, 0.2465],
        [0.9955, 0.9125],
        [0.8825, 0.9225],
        [0.3335, 0.6895],
        [0.9495, 0.8235],
        [0.2495, 0.2235],
        [0.9625, 0.1065],
        [0.4485, 0.8425],
        [0.8185, 0.8905],
        [0.4975, 0.8775],
        [0.3215, 0.0145],
        [0.8675, 0.0385],
        [0.7075, 0.0985],
        [0.7325, 0.6325],
        [0.1755, 0.2465]], device='cuda:0')

In [43]:
torch.cat((x,x.flip(3)))

tensor([[[[ 0.5475,  0.4256,  0.4679,  ..., -0.4578, -0.3721, -0.3374],
          [ 0.0495, -0.0104,  0.1402,  ..., -0.4178, -0.3488, -0.3277],
          [-0.2189, -0.1806, -0.0963,  ..., -0.4694, -0.3978, -0.2833],
          ...,
          [ 0.7669,  0.5832,  0.3993,  ...,  0.4284,  0.3420,  0.2521],
          [ 1.8287,  1.4582,  1.1881,  ...,  0.5075,  0.4490,  0.3780],
          [ 2.9904,  3.3737,  2.7351,  ...,  0.6545,  0.6055,  0.5650]]],


        [[[ 0.7016,  0.1676,  0.2565,  ...,  0.4182,  0.4174,  0.3346],
          [ 0.2959,  0.0219, -0.0156,  ...,  0.5053,  0.3971,  0.1852],
          [ 0.0090,  0.0646, -0.0509,  ...,  0.4948,  0.3446,  0.1229],
          ...,
          [ 0.5550,  0.6550,  0.6172,  ...,  0.6921,  0.5961,  0.5662],
          [ 0.6418,  0.7788,  0.6272,  ...,  0.9196,  0.9035,  0.7911],
          [ 0.8876,  0.8599,  0.6228,  ...,  0.9067,  0.9488,  0.8559]]],


        [[[ 0.9476,  1.4605,  0.8485,  ..., -0.7858, -0.6725, -0.5225],
          [ 1.2489,  1.498

In [None]:
plt.plot(*[range(epochs)],train_losses,label = 'Training Loss', linestyle="-.")
plt.plot(*[range(epochs)],val_losses,label = 'Validation Loss', linestyle=":")
plt.title(f'{model_size} Model with {data_size} data: Epochs')
plt.legend()
plt.xlabel("Epochs")
plt.ylabel('Loss')

## Parameter Looping

plt.plot(lr,train_fin_losses,label = 'Training Loss', linestyle="-.")
plt.plot(lr,val_fin_losses,label = 'Validation Loss', linestyle=":")
plt.title(f'{model_size} Model with {data_size} data: Learning Rate')
plt.legend()
plt.xlabel("Learning rate")
plt.ylabel('Loss')

In [None]:
# get test loss

test_loss1, test_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
test_loss, points = 0.0, 0
model.eval()
for x, y in test_loader:
    print(points)
    with torch.no_grad():
        x    = x[:,:,16:-16,16:-16]
        bs    = x.shape[0]         #batch size
        x     = x.to(device)       #send data to device
        y     = y.to(device)[:,g]  #send data to device
        p     = model(x)           #prediction for mean and variance
        y_NN  = p[:,g]             #prediction for mean
        loss1 = torch.mean((y_NN - y)**2,                axis=0)
        if features==4 or features==12:
            e_NN  = p[:,h]         #posterior std
            loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
            test_loss2 += loss2*bs
        test_loss1 += loss1*bs
        test_loss = torch.log(test_loss1/points)
        if features==4 or features==12:
            test_loss+=torch.log(test_loss2/points)
        test_loss = torch.mean(test_loss).item()

        # save results to their corresponding arrays
        params_true[points:points+x.shape[0]] = y.cpu().numpy() 
        params_NN[points:points+x.shape[0]]   = y_NN.cpu().numpy()
        if features==4 or features==12:
            errors_NN[points:points+x.shape[0]]   = np.abs(e_NN.cpu().numpy())
        points    += x.shape[0]
test_loss = torch.log(test_loss1/points) + torch.log(test_loss2/points)
test_loss = torch.mean(test_loss).item()
print('Test loss = %.3e\n'%test_loss)

# de-normalize
## I guess these are the hardcoded parameter limits
minimum = np.array([0.1, 0.6, 0.25, 0.25, 0.5, 0.5])
maximum = np.array([0.5, 1.0, 4.00, 4.00, 2.0, 2.0])

## Drop feedback parameters if they aren't included
minimum=minimum[g]
maximum=maximum[g]
params_true = params_true*(maximum - minimum) + minimum
params_NN   = params_NN*(maximum - minimum) + minimum


test_error = 100*np.mean(np.sqrt((params_true - params_NN)**2)/params_true,axis=0)
print('Error Omega_m = %.3f'%test_error[0])
print('Error sigma_8 = %.3f'%test_error[1])

wandb.run.summary["Error Omega_m"]=test_error[0]
wandb.run.summary["Error sigma_8"]=test_error[1]

if features>4:
    print('Error A_SN1   = %.3f'%test_error[2])
    print('Error A_AGN1  = %.3f'%test_error[3])
    print('Error A_SN2   = %.3f'%test_error[4])
    print('Error A_AGN2  = %.3f\n'%test_error[5])
    wandb.run.summary["Error A_SN1"]  =test_error[2]
    wandb.run.summary["Error A_AGN1"] =test_error[3]
    wandb.run.summary["Error A_SN2"]  =test_error[4]
    wandb.run.summary["Error A_AGN2"] =test_error[5]

wandb.run.summary["Error Omega_m"]=test_error[0]
wandb.run.summary["Error sigma_8"]=test_error[1]

if features==4:
    errors_NN   = errors_NN*(maximum - minimum)
    mean_error = 100*(np.absolute(np.mean(errors_NN/params_NN, axis=0)))
    print('Bayesian error Omega_m = %.3f'%mean_error[0])
    print('Bayesian error sigma_8 = %.3f'%mean_error[1])
    wandb.run.summary["Predicted error Omega_m"]=mean_error[0]
    wandb.run.summary["Predicted error sigma_8"]=mean_error[1]

elif features==12:
    errors_NN   = errors_NN*(maximum - minimum)
    mean_error = 100*(np.absolute(np.mean(errors_NN/params_NN, axis=0)))
    print('Bayesian error Omega_m = %.3f'%mean_error[0])
    print('Bayesian error sigma_8 = %.3f'%mean_error[1])
    print('Bayesian error A_SN1   = %.3f'%mean_error[2])
    print('Bayesian error A_AGN1  = %.3f'%mean_error[3])
    print('Bayesian error A_SN2   = %.3f'%mean_error[4])
    print('Bayesian error A_AGN2  = %.3f\n'%mean_error[5])
    wandb.run.summary["Predicted error Omega_m"]=mean_error[0]
    wandb.run.summary["Predicted error sigma_8"]=mean_error[1]
    wandb.run.summary["Predicted error A_SN1"]  =mean_error[2]
    wandb.run.summary["Predicted error A_AGN1"] =mean_error[3]
    wandb.run.summary["Predicted error A_SN2"]  =mean_error[4]
    wandb.run.summary["Predicted error A_AGN2"] =mean_error[5]


if features<5:
    f, axarr = plt.subplots(1, 2, figsize=(9,6))
    axarr[0].plot(np.linspace(min(params_true[:,0]),max(params_true[:,0]),100),np.linspace(min(params_true[:,0]),max(params_true[:,0]),100),color="black")
    axarr[1].plot(np.linspace(min(params_true[:,1]),max(params_true[:,1]),100),np.linspace(min(params_true[:,1]),max(params_true[:,1]),100),color="black")
    if features==4:
        axarr[0].errorbar(params_true[:,0],params_NN[:,0],errors_NN[:,0],marker="o",ls="none")
        axarr[1].errorbar(params_true[:,1],params_NN[:,1],errors_NN[:,1],marker="o",ls="none")
    else:
        axarr[0].plot(params_true[:,0],params_NN[:,0],marker="o",ls="none")
        axarr[1].plot(params_true[:,1],params_NN[:,1],marker="o",ls="none")
        
    axarr[0].set_xlabel(r"True $\Omega_m$")
    axarr[0].set_ylabel(r"Predicted $\Omega_m$")
    axarr[0].text(0.1,0.9,"%.3f %% error" % test_error[0],fontsize=12,transform=axarr[0].transAxes)

    axarr[1].set_xlabel(r"True $\sigma_8$")
    axarr[1].set_ylabel(r"Predicted $\sigma_8$")
    axarr[1].text(0.1,0.9,"%.3f %% error" % test_error[1],fontsize=12,transform=axarr[1].transAxes)


if features>4:
    f, axarr = plt.subplots(3, 2, figsize=(14,20))
    for aa in range(0,6,2):
        axarr[aa//2][0].plot(np.linspace(min(params_true[:,aa]),max(params_true[:,aa]),100),np.linspace(min(params_true[:,aa]),max(params_true[:,aa]),100),color="black")
        axarr[aa//2][1].plot(np.linspace(min(params_true[:,aa+1]),max(params_true[:,aa+1]),100),np.linspace(min(params_true[:,aa+1]),max(params_true[:,aa+1]),100),color="black")
        if features==12:
            axarr[aa//2][0].errorbar(params_true[:,aa],params_NN[:,aa],errors_NN[:,aa],marker="o",ls="none")
            axarr[aa//2][1].errorbar(params_true[:,aa+1],params_NN[:,aa+1],errors_NN[:,aa+1],marker="o",ls="none")
        else:
            axarr[aa//2][0].plot(params_true[:,aa],params_NN[:,aa],marker="o",ls="none")
            axarr[aa//2][1].plot(params_true[:,aa+1],params_NN[:,aa+1],marker="o",ls="none")
            
    axarr[0][0].set_xlabel(r"True $\Omega_m$")
    axarr[0][0].set_ylabel(r"Predicted $\Omega_m$")
    axarr[0][0].text(0.1,0.9,"%.3f %% error" % test_error[0],fontsize=12,transform=axarr[0][0].transAxes)

    axarr[0][1].set_xlabel(r"True $\sigma_8$")
    axarr[0][1].set_ylabel(r"Predicted $\sigma_8$")
    axarr[0][1].text(0.1,0.9,"%.3f %% error" % test_error[1],fontsize=12,transform=axarr[0][1].transAxes)

    axarr[1][0].set_xlabel(r"True $A_\mathrm{SN1}$")
    axarr[1][0].set_ylabel(r"Predicted $A_\mathrm{SN1}$")
    axarr[1][0].text(0.1,0.9,"%.3f %% error" % test_error[2],fontsize=12,transform=axarr[1][0].transAxes)

    axarr[1][1].set_xlabel(r"True $A_\mathrm{AGN1}$")
    axarr[1][1].set_ylabel(r"Predicted $A_\mathrm{AGN1}$")
    axarr[1][1].text(0.1,0.9,"%.3f %% error" % test_error[3],fontsize=12,transform=axarr[1][1].transAxes)

    axarr[2][0].set_xlabel(r"True $A_\mathrm{SN2}$")
    axarr[2][0].set_ylabel(r"Predicted $A_\mathrm{SN2}$")
    axarr[2][0].text(0.1,0.9,"%.3f %% error" % test_error[4],fontsize=12,transform=axarr[2][0].transAxes)

    axarr[2][1].set_xlabel(r"True $A_\mathrm{AGN2}$")
    axarr[2][1].set_ylabel(r"Predicted $A_\mathrm{AGN2}$")
    axarr[2][1].text(0.1,0.9,"%.3f %% error" % test_error[4],fontsize=12,transform=axarr[2][1].transAxes)

figure=wandb.Image(f)
wandb.log({"performance": figure})
wandb.finish()


In [None]:
import IPython
IPython.display.Audio("success_retro.wav",autoplay=True)