In [1]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, TensorDataset, random_split
from torch.functional import F
import torch.nn as nn

sys.path.append('../')
from utils_modules.models import Expander, vector_to_Cov
from utils_modules.vicreg import vicreg_loss
import utils_modules.data as utils_data
import utils_modules.baryons_toy_Pk as utils_toy_Pk


In [2]:
# select device; use GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Device: %s'%(device))

Device: cuda


## Generate parameters and power spectra

In [None]:
kmin  = 7e-3 #h/Mpc
kmax = 1

kF     = kmin
k_bins = int((kmax-kmin)/kF)
k      = np.arange(3,k_bins+2)*kF 
Nk     = 4.0*np.pi*k**2*kF/kF**3  #number of modes in each k-bin

# model parameters
predict_D     = True
Pk_continuous = True #whether fix A_value for kpivot or not

dset_size = 1000
train_frac, valid_frac, test_frac = 0.8, 0.1, 0.1

seed = 1
splits = 10

In [None]:
params = utils_toy_Pk.generate_params(dset_size, splits, 
                                      seed = seed,
                                      predict_D = predict_D, 
                                      Pk_continuous = Pk_continuous)
params = params.reshape(dset_size, splits, -1)

Pk = utils_toy_Pk.get_Pk_arr(k, Nk, params, predict_D = predict_D, seed = seed,)

## Train an encoder model

In [None]:
def run_training(fmodel, floss, 
                 net, mlp_net, 
                 optimizer, scheduler, 
                 train_loader, valid_loader,
                 epochs, verbose=True,
                 inv_weight = 1, var_weight = 0, cov_weight = 0):
    
    # compute minimum validation loss
    net.eval() 
    mlp_net.eval()
    total_loss, points = 0., 0
    
    with torch.no_grad():
        for x, y in valid_loader:
            print(x.shape)
            x = x.float().to(device)
            bsz = x.shape[0]
            
            emb_q = mlp_net(net(x[:, 0].contiguous()))
            emb_k = mlp_net(net(x[:, 1].contiguous()))
            
            loss, inv, var, cov = vicreg_loss(emb_q, emb_k, 
                                              inv_weight, var_weight, cov_weight)
            total_loss += loss.detach()*bsz
            points += bsz

    min_loss_valid = total_loss/points
    if verbose: print('Min validation loss: ', min_loss_valid)
    
    # do a loop over the different epochs
    for epoch in range(epochs): 
        
        total_loss, points = 0., 0
        inv_loss, var_loss, cov_loss = 0., 0., 0.
        
        net.train()
        mlp_net.train()
        for x, y in train_loader:
            x = x.float().to(device)
            bsz = x.shape[0]
            emb_q = mlp_net(net(x[:, 0].contiguous()))
            emb_k = mlp_net(net(x[:, 1].contiguous()))
            
            loss, inv, var, cov = vicreg_loss(emb_q, emb_k, 
                                              inv_weight, var_weight, cov_weight)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.detach()*bsz
            points += bsz
            
            inv_loss += inv.detach()*bsz
            var_loss += var.detach()*bsz
            cov_loss += cov.detach()*bsz
                
        # get the training loss and its components    
        loss_train = total_loss/points
        inv_loss   = inv_loss/points
        var_loss   = var_loss/points
        cov_loss   = cov_loss/points
               

        # validation
        net.eval() 
        mlp_net.eval()
        total_loss, points = 0., 0
        inv_loss, var_loss, cov_loss = 0., 0., 0.
        with torch.no_grad():
            for x, y in valid_loader:
                x = x.float().to(device)
                bsz = x.shape[0]
                emb_q = mlp_net(net(x[:, 0].contiguous()))
                emb_k = mlp_net(net(x[:, 1].contiguous()))
                
                loss, inv, var, cov = vicreg_loss(emb_q, emb_k, 
                                                  inv_weight, var_weight, cov_weight)
                total_loss += loss.detach()*bsz
                points += bsz
                
                inv_loss += inv.detach()*bsz
                var_loss += var.detach()*bsz
                cov_loss += cov.detach()*bsz
                
        # get the validation loss and its components      
        loss_valid = total_loss/points
        inv_loss   = inv_loss/points
        var_loss   = var_loss/points
        cov_loss   = cov_loss/points

        # save model if it is better
        if loss_valid < min_loss_valid:
            if verbose:
                print('saving model;  epoch %d; %.4e %.4e'\
                      %(epoch, loss_train, loss_valid))
            torch.save(net.state_dict(), fmodel)
            min_loss_valid = loss_valid
        else:
            if verbose:
                print('epoch %d; %.4e %.4e'\
                      %(epoch,loss_train,loss_valid))

        if epoch == 0:
            f = open(fout, 'w')
        else:
            f = open(fout, 'a')
        f.write('%d %.4e %.4e %.4e %.4e %.4e\n'%(epoch, loss_train, loss_valid, 
                                                 inv_loss, var_loss, cov_loss))
        f.close()
        scheduler.step(loss_valid)
        
    return net, mlp_net

In [None]:
# training parameters
epochs           = 500
batch_size       = 256

dset = utils_toy_Pk.customDataset(torch.tensor(np.log(Pk)), 
                                  torch.tensor(params), 
                                  utils_toy_Pk.AugmentationTransformations())
train_dset, valid_dset, test_dset = torch.utils.data.random_split(dset,
                                                                [int(train_frac*dset_size),
                                                                 int(valid_frac*dset_size), 
                                                                 int(test_frac*dset_size)],
                                                                 generator=torch.Generator().manual_seed(seed))


train_loader = DataLoader(train_dset, batch_size, shuffle = True)
valid_loader = DataLoader(valid_dset, batch_size, shuffle = True)
test_loader  = DataLoader(test_dset, batch_size, shuffle = False)


In [None]:
inv, var, cov = 15, 15, 1
fmodel = 'trained_models/VICReg_{:d}_{:d}_{:d}.pt'.format(inv, var, cov)
fout   = 'trained_models/VICReg_{:d}_{:d}_{:d}.txt'.format(inv, var, cov)
    
# define the expander model
hidden = 16
last_layer = 32
args_net = [hidden, 
            last_layer, last_layer, last_layer, 
            last_layer, last_layer, last_layer]
net = Expander(args_net, k.shape[0], bn = True).to(device)

# define the expander model
mlp_exp_units = [4*last_layer, 4*last_layer, 4*last_layer]
inference_net = Expander(mlp_exp_units, last_layer, bn = True).to(device)


# define the optimizer
lr = 1e-3
wd = 0.

optimizer = torch.optim.Adam([*net.parameters(), *inference_net.parameters()], 
                                 lr=lr, betas=(0.9, 0.999),
                                 eps=1e-8, amsgrad=False, weight_decay=wd)  
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)



In [None]:
net, mlp = run_training(fmodel, fout, net, inference_net, 
                        optimizer, scheduler, 
                        train_loader, valid_loader,
                        inv_weight = inv, var_weight = var, cov_weight = cov,
                        epochs = 300)


In [None]:
losses = np.loadtxt(fout)
plt.plot(losses[:, 0], losses[:, 1], label = 'Training loss')
plt.plot(losses[:, 0], losses[:, 2], label = 'Validation loss')
plt.legend(loc = 'best')
print(losses[:, 1].min(), losses[:, 2].min())